AzureOpenAI Embedding Model and LLM Model Initialisation from Config. (#1773)

This commit is contained in:
k10
2024-09-01 02:09:00 +05:30
committed by GitHub
parent ad233034ef
commit 077d0c47f9
10 changed files with 88 additions and 22 deletions

View File

@@ -51,6 +51,7 @@ Here's a comprehensive list of all parameters that can be used across different
| `http_client_proxies` | Allow proxy server settings | | `http_client_proxies` | Allow proxy server settings |
| `ollama_base_url` | Base URL for the Ollama embedding model | | `ollama_base_url` | Base URL for the Ollama embedding model |
| `model_kwargs` | Key-Value arguments for the Huggingface embedding model | | `model_kwargs` | Key-Value arguments for the Huggingface embedding model |
| `azure_kwargs` | Key-Value arguments for the AzureOpenAI embedding model |
| `openai_base_url` | Base URL for OpenAI API | OpenAI | | `openai_base_url` | Base URL for OpenAI API | OpenAI |

View File

@@ -2,7 +2,7 @@
title: Azure OpenAI title: Azure OpenAI
--- ---
To use Azure OpenAI embedding models, set the `AZURE_OPENAI_API_KEY` environment variable. You can obtain the Azure OpenAI API key from the Azure. To use Azure OpenAI embedding models, set the `EMBEDDING_AZURE_OPENAI_API_KEY`, `EMBEDDING_AZURE_DEPLOYMENT`, `EMBEDDING_AZURE_ENDPOINT` and `EMBEDDING_AZURE_API_VERSION` environment variables. You can obtain the Azure OpenAI API key from the Azure.
### Usage ### Usage
@@ -10,8 +10,10 @@ To use Azure OpenAI embedding models, set the `AZURE_OPENAI_API_KEY` environment
import os import os
from mem0 import Memory from mem0 import Memory
os.environ["OPENAI_API_KEY"] = "your_api_key" os.environ["EMBEDDING_AZURE_OPENAI_API_KEY"] = "your-api-key"
os.environ["AZURE_OPENAI_API_KEY"] = "your_api_key" os.environ["EMBEDDING_AZURE_DEPLOYMENT"] = "your-deployment-name"
os.environ["EMBEDDING_AZURE_ENDPOINT"] = "your-api-base-url"
os.environ["EMBEDDING_AZURE_API_VERSION"] = "version-to-use"
config = { config = {
@@ -19,6 +21,12 @@ config = {
"provider": "azure_openai", "provider": "azure_openai",
"config": { "config": {
"model": "text-embedding-3-large" "model": "text-embedding-3-large"
"azure_kwargs" : {
"api_version" : "",
"azure_deployment" : "",
"azure_endpoint" : "",
"api_key": ""
}
} }
} }
} }
@@ -35,4 +43,4 @@ Here are the parameters available for configuring Azure OpenAI embedder:
| --- | --- | --- | | --- | --- | --- |
| `model` | The name of the embedding model to use | `text-embedding-3-small` | | `model` | The name of the embedding model to use | `text-embedding-3-small` |
| `embedding_dims` | Dimensions of the embedding model | `1536` | | `embedding_dims` | Dimensions of the embedding model | `1536` |
| `api_key` | The Azure OpenAI API key | `None` | | `azure_kwargs` | The Azure OpenAI configs | `config_keys` |

View File

@@ -61,6 +61,7 @@ Here's the table based on the provided parameters:
| `app_name` | Application name | Openrouter | | `app_name` | Application name | Openrouter |
| `ollama_base_url` | Base URL for Ollama API | Ollama | | `ollama_base_url` | Base URL for Ollama API | Ollama |
| `openai_base_url` | Base URL for OpenAI API | OpenAI | | `openai_base_url` | Base URL for OpenAI API | OpenAI |
| `azure_kwargs` | Azure LLM args for initialization | AzureOpenAI |
## Supported LLMs ## Supported LLMs

View File

@@ -2,7 +2,7 @@
title: Azure OpenAI title: Azure OpenAI
--- ---
To use Azure OpenAI models, you have to set the `AZURE_OPENAI_API_KEY`, `AZURE_OPENAI_ENDPOINT`, and `OPENAI_API_VERSION` environment variables. You can obtain the Azure API key from the [Azure](https://azure.microsoft.com/). To use Azure OpenAI models, you have to set the `LLM_AZURE_OPENAI_API_KEY`, `LLM_AZURE_ENDPOINT`, `LLM_AZURE_DEPLOYMENT` and `LLM_AZURE_API_VERSION` environment variables. You can obtain the Azure API key from the [Azure](https://azure.microsoft.com/).
## Usage ## Usage
@@ -10,9 +10,10 @@ To use Azure OpenAI models, you have to set the `AZURE_OPENAI_API_KEY`, `AZURE_O
import os import os
from mem0 import Memory from mem0 import Memory
os.environ["AZURE_OPENAI_API_KEY"] = "your-api-key" os.environ["LLM_AZURE_OPENAI_API_KEY"] = "your-api-key"
os.environ["AZURE_OPENAI_ENDPOINT"] = "your-api-base-url" os.environ["LLM_AZURE_DEPLOYMENT"] = "your-deployment-name"
os.environ["OPENAI_API_VERSION"] = "version-to-use" os.environ["LLM_AZURE_ENDPOINT"] = "your-api-base-url"
os.environ["LLM_AZURE_API_VERSION"] = "version-to-use"
config = { config = {
"llm": { "llm": {
@@ -21,6 +22,12 @@ config = {
"model": "your-deployment-name", "model": "your-deployment-name",
"temperature": 0.1, "temperature": 0.1,
"max_tokens": 2000, "max_tokens": 2000,
"azure_kwargs" : {
"azure_deployment" : "",
"api_version" : "",
"azure_endpoint" : "",
"api_key" : ""
}
} }
} }
} }

View File

@@ -55,4 +55,20 @@ class MemoryConfig(BaseModel):
description="The version of the API", description="The version of the API",
default="v1.0", default="v1.0",
) )
class AzureConfig(BaseModel):
"""
Configuration settings for Azure.
Args:
api_key (str): The API key used for authenticating with the Azure service.
azure_deployment (str): The name of the Azure deployment.
azure_endpoint (str): The endpoint URL for the Azure service.
api_version (str): The version of the Azure API being used.
"""
api_key: str = Field(description="The API key used for authenticating with the Azure service.", default=None)
azure_deployment : str = Field(description="The name of the Azure deployment.", default=None)
azure_endpoint : str = Field(description="The endpoint URL for the Azure service.", default=None)
api_version : str = Field(description="The version of the Azure API being used.", default=None)

View File

@@ -1,4 +1,5 @@
from abc import ABC from abc import ABC
from mem0.configs.base import AzureConfig
from typing import Optional, Union, Dict from typing import Optional, Union, Dict
import httpx import httpx
@@ -21,6 +22,7 @@ class BaseEmbedderConfig(ABC):
# Huggingface specific # Huggingface specific
model_kwargs: Optional[dict] = None, model_kwargs: Optional[dict] = None,
# AzureOpenAI specific # AzureOpenAI specific
azure_kwargs: Optional[AzureConfig] = {},
http_client_proxies: Optional[Union[Dict, str]] = None, http_client_proxies: Optional[Union[Dict, str]] = None,
): ):
""" """
@@ -38,6 +40,8 @@ class BaseEmbedderConfig(ABC):
:type model_kwargs: Optional[Dict[str, Any]], defaults a dict inside init :type model_kwargs: Optional[Dict[str, Any]], defaults a dict inside init
:param openai_base_url: Openai base URL to be use, defaults to "https://api.openai.com/v1" :param openai_base_url: Openai base URL to be use, defaults to "https://api.openai.com/v1"
:type openai_base_url: Optional[str], optional :type openai_base_url: Optional[str], optional
:param azure_kwargs: key-value arguments for the AzureOpenAI embedding model, defaults a dict inside init
:type azure_kwargs: Optional[Dict[str, Any]], defaults a dict inside init
:param http_client_proxies: The proxy server settings used to create self.http_client, defaults to None :param http_client_proxies: The proxy server settings used to create self.http_client, defaults to None
:type http_client_proxies: Optional[Dict | str], optional :type http_client_proxies: Optional[Dict | str], optional
""" """
@@ -55,3 +59,6 @@ class BaseEmbedderConfig(ABC):
# Huggingface specific # Huggingface specific
self.model_kwargs = model_kwargs or {} self.model_kwargs = model_kwargs or {}
# AzureOpenAI specific
self.azure_kwargs = AzureConfig(**azure_kwargs) or {}

View File

@@ -1,4 +1,5 @@
from abc import ABC from abc import ABC
from mem0.configs.base import AzureConfig
from typing import Optional, Union, Dict from typing import Optional, Union, Dict
import httpx import httpx
@@ -27,7 +28,8 @@ class BaseLlmConfig(ABC):
app_name: Optional[str] = None, app_name: Optional[str] = None,
# Ollama specific # Ollama specific
ollama_base_url: Optional[str] = None, ollama_base_url: Optional[str] = None,
# AzureOpenAI specific
azure_kwargs: Optional[AzureConfig] = {},
# AzureOpenAI specific # AzureOpenAI specific
http_client_proxies: Optional[Union[Dict, str]] = None, http_client_proxies: Optional[Union[Dict, str]] = None,
): ):
@@ -62,6 +64,8 @@ class BaseLlmConfig(ABC):
:type ollama_base_url: Optional[str], optional :type ollama_base_url: Optional[str], optional
:param openai_base_url: Openai base URL to be use, defaults to "https://api.openai.com/v1" :param openai_base_url: Openai base URL to be use, defaults to "https://api.openai.com/v1"
:type openai_base_url: Optional[str], optional :type openai_base_url: Optional[str], optional
:param azure_kwargs: key-value arguments for the AzureOpenAI LLM model, defaults a dict inside init
:type azure_kwargs: Optional[Dict[str, Any]], defaults a dict inside init
:param http_client_proxies: The proxy server(s) settings used to create self.http_client, defaults to None :param http_client_proxies: The proxy server(s) settings used to create self.http_client, defaults to None
:type http_client_proxies: Optional[Dict | str], optional :type http_client_proxies: Optional[Dict | str], optional
""" """
@@ -86,3 +90,6 @@ class BaseLlmConfig(ABC):
# Ollama specific # Ollama specific
self.ollama_base_url = ollama_base_url self.ollama_base_url = ollama_base_url
# AzureOpenAI specific
self.azure_kwargs = AzureConfig(**azure_kwargs) or {}

View File

@@ -11,13 +11,18 @@ class AzureOpenAIEmbedding(EmbeddingBase):
def __init__(self, config: Optional[BaseEmbedderConfig] = None): def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config) super().__init__(config)
if self.config.model is None: api_key = os.getenv("EMBEDDING_AZURE_OPENAI_API_KEY") or self.config.azure_kwargs.api_key
self.config.model = "text-embedding-3-small" azure_deployment = os.getenv("EMBEDDING_AZURE_DEPLOYMENT") or self.config.azure_kwargs.azure_deployment
if self.config.embedding_dims is None: azure_endpoint = os.getenv("EMBEDDING_AZURE_ENDPOINT") or self.config.azure_kwargs.azure_endpoint
self.config.embedding_dims = 1536 api_version = os.getenv("EMBEDDING_AZURE_API_VERSION") or self.config.azure_kwargs.api_version
api_key = os.getenv("AZURE_OPENAI_API_KEY") or self.config.api_key self.client = AzureOpenAI(
self.client = AzureOpenAI(api_key=api_key, http_client=self.config.http_client) azure_deployment=azure_deployment,
azure_endpoint=azure_endpoint,
api_version=api_version,
api_key=api_key,
http_client=self.config.http_client
)
def embed(self, text): def embed(self, text):
""" """

View File

@@ -15,10 +15,20 @@ class AzureOpenAILLM(LLMBase):
# Model name should match the custom deployment name chosen for it. # Model name should match the custom deployment name chosen for it.
if not self.config.model: if not self.config.model:
self.config.model = "gpt-4o" self.config.model = "gpt-4o"
api_key = os.getenv("AZURE_OPENAI_API_KEY") or self.config.api_key api_key = os.getenv("LLM_AZURE_OPENAI_API_KEY") or self.config.azure_kwargs.api_key
self.client = AzureOpenAI(api_key=api_key, http_client=self.config.http_client) azure_deployment = os.getenv("LLM_AZURE_DEPLOYMENT") or self.config.azure_kwargs.azure_deployment
azure_endpoint = os.getenv("LLM_AZURE_ENDPOINT") or self.config.azure_kwargs.azure_endpoint
api_version = os.getenv("LLM_AZURE_API_VERSION") or self.config.azure_kwargs.api_version
self.client = AzureOpenAI(
azure_deployment=azure_deployment,
azure_endpoint=azure_endpoint,
api_version=api_version,
api_key=api_key,
http_client=self.config.http_client
)
def _parse_response(self, response, tools): def _parse_response(self, response, tools):
""" """
Process the response based on whether tools are used or not. Process the response based on whether tools are used or not.

View File

@@ -103,12 +103,16 @@ def test_generate_with_http_proxies():
with (patch("mem0.llms.azure_openai.AzureOpenAI") as mock_azure_openai, with (patch("mem0.llms.azure_openai.AzureOpenAI") as mock_azure_openai,
patch("httpx.Client", new=mock_http_client) as mock_http_client): patch("httpx.Client", new=mock_http_client) as mock_http_client):
config = BaseLlmConfig(model=MODEL, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P, config = BaseLlmConfig(model=MODEL, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P,
api_key="test", http_client_proxies="http://testproxy.mem0.net:8000") api_key="test", http_client_proxies="http://testproxy.mem0.net:8000",
azure_kwargs= {"api_key" : "test"})
_ = AzureOpenAILLM(config) _ = AzureOpenAILLM(config)
mock_azure_openai.assert_called_once_with( mock_azure_openai.assert_called_once_with(
api_key="test", api_key="test",
http_client=mock_http_client_instance http_client=mock_http_client_instance,
azure_deployment=None,
azure_endpoint=None,
api_version=None
) )
mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000") mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000")