Fix CI issues related to missing dependency (#3096)
This commit is contained in:
7
.github/workflows/ci.yml
vendored
7
.github/workflows/ci.yml
vendored
@@ -58,11 +58,10 @@ jobs:
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install -e ".[test,graph,vector_stores,llms,extras]"
|
||||
pip install ruff
|
||||
if: steps.cached-hatch-dependencies.outputs.cache-hit != 'true'
|
||||
- name: Run Formatting
|
||||
run: |
|
||||
mkdir -p .ruff_cache && chmod -R 777 .ruff_cache
|
||||
hatch run format
|
||||
- name: Run Linting
|
||||
run: make lint
|
||||
- name: Run tests and generate coverage report
|
||||
run: make test
|
||||
|
||||
|
||||
2
Makefile
2
Makefile
@@ -13,7 +13,7 @@ install:
|
||||
install_all:
|
||||
pip install ruff==0.6.9 groq together boto3 litellm ollama chromadb weaviate weaviate-client sentence_transformers vertexai \
|
||||
google-generativeai elasticsearch opensearch-py vecs "pinecone<7.0.0" pinecone-text faiss-cpu langchain-community \
|
||||
upstash-vector azure-search-documents langchain-memgraph langchain-neo4j rank-bm25 pymochow
|
||||
upstash-vector azure-search-documents langchain-memgraph langchain-neo4j rank-bm25 pymochow pymongo
|
||||
|
||||
# Format code with ruff
|
||||
format:
|
||||
|
||||
@@ -110,7 +110,7 @@ def main():
|
||||
print("All categories accuracy:")
|
||||
for cat, results in LLM_JUDGE.items():
|
||||
if results: # Only print if there are results for this category
|
||||
print(f" Category {cat}: {np.mean(results):.4f} " f"({sum(results)}/{len(results)})")
|
||||
print(f" Category {cat}: {np.mean(results):.4f} ({sum(results)}/{len(results)})")
|
||||
print("------------------------------------------")
|
||||
index += 1
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ class RAGManager:
|
||||
def clean_chat_history(self, chat_history):
|
||||
cleaned_chat_history = ""
|
||||
for c in chat_history:
|
||||
cleaned_chat_history += f"{c['timestamp']} | {c['speaker']}: " f"{c['text']}\n"
|
||||
cleaned_chat_history += f"{c['timestamp']} | {c['speaker']}: {c['text']}\n"
|
||||
|
||||
return cleaned_chat_history
|
||||
|
||||
|
||||
@@ -68,9 +68,7 @@
|
||||
"\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = (\n",
|
||||
" \"\"\n",
|
||||
")"
|
||||
"os.environ[\"OPENAI_API_KEY\"] = \"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -164,7 +162,7 @@
|
||||
" \"role\": \"assistant\",\n",
|
||||
" \"content\": \"Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future.\",\n",
|
||||
" },\n",
|
||||
"]\n"
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -185,9 +183,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Store inferred memories (default behavior)\n",
|
||||
"result = m.add(\n",
|
||||
" messages, user_id=\"alice\"\n",
|
||||
")"
|
||||
"result = m.add(messages, user_id=\"alice\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -45,11 +45,7 @@ def get_food_recommendation(user_query: str, user_id):
|
||||
"""Get food recommendation with memory context"""
|
||||
|
||||
# Search memory for relevant food preferences
|
||||
memories_result = memory_client.search(
|
||||
query=user_query,
|
||||
user_id=user_id,
|
||||
limit=5
|
||||
)
|
||||
memories_result = memory_client.search(query=user_query, user_id=user_id, limit=5)
|
||||
|
||||
# Add memory context to the message
|
||||
memories = [f"- {result['memory']}" for result in memories_result]
|
||||
@@ -71,6 +67,7 @@ def get_food_recommendation(user_query: str, user_id):
|
||||
# Save audio file
|
||||
if response.audio:
|
||||
import time
|
||||
|
||||
timestamp = int(time.time())
|
||||
filename = f"food_recommendation_{timestamp}.mp3"
|
||||
write_audio_to_file(
|
||||
@@ -118,7 +115,11 @@ def initialize_food_memory(user_id):
|
||||
# Initialize the memory for the user once in order for the agent to learn the user preference
|
||||
initialize_food_memory(user_id=USER_ID)
|
||||
|
||||
print(get_food_recommendation("Which type of restaurants should I go tonight for dinner and cuisines preferred?", user_id=USER_ID))
|
||||
print(
|
||||
get_food_recommendation(
|
||||
"Which type of restaurants should I go tonight for dinner and cuisines preferred?", user_id=USER_ID
|
||||
)
|
||||
)
|
||||
# OUTPUT: 🎵 Audio saved as food_recommendation_1750162610.mp3
|
||||
# For dinner tonight, considering your love for healthy spic optionsy, you could try a nice Thai, Indian, or Mexican restaurant.
|
||||
# You might find dishes with quinoa, chickpeas, tofu, and fresh herbs delightful. Enjoy your dinner!
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from agents import Agent, Runner, function_tool, handoffs, enable_verbose_stdout_logging
|
||||
from agents import Agent, Runner, function_tool, enable_verbose_stdout_logging
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from mem0 import MemoryClient
|
||||
@@ -35,7 +35,7 @@ travel_agent = Agent(
|
||||
understand the user's travel preferences and history before making recommendations.
|
||||
After providing your response, use store_conversation to save important details.""",
|
||||
tools=[search_memory, save_memory],
|
||||
model="gpt-4o"
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
health_agent = Agent(
|
||||
@@ -44,7 +44,7 @@ health_agent = Agent(
|
||||
understand the user's health goals and dietary preferences.
|
||||
After providing advice, use store_conversation to save relevant information.""",
|
||||
tools=[search_memory, save_memory],
|
||||
model="gpt-4o"
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
# Triage agent with handoffs
|
||||
@@ -55,7 +55,7 @@ triage_agent = Agent(
|
||||
For health-related questions (fitness, diet, wellness, exercise), hand off to Health Advisor.
|
||||
For general questions, you can handle them directly using available tools.""",
|
||||
handoffs=[travel_agent, health_agent],
|
||||
model="gpt-4o"
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
|
||||
@@ -74,10 +74,7 @@ def chat_with_handoffs(user_input: str, user_id: str) -> str:
|
||||
result = Runner.run_sync(triage_agent, user_input)
|
||||
|
||||
# Store the original conversation in memory
|
||||
conversation = [
|
||||
{"role": "user", "content": user_input},
|
||||
{"role": "assistant", "content": result.final_output}
|
||||
]
|
||||
conversation = [{"role": "user", "content": user_input}, {"role": "assistant", "content": result.final_output}]
|
||||
mem0.add(conversation, user_id=user_id)
|
||||
|
||||
return result.final_output
|
||||
|
||||
@@ -34,24 +34,16 @@ config = {
|
||||
"api_key": "vllm-api-key",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 100,
|
||||
}
|
||||
},
|
||||
"embedder": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small"
|
||||
}
|
||||
},
|
||||
"embedder": {"provider": "openai", "config": {"model": "text-embedding-3-small"}},
|
||||
"vector_store": {
|
||||
"provider": "qdrant",
|
||||
"config": {
|
||||
"collection_name": "vllm_memories",
|
||||
"host": "localhost",
|
||||
"port": 6333
|
||||
}
|
||||
}
|
||||
"config": {"collection_name": "vllm_memories", "host": "localhost", "port": 6333},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Demonstrate vLLM integration with mem0
|
||||
@@ -68,34 +60,40 @@ def main():
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "I love playing chess on weekends"},
|
||||
{"role": "assistant", "content": "That's great! Chess is an excellent strategic game that helps improve critical thinking."}
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "That's great! Chess is an excellent strategic game that helps improve critical thinking.",
|
||||
},
|
||||
],
|
||||
"user_id": "user_123"
|
||||
"user_id": "user_123",
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "I'm learning Python programming"},
|
||||
{"role": "assistant", "content": "Python is a fantastic language for beginners! What specific areas are you focusing on?"}
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Python is a fantastic language for beginners! What specific areas are you focusing on?",
|
||||
},
|
||||
],
|
||||
"user_id": "user_123"
|
||||
"user_id": "user_123",
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "I prefer working late at night, I'm more productive then"},
|
||||
{"role": "assistant", "content": "Many people find they're more creative and focused during nighttime hours. It's important to maintain a consistent schedule that works for you."}
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Many people find they're more creative and focused during nighttime hours. It's important to maintain a consistent schedule that works for you.",
|
||||
},
|
||||
],
|
||||
"user_id": "user_123"
|
||||
}
|
||||
"user_id": "user_123",
|
||||
},
|
||||
]
|
||||
|
||||
print("\n--> Adding memories using vLLM...")
|
||||
|
||||
# Add memories - now powered by vLLM's high-performance inference
|
||||
for i, conversation in enumerate(conversations, 1):
|
||||
result = memory.add(
|
||||
messages=conversation["messages"],
|
||||
user_id=conversation["user_id"]
|
||||
)
|
||||
result = memory.add(messages=conversation["messages"], user_id=conversation["user_id"])
|
||||
print(f"Memory {i} added: {result}")
|
||||
|
||||
print("\n🔍 Searching memories...")
|
||||
@@ -104,15 +102,12 @@ def main():
|
||||
search_queries = [
|
||||
"What does the user like to do on weekends?",
|
||||
"What is the user learning?",
|
||||
"When is the user most productive?"
|
||||
"When is the user most productive?",
|
||||
]
|
||||
|
||||
for query in search_queries:
|
||||
print(f"\nQuery: {query}")
|
||||
memories = memory.search(
|
||||
query=query,
|
||||
user_id="user_123"
|
||||
)
|
||||
memories = memory.search(query=query, user_id="user_123")
|
||||
|
||||
for memory_item in memories:
|
||||
print(f" - {memory_item['memory']}")
|
||||
|
||||
@@ -89,9 +89,7 @@ class MemoryClient:
|
||||
self.user_id = get_user_id()
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"Mem0 API Key not provided. Please provide an API Key."
|
||||
)
|
||||
raise ValueError("Mem0 API Key not provided. Please provide an API Key.")
|
||||
|
||||
# Create MD5 hash of API key for user_id
|
||||
self.user_id = hashlib.md5(self.api_key.encode()).hexdigest()
|
||||
@@ -174,9 +172,7 @@ class MemoryClient:
|
||||
response.raise_for_status()
|
||||
if "metadata" in kwargs:
|
||||
del kwargs["metadata"]
|
||||
capture_client_event(
|
||||
"client.add", self, {"keys": list(kwargs.keys()), "sync_type": "sync"}
|
||||
)
|
||||
capture_client_event("client.add", self, {"keys": list(kwargs.keys()), "sync_type": "sync"})
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
@@ -195,9 +191,7 @@ class MemoryClient:
|
||||
params = self._prepare_params()
|
||||
response = self.client.get(f"/v1/memories/{memory_id}/", params=params)
|
||||
response.raise_for_status()
|
||||
capture_client_event(
|
||||
"client.get", self, {"memory_id": memory_id, "sync_type": "sync"}
|
||||
)
|
||||
capture_client_event("client.get", self, {"memory_id": memory_id, "sync_type": "sync"})
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
@@ -224,13 +218,9 @@ class MemoryClient:
|
||||
"page": params.pop("page"),
|
||||
"page_size": params.pop("page_size"),
|
||||
}
|
||||
response = self.client.post(
|
||||
f"/{version}/memories/", json=params, params=query_params
|
||||
)
|
||||
response = self.client.post(f"/{version}/memories/", json=params, params=query_params)
|
||||
else:
|
||||
response = self.client.post(
|
||||
f"/{version}/memories/", json=params
|
||||
)
|
||||
response = self.client.post(f"/{version}/memories/", json=params)
|
||||
response.raise_for_status()
|
||||
if "metadata" in kwargs:
|
||||
del kwargs["metadata"]
|
||||
@@ -246,9 +236,7 @@ class MemoryClient:
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
def search(
|
||||
self, query: str, version: str = "v1", **kwargs
|
||||
) -> List[Dict[str, Any]]:
|
||||
def search(self, query: str, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
|
||||
"""Search memories based on a query.
|
||||
|
||||
Args:
|
||||
@@ -266,9 +254,7 @@ class MemoryClient:
|
||||
payload = {"query": query}
|
||||
params = self._prepare_params(kwargs)
|
||||
payload.update(params)
|
||||
response = self.client.post(
|
||||
f"/{version}/memories/search/", json=payload
|
||||
)
|
||||
response = self.client.post(f"/{version}/memories/search/", json=payload)
|
||||
response.raise_for_status()
|
||||
if "metadata" in kwargs:
|
||||
del kwargs["metadata"]
|
||||
@@ -308,13 +294,9 @@ class MemoryClient:
|
||||
if metadata is not None:
|
||||
payload["metadata"] = metadata
|
||||
|
||||
capture_client_event(
|
||||
"client.update", self, {"memory_id": memory_id, "sync_type": "sync"}
|
||||
)
|
||||
capture_client_event("client.update", self, {"memory_id": memory_id, "sync_type": "sync"})
|
||||
params = self._prepare_params()
|
||||
response = self.client.put(
|
||||
f"/v1/memories/{memory_id}/", json=payload, params=params
|
||||
)
|
||||
response = self.client.put(f"/v1/memories/{memory_id}/", json=payload, params=params)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@@ -332,13 +314,9 @@ class MemoryClient:
|
||||
APIError: If the API request fails.
|
||||
"""
|
||||
params = self._prepare_params()
|
||||
response = self.client.delete(
|
||||
f"/v1/memories/{memory_id}/", params=params
|
||||
)
|
||||
response = self.client.delete(f"/v1/memories/{memory_id}/", params=params)
|
||||
response.raise_for_status()
|
||||
capture_client_event(
|
||||
"client.delete", self, {"memory_id": memory_id, "sync_type": "sync"}
|
||||
)
|
||||
capture_client_event("client.delete", self, {"memory_id": memory_id, "sync_type": "sync"})
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
@@ -379,13 +357,9 @@ class MemoryClient:
|
||||
APIError: If the API request fails.
|
||||
"""
|
||||
params = self._prepare_params()
|
||||
response = self.client.get(
|
||||
f"/v1/memories/{memory_id}/history/", params=params
|
||||
)
|
||||
response = self.client.get(f"/v1/memories/{memory_id}/history/", params=params)
|
||||
response.raise_for_status()
|
||||
capture_client_event(
|
||||
"client.history", self, {"memory_id": memory_id, "sync_type": "sync"}
|
||||
)
|
||||
capture_client_event("client.history", self, {"memory_id": memory_id, "sync_type": "sync"})
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
@@ -432,10 +406,7 @@ class MemoryClient:
|
||||
else:
|
||||
entities = self.users()
|
||||
# Filter entities based on provided IDs using list comprehension
|
||||
to_delete = [
|
||||
{"type": entity["type"], "name": entity["name"]}
|
||||
for entity in entities["results"]
|
||||
]
|
||||
to_delete = [{"type": entity["type"], "name": entity["name"]} for entity in entities["results"]]
|
||||
|
||||
params = self._prepare_params()
|
||||
|
||||
@@ -444,9 +415,7 @@ class MemoryClient:
|
||||
|
||||
# Delete entities and check response immediately
|
||||
for entity in to_delete:
|
||||
response = self.client.delete(
|
||||
f"/v2/entities/{entity['type']}/{entity['name']}/", params=params
|
||||
)
|
||||
response = self.client.delete(f"/v2/entities/{entity['type']}/{entity['name']}/", params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
capture_client_event(
|
||||
@@ -484,9 +453,7 @@ class MemoryClient:
|
||||
self.delete_users()
|
||||
|
||||
capture_client_event("client.reset", self, {"sync_type": "sync"})
|
||||
return {
|
||||
"message": "Client reset successful. All users and memories deleted."
|
||||
}
|
||||
return {"message": "Client reset successful. All users and memories deleted."}
|
||||
|
||||
@api_error_handler
|
||||
def batch_update(self, memories: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
@@ -507,9 +474,7 @@ class MemoryClient:
|
||||
response = self.client.put("/v1/batch/", json={"memories": memories})
|
||||
response.raise_for_status()
|
||||
|
||||
capture_client_event(
|
||||
"client.batch_update", self, {"sync_type": "sync"}
|
||||
)
|
||||
capture_client_event("client.batch_update", self, {"sync_type": "sync"})
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
@@ -527,14 +492,10 @@ class MemoryClient:
|
||||
Raises:
|
||||
APIError: If the API request fails.
|
||||
"""
|
||||
response = self.client.request(
|
||||
"DELETE", "/v1/batch/", json={"memories": memories}
|
||||
)
|
||||
response = self.client.request("DELETE", "/v1/batch/", json={"memories": memories})
|
||||
response.raise_for_status()
|
||||
|
||||
capture_client_event(
|
||||
"client.batch_delete", self, {"sync_type": "sync"}
|
||||
)
|
||||
capture_client_event("client.batch_delete", self, {"sync_type": "sync"})
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
@@ -574,9 +535,7 @@ class MemoryClient:
|
||||
Returns:
|
||||
Dict containing the exported data
|
||||
"""
|
||||
response = self.client.post(
|
||||
"/v1/exports/get/", json=self._prepare_params(kwargs)
|
||||
)
|
||||
response = self.client.post("/v1/exports/get/", json=self._prepare_params(kwargs))
|
||||
response.raise_for_status()
|
||||
capture_client_event(
|
||||
"client.get_memory_export",
|
||||
@@ -586,9 +545,7 @@ class MemoryClient:
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
def get_summary(
|
||||
self, filters: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
def get_summary(self, filters: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""Get the summary of a memory export.
|
||||
|
||||
Args:
|
||||
@@ -598,17 +555,13 @@ class MemoryClient:
|
||||
Dict containing the export status and summary data
|
||||
"""
|
||||
|
||||
response = self.client.post(
|
||||
"/v1/summary/", json=self._prepare_params({"filters": filters})
|
||||
)
|
||||
response = self.client.post("/v1/summary/", json=self._prepare_params({"filters": filters}))
|
||||
response.raise_for_status()
|
||||
capture_client_event("client.get_summary", self, {"sync_type": "sync"})
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
def get_project(
|
||||
self, fields: Optional[List[str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
def get_project(self, fields: Optional[List[str]] = None) -> Dict[str, Any]:
|
||||
"""Get instructions or categories for the current project.
|
||||
|
||||
Args:
|
||||
@@ -622,10 +575,7 @@ class MemoryClient:
|
||||
ValueError: If org_id or project_id are not set.
|
||||
"""
|
||||
if not (self.org_id and self.project_id):
|
||||
raise ValueError(
|
||||
"org_id and project_id must be set to access instructions or "
|
||||
"categories"
|
||||
)
|
||||
raise ValueError("org_id and project_id must be set to access instructions or categories")
|
||||
|
||||
params = self._prepare_params({"fields": fields})
|
||||
response = self.client.get(
|
||||
@@ -666,10 +616,7 @@ class MemoryClient:
|
||||
ValueError: If org_id or project_id are not set.
|
||||
"""
|
||||
if not (self.org_id and self.project_id):
|
||||
raise ValueError(
|
||||
"org_id and project_id must be set to update instructions or "
|
||||
"categories"
|
||||
)
|
||||
raise ValueError("org_id and project_id must be set to update instructions or categories")
|
||||
|
||||
if (
|
||||
custom_instructions is None
|
||||
@@ -826,10 +773,7 @@ class MemoryClient:
|
||||
|
||||
feedback = feedback.upper() if feedback else None
|
||||
if feedback is not None and feedback not in VALID_FEEDBACK_VALUES:
|
||||
raise ValueError(
|
||||
f'feedback must be one of {", ".join(VALID_FEEDBACK_VALUES)} '
|
||||
"or None"
|
||||
)
|
||||
raise ValueError(f"feedback must be one of {', '.join(VALID_FEEDBACK_VALUES)} or None")
|
||||
|
||||
data = {
|
||||
"memory_id": memory_id,
|
||||
@@ -839,14 +783,10 @@ class MemoryClient:
|
||||
|
||||
response = self.client.post("/v1/feedback/", json=data)
|
||||
response.raise_for_status()
|
||||
capture_client_event(
|
||||
"client.feedback", self, data, {"sync_type": "sync"}
|
||||
)
|
||||
capture_client_event("client.feedback", self, data, {"sync_type": "sync"})
|
||||
return response.json()
|
||||
|
||||
def _prepare_payload(
|
||||
self, messages: List[Dict[str, str]], kwargs: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
def _prepare_payload(self, messages: List[Dict[str, str]], kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Prepare the payload for API requests.
|
||||
|
||||
Args:
|
||||
@@ -862,9 +802,7 @@ class MemoryClient:
|
||||
payload.update({k: v for k, v in kwargs.items() if v is not None})
|
||||
return payload
|
||||
|
||||
def _prepare_params(
|
||||
self, kwargs: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
def _prepare_params(self, kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""Prepare query parameters for API requests.
|
||||
|
||||
Args:
|
||||
@@ -929,9 +867,7 @@ class AsyncMemoryClient:
|
||||
self.user_id = get_user_id()
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"Mem0 API Key not provided. Please provide an API Key."
|
||||
)
|
||||
raise ValueError("Mem0 API Key not provided. Please provide an API Key.")
|
||||
|
||||
# Create MD5 hash of API key for user_id
|
||||
self.user_id = hashlib.md5(self.api_key.encode()).hexdigest()
|
||||
@@ -989,9 +925,7 @@ class AsyncMemoryClient:
|
||||
error_message = str(e)
|
||||
raise ValueError(f"Error: {error_message}")
|
||||
|
||||
def _prepare_payload(
|
||||
self, messages: List[Dict[str, str]], kwargs: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
def _prepare_payload(self, messages: List[Dict[str, str]], kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Prepare the payload for API requests.
|
||||
|
||||
Args:
|
||||
@@ -1007,9 +941,7 @@ class AsyncMemoryClient:
|
||||
payload.update({k: v for k, v in kwargs.items() if v is not None})
|
||||
return payload
|
||||
|
||||
def _prepare_params(
|
||||
self, kwargs: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
def _prepare_params(self, kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""Prepare query parameters for API requests.
|
||||
|
||||
Args:
|
||||
@@ -1041,9 +973,7 @@ class AsyncMemoryClient:
|
||||
await self.async_client.aclose()
|
||||
|
||||
@api_error_handler
|
||||
async def add(
|
||||
self, messages: List[Dict[str, str]], **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
async def add(self, messages: List[Dict[str, str]], **kwargs) -> Dict[str, Any]:
|
||||
kwargs = self._prepare_params(kwargs)
|
||||
if kwargs.get("output_format") != "v1.1":
|
||||
kwargs["output_format"] = "v1.1"
|
||||
@@ -1062,45 +992,31 @@ class AsyncMemoryClient:
|
||||
response.raise_for_status()
|
||||
if "metadata" in kwargs:
|
||||
del kwargs["metadata"]
|
||||
capture_client_event(
|
||||
"client.add", self, {"keys": list(kwargs.keys()), "sync_type": "async"}
|
||||
)
|
||||
capture_client_event("client.add", self, {"keys": list(kwargs.keys()), "sync_type": "async"})
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
async def get(self, memory_id: str) -> Dict[str, Any]:
|
||||
params = self._prepare_params()
|
||||
response = await self.async_client.get(
|
||||
f"/v1/memories/{memory_id}/", params=params
|
||||
)
|
||||
response = await self.async_client.get(f"/v1/memories/{memory_id}/", params=params)
|
||||
response.raise_for_status()
|
||||
capture_client_event(
|
||||
"client.get", self, {"memory_id": memory_id, "sync_type": "async"}
|
||||
)
|
||||
capture_client_event("client.get", self, {"memory_id": memory_id, "sync_type": "async"})
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
async def get_all(
|
||||
self, version: str = "v1", **kwargs
|
||||
) -> List[Dict[str, Any]]:
|
||||
async def get_all(self, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
|
||||
params = self._prepare_params(kwargs)
|
||||
if version == "v1":
|
||||
response = await self.async_client.get(
|
||||
f"/{version}/memories/", params=params
|
||||
)
|
||||
response = await self.async_client.get(f"/{version}/memories/", params=params)
|
||||
elif version == "v2":
|
||||
if "page" in params and "page_size" in params:
|
||||
query_params = {
|
||||
"page": params.pop("page"),
|
||||
"page_size": params.pop("page_size"),
|
||||
}
|
||||
response = await self.async_client.post(
|
||||
f"/{version}/memories/", json=params, params=query_params
|
||||
)
|
||||
response = await self.async_client.post(f"/{version}/memories/", json=params, params=query_params)
|
||||
else:
|
||||
response = await self.async_client.post(
|
||||
f"/{version}/memories/", json=params
|
||||
)
|
||||
response = await self.async_client.post(f"/{version}/memories/", json=params)
|
||||
response.raise_for_status()
|
||||
if "metadata" in kwargs:
|
||||
del kwargs["metadata"]
|
||||
@@ -1116,14 +1032,10 @@ class AsyncMemoryClient:
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
async def search(
|
||||
self, query: str, version: str = "v1", **kwargs
|
||||
) -> List[Dict[str, Any]]:
|
||||
async def search(self, query: str, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
|
||||
payload = {"query": query}
|
||||
payload.update(self._prepare_params(kwargs))
|
||||
response = await self.async_client.post(
|
||||
f"/{version}/memories/search/", json=payload
|
||||
)
|
||||
response = await self.async_client.post(f"/{version}/memories/search/", json=payload)
|
||||
response.raise_for_status()
|
||||
if "metadata" in kwargs:
|
||||
del kwargs["metadata"]
|
||||
@@ -1139,7 +1051,9 @@ class AsyncMemoryClient:
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
async def update(self, memory_id: str, text: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
async def update(
|
||||
self, memory_id: str, text: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Update a memory by ID.
|
||||
Args:
|
||||
@@ -1265,10 +1179,7 @@ class AsyncMemoryClient:
|
||||
else:
|
||||
entities = await self.users()
|
||||
# Filter entities based on provided IDs using list comprehension
|
||||
to_delete = [
|
||||
{"type": entity["type"], "name": entity["name"]}
|
||||
for entity in entities["results"]
|
||||
]
|
||||
to_delete = [{"type": entity["type"], "name": entity["name"]} for entity in entities["results"]]
|
||||
|
||||
params = self._prepare_params()
|
||||
|
||||
@@ -1277,9 +1188,7 @@ class AsyncMemoryClient:
|
||||
|
||||
# Delete entities and check response immediately
|
||||
for entity in to_delete:
|
||||
response = await self.async_client.delete(
|
||||
f"/v2/entities/{entity['type']}/{entity['name']}/", params=params
|
||||
)
|
||||
response = await self.async_client.delete(f"/v2/entities/{entity['type']}/{entity['name']}/", params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
capture_client_event(
|
||||
@@ -1335,9 +1244,7 @@ class AsyncMemoryClient:
|
||||
response = await self.async_client.put("/v1/batch/", json={"memories": memories})
|
||||
response.raise_for_status()
|
||||
|
||||
capture_client_event(
|
||||
"client.batch_update", self, {"sync_type": "async"}
|
||||
)
|
||||
capture_client_event("client.batch_update", self, {"sync_type": "async"})
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
@@ -1355,14 +1262,10 @@ class AsyncMemoryClient:
|
||||
Raises:
|
||||
APIError: If the API request fails.
|
||||
"""
|
||||
response = await self.async_client.request(
|
||||
"DELETE", "/v1/batch/", json={"memories": memories}
|
||||
)
|
||||
response = await self.async_client.request("DELETE", "/v1/batch/", json={"memories": memories})
|
||||
response.raise_for_status()
|
||||
|
||||
capture_client_event(
|
||||
"client.batch_delete", self, {"sync_type": "async"}
|
||||
)
|
||||
capture_client_event("client.batch_delete", self, {"sync_type": "async"})
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
@@ -1614,7 +1517,7 @@ class AsyncMemoryClient:
|
||||
|
||||
feedback = feedback.upper() if feedback else None
|
||||
if feedback is not None and feedback not in VALID_FEEDBACK_VALUES:
|
||||
raise ValueError(f'feedback must be one of {", ".join(VALID_FEEDBACK_VALUES)} or None')
|
||||
raise ValueError(f"feedback must be one of {', '.join(VALID_FEEDBACK_VALUES)} or None")
|
||||
|
||||
data = {"memory_id": memory_id, "feedback": feedback, "feedback_reason": feedback_reason}
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
@@ -11,7 +11,7 @@ class MongoDBConfig(BaseModel):
|
||||
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding vectors")
|
||||
mongo_uri: str = Field("mongodb://localhost:27017", description="MongoDB URI. Default is mongodb://localhost:27017")
|
||||
|
||||
@model_validator(mode='before')
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
allowed_fields = set(cls.model_fields.keys())
|
||||
|
||||
@@ -36,6 +36,6 @@ class OpenSearchConfig(BaseModel):
|
||||
extra_fields = input_fields - allowed_fields
|
||||
if extra_fields:
|
||||
raise ValueError(
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. " f"Allowed fields: {', '.join(allowed_fields)}"
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. Allowed fields: {', '.join(allowed_fields)}"
|
||||
)
|
||||
return values
|
||||
|
||||
@@ -92,10 +92,12 @@ class AWSBedrockLLM(LLMBase):
|
||||
if response["output"]["message"]["content"]:
|
||||
for item in response["output"]["message"]["content"]:
|
||||
if "toolUse" in item:
|
||||
processed_response["tool_calls"].append({
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": item["toolUse"]["name"],
|
||||
"arguments": item["toolUse"]["input"],
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
|
||||
|
||||
@@ -165,7 +165,6 @@ class GeminiLLM(LLMBase):
|
||||
if system_instruction:
|
||||
config_params["system_instruction"] = system_instruction
|
||||
|
||||
|
||||
if response_format is not None and response_format["type"] == "json_object":
|
||||
config_params["response_mime_type"] = "application/json"
|
||||
if "schema" in response_format:
|
||||
@@ -175,7 +174,6 @@ class GeminiLLM(LLMBase):
|
||||
formatted_tools = self._reformat_tools(tools)
|
||||
config_params["tools"] = formatted_tools
|
||||
|
||||
|
||||
if tool_choice:
|
||||
if tool_choice == "auto":
|
||||
mode = types.FunctionCallingConfigMode.AUTO
|
||||
|
||||
@@ -18,7 +18,7 @@ class SarvamLLM(LLMBase):
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"Sarvam API key is required. Set SARVAM_API_KEY environment variable " "or provide api_key in config."
|
||||
"Sarvam API key is required. Set SARVAM_API_KEY environment variable or provide api_key in config."
|
||||
)
|
||||
|
||||
# Set base URL - use config value or environment or default
|
||||
|
||||
@@ -7,7 +7,6 @@ from openai import OpenAI
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.memory.utils import extract_json
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
class VllmLLM(LLMBase):
|
||||
@@ -41,10 +40,12 @@ class VllmLLM(LLMBase):
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
processed_response["tool_calls"].append({
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(extract_json(tool_call.function.arguments)),
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
else:
|
||||
|
||||
@@ -136,7 +136,6 @@ class MemoryGraph:
|
||||
params = {"user_id": filters["user_id"]}
|
||||
self.graph.query(cypher, params=params)
|
||||
|
||||
|
||||
def get_all(self, filters, limit=100):
|
||||
"""
|
||||
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
|
||||
@@ -176,7 +175,6 @@ class MemoryGraph:
|
||||
|
||||
return final_results
|
||||
|
||||
|
||||
def _retrieve_nodes_from_data(self, data, filters):
|
||||
"""Extracts all the entities mentioned in the query."""
|
||||
_tools = [EXTRACT_ENTITIES_TOOL]
|
||||
@@ -221,9 +219,7 @@ class MemoryGraph:
|
||||
if self.config.graph_store.custom_prompt:
|
||||
system_content = EXTRACT_RELATIONS_PROMPT.replace("USER_ID", user_identity)
|
||||
# Add the custom prompt line if configured
|
||||
system_content = system_content.replace(
|
||||
"CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}"
|
||||
)
|
||||
system_content = system_content.replace("CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}")
|
||||
messages = [
|
||||
{"role": "system", "content": system_content},
|
||||
{"role": "user", "content": data},
|
||||
@@ -592,7 +588,6 @@ class MemoryGraph:
|
||||
result = self.graph.query(cypher, params=params)
|
||||
return result
|
||||
|
||||
|
||||
def _search_destination_node(self, destination_embedding, filters, threshold=0.9):
|
||||
agent_filter = ""
|
||||
if filters.get("agent_id"):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any, Callable
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -26,13 +26,7 @@ class MongoDB(VectorStoreBase):
|
||||
VECTOR_TYPE = "knnVector"
|
||||
SIMILARITY_METRIC = "cosine"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_name: str,
|
||||
collection_name: str,
|
||||
embedding_model_dims: int,
|
||||
mongo_uri: str
|
||||
):
|
||||
def __init__(self, db_name: str, collection_name: str, embedding_model_dims: int, mongo_uri: str):
|
||||
"""
|
||||
Initialize the MongoDB vector store with vector search capabilities.
|
||||
|
||||
@@ -46,9 +40,7 @@ class MongoDB(VectorStoreBase):
|
||||
self.embedding_model_dims = embedding_model_dims
|
||||
self.db_name = db_name
|
||||
|
||||
self.client = MongoClient(
|
||||
mongo_uri
|
||||
)
|
||||
self.client = MongoClient(mongo_uri)
|
||||
self.db = self.client[db_name]
|
||||
self.collection = self.create_col()
|
||||
|
||||
@@ -119,7 +111,9 @@ class MongoDB(VectorStoreBase):
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error inserting data: {e}")
|
||||
|
||||
def search(self, query: str, query_vector: List[float], limit=5, filters: Optional[Dict] = None) -> List[OutputData]:
|
||||
def search(
|
||||
self, query: str, query_vector: List[float], limit=5, filters: Optional[Dict] = None
|
||||
) -> List[OutputData]:
|
||||
"""
|
||||
Search for similar vectors using the vector search index.
|
||||
|
||||
|
||||
5797
poetry.lock
generated
5797
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -30,11 +30,13 @@ vector_stores = [
|
||||
"vecs>=0.4.0",
|
||||
"chromadb>=0.4.24",
|
||||
"weaviate-client>=4.4.0",
|
||||
"pinecone<7.0.0",
|
||||
"pinecone-text>=0.1.1",
|
||||
"pinecone<=7.3.0",
|
||||
"pinecone-text>=0.10.0",
|
||||
"faiss-cpu>=1.7.4",
|
||||
"upstash-vector>=0.1.0",
|
||||
"azure-search-documents>=11.4.0b8",
|
||||
"pymongo>=4.13.2",
|
||||
"pymochow>=2.2.9",
|
||||
]
|
||||
llms = [
|
||||
"groq>=0.3.0",
|
||||
@@ -44,12 +46,11 @@ llms = [
|
||||
"vertexai>=0.1.0",
|
||||
"google-generativeai>=0.3.0",
|
||||
"google-genai>=1.0.0",
|
||||
|
||||
]
|
||||
extras = [
|
||||
"boto3>=1.34.0",
|
||||
"langchain-community>=0.0.0",
|
||||
"sentence-transformers>=2.2.2",
|
||||
"sentence-transformers>=5.0.0",
|
||||
"elasticsearch>=8.0.0",
|
||||
"opensearch-py>=2.0.0",
|
||||
"langchain-memgraph>=0.1.0",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import patch, ANY
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -8,8 +8,10 @@ from mem0.embeddings.gemini import GoogleGenAIEmbedding
|
||||
|
||||
@pytest.fixture
|
||||
def mock_genai():
|
||||
with patch("mem0.embeddings.gemini.genai.embed_content") as mock_genai:
|
||||
yield mock_genai
|
||||
with patch("mem0.embeddings.gemini.genai.Client") as mock_client_class:
|
||||
mock_client = mock_client_class.return_value
|
||||
mock_client.models.embed_content.return_value = None
|
||||
yield mock_client.models.embed_content
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -18,7 +20,9 @@ def config():
|
||||
|
||||
|
||||
def test_embed_query(mock_genai, config):
|
||||
mock_embedding_response = {"embedding": [0.1, 0.2, 0.3, 0.4]}
|
||||
mock_embedding_response = type('Response', (), {
|
||||
'embeddings': [type('Embedding', (), {'values': [0.1, 0.2, 0.3, 0.4]})]
|
||||
})()
|
||||
mock_genai.return_value = mock_embedding_response
|
||||
|
||||
embedder = GoogleGenAIEmbedding(config)
|
||||
@@ -27,10 +31,11 @@ def test_embed_query(mock_genai, config):
|
||||
embedding = embedder.embed(text)
|
||||
|
||||
assert embedding == [0.1, 0.2, 0.3, 0.4]
|
||||
mock_genai.assert_called_once_with(model="test_model", content="Hello, world!", output_dimensionality=786)
|
||||
mock_genai.assert_called_once_with(model="test_model", contents="Hello, world!", config=ANY)
|
||||
|
||||
|
||||
def test_embed_returns_empty_list_if_none(mock_genai, config):
|
||||
mock_genai.return_value = None
|
||||
mock_genai.return_value = type('Response', (), {'embeddings': [type('Embedding', (), {'values': []})]})()
|
||||
|
||||
embedder = GoogleGenAIEmbedding(config)
|
||||
result = embedder.embed("test")
|
||||
@@ -47,10 +52,10 @@ def test_embed_raises_on_error(mock_genai, config):
|
||||
with pytest.raises(RuntimeError, match="Embedding failed"):
|
||||
embedder.embed("some input")
|
||||
|
||||
|
||||
def test_config_initialization(config):
|
||||
embedder = GoogleGenAIEmbedding(config)
|
||||
|
||||
assert embedder.config.api_key == "dummy_api_key"
|
||||
assert embedder.config.model == "test_model"
|
||||
assert embedder.config.embedding_dims == 786
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from mem0.llms.gemini import GeminiLLM
|
||||
|
||||
@pytest.fixture
|
||||
def mock_gemini_client():
|
||||
with patch("mem0.llms.gemini.genai") as mock_client_class:
|
||||
with patch("mem0.llms.gemini.genai.Client") as mock_client_class:
|
||||
mock_client = Mock()
|
||||
mock_client_class.return_value = mock_client
|
||||
yield mock_client
|
||||
@@ -24,45 +24,32 @@ def test_generate_response_without_tools(mock_gemini_client: Mock):
|
||||
]
|
||||
|
||||
mock_part = Mock(text="I'm doing well, thank you for asking!")
|
||||
mock_embedding = Mock()
|
||||
mock_embedding.values = [0.1, 0.2, 0.3]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.candidates = [Mock()]
|
||||
mock_response.candidates[0].content.parts = [Mock()]
|
||||
mock_response.candidates[0].content.parts[0].text = "I'm doing well, thank you for asking!"
|
||||
mock_content = Mock(parts=[mock_part])
|
||||
mock_candidate = Mock(content=mock_content)
|
||||
mock_response = Mock(candidates=[mock_candidate])
|
||||
|
||||
mock_gemini_client.models.generate_content.return_value = mock_response
|
||||
mock_content = Mock(parts=[mock_part])
|
||||
mock_message = Mock(content=mock_content)
|
||||
mock_response = Mock(candidates=[mock_message])
|
||||
mock_gemini_client.generate_content.return_value = mock_response
|
||||
|
||||
response = llm.generate_response(messages)
|
||||
|
||||
mock_gemini_client.generate_content.assert_called_once_with(
|
||||
contents=[
|
||||
{"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"},
|
||||
{"parts": "Hello, how are you?", "role": "user"},
|
||||
],
|
||||
config=types.GenerateContentConfig(
|
||||
temperature=0.7,
|
||||
max_output_tokens=100,
|
||||
top_p=1.0,
|
||||
tools=None,
|
||||
tool_config=types.ToolConfig(
|
||||
function_calling_config=types.FunctionCallingConfig(
|
||||
allowed_function_names=None,
|
||||
mode="auto"
|
||||
# Check the actual call - system instruction is now in config
|
||||
mock_gemini_client.models.generate_content.assert_called_once()
|
||||
call_args = mock_gemini_client.models.generate_content.call_args
|
||||
|
||||
)
|
||||
)
|
||||
) )
|
||||
# Verify model and contents
|
||||
assert call_args.kwargs['model'] == "gemini-2.0-flash-latest"
|
||||
assert len(call_args.kwargs['contents']) == 1 # Only user message
|
||||
|
||||
# Verify config has system instruction
|
||||
config_arg = call_args.kwargs['config']
|
||||
assert config_arg.system_instruction == "You are a helpful assistant."
|
||||
assert config_arg.temperature == 0.7
|
||||
assert config_arg.max_output_tokens == 100
|
||||
assert config_arg.top_p == 1.0
|
||||
|
||||
assert response == "I'm doing well, thank you for asking!"
|
||||
|
||||
|
||||
|
||||
def test_generate_response_with_tools(mock_gemini_client: Mock):
|
||||
config = BaseLlmConfig(model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||
llm = GeminiLLM(config)
|
||||
@@ -89,64 +76,42 @@ def test_generate_response_with_tools(mock_gemini_client: Mock):
|
||||
mock_tool_call.name = "add_memory"
|
||||
mock_tool_call.args = {"data": "Today is a sunny day."}
|
||||
|
||||
mock_part = Mock()
|
||||
mock_part.function_call = mock_tool_call
|
||||
mock_part.text = "I've added the memory for you."
|
||||
# Create mock parts with both text and function_call
|
||||
mock_text_part = Mock()
|
||||
mock_text_part.text = "I've added the memory for you."
|
||||
mock_text_part.function_call = None
|
||||
|
||||
mock_func_part = Mock()
|
||||
mock_func_part.text = None
|
||||
mock_func_part.function_call = mock_tool_call
|
||||
|
||||
mock_content = Mock()
|
||||
mock_content.parts = [mock_part]
|
||||
mock_content.parts = [mock_text_part, mock_func_part]
|
||||
|
||||
mock_message = Mock()
|
||||
mock_message.content = mock_content
|
||||
mock_candidate = Mock()
|
||||
mock_candidate.content = mock_content
|
||||
|
||||
mock_response = Mock(candidates=[mock_message])
|
||||
mock_gemini_client.generate_content.return_value = mock_response
|
||||
mock_response = Mock(candidates=[mock_candidate])
|
||||
mock_gemini_client.models.generate_content.return_value = mock_response
|
||||
|
||||
response = llm.generate_response(messages, tools=tools)
|
||||
|
||||
mock_gemini_client.generate_content.assert_called_once_with(
|
||||
contents=[
|
||||
{
|
||||
"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"parts": "Add a new memory: Today is a sunny day.",
|
||||
"role": "user"
|
||||
},
|
||||
],
|
||||
config=types.GenerateContentConfig(
|
||||
temperature=0.7,
|
||||
max_output_tokens=100,
|
||||
top_p=1.0,
|
||||
tools=[
|
||||
types.Tool(
|
||||
function_declarations=[
|
||||
types.FunctionDeclaration(
|
||||
name="add_memory",
|
||||
description="Add a memory",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {
|
||||
"type": "string",
|
||||
"description": "Data to add to memory"
|
||||
}
|
||||
},
|
||||
"required": ["data"]
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
],
|
||||
tool_config=types.ToolConfig(
|
||||
function_calling_config=types.FunctionCallingConfig(
|
||||
allowed_function_names=None,
|
||||
mode="auto"
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
# Check the actual call
|
||||
mock_gemini_client.models.generate_content.assert_called_once()
|
||||
call_args = mock_gemini_client.models.generate_content.call_args
|
||||
|
||||
# Verify model and contents
|
||||
assert call_args.kwargs['model'] == "gemini-1.5-flash-latest"
|
||||
assert len(call_args.kwargs['contents']) == 1 # Only user message
|
||||
|
||||
# Verify config has system instruction and tools
|
||||
config_arg = call_args.kwargs['config']
|
||||
assert config_arg.system_instruction == "You are a helpful assistant."
|
||||
assert config_arg.temperature == 0.7
|
||||
assert config_arg.max_output_tokens == 100
|
||||
assert config_arg.top_p == 1.0
|
||||
assert len(config_arg.tools) == 1
|
||||
assert config_arg.tool_config.function_calling_config.mode == types.FunctionCallingConfigMode.AUTO
|
||||
|
||||
assert response["content"] == "I've added the memory for you."
|
||||
assert len(response["tool_calls"]) == 1
|
||||
|
||||
@@ -43,6 +43,7 @@ def test_generate_response_without_tools(mock_lm_studio_client):
|
||||
|
||||
assert response == "I'm doing well, thank you for asking!"
|
||||
|
||||
|
||||
def test_generate_response_specifying_response_format(mock_lm_studio_client):
|
||||
config = BaseLlmConfig(
|
||||
model="lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf",
|
||||
|
||||
@@ -71,7 +71,13 @@ def test_generate_response_with_tools(mock_vllm_client):
|
||||
response = llm.generate_response(messages, tools=tools)
|
||||
|
||||
mock_vllm_client.chat.completions.create.assert_called_once_with(
|
||||
model="Qwen/Qwen2.5-32B-Instruct", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0, tools=tools, tool_choice="auto"
|
||||
model="Qwen/Qwen2.5-32B-Instruct",
|
||||
messages=messages,
|
||||
temperature=0.7,
|
||||
max_tokens=100,
|
||||
top_p=1.0,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
)
|
||||
|
||||
assert response["content"] == "I've added the memory for you."
|
||||
|
||||
@@ -253,11 +253,11 @@ def test_get_all(memory_instance, version, enable_graph, expected_result):
|
||||
def test_custom_prompts(memory_custom_instance):
|
||||
messages = [{"role": "user", "content": "Test message"}]
|
||||
from mem0.embeddings.mock import MockEmbeddings
|
||||
|
||||
memory_custom_instance.llm.generate_response = Mock()
|
||||
memory_custom_instance.llm.generate_response.return_value = '{"facts": ["fact1", "fact2"]}'
|
||||
memory_custom_instance.embedding_model = MockEmbeddings()
|
||||
|
||||
|
||||
with patch("mem0.memory.main.parse_messages", return_value="Test message") as mock_parse_messages:
|
||||
with patch(
|
||||
"mem0.memory.main.get_update_memory_messages", return_value="custom update memory prompt"
|
||||
|
||||
@@ -2,10 +2,9 @@ from unittest.mock import Mock, patch, PropertyMock
|
||||
|
||||
import pytest
|
||||
|
||||
from mem0.vector_stores.baidu import BaiduDB, OutputData
|
||||
from pymochow.model.enum import MetricType, TableState, ServerErrCode
|
||||
from pymochow.model.schema import Field, Schema, VectorIndex, FilteringIndex, HNSWParams, AutoBuildRowCountIncrement
|
||||
from pymochow.model.table import Partition, Row, VectorSearchConfig, VectorTopkSearchRequest, FloatVector, Table
|
||||
from mem0.vector_stores.baidu import BaiduDB
|
||||
from pymochow.model.enum import TableState, ServerErrCode
|
||||
from pymochow.model.table import VectorSearchConfig, VectorTopkSearchRequest, FloatVector, Table
|
||||
from pymochow.exception import ServerError
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from mem0.vector_stores.mongodb import MongoDB
|
||||
from pymongo.operations import SearchIndexModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@patch("mem0.vector_stores.mongodb.MongoClient")
|
||||
@@ -19,10 +19,11 @@ def mongo_vector_fixture(mock_mongo_client):
|
||||
db_name="test_db",
|
||||
collection_name="test_collection",
|
||||
embedding_model_dims=1536,
|
||||
mongo_uri="mongodb://username:password@localhost:27017"
|
||||
mongo_uri="mongodb://username:password@localhost:27017",
|
||||
)
|
||||
return mongo_vector, mock_collection, mock_db
|
||||
|
||||
|
||||
def test_initalize_create_col(mongo_vector_fixture):
|
||||
mongo_vector, mock_collection, mock_db = mongo_vector_fixture
|
||||
assert mongo_vector.collection_name == "test_collection"
|
||||
@@ -49,12 +50,13 @@ def test_initalize_create_col(mongo_vector_fixture):
|
||||
"dimensions": 1536,
|
||||
"similarity": "cosine",
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
assert mongo_vector.collection == mock_collection
|
||||
|
||||
|
||||
def test_insert(mongo_vector_fixture):
|
||||
mongo_vector, mock_collection, _ = mongo_vector_fixture
|
||||
vectors = [[0.1] * 1536, [0.2] * 1536]
|
||||
@@ -62,12 +64,13 @@ def test_insert(mongo_vector_fixture):
|
||||
ids = ["id1", "id2"]
|
||||
|
||||
mongo_vector.insert(vectors, payloads, ids)
|
||||
expected_records=[
|
||||
expected_records = [
|
||||
({"_id": ids[0], "embedding": vectors[0], "payload": payloads[0]}),
|
||||
({"_id": ids[1], "embedding": vectors[1], "payload": payloads[1]})
|
||||
({"_id": ids[1], "embedding": vectors[1], "payload": payloads[1]}),
|
||||
]
|
||||
mock_collection.insert_many.assert_called_once_with(expected_records)
|
||||
|
||||
|
||||
def test_search(mongo_vector_fixture):
|
||||
mongo_vector, mock_collection, _ = mongo_vector_fixture
|
||||
query_vector = [0.1] * 1536
|
||||
@@ -79,7 +82,8 @@ def test_search(mongo_vector_fixture):
|
||||
|
||||
results = mongo_vector.search("query_str", query_vector, limit=2)
|
||||
mock_collection.list_search_indexes.assert_called_with(name="test_collection_vector_index")
|
||||
mock_collection.aggregate.assert_called_once_with([
|
||||
mock_collection.aggregate.assert_called_once_with(
|
||||
[
|
||||
{
|
||||
"$vectorSearch": {
|
||||
"index": "test_collection_vector_index",
|
||||
@@ -91,13 +95,15 @@ def test_search(mongo_vector_fixture):
|
||||
},
|
||||
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
|
||||
{"$project": {"embedding": 0}},
|
||||
])
|
||||
]
|
||||
)
|
||||
assert len(results) == 2
|
||||
assert results[0].id == "id1"
|
||||
assert results[0].score == 0.9
|
||||
assert results[1].id == "id2"
|
||||
assert results[1].score == 0.8
|
||||
|
||||
|
||||
def test_delete(mongo_vector_fixture):
|
||||
mongo_vector, mock_collection, _ = mongo_vector_fixture
|
||||
mock_delete_result = MagicMock()
|
||||
@@ -107,6 +113,7 @@ def test_delete(mongo_vector_fixture):
|
||||
mongo_vector.delete("id1")
|
||||
mock_collection.delete_one.assert_called_with({"_id": "id1"})
|
||||
|
||||
|
||||
def test_update(mongo_vector_fixture):
|
||||
mongo_vector, mock_collection, _ = mongo_vector_fixture
|
||||
mock_update_result = MagicMock()
|
||||
@@ -122,6 +129,7 @@ def test_update(mongo_vector_fixture):
|
||||
{"$set": {"embedding": vectorValue, "payload": payloadValue}},
|
||||
)
|
||||
|
||||
|
||||
def test_get(mongo_vector_fixture):
|
||||
mongo_vector, mock_collection, _ = mongo_vector_fixture
|
||||
mock_collection.find_one.return_value = {"_id": "id1", "payload": {"key": "value1"}}
|
||||
@@ -131,6 +139,7 @@ def test_get(mongo_vector_fixture):
|
||||
assert result.id == "id1"
|
||||
assert result.payload == {"key": "value1"}
|
||||
|
||||
|
||||
def test_list_cols(mongo_vector_fixture):
|
||||
mongo_vector, _, mock_db = mongo_vector_fixture
|
||||
mock_db.list_collection_names.return_value = ["col1", "col2"]
|
||||
@@ -138,12 +147,14 @@ def test_list_cols(mongo_vector_fixture):
|
||||
collections = mongo_vector.list_cols()
|
||||
assert collections == ["col1", "col2"]
|
||||
|
||||
|
||||
def test_delete_col(mongo_vector_fixture):
|
||||
mongo_vector, mock_collection, _ = mongo_vector_fixture
|
||||
|
||||
mongo_vector.delete_col()
|
||||
mock_collection.drop.assert_called_once()
|
||||
|
||||
|
||||
def test_col_info(mongo_vector_fixture):
|
||||
mongo_vector, _, mock_db = mongo_vector_fixture
|
||||
mock_db.command.return_value = {"count": 10, "size": 1024}
|
||||
@@ -154,6 +165,7 @@ def test_col_info(mongo_vector_fixture):
|
||||
assert info["count"] == 10
|
||||
assert info["size"] == 1024
|
||||
|
||||
|
||||
def test_list(mongo_vector_fixture):
|
||||
mongo_vector, mock_collection, _ = mongo_vector_fixture
|
||||
mock_cursor = MagicMock()
|
||||
|
||||
@@ -88,7 +88,7 @@ def test_get_vector_found(pinecone_db):
|
||||
# or a list of dictionaries, not a dictionary with an 'id' field
|
||||
|
||||
# Create a mock Vector object
|
||||
from pinecone.data.dataclasses.vector import Vector
|
||||
from pinecone import Vector
|
||||
|
||||
mock_vector = Vector(id="id1", values=[0.1] * 128, metadata={"name": "vector1"})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user