http_client and http_async_client bugfix (#1454)
This commit is contained in:
@@ -30,6 +30,7 @@ llm:
|
|||||||
response_format:
|
response_format:
|
||||||
type: json_object
|
type: json_object
|
||||||
api_version: 2024-02-01
|
api_version: 2024-02-01
|
||||||
|
http_client_proxies: http://testproxy.mem0.net:8000
|
||||||
prompt: |
|
prompt: |
|
||||||
Use the following pieces of context to answer the query at the end.
|
Use the following pieces of context to answer the query at the end.
|
||||||
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
||||||
@@ -89,7 +90,8 @@ cache:
|
|||||||
"system_prompt": "Act as William Shakespeare. Answer the following questions in the style of William Shakespeare.",
|
"system_prompt": "Act as William Shakespeare. Answer the following questions in the style of William Shakespeare.",
|
||||||
"api_key": "sk-xxx",
|
"api_key": "sk-xxx",
|
||||||
"model_kwargs": {"response_format": {"type": "json_object"}},
|
"model_kwargs": {"response_format": {"type": "json_object"}},
|
||||||
"api_version": "2024-02-01"
|
"api_version": "2024-02-01",
|
||||||
|
"http_client_proxies": "http://testproxy.mem0.net:8000",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"vectordb": {
|
"vectordb": {
|
||||||
@@ -150,7 +152,8 @@ config = {
|
|||||||
"Act as William Shakespeare. Answer the following questions in the style of William Shakespeare."
|
"Act as William Shakespeare. Answer the following questions in the style of William Shakespeare."
|
||||||
),
|
),
|
||||||
'api_key': 'sk-xxx',
|
'api_key': 'sk-xxx',
|
||||||
"model_kwargs": {"response_format": {"type": "json_object"}}
|
"model_kwargs": {"response_format": {"type": "json_object"}},
|
||||||
|
"http_client_proxies": "http://testproxy.mem0.net:8000",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
'vectordb': {
|
'vectordb': {
|
||||||
@@ -211,6 +214,8 @@ Alright, let's dive into what each key means in the yaml config above:
|
|||||||
- `number_documents` (Integer): Number of documents to pull from the vectordb as context, defaults to 1
|
- `number_documents` (Integer): Number of documents to pull from the vectordb as context, defaults to 1
|
||||||
- `api_key` (String): The API key for the language model.
|
- `api_key` (String): The API key for the language model.
|
||||||
- `model_kwargs` (Dict): Keyword arguments to pass to the language model. Used for `aws_bedrock` provider, since it requires different arguments for each model.
|
- `model_kwargs` (Dict): Keyword arguments to pass to the language model. Used for `aws_bedrock` provider, since it requires different arguments for each model.
|
||||||
|
- `http_client_proxies` (Dict | String): The proxy server settings used to create `self.http_client` using `httpx.Client(proxies=http_client_proxies)`
|
||||||
|
- `http_async_client_proxies` (Dict | String): The proxy server settings for async calls used to create `self.http_async_client` using `httpx.AsyncClient(proxies=http_async_client_proxies)`
|
||||||
3. `vectordb` Section:
|
3. `vectordb` Section:
|
||||||
- `provider` (String): The provider for the vector database, set to 'chroma'. You can find the full list of vector database providers in [our docs](/components/vector-databases).
|
- `provider` (String): The provider for the vector database, set to 'chroma'. You can find the full list of vector database providers in [our docs](/components/vector-databases).
|
||||||
- `config`:
|
- `config`:
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from string import Template
|
from string import Template
|
||||||
from typing import Any, Mapping, Optional
|
from typing import Any, Mapping, Optional, Dict, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
from embedchain.config.base_config import BaseConfig
|
from embedchain.config.base_config import BaseConfig
|
||||||
from embedchain.helpers.json_serializable import register_deserializable
|
from embedchain.helpers.json_serializable import register_deserializable
|
||||||
@@ -99,8 +101,8 @@ class BaseLlmConfig(BaseConfig):
|
|||||||
base_url: Optional[str] = None,
|
base_url: Optional[str] = None,
|
||||||
endpoint: Optional[str] = None,
|
endpoint: Optional[str] = None,
|
||||||
model_kwargs: Optional[dict[str, Any]] = None,
|
model_kwargs: Optional[dict[str, Any]] = None,
|
||||||
http_client: Optional[Any] = None,
|
http_client_proxies: Optional[Union[Dict, str]] = None,
|
||||||
http_async_client: Optional[Any] = None,
|
http_async_client_proxies: Optional[Union[Dict, str]] = None,
|
||||||
local: Optional[bool] = False,
|
local: Optional[bool] = False,
|
||||||
default_headers: Optional[Mapping[str, str]] = None,
|
default_headers: Optional[Mapping[str, str]] = None,
|
||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
@@ -149,6 +151,11 @@ class BaseLlmConfig(BaseConfig):
|
|||||||
:type callbacks: Optional[list], optional
|
:type callbacks: Optional[list], optional
|
||||||
:param query_type: The type of query to use, defaults to None
|
:param query_type: The type of query to use, defaults to None
|
||||||
:type query_type: Optional[str], optional
|
:type query_type: Optional[str], optional
|
||||||
|
:param http_client_proxies: The proxy server settings used to create self.http_client, defaults to None
|
||||||
|
:type http_client_proxies: Optional[Dict | str], optional
|
||||||
|
:param http_async_client_proxies: The proxy server settings for async calls used to create
|
||||||
|
self.http_async_client, defaults to None
|
||||||
|
:type http_async_client_proxies: Optional[Dict | str], optional
|
||||||
:param local: If True, the model will be run locally, defaults to False (for huggingface provider)
|
:param local: If True, the model will be run locally, defaults to False (for huggingface provider)
|
||||||
:type local: Optional[bool], optional
|
:type local: Optional[bool], optional
|
||||||
:param default_headers: Set additional HTTP headers to be sent with requests to OpenAI
|
:param default_headers: Set additional HTTP headers to be sent with requests to OpenAI
|
||||||
@@ -181,8 +188,10 @@ class BaseLlmConfig(BaseConfig):
|
|||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
self.model_kwargs = model_kwargs
|
self.model_kwargs = model_kwargs
|
||||||
self.http_client = http_client
|
self.http_client = httpx.Client(proxies=http_client_proxies) if http_client_proxies else None
|
||||||
self.http_async_client = http_async_client
|
self.http_async_client = (
|
||||||
|
httpx.AsyncClient(proxies=http_async_client_proxies) if http_async_client_proxies else None
|
||||||
|
)
|
||||||
self.local = local
|
self.local = local
|
||||||
self.default_headers = default_headers
|
self.default_headers = default_headers
|
||||||
self.online = online
|
self.online = online
|
||||||
|
|||||||
@@ -56,7 +56,13 @@ class OpenAILlm(BaseLlm):
|
|||||||
http_async_client=config.http_async_client,
|
http_async_client=config.http_async_client,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
chat = ChatOpenAI(**kwargs, api_key=api_key, base_url=base_url)
|
chat = ChatOpenAI(
|
||||||
|
**kwargs,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
http_client=config.http_client,
|
||||||
|
http_async_client=config.http_async_client,
|
||||||
|
)
|
||||||
if self.tools:
|
if self.tools:
|
||||||
return self._query_function_call(chat, self.tools, messages)
|
return self._query_function_call(chat, self.tools, messages)
|
||||||
|
|
||||||
@@ -69,8 +75,7 @@ class OpenAILlm(BaseLlm):
|
|||||||
messages: list[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
) -> str:
|
) -> str:
|
||||||
from langchain.output_parsers.openai_tools import JsonOutputToolsParser
|
from langchain.output_parsers.openai_tools import JsonOutputToolsParser
|
||||||
from langchain_core.utils.function_calling import \
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||||
convert_to_openai_tool
|
|
||||||
|
|
||||||
openai_tools = [convert_to_openai_tool(tools)]
|
openai_tools = [convert_to_openai_tool(tools)]
|
||||||
chat = chat.bind(tools=openai_tools).pipe(JsonOutputToolsParser())
|
chat = chat.bind(tools=openai_tools).pipe(JsonOutputToolsParser())
|
||||||
|
|||||||
@@ -442,6 +442,8 @@ def validate_config(config_data):
|
|||||||
Optional("base_url"): str,
|
Optional("base_url"): str,
|
||||||
Optional("default_headers"): dict,
|
Optional("default_headers"): dict,
|
||||||
Optional("api_version"): Or(str, datetime.date),
|
Optional("api_version"): Or(str, datetime.date),
|
||||||
|
Optional("http_client_proxies"): Or(str, dict),
|
||||||
|
Optional("http_async_client_proxies"): Or(str, dict),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Optional("vectordb"): {
|
Optional("vectordb"): {
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||||
|
|
||||||
@@ -7,15 +8,27 @@ from embedchain.config import BaseLlmConfig
|
|||||||
from embedchain.llm.openai import OpenAILlm
|
from embedchain.llm.openai import OpenAILlm
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture()
|
||||||
def config():
|
def env_config():
|
||||||
os.environ["OPENAI_API_KEY"] = "test_api_key"
|
os.environ["OPENAI_API_KEY"] = "test_api_key"
|
||||||
os.environ["OPENAI_API_BASE"] = "https://api.openai.com/v1/engines/"
|
os.environ["OPENAI_API_BASE"] = "https://api.openai.com/v1/engines/"
|
||||||
|
yield
|
||||||
|
os.environ.pop("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config(env_config):
|
||||||
config = BaseLlmConfig(
|
config = BaseLlmConfig(
|
||||||
temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt", model="gpt-3.5-turbo"
|
temperature=0.7,
|
||||||
|
max_tokens=50,
|
||||||
|
top_p=0.8,
|
||||||
|
stream=False,
|
||||||
|
system_prompt="System prompt",
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
http_client_proxies=None,
|
||||||
|
http_async_client_proxies=None,
|
||||||
)
|
)
|
||||||
yield config
|
yield config
|
||||||
os.environ.pop("OPENAI_API_KEY")
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_llm_model_answer(config, mocker):
|
def test_get_llm_model_answer(config, mocker):
|
||||||
@@ -75,6 +88,8 @@ def test_get_llm_model_answer_without_system_prompt(config, mocker):
|
|||||||
model_kwargs={"top_p": config.top_p},
|
model_kwargs={"top_p": config.top_p},
|
||||||
api_key=os.environ["OPENAI_API_KEY"],
|
api_key=os.environ["OPENAI_API_KEY"],
|
||||||
base_url=os.environ["OPENAI_API_BASE"],
|
base_url=os.environ["OPENAI_API_BASE"],
|
||||||
|
http_client=None,
|
||||||
|
http_async_client=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -93,6 +108,8 @@ def test_get_llm_model_answer_with_special_headers(config, mocker):
|
|||||||
api_key=os.environ["OPENAI_API_KEY"],
|
api_key=os.environ["OPENAI_API_KEY"],
|
||||||
base_url=os.environ["OPENAI_API_BASE"],
|
base_url=os.environ["OPENAI_API_BASE"],
|
||||||
default_headers={"test": "test"},
|
default_headers={"test": "test"},
|
||||||
|
http_client=None,
|
||||||
|
http_async_client=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -110,6 +127,8 @@ def test_get_llm_model_answer_with_model_kwargs(config, mocker):
|
|||||||
model_kwargs={"top_p": config.top_p, "response_format": {"type": "json_object"}},
|
model_kwargs={"top_p": config.top_p, "response_format": {"type": "json_object"}},
|
||||||
api_key=os.environ["OPENAI_API_KEY"],
|
api_key=os.environ["OPENAI_API_KEY"],
|
||||||
base_url=os.environ["OPENAI_API_BASE"],
|
base_url=os.environ["OPENAI_API_BASE"],
|
||||||
|
http_client=None,
|
||||||
|
http_async_client=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -136,8 +155,78 @@ def test_get_llm_model_answer_with_tools(config, mocker, mock_return, expected):
|
|||||||
model_kwargs={"top_p": config.top_p},
|
model_kwargs={"top_p": config.top_p},
|
||||||
api_key=os.environ["OPENAI_API_KEY"],
|
api_key=os.environ["OPENAI_API_KEY"],
|
||||||
base_url=os.environ["OPENAI_API_BASE"],
|
base_url=os.environ["OPENAI_API_BASE"],
|
||||||
|
http_client=None,
|
||||||
|
http_async_client=None,
|
||||||
)
|
)
|
||||||
mocked_convert_to_openai_tool.assert_called_once_with({"test": "test"})
|
mocked_convert_to_openai_tool.assert_called_once_with({"test": "test"})
|
||||||
mocked_json_output_tools_parser.assert_called_once()
|
mocked_json_output_tools_parser.assert_called_once()
|
||||||
|
|
||||||
assert answer == expected
|
assert answer == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_llm_model_answer_with_http_client_proxies(env_config, mocker):
|
||||||
|
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
|
||||||
|
mock_http_client = mocker.Mock(spec=httpx.Client)
|
||||||
|
mock_http_client_instance = mocker.Mock(spec=httpx.Client)
|
||||||
|
mock_http_client.return_value = mock_http_client_instance
|
||||||
|
|
||||||
|
mocker.patch("httpx.Client", new=mock_http_client)
|
||||||
|
|
||||||
|
config = BaseLlmConfig(
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=50,
|
||||||
|
top_p=0.8,
|
||||||
|
stream=False,
|
||||||
|
system_prompt="System prompt",
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
http_client_proxies="http://testproxy.mem0.net:8000",
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = OpenAILlm(config)
|
||||||
|
llm.get_llm_model_answer("Test query")
|
||||||
|
|
||||||
|
mocked_openai_chat.assert_called_once_with(
|
||||||
|
model=config.model,
|
||||||
|
temperature=config.temperature,
|
||||||
|
max_tokens=config.max_tokens,
|
||||||
|
model_kwargs={"top_p": config.top_p},
|
||||||
|
api_key=os.environ["OPENAI_API_KEY"],
|
||||||
|
base_url=os.environ["OPENAI_API_BASE"],
|
||||||
|
http_client=mock_http_client_instance,
|
||||||
|
http_async_client=None,
|
||||||
|
)
|
||||||
|
mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000")
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_llm_model_answer_with_http_async_client_proxies(env_config, mocker):
|
||||||
|
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
|
||||||
|
mock_http_async_client = mocker.Mock(spec=httpx.AsyncClient)
|
||||||
|
mock_http_async_client_instance = mocker.Mock(spec=httpx.AsyncClient)
|
||||||
|
mock_http_async_client.return_value = mock_http_async_client_instance
|
||||||
|
|
||||||
|
mocker.patch("httpx.AsyncClient", new=mock_http_async_client)
|
||||||
|
|
||||||
|
config = BaseLlmConfig(
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=50,
|
||||||
|
top_p=0.8,
|
||||||
|
stream=False,
|
||||||
|
system_prompt="System prompt",
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
http_async_client_proxies={"http://": "http://testproxy.mem0.net:8000"},
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = OpenAILlm(config)
|
||||||
|
llm.get_llm_model_answer("Test query")
|
||||||
|
|
||||||
|
mocked_openai_chat.assert_called_once_with(
|
||||||
|
model=config.model,
|
||||||
|
temperature=config.temperature,
|
||||||
|
max_tokens=config.max_tokens,
|
||||||
|
model_kwargs={"top_p": config.top_p},
|
||||||
|
api_key=os.environ["OPENAI_API_KEY"],
|
||||||
|
base_url=os.environ["OPENAI_API_BASE"],
|
||||||
|
http_client=None,
|
||||||
|
http_async_client=mock_http_async_client_instance,
|
||||||
|
)
|
||||||
|
mock_http_async_client.assert_called_once_with(proxies={"http://": "http://testproxy.mem0.net:8000"})
|
||||||
|
|||||||
Reference in New Issue
Block a user