diff --git a/docs/api-reference/advanced/configuration.mdx b/docs/api-reference/advanced/configuration.mdx index cc91ecbe..c9ecb79b 100644 --- a/docs/api-reference/advanced/configuration.mdx +++ b/docs/api-reference/advanced/configuration.mdx @@ -29,6 +29,7 @@ llm: model_kwargs: response_format: type: json_object + api_version: 2024-02-01 prompt: | Use the following pieces of context to answer the query at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. @@ -87,7 +88,8 @@ cache: "prompt": "Use the following pieces of context to answer the query at the end.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n$context\n\nQuery: $query\n\nHelpful Answer:", "system_prompt": "Act as William Shakespeare. Answer the following questions in the style of William Shakespeare.", "api_key": "sk-xxx", - "model_kwargs": {"response_format": {"type": "json_object"}} + "model_kwargs": {"response_format": {"type": "json_object"}}, + "api_version": "2024-02-01" } }, "vectordb": { diff --git a/docs/components/llms.mdx b/docs/components/llms.mdx index 3bd88b61..14f4ada2 100644 --- a/docs/components/llms.mdx +++ b/docs/components/llms.mdx @@ -193,8 +193,8 @@ import os from embedchain import App os.environ["OPENAI_API_TYPE"] = "azure" -os.environ["OPENAI_API_BASE"] = "https://xxx.openai.azure.com/" -os.environ["OPENAI_API_KEY"] = "xxx" +os.environ["AZURE_OPENAI_ENDPOINT"] = "https://xxx.openai.azure.com/" +os.environ["AZURE_OPENAI_KEY"] = "xxx" os.environ["OPENAI_API_VERSION"] = "xxx" app = App.from_config(config_path="config.yaml") diff --git a/embedchain/config/llm/base.py b/embedchain/config/llm/base.py index 39df717e..ac570b1c 100644 --- a/embedchain/config/llm/base.py +++ b/embedchain/config/llm/base.py @@ -103,6 +103,7 @@ class BaseLlmConfig(BaseConfig): http_async_client: Optional[Any] = None, local: Optional[bool] = False, default_headers: Optional[Mapping[str, str]] = None, + api_version: Optional[str] = None, ): """ Initializes a configuration class instance for the LLM. @@ -185,6 +186,7 @@ class BaseLlmConfig(BaseConfig): self.local = local self.default_headers = default_headers self.online = online + self.api_version = api_version if isinstance(prompt, str): prompt = Template(prompt) diff --git a/embedchain/llm/azure_openai.py b/embedchain/llm/azure_openai.py index b2542b42..6c5a03f1 100644 --- a/embedchain/llm/azure_openai.py +++ b/embedchain/llm/azure_openai.py @@ -25,7 +25,7 @@ class AzureOpenAILlm(BaseLlm): chat = AzureChatOpenAI( deployment_name=config.deployment_name, - openai_api_version="2023-05-15", + openai_api_version=str(config.api_version) if config.api_version else "2023-05-15", model_name=config.model or "gpt-3.5-turbo", temperature=config.temperature, max_tokens=config.max_tokens, diff --git a/embedchain/utils/misc.py b/embedchain/utils/misc.py index 31fd12ac..03db506a 100644 --- a/embedchain/utils/misc.py +++ b/embedchain/utils/misc.py @@ -1,3 +1,4 @@ +import datetime import itertools import json import logging @@ -439,6 +440,7 @@ def validate_config(config_data): Optional("local"): bool, Optional("base_url"): str, Optional("default_headers"): dict, + Optional("api_version"): Or(str, datetime.date) }, }, Optional("vectordb"): { diff --git a/tests/llm/test_azure_openai.py b/tests/llm/test_azure_openai.py index ab5a8d93..b936a302 100644 --- a/tests/llm/test_azure_openai.py +++ b/tests/llm/test_azure_openai.py @@ -64,3 +64,27 @@ def test_when_no_deployment_name_provided(): with pytest.raises(ValueError): llm = AzureOpenAILlm(config) llm.get_llm_model_answer("Test Prompt") + +def test_with_api_version(): + config = BaseLlmConfig( + deployment_name="azure_deployment", + temperature=0.7, + model="gpt-3.5-turbo", + max_tokens=50, + system_prompt="System Prompt", + api_version="2024-02-01", + ) + + with patch("langchain_community.chat_models.AzureChatOpenAI") as mock_chat: + + llm = AzureOpenAILlm(config) + llm.get_llm_model_answer("Test Prompt") + + mock_chat.assert_called_once_with( + deployment_name="azure_deployment", + openai_api_version="2024-02-01", + model_name="gpt-3.5-turbo", + temperature=0.7, + max_tokens=50, + streaming=False, + ) \ No newline at end of file