Fix CI issues related to missing dependency (#3096)

This commit is contained in:
Deshraj Yadav
2025-07-03 18:52:50 -07:00
committed by GitHub
parent 2c496e6376
commit 7484eed4b2
32 changed files with 6150 additions and 828 deletions

View File

@@ -58,11 +58,10 @@ jobs:
run: | run: |
pip install --upgrade pip pip install --upgrade pip
pip install -e ".[test,graph,vector_stores,llms,extras]" pip install -e ".[test,graph,vector_stores,llms,extras]"
pip install ruff
if: steps.cached-hatch-dependencies.outputs.cache-hit != 'true' if: steps.cached-hatch-dependencies.outputs.cache-hit != 'true'
- name: Run Formatting - name: Run Linting
run: | run: make lint
mkdir -p .ruff_cache && chmod -R 777 .ruff_cache
hatch run format
- name: Run tests and generate coverage report - name: Run tests and generate coverage report
run: make test run: make test

View File

@@ -13,7 +13,7 @@ install:
install_all: install_all:
pip install ruff==0.6.9 groq together boto3 litellm ollama chromadb weaviate weaviate-client sentence_transformers vertexai \ pip install ruff==0.6.9 groq together boto3 litellm ollama chromadb weaviate weaviate-client sentence_transformers vertexai \
google-generativeai elasticsearch opensearch-py vecs "pinecone<7.0.0" pinecone-text faiss-cpu langchain-community \ google-generativeai elasticsearch opensearch-py vecs "pinecone<7.0.0" pinecone-text faiss-cpu langchain-community \
upstash-vector azure-search-documents langchain-memgraph langchain-neo4j rank-bm25 pymochow upstash-vector azure-search-documents langchain-memgraph langchain-neo4j rank-bm25 pymochow pymongo
# Format code with ruff # Format code with ruff
format: format:

View File

@@ -110,7 +110,7 @@ def main():
print("All categories accuracy:") print("All categories accuracy:")
for cat, results in LLM_JUDGE.items(): for cat, results in LLM_JUDGE.items():
if results: # Only print if there are results for this category if results: # Only print if there are results for this category
print(f" Category {cat}: {np.mean(results):.4f} " f"({sum(results)}/{len(results)})") print(f" Category {cat}: {np.mean(results):.4f} ({sum(results)}/{len(results)})")
print("------------------------------------------") print("------------------------------------------")
index += 1 index += 1

View File

@@ -68,7 +68,7 @@ class RAGManager:
def clean_chat_history(self, chat_history): def clean_chat_history(self, chat_history):
cleaned_chat_history = "" cleaned_chat_history = ""
for c in chat_history: for c in chat_history:
cleaned_chat_history += f"{c['timestamp']} | {c['speaker']}: " f"{c['text']}\n" cleaned_chat_history += f"{c['timestamp']} | {c['speaker']}: {c['text']}\n"
return cleaned_chat_history return cleaned_chat_history

View File

@@ -68,9 +68,7 @@
"\n", "\n",
"import os\n", "import os\n",
"\n", "\n",
"os.environ[\"OPENAI_API_KEY\"] = (\n", "os.environ[\"OPENAI_API_KEY\"] = \"\""
" \"\"\n",
")"
] ]
}, },
{ {
@@ -164,7 +162,7 @@
" \"role\": \"assistant\",\n", " \"role\": \"assistant\",\n",
" \"content\": \"Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future.\",\n", " \"content\": \"Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future.\",\n",
" },\n", " },\n",
"]\n" "]"
] ]
}, },
{ {
@@ -185,9 +183,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# Store inferred memories (default behavior)\n", "# Store inferred memories (default behavior)\n",
"result = m.add(\n", "result = m.add(messages, user_id=\"alice\")"
" messages, user_id=\"alice\"\n",
")"
] ]
}, },
{ {

View File

@@ -45,11 +45,7 @@ def get_food_recommendation(user_query: str, user_id):
"""Get food recommendation with memory context""" """Get food recommendation with memory context"""
# Search memory for relevant food preferences # Search memory for relevant food preferences
memories_result = memory_client.search( memories_result = memory_client.search(query=user_query, user_id=user_id, limit=5)
query=user_query,
user_id=user_id,
limit=5
)
# Add memory context to the message # Add memory context to the message
memories = [f"- {result['memory']}" for result in memories_result] memories = [f"- {result['memory']}" for result in memories_result]
@@ -71,6 +67,7 @@ def get_food_recommendation(user_query: str, user_id):
# Save audio file # Save audio file
if response.audio: if response.audio:
import time import time
timestamp = int(time.time()) timestamp = int(time.time())
filename = f"food_recommendation_{timestamp}.mp3" filename = f"food_recommendation_{timestamp}.mp3"
write_audio_to_file( write_audio_to_file(
@@ -118,7 +115,11 @@ def initialize_food_memory(user_id):
# Initialize the memory for the user once in order for the agent to learn the user preference # Initialize the memory for the user once in order for the agent to learn the user preference
initialize_food_memory(user_id=USER_ID) initialize_food_memory(user_id=USER_ID)
print(get_food_recommendation("Which type of restaurants should I go tonight for dinner and cuisines preferred?", user_id=USER_ID)) print(
get_food_recommendation(
"Which type of restaurants should I go tonight for dinner and cuisines preferred?", user_id=USER_ID
)
)
# OUTPUT: 🎵 Audio saved as food_recommendation_1750162610.mp3 # OUTPUT: 🎵 Audio saved as food_recommendation_1750162610.mp3
# For dinner tonight, considering your love for healthy spic optionsy, you could try a nice Thai, Indian, or Mexican restaurant. # For dinner tonight, considering your love for healthy spic optionsy, you could try a nice Thai, Indian, or Mexican restaurant.
# You might find dishes with quinoa, chickpeas, tofu, and fresh herbs delightful. Enjoy your dinner! # You might find dishes with quinoa, chickpeas, tofu, and fresh herbs delightful. Enjoy your dinner!

View File

@@ -1,4 +1,4 @@
from agents import Agent, Runner, function_tool, handoffs, enable_verbose_stdout_logging from agents import Agent, Runner, function_tool, enable_verbose_stdout_logging
from dotenv import load_dotenv from dotenv import load_dotenv
from mem0 import MemoryClient from mem0 import MemoryClient
@@ -35,7 +35,7 @@ travel_agent = Agent(
understand the user's travel preferences and history before making recommendations. understand the user's travel preferences and history before making recommendations.
After providing your response, use store_conversation to save important details.""", After providing your response, use store_conversation to save important details.""",
tools=[search_memory, save_memory], tools=[search_memory, save_memory],
model="gpt-4o" model="gpt-4o",
) )
health_agent = Agent( health_agent = Agent(
@@ -44,7 +44,7 @@ health_agent = Agent(
understand the user's health goals and dietary preferences. understand the user's health goals and dietary preferences.
After providing advice, use store_conversation to save relevant information.""", After providing advice, use store_conversation to save relevant information.""",
tools=[search_memory, save_memory], tools=[search_memory, save_memory],
model="gpt-4o" model="gpt-4o",
) )
# Triage agent with handoffs # Triage agent with handoffs
@@ -55,7 +55,7 @@ triage_agent = Agent(
For health-related questions (fitness, diet, wellness, exercise), hand off to Health Advisor. For health-related questions (fitness, diet, wellness, exercise), hand off to Health Advisor.
For general questions, you can handle them directly using available tools.""", For general questions, you can handle them directly using available tools.""",
handoffs=[travel_agent, health_agent], handoffs=[travel_agent, health_agent],
model="gpt-4o" model="gpt-4o",
) )
@@ -74,10 +74,7 @@ def chat_with_handoffs(user_input: str, user_id: str) -> str:
result = Runner.run_sync(triage_agent, user_input) result = Runner.run_sync(triage_agent, user_input)
# Store the original conversation in memory # Store the original conversation in memory
conversation = [ conversation = [{"role": "user", "content": user_input}, {"role": "assistant", "content": result.final_output}]
{"role": "user", "content": user_input},
{"role": "assistant", "content": result.final_output}
]
mem0.add(conversation, user_id=user_id) mem0.add(conversation, user_id=user_id)
return result.final_output return result.final_output

View File

@@ -34,24 +34,16 @@ config = {
"api_key": "vllm-api-key", "api_key": "vllm-api-key",
"temperature": 0.7, "temperature": 0.7,
"max_tokens": 100, "max_tokens": 100,
}
}, },
"embedder": {
"provider": "openai",
"config": {
"model": "text-embedding-3-small"
}
}, },
"embedder": {"provider": "openai", "config": {"model": "text-embedding-3-small"}},
"vector_store": { "vector_store": {
"provider": "qdrant", "provider": "qdrant",
"config": { "config": {"collection_name": "vllm_memories", "host": "localhost", "port": 6333},
"collection_name": "vllm_memories", },
"host": "localhost",
"port": 6333
}
}
} }
def main(): def main():
""" """
Demonstrate vLLM integration with mem0 Demonstrate vLLM integration with mem0
@@ -68,34 +60,40 @@ def main():
{ {
"messages": [ "messages": [
{"role": "user", "content": "I love playing chess on weekends"}, {"role": "user", "content": "I love playing chess on weekends"},
{"role": "assistant", "content": "That's great! Chess is an excellent strategic game that helps improve critical thinking."} {
"role": "assistant",
"content": "That's great! Chess is an excellent strategic game that helps improve critical thinking.",
},
], ],
"user_id": "user_123" "user_id": "user_123",
}, },
{ {
"messages": [ "messages": [
{"role": "user", "content": "I'm learning Python programming"}, {"role": "user", "content": "I'm learning Python programming"},
{"role": "assistant", "content": "Python is a fantastic language for beginners! What specific areas are you focusing on?"} {
"role": "assistant",
"content": "Python is a fantastic language for beginners! What specific areas are you focusing on?",
},
], ],
"user_id": "user_123" "user_id": "user_123",
}, },
{ {
"messages": [ "messages": [
{"role": "user", "content": "I prefer working late at night, I'm more productive then"}, {"role": "user", "content": "I prefer working late at night, I'm more productive then"},
{"role": "assistant", "content": "Many people find they're more creative and focused during nighttime hours. It's important to maintain a consistent schedule that works for you."} {
"role": "assistant",
"content": "Many people find they're more creative and focused during nighttime hours. It's important to maintain a consistent schedule that works for you.",
},
], ],
"user_id": "user_123" "user_id": "user_123",
} },
] ]
print("\n--> Adding memories using vLLM...") print("\n--> Adding memories using vLLM...")
# Add memories - now powered by vLLM's high-performance inference # Add memories - now powered by vLLM's high-performance inference
for i, conversation in enumerate(conversations, 1): for i, conversation in enumerate(conversations, 1):
result = memory.add( result = memory.add(messages=conversation["messages"], user_id=conversation["user_id"])
messages=conversation["messages"],
user_id=conversation["user_id"]
)
print(f"Memory {i} added: {result}") print(f"Memory {i} added: {result}")
print("\n🔍 Searching memories...") print("\n🔍 Searching memories...")
@@ -104,15 +102,12 @@ def main():
search_queries = [ search_queries = [
"What does the user like to do on weekends?", "What does the user like to do on weekends?",
"What is the user learning?", "What is the user learning?",
"When is the user most productive?" "When is the user most productive?",
] ]
for query in search_queries: for query in search_queries:
print(f"\nQuery: {query}") print(f"\nQuery: {query}")
memories = memory.search( memories = memory.search(query=query, user_id="user_123")
query=query,
user_id="user_123"
)
for memory_item in memories: for memory_item in memories:
print(f" - {memory_item['memory']}") print(f" - {memory_item['memory']}")

View File

@@ -89,9 +89,7 @@ class MemoryClient:
self.user_id = get_user_id() self.user_id = get_user_id()
if not self.api_key: if not self.api_key:
raise ValueError( raise ValueError("Mem0 API Key not provided. Please provide an API Key.")
"Mem0 API Key not provided. Please provide an API Key."
)
# Create MD5 hash of API key for user_id # Create MD5 hash of API key for user_id
self.user_id = hashlib.md5(self.api_key.encode()).hexdigest() self.user_id = hashlib.md5(self.api_key.encode()).hexdigest()
@@ -174,9 +172,7 @@ class MemoryClient:
response.raise_for_status() response.raise_for_status()
if "metadata" in kwargs: if "metadata" in kwargs:
del kwargs["metadata"] del kwargs["metadata"]
capture_client_event( capture_client_event("client.add", self, {"keys": list(kwargs.keys()), "sync_type": "sync"})
"client.add", self, {"keys": list(kwargs.keys()), "sync_type": "sync"}
)
return response.json() return response.json()
@api_error_handler @api_error_handler
@@ -195,9 +191,7 @@ class MemoryClient:
params = self._prepare_params() params = self._prepare_params()
response = self.client.get(f"/v1/memories/{memory_id}/", params=params) response = self.client.get(f"/v1/memories/{memory_id}/", params=params)
response.raise_for_status() response.raise_for_status()
capture_client_event( capture_client_event("client.get", self, {"memory_id": memory_id, "sync_type": "sync"})
"client.get", self, {"memory_id": memory_id, "sync_type": "sync"}
)
return response.json() return response.json()
@api_error_handler @api_error_handler
@@ -224,13 +218,9 @@ class MemoryClient:
"page": params.pop("page"), "page": params.pop("page"),
"page_size": params.pop("page_size"), "page_size": params.pop("page_size"),
} }
response = self.client.post( response = self.client.post(f"/{version}/memories/", json=params, params=query_params)
f"/{version}/memories/", json=params, params=query_params
)
else: else:
response = self.client.post( response = self.client.post(f"/{version}/memories/", json=params)
f"/{version}/memories/", json=params
)
response.raise_for_status() response.raise_for_status()
if "metadata" in kwargs: if "metadata" in kwargs:
del kwargs["metadata"] del kwargs["metadata"]
@@ -246,9 +236,7 @@ class MemoryClient:
return response.json() return response.json()
@api_error_handler @api_error_handler
def search( def search(self, query: str, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
self, query: str, version: str = "v1", **kwargs
) -> List[Dict[str, Any]]:
"""Search memories based on a query. """Search memories based on a query.
Args: Args:
@@ -266,9 +254,7 @@ class MemoryClient:
payload = {"query": query} payload = {"query": query}
params = self._prepare_params(kwargs) params = self._prepare_params(kwargs)
payload.update(params) payload.update(params)
response = self.client.post( response = self.client.post(f"/{version}/memories/search/", json=payload)
f"/{version}/memories/search/", json=payload
)
response.raise_for_status() response.raise_for_status()
if "metadata" in kwargs: if "metadata" in kwargs:
del kwargs["metadata"] del kwargs["metadata"]
@@ -308,13 +294,9 @@ class MemoryClient:
if metadata is not None: if metadata is not None:
payload["metadata"] = metadata payload["metadata"] = metadata
capture_client_event( capture_client_event("client.update", self, {"memory_id": memory_id, "sync_type": "sync"})
"client.update", self, {"memory_id": memory_id, "sync_type": "sync"}
)
params = self._prepare_params() params = self._prepare_params()
response = self.client.put( response = self.client.put(f"/v1/memories/{memory_id}/", json=payload, params=params)
f"/v1/memories/{memory_id}/", json=payload, params=params
)
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
@@ -332,13 +314,9 @@ class MemoryClient:
APIError: If the API request fails. APIError: If the API request fails.
""" """
params = self._prepare_params() params = self._prepare_params()
response = self.client.delete( response = self.client.delete(f"/v1/memories/{memory_id}/", params=params)
f"/v1/memories/{memory_id}/", params=params
)
response.raise_for_status() response.raise_for_status()
capture_client_event( capture_client_event("client.delete", self, {"memory_id": memory_id, "sync_type": "sync"})
"client.delete", self, {"memory_id": memory_id, "sync_type": "sync"}
)
return response.json() return response.json()
@api_error_handler @api_error_handler
@@ -379,13 +357,9 @@ class MemoryClient:
APIError: If the API request fails. APIError: If the API request fails.
""" """
params = self._prepare_params() params = self._prepare_params()
response = self.client.get( response = self.client.get(f"/v1/memories/{memory_id}/history/", params=params)
f"/v1/memories/{memory_id}/history/", params=params
)
response.raise_for_status() response.raise_for_status()
capture_client_event( capture_client_event("client.history", self, {"memory_id": memory_id, "sync_type": "sync"})
"client.history", self, {"memory_id": memory_id, "sync_type": "sync"}
)
return response.json() return response.json()
@api_error_handler @api_error_handler
@@ -432,10 +406,7 @@ class MemoryClient:
else: else:
entities = self.users() entities = self.users()
# Filter entities based on provided IDs using list comprehension # Filter entities based on provided IDs using list comprehension
to_delete = [ to_delete = [{"type": entity["type"], "name": entity["name"]} for entity in entities["results"]]
{"type": entity["type"], "name": entity["name"]}
for entity in entities["results"]
]
params = self._prepare_params() params = self._prepare_params()
@@ -444,9 +415,7 @@ class MemoryClient:
# Delete entities and check response immediately # Delete entities and check response immediately
for entity in to_delete: for entity in to_delete:
response = self.client.delete( response = self.client.delete(f"/v2/entities/{entity['type']}/{entity['name']}/", params=params)
f"/v2/entities/{entity['type']}/{entity['name']}/", params=params
)
response.raise_for_status() response.raise_for_status()
capture_client_event( capture_client_event(
@@ -484,9 +453,7 @@ class MemoryClient:
self.delete_users() self.delete_users()
capture_client_event("client.reset", self, {"sync_type": "sync"}) capture_client_event("client.reset", self, {"sync_type": "sync"})
return { return {"message": "Client reset successful. All users and memories deleted."}
"message": "Client reset successful. All users and memories deleted."
}
@api_error_handler @api_error_handler
def batch_update(self, memories: List[Dict[str, Any]]) -> Dict[str, Any]: def batch_update(self, memories: List[Dict[str, Any]]) -> Dict[str, Any]:
@@ -507,9 +474,7 @@ class MemoryClient:
response = self.client.put("/v1/batch/", json={"memories": memories}) response = self.client.put("/v1/batch/", json={"memories": memories})
response.raise_for_status() response.raise_for_status()
capture_client_event( capture_client_event("client.batch_update", self, {"sync_type": "sync"})
"client.batch_update", self, {"sync_type": "sync"}
)
return response.json() return response.json()
@api_error_handler @api_error_handler
@@ -527,14 +492,10 @@ class MemoryClient:
Raises: Raises:
APIError: If the API request fails. APIError: If the API request fails.
""" """
response = self.client.request( response = self.client.request("DELETE", "/v1/batch/", json={"memories": memories})
"DELETE", "/v1/batch/", json={"memories": memories}
)
response.raise_for_status() response.raise_for_status()
capture_client_event( capture_client_event("client.batch_delete", self, {"sync_type": "sync"})
"client.batch_delete", self, {"sync_type": "sync"}
)
return response.json() return response.json()
@api_error_handler @api_error_handler
@@ -574,9 +535,7 @@ class MemoryClient:
Returns: Returns:
Dict containing the exported data Dict containing the exported data
""" """
response = self.client.post( response = self.client.post("/v1/exports/get/", json=self._prepare_params(kwargs))
"/v1/exports/get/", json=self._prepare_params(kwargs)
)
response.raise_for_status() response.raise_for_status()
capture_client_event( capture_client_event(
"client.get_memory_export", "client.get_memory_export",
@@ -586,9 +545,7 @@ class MemoryClient:
return response.json() return response.json()
@api_error_handler @api_error_handler
def get_summary( def get_summary(self, filters: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
self, filters: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Get the summary of a memory export. """Get the summary of a memory export.
Args: Args:
@@ -598,17 +555,13 @@ class MemoryClient:
Dict containing the export status and summary data Dict containing the export status and summary data
""" """
response = self.client.post( response = self.client.post("/v1/summary/", json=self._prepare_params({"filters": filters}))
"/v1/summary/", json=self._prepare_params({"filters": filters})
)
response.raise_for_status() response.raise_for_status()
capture_client_event("client.get_summary", self, {"sync_type": "sync"}) capture_client_event("client.get_summary", self, {"sync_type": "sync"})
return response.json() return response.json()
@api_error_handler @api_error_handler
def get_project( def get_project(self, fields: Optional[List[str]] = None) -> Dict[str, Any]:
self, fields: Optional[List[str]] = None
) -> Dict[str, Any]:
"""Get instructions or categories for the current project. """Get instructions or categories for the current project.
Args: Args:
@@ -622,10 +575,7 @@ class MemoryClient:
ValueError: If org_id or project_id are not set. ValueError: If org_id or project_id are not set.
""" """
if not (self.org_id and self.project_id): if not (self.org_id and self.project_id):
raise ValueError( raise ValueError("org_id and project_id must be set to access instructions or categories")
"org_id and project_id must be set to access instructions or "
"categories"
)
params = self._prepare_params({"fields": fields}) params = self._prepare_params({"fields": fields})
response = self.client.get( response = self.client.get(
@@ -666,10 +616,7 @@ class MemoryClient:
ValueError: If org_id or project_id are not set. ValueError: If org_id or project_id are not set.
""" """
if not (self.org_id and self.project_id): if not (self.org_id and self.project_id):
raise ValueError( raise ValueError("org_id and project_id must be set to update instructions or categories")
"org_id and project_id must be set to update instructions or "
"categories"
)
if ( if (
custom_instructions is None custom_instructions is None
@@ -826,10 +773,7 @@ class MemoryClient:
feedback = feedback.upper() if feedback else None feedback = feedback.upper() if feedback else None
if feedback is not None and feedback not in VALID_FEEDBACK_VALUES: if feedback is not None and feedback not in VALID_FEEDBACK_VALUES:
raise ValueError( raise ValueError(f"feedback must be one of {', '.join(VALID_FEEDBACK_VALUES)} or None")
f'feedback must be one of {", ".join(VALID_FEEDBACK_VALUES)} '
"or None"
)
data = { data = {
"memory_id": memory_id, "memory_id": memory_id,
@@ -839,14 +783,10 @@ class MemoryClient:
response = self.client.post("/v1/feedback/", json=data) response = self.client.post("/v1/feedback/", json=data)
response.raise_for_status() response.raise_for_status()
capture_client_event( capture_client_event("client.feedback", self, data, {"sync_type": "sync"})
"client.feedback", self, data, {"sync_type": "sync"}
)
return response.json() return response.json()
def _prepare_payload( def _prepare_payload(self, messages: List[Dict[str, str]], kwargs: Dict[str, Any]) -> Dict[str, Any]:
self, messages: List[Dict[str, str]], kwargs: Dict[str, Any]
) -> Dict[str, Any]:
"""Prepare the payload for API requests. """Prepare the payload for API requests.
Args: Args:
@@ -862,9 +802,7 @@ class MemoryClient:
payload.update({k: v for k, v in kwargs.items() if v is not None}) payload.update({k: v for k, v in kwargs.items() if v is not None})
return payload return payload
def _prepare_params( def _prepare_params(self, kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
self, kwargs: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Prepare query parameters for API requests. """Prepare query parameters for API requests.
Args: Args:
@@ -929,9 +867,7 @@ class AsyncMemoryClient:
self.user_id = get_user_id() self.user_id = get_user_id()
if not self.api_key: if not self.api_key:
raise ValueError( raise ValueError("Mem0 API Key not provided. Please provide an API Key.")
"Mem0 API Key not provided. Please provide an API Key."
)
# Create MD5 hash of API key for user_id # Create MD5 hash of API key for user_id
self.user_id = hashlib.md5(self.api_key.encode()).hexdigest() self.user_id = hashlib.md5(self.api_key.encode()).hexdigest()
@@ -989,9 +925,7 @@ class AsyncMemoryClient:
error_message = str(e) error_message = str(e)
raise ValueError(f"Error: {error_message}") raise ValueError(f"Error: {error_message}")
def _prepare_payload( def _prepare_payload(self, messages: List[Dict[str, str]], kwargs: Dict[str, Any]) -> Dict[str, Any]:
self, messages: List[Dict[str, str]], kwargs: Dict[str, Any]
) -> Dict[str, Any]:
"""Prepare the payload for API requests. """Prepare the payload for API requests.
Args: Args:
@@ -1007,9 +941,7 @@ class AsyncMemoryClient:
payload.update({k: v for k, v in kwargs.items() if v is not None}) payload.update({k: v for k, v in kwargs.items() if v is not None})
return payload return payload
def _prepare_params( def _prepare_params(self, kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
self, kwargs: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Prepare query parameters for API requests. """Prepare query parameters for API requests.
Args: Args:
@@ -1041,9 +973,7 @@ class AsyncMemoryClient:
await self.async_client.aclose() await self.async_client.aclose()
@api_error_handler @api_error_handler
async def add( async def add(self, messages: List[Dict[str, str]], **kwargs) -> Dict[str, Any]:
self, messages: List[Dict[str, str]], **kwargs
) -> Dict[str, Any]:
kwargs = self._prepare_params(kwargs) kwargs = self._prepare_params(kwargs)
if kwargs.get("output_format") != "v1.1": if kwargs.get("output_format") != "v1.1":
kwargs["output_format"] = "v1.1" kwargs["output_format"] = "v1.1"
@@ -1062,45 +992,31 @@ class AsyncMemoryClient:
response.raise_for_status() response.raise_for_status()
if "metadata" in kwargs: if "metadata" in kwargs:
del kwargs["metadata"] del kwargs["metadata"]
capture_client_event( capture_client_event("client.add", self, {"keys": list(kwargs.keys()), "sync_type": "async"})
"client.add", self, {"keys": list(kwargs.keys()), "sync_type": "async"}
)
return response.json() return response.json()
@api_error_handler @api_error_handler
async def get(self, memory_id: str) -> Dict[str, Any]: async def get(self, memory_id: str) -> Dict[str, Any]:
params = self._prepare_params() params = self._prepare_params()
response = await self.async_client.get( response = await self.async_client.get(f"/v1/memories/{memory_id}/", params=params)
f"/v1/memories/{memory_id}/", params=params
)
response.raise_for_status() response.raise_for_status()
capture_client_event( capture_client_event("client.get", self, {"memory_id": memory_id, "sync_type": "async"})
"client.get", self, {"memory_id": memory_id, "sync_type": "async"}
)
return response.json() return response.json()
@api_error_handler @api_error_handler
async def get_all( async def get_all(self, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
self, version: str = "v1", **kwargs
) -> List[Dict[str, Any]]:
params = self._prepare_params(kwargs) params = self._prepare_params(kwargs)
if version == "v1": if version == "v1":
response = await self.async_client.get( response = await self.async_client.get(f"/{version}/memories/", params=params)
f"/{version}/memories/", params=params
)
elif version == "v2": elif version == "v2":
if "page" in params and "page_size" in params: if "page" in params and "page_size" in params:
query_params = { query_params = {
"page": params.pop("page"), "page": params.pop("page"),
"page_size": params.pop("page_size"), "page_size": params.pop("page_size"),
} }
response = await self.async_client.post( response = await self.async_client.post(f"/{version}/memories/", json=params, params=query_params)
f"/{version}/memories/", json=params, params=query_params
)
else: else:
response = await self.async_client.post( response = await self.async_client.post(f"/{version}/memories/", json=params)
f"/{version}/memories/", json=params
)
response.raise_for_status() response.raise_for_status()
if "metadata" in kwargs: if "metadata" in kwargs:
del kwargs["metadata"] del kwargs["metadata"]
@@ -1116,14 +1032,10 @@ class AsyncMemoryClient:
return response.json() return response.json()
@api_error_handler @api_error_handler
async def search( async def search(self, query: str, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
self, query: str, version: str = "v1", **kwargs
) -> List[Dict[str, Any]]:
payload = {"query": query} payload = {"query": query}
payload.update(self._prepare_params(kwargs)) payload.update(self._prepare_params(kwargs))
response = await self.async_client.post( response = await self.async_client.post(f"/{version}/memories/search/", json=payload)
f"/{version}/memories/search/", json=payload
)
response.raise_for_status() response.raise_for_status()
if "metadata" in kwargs: if "metadata" in kwargs:
del kwargs["metadata"] del kwargs["metadata"]
@@ -1139,7 +1051,9 @@ class AsyncMemoryClient:
return response.json() return response.json()
@api_error_handler @api_error_handler
async def update(self, memory_id: str, text: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: async def update(
self, memory_id: str, text: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
""" """
Update a memory by ID. Update a memory by ID.
Args: Args:
@@ -1265,10 +1179,7 @@ class AsyncMemoryClient:
else: else:
entities = await self.users() entities = await self.users()
# Filter entities based on provided IDs using list comprehension # Filter entities based on provided IDs using list comprehension
to_delete = [ to_delete = [{"type": entity["type"], "name": entity["name"]} for entity in entities["results"]]
{"type": entity["type"], "name": entity["name"]}
for entity in entities["results"]
]
params = self._prepare_params() params = self._prepare_params()
@@ -1277,9 +1188,7 @@ class AsyncMemoryClient:
# Delete entities and check response immediately # Delete entities and check response immediately
for entity in to_delete: for entity in to_delete:
response = await self.async_client.delete( response = await self.async_client.delete(f"/v2/entities/{entity['type']}/{entity['name']}/", params=params)
f"/v2/entities/{entity['type']}/{entity['name']}/", params=params
)
response.raise_for_status() response.raise_for_status()
capture_client_event( capture_client_event(
@@ -1335,9 +1244,7 @@ class AsyncMemoryClient:
response = await self.async_client.put("/v1/batch/", json={"memories": memories}) response = await self.async_client.put("/v1/batch/", json={"memories": memories})
response.raise_for_status() response.raise_for_status()
capture_client_event( capture_client_event("client.batch_update", self, {"sync_type": "async"})
"client.batch_update", self, {"sync_type": "async"}
)
return response.json() return response.json()
@api_error_handler @api_error_handler
@@ -1355,14 +1262,10 @@ class AsyncMemoryClient:
Raises: Raises:
APIError: If the API request fails. APIError: If the API request fails.
""" """
response = await self.async_client.request( response = await self.async_client.request("DELETE", "/v1/batch/", json={"memories": memories})
"DELETE", "/v1/batch/", json={"memories": memories}
)
response.raise_for_status() response.raise_for_status()
capture_client_event( capture_client_event("client.batch_delete", self, {"sync_type": "async"})
"client.batch_delete", self, {"sync_type": "async"}
)
return response.json() return response.json()
@api_error_handler @api_error_handler
@@ -1614,7 +1517,7 @@ class AsyncMemoryClient:
feedback = feedback.upper() if feedback else None feedback = feedback.upper() if feedback else None
if feedback is not None and feedback not in VALID_FEEDBACK_VALUES: if feedback is not None and feedback not in VALID_FEEDBACK_VALUES:
raise ValueError(f'feedback must be one of {", ".join(VALID_FEEDBACK_VALUES)} or None') raise ValueError(f"feedback must be one of {', '.join(VALID_FEEDBACK_VALUES)} or None")
data = {"memory_id": memory_id, "feedback": feedback, "feedback_reason": feedback_reason} data = {"memory_id": memory_id, "feedback": feedback, "feedback_reason": feedback_reason}

View File

@@ -1,4 +1,3 @@
from enum import Enum
from typing import Any, Dict from typing import Any, Dict
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator

View File

@@ -11,7 +11,7 @@ class MongoDBConfig(BaseModel):
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding vectors") embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding vectors")
mongo_uri: str = Field("mongodb://localhost:27017", description="MongoDB URI. Default is mongodb://localhost:27017") mongo_uri: str = Field("mongodb://localhost:27017", description="MongoDB URI. Default is mongodb://localhost:27017")
@model_validator(mode='before') @model_validator(mode="before")
@classmethod @classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]: def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys()) allowed_fields = set(cls.model_fields.keys())

View File

@@ -36,6 +36,6 @@ class OpenSearchConfig(BaseModel):
extra_fields = input_fields - allowed_fields extra_fields = input_fields - allowed_fields
if extra_fields: if extra_fields:
raise ValueError( raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. " f"Allowed fields: {', '.join(allowed_fields)}" f"Extra fields not allowed: {', '.join(extra_fields)}. Allowed fields: {', '.join(allowed_fields)}"
) )
return values return values

View File

@@ -92,10 +92,12 @@ class AWSBedrockLLM(LLMBase):
if response["output"]["message"]["content"]: if response["output"]["message"]["content"]:
for item in response["output"]["message"]["content"]: for item in response["output"]["message"]["content"]:
if "toolUse" in item: if "toolUse" in item:
processed_response["tool_calls"].append({ processed_response["tool_calls"].append(
{
"name": item["toolUse"]["name"], "name": item["toolUse"]["name"],
"arguments": item["toolUse"]["input"], "arguments": item["toolUse"]["input"],
}) }
)
return processed_response return processed_response

View File

@@ -165,7 +165,6 @@ class GeminiLLM(LLMBase):
if system_instruction: if system_instruction:
config_params["system_instruction"] = system_instruction config_params["system_instruction"] = system_instruction
if response_format is not None and response_format["type"] == "json_object": if response_format is not None and response_format["type"] == "json_object":
config_params["response_mime_type"] = "application/json" config_params["response_mime_type"] = "application/json"
if "schema" in response_format: if "schema" in response_format:
@@ -175,7 +174,6 @@ class GeminiLLM(LLMBase):
formatted_tools = self._reformat_tools(tools) formatted_tools = self._reformat_tools(tools)
config_params["tools"] = formatted_tools config_params["tools"] = formatted_tools
if tool_choice: if tool_choice:
if tool_choice == "auto": if tool_choice == "auto":
mode = types.FunctionCallingConfigMode.AUTO mode = types.FunctionCallingConfigMode.AUTO

View File

@@ -18,7 +18,7 @@ class SarvamLLM(LLMBase):
if not self.api_key: if not self.api_key:
raise ValueError( raise ValueError(
"Sarvam API key is required. Set SARVAM_API_KEY environment variable " "or provide api_key in config." "Sarvam API key is required. Set SARVAM_API_KEY environment variable or provide api_key in config."
) )
# Set base URL - use config value or environment or default # Set base URL - use config value or environment or default

View File

@@ -7,7 +7,6 @@ from openai import OpenAI
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
from mem0.memory.utils import extract_json from mem0.memory.utils import extract_json
from openai import OpenAI
class VllmLLM(LLMBase): class VllmLLM(LLMBase):
@@ -41,10 +40,12 @@ class VllmLLM(LLMBase):
if response.choices[0].message.tool_calls: if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls: for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append({ processed_response["tool_calls"].append(
{
"name": tool_call.function.name, "name": tool_call.function.name,
"arguments": json.loads(extract_json(tool_call.function.arguments)), "arguments": json.loads(extract_json(tool_call.function.arguments)),
}) }
)
return processed_response return processed_response
else: else:

View File

@@ -136,7 +136,6 @@ class MemoryGraph:
params = {"user_id": filters["user_id"]} params = {"user_id": filters["user_id"]}
self.graph.query(cypher, params=params) self.graph.query(cypher, params=params)
def get_all(self, filters, limit=100): def get_all(self, filters, limit=100):
""" """
Retrieves all nodes and relationships from the graph database based on optional filtering criteria. Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
@@ -176,7 +175,6 @@ class MemoryGraph:
return final_results return final_results
def _retrieve_nodes_from_data(self, data, filters): def _retrieve_nodes_from_data(self, data, filters):
"""Extracts all the entities mentioned in the query.""" """Extracts all the entities mentioned in the query."""
_tools = [EXTRACT_ENTITIES_TOOL] _tools = [EXTRACT_ENTITIES_TOOL]
@@ -221,9 +219,7 @@ class MemoryGraph:
if self.config.graph_store.custom_prompt: if self.config.graph_store.custom_prompt:
system_content = EXTRACT_RELATIONS_PROMPT.replace("USER_ID", user_identity) system_content = EXTRACT_RELATIONS_PROMPT.replace("USER_ID", user_identity)
# Add the custom prompt line if configured # Add the custom prompt line if configured
system_content = system_content.replace( system_content = system_content.replace("CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}")
"CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}"
)
messages = [ messages = [
{"role": "system", "content": system_content}, {"role": "system", "content": system_content},
{"role": "user", "content": data}, {"role": "user", "content": data},
@@ -592,7 +588,6 @@ class MemoryGraph:
result = self.graph.query(cypher, params=params) result = self.graph.query(cypher, params=params)
return result return result
def _search_destination_node(self, destination_embedding, filters, threshold=0.9): def _search_destination_node(self, destination_embedding, filters, threshold=0.9):
agent_filter = "" agent_filter = ""
if filters.get("agent_id"): if filters.get("agent_id"):

View File

@@ -1,5 +1,5 @@
import logging import logging
from typing import List, Optional, Dict, Any, Callable from typing import List, Optional, Dict, Any
from pydantic import BaseModel from pydantic import BaseModel
@@ -26,13 +26,7 @@ class MongoDB(VectorStoreBase):
VECTOR_TYPE = "knnVector" VECTOR_TYPE = "knnVector"
SIMILARITY_METRIC = "cosine" SIMILARITY_METRIC = "cosine"
def __init__( def __init__(self, db_name: str, collection_name: str, embedding_model_dims: int, mongo_uri: str):
self,
db_name: str,
collection_name: str,
embedding_model_dims: int,
mongo_uri: str
):
""" """
Initialize the MongoDB vector store with vector search capabilities. Initialize the MongoDB vector store with vector search capabilities.
@@ -46,9 +40,7 @@ class MongoDB(VectorStoreBase):
self.embedding_model_dims = embedding_model_dims self.embedding_model_dims = embedding_model_dims
self.db_name = db_name self.db_name = db_name
self.client = MongoClient( self.client = MongoClient(mongo_uri)
mongo_uri
)
self.db = self.client[db_name] self.db = self.client[db_name]
self.collection = self.create_col() self.collection = self.create_col()
@@ -119,7 +111,9 @@ class MongoDB(VectorStoreBase):
except PyMongoError as e: except PyMongoError as e:
logger.error(f"Error inserting data: {e}") logger.error(f"Error inserting data: {e}")
def search(self, query: str, query_vector: List[float], limit=5, filters: Optional[Dict] = None) -> List[OutputData]: def search(
self, query: str, query_vector: List[float], limit=5, filters: Optional[Dict] = None
) -> List[OutputData]:
""" """
Search for similar vectors using the vector search index. Search for similar vectors using the vector search index.

5797
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -30,11 +30,13 @@ vector_stores = [
"vecs>=0.4.0", "vecs>=0.4.0",
"chromadb>=0.4.24", "chromadb>=0.4.24",
"weaviate-client>=4.4.0", "weaviate-client>=4.4.0",
"pinecone<7.0.0", "pinecone<=7.3.0",
"pinecone-text>=0.1.1", "pinecone-text>=0.10.0",
"faiss-cpu>=1.7.4", "faiss-cpu>=1.7.4",
"upstash-vector>=0.1.0", "upstash-vector>=0.1.0",
"azure-search-documents>=11.4.0b8", "azure-search-documents>=11.4.0b8",
"pymongo>=4.13.2",
"pymochow>=2.2.9",
] ]
llms = [ llms = [
"groq>=0.3.0", "groq>=0.3.0",
@@ -44,12 +46,11 @@ llms = [
"vertexai>=0.1.0", "vertexai>=0.1.0",
"google-generativeai>=0.3.0", "google-generativeai>=0.3.0",
"google-genai>=1.0.0", "google-genai>=1.0.0",
] ]
extras = [ extras = [
"boto3>=1.34.0", "boto3>=1.34.0",
"langchain-community>=0.0.0", "langchain-community>=0.0.0",
"sentence-transformers>=2.2.2", "sentence-transformers>=5.0.0",
"elasticsearch>=8.0.0", "elasticsearch>=8.0.0",
"opensearch-py>=2.0.0", "opensearch-py>=2.0.0",
"langchain-memgraph>=0.1.0", "langchain-memgraph>=0.1.0",

View File

@@ -1,4 +1,4 @@
from unittest.mock import patch from unittest.mock import patch, ANY
import pytest import pytest
@@ -8,8 +8,10 @@ from mem0.embeddings.gemini import GoogleGenAIEmbedding
@pytest.fixture @pytest.fixture
def mock_genai(): def mock_genai():
with patch("mem0.embeddings.gemini.genai.embed_content") as mock_genai: with patch("mem0.embeddings.gemini.genai.Client") as mock_client_class:
yield mock_genai mock_client = mock_client_class.return_value
mock_client.models.embed_content.return_value = None
yield mock_client.models.embed_content
@pytest.fixture @pytest.fixture
@@ -18,7 +20,9 @@ def config():
def test_embed_query(mock_genai, config): def test_embed_query(mock_genai, config):
mock_embedding_response = {"embedding": [0.1, 0.2, 0.3, 0.4]} mock_embedding_response = type('Response', (), {
'embeddings': [type('Embedding', (), {'values': [0.1, 0.2, 0.3, 0.4]})]
})()
mock_genai.return_value = mock_embedding_response mock_genai.return_value = mock_embedding_response
embedder = GoogleGenAIEmbedding(config) embedder = GoogleGenAIEmbedding(config)
@@ -27,10 +31,11 @@ def test_embed_query(mock_genai, config):
embedding = embedder.embed(text) embedding = embedder.embed(text)
assert embedding == [0.1, 0.2, 0.3, 0.4] assert embedding == [0.1, 0.2, 0.3, 0.4]
mock_genai.assert_called_once_with(model="test_model", content="Hello, world!", output_dimensionality=786) mock_genai.assert_called_once_with(model="test_model", contents="Hello, world!", config=ANY)
def test_embed_returns_empty_list_if_none(mock_genai, config): def test_embed_returns_empty_list_if_none(mock_genai, config):
mock_genai.return_value = None mock_genai.return_value = type('Response', (), {'embeddings': [type('Embedding', (), {'values': []})]})()
embedder = GoogleGenAIEmbedding(config) embedder = GoogleGenAIEmbedding(config)
result = embedder.embed("test") result = embedder.embed("test")
@@ -47,10 +52,10 @@ def test_embed_raises_on_error(mock_genai, config):
with pytest.raises(RuntimeError, match="Embedding failed"): with pytest.raises(RuntimeError, match="Embedding failed"):
embedder.embed("some input") embedder.embed("some input")
def test_config_initialization(config): def test_config_initialization(config):
embedder = GoogleGenAIEmbedding(config) embedder = GoogleGenAIEmbedding(config)
assert embedder.config.api_key == "dummy_api_key" assert embedder.config.api_key == "dummy_api_key"
assert embedder.config.model == "test_model" assert embedder.config.model == "test_model"
assert embedder.config.embedding_dims == 786 assert embedder.config.embedding_dims == 786

View File

@@ -9,7 +9,7 @@ from mem0.llms.gemini import GeminiLLM
@pytest.fixture @pytest.fixture
def mock_gemini_client(): def mock_gemini_client():
with patch("mem0.llms.gemini.genai") as mock_client_class: with patch("mem0.llms.gemini.genai.Client") as mock_client_class:
mock_client = Mock() mock_client = Mock()
mock_client_class.return_value = mock_client mock_client_class.return_value = mock_client
yield mock_client yield mock_client
@@ -24,45 +24,32 @@ def test_generate_response_without_tools(mock_gemini_client: Mock):
] ]
mock_part = Mock(text="I'm doing well, thank you for asking!") mock_part = Mock(text="I'm doing well, thank you for asking!")
mock_embedding = Mock() mock_content = Mock(parts=[mock_part])
mock_embedding.values = [0.1, 0.2, 0.3] mock_candidate = Mock(content=mock_content)
mock_response = Mock(candidates=[mock_candidate])
mock_response = Mock()
mock_response.candidates = [Mock()]
mock_response.candidates[0].content.parts = [Mock()]
mock_response.candidates[0].content.parts[0].text = "I'm doing well, thank you for asking!"
mock_gemini_client.models.generate_content.return_value = mock_response mock_gemini_client.models.generate_content.return_value = mock_response
mock_content = Mock(parts=[mock_part])
mock_message = Mock(content=mock_content)
mock_response = Mock(candidates=[mock_message])
mock_gemini_client.generate_content.return_value = mock_response
response = llm.generate_response(messages) response = llm.generate_response(messages)
mock_gemini_client.generate_content.assert_called_once_with( # Check the actual call - system instruction is now in config
contents=[ mock_gemini_client.models.generate_content.assert_called_once()
{"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"}, call_args = mock_gemini_client.models.generate_content.call_args
{"parts": "Hello, how are you?", "role": "user"},
],
config=types.GenerateContentConfig(
temperature=0.7,
max_output_tokens=100,
top_p=1.0,
tools=None,
tool_config=types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(
allowed_function_names=None,
mode="auto"
) # Verify model and contents
) assert call_args.kwargs['model'] == "gemini-2.0-flash-latest"
) ) assert len(call_args.kwargs['contents']) == 1 # Only user message
# Verify config has system instruction
config_arg = call_args.kwargs['config']
assert config_arg.system_instruction == "You are a helpful assistant."
assert config_arg.temperature == 0.7
assert config_arg.max_output_tokens == 100
assert config_arg.top_p == 1.0
assert response == "I'm doing well, thank you for asking!" assert response == "I'm doing well, thank you for asking!"
def test_generate_response_with_tools(mock_gemini_client: Mock): def test_generate_response_with_tools(mock_gemini_client: Mock):
config = BaseLlmConfig(model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0) config = BaseLlmConfig(model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0)
llm = GeminiLLM(config) llm = GeminiLLM(config)
@@ -89,64 +76,42 @@ def test_generate_response_with_tools(mock_gemini_client: Mock):
mock_tool_call.name = "add_memory" mock_tool_call.name = "add_memory"
mock_tool_call.args = {"data": "Today is a sunny day."} mock_tool_call.args = {"data": "Today is a sunny day."}
mock_part = Mock() # Create mock parts with both text and function_call
mock_part.function_call = mock_tool_call mock_text_part = Mock()
mock_part.text = "I've added the memory for you." mock_text_part.text = "I've added the memory for you."
mock_text_part.function_call = None
mock_func_part = Mock()
mock_func_part.text = None
mock_func_part.function_call = mock_tool_call
mock_content = Mock() mock_content = Mock()
mock_content.parts = [mock_part] mock_content.parts = [mock_text_part, mock_func_part]
mock_message = Mock() mock_candidate = Mock()
mock_message.content = mock_content mock_candidate.content = mock_content
mock_response = Mock(candidates=[mock_message]) mock_response = Mock(candidates=[mock_candidate])
mock_gemini_client.generate_content.return_value = mock_response mock_gemini_client.models.generate_content.return_value = mock_response
response = llm.generate_response(messages, tools=tools) response = llm.generate_response(messages, tools=tools)
mock_gemini_client.generate_content.assert_called_once_with( # Check the actual call
contents=[ mock_gemini_client.models.generate_content.assert_called_once()
{ call_args = mock_gemini_client.models.generate_content.call_args
"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.",
"role": "user" # Verify model and contents
}, assert call_args.kwargs['model'] == "gemini-1.5-flash-latest"
{ assert len(call_args.kwargs['contents']) == 1 # Only user message
"parts": "Add a new memory: Today is a sunny day.",
"role": "user" # Verify config has system instruction and tools
}, config_arg = call_args.kwargs['config']
], assert config_arg.system_instruction == "You are a helpful assistant."
config=types.GenerateContentConfig( assert config_arg.temperature == 0.7
temperature=0.7, assert config_arg.max_output_tokens == 100
max_output_tokens=100, assert config_arg.top_p == 1.0
top_p=1.0, assert len(config_arg.tools) == 1
tools=[ assert config_arg.tool_config.function_calling_config.mode == types.FunctionCallingConfigMode.AUTO
types.Tool(
function_declarations=[
types.FunctionDeclaration(
name="add_memory",
description="Add a memory",
parameters={
"type": "object",
"properties": {
"data": {
"type": "string",
"description": "Data to add to memory"
}
},
"required": ["data"]
}
)
]
)
],
tool_config=types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(
allowed_function_names=None,
mode="auto"
)
)
)
)
assert response["content"] == "I've added the memory for you." assert response["content"] == "I've added the memory for you."
assert len(response["tool_calls"]) == 1 assert len(response["tool_calls"]) == 1

View File

@@ -43,6 +43,7 @@ def test_generate_response_without_tools(mock_lm_studio_client):
assert response == "I'm doing well, thank you for asking!" assert response == "I'm doing well, thank you for asking!"
def test_generate_response_specifying_response_format(mock_lm_studio_client): def test_generate_response_specifying_response_format(mock_lm_studio_client):
config = BaseLlmConfig( config = BaseLlmConfig(
model="lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", model="lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf",

View File

@@ -71,7 +71,13 @@ def test_generate_response_with_tools(mock_vllm_client):
response = llm.generate_response(messages, tools=tools) response = llm.generate_response(messages, tools=tools)
mock_vllm_client.chat.completions.create.assert_called_once_with( mock_vllm_client.chat.completions.create.assert_called_once_with(
model="Qwen/Qwen2.5-32B-Instruct", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0, tools=tools, tool_choice="auto" model="Qwen/Qwen2.5-32B-Instruct",
messages=messages,
temperature=0.7,
max_tokens=100,
top_p=1.0,
tools=tools,
tool_choice="auto",
) )
assert response["content"] == "I've added the memory for you." assert response["content"] == "I've added the memory for you."

View File

@@ -253,11 +253,11 @@ def test_get_all(memory_instance, version, enable_graph, expected_result):
def test_custom_prompts(memory_custom_instance): def test_custom_prompts(memory_custom_instance):
messages = [{"role": "user", "content": "Test message"}] messages = [{"role": "user", "content": "Test message"}]
from mem0.embeddings.mock import MockEmbeddings from mem0.embeddings.mock import MockEmbeddings
memory_custom_instance.llm.generate_response = Mock() memory_custom_instance.llm.generate_response = Mock()
memory_custom_instance.llm.generate_response.return_value = '{"facts": ["fact1", "fact2"]}' memory_custom_instance.llm.generate_response.return_value = '{"facts": ["fact1", "fact2"]}'
memory_custom_instance.embedding_model = MockEmbeddings() memory_custom_instance.embedding_model = MockEmbeddings()
with patch("mem0.memory.main.parse_messages", return_value="Test message") as mock_parse_messages: with patch("mem0.memory.main.parse_messages", return_value="Test message") as mock_parse_messages:
with patch( with patch(
"mem0.memory.main.get_update_memory_messages", return_value="custom update memory prompt" "mem0.memory.main.get_update_memory_messages", return_value="custom update memory prompt"

View File

@@ -2,10 +2,9 @@ from unittest.mock import Mock, patch, PropertyMock
import pytest import pytest
from mem0.vector_stores.baidu import BaiduDB, OutputData from mem0.vector_stores.baidu import BaiduDB
from pymochow.model.enum import MetricType, TableState, ServerErrCode from pymochow.model.enum import TableState, ServerErrCode
from pymochow.model.schema import Field, Schema, VectorIndex, FilteringIndex, HNSWParams, AutoBuildRowCountIncrement from pymochow.model.table import VectorSearchConfig, VectorTopkSearchRequest, FloatVector, Table
from pymochow.model.table import Partition, Row, VectorSearchConfig, VectorTopkSearchRequest, FloatVector, Table
from pymochow.exception import ServerError from pymochow.exception import ServerError

View File

@@ -1,7 +1,7 @@
import pytest import pytest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from mem0.vector_stores.mongodb import MongoDB from mem0.vector_stores.mongodb import MongoDB
from pymongo.operations import SearchIndexModel
@pytest.fixture @pytest.fixture
@patch("mem0.vector_stores.mongodb.MongoClient") @patch("mem0.vector_stores.mongodb.MongoClient")
@@ -19,10 +19,11 @@ def mongo_vector_fixture(mock_mongo_client):
db_name="test_db", db_name="test_db",
collection_name="test_collection", collection_name="test_collection",
embedding_model_dims=1536, embedding_model_dims=1536,
mongo_uri="mongodb://username:password@localhost:27017" mongo_uri="mongodb://username:password@localhost:27017",
) )
return mongo_vector, mock_collection, mock_db return mongo_vector, mock_collection, mock_db
def test_initalize_create_col(mongo_vector_fixture): def test_initalize_create_col(mongo_vector_fixture):
mongo_vector, mock_collection, mock_db = mongo_vector_fixture mongo_vector, mock_collection, mock_db = mongo_vector_fixture
assert mongo_vector.collection_name == "test_collection" assert mongo_vector.collection_name == "test_collection"
@@ -49,12 +50,13 @@ def test_initalize_create_col(mongo_vector_fixture):
"dimensions": 1536, "dimensions": 1536,
"similarity": "cosine", "similarity": "cosine",
} }
},
} }
} },
}
} }
assert mongo_vector.collection == mock_collection assert mongo_vector.collection == mock_collection
def test_insert(mongo_vector_fixture): def test_insert(mongo_vector_fixture):
mongo_vector, mock_collection, _ = mongo_vector_fixture mongo_vector, mock_collection, _ = mongo_vector_fixture
vectors = [[0.1] * 1536, [0.2] * 1536] vectors = [[0.1] * 1536, [0.2] * 1536]
@@ -64,10 +66,11 @@ def test_insert(mongo_vector_fixture):
mongo_vector.insert(vectors, payloads, ids) mongo_vector.insert(vectors, payloads, ids)
expected_records = [ expected_records = [
({"_id": ids[0], "embedding": vectors[0], "payload": payloads[0]}), ({"_id": ids[0], "embedding": vectors[0], "payload": payloads[0]}),
({"_id": ids[1], "embedding": vectors[1], "payload": payloads[1]}) ({"_id": ids[1], "embedding": vectors[1], "payload": payloads[1]}),
] ]
mock_collection.insert_many.assert_called_once_with(expected_records) mock_collection.insert_many.assert_called_once_with(expected_records)
def test_search(mongo_vector_fixture): def test_search(mongo_vector_fixture):
mongo_vector, mock_collection, _ = mongo_vector_fixture mongo_vector, mock_collection, _ = mongo_vector_fixture
query_vector = [0.1] * 1536 query_vector = [0.1] * 1536
@@ -79,7 +82,8 @@ def test_search(mongo_vector_fixture):
results = mongo_vector.search("query_str", query_vector, limit=2) results = mongo_vector.search("query_str", query_vector, limit=2)
mock_collection.list_search_indexes.assert_called_with(name="test_collection_vector_index") mock_collection.list_search_indexes.assert_called_with(name="test_collection_vector_index")
mock_collection.aggregate.assert_called_once_with([ mock_collection.aggregate.assert_called_once_with(
[
{ {
"$vectorSearch": { "$vectorSearch": {
"index": "test_collection_vector_index", "index": "test_collection_vector_index",
@@ -91,13 +95,15 @@ def test_search(mongo_vector_fixture):
}, },
{"$set": {"score": {"$meta": "vectorSearchScore"}}}, {"$set": {"score": {"$meta": "vectorSearchScore"}}},
{"$project": {"embedding": 0}}, {"$project": {"embedding": 0}},
]) ]
)
assert len(results) == 2 assert len(results) == 2
assert results[0].id == "id1" assert results[0].id == "id1"
assert results[0].score == 0.9 assert results[0].score == 0.9
assert results[1].id == "id2" assert results[1].id == "id2"
assert results[1].score == 0.8 assert results[1].score == 0.8
def test_delete(mongo_vector_fixture): def test_delete(mongo_vector_fixture):
mongo_vector, mock_collection, _ = mongo_vector_fixture mongo_vector, mock_collection, _ = mongo_vector_fixture
mock_delete_result = MagicMock() mock_delete_result = MagicMock()
@@ -107,6 +113,7 @@ def test_delete(mongo_vector_fixture):
mongo_vector.delete("id1") mongo_vector.delete("id1")
mock_collection.delete_one.assert_called_with({"_id": "id1"}) mock_collection.delete_one.assert_called_with({"_id": "id1"})
def test_update(mongo_vector_fixture): def test_update(mongo_vector_fixture):
mongo_vector, mock_collection, _ = mongo_vector_fixture mongo_vector, mock_collection, _ = mongo_vector_fixture
mock_update_result = MagicMock() mock_update_result = MagicMock()
@@ -122,6 +129,7 @@ def test_update(mongo_vector_fixture):
{"$set": {"embedding": vectorValue, "payload": payloadValue}}, {"$set": {"embedding": vectorValue, "payload": payloadValue}},
) )
def test_get(mongo_vector_fixture): def test_get(mongo_vector_fixture):
mongo_vector, mock_collection, _ = mongo_vector_fixture mongo_vector, mock_collection, _ = mongo_vector_fixture
mock_collection.find_one.return_value = {"_id": "id1", "payload": {"key": "value1"}} mock_collection.find_one.return_value = {"_id": "id1", "payload": {"key": "value1"}}
@@ -131,6 +139,7 @@ def test_get(mongo_vector_fixture):
assert result.id == "id1" assert result.id == "id1"
assert result.payload == {"key": "value1"} assert result.payload == {"key": "value1"}
def test_list_cols(mongo_vector_fixture): def test_list_cols(mongo_vector_fixture):
mongo_vector, _, mock_db = mongo_vector_fixture mongo_vector, _, mock_db = mongo_vector_fixture
mock_db.list_collection_names.return_value = ["col1", "col2"] mock_db.list_collection_names.return_value = ["col1", "col2"]
@@ -138,12 +147,14 @@ def test_list_cols(mongo_vector_fixture):
collections = mongo_vector.list_cols() collections = mongo_vector.list_cols()
assert collections == ["col1", "col2"] assert collections == ["col1", "col2"]
def test_delete_col(mongo_vector_fixture): def test_delete_col(mongo_vector_fixture):
mongo_vector, mock_collection, _ = mongo_vector_fixture mongo_vector, mock_collection, _ = mongo_vector_fixture
mongo_vector.delete_col() mongo_vector.delete_col()
mock_collection.drop.assert_called_once() mock_collection.drop.assert_called_once()
def test_col_info(mongo_vector_fixture): def test_col_info(mongo_vector_fixture):
mongo_vector, _, mock_db = mongo_vector_fixture mongo_vector, _, mock_db = mongo_vector_fixture
mock_db.command.return_value = {"count": 10, "size": 1024} mock_db.command.return_value = {"count": 10, "size": 1024}
@@ -154,6 +165,7 @@ def test_col_info(mongo_vector_fixture):
assert info["count"] == 10 assert info["count"] == 10
assert info["size"] == 1024 assert info["size"] == 1024
def test_list(mongo_vector_fixture): def test_list(mongo_vector_fixture):
mongo_vector, mock_collection, _ = mongo_vector_fixture mongo_vector, mock_collection, _ = mongo_vector_fixture
mock_cursor = MagicMock() mock_cursor = MagicMock()

View File

@@ -88,7 +88,7 @@ def test_get_vector_found(pinecone_db):
# or a list of dictionaries, not a dictionary with an 'id' field # or a list of dictionaries, not a dictionary with an 'id' field
# Create a mock Vector object # Create a mock Vector object
from pinecone.data.dataclasses.vector import Vector from pinecone import Vector
mock_vector = Vector(id="id1", values=[0.1] * 128, metadata={"name": "vector1"}) mock_vector = Vector(id="id1", values=[0.1] * 128, metadata={"name": "vector1"})