Fix CI issues related to missing dependency (#3096)
This commit is contained in:
@@ -89,9 +89,7 @@ class MemoryClient:
|
||||
self.user_id = get_user_id()
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"Mem0 API Key not provided. Please provide an API Key."
|
||||
)
|
||||
raise ValueError("Mem0 API Key not provided. Please provide an API Key.")
|
||||
|
||||
# Create MD5 hash of API key for user_id
|
||||
self.user_id = hashlib.md5(self.api_key.encode()).hexdigest()
|
||||
@@ -174,9 +172,7 @@ class MemoryClient:
|
||||
response.raise_for_status()
|
||||
if "metadata" in kwargs:
|
||||
del kwargs["metadata"]
|
||||
capture_client_event(
|
||||
"client.add", self, {"keys": list(kwargs.keys()), "sync_type": "sync"}
|
||||
)
|
||||
capture_client_event("client.add", self, {"keys": list(kwargs.keys()), "sync_type": "sync"})
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
@@ -195,9 +191,7 @@ class MemoryClient:
|
||||
params = self._prepare_params()
|
||||
response = self.client.get(f"/v1/memories/{memory_id}/", params=params)
|
||||
response.raise_for_status()
|
||||
capture_client_event(
|
||||
"client.get", self, {"memory_id": memory_id, "sync_type": "sync"}
|
||||
)
|
||||
capture_client_event("client.get", self, {"memory_id": memory_id, "sync_type": "sync"})
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
@@ -224,13 +218,9 @@ class MemoryClient:
|
||||
"page": params.pop("page"),
|
||||
"page_size": params.pop("page_size"),
|
||||
}
|
||||
response = self.client.post(
|
||||
f"/{version}/memories/", json=params, params=query_params
|
||||
)
|
||||
response = self.client.post(f"/{version}/memories/", json=params, params=query_params)
|
||||
else:
|
||||
response = self.client.post(
|
||||
f"/{version}/memories/", json=params
|
||||
)
|
||||
response = self.client.post(f"/{version}/memories/", json=params)
|
||||
response.raise_for_status()
|
||||
if "metadata" in kwargs:
|
||||
del kwargs["metadata"]
|
||||
@@ -246,9 +236,7 @@ class MemoryClient:
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
def search(
|
||||
self, query: str, version: str = "v1", **kwargs
|
||||
) -> List[Dict[str, Any]]:
|
||||
def search(self, query: str, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
|
||||
"""Search memories based on a query.
|
||||
|
||||
Args:
|
||||
@@ -266,9 +254,7 @@ class MemoryClient:
|
||||
payload = {"query": query}
|
||||
params = self._prepare_params(kwargs)
|
||||
payload.update(params)
|
||||
response = self.client.post(
|
||||
f"/{version}/memories/search/", json=payload
|
||||
)
|
||||
response = self.client.post(f"/{version}/memories/search/", json=payload)
|
||||
response.raise_for_status()
|
||||
if "metadata" in kwargs:
|
||||
del kwargs["metadata"]
|
||||
@@ -308,13 +294,9 @@ class MemoryClient:
|
||||
if metadata is not None:
|
||||
payload["metadata"] = metadata
|
||||
|
||||
capture_client_event(
|
||||
"client.update", self, {"memory_id": memory_id, "sync_type": "sync"}
|
||||
)
|
||||
capture_client_event("client.update", self, {"memory_id": memory_id, "sync_type": "sync"})
|
||||
params = self._prepare_params()
|
||||
response = self.client.put(
|
||||
f"/v1/memories/{memory_id}/", json=payload, params=params
|
||||
)
|
||||
response = self.client.put(f"/v1/memories/{memory_id}/", json=payload, params=params)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@@ -332,13 +314,9 @@ class MemoryClient:
|
||||
APIError: If the API request fails.
|
||||
"""
|
||||
params = self._prepare_params()
|
||||
response = self.client.delete(
|
||||
f"/v1/memories/{memory_id}/", params=params
|
||||
)
|
||||
response = self.client.delete(f"/v1/memories/{memory_id}/", params=params)
|
||||
response.raise_for_status()
|
||||
capture_client_event(
|
||||
"client.delete", self, {"memory_id": memory_id, "sync_type": "sync"}
|
||||
)
|
||||
capture_client_event("client.delete", self, {"memory_id": memory_id, "sync_type": "sync"})
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
@@ -379,13 +357,9 @@ class MemoryClient:
|
||||
APIError: If the API request fails.
|
||||
"""
|
||||
params = self._prepare_params()
|
||||
response = self.client.get(
|
||||
f"/v1/memories/{memory_id}/history/", params=params
|
||||
)
|
||||
response = self.client.get(f"/v1/memories/{memory_id}/history/", params=params)
|
||||
response.raise_for_status()
|
||||
capture_client_event(
|
||||
"client.history", self, {"memory_id": memory_id, "sync_type": "sync"}
|
||||
)
|
||||
capture_client_event("client.history", self, {"memory_id": memory_id, "sync_type": "sync"})
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
@@ -432,10 +406,7 @@ class MemoryClient:
|
||||
else:
|
||||
entities = self.users()
|
||||
# Filter entities based on provided IDs using list comprehension
|
||||
to_delete = [
|
||||
{"type": entity["type"], "name": entity["name"]}
|
||||
for entity in entities["results"]
|
||||
]
|
||||
to_delete = [{"type": entity["type"], "name": entity["name"]} for entity in entities["results"]]
|
||||
|
||||
params = self._prepare_params()
|
||||
|
||||
@@ -444,9 +415,7 @@ class MemoryClient:
|
||||
|
||||
# Delete entities and check response immediately
|
||||
for entity in to_delete:
|
||||
response = self.client.delete(
|
||||
f"/v2/entities/{entity['type']}/{entity['name']}/", params=params
|
||||
)
|
||||
response = self.client.delete(f"/v2/entities/{entity['type']}/{entity['name']}/", params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
capture_client_event(
|
||||
@@ -484,9 +453,7 @@ class MemoryClient:
|
||||
self.delete_users()
|
||||
|
||||
capture_client_event("client.reset", self, {"sync_type": "sync"})
|
||||
return {
|
||||
"message": "Client reset successful. All users and memories deleted."
|
||||
}
|
||||
return {"message": "Client reset successful. All users and memories deleted."}
|
||||
|
||||
@api_error_handler
|
||||
def batch_update(self, memories: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
@@ -507,9 +474,7 @@ class MemoryClient:
|
||||
response = self.client.put("/v1/batch/", json={"memories": memories})
|
||||
response.raise_for_status()
|
||||
|
||||
capture_client_event(
|
||||
"client.batch_update", self, {"sync_type": "sync"}
|
||||
)
|
||||
capture_client_event("client.batch_update", self, {"sync_type": "sync"})
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
@@ -527,14 +492,10 @@ class MemoryClient:
|
||||
Raises:
|
||||
APIError: If the API request fails.
|
||||
"""
|
||||
response = self.client.request(
|
||||
"DELETE", "/v1/batch/", json={"memories": memories}
|
||||
)
|
||||
response = self.client.request("DELETE", "/v1/batch/", json={"memories": memories})
|
||||
response.raise_for_status()
|
||||
|
||||
capture_client_event(
|
||||
"client.batch_delete", self, {"sync_type": "sync"}
|
||||
)
|
||||
capture_client_event("client.batch_delete", self, {"sync_type": "sync"})
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
@@ -574,9 +535,7 @@ class MemoryClient:
|
||||
Returns:
|
||||
Dict containing the exported data
|
||||
"""
|
||||
response = self.client.post(
|
||||
"/v1/exports/get/", json=self._prepare_params(kwargs)
|
||||
)
|
||||
response = self.client.post("/v1/exports/get/", json=self._prepare_params(kwargs))
|
||||
response.raise_for_status()
|
||||
capture_client_event(
|
||||
"client.get_memory_export",
|
||||
@@ -586,9 +545,7 @@ class MemoryClient:
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
def get_summary(
|
||||
self, filters: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
def get_summary(self, filters: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""Get the summary of a memory export.
|
||||
|
||||
Args:
|
||||
@@ -598,17 +555,13 @@ class MemoryClient:
|
||||
Dict containing the export status and summary data
|
||||
"""
|
||||
|
||||
response = self.client.post(
|
||||
"/v1/summary/", json=self._prepare_params({"filters": filters})
|
||||
)
|
||||
response = self.client.post("/v1/summary/", json=self._prepare_params({"filters": filters}))
|
||||
response.raise_for_status()
|
||||
capture_client_event("client.get_summary", self, {"sync_type": "sync"})
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
def get_project(
|
||||
self, fields: Optional[List[str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
def get_project(self, fields: Optional[List[str]] = None) -> Dict[str, Any]:
|
||||
"""Get instructions or categories for the current project.
|
||||
|
||||
Args:
|
||||
@@ -622,10 +575,7 @@ class MemoryClient:
|
||||
ValueError: If org_id or project_id are not set.
|
||||
"""
|
||||
if not (self.org_id and self.project_id):
|
||||
raise ValueError(
|
||||
"org_id and project_id must be set to access instructions or "
|
||||
"categories"
|
||||
)
|
||||
raise ValueError("org_id and project_id must be set to access instructions or categories")
|
||||
|
||||
params = self._prepare_params({"fields": fields})
|
||||
response = self.client.get(
|
||||
@@ -666,10 +616,7 @@ class MemoryClient:
|
||||
ValueError: If org_id or project_id are not set.
|
||||
"""
|
||||
if not (self.org_id and self.project_id):
|
||||
raise ValueError(
|
||||
"org_id and project_id must be set to update instructions or "
|
||||
"categories"
|
||||
)
|
||||
raise ValueError("org_id and project_id must be set to update instructions or categories")
|
||||
|
||||
if (
|
||||
custom_instructions is None
|
||||
@@ -826,10 +773,7 @@ class MemoryClient:
|
||||
|
||||
feedback = feedback.upper() if feedback else None
|
||||
if feedback is not None and feedback not in VALID_FEEDBACK_VALUES:
|
||||
raise ValueError(
|
||||
f'feedback must be one of {", ".join(VALID_FEEDBACK_VALUES)} '
|
||||
"or None"
|
||||
)
|
||||
raise ValueError(f"feedback must be one of {', '.join(VALID_FEEDBACK_VALUES)} or None")
|
||||
|
||||
data = {
|
||||
"memory_id": memory_id,
|
||||
@@ -839,14 +783,10 @@ class MemoryClient:
|
||||
|
||||
response = self.client.post("/v1/feedback/", json=data)
|
||||
response.raise_for_status()
|
||||
capture_client_event(
|
||||
"client.feedback", self, data, {"sync_type": "sync"}
|
||||
)
|
||||
capture_client_event("client.feedback", self, data, {"sync_type": "sync"})
|
||||
return response.json()
|
||||
|
||||
def _prepare_payload(
|
||||
self, messages: List[Dict[str, str]], kwargs: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
def _prepare_payload(self, messages: List[Dict[str, str]], kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Prepare the payload for API requests.
|
||||
|
||||
Args:
|
||||
@@ -862,9 +802,7 @@ class MemoryClient:
|
||||
payload.update({k: v for k, v in kwargs.items() if v is not None})
|
||||
return payload
|
||||
|
||||
def _prepare_params(
|
||||
self, kwargs: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
def _prepare_params(self, kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""Prepare query parameters for API requests.
|
||||
|
||||
Args:
|
||||
@@ -929,9 +867,7 @@ class AsyncMemoryClient:
|
||||
self.user_id = get_user_id()
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"Mem0 API Key not provided. Please provide an API Key."
|
||||
)
|
||||
raise ValueError("Mem0 API Key not provided. Please provide an API Key.")
|
||||
|
||||
# Create MD5 hash of API key for user_id
|
||||
self.user_id = hashlib.md5(self.api_key.encode()).hexdigest()
|
||||
@@ -989,9 +925,7 @@ class AsyncMemoryClient:
|
||||
error_message = str(e)
|
||||
raise ValueError(f"Error: {error_message}")
|
||||
|
||||
def _prepare_payload(
|
||||
self, messages: List[Dict[str, str]], kwargs: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
def _prepare_payload(self, messages: List[Dict[str, str]], kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Prepare the payload for API requests.
|
||||
|
||||
Args:
|
||||
@@ -1007,9 +941,7 @@ class AsyncMemoryClient:
|
||||
payload.update({k: v for k, v in kwargs.items() if v is not None})
|
||||
return payload
|
||||
|
||||
def _prepare_params(
|
||||
self, kwargs: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
def _prepare_params(self, kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""Prepare query parameters for API requests.
|
||||
|
||||
Args:
|
||||
@@ -1041,9 +973,7 @@ class AsyncMemoryClient:
|
||||
await self.async_client.aclose()
|
||||
|
||||
@api_error_handler
|
||||
async def add(
|
||||
self, messages: List[Dict[str, str]], **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
async def add(self, messages: List[Dict[str, str]], **kwargs) -> Dict[str, Any]:
|
||||
kwargs = self._prepare_params(kwargs)
|
||||
if kwargs.get("output_format") != "v1.1":
|
||||
kwargs["output_format"] = "v1.1"
|
||||
@@ -1062,45 +992,31 @@ class AsyncMemoryClient:
|
||||
response.raise_for_status()
|
||||
if "metadata" in kwargs:
|
||||
del kwargs["metadata"]
|
||||
capture_client_event(
|
||||
"client.add", self, {"keys": list(kwargs.keys()), "sync_type": "async"}
|
||||
)
|
||||
capture_client_event("client.add", self, {"keys": list(kwargs.keys()), "sync_type": "async"})
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
async def get(self, memory_id: str) -> Dict[str, Any]:
|
||||
params = self._prepare_params()
|
||||
response = await self.async_client.get(
|
||||
f"/v1/memories/{memory_id}/", params=params
|
||||
)
|
||||
response = await self.async_client.get(f"/v1/memories/{memory_id}/", params=params)
|
||||
response.raise_for_status()
|
||||
capture_client_event(
|
||||
"client.get", self, {"memory_id": memory_id, "sync_type": "async"}
|
||||
)
|
||||
capture_client_event("client.get", self, {"memory_id": memory_id, "sync_type": "async"})
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
async def get_all(
|
||||
self, version: str = "v1", **kwargs
|
||||
) -> List[Dict[str, Any]]:
|
||||
async def get_all(self, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
|
||||
params = self._prepare_params(kwargs)
|
||||
if version == "v1":
|
||||
response = await self.async_client.get(
|
||||
f"/{version}/memories/", params=params
|
||||
)
|
||||
response = await self.async_client.get(f"/{version}/memories/", params=params)
|
||||
elif version == "v2":
|
||||
if "page" in params and "page_size" in params:
|
||||
query_params = {
|
||||
"page": params.pop("page"),
|
||||
"page_size": params.pop("page_size"),
|
||||
}
|
||||
response = await self.async_client.post(
|
||||
f"/{version}/memories/", json=params, params=query_params
|
||||
)
|
||||
response = await self.async_client.post(f"/{version}/memories/", json=params, params=query_params)
|
||||
else:
|
||||
response = await self.async_client.post(
|
||||
f"/{version}/memories/", json=params
|
||||
)
|
||||
response = await self.async_client.post(f"/{version}/memories/", json=params)
|
||||
response.raise_for_status()
|
||||
if "metadata" in kwargs:
|
||||
del kwargs["metadata"]
|
||||
@@ -1116,14 +1032,10 @@ class AsyncMemoryClient:
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
async def search(
|
||||
self, query: str, version: str = "v1", **kwargs
|
||||
) -> List[Dict[str, Any]]:
|
||||
async def search(self, query: str, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
|
||||
payload = {"query": query}
|
||||
payload.update(self._prepare_params(kwargs))
|
||||
response = await self.async_client.post(
|
||||
f"/{version}/memories/search/", json=payload
|
||||
)
|
||||
response = await self.async_client.post(f"/{version}/memories/search/", json=payload)
|
||||
response.raise_for_status()
|
||||
if "metadata" in kwargs:
|
||||
del kwargs["metadata"]
|
||||
@@ -1139,7 +1051,9 @@ class AsyncMemoryClient:
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
async def update(self, memory_id: str, text: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
async def update(
|
||||
self, memory_id: str, text: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Update a memory by ID.
|
||||
Args:
|
||||
@@ -1265,10 +1179,7 @@ class AsyncMemoryClient:
|
||||
else:
|
||||
entities = await self.users()
|
||||
# Filter entities based on provided IDs using list comprehension
|
||||
to_delete = [
|
||||
{"type": entity["type"], "name": entity["name"]}
|
||||
for entity in entities["results"]
|
||||
]
|
||||
to_delete = [{"type": entity["type"], "name": entity["name"]} for entity in entities["results"]]
|
||||
|
||||
params = self._prepare_params()
|
||||
|
||||
@@ -1277,9 +1188,7 @@ class AsyncMemoryClient:
|
||||
|
||||
# Delete entities and check response immediately
|
||||
for entity in to_delete:
|
||||
response = await self.async_client.delete(
|
||||
f"/v2/entities/{entity['type']}/{entity['name']}/", params=params
|
||||
)
|
||||
response = await self.async_client.delete(f"/v2/entities/{entity['type']}/{entity['name']}/", params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
capture_client_event(
|
||||
@@ -1335,9 +1244,7 @@ class AsyncMemoryClient:
|
||||
response = await self.async_client.put("/v1/batch/", json={"memories": memories})
|
||||
response.raise_for_status()
|
||||
|
||||
capture_client_event(
|
||||
"client.batch_update", self, {"sync_type": "async"}
|
||||
)
|
||||
capture_client_event("client.batch_update", self, {"sync_type": "async"})
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
@@ -1355,14 +1262,10 @@ class AsyncMemoryClient:
|
||||
Raises:
|
||||
APIError: If the API request fails.
|
||||
"""
|
||||
response = await self.async_client.request(
|
||||
"DELETE", "/v1/batch/", json={"memories": memories}
|
||||
)
|
||||
response = await self.async_client.request("DELETE", "/v1/batch/", json={"memories": memories})
|
||||
response.raise_for_status()
|
||||
|
||||
capture_client_event(
|
||||
"client.batch_delete", self, {"sync_type": "async"}
|
||||
)
|
||||
capture_client_event("client.batch_delete", self, {"sync_type": "async"})
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
@@ -1614,7 +1517,7 @@ class AsyncMemoryClient:
|
||||
|
||||
feedback = feedback.upper() if feedback else None
|
||||
if feedback is not None and feedback not in VALID_FEEDBACK_VALUES:
|
||||
raise ValueError(f'feedback must be one of {", ".join(VALID_FEEDBACK_VALUES)} or None')
|
||||
raise ValueError(f"feedback must be one of {', '.join(VALID_FEEDBACK_VALUES)} or None")
|
||||
|
||||
data = {"memory_id": memory_id, "feedback": feedback, "feedback_reason": feedback_reason}
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
@@ -11,7 +11,7 @@ class MongoDBConfig(BaseModel):
|
||||
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding vectors")
|
||||
mongo_uri: str = Field("mongodb://localhost:27017", description="MongoDB URI. Default is mongodb://localhost:27017")
|
||||
|
||||
@model_validator(mode='before')
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
allowed_fields = set(cls.model_fields.keys())
|
||||
|
||||
@@ -36,6 +36,6 @@ class OpenSearchConfig(BaseModel):
|
||||
extra_fields = input_fields - allowed_fields
|
||||
if extra_fields:
|
||||
raise ValueError(
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. " f"Allowed fields: {', '.join(allowed_fields)}"
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. Allowed fields: {', '.join(allowed_fields)}"
|
||||
)
|
||||
return values
|
||||
|
||||
@@ -36,4 +36,4 @@ class GoogleGenAIEmbedding(EmbeddingBase):
|
||||
# Call the embed_content method with the correct parameters
|
||||
response = self.client.models.embed_content(model=self.config.model, contents=text, config=config)
|
||||
|
||||
return response.embeddings[0].values
|
||||
return response.embeddings[0].values
|
||||
|
||||
@@ -92,10 +92,12 @@ class AWSBedrockLLM(LLMBase):
|
||||
if response["output"]["message"]["content"]:
|
||||
for item in response["output"]["message"]["content"]:
|
||||
if "toolUse" in item:
|
||||
processed_response["tool_calls"].append({
|
||||
"name": item["toolUse"]["name"],
|
||||
"arguments": item["toolUse"]["input"],
|
||||
})
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": item["toolUse"]["name"],
|
||||
"arguments": item["toolUse"]["input"],
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
|
||||
|
||||
@@ -165,7 +165,6 @@ class GeminiLLM(LLMBase):
|
||||
if system_instruction:
|
||||
config_params["system_instruction"] = system_instruction
|
||||
|
||||
|
||||
if response_format is not None and response_format["type"] == "json_object":
|
||||
config_params["response_mime_type"] = "application/json"
|
||||
if "schema" in response_format:
|
||||
@@ -175,7 +174,6 @@ class GeminiLLM(LLMBase):
|
||||
formatted_tools = self._reformat_tools(tools)
|
||||
config_params["tools"] = formatted_tools
|
||||
|
||||
|
||||
if tool_choice:
|
||||
if tool_choice == "auto":
|
||||
mode = types.FunctionCallingConfigMode.AUTO
|
||||
|
||||
@@ -18,7 +18,7 @@ class SarvamLLM(LLMBase):
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"Sarvam API key is required. Set SARVAM_API_KEY environment variable " "or provide api_key in config."
|
||||
"Sarvam API key is required. Set SARVAM_API_KEY environment variable or provide api_key in config."
|
||||
)
|
||||
|
||||
# Set base URL - use config value or environment or default
|
||||
|
||||
@@ -7,7 +7,6 @@ from openai import OpenAI
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.memory.utils import extract_json
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
class VllmLLM(LLMBase):
|
||||
@@ -41,10 +40,12 @@ class VllmLLM(LLMBase):
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
processed_response["tool_calls"].append({
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(extract_json(tool_call.function.arguments)),
|
||||
})
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(extract_json(tool_call.function.arguments)),
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
else:
|
||||
|
||||
@@ -136,7 +136,6 @@ class MemoryGraph:
|
||||
params = {"user_id": filters["user_id"]}
|
||||
self.graph.query(cypher, params=params)
|
||||
|
||||
|
||||
def get_all(self, filters, limit=100):
|
||||
"""
|
||||
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
|
||||
@@ -176,7 +175,6 @@ class MemoryGraph:
|
||||
|
||||
return final_results
|
||||
|
||||
|
||||
def _retrieve_nodes_from_data(self, data, filters):
|
||||
"""Extracts all the entities mentioned in the query."""
|
||||
_tools = [EXTRACT_ENTITIES_TOOL]
|
||||
@@ -213,7 +211,7 @@ class MemoryGraph:
|
||||
def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
|
||||
"""Establish relations among the extracted nodes."""
|
||||
|
||||
# Compose user identification string for prompt
|
||||
# Compose user identification string for prompt
|
||||
user_identity = f"user_id: {filters['user_id']}"
|
||||
if filters.get("agent_id"):
|
||||
user_identity += f", agent_id: {filters['agent_id']}"
|
||||
@@ -221,9 +219,7 @@ class MemoryGraph:
|
||||
if self.config.graph_store.custom_prompt:
|
||||
system_content = EXTRACT_RELATIONS_PROMPT.replace("USER_ID", user_identity)
|
||||
# Add the custom prompt line if configured
|
||||
system_content = system_content.replace(
|
||||
"CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}"
|
||||
)
|
||||
system_content = system_content.replace("CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}")
|
||||
messages = [
|
||||
{"role": "system", "content": system_content},
|
||||
{"role": "user", "content": data},
|
||||
@@ -336,7 +332,7 @@ class MemoryGraph:
|
||||
user_id = filters["user_id"]
|
||||
agent_id = filters.get("agent_id", None)
|
||||
results = []
|
||||
|
||||
|
||||
for item in to_be_deleted:
|
||||
source = item["source"]
|
||||
destination = item["destination"]
|
||||
@@ -349,7 +345,7 @@ class MemoryGraph:
|
||||
"dest_name": destination,
|
||||
"user_id": user_id,
|
||||
}
|
||||
|
||||
|
||||
if agent_id:
|
||||
agent_filter = "AND n.agent_id = $agent_id AND m.agent_id = $agent_id"
|
||||
params["agent_id"] = agent_id
|
||||
@@ -366,10 +362,10 @@ class MemoryGraph:
|
||||
m.name AS target,
|
||||
type(r) AS relationship
|
||||
"""
|
||||
|
||||
|
||||
result = self.graph.query(cypher, params=params)
|
||||
results.append(result)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
def _add_entities(self, to_be_added, filters, entity_type_map):
|
||||
@@ -430,7 +426,7 @@ class MemoryGraph:
|
||||
r.mentions = coalesce(r.mentions, 0) + 1
|
||||
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
||||
"""
|
||||
|
||||
|
||||
params = {
|
||||
"source_id": source_node_search_result[0]["elementId(source_candidate)"],
|
||||
"destination_name": destination,
|
||||
@@ -592,7 +588,6 @@ class MemoryGraph:
|
||||
result = self.graph.query(cypher, params=params)
|
||||
return result
|
||||
|
||||
|
||||
def _search_destination_node(self, destination_embedding, filters, threshold=0.9):
|
||||
agent_filter = ""
|
||||
if filters.get("agent_id"):
|
||||
|
||||
@@ -338,7 +338,7 @@ class Memory(MemoryBase):
|
||||
except Exception as e:
|
||||
logger.error(f"Error in new_retrieved_facts: {e}")
|
||||
new_retrieved_facts = []
|
||||
|
||||
|
||||
if not new_retrieved_facts:
|
||||
logger.debug("No new facts retrieved from input. Skipping memory update LLM call.")
|
||||
|
||||
@@ -1166,7 +1166,7 @@ class AsyncMemory(MemoryBase):
|
||||
except Exception as e:
|
||||
logger.error(f"Error in new_retrieved_facts: {e}")
|
||||
new_retrieved_facts = []
|
||||
|
||||
|
||||
if not new_retrieved_facts:
|
||||
logger.debug("No new facts retrieved from input. Skipping memory update LLM call.")
|
||||
|
||||
|
||||
@@ -162,7 +162,7 @@ class MemoryGraph:
|
||||
LIMIT $limit
|
||||
"""
|
||||
params = {"user_id": filters["user_id"], "limit": limit}
|
||||
|
||||
|
||||
results = self.graph.query(query, params=params)
|
||||
|
||||
final_results = []
|
||||
@@ -318,7 +318,7 @@ class MemoryGraph:
|
||||
"user_id": filters["user_id"],
|
||||
"limit": limit,
|
||||
}
|
||||
|
||||
|
||||
ans = self.graph.query(cypher_query, params=params)
|
||||
result_relations.extend(ans)
|
||||
|
||||
@@ -356,7 +356,7 @@ class MemoryGraph:
|
||||
user_id = filters["user_id"]
|
||||
agent_id = filters.get("agent_id", None)
|
||||
results = []
|
||||
|
||||
|
||||
for item in to_be_deleted:
|
||||
source = item["source"]
|
||||
destination = item["destination"]
|
||||
@@ -369,7 +369,7 @@ class MemoryGraph:
|
||||
"dest_name": destination,
|
||||
"user_id": user_id,
|
||||
}
|
||||
|
||||
|
||||
if agent_id:
|
||||
agent_filter = "AND n.agent_id = $agent_id AND m.agent_id = $agent_id"
|
||||
params["agent_id"] = agent_id
|
||||
@@ -386,10 +386,10 @@ class MemoryGraph:
|
||||
m.name AS target,
|
||||
type(r) AS relationship
|
||||
"""
|
||||
|
||||
|
||||
result = self.graph.query(cypher, params=params)
|
||||
results.append(result)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
# added Entity label to all nodes for vector search to work
|
||||
@@ -398,7 +398,7 @@ class MemoryGraph:
|
||||
user_id = filters["user_id"]
|
||||
agent_id = filters.get("agent_id", None)
|
||||
results = []
|
||||
|
||||
|
||||
for item in to_be_added:
|
||||
# entities
|
||||
source = item["source"]
|
||||
@@ -421,7 +421,7 @@ class MemoryGraph:
|
||||
agent_id_clause = ""
|
||||
if agent_id:
|
||||
agent_id_clause = ", agent_id: $agent_id"
|
||||
|
||||
|
||||
# TODO: Create a cypher query and common params for all the cases
|
||||
if not destination_node_search_result and source_node_search_result:
|
||||
cypher = f"""
|
||||
@@ -446,7 +446,7 @@ class MemoryGraph:
|
||||
}
|
||||
if agent_id:
|
||||
params["agent_id"] = agent_id
|
||||
|
||||
|
||||
elif destination_node_search_result and not source_node_search_result:
|
||||
cypher = f"""
|
||||
MATCH (destination:Entity)
|
||||
@@ -470,7 +470,7 @@ class MemoryGraph:
|
||||
}
|
||||
if agent_id:
|
||||
params["agent_id"] = agent_id
|
||||
|
||||
|
||||
elif source_node_search_result and destination_node_search_result:
|
||||
cypher = f"""
|
||||
MATCH (source:Entity)
|
||||
@@ -490,7 +490,7 @@ class MemoryGraph:
|
||||
}
|
||||
if agent_id:
|
||||
params["agent_id"] = agent_id
|
||||
|
||||
|
||||
else:
|
||||
cypher = f"""
|
||||
MERGE (n:{source_type}:Entity {{name: $source_name, user_id: $user_id{agent_id_clause}}})
|
||||
@@ -512,7 +512,7 @@ class MemoryGraph:
|
||||
}
|
||||
if agent_id:
|
||||
params["agent_id"] = agent_id
|
||||
|
||||
|
||||
result = self.graph.query(cypher, params=params)
|
||||
results.append(result)
|
||||
return results
|
||||
@@ -528,7 +528,7 @@ class MemoryGraph:
|
||||
"""Search for source nodes with similar embeddings."""
|
||||
user_id = filters["user_id"]
|
||||
agent_id = filters.get("agent_id", None)
|
||||
|
||||
|
||||
if agent_id:
|
||||
cypher = """
|
||||
CALL vector_search.search("memzero", 1, $source_embedding)
|
||||
@@ -567,7 +567,7 @@ class MemoryGraph:
|
||||
"""Search for destination nodes with similar embeddings."""
|
||||
user_id = filters["user_id"]
|
||||
agent_id = filters.get("agent_id", None)
|
||||
|
||||
|
||||
if agent_id:
|
||||
cypher = """
|
||||
CALL vector_search.search("memzero", 1, $destination_embedding)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any, Callable
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -26,13 +26,7 @@ class MongoDB(VectorStoreBase):
|
||||
VECTOR_TYPE = "knnVector"
|
||||
SIMILARITY_METRIC = "cosine"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_name: str,
|
||||
collection_name: str,
|
||||
embedding_model_dims: int,
|
||||
mongo_uri: str
|
||||
):
|
||||
def __init__(self, db_name: str, collection_name: str, embedding_model_dims: int, mongo_uri: str):
|
||||
"""
|
||||
Initialize the MongoDB vector store with vector search capabilities.
|
||||
|
||||
@@ -46,9 +40,7 @@ class MongoDB(VectorStoreBase):
|
||||
self.embedding_model_dims = embedding_model_dims
|
||||
self.db_name = db_name
|
||||
|
||||
self.client = MongoClient(
|
||||
mongo_uri
|
||||
)
|
||||
self.client = MongoClient(mongo_uri)
|
||||
self.db = self.client[db_name]
|
||||
self.collection = self.create_col()
|
||||
|
||||
@@ -119,7 +111,9 @@ class MongoDB(VectorStoreBase):
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error inserting data: {e}")
|
||||
|
||||
def search(self, query: str, query_vector: List[float], limit=5, filters: Optional[Dict] = None) -> List[OutputData]:
|
||||
def search(
|
||||
self, query: str, query_vector: List[float], limit=5, filters: Optional[Dict] = None
|
||||
) -> List[OutputData]:
|
||||
"""
|
||||
Search for similar vectors using the vector search index.
|
||||
|
||||
@@ -285,7 +279,7 @@ class MongoDB(VectorStoreBase):
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error listing documents: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
|
||||
@@ -88,7 +88,7 @@ class OpenSearchDB(VectorStoreBase):
|
||||
self.client.indices.create(index=name, body=index_settings)
|
||||
|
||||
# Wait for index to be ready
|
||||
max_retries = 180 # 3 minutes timeout
|
||||
max_retries = 180 # 3 minutes timeout
|
||||
retry_count = 0
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user