AzureOpenai access from behind company proxies. (#1459)
This commit is contained in:
@@ -55,6 +55,7 @@ embedder:
|
||||
config:
|
||||
model: 'text-embedding-ada-002'
|
||||
api_key: sk-xxx
|
||||
http_client_proxies: http://testproxy.mem0.net:8000
|
||||
|
||||
chunker:
|
||||
chunk_size: 2000
|
||||
@@ -106,7 +107,8 @@ cache:
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-ada-002",
|
||||
"api_key": "sk-xxx"
|
||||
"api_key": "sk-xxx",
|
||||
"http_client_proxies": "http://testproxy.mem0.net:8000",
|
||||
}
|
||||
},
|
||||
"chunker": {
|
||||
@@ -168,7 +170,8 @@ config = {
|
||||
'provider': 'openai',
|
||||
'config': {
|
||||
'model': 'text-embedding-ada-002',
|
||||
'api_key': 'sk-xxx'
|
||||
'api_key': 'sk-xxx',
|
||||
"http_client_proxies": "http://testproxy.mem0.net:8000",
|
||||
}
|
||||
},
|
||||
'chunker': {
|
||||
@@ -236,6 +239,8 @@ Alright, let's dive into what each key means in the yaml config above:
|
||||
- `title` (String): The title for the embedding model for Google Embedder.
|
||||
- `task_type` (String): The task type for the embedding model for Google Embedder.
|
||||
- `model_kwargs` (Dict): Used to pass extra arguments to embedders.
|
||||
- `http_client_proxies` (Dict | String): The proxy server settings used to create `self.http_client` using `httpx.Client(proxies=http_client_proxies)`
|
||||
- `http_async_client_proxies` (Dict | String): The proxy server settings for async calls used to create `self.http_async_client` using `httpx.AsyncClient(proxies=http_async_client_proxies)`
|
||||
5. `chunker` Section:
|
||||
- `chunk_size` (Integer): The size of each chunk of text that is sent to the language model.
|
||||
- `chunk_overlap` (Integer): The amount of overlap between each chunk of text.
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
@@ -14,6 +16,8 @@ class BaseEmbedderConfig:
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
http_client_proxies: Optional[Union[Dict, str]] = None,
|
||||
http_async_client_proxies: Optional[Union[Dict, str]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize a new instance of an embedder config class.
|
||||
@@ -32,6 +36,11 @@ class BaseEmbedderConfig:
|
||||
:type api_base: Optional[str], optional
|
||||
:param model_kwargs: key-value arguments for the embedding model, defaults a dict inside init.
|
||||
:type model_kwargs: Optional[Dict[str, Any]], defaults a dict inside init.
|
||||
:param http_client_proxies: The proxy server settings used to create self.http_client, defaults to None
|
||||
:type http_client_proxies: Optional[Dict | str], optional
|
||||
:param http_async_client_proxies: The proxy server settings for async calls used to create
|
||||
self.http_async_client, defaults to None
|
||||
:type http_async_client_proxies: Optional[Dict | str], optional
|
||||
"""
|
||||
self.model = model
|
||||
self.deployment_name = deployment_name
|
||||
@@ -40,3 +49,7 @@ class BaseEmbedderConfig:
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.model_kwargs = model_kwargs or {}
|
||||
self.http_client = httpx.Client(proxies=http_client_proxies) if http_client_proxies else None
|
||||
self.http_async_client = (
|
||||
httpx.AsyncClient(proxies=http_async_client_proxies) if http_async_client_proxies else None
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain_community.embeddings import AzureOpenAIEmbeddings
|
||||
from langchain_openai import AzureOpenAIEmbeddings
|
||||
|
||||
from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
@@ -14,7 +14,11 @@ class AzureOpenAIEmbedder(BaseEmbedder):
|
||||
if self.config.model is None:
|
||||
self.config.model = "text-embedding-ada-002"
|
||||
|
||||
embeddings = AzureOpenAIEmbeddings(deployment=self.config.deployment_name)
|
||||
embeddings = AzureOpenAIEmbeddings(
|
||||
deployment=self.config.deployment_name,
|
||||
http_client=self.config.http_client,
|
||||
http_async_client=self.config.http_async_client,
|
||||
)
|
||||
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
|
||||
|
||||
self.set_embedding_fn(embedding_fn=embedding_fn)
|
||||
|
||||
@@ -30,6 +30,8 @@ class AzureOpenAILlm(BaseLlm):
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
streaming=config.stream,
|
||||
http_client=config.http_client,
|
||||
http_async_client=config.http_async_client,
|
||||
)
|
||||
|
||||
if config.top_p and config.top_p != 1:
|
||||
|
||||
@@ -479,6 +479,8 @@ def validate_config(config_data):
|
||||
Optional("base_url"): str,
|
||||
Optional("endpoint"): str,
|
||||
Optional("model_kwargs"): dict,
|
||||
Optional("http_client_proxies"): Or(str, dict),
|
||||
Optional("http_async_client_proxies"): Or(str, dict),
|
||||
},
|
||||
},
|
||||
Optional("embedding_model"): {
|
||||
|
||||
12
embedchain/poetry.lock
generated
12
embedchain/poetry.lock
generated
@@ -3255,6 +3255,7 @@ description = "Nvidia JIT LTO Library"
|
||||
optional = true
|
||||
python-versions = ">=3"
|
||||
files = [
|
||||
{file = "nvidia_nvjitlink_cu12-12.5.82-py3-none-manylinux2014_aarch64.whl", hash = "sha256:98103729cc5226e13ca319a10bbf9433bbbd44ef64fe72f45f067cacc14b8d27"},
|
||||
{file = "nvidia_nvjitlink_cu12-12.5.82-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f9b37bc5c8cf7509665cb6ada5aaa0ce65618f2332b7d3e78e9790511f111212"},
|
||||
{file = "nvidia_nvjitlink_cu12-12.5.82-py3-none-win_amd64.whl", hash = "sha256:e782564d705ff0bf61ac3e1bf730166da66dd2fe9012f111ede5fc49b64ae697"},
|
||||
]
|
||||
@@ -4614,6 +4615,7 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"},
|
||||
@@ -4621,8 +4623,16 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"},
|
||||
{file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"},
|
||||
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"},
|
||||
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"},
|
||||
@@ -4639,6 +4649,7 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"},
|
||||
@@ -4646,6 +4657,7 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"},
|
||||
{file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"},
|
||||
|
||||
52
embedchain/tests/embedder/test_azure_openai_embedder.py
Normal file
52
embedchain/tests/embedder/test_azure_openai_embedder.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from unittest.mock import patch, Mock
|
||||
|
||||
import httpx
|
||||
|
||||
from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.azure_openai import AzureOpenAIEmbedder
|
||||
|
||||
|
||||
def test_azure_openai_embedder_with_http_client(monkeypatch):
|
||||
mock_http_client = Mock(spec=httpx.Client)
|
||||
mock_http_client_instance = Mock(spec=httpx.Client)
|
||||
mock_http_client.return_value = mock_http_client_instance
|
||||
|
||||
with patch("embedchain.embedder.azure_openai.AzureOpenAIEmbeddings") as mock_embeddings, patch(
|
||||
"httpx.Client", new=mock_http_client
|
||||
) as mock_http_client:
|
||||
config = BaseEmbedderConfig(
|
||||
deployment_name="text-embedding-ada-002",
|
||||
http_client_proxies="http://testproxy.mem0.net:8000",
|
||||
)
|
||||
|
||||
_ = AzureOpenAIEmbedder(config=config)
|
||||
|
||||
mock_embeddings.assert_called_once_with(
|
||||
deployment="text-embedding-ada-002",
|
||||
http_client=mock_http_client_instance,
|
||||
http_async_client=None,
|
||||
)
|
||||
mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000")
|
||||
|
||||
|
||||
def test_azure_openai_embedder_with_http_async_client(monkeypatch):
|
||||
mock_http_async_client = Mock(spec=httpx.AsyncClient)
|
||||
mock_http_async_client_instance = Mock(spec=httpx.AsyncClient)
|
||||
mock_http_async_client.return_value = mock_http_async_client_instance
|
||||
|
||||
with patch("embedchain.embedder.azure_openai.AzureOpenAIEmbeddings") as mock_embeddings, patch(
|
||||
"httpx.AsyncClient", new=mock_http_async_client
|
||||
) as mock_http_async_client:
|
||||
config = BaseEmbedderConfig(
|
||||
deployment_name="text-embedding-ada-002",
|
||||
http_async_client_proxies={"http://": "http://testproxy.mem0.net:8000"},
|
||||
)
|
||||
|
||||
_ = AzureOpenAIEmbedder(config=config)
|
||||
|
||||
mock_embeddings.assert_called_once_with(
|
||||
deployment="text-embedding-ada-002",
|
||||
http_client=None,
|
||||
http_async_client=mock_http_async_client_instance,
|
||||
)
|
||||
mock_http_async_client.assert_called_once_with(proxies={"http://": "http://testproxy.mem0.net:8000"})
|
||||
@@ -1,5 +1,6 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import Mock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
@@ -43,6 +44,8 @@ def test_get_answer(azure_openai_llm):
|
||||
temperature=azure_openai_llm.config.temperature,
|
||||
max_tokens=azure_openai_llm.config.max_tokens,
|
||||
streaming=azure_openai_llm.config.stream,
|
||||
http_client=None,
|
||||
http_async_client=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -84,4 +87,78 @@ def test_with_api_version():
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
streaming=False,
|
||||
http_client=None,
|
||||
http_async_client=None,
|
||||
)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_http_client_proxies():
|
||||
mock_http_client = Mock(spec=httpx.Client)
|
||||
mock_http_client_instance = Mock(spec=httpx.Client)
|
||||
mock_http_client.return_value = mock_http_client_instance
|
||||
|
||||
with patch("langchain_openai.AzureChatOpenAI") as mock_chat, patch(
|
||||
"httpx.Client", new=mock_http_client
|
||||
) as mock_http_client:
|
||||
mock_chat.return_value.invoke.return_value.content = "Mocked response"
|
||||
|
||||
config = BaseLlmConfig(
|
||||
deployment_name="azure_deployment",
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
stream=False,
|
||||
system_prompt="System prompt",
|
||||
model="gpt-3.5-turbo",
|
||||
http_client_proxies="http://testproxy.mem0.net:8000",
|
||||
)
|
||||
|
||||
llm = AzureOpenAILlm(config)
|
||||
llm.get_llm_model_answer("Test query")
|
||||
|
||||
mock_chat.assert_called_once_with(
|
||||
deployment_name="azure_deployment",
|
||||
openai_api_version="2024-02-01",
|
||||
model_name="gpt-3.5-turbo",
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
streaming=False,
|
||||
http_client=mock_http_client_instance,
|
||||
http_async_client=None,
|
||||
)
|
||||
mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000")
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_http_async_client_proxies():
|
||||
mock_http_async_client = Mock(spec=httpx.AsyncClient)
|
||||
mock_http_async_client_instance = Mock(spec=httpx.AsyncClient)
|
||||
mock_http_async_client.return_value = mock_http_async_client_instance
|
||||
|
||||
with patch("langchain_openai.AzureChatOpenAI") as mock_chat, patch(
|
||||
"httpx.AsyncClient", new=mock_http_async_client
|
||||
) as mock_http_async_client:
|
||||
mock_chat.return_value.invoke.return_value.content = "Mocked response"
|
||||
|
||||
config = BaseLlmConfig(
|
||||
deployment_name="azure_deployment",
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
stream=False,
|
||||
system_prompt="System prompt",
|
||||
model="gpt-3.5-turbo",
|
||||
http_async_client_proxies={"http://": "http://testproxy.mem0.net:8000"},
|
||||
)
|
||||
|
||||
llm = AzureOpenAILlm(config)
|
||||
llm.get_llm_model_answer("Test query")
|
||||
|
||||
mock_chat.assert_called_once_with(
|
||||
deployment_name="azure_deployment",
|
||||
openai_api_version="2024-02-01",
|
||||
model_name="gpt-3.5-turbo",
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
streaming=False,
|
||||
http_client=None,
|
||||
http_async_client=mock_http_async_client_instance,
|
||||
)
|
||||
mock_http_async_client.assert_called_once_with(proxies={"http://": "http://testproxy.mem0.net:8000"})
|
||||
|
||||
Reference in New Issue
Block a user