AzureOpenai access from behind company proxies. (#1459)
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
@@ -14,6 +16,8 @@ class BaseEmbedderConfig:
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
http_client_proxies: Optional[Union[Dict, str]] = None,
|
||||
http_async_client_proxies: Optional[Union[Dict, str]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize a new instance of an embedder config class.
|
||||
@@ -32,6 +36,11 @@ class BaseEmbedderConfig:
|
||||
:type api_base: Optional[str], optional
|
||||
:param model_kwargs: key-value arguments for the embedding model, defaults a dict inside init.
|
||||
:type model_kwargs: Optional[Dict[str, Any]], defaults a dict inside init.
|
||||
:param http_client_proxies: The proxy server settings used to create self.http_client, defaults to None
|
||||
:type http_client_proxies: Optional[Dict | str], optional
|
||||
:param http_async_client_proxies: The proxy server settings for async calls used to create
|
||||
self.http_async_client, defaults to None
|
||||
:type http_async_client_proxies: Optional[Dict | str], optional
|
||||
"""
|
||||
self.model = model
|
||||
self.deployment_name = deployment_name
|
||||
@@ -40,3 +49,7 @@ class BaseEmbedderConfig:
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.model_kwargs = model_kwargs or {}
|
||||
self.http_client = httpx.Client(proxies=http_client_proxies) if http_client_proxies else None
|
||||
self.http_async_client = (
|
||||
httpx.AsyncClient(proxies=http_async_client_proxies) if http_async_client_proxies else None
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain_community.embeddings import AzureOpenAIEmbeddings
|
||||
from langchain_openai import AzureOpenAIEmbeddings
|
||||
|
||||
from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
@@ -14,7 +14,11 @@ class AzureOpenAIEmbedder(BaseEmbedder):
|
||||
if self.config.model is None:
|
||||
self.config.model = "text-embedding-ada-002"
|
||||
|
||||
embeddings = AzureOpenAIEmbeddings(deployment=self.config.deployment_name)
|
||||
embeddings = AzureOpenAIEmbeddings(
|
||||
deployment=self.config.deployment_name,
|
||||
http_client=self.config.http_client,
|
||||
http_async_client=self.config.http_async_client,
|
||||
)
|
||||
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
|
||||
|
||||
self.set_embedding_fn(embedding_fn=embedding_fn)
|
||||
|
||||
@@ -30,6 +30,8 @@ class AzureOpenAILlm(BaseLlm):
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
streaming=config.stream,
|
||||
http_client=config.http_client,
|
||||
http_async_client=config.http_async_client,
|
||||
)
|
||||
|
||||
if config.top_p and config.top_p != 1:
|
||||
|
||||
@@ -479,6 +479,8 @@ def validate_config(config_data):
|
||||
Optional("base_url"): str,
|
||||
Optional("endpoint"): str,
|
||||
Optional("model_kwargs"): dict,
|
||||
Optional("http_client_proxies"): Or(str, dict),
|
||||
Optional("http_async_client_proxies"): Or(str, dict),
|
||||
},
|
||||
},
|
||||
Optional("embedding_model"): {
|
||||
|
||||
Reference in New Issue
Block a user