AzureOpenai access from behind company proxies. (#1459)

This commit is contained in:
Pranav Puranik
2024-08-01 13:53:38 -05:00
committed by GitHub
parent 563a130141
commit c197a5fe93
8 changed files with 173 additions and 6 deletions

View File

@@ -1,4 +1,6 @@
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union
import httpx
from embedchain.helpers.json_serializable import register_deserializable
@@ -14,6 +16,8 @@ class BaseEmbedderConfig:
api_key: Optional[str] = None,
api_base: Optional[str] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
http_client_proxies: Optional[Union[Dict, str]] = None,
http_async_client_proxies: Optional[Union[Dict, str]] = None,
):
"""
Initialize a new instance of an embedder config class.
@@ -32,6 +36,11 @@ class BaseEmbedderConfig:
:type api_base: Optional[str], optional
:param model_kwargs: key-value arguments for the embedding model, defaults a dict inside init.
:type model_kwargs: Optional[Dict[str, Any]], defaults a dict inside init.
:param http_client_proxies: The proxy server settings used to create self.http_client, defaults to None
:type http_client_proxies: Optional[Dict | str], optional
:param http_async_client_proxies: The proxy server settings for async calls used to create
self.http_async_client, defaults to None
:type http_async_client_proxies: Optional[Dict | str], optional
"""
self.model = model
self.deployment_name = deployment_name
@@ -40,3 +49,7 @@ class BaseEmbedderConfig:
self.api_key = api_key
self.api_base = api_base
self.model_kwargs = model_kwargs or {}
self.http_client = httpx.Client(proxies=http_client_proxies) if http_client_proxies else None
self.http_async_client = (
httpx.AsyncClient(proxies=http_async_client_proxies) if http_async_client_proxies else None
)

View File

@@ -1,6 +1,6 @@
from typing import Optional
from langchain_community.embeddings import AzureOpenAIEmbeddings
from langchain_openai import AzureOpenAIEmbeddings
from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
@@ -14,7 +14,11 @@ class AzureOpenAIEmbedder(BaseEmbedder):
if self.config.model is None:
self.config.model = "text-embedding-ada-002"
embeddings = AzureOpenAIEmbeddings(deployment=self.config.deployment_name)
embeddings = AzureOpenAIEmbeddings(
deployment=self.config.deployment_name,
http_client=self.config.http_client,
http_async_client=self.config.http_async_client,
)
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
self.set_embedding_fn(embedding_fn=embedding_fn)

View File

@@ -30,6 +30,8 @@ class AzureOpenAILlm(BaseLlm):
temperature=config.temperature,
max_tokens=config.max_tokens,
streaming=config.stream,
http_client=config.http_client,
http_async_client=config.http_async_client,
)
if config.top_p and config.top_p != 1:

View File

@@ -479,6 +479,8 @@ def validate_config(config_data):
Optional("base_url"): str,
Optional("endpoint"): str,
Optional("model_kwargs"): dict,
Optional("http_client_proxies"): Or(str, dict),
Optional("http_async_client_proxies"): Or(str, dict),
},
},
Optional("embedding_model"): {