From 81b4431c9b616f33beda6c07f8c15272c804e8b7 Mon Sep 17 00:00:00 2001
From: Mitul Kataria <53906546+kmitul@users.noreply.github.com>
Date: Sun, 4 Aug 2024 00:01:43 +0900
Subject: [PATCH] Support Azure OpenAI LLM (#1581)
---
docs/components/llms.mdx | 19 +++----
mem0/llms/azure_openai.py | 80 ++++++++++++++++++++++++++++
mem0/llms/configs.py | 2 +-
mem0/utils/factory.py | 1 +
tests/llms/test_azure_openai.py | 94 +++++++++++++++++++++++++++++++++
5 files changed, 184 insertions(+), 12 deletions(-)
create mode 100644 mem0/llms/azure_openai.py
create mode 100644 tests/llms/test_azure_openai.py
diff --git a/docs/components/llms.mdx b/docs/components/llms.mdx
index b0a1c484..e113e959 100644
--- a/docs/components/llms.mdx
+++ b/docs/components/llms.mdx
@@ -16,7 +16,7 @@ Mem0 includes built-in support for various popular large language models. Memory
-
+
## OpenAI
@@ -263,26 +263,23 @@ m = Memory.from_config(config)
m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"})
```
-## OpenAI Azure
+## Azure OpenAI
-To use Azure AI models, you have to set the `AZURE_API_KEY`, `AZURE_API_BASE`, and `AZURE_API_VERSION` environment variables. You can obtain the Azure API key from the [Azure](https://azure.microsoft.com/).
+To use Azure OpenAI models, you have to set the `AZURE_OPENAI_API_KEY`, `AZURE_OPENAI_ENDPOINT`, and `OPENAI_API_VERSION` environment variables. You can obtain the Azure API key from the [Azure](https://azure.microsoft.com/).
```python
import os
from mem0 import Memory
-
-os.environ["AZURE_API_KEY"] = "your-api-key"
-
-# Needed to use custom models
-os.environ["AZURE_API_BASE"] = "your-api-base-url"
-os.environ["AZURE_API_VERSION"] = "version-to-use"
+os.environ["AZURE_OPENAI_API_KEY"] = "your-api-key"
+os.environ["AZURE_OPENAI_ENDPOINT"] = "your-api-base-url"
+os.environ["OPENAI_API_VERSION"] = "version-to-use"
config = {
"llm": {
- "provider": "litellm",
+ "provider": "azure_openai",
"config": {
- "model": "azure_ai/command-r-plus",
+ "model": "your-deployment-name",
"temperature": 0.1,
"max_tokens": 2000,
}
diff --git a/mem0/llms/azure_openai.py b/mem0/llms/azure_openai.py
new file mode 100644
index 00000000..bb249cff
--- /dev/null
+++ b/mem0/llms/azure_openai.py
@@ -0,0 +1,80 @@
+import json
+from typing import Dict, List, Optional
+
+from openai import AzureOpenAI
+
+from mem0.llms.base import LLMBase
+from mem0.configs.llms.base import BaseLlmConfig
+
+class AzureOpenAILLM(LLMBase):
+ def __init__(self, config: Optional[BaseLlmConfig] = None):
+ super().__init__(config)
+
+ # Model name should match the custom deployment name chosen for it.
+ if not self.config.model:
+ self.config.model="gpt-4o"
+ self.client = AzureOpenAI()
+
+ def _parse_response(self, response, tools):
+ """
+ Process the response based on whether tools are used or not.
+
+ Args:
+ response: The raw response from API.
+ tools: The list of tools provided in the request.
+
+ Returns:
+ str or dict: The processed response.
+ """
+ if tools:
+ processed_response = {
+ "content": response.choices[0].message.content,
+ "tool_calls": []
+ }
+
+ if response.choices[0].message.tool_calls:
+ for tool_call in response.choices[0].message.tool_calls:
+ processed_response["tool_calls"].append({
+ "name": tool_call.function.name,
+ "arguments": json.loads(tool_call.function.arguments)
+ })
+
+ return processed_response
+ else:
+ return response.choices[0].message.content
+
+
+ def generate_response(
+ self,
+ messages: List[Dict[str, str]],
+ response_format=None,
+ tools: Optional[List[Dict]] = None,
+ tool_choice: str = "auto",
+ ):
+ """
+ Generate a response based on the given messages using Azure OpenAI.
+
+ Args:
+ messages (list): List of message dicts containing 'role' and 'content'.
+ response_format (str or object, optional): Format of the response. Defaults to "text".
+ tools (list, optional): List of tools that the model can call. Defaults to None.
+ tool_choice (str, optional): Tool choice method. Defaults to "auto".
+
+ Returns:
+ str: The generated response.
+ """
+ params = {
+ "model": self.config.model,
+ "messages": messages,
+ "temperature": self.config.temperature,
+ "max_tokens": self.config.max_tokens,
+ "top_p": self.config.top_p
+ }
+ if response_format:
+ params["response_format"] = response_format
+ if tools:
+ params["tools"] = tools
+ params["tool_choice"] = tool_choice
+
+ response = self.client.chat.completions.create(**params)
+ return self._parse_response(response, tools)
diff --git a/mem0/llms/configs.py b/mem0/llms/configs.py
index beae356f..1bef9b36 100644
--- a/mem0/llms/configs.py
+++ b/mem0/llms/configs.py
@@ -14,7 +14,7 @@ class LlmConfig(BaseModel):
@field_validator("config")
def validate_config(cls, v, values):
provider = values.data.get("provider")
- if provider in ("openai", "ollama", "groq", "together", "aws_bedrock", "litellm"):
+ if provider in ("openai", "ollama", "groq", "together", "aws_bedrock", "litellm", "azure_openai"):
return v
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py
index 9c72efd6..f518b917 100644
--- a/mem0/utils/factory.py
+++ b/mem0/utils/factory.py
@@ -18,6 +18,7 @@ class LlmFactory:
"aws_bedrock": "mem0.llms.aws_bedrock.AWSBedrockLLM",
"litellm": "mem0.llms.litellm.LiteLLM",
"ollama": "mem0.llms.ollama.OllamaLLM",
+ "azure_openai": "mem0.llms.azure_openai.AzureOpenAILLM",
}
@classmethod
diff --git a/tests/llms/test_azure_openai.py b/tests/llms/test_azure_openai.py
new file mode 100644
index 00000000..95b73d80
--- /dev/null
+++ b/tests/llms/test_azure_openai.py
@@ -0,0 +1,94 @@
+import pytest
+from unittest.mock import Mock, patch
+from mem0.llms.azure_openai import AzureOpenAILLM
+from mem0.configs.llms.base import BaseLlmConfig
+
+MODEL = "gpt-4o" # or your custom deployment name
+TEMPERATURE = 0.7
+MAX_TOKENS = 100
+TOP_P = 1.0
+
+@pytest.fixture
+def mock_openai_client():
+ with patch('mem0.llms.azure_openai.AzureOpenAI') as mock_openai:
+ mock_client = Mock()
+ mock_openai.return_value = mock_client
+ yield mock_client
+
+def test_generate_response_without_tools(mock_openai_client):
+ config = BaseLlmConfig(model=MODEL, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P)
+ llm = AzureOpenAILLM(config)
+ messages = [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "Hello, how are you?"}
+ ]
+
+ mock_response = Mock()
+ mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))]
+ mock_openai_client.chat.completions.create.return_value = mock_response
+
+ response = llm.generate_response(messages)
+
+ mock_openai_client.chat.completions.create.assert_called_once_with(
+ model=MODEL,
+ messages=messages,
+ temperature=TEMPERATURE,
+ max_tokens=MAX_TOKENS,
+ top_p=TOP_P
+ )
+ assert response == "I'm doing well, thank you for asking!"
+
+
+def test_generate_response_with_tools(mock_openai_client):
+ config = BaseLlmConfig(model=MODEL, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P)
+ llm = AzureOpenAILLM(config)
+ messages = [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "Add a new memory: Today is a sunny day."}
+ ]
+ tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "add_memory",
+ "description": "Add a memory",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "data": {"type": "string", "description": "Data to add to memory"}
+ },
+ "required": ["data"],
+ },
+ },
+ }
+ ]
+
+ mock_response = Mock()
+ mock_message = Mock()
+ mock_message.content = "I've added the memory for you."
+
+ mock_tool_call = Mock()
+ mock_tool_call.function.name = "add_memory"
+ mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
+
+ mock_message.tool_calls = [mock_tool_call]
+ mock_response.choices = [Mock(message=mock_message)]
+ mock_openai_client.chat.completions.create.return_value = mock_response
+
+ response = llm.generate_response(messages, tools=tools)
+
+ mock_openai_client.chat.completions.create.assert_called_once_with(
+ model=MODEL,
+ messages=messages,
+ temperature=TEMPERATURE,
+ max_tokens=MAX_TOKENS,
+ top_p=TOP_P,
+ tools=tools,
+ tool_choice="auto"
+ )
+
+ assert response["content"] == "I've added the memory for you."
+ assert len(response["tool_calls"]) == 1
+ assert response["tool_calls"][0]["name"] == "add_memory"
+ assert response["tool_calls"][0]["arguments"] == {'data': 'Today is a sunny day.'}
+
\ No newline at end of file