http_client and http_async_client bugfix (#1454)
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
import logging
|
||||
import re
|
||||
from string import Template
|
||||
from typing import Any, Mapping, Optional
|
||||
from typing import Any, Mapping, Optional, Dict, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from embedchain.config.base_config import BaseConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
@@ -99,8 +101,8 @@ class BaseLlmConfig(BaseConfig):
|
||||
base_url: Optional[str] = None,
|
||||
endpoint: Optional[str] = None,
|
||||
model_kwargs: Optional[dict[str, Any]] = None,
|
||||
http_client: Optional[Any] = None,
|
||||
http_async_client: Optional[Any] = None,
|
||||
http_client_proxies: Optional[Union[Dict, str]] = None,
|
||||
http_async_client_proxies: Optional[Union[Dict, str]] = None,
|
||||
local: Optional[bool] = False,
|
||||
default_headers: Optional[Mapping[str, str]] = None,
|
||||
api_version: Optional[str] = None,
|
||||
@@ -149,6 +151,11 @@ class BaseLlmConfig(BaseConfig):
|
||||
:type callbacks: Optional[list], optional
|
||||
:param query_type: The type of query to use, defaults to None
|
||||
:type query_type: Optional[str], optional
|
||||
:param http_client_proxies: The proxy server settings used to create self.http_client, defaults to None
|
||||
:type http_client_proxies: Optional[Dict | str], optional
|
||||
:param http_async_client_proxies: The proxy server settings for async calls used to create
|
||||
self.http_async_client, defaults to None
|
||||
:type http_async_client_proxies: Optional[Dict | str], optional
|
||||
:param local: If True, the model will be run locally, defaults to False (for huggingface provider)
|
||||
:type local: Optional[bool], optional
|
||||
:param default_headers: Set additional HTTP headers to be sent with requests to OpenAI
|
||||
@@ -181,8 +188,10 @@ class BaseLlmConfig(BaseConfig):
|
||||
self.base_url = base_url
|
||||
self.endpoint = endpoint
|
||||
self.model_kwargs = model_kwargs
|
||||
self.http_client = http_client
|
||||
self.http_async_client = http_async_client
|
||||
self.http_client = httpx.Client(proxies=http_client_proxies) if http_client_proxies else None
|
||||
self.http_async_client = (
|
||||
httpx.AsyncClient(proxies=http_async_client_proxies) if http_async_client_proxies else None
|
||||
)
|
||||
self.local = local
|
||||
self.default_headers = default_headers
|
||||
self.online = online
|
||||
|
||||
@@ -56,7 +56,13 @@ class OpenAILlm(BaseLlm):
|
||||
http_async_client=config.http_async_client,
|
||||
)
|
||||
else:
|
||||
chat = ChatOpenAI(**kwargs, api_key=api_key, base_url=base_url)
|
||||
chat = ChatOpenAI(
|
||||
**kwargs,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
http_client=config.http_client,
|
||||
http_async_client=config.http_async_client,
|
||||
)
|
||||
if self.tools:
|
||||
return self._query_function_call(chat, self.tools, messages)
|
||||
|
||||
@@ -69,8 +75,7 @@ class OpenAILlm(BaseLlm):
|
||||
messages: list[BaseMessage],
|
||||
) -> str:
|
||||
from langchain.output_parsers.openai_tools import JsonOutputToolsParser
|
||||
from langchain_core.utils.function_calling import \
|
||||
convert_to_openai_tool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
openai_tools = [convert_to_openai_tool(tools)]
|
||||
chat = chat.bind(tools=openai_tools).pipe(JsonOutputToolsParser())
|
||||
|
||||
@@ -442,6 +442,8 @@ def validate_config(config_data):
|
||||
Optional("base_url"): str,
|
||||
Optional("default_headers"): dict,
|
||||
Optional("api_version"): Or(str, datetime.date),
|
||||
Optional("http_client_proxies"): Or(str, dict),
|
||||
Optional("http_async_client_proxies"): Or(str, dict),
|
||||
},
|
||||
},
|
||||
Optional("vectordb"): {
|
||||
|
||||
Reference in New Issue
Block a user