http_client and http_async_client bugfix (#1454)

This commit is contained in:
Pranav Puranik
2024-07-02 18:13:33 -05:00
committed by GitHub
parent b305d674de
commit 5258fd91ea
5 changed files with 124 additions and 14 deletions

View File

@@ -1,7 +1,9 @@
import logging
import re
from string import Template
from typing import Any, Mapping, Optional
from typing import Any, Mapping, Optional, Dict, Union
import httpx
from embedchain.config.base_config import BaseConfig
from embedchain.helpers.json_serializable import register_deserializable
@@ -99,8 +101,8 @@ class BaseLlmConfig(BaseConfig):
base_url: Optional[str] = None,
endpoint: Optional[str] = None,
model_kwargs: Optional[dict[str, Any]] = None,
http_client: Optional[Any] = None,
http_async_client: Optional[Any] = None,
http_client_proxies: Optional[Union[Dict, str]] = None,
http_async_client_proxies: Optional[Union[Dict, str]] = None,
local: Optional[bool] = False,
default_headers: Optional[Mapping[str, str]] = None,
api_version: Optional[str] = None,
@@ -149,6 +151,11 @@ class BaseLlmConfig(BaseConfig):
:type callbacks: Optional[list], optional
:param query_type: The type of query to use, defaults to None
:type query_type: Optional[str], optional
:param http_client_proxies: The proxy server settings used to create self.http_client, defaults to None
:type http_client_proxies: Optional[Dict | str], optional
:param http_async_client_proxies: The proxy server settings for async calls used to create
self.http_async_client, defaults to None
:type http_async_client_proxies: Optional[Dict | str], optional
:param local: If True, the model will be run locally, defaults to False (for huggingface provider)
:type local: Optional[bool], optional
:param default_headers: Set additional HTTP headers to be sent with requests to OpenAI
@@ -181,8 +188,10 @@ class BaseLlmConfig(BaseConfig):
self.base_url = base_url
self.endpoint = endpoint
self.model_kwargs = model_kwargs
self.http_client = http_client
self.http_async_client = http_async_client
self.http_client = httpx.Client(proxies=http_client_proxies) if http_client_proxies else None
self.http_async_client = (
httpx.AsyncClient(proxies=http_async_client_proxies) if http_async_client_proxies else None
)
self.local = local
self.default_headers = default_headers
self.online = online

View File

@@ -56,7 +56,13 @@ class OpenAILlm(BaseLlm):
http_async_client=config.http_async_client,
)
else:
chat = ChatOpenAI(**kwargs, api_key=api_key, base_url=base_url)
chat = ChatOpenAI(
**kwargs,
api_key=api_key,
base_url=base_url,
http_client=config.http_client,
http_async_client=config.http_async_client,
)
if self.tools:
return self._query_function_call(chat, self.tools, messages)
@@ -69,8 +75,7 @@ class OpenAILlm(BaseLlm):
messages: list[BaseMessage],
) -> str:
from langchain.output_parsers.openai_tools import JsonOutputToolsParser
from langchain_core.utils.function_calling import \
convert_to_openai_tool
from langchain_core.utils.function_calling import convert_to_openai_tool
openai_tools = [convert_to_openai_tool(tools)]
chat = chat.bind(tools=openai_tools).pipe(JsonOutputToolsParser())

View File

@@ -442,6 +442,8 @@ def validate_config(config_data):
Optional("base_url"): str,
Optional("default_headers"): dict,
Optional("api_version"): Or(str, datetime.date),
Optional("http_client_proxies"): Or(str, dict),
Optional("http_async_client_proxies"): Or(str, dict),
},
},
Optional("vectordb"): {