Fix CI issues related to missing dependency (#3096)

This commit is contained in:
Deshraj Yadav
2025-07-03 18:52:50 -07:00
committed by GitHub
parent 2c496e6376
commit 7484eed4b2
32 changed files with 6150 additions and 828 deletions

View File

@@ -58,11 +58,10 @@ jobs:
run: |
pip install --upgrade pip
pip install -e ".[test,graph,vector_stores,llms,extras]"
pip install ruff
if: steps.cached-hatch-dependencies.outputs.cache-hit != 'true'
- name: Run Formatting
run: |
mkdir -p .ruff_cache && chmod -R 777 .ruff_cache
hatch run format
- name: Run Linting
run: make lint
- name: Run tests and generate coverage report
run: make test

View File

@@ -13,7 +13,7 @@ install:
install_all:
pip install ruff==0.6.9 groq together boto3 litellm ollama chromadb weaviate weaviate-client sentence_transformers vertexai \
google-generativeai elasticsearch opensearch-py vecs "pinecone<7.0.0" pinecone-text faiss-cpu langchain-community \
upstash-vector azure-search-documents langchain-memgraph langchain-neo4j rank-bm25 pymochow
upstash-vector azure-search-documents langchain-memgraph langchain-neo4j rank-bm25 pymochow pymongo
# Format code with ruff
format:

View File

@@ -110,7 +110,7 @@ def main():
print("All categories accuracy:")
for cat, results in LLM_JUDGE.items():
if results: # Only print if there are results for this category
print(f" Category {cat}: {np.mean(results):.4f} " f"({sum(results)}/{len(results)})")
print(f" Category {cat}: {np.mean(results):.4f} ({sum(results)}/{len(results)})")
print("------------------------------------------")
index += 1

View File

@@ -68,7 +68,7 @@ class RAGManager:
def clean_chat_history(self, chat_history):
cleaned_chat_history = ""
for c in chat_history:
cleaned_chat_history += f"{c['timestamp']} | {c['speaker']}: " f"{c['text']}\n"
cleaned_chat_history += f"{c['timestamp']} | {c['speaker']}: {c['text']}\n"
return cleaned_chat_history

View File

@@ -68,9 +68,7 @@
"\n",
"import os\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = (\n",
" \"\"\n",
")"
"os.environ[\"OPENAI_API_KEY\"] = \"\""
]
},
{
@@ -164,7 +162,7 @@
" \"role\": \"assistant\",\n",
" \"content\": \"Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future.\",\n",
" },\n",
"]\n"
"]"
]
},
{
@@ -185,9 +183,7 @@
"outputs": [],
"source": [
"# Store inferred memories (default behavior)\n",
"result = m.add(\n",
" messages, user_id=\"alice\"\n",
")"
"result = m.add(messages, user_id=\"alice\")"
]
},
{

View File

@@ -45,11 +45,7 @@ def get_food_recommendation(user_query: str, user_id):
"""Get food recommendation with memory context"""
# Search memory for relevant food preferences
memories_result = memory_client.search(
query=user_query,
user_id=user_id,
limit=5
)
memories_result = memory_client.search(query=user_query, user_id=user_id, limit=5)
# Add memory context to the message
memories = [f"- {result['memory']}" for result in memories_result]
@@ -71,6 +67,7 @@ def get_food_recommendation(user_query: str, user_id):
# Save audio file
if response.audio:
import time
timestamp = int(time.time())
filename = f"food_recommendation_{timestamp}.mp3"
write_audio_to_file(
@@ -118,7 +115,11 @@ def initialize_food_memory(user_id):
# Initialize the memory for the user once in order for the agent to learn the user preference
initialize_food_memory(user_id=USER_ID)
print(get_food_recommendation("Which type of restaurants should I go tonight for dinner and cuisines preferred?", user_id=USER_ID))
print(
get_food_recommendation(
"Which type of restaurants should I go tonight for dinner and cuisines preferred?", user_id=USER_ID
)
)
# OUTPUT: 🎵 Audio saved as food_recommendation_1750162610.mp3
# For dinner tonight, considering your love for healthy spic optionsy, you could try a nice Thai, Indian, or Mexican restaurant.
# You might find dishes with quinoa, chickpeas, tofu, and fresh herbs delightful. Enjoy your dinner!

View File

@@ -1,4 +1,4 @@
from agents import Agent, Runner, function_tool, handoffs, enable_verbose_stdout_logging
from agents import Agent, Runner, function_tool, enable_verbose_stdout_logging
from dotenv import load_dotenv
from mem0 import MemoryClient
@@ -35,7 +35,7 @@ travel_agent = Agent(
understand the user's travel preferences and history before making recommendations.
After providing your response, use store_conversation to save important details.""",
tools=[search_memory, save_memory],
model="gpt-4o"
model="gpt-4o",
)
health_agent = Agent(
@@ -44,7 +44,7 @@ health_agent = Agent(
understand the user's health goals and dietary preferences.
After providing advice, use store_conversation to save relevant information.""",
tools=[search_memory, save_memory],
model="gpt-4o"
model="gpt-4o",
)
# Triage agent with handoffs
@@ -55,7 +55,7 @@ triage_agent = Agent(
For health-related questions (fitness, diet, wellness, exercise), hand off to Health Advisor.
For general questions, you can handle them directly using available tools.""",
handoffs=[travel_agent, health_agent],
model="gpt-4o"
model="gpt-4o",
)
@@ -74,10 +74,7 @@ def chat_with_handoffs(user_input: str, user_id: str) -> str:
result = Runner.run_sync(triage_agent, user_input)
# Store the original conversation in memory
conversation = [
{"role": "user", "content": user_input},
{"role": "assistant", "content": result.final_output}
]
conversation = [{"role": "user", "content": user_input}, {"role": "assistant", "content": result.final_output}]
mem0.add(conversation, user_id=user_id)
return result.final_output

View File

@@ -34,24 +34,16 @@ config = {
"api_key": "vllm-api-key",
"temperature": 0.7,
"max_tokens": 100,
}
},
"embedder": {
"provider": "openai",
"config": {
"model": "text-embedding-3-small"
}
},
"embedder": {"provider": "openai", "config": {"model": "text-embedding-3-small"}},
"vector_store": {
"provider": "qdrant",
"config": {
"collection_name": "vllm_memories",
"host": "localhost",
"port": 6333
}
}
"config": {"collection_name": "vllm_memories", "host": "localhost", "port": 6333},
},
}
def main():
"""
Demonstrate vLLM integration with mem0
@@ -68,34 +60,40 @@ def main():
{
"messages": [
{"role": "user", "content": "I love playing chess on weekends"},
{"role": "assistant", "content": "That's great! Chess is an excellent strategic game that helps improve critical thinking."}
{
"role": "assistant",
"content": "That's great! Chess is an excellent strategic game that helps improve critical thinking.",
},
],
"user_id": "user_123"
"user_id": "user_123",
},
{
"messages": [
{"role": "user", "content": "I'm learning Python programming"},
{"role": "assistant", "content": "Python is a fantastic language for beginners! What specific areas are you focusing on?"}
{
"role": "assistant",
"content": "Python is a fantastic language for beginners! What specific areas are you focusing on?",
},
],
"user_id": "user_123"
"user_id": "user_123",
},
{
"messages": [
{"role": "user", "content": "I prefer working late at night, I'm more productive then"},
{"role": "assistant", "content": "Many people find they're more creative and focused during nighttime hours. It's important to maintain a consistent schedule that works for you."}
{
"role": "assistant",
"content": "Many people find they're more creative and focused during nighttime hours. It's important to maintain a consistent schedule that works for you.",
},
],
"user_id": "user_123"
}
"user_id": "user_123",
},
]
print("\n--> Adding memories using vLLM...")
# Add memories - now powered by vLLM's high-performance inference
for i, conversation in enumerate(conversations, 1):
result = memory.add(
messages=conversation["messages"],
user_id=conversation["user_id"]
)
result = memory.add(messages=conversation["messages"], user_id=conversation["user_id"])
print(f"Memory {i} added: {result}")
print("\n🔍 Searching memories...")
@@ -104,15 +102,12 @@ def main():
search_queries = [
"What does the user like to do on weekends?",
"What is the user learning?",
"When is the user most productive?"
"When is the user most productive?",
]
for query in search_queries:
print(f"\nQuery: {query}")
memories = memory.search(
query=query,
user_id="user_123"
)
memories = memory.search(query=query, user_id="user_123")
for memory_item in memories:
print(f" - {memory_item['memory']}")

View File

@@ -89,9 +89,7 @@ class MemoryClient:
self.user_id = get_user_id()
if not self.api_key:
raise ValueError(
"Mem0 API Key not provided. Please provide an API Key."
)
raise ValueError("Mem0 API Key not provided. Please provide an API Key.")
# Create MD5 hash of API key for user_id
self.user_id = hashlib.md5(self.api_key.encode()).hexdigest()
@@ -174,9 +172,7 @@ class MemoryClient:
response.raise_for_status()
if "metadata" in kwargs:
del kwargs["metadata"]
capture_client_event(
"client.add", self, {"keys": list(kwargs.keys()), "sync_type": "sync"}
)
capture_client_event("client.add", self, {"keys": list(kwargs.keys()), "sync_type": "sync"})
return response.json()
@api_error_handler
@@ -195,9 +191,7 @@ class MemoryClient:
params = self._prepare_params()
response = self.client.get(f"/v1/memories/{memory_id}/", params=params)
response.raise_for_status()
capture_client_event(
"client.get", self, {"memory_id": memory_id, "sync_type": "sync"}
)
capture_client_event("client.get", self, {"memory_id": memory_id, "sync_type": "sync"})
return response.json()
@api_error_handler
@@ -224,13 +218,9 @@ class MemoryClient:
"page": params.pop("page"),
"page_size": params.pop("page_size"),
}
response = self.client.post(
f"/{version}/memories/", json=params, params=query_params
)
response = self.client.post(f"/{version}/memories/", json=params, params=query_params)
else:
response = self.client.post(
f"/{version}/memories/", json=params
)
response = self.client.post(f"/{version}/memories/", json=params)
response.raise_for_status()
if "metadata" in kwargs:
del kwargs["metadata"]
@@ -246,9 +236,7 @@ class MemoryClient:
return response.json()
@api_error_handler
def search(
self, query: str, version: str = "v1", **kwargs
) -> List[Dict[str, Any]]:
def search(self, query: str, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
"""Search memories based on a query.
Args:
@@ -266,9 +254,7 @@ class MemoryClient:
payload = {"query": query}
params = self._prepare_params(kwargs)
payload.update(params)
response = self.client.post(
f"/{version}/memories/search/", json=payload
)
response = self.client.post(f"/{version}/memories/search/", json=payload)
response.raise_for_status()
if "metadata" in kwargs:
del kwargs["metadata"]
@@ -308,13 +294,9 @@ class MemoryClient:
if metadata is not None:
payload["metadata"] = metadata
capture_client_event(
"client.update", self, {"memory_id": memory_id, "sync_type": "sync"}
)
capture_client_event("client.update", self, {"memory_id": memory_id, "sync_type": "sync"})
params = self._prepare_params()
response = self.client.put(
f"/v1/memories/{memory_id}/", json=payload, params=params
)
response = self.client.put(f"/v1/memories/{memory_id}/", json=payload, params=params)
response.raise_for_status()
return response.json()
@@ -332,13 +314,9 @@ class MemoryClient:
APIError: If the API request fails.
"""
params = self._prepare_params()
response = self.client.delete(
f"/v1/memories/{memory_id}/", params=params
)
response = self.client.delete(f"/v1/memories/{memory_id}/", params=params)
response.raise_for_status()
capture_client_event(
"client.delete", self, {"memory_id": memory_id, "sync_type": "sync"}
)
capture_client_event("client.delete", self, {"memory_id": memory_id, "sync_type": "sync"})
return response.json()
@api_error_handler
@@ -379,13 +357,9 @@ class MemoryClient:
APIError: If the API request fails.
"""
params = self._prepare_params()
response = self.client.get(
f"/v1/memories/{memory_id}/history/", params=params
)
response = self.client.get(f"/v1/memories/{memory_id}/history/", params=params)
response.raise_for_status()
capture_client_event(
"client.history", self, {"memory_id": memory_id, "sync_type": "sync"}
)
capture_client_event("client.history", self, {"memory_id": memory_id, "sync_type": "sync"})
return response.json()
@api_error_handler
@@ -432,10 +406,7 @@ class MemoryClient:
else:
entities = self.users()
# Filter entities based on provided IDs using list comprehension
to_delete = [
{"type": entity["type"], "name": entity["name"]}
for entity in entities["results"]
]
to_delete = [{"type": entity["type"], "name": entity["name"]} for entity in entities["results"]]
params = self._prepare_params()
@@ -444,9 +415,7 @@ class MemoryClient:
# Delete entities and check response immediately
for entity in to_delete:
response = self.client.delete(
f"/v2/entities/{entity['type']}/{entity['name']}/", params=params
)
response = self.client.delete(f"/v2/entities/{entity['type']}/{entity['name']}/", params=params)
response.raise_for_status()
capture_client_event(
@@ -484,9 +453,7 @@ class MemoryClient:
self.delete_users()
capture_client_event("client.reset", self, {"sync_type": "sync"})
return {
"message": "Client reset successful. All users and memories deleted."
}
return {"message": "Client reset successful. All users and memories deleted."}
@api_error_handler
def batch_update(self, memories: List[Dict[str, Any]]) -> Dict[str, Any]:
@@ -507,9 +474,7 @@ class MemoryClient:
response = self.client.put("/v1/batch/", json={"memories": memories})
response.raise_for_status()
capture_client_event(
"client.batch_update", self, {"sync_type": "sync"}
)
capture_client_event("client.batch_update", self, {"sync_type": "sync"})
return response.json()
@api_error_handler
@@ -527,14 +492,10 @@ class MemoryClient:
Raises:
APIError: If the API request fails.
"""
response = self.client.request(
"DELETE", "/v1/batch/", json={"memories": memories}
)
response = self.client.request("DELETE", "/v1/batch/", json={"memories": memories})
response.raise_for_status()
capture_client_event(
"client.batch_delete", self, {"sync_type": "sync"}
)
capture_client_event("client.batch_delete", self, {"sync_type": "sync"})
return response.json()
@api_error_handler
@@ -574,9 +535,7 @@ class MemoryClient:
Returns:
Dict containing the exported data
"""
response = self.client.post(
"/v1/exports/get/", json=self._prepare_params(kwargs)
)
response = self.client.post("/v1/exports/get/", json=self._prepare_params(kwargs))
response.raise_for_status()
capture_client_event(
"client.get_memory_export",
@@ -586,9 +545,7 @@ class MemoryClient:
return response.json()
@api_error_handler
def get_summary(
self, filters: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
def get_summary(self, filters: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""Get the summary of a memory export.
Args:
@@ -598,17 +555,13 @@ class MemoryClient:
Dict containing the export status and summary data
"""
response = self.client.post(
"/v1/summary/", json=self._prepare_params({"filters": filters})
)
response = self.client.post("/v1/summary/", json=self._prepare_params({"filters": filters}))
response.raise_for_status()
capture_client_event("client.get_summary", self, {"sync_type": "sync"})
return response.json()
@api_error_handler
def get_project(
self, fields: Optional[List[str]] = None
) -> Dict[str, Any]:
def get_project(self, fields: Optional[List[str]] = None) -> Dict[str, Any]:
"""Get instructions or categories for the current project.
Args:
@@ -622,10 +575,7 @@ class MemoryClient:
ValueError: If org_id or project_id are not set.
"""
if not (self.org_id and self.project_id):
raise ValueError(
"org_id and project_id must be set to access instructions or "
"categories"
)
raise ValueError("org_id and project_id must be set to access instructions or categories")
params = self._prepare_params({"fields": fields})
response = self.client.get(
@@ -666,10 +616,7 @@ class MemoryClient:
ValueError: If org_id or project_id are not set.
"""
if not (self.org_id and self.project_id):
raise ValueError(
"org_id and project_id must be set to update instructions or "
"categories"
)
raise ValueError("org_id and project_id must be set to update instructions or categories")
if (
custom_instructions is None
@@ -826,10 +773,7 @@ class MemoryClient:
feedback = feedback.upper() if feedback else None
if feedback is not None and feedback not in VALID_FEEDBACK_VALUES:
raise ValueError(
f'feedback must be one of {", ".join(VALID_FEEDBACK_VALUES)} '
"or None"
)
raise ValueError(f"feedback must be one of {', '.join(VALID_FEEDBACK_VALUES)} or None")
data = {
"memory_id": memory_id,
@@ -839,14 +783,10 @@ class MemoryClient:
response = self.client.post("/v1/feedback/", json=data)
response.raise_for_status()
capture_client_event(
"client.feedback", self, data, {"sync_type": "sync"}
)
capture_client_event("client.feedback", self, data, {"sync_type": "sync"})
return response.json()
def _prepare_payload(
self, messages: List[Dict[str, str]], kwargs: Dict[str, Any]
) -> Dict[str, Any]:
def _prepare_payload(self, messages: List[Dict[str, str]], kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Prepare the payload for API requests.
Args:
@@ -862,9 +802,7 @@ class MemoryClient:
payload.update({k: v for k, v in kwargs.items() if v is not None})
return payload
def _prepare_params(
self, kwargs: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
def _prepare_params(self, kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""Prepare query parameters for API requests.
Args:
@@ -929,9 +867,7 @@ class AsyncMemoryClient:
self.user_id = get_user_id()
if not self.api_key:
raise ValueError(
"Mem0 API Key not provided. Please provide an API Key."
)
raise ValueError("Mem0 API Key not provided. Please provide an API Key.")
# Create MD5 hash of API key for user_id
self.user_id = hashlib.md5(self.api_key.encode()).hexdigest()
@@ -989,9 +925,7 @@ class AsyncMemoryClient:
error_message = str(e)
raise ValueError(f"Error: {error_message}")
def _prepare_payload(
self, messages: List[Dict[str, str]], kwargs: Dict[str, Any]
) -> Dict[str, Any]:
def _prepare_payload(self, messages: List[Dict[str, str]], kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Prepare the payload for API requests.
Args:
@@ -1007,9 +941,7 @@ class AsyncMemoryClient:
payload.update({k: v for k, v in kwargs.items() if v is not None})
return payload
def _prepare_params(
self, kwargs: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
def _prepare_params(self, kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""Prepare query parameters for API requests.
Args:
@@ -1041,9 +973,7 @@ class AsyncMemoryClient:
await self.async_client.aclose()
@api_error_handler
async def add(
self, messages: List[Dict[str, str]], **kwargs
) -> Dict[str, Any]:
async def add(self, messages: List[Dict[str, str]], **kwargs) -> Dict[str, Any]:
kwargs = self._prepare_params(kwargs)
if kwargs.get("output_format") != "v1.1":
kwargs["output_format"] = "v1.1"
@@ -1062,45 +992,31 @@ class AsyncMemoryClient:
response.raise_for_status()
if "metadata" in kwargs:
del kwargs["metadata"]
capture_client_event(
"client.add", self, {"keys": list(kwargs.keys()), "sync_type": "async"}
)
capture_client_event("client.add", self, {"keys": list(kwargs.keys()), "sync_type": "async"})
return response.json()
@api_error_handler
async def get(self, memory_id: str) -> Dict[str, Any]:
params = self._prepare_params()
response = await self.async_client.get(
f"/v1/memories/{memory_id}/", params=params
)
response = await self.async_client.get(f"/v1/memories/{memory_id}/", params=params)
response.raise_for_status()
capture_client_event(
"client.get", self, {"memory_id": memory_id, "sync_type": "async"}
)
capture_client_event("client.get", self, {"memory_id": memory_id, "sync_type": "async"})
return response.json()
@api_error_handler
async def get_all(
self, version: str = "v1", **kwargs
) -> List[Dict[str, Any]]:
async def get_all(self, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
params = self._prepare_params(kwargs)
if version == "v1":
response = await self.async_client.get(
f"/{version}/memories/", params=params
)
response = await self.async_client.get(f"/{version}/memories/", params=params)
elif version == "v2":
if "page" in params and "page_size" in params:
query_params = {
"page": params.pop("page"),
"page_size": params.pop("page_size"),
}
response = await self.async_client.post(
f"/{version}/memories/", json=params, params=query_params
)
response = await self.async_client.post(f"/{version}/memories/", json=params, params=query_params)
else:
response = await self.async_client.post(
f"/{version}/memories/", json=params
)
response = await self.async_client.post(f"/{version}/memories/", json=params)
response.raise_for_status()
if "metadata" in kwargs:
del kwargs["metadata"]
@@ -1116,14 +1032,10 @@ class AsyncMemoryClient:
return response.json()
@api_error_handler
async def search(
self, query: str, version: str = "v1", **kwargs
) -> List[Dict[str, Any]]:
async def search(self, query: str, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
payload = {"query": query}
payload.update(self._prepare_params(kwargs))
response = await self.async_client.post(
f"/{version}/memories/search/", json=payload
)
response = await self.async_client.post(f"/{version}/memories/search/", json=payload)
response.raise_for_status()
if "metadata" in kwargs:
del kwargs["metadata"]
@@ -1139,7 +1051,9 @@ class AsyncMemoryClient:
return response.json()
@api_error_handler
async def update(self, memory_id: str, text: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
async def update(
self, memory_id: str, text: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Update a memory by ID.
Args:
@@ -1265,10 +1179,7 @@ class AsyncMemoryClient:
else:
entities = await self.users()
# Filter entities based on provided IDs using list comprehension
to_delete = [
{"type": entity["type"], "name": entity["name"]}
for entity in entities["results"]
]
to_delete = [{"type": entity["type"], "name": entity["name"]} for entity in entities["results"]]
params = self._prepare_params()
@@ -1277,9 +1188,7 @@ class AsyncMemoryClient:
# Delete entities and check response immediately
for entity in to_delete:
response = await self.async_client.delete(
f"/v2/entities/{entity['type']}/{entity['name']}/", params=params
)
response = await self.async_client.delete(f"/v2/entities/{entity['type']}/{entity['name']}/", params=params)
response.raise_for_status()
capture_client_event(
@@ -1335,9 +1244,7 @@ class AsyncMemoryClient:
response = await self.async_client.put("/v1/batch/", json={"memories": memories})
response.raise_for_status()
capture_client_event(
"client.batch_update", self, {"sync_type": "async"}
)
capture_client_event("client.batch_update", self, {"sync_type": "async"})
return response.json()
@api_error_handler
@@ -1355,14 +1262,10 @@ class AsyncMemoryClient:
Raises:
APIError: If the API request fails.
"""
response = await self.async_client.request(
"DELETE", "/v1/batch/", json={"memories": memories}
)
response = await self.async_client.request("DELETE", "/v1/batch/", json={"memories": memories})
response.raise_for_status()
capture_client_event(
"client.batch_delete", self, {"sync_type": "async"}
)
capture_client_event("client.batch_delete", self, {"sync_type": "async"})
return response.json()
@api_error_handler
@@ -1614,7 +1517,7 @@ class AsyncMemoryClient:
feedback = feedback.upper() if feedback else None
if feedback is not None and feedback not in VALID_FEEDBACK_VALUES:
raise ValueError(f'feedback must be one of {", ".join(VALID_FEEDBACK_VALUES)} or None')
raise ValueError(f"feedback must be one of {', '.join(VALID_FEEDBACK_VALUES)} or None")
data = {"memory_id": memory_id, "feedback": feedback, "feedback_reason": feedback_reason}

View File

@@ -1,4 +1,3 @@
from enum import Enum
from typing import Any, Dict
from pydantic import BaseModel, Field, model_validator

View File

@@ -11,7 +11,7 @@ class MongoDBConfig(BaseModel):
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding vectors")
mongo_uri: str = Field("mongodb://localhost:27017", description="MongoDB URI. Default is mongodb://localhost:27017")
@model_validator(mode='before')
@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys())

View File

@@ -36,6 +36,6 @@ class OpenSearchConfig(BaseModel):
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. " f"Allowed fields: {', '.join(allowed_fields)}"
f"Extra fields not allowed: {', '.join(extra_fields)}. Allowed fields: {', '.join(allowed_fields)}"
)
return values

View File

@@ -92,10 +92,12 @@ class AWSBedrockLLM(LLMBase):
if response["output"]["message"]["content"]:
for item in response["output"]["message"]["content"]:
if "toolUse" in item:
processed_response["tool_calls"].append({
processed_response["tool_calls"].append(
{
"name": item["toolUse"]["name"],
"arguments": item["toolUse"]["input"],
})
}
)
return processed_response

View File

@@ -165,7 +165,6 @@ class GeminiLLM(LLMBase):
if system_instruction:
config_params["system_instruction"] = system_instruction
if response_format is not None and response_format["type"] == "json_object":
config_params["response_mime_type"] = "application/json"
if "schema" in response_format:
@@ -175,7 +174,6 @@ class GeminiLLM(LLMBase):
formatted_tools = self._reformat_tools(tools)
config_params["tools"] = formatted_tools
if tool_choice:
if tool_choice == "auto":
mode = types.FunctionCallingConfigMode.AUTO

View File

@@ -18,7 +18,7 @@ class SarvamLLM(LLMBase):
if not self.api_key:
raise ValueError(
"Sarvam API key is required. Set SARVAM_API_KEY environment variable " "or provide api_key in config."
"Sarvam API key is required. Set SARVAM_API_KEY environment variable or provide api_key in config."
)
# Set base URL - use config value or environment or default

View File

@@ -7,7 +7,6 @@ from openai import OpenAI
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
from mem0.memory.utils import extract_json
from openai import OpenAI
class VllmLLM(LLMBase):
@@ -41,10 +40,12 @@ class VllmLLM(LLMBase):
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append({
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(extract_json(tool_call.function.arguments)),
})
}
)
return processed_response
else:

View File

@@ -136,7 +136,6 @@ class MemoryGraph:
params = {"user_id": filters["user_id"]}
self.graph.query(cypher, params=params)
def get_all(self, filters, limit=100):
"""
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
@@ -176,7 +175,6 @@ class MemoryGraph:
return final_results
def _retrieve_nodes_from_data(self, data, filters):
"""Extracts all the entities mentioned in the query."""
_tools = [EXTRACT_ENTITIES_TOOL]
@@ -221,9 +219,7 @@ class MemoryGraph:
if self.config.graph_store.custom_prompt:
system_content = EXTRACT_RELATIONS_PROMPT.replace("USER_ID", user_identity)
# Add the custom prompt line if configured
system_content = system_content.replace(
"CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}"
)
system_content = system_content.replace("CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}")
messages = [
{"role": "system", "content": system_content},
{"role": "user", "content": data},
@@ -592,7 +588,6 @@ class MemoryGraph:
result = self.graph.query(cypher, params=params)
return result
def _search_destination_node(self, destination_embedding, filters, threshold=0.9):
agent_filter = ""
if filters.get("agent_id"):

View File

@@ -1,5 +1,5 @@
import logging
from typing import List, Optional, Dict, Any, Callable
from typing import List, Optional, Dict, Any
from pydantic import BaseModel
@@ -26,13 +26,7 @@ class MongoDB(VectorStoreBase):
VECTOR_TYPE = "knnVector"
SIMILARITY_METRIC = "cosine"
def __init__(
self,
db_name: str,
collection_name: str,
embedding_model_dims: int,
mongo_uri: str
):
def __init__(self, db_name: str, collection_name: str, embedding_model_dims: int, mongo_uri: str):
"""
Initialize the MongoDB vector store with vector search capabilities.
@@ -46,9 +40,7 @@ class MongoDB(VectorStoreBase):
self.embedding_model_dims = embedding_model_dims
self.db_name = db_name
self.client = MongoClient(
mongo_uri
)
self.client = MongoClient(mongo_uri)
self.db = self.client[db_name]
self.collection = self.create_col()
@@ -119,7 +111,9 @@ class MongoDB(VectorStoreBase):
except PyMongoError as e:
logger.error(f"Error inserting data: {e}")
def search(self, query: str, query_vector: List[float], limit=5, filters: Optional[Dict] = None) -> List[OutputData]:
def search(
self, query: str, query_vector: List[float], limit=5, filters: Optional[Dict] = None
) -> List[OutputData]:
"""
Search for similar vectors using the vector search index.

5797
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -30,11 +30,13 @@ vector_stores = [
"vecs>=0.4.0",
"chromadb>=0.4.24",
"weaviate-client>=4.4.0",
"pinecone<7.0.0",
"pinecone-text>=0.1.1",
"pinecone<=7.3.0",
"pinecone-text>=0.10.0",
"faiss-cpu>=1.7.4",
"upstash-vector>=0.1.0",
"azure-search-documents>=11.4.0b8",
"pymongo>=4.13.2",
"pymochow>=2.2.9",
]
llms = [
"groq>=0.3.0",
@@ -44,12 +46,11 @@ llms = [
"vertexai>=0.1.0",
"google-generativeai>=0.3.0",
"google-genai>=1.0.0",
]
extras = [
"boto3>=1.34.0",
"langchain-community>=0.0.0",
"sentence-transformers>=2.2.2",
"sentence-transformers>=5.0.0",
"elasticsearch>=8.0.0",
"opensearch-py>=2.0.0",
"langchain-memgraph>=0.1.0",

View File

@@ -1,4 +1,4 @@
from unittest.mock import patch
from unittest.mock import patch, ANY
import pytest
@@ -8,8 +8,10 @@ from mem0.embeddings.gemini import GoogleGenAIEmbedding
@pytest.fixture
def mock_genai():
with patch("mem0.embeddings.gemini.genai.embed_content") as mock_genai:
yield mock_genai
with patch("mem0.embeddings.gemini.genai.Client") as mock_client_class:
mock_client = mock_client_class.return_value
mock_client.models.embed_content.return_value = None
yield mock_client.models.embed_content
@pytest.fixture
@@ -18,7 +20,9 @@ def config():
def test_embed_query(mock_genai, config):
mock_embedding_response = {"embedding": [0.1, 0.2, 0.3, 0.4]}
mock_embedding_response = type('Response', (), {
'embeddings': [type('Embedding', (), {'values': [0.1, 0.2, 0.3, 0.4]})]
})()
mock_genai.return_value = mock_embedding_response
embedder = GoogleGenAIEmbedding(config)
@@ -27,10 +31,11 @@ def test_embed_query(mock_genai, config):
embedding = embedder.embed(text)
assert embedding == [0.1, 0.2, 0.3, 0.4]
mock_genai.assert_called_once_with(model="test_model", content="Hello, world!", output_dimensionality=786)
mock_genai.assert_called_once_with(model="test_model", contents="Hello, world!", config=ANY)
def test_embed_returns_empty_list_if_none(mock_genai, config):
mock_genai.return_value = None
mock_genai.return_value = type('Response', (), {'embeddings': [type('Embedding', (), {'values': []})]})()
embedder = GoogleGenAIEmbedding(config)
result = embedder.embed("test")
@@ -47,10 +52,10 @@ def test_embed_raises_on_error(mock_genai, config):
with pytest.raises(RuntimeError, match="Embedding failed"):
embedder.embed("some input")
def test_config_initialization(config):
embedder = GoogleGenAIEmbedding(config)
assert embedder.config.api_key == "dummy_api_key"
assert embedder.config.model == "test_model"
assert embedder.config.embedding_dims == 786

View File

@@ -9,7 +9,7 @@ from mem0.llms.gemini import GeminiLLM
@pytest.fixture
def mock_gemini_client():
with patch("mem0.llms.gemini.genai") as mock_client_class:
with patch("mem0.llms.gemini.genai.Client") as mock_client_class:
mock_client = Mock()
mock_client_class.return_value = mock_client
yield mock_client
@@ -24,45 +24,32 @@ def test_generate_response_without_tools(mock_gemini_client: Mock):
]
mock_part = Mock(text="I'm doing well, thank you for asking!")
mock_embedding = Mock()
mock_embedding.values = [0.1, 0.2, 0.3]
mock_response = Mock()
mock_response.candidates = [Mock()]
mock_response.candidates[0].content.parts = [Mock()]
mock_response.candidates[0].content.parts[0].text = "I'm doing well, thank you for asking!"
mock_content = Mock(parts=[mock_part])
mock_candidate = Mock(content=mock_content)
mock_response = Mock(candidates=[mock_candidate])
mock_gemini_client.models.generate_content.return_value = mock_response
mock_content = Mock(parts=[mock_part])
mock_message = Mock(content=mock_content)
mock_response = Mock(candidates=[mock_message])
mock_gemini_client.generate_content.return_value = mock_response
response = llm.generate_response(messages)
mock_gemini_client.generate_content.assert_called_once_with(
contents=[
{"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"},
{"parts": "Hello, how are you?", "role": "user"},
],
config=types.GenerateContentConfig(
temperature=0.7,
max_output_tokens=100,
top_p=1.0,
tools=None,
tool_config=types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(
allowed_function_names=None,
mode="auto"
# Check the actual call - system instruction is now in config
mock_gemini_client.models.generate_content.assert_called_once()
call_args = mock_gemini_client.models.generate_content.call_args
)
)
) )
# Verify model and contents
assert call_args.kwargs['model'] == "gemini-2.0-flash-latest"
assert len(call_args.kwargs['contents']) == 1 # Only user message
# Verify config has system instruction
config_arg = call_args.kwargs['config']
assert config_arg.system_instruction == "You are a helpful assistant."
assert config_arg.temperature == 0.7
assert config_arg.max_output_tokens == 100
assert config_arg.top_p == 1.0
assert response == "I'm doing well, thank you for asking!"
def test_generate_response_with_tools(mock_gemini_client: Mock):
config = BaseLlmConfig(model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0)
llm = GeminiLLM(config)
@@ -89,64 +76,42 @@ def test_generate_response_with_tools(mock_gemini_client: Mock):
mock_tool_call.name = "add_memory"
mock_tool_call.args = {"data": "Today is a sunny day."}
mock_part = Mock()
mock_part.function_call = mock_tool_call
mock_part.text = "I've added the memory for you."
# Create mock parts with both text and function_call
mock_text_part = Mock()
mock_text_part.text = "I've added the memory for you."
mock_text_part.function_call = None
mock_func_part = Mock()
mock_func_part.text = None
mock_func_part.function_call = mock_tool_call
mock_content = Mock()
mock_content.parts = [mock_part]
mock_content.parts = [mock_text_part, mock_func_part]
mock_message = Mock()
mock_message.content = mock_content
mock_candidate = Mock()
mock_candidate.content = mock_content
mock_response = Mock(candidates=[mock_message])
mock_gemini_client.generate_content.return_value = mock_response
mock_response = Mock(candidates=[mock_candidate])
mock_gemini_client.models.generate_content.return_value = mock_response
response = llm.generate_response(messages, tools=tools)
mock_gemini_client.generate_content.assert_called_once_with(
contents=[
{
"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.",
"role": "user"
},
{
"parts": "Add a new memory: Today is a sunny day.",
"role": "user"
},
],
config=types.GenerateContentConfig(
temperature=0.7,
max_output_tokens=100,
top_p=1.0,
tools=[
types.Tool(
function_declarations=[
types.FunctionDeclaration(
name="add_memory",
description="Add a memory",
parameters={
"type": "object",
"properties": {
"data": {
"type": "string",
"description": "Data to add to memory"
}
},
"required": ["data"]
}
)
]
)
],
tool_config=types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(
allowed_function_names=None,
mode="auto"
)
)
)
)
# Check the actual call
mock_gemini_client.models.generate_content.assert_called_once()
call_args = mock_gemini_client.models.generate_content.call_args
# Verify model and contents
assert call_args.kwargs['model'] == "gemini-1.5-flash-latest"
assert len(call_args.kwargs['contents']) == 1 # Only user message
# Verify config has system instruction and tools
config_arg = call_args.kwargs['config']
assert config_arg.system_instruction == "You are a helpful assistant."
assert config_arg.temperature == 0.7
assert config_arg.max_output_tokens == 100
assert config_arg.top_p == 1.0
assert len(config_arg.tools) == 1
assert config_arg.tool_config.function_calling_config.mode == types.FunctionCallingConfigMode.AUTO
assert response["content"] == "I've added the memory for you."
assert len(response["tool_calls"]) == 1

View File

@@ -43,6 +43,7 @@ def test_generate_response_without_tools(mock_lm_studio_client):
assert response == "I'm doing well, thank you for asking!"
def test_generate_response_specifying_response_format(mock_lm_studio_client):
config = BaseLlmConfig(
model="lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf",

View File

@@ -71,7 +71,13 @@ def test_generate_response_with_tools(mock_vllm_client):
response = llm.generate_response(messages, tools=tools)
mock_vllm_client.chat.completions.create.assert_called_once_with(
model="Qwen/Qwen2.5-32B-Instruct", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0, tools=tools, tool_choice="auto"
model="Qwen/Qwen2.5-32B-Instruct",
messages=messages,
temperature=0.7,
max_tokens=100,
top_p=1.0,
tools=tools,
tool_choice="auto",
)
assert response["content"] == "I've added the memory for you."

View File

@@ -253,11 +253,11 @@ def test_get_all(memory_instance, version, enable_graph, expected_result):
def test_custom_prompts(memory_custom_instance):
messages = [{"role": "user", "content": "Test message"}]
from mem0.embeddings.mock import MockEmbeddings
memory_custom_instance.llm.generate_response = Mock()
memory_custom_instance.llm.generate_response.return_value = '{"facts": ["fact1", "fact2"]}'
memory_custom_instance.embedding_model = MockEmbeddings()
with patch("mem0.memory.main.parse_messages", return_value="Test message") as mock_parse_messages:
with patch(
"mem0.memory.main.get_update_memory_messages", return_value="custom update memory prompt"

View File

@@ -2,10 +2,9 @@ from unittest.mock import Mock, patch, PropertyMock
import pytest
from mem0.vector_stores.baidu import BaiduDB, OutputData
from pymochow.model.enum import MetricType, TableState, ServerErrCode
from pymochow.model.schema import Field, Schema, VectorIndex, FilteringIndex, HNSWParams, AutoBuildRowCountIncrement
from pymochow.model.table import Partition, Row, VectorSearchConfig, VectorTopkSearchRequest, FloatVector, Table
from mem0.vector_stores.baidu import BaiduDB
from pymochow.model.enum import TableState, ServerErrCode
from pymochow.model.table import VectorSearchConfig, VectorTopkSearchRequest, FloatVector, Table
from pymochow.exception import ServerError

View File

@@ -1,7 +1,7 @@
import pytest
from unittest.mock import MagicMock, patch
from mem0.vector_stores.mongodb import MongoDB
from pymongo.operations import SearchIndexModel
@pytest.fixture
@patch("mem0.vector_stores.mongodb.MongoClient")
@@ -19,10 +19,11 @@ def mongo_vector_fixture(mock_mongo_client):
db_name="test_db",
collection_name="test_collection",
embedding_model_dims=1536,
mongo_uri="mongodb://username:password@localhost:27017"
mongo_uri="mongodb://username:password@localhost:27017",
)
return mongo_vector, mock_collection, mock_db
def test_initalize_create_col(mongo_vector_fixture):
mongo_vector, mock_collection, mock_db = mongo_vector_fixture
assert mongo_vector.collection_name == "test_collection"
@@ -49,12 +50,13 @@ def test_initalize_create_col(mongo_vector_fixture):
"dimensions": 1536,
"similarity": "cosine",
}
},
}
}
}
},
}
assert mongo_vector.collection == mock_collection
def test_insert(mongo_vector_fixture):
mongo_vector, mock_collection, _ = mongo_vector_fixture
vectors = [[0.1] * 1536, [0.2] * 1536]
@@ -62,12 +64,13 @@ def test_insert(mongo_vector_fixture):
ids = ["id1", "id2"]
mongo_vector.insert(vectors, payloads, ids)
expected_records=[
expected_records = [
({"_id": ids[0], "embedding": vectors[0], "payload": payloads[0]}),
({"_id": ids[1], "embedding": vectors[1], "payload": payloads[1]})
({"_id": ids[1], "embedding": vectors[1], "payload": payloads[1]}),
]
mock_collection.insert_many.assert_called_once_with(expected_records)
def test_search(mongo_vector_fixture):
mongo_vector, mock_collection, _ = mongo_vector_fixture
query_vector = [0.1] * 1536
@@ -79,7 +82,8 @@ def test_search(mongo_vector_fixture):
results = mongo_vector.search("query_str", query_vector, limit=2)
mock_collection.list_search_indexes.assert_called_with(name="test_collection_vector_index")
mock_collection.aggregate.assert_called_once_with([
mock_collection.aggregate.assert_called_once_with(
[
{
"$vectorSearch": {
"index": "test_collection_vector_index",
@@ -91,13 +95,15 @@ def test_search(mongo_vector_fixture):
},
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
{"$project": {"embedding": 0}},
])
]
)
assert len(results) == 2
assert results[0].id == "id1"
assert results[0].score == 0.9
assert results[1].id == "id2"
assert results[1].score == 0.8
def test_delete(mongo_vector_fixture):
mongo_vector, mock_collection, _ = mongo_vector_fixture
mock_delete_result = MagicMock()
@@ -107,6 +113,7 @@ def test_delete(mongo_vector_fixture):
mongo_vector.delete("id1")
mock_collection.delete_one.assert_called_with({"_id": "id1"})
def test_update(mongo_vector_fixture):
mongo_vector, mock_collection, _ = mongo_vector_fixture
mock_update_result = MagicMock()
@@ -122,6 +129,7 @@ def test_update(mongo_vector_fixture):
{"$set": {"embedding": vectorValue, "payload": payloadValue}},
)
def test_get(mongo_vector_fixture):
mongo_vector, mock_collection, _ = mongo_vector_fixture
mock_collection.find_one.return_value = {"_id": "id1", "payload": {"key": "value1"}}
@@ -131,6 +139,7 @@ def test_get(mongo_vector_fixture):
assert result.id == "id1"
assert result.payload == {"key": "value1"}
def test_list_cols(mongo_vector_fixture):
mongo_vector, _, mock_db = mongo_vector_fixture
mock_db.list_collection_names.return_value = ["col1", "col2"]
@@ -138,12 +147,14 @@ def test_list_cols(mongo_vector_fixture):
collections = mongo_vector.list_cols()
assert collections == ["col1", "col2"]
def test_delete_col(mongo_vector_fixture):
mongo_vector, mock_collection, _ = mongo_vector_fixture
mongo_vector.delete_col()
mock_collection.drop.assert_called_once()
def test_col_info(mongo_vector_fixture):
mongo_vector, _, mock_db = mongo_vector_fixture
mock_db.command.return_value = {"count": 10, "size": 1024}
@@ -154,6 +165,7 @@ def test_col_info(mongo_vector_fixture):
assert info["count"] == 10
assert info["size"] == 1024
def test_list(mongo_vector_fixture):
mongo_vector, mock_collection, _ = mongo_vector_fixture
mock_cursor = MagicMock()

View File

@@ -88,7 +88,7 @@ def test_get_vector_found(pinecone_db):
# or a list of dictionaries, not a dictionary with an 'id' field
# Create a mock Vector object
from pinecone.data.dataclasses.vector import Vector
from pinecone import Vector
mock_vector = Vector(id="id1", values=[0.1] * 128, metadata={"name": "vector1"})