Add OpenAI proxy (#1503)

Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
Dev Khant
2024-08-02 20:14:27 +05:30
committed by GitHub
parent 51092b0b64
commit 419dc6598c
18 changed files with 637 additions and 135 deletions

View File

@@ -92,6 +92,33 @@ history = m.history(memory_id=<memory_id_1>)
# Logs corresponding to memory_id_1 --> {'prev_value': 'Working on improving tennis skills and interested in online courses for tennis.', 'new_value': 'Likes to play tennis on weekends' }
```
### Mem0 Platform
```python
from mem0 import MemoryClient
client = MemoryClient(api_key="your-api-key") # get api_key from https://app.mem0.ai/
# Store messages
messages = [
{"role": "user", "content": "Hi, I'm Alex. I'm a vegetarian and I'm allergic to nuts."},
{"role": "assistant", "content": "Hello Alex! I've noted that you're a vegetarian and have a nut allergy. I'll keep this in mind for any food-related recommendations or discussions."}
]
result = client.add(messages, user_id="alex")
print(result)
# Retrieve memories
all_memories = client.get_all(user_id="alex")
print(all_memories)
# Search memories
query = "What do you know about me?"
related_memories = client.search(query, user_id="alex")
# Get memory history
history = client.history(memory_id="m1")
print(history)
```
=======
> [!TIP]
> If you are looking for a hosted version and don't want to setup the infrastucture yourself, checkout [Mem0 Platform Docs](https://docs.mem0.ai/platform/quickstart) to get started in minutes.

View File

@@ -0,0 +1,22 @@
---
title: OpenAI Compatibility
---
Mem0 seamlessly offers an OpenAI-compatible API, making it easy to incorporate into existing projects.
## Mem0 Params for Chat Completion
- `user_id` (Optional[str]): Identifier for the user.
- `agent_id` (Optional[str]): Identifier for the agent.
- `run_id` (Optional[str]): Identifier for the run.
- `metadata` (Optional[dict]): Additional metadata to be stored with the memory.
- `filters` (Optional[dict]): Filters to apply when searching for relevant memories.
- `limit` (Optional[int]): Maximum number of relevant memories to retrieve. Default is 10.
Other parameters are similar to OpenAI's API, making it easy to integrate Mem0 into your existing applications.

View File

@@ -67,6 +67,12 @@
"components/vectordb.mdx"
]
},
{
"group": "Features",
"pages":[
"features/openai_compatibility"
]
},
{
"group": "Integrations",
"pages": [

View File

@@ -60,28 +60,26 @@ m = Memory.from_config(config)
### Store a Memory
```python
<CodeGroup>
```python Code
# For a user
result = m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"})
print(result)
```
Output:
```python
```json Output
{'message': 'ok'}
```
</CodeGroup>
### Retrieve Memories
```python
<CodeGroup>
```python Code
# Get all memories
all_memories = m.get_all()
print(all_memories)
```
Output:
```python
```json Output
[
{
"id":"13efe83b-a8df-4ec0-814e-428d6e8451eb",
@@ -94,15 +92,15 @@ Output:
}
]
```
</CodeGroup>
```python
<CodeGroup>
```python Code
# Get a single memory by ID
specific_memory = m.get("m1")
print(specific_memory)
```
Output:
```python
```json Output
{
"id":"13efe83b-a8df-4ec0-814e-428d6e8451eb",
"memory":"Likes to play cricket on weekends",
@@ -113,17 +111,16 @@ Output:
"user_id":"alice"
}
```
</CodeGroup>
### Search Memories
```python
<CodeGroup>
```python Code
related_memories = m.search(query="What are Alice's hobbies?", user_id="alice")
print(related_memories)
```
Output:
```python
```json Output
[
{
"id":"ea925981-272f-40dd-b576-be64e4871429",
@@ -139,28 +136,28 @@ Output:
}
]
```
</CodeGroup>
### Update a Memory
```python
<CodeGroup>
```python Code
result = m.update(memory_id="m1", data="Likes to play tennis on weekends")
print(result)
```
Output:
```python
```json Output
{'message': 'Memory updated successfully!'}
```
</CodeGroup>
### Memory History
```python
<CodeGroup>
```python Code
history = m.history(memory_id="m1")
print(history)
```
Output:
```python
```json Output
[
{
"id":"4e0e63d6-a9c6-43c0-b11c-a1bad3fc7abb",
@@ -182,6 +179,7 @@ Output:
}
]
```
</CodeGroup>
### Delete Memory
@@ -197,6 +195,103 @@ m.delete_all(user_id="alice") # Delete all memories
m.reset() # Reset all memories
```
## Chat Completion
Mem0 can be easily integrate into chat applications to enhance conversational agents with structured memory. Mem0's APIs are designed to be compatible with OpenAI's, with the goal of making it easy to leverage Mem0 in applications you may have already built.
If you have a `Mem0 API key`, you can use it to initialize the client. Alternatively, you can initialize Mem0 without an API key if you're using it locally.
Mem0 supports several language models (LLMs) through integration with various [providers](https://litellm.vercel.app/docs/providers).
## Use Mem0 Platform
```python
from mem0 import Mem0
client = Mem0(api_key="m0-xxx")
# First interaction: Storing user preferences
messages = [
{
"role": "user",
"content": "I love indian food but I cannot eat pizza since allergic to cheese."
},
]
user_id = "deshraj"
chat_completion = client.chat.completions.create(messages=messages, model="gpt-4o-mini", user_id=user_id)
# Memory saved after this will look like: "Loves Indian food. Allergic to cheese and cannot eat pizza."
# Second interaction: Leveraging stored memory
messages = [
{
"role": "user",
"content": "Suggest restaurants in San Francisco to eat.",
}
]
chat_completion = client.chat.completions.create(messages=messages, model="gpt-4o-mini", user_id=user_id)
print(chat_completion.choices[0].message.content)
# Answer: You might enjoy Indian restaurants in San Francisco, such as Amber India, Dosa, or Curry Up Now, which offer delicious options without cheese.
```
In this example, you can see how the second response is tailored based on the information provided in the first interaction. Mem0 remembers the user's preference for Indian food and their cheese allergy, using this information to provide more relevant and personalized restaurant suggestions in San Francisco.
### Use Mem0 OSS
```python
config = {
"vector_store": {
"provider": "qdrant",
"config": {
"host": "localhost",
"port": 6333,
}
},
}
client = Mem0(config=config)
chat_completion = client.chat.completions.create(
messages=[
{
"role": "user",
"content": "What's the capital of France?",
}
],
model="gpt-4o",
)
```
## APIs
Get started with using Mem0 APIs in your applications. For more details, refer to the [Platform](/platform/quickstart.mdx).
Here is an example of how to use Mem0 APIs:
```python
from mem0 import MemoryClient
client = MemoryClient(api_key="your-api-key") # get api_key from https://app.mem0.ai/
# Store messages
messages = [
{"role": "user", "content": "Hi, I'm Alex. I'm a vegetarian and I'm allergic to nuts."},
{"role": "assistant", "content": "Hello Alex! I've noted that you're a vegetarian and have a nut allergy. I'll keep this in mind for any food-related recommendations or discussions."}
]
result = client.add(messages, user_id="alex")
print(result)
# Retrieve memories
all_memories = client.get_all(user_id="alex")
print(all_memories)
# Search memories
query = "What do you know about me?"
related_memories = client.search(query, user_id="alex")
# Get memory history
history = client.history(memory_id="m1")
print(history)
```
If you have any questions, please feel free to reach out to us using one of the following methods:

View File

@@ -4,3 +4,4 @@ __version__ = importlib.metadata.version("mem0ai")
from mem0.memory.main import Memory # noqa
from mem0.client.main import MemoryClient # noqa
from mem0.proxy.main import Mem0 #noqa

View File

@@ -0,0 +1,39 @@
import os
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
from mem0.memory.setup import mem0_dir
from mem0.vector_stores.configs import VectorStoreConfig
from mem0.llms.configs import LlmConfig
from mem0.embeddings.configs import EmbedderConfig
class MemoryItem(BaseModel):
id: str = Field(..., description="The unique identifier for the text data")
memory: str = Field(..., description="The memory deduced from the text data") # TODO After prompt changes from platform, update this
hash: Optional[str] = Field(None, description="The hash of the memory")
# The metadata value can be anything and not just string. Fix it
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata for the text data")
score: Optional[float] = Field(
None, description="The score associated with the text data"
)
created_at: Optional[str] = Field(None, description="The timestamp when the memory was created")
updated_at: Optional[str] = Field(None, description="The timestamp when the memory was updated")
class MemoryConfig(BaseModel):
vector_store: VectorStoreConfig = Field(
description="Configuration for the vector store",
default_factory=VectorStoreConfig,
)
llm: LlmConfig = Field(
description="Configuration for the language model",
default_factory=LlmConfig,
)
embedder: EmbedderConfig = Field(
description="Configuration for the embedding model",
default_factory=EmbedderConfig,
)
history_db_path: str = Field(
description="Path to the history database",
default=os.path.join(mem0_dir, "history.db"),
)

View File

View File

@@ -0,0 +1,32 @@
from abc import ABC
from typing import Optional
class BaseEmbedderConfig(ABC):
"""
Config for Embeddings.
"""
def __init__(
self,
model: Optional[str] = None,
embedding_dims: Optional[int] = None,
# Ollama specific
base_url: Optional[str] = None
):
"""
Initializes a configuration class instance for the Embeddings.
:param model: Embedding model to use, defaults to None
:type model: Optional[str], optional
:param embedding_dims: The number of dimensions in the embedding, defaults to None
:type embedding_dims: Optional[int], optional
:param base_url: Base URL for the Ollama API, defaults to None
:type base_url: Optional[str], optional
"""
self.model = model
self.embedding_dims = embedding_dims
# Ollama specific
self.base_url = base_url

View File

@@ -29,3 +29,14 @@ Constraint for deducing facts, preferences, and memories:
Deduced facts, preferences, and memories:
"""
MEMORY_ANSWER_PROMPT = """
You are an expert at answering questions based on the provided memories. Your task is to provide accurate and concise answers to the questions by leveraging the information given in the memories.
Guidelines:
- Extract relevant information from the memories based on the question.
- If no relevant information is found, make sure you don't say no information is found. Instead, accept the question and provide a general response.
- Ensure that the answers are clear, concise, and directly address the question.
Here are the details of the task:
"""

View File

@@ -15,51 +15,17 @@ from mem0.llms.utils.tools import (
)
from mem0.configs.prompts import MEMORY_DEDUCTION_PROMPT
from mem0.memory.base import MemoryBase
from mem0.memory.setup import mem0_dir, setup_config
from mem0.memory.setup import setup_config
from mem0.memory.storage import SQLiteManager
from mem0.memory.telemetry import capture_event
from mem0.memory.utils import get_update_memory_messages
from mem0.vector_stores.configs import VectorStoreConfig
from mem0.llms.configs import LlmConfig
from mem0.embeddings.configs import EmbedderConfig
from mem0.utils.factory import LlmFactory, EmbedderFactory, VectorStoreFactory
from mem0.configs.base import MemoryItem, MemoryConfig
# Setup user config
setup_config()
class MemoryItem(BaseModel):
id: str = Field(..., description="The unique identifier for the text data")
memory: str = Field(..., description="The memory deduced from the text data") # TODO After prompt changes from platform, update this
hash: Optional[str] = Field(None, description="The hash of the memory")
# The metadata value can be anything and not just string. Fix it
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata for the text data")
score: Optional[float] = Field(
None, description="The score associated with the text data"
)
created_at: Optional[str] = Field(None, description="The timestamp when the memory was created")
updated_at: Optional[str] = Field(None, description="The timestamp when the memory was updated")
class MemoryConfig(BaseModel):
vector_store: VectorStoreConfig = Field(
description="Configuration for the vector store",
default_factory=VectorStoreConfig,
)
llm: LlmConfig = Field(
description="Configuration for the language model",
default_factory=LlmConfig,
)
embedder: EmbedderConfig = Field(
description="Configuration for the embedding model",
default_factory=EmbedderConfig,
)
history_db_path: str = Field(
description="Path to the history database",
default=os.path.join(mem0_dir, "history.db"),
)
class Memory(MemoryBase):
def __init__(self, config: MemoryConfig = MemoryConfig()):
self.config = config

View File

@@ -17,6 +17,25 @@ class SQLiteManager:
table_exists = cursor.fetchone() is not None
if table_exists:
# Get the current schema of the history table
cursor.execute("PRAGMA table_info(history)")
current_schema = {row[1]: row[2] for row in cursor.fetchall()}
# Define the expected schema
expected_schema = {
'id': 'TEXT',
'memory_id': 'TEXT',
'old_memory': 'TEXT',
'new_memory': 'TEXT',
'new_value': 'TEXT',
'event': 'TEXT',
'created_at': 'DATETIME',
'updated_at': 'DATETIME',
'is_deleted': 'INTEGER'
}
# Check if the schemas are the same
if current_schema != expected_schema:
# Rename the old table
cursor.execute("ALTER TABLE history RENAME TO old_history")

0
mem0/proxy/__init__.py Normal file
View File

155
mem0/proxy/main.py Normal file
View File

@@ -0,0 +1,155 @@
import httpx
from typing import Optional, List, Union
import threading
import litellm
from mem0 import Memory, MemoryClient
from mem0.configs.prompts import MEMORY_ANSWER_PROMPT
class Mem0:
def __init__(
self,
config: Optional[dict] = None,
api_key: Optional[str] = None,
host: Optional[str] = None
):
if api_key:
self.mem0_client = MemoryClient(api_key, host)
else:
self.mem0_client = Memory.from_config(config) if config else Memory()
self.chat = Chat(self.mem0_client)
class Chat:
def __init__(self, mem0_client):
self.completions = Completions(mem0_client)
class Completions:
def __init__(self, mem0_client):
self.mem0_client = mem0_client
def create(
self,
model: str,
messages: List = [],
# Mem0 arguments
user_id: Optional[str] = None,
agent_id: Optional[str] = None,
run_id: Optional[str] = None,
metadata: Optional[dict] = None,
filters: Optional[dict] = None,
limit: Optional[int] = 10,
# LLM arguments
timeout: Optional[Union[float, str, httpx.Timeout]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
n: Optional[int] = None,
stream: Optional[bool] = None,
stream_options: Optional[dict] = None,
stop=None,
max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[dict] = None,
user: Optional[str] = None,
# openai v1.0+ new params
response_format: Optional[dict] = None,
seed: Optional[int] = None,
tools: Optional[List] = None,
tool_choice: Optional[Union[str, dict]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
deployment_id=None,
extra_headers: Optional[dict] = None,
# soon to be deprecated params by OpenAI
functions: Optional[List] = None,
function_call: Optional[str] = None,
# set api_base, api_version, api_key
base_url: Optional[str] = None,
api_version: Optional[str] = None,
api_key: Optional[str] = None,
model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
):
if not any([user_id, agent_id, run_id]):
raise ValueError("One of user_id, agent_id, run_id must be provided")
if not litellm.supports_function_calling(model):
raise ValueError(f"Model '{model}' does not support function calling. Please use a model that supports function calling.")
prepared_messages = self._prepare_messages(messages)
if prepared_messages[-1]["role"] == "user":
self._async_add_to_memory(messages, user_id, agent_id, run_id, metadata, filters)
relevant_memories = self._fetch_relevant_memories(messages, user_id, agent_id, run_id, filters, limit)
prepared_messages[-1]["content"] = self._format_query_with_memories(messages, relevant_memories)
response = litellm.completion(
model=model,
messages=prepared_messages,
temperature=temperature,
top_p=top_p,
n=n,
timeout=timeout,
stream=stream,
stream_options=stream_options,
stop=stop,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
user=user,
response_format=response_format,
seed=seed,
tools=tools,
tool_choice=tool_choice,
logprobs=logprobs,
top_logprobs=top_logprobs,
parallel_tool_calls=parallel_tool_calls,
deployment_id=deployment_id,
extra_headers=extra_headers,
functions=functions,
function_call=function_call,
base_url=base_url,
api_version=api_version,
api_key=api_key,
model_list=model_list
)
return response
def _prepare_messages(self, messages: List[dict]) -> List[dict]:
if not messages or messages[0]["role"] != "system":
return [{"role": "system", "content": MEMORY_ANSWER_PROMPT}] + messages
messages[0]["content"] = MEMORY_ANSWER_PROMPT
return messages
def _async_add_to_memory(self, messages, user_id, agent_id, run_id, metadata, filters):
def add_task():
self.mem0_client.add(
messages=messages,
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
metadata=metadata,
filters=filters,
)
threading.Thread(target=add_task, daemon=True).start()
def _fetch_relevant_memories(self, messages, user_id, agent_id, run_id, filters, limit):
# Currently, only pass the last 6 messages to the search API to prevent long query
message_input = [f"{message['role']}: {message['content']}" for message in messages][-6:]
# TODO: Make it better by summarizing the past conversation
return self.mem0_client.search(
query="\n".join(message_input),
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
filters=filters,
limit=limit,
)
def _format_query_with_memories(self, messages, relevant_memories):
memories_text = "\n".join(memory["memory"] for memory in relevant_memories)
return f"- Relevant Memories/Facts: {memories_text}\n\n- User Question: {messages[-1]['content']}"

View File

@@ -55,6 +55,8 @@ class VectorStoreFactory:
def create(cls, provider_name, config):
class_type = cls.provider_to_class.get(provider_name)
if class_type:
if not isinstance(config, dict):
config = config.model_dump()
vector_store_instance = load_class(class_type)
return vector_store_instance(**config)
else:

View File

@@ -21,20 +21,20 @@ class OutputData(BaseModel):
class ChromaDB(VectorStoreBase):
def __init__(
self,
collection_name="mem0",
client=None,
host=None,
port=None,
path=None
collection_name,
client,
host,
port,
path
):
"""
Initialize the Qdrant vector store.
Args:
client (QdrantClient, optional): Existing Qdrant client instance. Defaults to None.
host (str, optional): Host address for Qdrant server. Defaults to None.
port (int, optional): Port for Qdrant server. Defaults to None.
path (str, optional): Path for local Qdrant database. Defaults to None.
client (QdrantClient, optional): Existing Qdrant client instance.
host (str, optional): Host address for Qdrant server.
port (int, optional): Port for Qdrant server.
path (str, optional): Path for local Qdrant database.
"""
if client:
self.client = client
@@ -95,7 +95,7 @@ class ChromaDB(VectorStoreBase):
Args:
name (str): Name of the collection.
embedding_fn (function): Embedding function to use.
embedding_fn (function): Embedding function to use. Defaults to None.
"""
# Skip creating collection if already exists
collections = self.list_cols()
@@ -213,7 +213,7 @@ class ChromaDB(VectorStoreBase):
Args:
name (str): Name of the collection.
filters (dict, optional): Filters to apply to the list. Defaults to None.
filters (dict, optional): Filters to apply to the list.
limit (int, optional): Number of vectors to return. Defaults to 100.
Returns:

View File

@@ -3,14 +3,23 @@ from typing import Optional
from pydantic import BaseModel, Field, field_validator, model_validator
def create_default_config(provider: str):
"""Create a default configuration based on the provider."""
if provider == "qdrant":
return QdrantConfig(path="/tmp/qdrant")
elif provider == "chromadb":
return ChromaDbConfig(path="/tmp/chromadb")
else:
raise ValueError(f"Unsupported vector store provider: {provider}")
class QdrantConfig(BaseModel):
collection_name: str = Field(default="mem0", description="Name of the collection")
embedding_model_dims: Optional[int] = Field(
default=1536, description="Dimensions of the embedding model"
)
collection_name: str = Field("mem0", description="Name of the collection")
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
client: Optional[str] = Field(None, description="Existing Qdrant client instance")
host: Optional[str] = Field(None, description="Host address for Qdrant server")
port: Optional[int] = Field(None, description="Port for Qdrant server")
path: Optional[str] = Field(None, description="Path for local Qdrant database")
path: Optional[str] = Field("/tmp/qdrant", description="Path for local Qdrant database")
url: Optional[str] = Field(None, description="Full URL for Qdrant server")
api_key: Optional[str] = Field(None, description="API key for Qdrant server")
@@ -31,18 +40,11 @@ class QdrantConfig(BaseModel):
class ChromaDbConfig(BaseModel):
collection_name: str = Field(
default="mem0", description="Default name for the collection"
)
path: Optional[str] = Field(
default=None, description="Path to the database directory"
)
host: Optional[str] = Field(
default=None, description="Database connection remote host"
)
port: Optional[str] = Field(
default=None, description="Database connection remote port"
)
collection_name: str = Field("mem0", description="Default name for the collection")
client: Optional[str] = Field(None, description="Existing ChromaDB client instance")
path: Optional[str] = Field(None, description="Path to the database directory")
host: Optional[str] = Field(None, description="Database connection remote host")
port: Optional[str] = Field(None, description="Database connection remote port")
@model_validator(mode="before")
def check_host_port_or_path(cls, values):
@@ -59,15 +61,37 @@ class VectorStoreConfig(BaseModel):
)
config: Optional[dict] = Field(
description="Configuration for the specific vector store",
default={},
default=None
)
@field_validator("config")
def validate_config(cls, v, values):
provider = values.data.get("provider")
if v is None:
return create_default_config(provider)
if isinstance(v, dict):
if provider == "qdrant":
return QdrantConfig(**v.model_dump())
return QdrantConfig(**v)
elif provider == "chromadb":
return ChromaDbConfig(**v.model_dump())
else:
raise ValueError(f"Unsupported vector store provider: {provider}")
return ChromaDbConfig(**v)
return v
@model_validator(mode="after")
def ensure_config_type(cls, values):
provider = values.provider
config = values.config
if config is None:
values.config = create_default_config(provider)
elif isinstance(config, dict):
if provider == "qdrant":
values.config = QdrantConfig(**config)
elif provider == "chromadb":
values.config = ChromaDbConfig(**config)
elif not isinstance(config, (QdrantConfig, ChromaDbConfig)):
raise ValueError(f"Invalid config type for provider {provider}")
return values

View File

@@ -20,25 +20,25 @@ from mem0.vector_stores.base import VectorStoreBase
class Qdrant(VectorStoreBase):
def __init__(
self,
collection_name="mem0",
embedding_model_dims=1536,
client=None,
host="localhost",
port=6333,
path=None,
url=None,
api_key=None,
collection_name,
embedding_model_dims,
client,
host,
port,
path,
url,
api_key,
):
"""
Initialize the Qdrant vector store.
Args:
client (QdrantClient, optional): Existing Qdrant client instance. Defaults to None.
host (str, optional): Host address for Qdrant server. Defaults to "localhost".
port (int, optional): Port for Qdrant server. Defaults to 6333.
path (str, optional): Path for local Qdrant database. Defaults to None.
url (str, optional): Full URL for Qdrant server. Defaults to None.
api_key (str, optional): API key for Qdrant server. Defaults to None.
client (QdrantClient, optional): Existing Qdrant client instance.
host (str, optional): Host address for Qdrant server.
port (int, optional): Port for Qdrant server.
path (str, optional): Path for local Qdrant database.
url (str, optional): Full URL for Qdrant server.
api_key (str, optional): API key for Qdrant server.
"""
if client:
self.client = client

103
tests/test_proxy.py Normal file
View File

@@ -0,0 +1,103 @@
import pytest
from unittest.mock import Mock, patch
from mem0.configs.prompts import MEMORY_ANSWER_PROMPT
from mem0 import Memory, MemoryClient, Mem0
from mem0.proxy.main import Chat, Completions
@pytest.fixture
def mock_memory_client():
return Mock(spec=MemoryClient)
@pytest.fixture
def mock_openai_embedding_client():
with patch('mem0.embeddings.openai.OpenAI') as mock_openai:
mock_client = Mock()
mock_openai.return_value = mock_client
yield mock_client
@pytest.fixture
def mock_openai_llm_client():
with patch('mem0.llms.openai.OpenAI') as mock_openai:
mock_client = Mock()
mock_openai.return_value = mock_client
yield mock_client
@pytest.fixture
def mock_litellm():
with patch('mem0.proxy.main.litellm') as mock:
yield mock
def test_mem0_initialization_with_api_key(mock_openai_embedding_client, mock_openai_llm_client):
mem0 = Mem0()
assert isinstance(mem0.mem0_client, Memory)
assert isinstance(mem0.chat, Chat)
def test_mem0_initialization_with_config():
config = {"some_config": "value"}
with patch('mem0.Memory.from_config') as mock_from_config:
mem0 = Mem0(config=config)
mock_from_config.assert_called_once_with(config)
assert isinstance(mem0.chat, Chat)
def test_mem0_initialization_without_params(mock_openai_embedding_client, mock_openai_llm_client):
mem0 = Mem0()
assert isinstance(mem0.mem0_client, Memory)
assert isinstance(mem0.chat, Chat)
def test_chat_initialization(mock_memory_client):
chat = Chat(mock_memory_client)
assert isinstance(chat.completions, Completions)
def test_completions_create(mock_memory_client, mock_litellm):
completions = Completions(mock_memory_client)
messages = [
{"role": "user", "content": "Hello, how are you?"}
]
mock_memory_client.search.return_value = [{"memory": "Some relevant memory"}]
mock_litellm.completion.return_value = {"choices": [{"message": {"content": "I'm doing well, thank you!"}}]}
response = completions.create(
model="gpt-3.5-turbo",
messages=messages,
user_id="test_user",
temperature=0.7
)
mock_memory_client.add.assert_called_once()
mock_memory_client.search.assert_called_once()
mock_litellm.completion.assert_called_once()
call_args = mock_litellm.completion.call_args[1]
assert call_args['model'] == "gpt-3.5-turbo"
assert len(call_args['messages']) == 2
assert call_args['temperature'] == 0.7
assert response == {"choices": [{"message": {"content": "I'm doing well, thank you!"}}]}
def test_completions_create_with_system_message(mock_memory_client, mock_litellm):
completions = Completions(mock_memory_client)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"}
]
mock_memory_client.search.return_value = [{"memory": "Some relevant memory"}]
mock_litellm.completion.return_value = {"choices": [{"message": {"content": "I'm doing well, thank you!"}}]}
response = completions.create(
model="gpt-3.5-turbo",
messages=messages,
user_id="test_user"
)
call_args = mock_litellm.completion.call_args[1]
assert call_args['messages'][0]['role'] == "system"
assert call_args['messages'][0]['content'] == MEMORY_ANSWER_PROMPT