http_client and http_async_client bugfix (#1454)
This commit is contained in:
@@ -30,6 +30,7 @@ llm:
|
||||
response_format:
|
||||
type: json_object
|
||||
api_version: 2024-02-01
|
||||
http_client_proxies: http://testproxy.mem0.net:8000
|
||||
prompt: |
|
||||
Use the following pieces of context to answer the query at the end.
|
||||
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
||||
@@ -89,7 +90,8 @@ cache:
|
||||
"system_prompt": "Act as William Shakespeare. Answer the following questions in the style of William Shakespeare.",
|
||||
"api_key": "sk-xxx",
|
||||
"model_kwargs": {"response_format": {"type": "json_object"}},
|
||||
"api_version": "2024-02-01"
|
||||
"api_version": "2024-02-01",
|
||||
"http_client_proxies": "http://testproxy.mem0.net:8000",
|
||||
}
|
||||
},
|
||||
"vectordb": {
|
||||
@@ -150,7 +152,8 @@ config = {
|
||||
"Act as William Shakespeare. Answer the following questions in the style of William Shakespeare."
|
||||
),
|
||||
'api_key': 'sk-xxx',
|
||||
"model_kwargs": {"response_format": {"type": "json_object"}}
|
||||
"model_kwargs": {"response_format": {"type": "json_object"}},
|
||||
"http_client_proxies": "http://testproxy.mem0.net:8000",
|
||||
}
|
||||
},
|
||||
'vectordb': {
|
||||
@@ -211,6 +214,8 @@ Alright, let's dive into what each key means in the yaml config above:
|
||||
- `number_documents` (Integer): Number of documents to pull from the vectordb as context, defaults to 1
|
||||
- `api_key` (String): The API key for the language model.
|
||||
- `model_kwargs` (Dict): Keyword arguments to pass to the language model. Used for `aws_bedrock` provider, since it requires different arguments for each model.
|
||||
- `http_client_proxies` (Dict | String): The proxy server settings used to create `self.http_client` using `httpx.Client(proxies=http_client_proxies)`
|
||||
- `http_async_client_proxies` (Dict | String): The proxy server settings for async calls used to create `self.http_async_client` using `httpx.AsyncClient(proxies=http_async_client_proxies)`
|
||||
3. `vectordb` Section:
|
||||
- `provider` (String): The provider for the vector database, set to 'chroma'. You can find the full list of vector database providers in [our docs](/components/vector-databases).
|
||||
- `config`:
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import logging
|
||||
import re
|
||||
from string import Template
|
||||
from typing import Any, Mapping, Optional
|
||||
from typing import Any, Mapping, Optional, Dict, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from embedchain.config.base_config import BaseConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
@@ -99,8 +101,8 @@ class BaseLlmConfig(BaseConfig):
|
||||
base_url: Optional[str] = None,
|
||||
endpoint: Optional[str] = None,
|
||||
model_kwargs: Optional[dict[str, Any]] = None,
|
||||
http_client: Optional[Any] = None,
|
||||
http_async_client: Optional[Any] = None,
|
||||
http_client_proxies: Optional[Union[Dict, str]] = None,
|
||||
http_async_client_proxies: Optional[Union[Dict, str]] = None,
|
||||
local: Optional[bool] = False,
|
||||
default_headers: Optional[Mapping[str, str]] = None,
|
||||
api_version: Optional[str] = None,
|
||||
@@ -149,6 +151,11 @@ class BaseLlmConfig(BaseConfig):
|
||||
:type callbacks: Optional[list], optional
|
||||
:param query_type: The type of query to use, defaults to None
|
||||
:type query_type: Optional[str], optional
|
||||
:param http_client_proxies: The proxy server settings used to create self.http_client, defaults to None
|
||||
:type http_client_proxies: Optional[Dict | str], optional
|
||||
:param http_async_client_proxies: The proxy server settings for async calls used to create
|
||||
self.http_async_client, defaults to None
|
||||
:type http_async_client_proxies: Optional[Dict | str], optional
|
||||
:param local: If True, the model will be run locally, defaults to False (for huggingface provider)
|
||||
:type local: Optional[bool], optional
|
||||
:param default_headers: Set additional HTTP headers to be sent with requests to OpenAI
|
||||
@@ -181,8 +188,10 @@ class BaseLlmConfig(BaseConfig):
|
||||
self.base_url = base_url
|
||||
self.endpoint = endpoint
|
||||
self.model_kwargs = model_kwargs
|
||||
self.http_client = http_client
|
||||
self.http_async_client = http_async_client
|
||||
self.http_client = httpx.Client(proxies=http_client_proxies) if http_client_proxies else None
|
||||
self.http_async_client = (
|
||||
httpx.AsyncClient(proxies=http_async_client_proxies) if http_async_client_proxies else None
|
||||
)
|
||||
self.local = local
|
||||
self.default_headers = default_headers
|
||||
self.online = online
|
||||
|
||||
@@ -56,7 +56,13 @@ class OpenAILlm(BaseLlm):
|
||||
http_async_client=config.http_async_client,
|
||||
)
|
||||
else:
|
||||
chat = ChatOpenAI(**kwargs, api_key=api_key, base_url=base_url)
|
||||
chat = ChatOpenAI(
|
||||
**kwargs,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
http_client=config.http_client,
|
||||
http_async_client=config.http_async_client,
|
||||
)
|
||||
if self.tools:
|
||||
return self._query_function_call(chat, self.tools, messages)
|
||||
|
||||
@@ -69,8 +75,7 @@ class OpenAILlm(BaseLlm):
|
||||
messages: list[BaseMessage],
|
||||
) -> str:
|
||||
from langchain.output_parsers.openai_tools import JsonOutputToolsParser
|
||||
from langchain_core.utils.function_calling import \
|
||||
convert_to_openai_tool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
openai_tools = [convert_to_openai_tool(tools)]
|
||||
chat = chat.bind(tools=openai_tools).pipe(JsonOutputToolsParser())
|
||||
|
||||
@@ -442,6 +442,8 @@ def validate_config(config_data):
|
||||
Optional("base_url"): str,
|
||||
Optional("default_headers"): dict,
|
||||
Optional("api_version"): Or(str, datetime.date),
|
||||
Optional("http_client_proxies"): Or(str, dict),
|
||||
Optional("http_async_client_proxies"): Or(str, dict),
|
||||
},
|
||||
},
|
||||
Optional("vectordb"): {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
@@ -7,15 +8,27 @@ from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config():
|
||||
@pytest.fixture()
|
||||
def env_config():
|
||||
os.environ["OPENAI_API_KEY"] = "test_api_key"
|
||||
os.environ["OPENAI_API_BASE"] = "https://api.openai.com/v1/engines/"
|
||||
yield
|
||||
os.environ.pop("OPENAI_API_KEY")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config(env_config):
|
||||
config = BaseLlmConfig(
|
||||
temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt", model="gpt-3.5-turbo"
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
top_p=0.8,
|
||||
stream=False,
|
||||
system_prompt="System prompt",
|
||||
model="gpt-3.5-turbo",
|
||||
http_client_proxies=None,
|
||||
http_async_client_proxies=None,
|
||||
)
|
||||
yield config
|
||||
os.environ.pop("OPENAI_API_KEY")
|
||||
|
||||
|
||||
def test_get_llm_model_answer(config, mocker):
|
||||
@@ -75,6 +88,8 @@ def test_get_llm_model_answer_without_system_prompt(config, mocker):
|
||||
model_kwargs={"top_p": config.top_p},
|
||||
api_key=os.environ["OPENAI_API_KEY"],
|
||||
base_url=os.environ["OPENAI_API_BASE"],
|
||||
http_client=None,
|
||||
http_async_client=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -93,6 +108,8 @@ def test_get_llm_model_answer_with_special_headers(config, mocker):
|
||||
api_key=os.environ["OPENAI_API_KEY"],
|
||||
base_url=os.environ["OPENAI_API_BASE"],
|
||||
default_headers={"test": "test"},
|
||||
http_client=None,
|
||||
http_async_client=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -110,6 +127,8 @@ def test_get_llm_model_answer_with_model_kwargs(config, mocker):
|
||||
model_kwargs={"top_p": config.top_p, "response_format": {"type": "json_object"}},
|
||||
api_key=os.environ["OPENAI_API_KEY"],
|
||||
base_url=os.environ["OPENAI_API_BASE"],
|
||||
http_client=None,
|
||||
http_async_client=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -136,8 +155,78 @@ def test_get_llm_model_answer_with_tools(config, mocker, mock_return, expected):
|
||||
model_kwargs={"top_p": config.top_p},
|
||||
api_key=os.environ["OPENAI_API_KEY"],
|
||||
base_url=os.environ["OPENAI_API_BASE"],
|
||||
http_client=None,
|
||||
http_async_client=None,
|
||||
)
|
||||
mocked_convert_to_openai_tool.assert_called_once_with({"test": "test"})
|
||||
mocked_json_output_tools_parser.assert_called_once()
|
||||
|
||||
assert answer == expected
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_http_client_proxies(env_config, mocker):
|
||||
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
|
||||
mock_http_client = mocker.Mock(spec=httpx.Client)
|
||||
mock_http_client_instance = mocker.Mock(spec=httpx.Client)
|
||||
mock_http_client.return_value = mock_http_client_instance
|
||||
|
||||
mocker.patch("httpx.Client", new=mock_http_client)
|
||||
|
||||
config = BaseLlmConfig(
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
top_p=0.8,
|
||||
stream=False,
|
||||
system_prompt="System prompt",
|
||||
model="gpt-3.5-turbo",
|
||||
http_client_proxies="http://testproxy.mem0.net:8000",
|
||||
)
|
||||
|
||||
llm = OpenAILlm(config)
|
||||
llm.get_llm_model_answer("Test query")
|
||||
|
||||
mocked_openai_chat.assert_called_once_with(
|
||||
model=config.model,
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
model_kwargs={"top_p": config.top_p},
|
||||
api_key=os.environ["OPENAI_API_KEY"],
|
||||
base_url=os.environ["OPENAI_API_BASE"],
|
||||
http_client=mock_http_client_instance,
|
||||
http_async_client=None,
|
||||
)
|
||||
mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000")
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_http_async_client_proxies(env_config, mocker):
|
||||
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
|
||||
mock_http_async_client = mocker.Mock(spec=httpx.AsyncClient)
|
||||
mock_http_async_client_instance = mocker.Mock(spec=httpx.AsyncClient)
|
||||
mock_http_async_client.return_value = mock_http_async_client_instance
|
||||
|
||||
mocker.patch("httpx.AsyncClient", new=mock_http_async_client)
|
||||
|
||||
config = BaseLlmConfig(
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
top_p=0.8,
|
||||
stream=False,
|
||||
system_prompt="System prompt",
|
||||
model="gpt-3.5-turbo",
|
||||
http_async_client_proxies={"http://": "http://testproxy.mem0.net:8000"},
|
||||
)
|
||||
|
||||
llm = OpenAILlm(config)
|
||||
llm.get_llm_model_answer("Test query")
|
||||
|
||||
mocked_openai_chat.assert_called_once_with(
|
||||
model=config.model,
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
model_kwargs={"top_p": config.top_p},
|
||||
api_key=os.environ["OPENAI_API_KEY"],
|
||||
base_url=os.environ["OPENAI_API_BASE"],
|
||||
http_client=None,
|
||||
http_async_client=mock_http_async_client_instance,
|
||||
)
|
||||
mock_http_async_client.assert_called_once_with(proxies={"http://": "http://testproxy.mem0.net:8000"})
|
||||
|
||||
Reference in New Issue
Block a user