Formatting (#2750)
This commit is contained in:
@@ -95,10 +95,7 @@ class MemoryClient:
|
||||
self.client = client
|
||||
# Ensure the client has the correct base_url and headers
|
||||
self.client.base_url = httpx.URL(self.host)
|
||||
self.client.headers.update({
|
||||
"Authorization": f"Token {self.api_key}",
|
||||
"Mem0-User-ID": self.user_id
|
||||
})
|
||||
self.client.headers.update({"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id})
|
||||
else:
|
||||
self.client = httpx.Client(
|
||||
base_url=self.host,
|
||||
@@ -237,7 +234,9 @@ class MemoryClient:
|
||||
response.raise_for_status()
|
||||
if "metadata" in kwargs:
|
||||
del kwargs["metadata"]
|
||||
capture_client_event("client.search", self, {"api_version": version, "keys": list(kwargs.keys()), "sync_type": "sync"})
|
||||
capture_client_event(
|
||||
"client.search", self, {"api_version": version, "keys": list(kwargs.keys()), "sync_type": "sync"}
|
||||
)
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
@@ -357,10 +356,7 @@ class MemoryClient:
|
||||
else:
|
||||
entities = self.users()
|
||||
# Filter entities based on provided IDs using list comprehension
|
||||
to_delete = [
|
||||
{"type": entity["type"], "name": entity["name"]}
|
||||
for entity in entities["results"]
|
||||
]
|
||||
to_delete = [{"type": entity["type"], "name": entity["name"]} for entity in entities["results"]]
|
||||
|
||||
params = self._prepare_params()
|
||||
|
||||
@@ -373,7 +369,9 @@ class MemoryClient:
|
||||
response.raise_for_status()
|
||||
|
||||
capture_client_event(
|
||||
"client.delete_users", self, {"user_id": user_id, "agent_id": agent_id, "app_id": app_id, "run_id": run_id, "sync_type": "sync"}
|
||||
"client.delete_users",
|
||||
self,
|
||||
{"user_id": user_id, "agent_id": agent_id, "app_id": app_id, "run_id": run_id, "sync_type": "sync"},
|
||||
)
|
||||
return {
|
||||
"message": "Entity deleted successfully."
|
||||
@@ -454,7 +452,9 @@ class MemoryClient:
|
||||
"""
|
||||
response = self.client.post("/v1/exports/", json={"schema": schema, **self._prepare_params(kwargs)})
|
||||
response.raise_for_status()
|
||||
capture_client_event("client.create_memory_export", self, {"schema": schema, "keys": list(kwargs.keys()), "sync_type": "sync"})
|
||||
capture_client_event(
|
||||
"client.create_memory_export", self, {"schema": schema, "keys": list(kwargs.keys()), "sync_type": "sync"}
|
||||
)
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
@@ -527,7 +527,11 @@ class MemoryClient:
|
||||
)
|
||||
|
||||
payload = self._prepare_params(
|
||||
{"custom_instructions": custom_instructions, "custom_categories": custom_categories, "retrieval_criteria": retrieval_criteria}
|
||||
{
|
||||
"custom_instructions": custom_instructions,
|
||||
"custom_categories": custom_categories,
|
||||
"retrieval_criteria": retrieval_criteria,
|
||||
}
|
||||
)
|
||||
response = self.client.patch(
|
||||
f"/api/v1/orgs/organizations/{self.org_id}/projects/{self.project_id}/",
|
||||
@@ -537,7 +541,12 @@ class MemoryClient:
|
||||
capture_client_event(
|
||||
"client.update_project",
|
||||
self,
|
||||
{"custom_instructions": custom_instructions, "custom_categories": custom_categories, "retrieval_criteria": retrieval_criteria, "sync_type": "sync"},
|
||||
{
|
||||
"custom_instructions": custom_instructions,
|
||||
"custom_categories": custom_categories,
|
||||
"retrieval_criteria": retrieval_criteria,
|
||||
"sync_type": "sync",
|
||||
},
|
||||
)
|
||||
return response.json()
|
||||
|
||||
@@ -750,10 +759,7 @@ class AsyncMemoryClient:
|
||||
self.async_client = client
|
||||
# Ensure the client has the correct base_url and headers
|
||||
self.async_client.base_url = httpx.URL(self.host)
|
||||
self.async_client.headers.update({
|
||||
"Authorization": f"Token {self.api_key}",
|
||||
"Mem0-User-ID": self.user_id
|
||||
})
|
||||
self.async_client.headers.update({"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id})
|
||||
else:
|
||||
self.async_client = httpx.AsyncClient(
|
||||
base_url=self.host,
|
||||
@@ -768,7 +774,11 @@ class AsyncMemoryClient:
|
||||
"""Validate the API key by making a test request."""
|
||||
try:
|
||||
params = self._prepare_params()
|
||||
response = requests.get(f"{self.host}/v1/ping/", headers={"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id}, params=params)
|
||||
response = requests.get(
|
||||
f"{self.host}/v1/ping/",
|
||||
headers={"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id},
|
||||
params=params,
|
||||
)
|
||||
data = response.json()
|
||||
|
||||
response.raise_for_status()
|
||||
@@ -973,10 +983,7 @@ class AsyncMemoryClient:
|
||||
else:
|
||||
entities = await self.users()
|
||||
# Filter entities based on provided IDs using list comprehension
|
||||
to_delete = [
|
||||
{"type": entity["type"], "name": entity["name"]}
|
||||
for entity in entities["results"]
|
||||
]
|
||||
to_delete = [{"type": entity["type"], "name": entity["name"]} for entity in entities["results"]]
|
||||
|
||||
params = self._prepare_params()
|
||||
|
||||
@@ -988,7 +995,11 @@ class AsyncMemoryClient:
|
||||
response = await self.async_client.delete(f"/v2/entities/{entity['type']}/{entity['name']}/", params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
capture_client_event("client.delete_users", self, {"user_id": user_id, "agent_id": agent_id, "app_id": app_id, "run_id": run_id, "sync_type": "async"})
|
||||
capture_client_event(
|
||||
"client.delete_users",
|
||||
self,
|
||||
{"user_id": user_id, "agent_id": agent_id, "app_id": app_id, "run_id": run_id, "sync_type": "async"},
|
||||
)
|
||||
return {
|
||||
"message": "Entity deleted successfully."
|
||||
if (user_id or agent_id or app_id or run_id)
|
||||
@@ -1091,8 +1102,10 @@ class AsyncMemoryClient:
|
||||
|
||||
@api_error_handler
|
||||
async def update_project(
|
||||
self, custom_instructions: Optional[str] = None, custom_categories: Optional[List[str]] = None,
|
||||
retrieval_criteria: Optional[List[Dict[str, Any]]] = None
|
||||
self,
|
||||
custom_instructions: Optional[str] = None,
|
||||
custom_categories: Optional[List[str]] = None,
|
||||
retrieval_criteria: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
if not (self.org_id and self.project_id):
|
||||
raise ValueError("org_id and project_id must be set to update instructions or categories")
|
||||
@@ -1103,7 +1116,11 @@ class AsyncMemoryClient:
|
||||
)
|
||||
|
||||
payload = self._prepare_params(
|
||||
{"custom_instructions": custom_instructions, "custom_categories": custom_categories, "retrieval_criteria": retrieval_criteria}
|
||||
{
|
||||
"custom_instructions": custom_instructions,
|
||||
"custom_categories": custom_categories,
|
||||
"retrieval_criteria": retrieval_criteria,
|
||||
}
|
||||
)
|
||||
response = await self.async_client.patch(
|
||||
f"/api/v1/orgs/organizations/{self.org_id}/projects/{self.project_id}/",
|
||||
@@ -1113,7 +1130,12 @@ class AsyncMemoryClient:
|
||||
capture_client_event(
|
||||
"client.update_project",
|
||||
self,
|
||||
{"custom_instructions": custom_instructions, "custom_categories": custom_categories, "retrieval_criteria": retrieval_criteria, "sync_type": "async"},
|
||||
{
|
||||
"custom_instructions": custom_instructions,
|
||||
"custom_categories": custom_categories,
|
||||
"retrieval_criteria": retrieval_criteria,
|
||||
"sync_type": "async",
|
||||
},
|
||||
)
|
||||
return response.json()
|
||||
|
||||
@@ -1174,4 +1196,3 @@ class AsyncMemoryClient:
|
||||
response.raise_for_status()
|
||||
capture_client_event("client.feedback", self, data, {"sync_type": "async"})
|
||||
return response.json()
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Optional, Union, Type
|
||||
from typing import Any, Dict, Optional, Type, Union
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
@@ -7,33 +7,17 @@ class OpenSearchConfig(BaseModel):
|
||||
collection_name: str = Field("mem0", description="Name of the index")
|
||||
host: str = Field("localhost", description="OpenSearch host")
|
||||
port: int = Field(9200, description="OpenSearch port")
|
||||
user: Optional[str] = Field(
|
||||
None, description="Username for authentication"
|
||||
)
|
||||
password: Optional[str] = Field(
|
||||
None, description="Password for authentication"
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
None, description="API key for authentication (if applicable)"
|
||||
)
|
||||
embedding_model_dims: int = Field(
|
||||
1536, description="Dimension of the embedding vector"
|
||||
)
|
||||
verify_certs: bool = Field(
|
||||
False, description="Verify SSL certificates (default False for OpenSearch)"
|
||||
)
|
||||
use_ssl: bool = Field(
|
||||
False, description="Use SSL for connection (default False for OpenSearch)"
|
||||
)
|
||||
http_auth: Optional[object] = Field(
|
||||
None, description="HTTP authentication method / AWS SigV4"
|
||||
)
|
||||
user: Optional[str] = Field(None, description="Username for authentication")
|
||||
password: Optional[str] = Field(None, description="Password for authentication")
|
||||
api_key: Optional[str] = Field(None, description="API key for authentication (if applicable)")
|
||||
embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector")
|
||||
verify_certs: bool = Field(False, description="Verify SSL certificates (default False for OpenSearch)")
|
||||
use_ssl: bool = Field(False, description="Use SSL for connection (default False for OpenSearch)")
|
||||
http_auth: Optional[object] = Field(None, description="HTTP authentication method / AWS SigV4")
|
||||
connection_class: Optional[Union[str, Type]] = Field(
|
||||
"RequestsHttpConnection", description="Connection class for OpenSearch"
|
||||
)
|
||||
pool_maxsize: int = Field(
|
||||
20, description="Maximum number of connections in the pool"
|
||||
)
|
||||
pool_maxsize: int = Field(20, description="Maximum number of connections in the pool")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -41,7 +25,7 @@ class OpenSearchConfig(BaseModel):
|
||||
# Check if host is provided
|
||||
if not values.get("host"):
|
||||
raise ValueError("Host must be provided for OpenSearch")
|
||||
|
||||
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -52,7 +36,6 @@ class OpenSearchConfig(BaseModel):
|
||||
extra_fields = input_fields - allowed_fields
|
||||
if extra_fields:
|
||||
raise ValueError(
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. "
|
||||
f"Allowed fields: {', '.join(allowed_fields)}"
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. " f"Allowed fields: {', '.join(allowed_fields)}"
|
||||
)
|
||||
return values
|
||||
|
||||
@@ -23,12 +23,12 @@ class AWSBedrockEmbedding(EmbeddingBase):
|
||||
super().__init__(config)
|
||||
|
||||
self.config.model = self.config.model or "amazon.titan-embed-text-v1"
|
||||
|
||||
|
||||
# Get AWS config from environment variables or use defaults
|
||||
aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "")
|
||||
aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "")
|
||||
aws_region = os.environ.get("AWS_REGION", "us-west-2")
|
||||
|
||||
|
||||
# Check if AWS config is provided in the config
|
||||
if hasattr(self.config, "aws_access_key_id"):
|
||||
aws_access_key = self.config.aws_access_key_id
|
||||
@@ -36,7 +36,7 @@ class AWSBedrockEmbedding(EmbeddingBase):
|
||||
aws_secret_key = self.config.aws_secret_access_key
|
||||
if hasattr(self.config, "aws_region"):
|
||||
aws_region = self.config.aws_region
|
||||
|
||||
|
||||
self.client = boto3.client(
|
||||
"bedrock-runtime",
|
||||
region_name=aws_region,
|
||||
|
||||
@@ -11,6 +11,7 @@ logging.getLogger("transformers").setLevel(logging.WARNING)
|
||||
logging.getLogger("sentence_transformers").setLevel(logging.WARNING)
|
||||
logging.getLogger("huggingface_hub").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class HuggingFaceEmbedding(EmbeddingBase):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -22,7 +22,8 @@ class Neo4jConfig(BaseModel):
|
||||
if not url or not username or not password:
|
||||
raise ValueError("Please provide 'url', 'username' and 'password'.")
|
||||
return values
|
||||
|
||||
|
||||
|
||||
class MemgraphConfig(BaseModel):
|
||||
url: Optional[str] = Field(None, description="Host address for the graph database")
|
||||
username: Optional[str] = Field(None, description="Username for the graph database")
|
||||
|
||||
@@ -20,18 +20,19 @@ def extract_provider(model: str) -> str:
|
||||
return provider
|
||||
raise ValueError(f"Unknown provider in model: {model}")
|
||||
|
||||
|
||||
class AWSBedrockLLM(LLMBase):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model = "anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||
|
||||
|
||||
# Get AWS config from environment variables or use defaults
|
||||
aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "")
|
||||
aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "")
|
||||
aws_region = os.environ.get("AWS_REGION", "us-west-2")
|
||||
|
||||
|
||||
# Check if AWS config is provided in the config
|
||||
if hasattr(self.config, "aws_access_key_id"):
|
||||
aws_access_key = self.config.aws_access_key_id
|
||||
@@ -39,14 +40,14 @@ class AWSBedrockLLM(LLMBase):
|
||||
aws_secret_key = self.config.aws_secret_access_key
|
||||
if hasattr(self.config, "aws_region"):
|
||||
aws_region = self.config.aws_region
|
||||
|
||||
|
||||
self.client = boto3.client(
|
||||
"bedrock-runtime",
|
||||
region_name=aws_region,
|
||||
aws_access_key_id=aws_access_key if aws_access_key else None,
|
||||
aws_secret_access_key=aws_secret_key if aws_secret_key else None,
|
||||
)
|
||||
|
||||
|
||||
self.model_kwargs = {
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens_to_sample": self.config.max_tokens,
|
||||
@@ -145,7 +146,9 @@ class AWSBedrockLLM(LLMBase):
|
||||
input_body = {
|
||||
"inputText": prompt,
|
||||
"textGenerationConfig": {
|
||||
"maxTokenCount": self.model_kwargs["max_tokens_to_sample"] or self.model_kwargs["max_tokens"] or 5000,
|
||||
"maxTokenCount": self.model_kwargs["max_tokens_to_sample"]
|
||||
or self.model_kwargs["max_tokens"]
|
||||
or 5000,
|
||||
"topP": self.model_kwargs["top_p"] or 0.9,
|
||||
"temperature": self.model_kwargs["temperature"] or 0.1,
|
||||
},
|
||||
@@ -243,22 +246,15 @@ class AWSBedrockLLM(LLMBase):
|
||||
body = json.dumps(input_body)
|
||||
|
||||
if provider == "anthropic" or provider == "deepseek":
|
||||
|
||||
input_body = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": prompt}]
|
||||
}
|
||||
],
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": prompt}]}],
|
||||
"max_tokens": self.model_kwargs["max_tokens_to_sample"] or self.model_kwargs["max_tokens"] or 5000,
|
||||
"temperature": self.model_kwargs["temperature"] or 0.1,
|
||||
"top_p": self.model_kwargs["top_p"] or 0.9,
|
||||
"anthropic_version": "bedrock-2023-05-31",
|
||||
}
|
||||
|
||||
|
||||
body = json.dumps(input_body)
|
||||
|
||||
|
||||
response = self.client.invoke_model(
|
||||
body=body,
|
||||
@@ -272,6 +268,6 @@ class AWSBedrockLLM(LLMBase):
|
||||
modelId=self.config.model,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
)
|
||||
|
||||
return self._parse_response(response, tools)
|
||||
|
||||
@@ -34,17 +34,17 @@ from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
|
||||
|
||||
|
||||
def _build_filters_and_metadata(
|
||||
*, # Enforce keyword-only arguments
|
||||
*, # Enforce keyword-only arguments
|
||||
user_id: Optional[str] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
run_id: Optional[str] = None,
|
||||
actor_id: Optional[str] = None, # For query-time filtering
|
||||
actor_id: Optional[str] = None, # For query-time filtering
|
||||
input_metadata: Optional[Dict[str, Any]] = None,
|
||||
input_filters: Optional[Dict[str, Any]] = None,
|
||||
) -> tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
"""
|
||||
Constructs metadata for storage and filters for querying based on session and actor identifiers.
|
||||
|
||||
|
||||
This helper ties every memory/query to exactly one session id (`user_id`, `agent_id`, or `run_id`) and optionally narrows queries to a specific `actor_id`. It returns two dicts:
|
||||
|
||||
|
||||
@@ -78,10 +78,10 @@ def _build_filters_and_metadata(
|
||||
- effective_query_filters (Dict[str, Any]): Filters for querying memories,
|
||||
scoped to the determined session and potentially a resolved actor.
|
||||
"""
|
||||
|
||||
|
||||
base_metadata_template = deepcopy(input_metadata) if input_metadata else {}
|
||||
effective_query_filters = deepcopy(input_filters) if input_filters else {}
|
||||
|
||||
|
||||
# ---------- resolve session id (mandatory) ----------
|
||||
session_key, session_val = None, None
|
||||
if user_id:
|
||||
@@ -90,20 +90,20 @@ def _build_filters_and_metadata(
|
||||
session_key, session_val = "agent_id", agent_id
|
||||
elif run_id:
|
||||
session_key, session_val = "run_id", run_id
|
||||
|
||||
|
||||
if session_key is None:
|
||||
raise ValueError("One of 'user_id', 'agent_id', or 'run_id' must be provided.")
|
||||
|
||||
|
||||
base_metadata_template[session_key] = session_val
|
||||
effective_query_filters[session_key] = session_val
|
||||
|
||||
|
||||
# ---------- optional actor filter ----------
|
||||
resolved_actor_id = actor_id or effective_query_filters.get("actor_id")
|
||||
if resolved_actor_id:
|
||||
effective_query_filters["actor_id"] = resolved_actor_id
|
||||
|
||||
|
||||
return base_metadata_template, effective_query_filters
|
||||
|
||||
|
||||
|
||||
setup_config()
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -189,7 +189,7 @@ class Memory(MemoryBase):
|
||||
):
|
||||
"""
|
||||
Create a new memory.
|
||||
|
||||
|
||||
Adds new memories scoped to a single session id (e.g. `user_id`, `agent_id`, or `run_id`). One of those ids is required.
|
||||
|
||||
Args:
|
||||
@@ -208,7 +208,7 @@ class Memory(MemoryBase):
|
||||
creating procedural memories (typically requires 'agent_id'). Otherwise, memories
|
||||
are treated as general conversational/factual memories.memory_type (str, optional): Type of memory to create. Defaults to None. By default, it creates the short term memories and long term (semantic and episodic) memories. Pass "procedural_memory" to create procedural memories.
|
||||
prompt (str, optional): Prompt to use for the memory creation. Defaults to None.
|
||||
|
||||
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the result of the memory addition operation, typically
|
||||
@@ -216,14 +216,14 @@ class Memory(MemoryBase):
|
||||
and potentially "relations" if graph store is enabled.
|
||||
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", "event": "ADD"}]}`
|
||||
"""
|
||||
|
||||
|
||||
processed_metadata, effective_filters = _build_filters_and_metadata(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
run_id=run_id,
|
||||
input_metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
if memory_type is not None and memory_type != MemoryType.PROCEDURAL.value:
|
||||
raise ValueError(
|
||||
f"Invalid 'memory_type'. Please pass {MemoryType.PROCEDURAL.value} to create procedural memories."
|
||||
@@ -231,10 +231,10 @@ class Memory(MemoryBase):
|
||||
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
|
||||
elif isinstance(messages, dict):
|
||||
messages = [messages]
|
||||
|
||||
|
||||
elif not isinstance(messages, list):
|
||||
raise ValueError("messages must be str, dict, or list[dict]")
|
||||
|
||||
@@ -255,7 +255,7 @@ class Memory(MemoryBase):
|
||||
|
||||
vector_store_result = future1.result()
|
||||
graph_result = future2.result()
|
||||
|
||||
|
||||
if self.api_version == "v1.0":
|
||||
warnings.warn(
|
||||
"The current add API output format is deprecated. "
|
||||
@@ -277,21 +277,21 @@ class Memory(MemoryBase):
|
||||
def _add_to_vector_store(self, messages, metadata, filters, infer):
|
||||
if not infer:
|
||||
returned_memories = []
|
||||
for message_dict in messages:
|
||||
if not isinstance(message_dict, dict) or \
|
||||
message_dict.get("role") is None or \
|
||||
message_dict.get("content") is None:
|
||||
for message_dict in messages:
|
||||
if (
|
||||
not isinstance(message_dict, dict)
|
||||
or message_dict.get("role") is None
|
||||
or message_dict.get("content") is None
|
||||
):
|
||||
logger.warning(f"Skipping invalid message format: {message_dict}")
|
||||
continue
|
||||
|
||||
if message_dict["role"] == "system":
|
||||
continue
|
||||
continue
|
||||
|
||||
|
||||
per_msg_meta = deepcopy(metadata)
|
||||
per_msg_meta["role"] = message_dict["role"]
|
||||
|
||||
|
||||
actor_name = message_dict.get("name")
|
||||
if actor_name:
|
||||
per_msg_meta["actor_id"] = actor_name
|
||||
@@ -311,8 +311,8 @@ class Memory(MemoryBase):
|
||||
)
|
||||
return returned_memories
|
||||
|
||||
parsed_messages = parse_messages(messages)
|
||||
|
||||
parsed_messages = parse_messages(messages)
|
||||
|
||||
if self.config.custom_fact_extraction_prompt:
|
||||
system_prompt = self.config.custom_fact_extraction_prompt
|
||||
user_prompt = f"Input:\n{parsed_messages}"
|
||||
@@ -336,7 +336,7 @@ class Memory(MemoryBase):
|
||||
|
||||
retrieved_old_memory = []
|
||||
new_message_embeddings = {}
|
||||
for new_mem in new_retrieved_facts:
|
||||
for new_mem in new_retrieved_facts:
|
||||
messages_embeddings = self.embedding_model.embed(new_mem, "add")
|
||||
new_message_embeddings[new_mem] = messages_embeddings
|
||||
existing_memories = self.vector_store.search(
|
||||
@@ -347,7 +347,7 @@ class Memory(MemoryBase):
|
||||
)
|
||||
for mem in existing_memories:
|
||||
retrieved_old_memory.append({"id": mem.id, "text": mem.payload["data"]})
|
||||
|
||||
|
||||
unique_data = {}
|
||||
for item in retrieved_old_memory:
|
||||
unique_data[item["id"]] = item
|
||||
@@ -389,7 +389,7 @@ class Memory(MemoryBase):
|
||||
if not action_text:
|
||||
logging.info("Skipping memory entry because of empty `text` field.")
|
||||
continue
|
||||
|
||||
|
||||
event_type = resp.get("event")
|
||||
if event_type == "ADD":
|
||||
memory_id = self._create_memory(
|
||||
@@ -405,16 +405,23 @@ class Memory(MemoryBase):
|
||||
existing_embeddings=new_message_embeddings,
|
||||
metadata=deepcopy(metadata),
|
||||
)
|
||||
returned_memories.append({
|
||||
"id": temp_uuid_mapping[resp.get("id")], "memory": action_text,
|
||||
"event": event_type, "previous_memory": resp.get("old_memory"),
|
||||
})
|
||||
returned_memories.append(
|
||||
{
|
||||
"id": temp_uuid_mapping[resp.get("id")],
|
||||
"memory": action_text,
|
||||
"event": event_type,
|
||||
"previous_memory": resp.get("old_memory"),
|
||||
}
|
||||
)
|
||||
elif event_type == "DELETE":
|
||||
self._delete_memory(memory_id=temp_uuid_mapping[resp.get("id")])
|
||||
returned_memories.append({
|
||||
"id": temp_uuid_mapping[resp.get("id")], "memory": action_text,
|
||||
"event": event_type,
|
||||
})
|
||||
returned_memories.append(
|
||||
{
|
||||
"id": temp_uuid_mapping[resp.get("id")],
|
||||
"memory": action_text,
|
||||
"event": event_type,
|
||||
}
|
||||
)
|
||||
elif event_type == "NONE":
|
||||
logging.info("NOOP for Memory.")
|
||||
except Exception as e:
|
||||
@@ -462,11 +469,8 @@ class Memory(MemoryBase):
|
||||
"actor_id",
|
||||
"role",
|
||||
]
|
||||
|
||||
core_and_promoted_keys = {
|
||||
"data", "hash", "created_at", "updated_at", "id",
|
||||
*promoted_payload_keys
|
||||
}
|
||||
|
||||
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys}
|
||||
|
||||
result_item = MemoryItem(
|
||||
id=memory.id,
|
||||
@@ -479,18 +483,16 @@ class Memory(MemoryBase):
|
||||
for key in promoted_payload_keys:
|
||||
if key in memory.payload:
|
||||
result_item[key] = memory.payload[key]
|
||||
|
||||
additional_metadata = {
|
||||
k: v for k, v in memory.payload.items() if k not in core_and_promoted_keys
|
||||
}
|
||||
|
||||
additional_metadata = {k: v for k, v in memory.payload.items() if k not in core_and_promoted_keys}
|
||||
if additional_metadata:
|
||||
result_item["metadata"] = additional_metadata
|
||||
|
||||
|
||||
return result_item
|
||||
|
||||
def get_all(
|
||||
self,
|
||||
*,
|
||||
*,
|
||||
user_id: Optional[str] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
run_id: Optional[str] = None,
|
||||
@@ -505,7 +507,7 @@ class Memory(MemoryBase):
|
||||
agent_id (str, optional): agent id
|
||||
run_id (str, optional): run id
|
||||
filters (dict, optional): Additional custom key-value filters to apply to the search.
|
||||
These are merged with the ID-based scoping filters. For example,
|
||||
These are merged with the ID-based scoping filters. For example,
|
||||
`filters={"actor_id": "some_user"}`.
|
||||
limit (int, optional): The maximum number of memories to return. Defaults to 100.
|
||||
|
||||
@@ -515,21 +517,16 @@ class Memory(MemoryBase):
|
||||
it might return a direct list (see deprecation warning).
|
||||
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", ...}]}`
|
||||
"""
|
||||
|
||||
|
||||
_, effective_filters = _build_filters_and_metadata(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
run_id=run_id,
|
||||
input_filters=filters
|
||||
user_id=user_id, agent_id=agent_id, run_id=run_id, input_filters=filters
|
||||
)
|
||||
|
||||
|
||||
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
|
||||
raise ValueError("At least one of 'user_id', 'agent_id', or 'run_id' must be specified.")
|
||||
|
||||
capture_event(
|
||||
"mem0.get_all",
|
||||
self,
|
||||
{"limit": limit, "keys": list(effective_filters.keys()), "sync_type": "sync"}
|
||||
"mem0.get_all", self, {"limit": limit, "keys": list(effective_filters.keys()), "sync_type": "sync"}
|
||||
)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
@@ -542,9 +539,9 @@ class Memory(MemoryBase):
|
||||
[future_memories, future_graph_entities] if future_graph_entities else [future_memories]
|
||||
)
|
||||
|
||||
all_memories_result = future_memories.result()
|
||||
all_memories_result = future_memories.result()
|
||||
graph_entities_result = future_graph_entities.result() if future_graph_entities else None
|
||||
|
||||
|
||||
if self.enable_graph:
|
||||
return {"results": all_memories_result, "relations": graph_entities_result}
|
||||
|
||||
@@ -556,26 +553,27 @@ class Memory(MemoryBase):
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return all_memories_result
|
||||
return all_memories_result
|
||||
else:
|
||||
return {"results": all_memories_result}
|
||||
|
||||
def _get_all_from_vector_store(self, filters, limit):
|
||||
memories_result = self.vector_store.list(filters=filters, limit=limit)
|
||||
actual_memories = memories_result[0] if isinstance(memories_result, tuple) and len(memories_result) > 0 else memories_result
|
||||
actual_memories = (
|
||||
memories_result[0] if isinstance(memories_result, tuple) and len(memories_result) > 0 else memories_result
|
||||
)
|
||||
|
||||
promoted_payload_keys = [
|
||||
"user_id", "agent_id", "run_id",
|
||||
"user_id",
|
||||
"agent_id",
|
||||
"run_id",
|
||||
"actor_id",
|
||||
"role",
|
||||
]
|
||||
core_and_promoted_keys = {
|
||||
"data", "hash", "created_at", "updated_at", "id",
|
||||
*promoted_payload_keys
|
||||
}
|
||||
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys}
|
||||
|
||||
formatted_memories = []
|
||||
for mem in actual_memories:
|
||||
for mem in actual_memories:
|
||||
memory_item_dict = MemoryItem(
|
||||
id=mem.id,
|
||||
memory=mem.payload["data"],
|
||||
@@ -587,15 +585,13 @@ class Memory(MemoryBase):
|
||||
for key in promoted_payload_keys:
|
||||
if key in mem.payload:
|
||||
memory_item_dict[key] = mem.payload[key]
|
||||
|
||||
additional_metadata = {
|
||||
k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys
|
||||
}
|
||||
|
||||
additional_metadata = {k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys}
|
||||
if additional_metadata:
|
||||
memory_item_dict["metadata"] = additional_metadata
|
||||
|
||||
|
||||
formatted_memories.append(memory_item_dict)
|
||||
|
||||
|
||||
return formatted_memories
|
||||
|
||||
def search(
|
||||
@@ -624,12 +620,9 @@ class Memory(MemoryBase):
|
||||
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", "score": 0.8, ...}]}`
|
||||
"""
|
||||
_, effective_filters = _build_filters_and_metadata(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
run_id=run_id,
|
||||
input_filters=filters
|
||||
user_id=user_id, agent_id=agent_id, run_id=run_id, input_filters=filters
|
||||
)
|
||||
|
||||
|
||||
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
|
||||
raise ValueError("At least one of 'user_id', 'agent_id', or 'run_id' must be specified.")
|
||||
|
||||
@@ -651,7 +644,7 @@ class Memory(MemoryBase):
|
||||
|
||||
original_memories = future_memories.result()
|
||||
graph_entities = future_graph_entities.result() if future_graph_entities else None
|
||||
|
||||
|
||||
if self.enable_graph:
|
||||
return {"results": original_memories, "relations": graph_entities}
|
||||
|
||||
@@ -678,11 +671,8 @@ class Memory(MemoryBase):
|
||||
"actor_id",
|
||||
"role",
|
||||
]
|
||||
|
||||
core_and_promoted_keys = {
|
||||
"data", "hash", "created_at", "updated_at", "id",
|
||||
*promoted_payload_keys
|
||||
}
|
||||
|
||||
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys}
|
||||
|
||||
original_memories = []
|
||||
for mem in memories:
|
||||
@@ -693,18 +683,16 @@ class Memory(MemoryBase):
|
||||
created_at=mem.payload.get("created_at"),
|
||||
updated_at=mem.payload.get("updated_at"),
|
||||
score=mem.score,
|
||||
).model_dump()
|
||||
).model_dump()
|
||||
|
||||
for key in promoted_payload_keys:
|
||||
if key in mem.payload:
|
||||
memory_item_dict[key] = mem.payload[key]
|
||||
|
||||
additional_metadata = {
|
||||
k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys
|
||||
}
|
||||
|
||||
additional_metadata = {k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys}
|
||||
if additional_metadata:
|
||||
memory_item_dict["metadata"] = additional_metadata
|
||||
|
||||
|
||||
original_memories.append(memory_item_dict)
|
||||
|
||||
return original_memories
|
||||
@@ -738,7 +726,7 @@ class Memory(MemoryBase):
|
||||
self._delete_memory(memory_id)
|
||||
return {"message": "Memory deleted successfully!"}
|
||||
|
||||
def delete_all(self, user_id:Optional[str]=None, agent_id:Optional[str]=None, run_id:Optional[str]=None):
|
||||
def delete_all(self, user_id: Optional[str] = None, agent_id: Optional[str] = None, run_id: Optional[str] = None):
|
||||
"""
|
||||
Delete all memories.
|
||||
|
||||
@@ -860,11 +848,11 @@ class Memory(MemoryBase):
|
||||
except Exception:
|
||||
logger.error(f"Error getting memory with ID {memory_id} during update.")
|
||||
raise ValueError(f"Error getting memory with ID {memory_id}. Please provide a valid 'memory_id'")
|
||||
|
||||
|
||||
prev_value = existing_memory.payload.get("data")
|
||||
|
||||
new_metadata = deepcopy(metadata) if metadata is not None else {}
|
||||
|
||||
|
||||
new_metadata["data"] = data
|
||||
new_metadata["hash"] = hashlib.md5(data.encode()).hexdigest()
|
||||
new_metadata["created_at"] = existing_memory.payload.get("created_at")
|
||||
@@ -875,7 +863,7 @@ class Memory(MemoryBase):
|
||||
if "agent_id" in existing_memory.payload:
|
||||
new_metadata["agent_id"] = existing_memory.payload["agent_id"]
|
||||
if "run_id" in existing_memory.payload:
|
||||
new_metadata["run_id"] = existing_memory.payload["run_id"]
|
||||
new_metadata["run_id"] = existing_memory.payload["run_id"]
|
||||
if "actor_id" in existing_memory.payload:
|
||||
new_metadata["actor_id"] = existing_memory.payload["actor_id"]
|
||||
if "role" in existing_memory.payload:
|
||||
@@ -885,14 +873,14 @@ class Memory(MemoryBase):
|
||||
embeddings = existing_embeddings[data]
|
||||
else:
|
||||
embeddings = self.embedding_model.embed(data, "update")
|
||||
|
||||
|
||||
self.vector_store.update(
|
||||
vector_id=memory_id,
|
||||
vector=embeddings,
|
||||
payload=new_metadata,
|
||||
)
|
||||
logger.info(f"Updating memory with ID {memory_id=} with {data=}")
|
||||
|
||||
|
||||
self.db.add_history(
|
||||
memory_id,
|
||||
prev_value,
|
||||
@@ -1037,12 +1025,9 @@ class AsyncMemory(MemoryBase):
|
||||
dict: A dictionary containing the result of the memory addition operation.
|
||||
"""
|
||||
processed_metadata, effective_filters = _build_filters_and_metadata(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
run_id=run_id,
|
||||
input_metadata=metadata
|
||||
user_id=user_id, agent_id=agent_id, run_id=run_id, input_metadata=metadata
|
||||
)
|
||||
|
||||
|
||||
if memory_type is not None and memory_type != MemoryType.PROCEDURAL.value:
|
||||
raise ValueError(
|
||||
f"Invalid 'memory_type'. Please pass {MemoryType.PROCEDURAL.value} to create procedural memories."
|
||||
@@ -1050,15 +1035,17 @@ class AsyncMemory(MemoryBase):
|
||||
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
|
||||
elif isinstance(messages, dict):
|
||||
messages = [messages]
|
||||
|
||||
|
||||
elif not isinstance(messages, list):
|
||||
raise ValueError("messages must be str, dict, or list[dict]")
|
||||
|
||||
if agent_id is not None and memory_type == MemoryType.PROCEDURAL.value:
|
||||
results = await self._create_procedural_memory(messages, metadata=processed_metadata, prompt=prompt, llm=llm)
|
||||
results = await self._create_procedural_memory(
|
||||
messages, metadata=processed_metadata, prompt=prompt, llm=llm
|
||||
)
|
||||
return results
|
||||
|
||||
if self.config.llm.config.get("enable_vision"):
|
||||
@@ -1066,7 +1053,9 @@ class AsyncMemory(MemoryBase):
|
||||
else:
|
||||
messages = parse_vision_messages(messages)
|
||||
|
||||
vector_store_task = asyncio.create_task(self._add_to_vector_store(messages, processed_metadata, effective_filters, infer))
|
||||
vector_store_task = asyncio.create_task(
|
||||
self._add_to_vector_store(messages, processed_metadata, effective_filters, infer)
|
||||
)
|
||||
graph_task = asyncio.create_task(self._add_to_graph(messages, effective_filters))
|
||||
|
||||
vector_store_result, graph_result = await asyncio.gather(vector_store_task, graph_task)
|
||||
@@ -1090,8 +1079,8 @@ class AsyncMemory(MemoryBase):
|
||||
return {"results": vector_store_result}
|
||||
|
||||
async def _add_to_vector_store(
|
||||
self,
|
||||
messages: list,
|
||||
self,
|
||||
messages: list,
|
||||
metadata: dict,
|
||||
filters: dict,
|
||||
infer: bool,
|
||||
@@ -1099,9 +1088,11 @@ class AsyncMemory(MemoryBase):
|
||||
if not infer:
|
||||
returned_memories = []
|
||||
for message_dict in messages:
|
||||
if not isinstance(message_dict, dict) or \
|
||||
message_dict.get("role") is None or \
|
||||
message_dict.get("content") is None:
|
||||
if (
|
||||
not isinstance(message_dict, dict)
|
||||
or message_dict.get("role") is None
|
||||
or message_dict.get("content") is None
|
||||
):
|
||||
logger.warning(f"Skipping invalid message format (async): {message_dict}")
|
||||
continue
|
||||
|
||||
@@ -1110,20 +1101,24 @@ class AsyncMemory(MemoryBase):
|
||||
|
||||
per_msg_meta = deepcopy(metadata)
|
||||
per_msg_meta["role"] = message_dict["role"]
|
||||
|
||||
|
||||
actor_name = message_dict.get("name")
|
||||
if actor_name:
|
||||
per_msg_meta["actor_id"] = actor_name
|
||||
|
||||
|
||||
msg_content = message_dict["content"]
|
||||
msg_embeddings = await asyncio.to_thread(self.embedding_model.embed, msg_content, "add")
|
||||
mem_id = await self._create_memory(msg_content, msg_embeddings, per_msg_meta)
|
||||
|
||||
returned_memories.append({
|
||||
"id": mem_id, "memory": msg_content, "event": "ADD",
|
||||
"actor_id": actor_name if actor_name else None,
|
||||
"role": message_dict["role"],
|
||||
})
|
||||
|
||||
returned_memories.append(
|
||||
{
|
||||
"id": mem_id,
|
||||
"memory": msg_content,
|
||||
"event": "ADD",
|
||||
"actor_id": actor_name if actor_name else None,
|
||||
"role": message_dict["role"],
|
||||
}
|
||||
)
|
||||
return returned_memories
|
||||
|
||||
parsed_messages = parse_messages(messages)
|
||||
@@ -1142,17 +1137,21 @@ class AsyncMemory(MemoryBase):
|
||||
response = remove_code_blocks(response)
|
||||
new_retrieved_facts = json.loads(response)["facts"]
|
||||
except Exception as e:
|
||||
logging.error(f"Error in new_retrieved_facts: {e}"); new_retrieved_facts = []
|
||||
logging.error(f"Error in new_retrieved_facts: {e}")
|
||||
new_retrieved_facts = []
|
||||
|
||||
retrieved_old_memory = []
|
||||
new_message_embeddings = {}
|
||||
|
||||
|
||||
async def process_fact_for_search(new_mem_content):
|
||||
embeddings = await asyncio.to_thread(self.embedding_model.embed, new_mem_content, "add")
|
||||
new_message_embeddings[new_mem_content] = embeddings
|
||||
existing_mems = await asyncio.to_thread(
|
||||
self.vector_store.search, query=new_mem_content, vectors=embeddings,
|
||||
limit=5, filters=filters, # 'filters' is query_filters_for_inference
|
||||
self.vector_store.search,
|
||||
query=new_mem_content,
|
||||
vectors=embeddings,
|
||||
limit=5,
|
||||
filters=filters, # 'filters' is query_filters_for_inference
|
||||
)
|
||||
return [{"id": mem.id, "text": mem.payload["data"]} for mem in existing_mems]
|
||||
|
||||
@@ -1160,9 +1159,10 @@ class AsyncMemory(MemoryBase):
|
||||
search_results_list = await asyncio.gather(*search_tasks)
|
||||
for result_group in search_results_list:
|
||||
retrieved_old_memory.extend(result_group)
|
||||
|
||||
|
||||
unique_data = {}
|
||||
for item in retrieved_old_memory: unique_data[item["id"]] = item
|
||||
for item in retrieved_old_memory:
|
||||
unique_data[item["id"]] = item
|
||||
retrieved_old_memory = list(unique_data.values())
|
||||
logging.info(f"Total existing memories: {len(retrieved_old_memory)}")
|
||||
temp_uuid_mapping = {}
|
||||
@@ -1180,35 +1180,45 @@ class AsyncMemory(MemoryBase):
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error in new memory actions response: {e}"); response = ""
|
||||
|
||||
logging.error(f"Error in new memory actions response: {e}")
|
||||
response = ""
|
||||
|
||||
try:
|
||||
response = remove_code_blocks(response)
|
||||
new_memories_with_actions = json.loads(response)
|
||||
except Exception as e:
|
||||
logging.error(f"Invalid JSON response: {e}"); new_memories_with_actions = {}
|
||||
logging.error(f"Invalid JSON response: {e}")
|
||||
new_memories_with_actions = {}
|
||||
|
||||
returned_memories = []
|
||||
returned_memories = []
|
||||
try:
|
||||
memory_tasks = []
|
||||
for resp in new_memories_with_actions.get("memory", []):
|
||||
logging.info(resp)
|
||||
try:
|
||||
action_text = resp.get("text")
|
||||
if not action_text: continue
|
||||
if not action_text:
|
||||
continue
|
||||
event_type = resp.get("event")
|
||||
|
||||
if event_type == "ADD":
|
||||
task = asyncio.create_task(self._create_memory(
|
||||
data=action_text, existing_embeddings=new_message_embeddings,
|
||||
metadata=deepcopy(metadata)
|
||||
))
|
||||
task = asyncio.create_task(
|
||||
self._create_memory(
|
||||
data=action_text,
|
||||
existing_embeddings=new_message_embeddings,
|
||||
metadata=deepcopy(metadata),
|
||||
)
|
||||
)
|
||||
memory_tasks.append((task, resp, "ADD", None))
|
||||
elif event_type == "UPDATE":
|
||||
task = asyncio.create_task(self._update_memory(
|
||||
memory_id=temp_uuid_mapping[resp["id"]], data=action_text,
|
||||
existing_embeddings=new_message_embeddings, metadata=deepcopy(metadata)
|
||||
))
|
||||
task = asyncio.create_task(
|
||||
self._update_memory(
|
||||
memory_id=temp_uuid_mapping[resp["id"]],
|
||||
data=action_text,
|
||||
existing_embeddings=new_message_embeddings,
|
||||
metadata=deepcopy(metadata),
|
||||
)
|
||||
)
|
||||
memory_tasks.append((task, resp, "UPDATE", temp_uuid_mapping[resp["id"]]))
|
||||
elif event_type == "DELETE":
|
||||
task = asyncio.create_task(self._delete_memory(memory_id=temp_uuid_mapping[resp.get("id")]))
|
||||
@@ -1217,31 +1227,30 @@ class AsyncMemory(MemoryBase):
|
||||
logging.info("NOOP for Memory (async).")
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing memory action (async): {resp}, Error: {e}")
|
||||
|
||||
|
||||
for task, resp, event_type, mem_id in memory_tasks:
|
||||
try:
|
||||
result_id = await task
|
||||
if event_type == "ADD":
|
||||
returned_memories.append({
|
||||
"id": result_id, "memory": resp.get("text"), "event": event_type
|
||||
})
|
||||
returned_memories.append({"id": result_id, "memory": resp.get("text"), "event": event_type})
|
||||
elif event_type == "UPDATE":
|
||||
returned_memories.append({
|
||||
"id": mem_id, "memory": resp.get("text"),
|
||||
"event": event_type, "previous_memory": resp.get("old_memory")
|
||||
})
|
||||
returned_memories.append(
|
||||
{
|
||||
"id": mem_id,
|
||||
"memory": resp.get("text"),
|
||||
"event": event_type,
|
||||
"previous_memory": resp.get("old_memory"),
|
||||
}
|
||||
)
|
||||
elif event_type == "DELETE":
|
||||
returned_memories.append({
|
||||
"id": mem_id, "memory": resp.get("text"), "event": event_type
|
||||
})
|
||||
returned_memories.append({"id": mem_id, "memory": resp.get("text"), "event": event_type})
|
||||
except Exception as e:
|
||||
logging.error(f"Error awaiting memory task (async): {e}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error in memory processing loop (async): {e}")
|
||||
|
||||
|
||||
capture_event(
|
||||
"mem0.add", self,
|
||||
{"version": self.api_version, "keys": list(filters.keys()), "sync_type": "async"}
|
||||
"mem0.add", self, {"version": self.api_version, "keys": list(filters.keys()), "sync_type": "async"}
|
||||
)
|
||||
return returned_memories
|
||||
|
||||
@@ -1272,17 +1281,14 @@ class AsyncMemory(MemoryBase):
|
||||
return None
|
||||
|
||||
promoted_payload_keys = [
|
||||
"user_id",
|
||||
"agent_id",
|
||||
"run_id",
|
||||
"user_id",
|
||||
"agent_id",
|
||||
"run_id",
|
||||
"actor_id",
|
||||
"role",
|
||||
]
|
||||
|
||||
core_and_promoted_keys = {
|
||||
"data", "hash", "created_at", "updated_at", "id",
|
||||
*promoted_payload_keys
|
||||
}
|
||||
|
||||
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys}
|
||||
|
||||
result_item = MemoryItem(
|
||||
id=memory.id,
|
||||
@@ -1295,18 +1301,16 @@ class AsyncMemory(MemoryBase):
|
||||
for key in promoted_payload_keys:
|
||||
if key in memory.payload:
|
||||
result_item[key] = memory.payload[key]
|
||||
|
||||
additional_metadata = {
|
||||
k: v for k, v in memory.payload.items() if k not in core_and_promoted_keys
|
||||
}
|
||||
|
||||
additional_metadata = {k: v for k, v in memory.payload.items() if k not in core_and_promoted_keys}
|
||||
if additional_metadata:
|
||||
result_item["metadata"] = additional_metadata
|
||||
|
||||
|
||||
return result_item
|
||||
|
||||
async def get_all(
|
||||
self,
|
||||
*,
|
||||
*,
|
||||
user_id: Optional[str] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
run_id: Optional[str] = None,
|
||||
@@ -1314,41 +1318,36 @@ class AsyncMemory(MemoryBase):
|
||||
limit: int = 100,
|
||||
):
|
||||
"""
|
||||
List all memories.
|
||||
List all memories.
|
||||
|
||||
Args:
|
||||
user_id (str, optional): user id
|
||||
agent_id (str, optional): agent id
|
||||
run_id (str, optional): run id
|
||||
filters (dict, optional): Additional custom key-value filters to apply to the search.
|
||||
These are merged with the ID-based scoping filters. For example,
|
||||
`filters={"actor_id": "some_user"}`.
|
||||
limit (int, optional): The maximum number of memories to return. Defaults to 100.
|
||||
Args:
|
||||
user_id (str, optional): user id
|
||||
agent_id (str, optional): agent id
|
||||
run_id (str, optional): run id
|
||||
filters (dict, optional): Additional custom key-value filters to apply to the search.
|
||||
These are merged with the ID-based scoping filters. For example,
|
||||
`filters={"actor_id": "some_user"}`.
|
||||
limit (int, optional): The maximum number of memories to return. Defaults to 100.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing a list of memories under the "results" key,
|
||||
and potentially "relations" if graph store is enabled. For API v1.0,
|
||||
it might return a direct list (see deprecation warning).
|
||||
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", ...}]}`
|
||||
Returns:
|
||||
dict: A dictionary containing a list of memories under the "results" key,
|
||||
and potentially "relations" if graph store is enabled. For API v1.0,
|
||||
it might return a direct list (see deprecation warning).
|
||||
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", ...}]}`
|
||||
"""
|
||||
|
||||
|
||||
_, effective_filters = _build_filters_and_metadata(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
run_id=run_id,
|
||||
input_filters=filters
|
||||
user_id=user_id, agent_id=agent_id, run_id=run_id, input_filters=filters
|
||||
)
|
||||
|
||||
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
|
||||
raise ValueError(
|
||||
"When 'conversation_id' is not provided (classic mode), "
|
||||
"at least one of 'user_id', 'agent_id', or 'run_id' must be specified for get_all."
|
||||
)
|
||||
raise ValueError(
|
||||
"When 'conversation_id' is not provided (classic mode), "
|
||||
"at least one of 'user_id', 'agent_id', or 'run_id' must be specified for get_all."
|
||||
)
|
||||
|
||||
capture_event(
|
||||
"mem0.get_all",
|
||||
self,
|
||||
{"limit": limit, "keys": list(effective_filters.keys()), "sync_type": "async"}
|
||||
"mem0.get_all", self, {"limit": limit, "keys": list(effective_filters.keys()), "sync_type": "async"}
|
||||
)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
@@ -1361,9 +1360,9 @@ class AsyncMemory(MemoryBase):
|
||||
[future_memories, future_graph_entities] if future_graph_entities else [future_memories]
|
||||
)
|
||||
|
||||
all_memories_result = future_memories.result()
|
||||
all_memories_result = future_memories.result()
|
||||
graph_entities_result = future_graph_entities.result() if future_graph_entities else None
|
||||
|
||||
|
||||
if self.enable_graph:
|
||||
return {"results": all_memories_result, "relations": graph_entities_result}
|
||||
|
||||
@@ -1381,20 +1380,21 @@ class AsyncMemory(MemoryBase):
|
||||
|
||||
async def _get_all_from_vector_store(self, filters, limit):
|
||||
memories_result = await asyncio.to_thread(self.vector_store.list, filters=filters, limit=limit)
|
||||
actual_memories = memories_result[0] if isinstance(memories_result, tuple) and len(memories_result) > 0 else memories_result
|
||||
actual_memories = (
|
||||
memories_result[0] if isinstance(memories_result, tuple) and len(memories_result) > 0 else memories_result
|
||||
)
|
||||
|
||||
promoted_payload_keys = [
|
||||
"user_id", "agent_id", "run_id",
|
||||
"user_id",
|
||||
"agent_id",
|
||||
"run_id",
|
||||
"actor_id",
|
||||
"role",
|
||||
]
|
||||
core_and_promoted_keys = {
|
||||
"data", "hash", "created_at", "updated_at", "id",
|
||||
*promoted_payload_keys
|
||||
}
|
||||
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys}
|
||||
|
||||
formatted_memories = []
|
||||
for mem in actual_memories:
|
||||
for mem in actual_memories:
|
||||
memory_item_dict = MemoryItem(
|
||||
id=mem.id,
|
||||
memory=mem.payload["data"],
|
||||
@@ -1406,15 +1406,13 @@ class AsyncMemory(MemoryBase):
|
||||
for key in promoted_payload_keys:
|
||||
if key in mem.payload:
|
||||
memory_item_dict[key] = mem.payload[key]
|
||||
|
||||
additional_metadata = {
|
||||
k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys
|
||||
}
|
||||
|
||||
additional_metadata = {k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys}
|
||||
if additional_metadata:
|
||||
memory_item_dict["metadata"] = additional_metadata
|
||||
|
||||
|
||||
formatted_memories.append(memory_item_dict)
|
||||
|
||||
|
||||
return formatted_memories
|
||||
|
||||
async def search(
|
||||
@@ -1442,16 +1440,13 @@ class AsyncMemory(MemoryBase):
|
||||
and potentially "relations" if graph store is enabled.
|
||||
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", "score": 0.8, ...}]}`
|
||||
"""
|
||||
|
||||
|
||||
_, effective_filters = _build_filters_and_metadata(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
run_id=run_id,
|
||||
input_filters=filters
|
||||
user_id=user_id, agent_id=agent_id, run_id=run_id, input_filters=filters
|
||||
)
|
||||
|
||||
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
|
||||
raise ValueError("at least one of 'user_id', 'agent_id', or 'run_id' must be specified ")
|
||||
raise ValueError("at least one of 'user_id', 'agent_id', or 'run_id' must be specified ")
|
||||
|
||||
capture_event(
|
||||
"mem0.search",
|
||||
@@ -1460,22 +1455,20 @@ class AsyncMemory(MemoryBase):
|
||||
)
|
||||
|
||||
vector_store_task = asyncio.create_task(self._search_vector_store(query, effective_filters, limit))
|
||||
|
||||
|
||||
graph_task = None
|
||||
if self.enable_graph:
|
||||
if hasattr(self.graph.search, "__await__"): # Check if graph search is async
|
||||
graph_task = asyncio.create_task(self.graph.search(query, effective_filters, limit))
|
||||
else:
|
||||
graph_task = asyncio.create_task(
|
||||
asyncio.to_thread(self.graph.search, query, effective_filters, limit)
|
||||
)
|
||||
|
||||
graph_task = asyncio.create_task(asyncio.to_thread(self.graph.search, query, effective_filters, limit))
|
||||
|
||||
if graph_task:
|
||||
original_memories, graph_entities = await asyncio.gather(vector_store_task, graph_task)
|
||||
else:
|
||||
original_memories = await vector_store_task
|
||||
graph_entities = None
|
||||
|
||||
|
||||
if self.enable_graph:
|
||||
return {"results": original_memories, "relations": graph_entities}
|
||||
|
||||
@@ -1504,11 +1497,8 @@ class AsyncMemory(MemoryBase):
|
||||
"actor_id",
|
||||
"role",
|
||||
]
|
||||
|
||||
core_and_promoted_keys = {
|
||||
"data", "hash", "created_at", "updated_at", "id",
|
||||
*promoted_payload_keys
|
||||
}
|
||||
|
||||
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys}
|
||||
|
||||
original_memories = []
|
||||
for mem in memories:
|
||||
@@ -1518,19 +1508,17 @@ class AsyncMemory(MemoryBase):
|
||||
hash=mem.payload.get("hash"),
|
||||
created_at=mem.payload.get("created_at"),
|
||||
updated_at=mem.payload.get("updated_at"),
|
||||
score=mem.score,
|
||||
).model_dump()
|
||||
score=mem.score,
|
||||
).model_dump()
|
||||
|
||||
for key in promoted_payload_keys:
|
||||
if key in mem.payload:
|
||||
memory_item_dict[key] = mem.payload[key]
|
||||
|
||||
additional_metadata = {
|
||||
k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys
|
||||
}
|
||||
|
||||
additional_metadata = {k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys}
|
||||
if additional_metadata:
|
||||
memory_item_dict["metadata"] = additional_metadata
|
||||
|
||||
|
||||
original_memories.append(memory_item_dict)
|
||||
|
||||
return original_memories
|
||||
@@ -1650,7 +1638,7 @@ class AsyncMemory(MemoryBase):
|
||||
capture_event("mem0._create_memory", self, {"memory_id": memory_id, "sync_type": "async"})
|
||||
return memory_id
|
||||
|
||||
async def _create_procedural_memory(self, messages, metadata=None,llm=None ,prompt=None):
|
||||
async def _create_procedural_memory(self, messages, metadata=None, llm=None, prompt=None):
|
||||
"""
|
||||
Create a procedural memory asynchronously
|
||||
|
||||
@@ -1709,11 +1697,11 @@ class AsyncMemory(MemoryBase):
|
||||
except Exception:
|
||||
logger.error(f"Error getting memory with ID {memory_id} during update.")
|
||||
raise ValueError(f"Error getting memory with ID {memory_id}. Please provide a valid 'memory_id'")
|
||||
|
||||
|
||||
prev_value = existing_memory.payload.get("data")
|
||||
|
||||
new_metadata = deepcopy(metadata) if metadata is not None else {}
|
||||
|
||||
|
||||
new_metadata["data"] = data
|
||||
new_metadata["hash"] = hashlib.md5(data.encode()).hexdigest()
|
||||
new_metadata["created_at"] = existing_memory.payload.get("created_at")
|
||||
@@ -1725,8 +1713,7 @@ class AsyncMemory(MemoryBase):
|
||||
new_metadata["agent_id"] = existing_memory.payload["agent_id"]
|
||||
if "run_id" in existing_memory.payload:
|
||||
new_metadata["run_id"] = existing_memory.payload["run_id"]
|
||||
|
||||
|
||||
|
||||
if "actor_id" in existing_memory.payload:
|
||||
new_metadata["actor_id"] = existing_memory.payload["actor_id"]
|
||||
if "role" in existing_memory.payload:
|
||||
@@ -1736,7 +1723,7 @@ class AsyncMemory(MemoryBase):
|
||||
embeddings = existing_embeddings[data]
|
||||
else:
|
||||
embeddings = await asyncio.to_thread(self.embedding_model.embed, data, "update")
|
||||
|
||||
|
||||
await asyncio.to_thread(
|
||||
self.vector_store.update,
|
||||
vector_id=memory_id,
|
||||
@@ -1744,7 +1731,7 @@ class AsyncMemory(MemoryBase):
|
||||
payload=new_metadata,
|
||||
)
|
||||
logger.info(f"Updating memory with ID {memory_id=} with {data=}")
|
||||
|
||||
|
||||
await asyncio.to_thread(
|
||||
self.db.add_history,
|
||||
memory_id,
|
||||
|
||||
@@ -5,16 +5,12 @@ from mem0.memory.utils import format_entities
|
||||
try:
|
||||
from langchain_memgraph import Memgraph
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"langchain_memgraph is not installed. Please install it using pip install langchain-memgraph"
|
||||
)
|
||||
raise ImportError("langchain_memgraph is not installed. Please install it using pip install langchain-memgraph")
|
||||
|
||||
try:
|
||||
from rank_bm25 import BM25Okapi
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"rank_bm25 is not installed. Please install it using pip install rank-bm25"
|
||||
)
|
||||
raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25")
|
||||
|
||||
from mem0.graphs.tools import (
|
||||
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
|
||||
@@ -74,22 +70,14 @@ class MemoryGraph:
|
||||
filters (dict): A dictionary containing filters to be applied during the addition.
|
||||
"""
|
||||
entity_type_map = self._retrieve_nodes_from_data(data, filters)
|
||||
to_be_added = self._establish_nodes_relations_from_data(
|
||||
data, filters, entity_type_map
|
||||
)
|
||||
search_output = self._search_graph_db(
|
||||
node_list=list(entity_type_map.keys()), filters=filters
|
||||
)
|
||||
to_be_deleted = self._get_delete_entities_from_search_output(
|
||||
search_output, data, filters
|
||||
)
|
||||
to_be_added = self._establish_nodes_relations_from_data(data, filters, entity_type_map)
|
||||
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
|
||||
to_be_deleted = self._get_delete_entities_from_search_output(search_output, data, filters)
|
||||
|
||||
# TODO: Batch queries with APOC plugin
|
||||
# TODO: Add more filter support
|
||||
deleted_entities = self._delete_entities(to_be_deleted, filters["user_id"])
|
||||
added_entities = self._add_entities(
|
||||
to_be_added, filters["user_id"], entity_type_map
|
||||
)
|
||||
added_entities = self._add_entities(to_be_added, filters["user_id"], entity_type_map)
|
||||
|
||||
return {"deleted_entities": deleted_entities, "added_entities": added_entities}
|
||||
|
||||
@@ -108,16 +96,13 @@ class MemoryGraph:
|
||||
- "entities": List of related graph data based on the query.
|
||||
"""
|
||||
entity_type_map = self._retrieve_nodes_from_data(query, filters)
|
||||
search_output = self._search_graph_db(
|
||||
node_list=list(entity_type_map.keys()), filters=filters
|
||||
)
|
||||
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
|
||||
|
||||
if not search_output:
|
||||
return []
|
||||
|
||||
search_outputs_sequence = [
|
||||
[item["source"], item["relationship"], item["destination"]]
|
||||
for item in search_output
|
||||
[item["source"], item["relationship"], item["destination"]] for item in search_output
|
||||
]
|
||||
bm25 = BM25Okapi(search_outputs_sequence)
|
||||
|
||||
@@ -126,9 +111,7 @@ class MemoryGraph:
|
||||
|
||||
search_results = []
|
||||
for item in reranked_results:
|
||||
search_results.append(
|
||||
{"source": item[0], "relationship": item[1], "destination": item[2]}
|
||||
)
|
||||
search_results.append({"source": item[0], "relationship": item[1], "destination": item[2]})
|
||||
|
||||
logger.info(f"Returned {len(search_results)} search results")
|
||||
|
||||
@@ -161,9 +144,7 @@ class MemoryGraph:
|
||||
RETURN n.name AS source, type(r) AS relationship, m.name AS target
|
||||
LIMIT $limit
|
||||
"""
|
||||
results = self.graph.query(
|
||||
query, params={"user_id": filters["user_id"], "limit": limit}
|
||||
)
|
||||
results = self.graph.query(query, params={"user_id": filters["user_id"], "limit": limit})
|
||||
|
||||
final_results = []
|
||||
for result in results:
|
||||
@@ -208,13 +189,8 @@ class MemoryGraph:
|
||||
f"Error in search tool: {e}, llm_provider={self.llm_provider}, search_results={search_results}"
|
||||
)
|
||||
|
||||
entity_type_map = {
|
||||
k.lower().replace(" ", "_"): v.lower().replace(" ", "_")
|
||||
for k, v in entity_type_map.items()
|
||||
}
|
||||
logger.debug(
|
||||
f"Entity type map: {entity_type_map}\n search_results={search_results}"
|
||||
)
|
||||
entity_type_map = {k.lower().replace(" ", "_"): v.lower().replace(" ", "_") for k, v in entity_type_map.items()}
|
||||
logger.debug(f"Entity type map: {entity_type_map}\n search_results={search_results}")
|
||||
return entity_type_map
|
||||
|
||||
def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
|
||||
@@ -223,9 +199,7 @@ class MemoryGraph:
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": EXTRACT_RELATIONS_PROMPT.replace(
|
||||
"USER_ID", filters["user_id"]
|
||||
).replace(
|
||||
"content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]).replace(
|
||||
"CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}"
|
||||
),
|
||||
},
|
||||
@@ -235,9 +209,7 @@ class MemoryGraph:
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": EXTRACT_RELATIONS_PROMPT.replace(
|
||||
"USER_ID", filters["user_id"]
|
||||
),
|
||||
"content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
@@ -304,9 +276,7 @@ class MemoryGraph:
|
||||
def _get_delete_entities_from_search_output(self, search_output, data, filters):
|
||||
"""Get the entities to be deleted from the search output."""
|
||||
search_output_string = format_entities(search_output)
|
||||
system_prompt, user_prompt = get_delete_messages(
|
||||
search_output_string, data, filters["user_id"]
|
||||
)
|
||||
system_prompt, user_prompt = get_delete_messages(search_output_string, data, filters["user_id"])
|
||||
|
||||
_tools = [DELETE_MEMORY_TOOL_GRAPH]
|
||||
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
||||
@@ -379,12 +349,8 @@ class MemoryGraph:
|
||||
# search for the nodes with the closest embeddings; this is basically
|
||||
# comparison of one embedding to all embeddings in a graph -> vector
|
||||
# search with cosine similarity metric
|
||||
source_node_search_result = self._search_source_node(
|
||||
source_embedding, user_id, threshold=0.9
|
||||
)
|
||||
destination_node_search_result = self._search_destination_node(
|
||||
dest_embedding, user_id, threshold=0.9
|
||||
)
|
||||
source_node_search_result = self._search_source_node(source_embedding, user_id, threshold=0.9)
|
||||
destination_node_search_result = self._search_destination_node(dest_embedding, user_id, threshold=0.9)
|
||||
|
||||
# TODO: Create a cypher query and common params for all the cases
|
||||
if not destination_node_search_result and source_node_search_result:
|
||||
@@ -424,9 +390,7 @@ class MemoryGraph:
|
||||
"""
|
||||
|
||||
params = {
|
||||
"destination_id": destination_node_search_result[0][
|
||||
"id(destination_candidate)"
|
||||
],
|
||||
"destination_id": destination_node_search_result[0]["id(destination_candidate)"],
|
||||
"source_name": source,
|
||||
"source_embedding": source_embedding,
|
||||
"user_id": user_id,
|
||||
@@ -445,9 +409,7 @@ class MemoryGraph:
|
||||
"""
|
||||
params = {
|
||||
"source_id": source_node_search_result[0]["id(source_candidate)"],
|
||||
"destination_id": destination_node_search_result[0][
|
||||
"id(destination_candidate)"
|
||||
],
|
||||
"destination_id": destination_node_search_result[0]["id(destination_candidate)"],
|
||||
"user_id": user_id,
|
||||
}
|
||||
else:
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import logging
|
||||
import sqlite3
|
||||
import threading
|
||||
import uuid
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -23,9 +23,7 @@ class SQLiteManager:
|
||||
"""
|
||||
with self._lock, self.connection:
|
||||
cur = self.connection.cursor()
|
||||
cur.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='history'"
|
||||
)
|
||||
cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='history'")
|
||||
if cur.fetchone() is None:
|
||||
return # nothing to migrate
|
||||
|
||||
@@ -51,13 +49,11 @@ class SQLiteManager:
|
||||
logger.info("Migrating history table to new schema (no convo columns).")
|
||||
cur.execute("ALTER TABLE history RENAME TO history_old")
|
||||
|
||||
self._create_history_table()
|
||||
self._create_history_table()
|
||||
|
||||
intersecting = list(expected_cols & old_cols)
|
||||
cols_csv = ", ".join(intersecting)
|
||||
cur.execute(
|
||||
f"INSERT INTO history ({cols_csv}) SELECT {cols_csv} FROM history_old"
|
||||
)
|
||||
cur.execute(f"INSERT INTO history ({cols_csv}) SELECT {cols_csv} FROM history_old")
|
||||
cur.execute("DROP TABLE history_old")
|
||||
|
||||
def _create_history_table(self) -> None:
|
||||
|
||||
@@ -9,8 +9,8 @@ import mem0
|
||||
from mem0.memory.setup import get_or_create_user_id
|
||||
|
||||
MEM0_TELEMETRY = os.environ.get("MEM0_TELEMETRY", "True")
|
||||
PROJECT_API_KEY="phc_hgJkUVJFYtmaJqrvf6CYN67TIQ8yhXAkWzUn9AMU4yX"
|
||||
HOST="https://us.i.posthog.com"
|
||||
PROJECT_API_KEY = "phc_hgJkUVJFYtmaJqrvf6CYN67TIQ8yhXAkWzUn9AMU4yX"
|
||||
HOST = "https://us.i.posthog.com"
|
||||
|
||||
if isinstance(MEM0_TELEMETRY, str):
|
||||
MEM0_TELEMETRY = MEM0_TELEMETRY.lower() in ("true", "1", "yes")
|
||||
|
||||
@@ -98,9 +98,8 @@ class VectorStoreFactory:
|
||||
return vector_store_instance(**config)
|
||||
else:
|
||||
raise ValueError(f"Unsupported VectorStore provider: {provider_name}")
|
||||
|
||||
|
||||
@classmethod
|
||||
def reset(cls, instance):
|
||||
instance.reset()
|
||||
return instance
|
||||
|
||||
|
||||
@@ -377,4 +377,3 @@ class AzureAISearch(VectorStoreBase):
|
||||
except Exception as e:
|
||||
logger.error(f"Error resetting index {self.index_name}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ class VectorStoreBase(ABC):
|
||||
def list(self, filters=None, limit=None):
|
||||
"""List all memories."""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
"""Reset by delete the collection and recreate it."""
|
||||
|
||||
@@ -221,7 +221,7 @@ class ChromaDB(VectorStoreBase):
|
||||
"""
|
||||
results = self.collection.get(where=filters, limit=limit)
|
||||
return [self._parse_output(results)]
|
||||
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
|
||||
@@ -58,7 +58,12 @@ class ElasticsearchDB(VectorStoreBase):
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"text": {"type": "text"},
|
||||
"vector": {"type": "dense_vector", "dims": self.embedding_model_dims, "index": True, "similarity": "cosine"},
|
||||
"vector": {
|
||||
"type": "dense_vector",
|
||||
"dims": self.embedding_model_dims,
|
||||
"index": True,
|
||||
"similarity": "cosine",
|
||||
},
|
||||
"metadata": {"type": "object", "properties": {"user_id": {"type": "keyword"}}},
|
||||
}
|
||||
},
|
||||
@@ -222,7 +227,7 @@ class ElasticsearchDB(VectorStoreBase):
|
||||
)
|
||||
|
||||
return [results]
|
||||
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
|
||||
@@ -465,7 +465,7 @@ class FAISS(VectorStoreBase):
|
||||
break
|
||||
|
||||
return [results]
|
||||
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
|
||||
@@ -14,6 +14,7 @@ from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: Optional[str] # memory id
|
||||
score: Optional[float] # distance
|
||||
@@ -162,10 +163,7 @@ class Langchain(VectorStoreBase):
|
||||
if filters and "user_id" in filters:
|
||||
where_clause = {"user_id": filters["user_id"]}
|
||||
|
||||
result = self.client._collection.get(
|
||||
where=where_clause,
|
||||
limit=limit
|
||||
)
|
||||
result = self.client._collection.get(where=where_clause, limit=limit)
|
||||
|
||||
# Convert the result to the expected format
|
||||
if result and isinstance(result, dict):
|
||||
|
||||
@@ -237,7 +237,7 @@ class MilvusDB(VectorStoreBase):
|
||||
obj = OutputData(id=data.get("id"), score=None, payload=data.get("metadata"))
|
||||
memories.append(obj)
|
||||
return [memories]
|
||||
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
try:
|
||||
from opensearchpy import OpenSearch, RequestsHttpConnection
|
||||
@@ -34,7 +34,7 @@ class OpenSearchDB(VectorStoreBase):
|
||||
use_ssl=config.use_ssl,
|
||||
verify_certs=config.verify_certs,
|
||||
connection_class=RequestsHttpConnection,
|
||||
pool_maxsize=20
|
||||
pool_maxsize=20,
|
||||
)
|
||||
|
||||
self.collection_name = config.collection_name
|
||||
@@ -69,9 +69,7 @@ class OpenSearchDB(VectorStoreBase):
|
||||
def create_col(self, name: str, vector_size: int) -> None:
|
||||
"""Create a new collection (index in OpenSearch)."""
|
||||
index_settings = {
|
||||
"settings": {
|
||||
"index.knn": True
|
||||
},
|
||||
"settings": {"index.knn": True},
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"vector_field": {
|
||||
@@ -82,7 +80,7 @@ class OpenSearchDB(VectorStoreBase):
|
||||
"payload": {"type": "object"},
|
||||
"id": {"type": "keyword"},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
if not self.client.indices.exists(index=name):
|
||||
@@ -102,9 +100,7 @@ class OpenSearchDB(VectorStoreBase):
|
||||
except Exception:
|
||||
retry_count += 1
|
||||
if retry_count == max_retries:
|
||||
raise TimeoutError(
|
||||
f"Index {name} creation timed out after {max_retries} seconds"
|
||||
)
|
||||
raise TimeoutError(f"Index {name} creation timed out after {max_retries} seconds")
|
||||
time.sleep(0.5)
|
||||
|
||||
def insert(
|
||||
@@ -145,10 +141,7 @@ class OpenSearchDB(VectorStoreBase):
|
||||
}
|
||||
|
||||
# Start building the full query
|
||||
query_body = {
|
||||
"size": limit * 2,
|
||||
"query": None
|
||||
}
|
||||
query_body = {"size": limit * 2, "query": None}
|
||||
|
||||
# Prepare filter conditions if applicable
|
||||
filter_clauses = []
|
||||
@@ -156,18 +149,11 @@ class OpenSearchDB(VectorStoreBase):
|
||||
for key in ["user_id", "run_id", "agent_id"]:
|
||||
value = filters.get(key)
|
||||
if value:
|
||||
filter_clauses.append({
|
||||
"term": {f"payload.{key}.keyword": value}
|
||||
})
|
||||
filter_clauses.append({"term": {f"payload.{key}.keyword": value}})
|
||||
|
||||
# Combine knn with filters if needed
|
||||
if filter_clauses:
|
||||
query_body["query"] = {
|
||||
"bool": {
|
||||
"must": knn_query,
|
||||
"filter": filter_clauses
|
||||
}
|
||||
}
|
||||
query_body["query"] = {"bool": {"must": knn_query, "filter": filter_clauses}}
|
||||
else:
|
||||
query_body["query"] = knn_query
|
||||
|
||||
@@ -176,11 +162,7 @@ class OpenSearchDB(VectorStoreBase):
|
||||
|
||||
hits = response["hits"]["hits"]
|
||||
results = [
|
||||
OutputData(
|
||||
id=hit["_source"].get("id"),
|
||||
score=hit["_score"],
|
||||
payload=hit["_source"].get("payload", {})
|
||||
)
|
||||
OutputData(id=hit["_source"].get("id"), score=hit["_score"], payload=hit["_source"].get("payload", {}))
|
||||
for hit in hits
|
||||
]
|
||||
return results
|
||||
@@ -188,13 +170,7 @@ class OpenSearchDB(VectorStoreBase):
|
||||
def delete(self, vector_id: str) -> None:
|
||||
"""Delete a vector by custom ID."""
|
||||
# First, find the document by custom ID
|
||||
search_query = {
|
||||
"query": {
|
||||
"term": {
|
||||
"id": vector_id
|
||||
}
|
||||
}
|
||||
}
|
||||
search_query = {"query": {"term": {"id": vector_id}}}
|
||||
|
||||
response = self.client.search(index=self.collection_name, body=search_query)
|
||||
hits = response.get("hits", {}).get("hits", [])
|
||||
@@ -207,18 +183,11 @@ class OpenSearchDB(VectorStoreBase):
|
||||
# Delete using the actual document ID
|
||||
self.client.delete(index=self.collection_name, id=opensearch_id)
|
||||
|
||||
|
||||
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None) -> None:
|
||||
"""Update a vector and its payload using the custom 'id' field."""
|
||||
|
||||
# First, find the document by custom ID
|
||||
search_query = {
|
||||
"query": {
|
||||
"term": {
|
||||
"id": vector_id
|
||||
}
|
||||
}
|
||||
}
|
||||
search_query = {"query": {"term": {"id": vector_id}}}
|
||||
|
||||
response = self.client.search(index=self.collection_name, body=search_query)
|
||||
hits = response.get("hits", {}).get("hits", [])
|
||||
@@ -241,7 +210,6 @@ class OpenSearchDB(VectorStoreBase):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def get(self, vector_id: str) -> Optional[OutputData]:
|
||||
"""Retrieve a vector by ID."""
|
||||
try:
|
||||
@@ -251,13 +219,7 @@ class OpenSearchDB(VectorStoreBase):
|
||||
self.create_col(self.collection_name, self.embedding_model_dims)
|
||||
return None
|
||||
|
||||
search_query = {
|
||||
"query": {
|
||||
"term": {
|
||||
"id": vector_id
|
||||
}
|
||||
}
|
||||
}
|
||||
search_query = {"query": {"term": {"id": vector_id}}}
|
||||
response = self.client.search(index=self.collection_name, body=search_query)
|
||||
|
||||
hits = response["hits"]["hits"]
|
||||
@@ -265,11 +227,7 @@ class OpenSearchDB(VectorStoreBase):
|
||||
if not hits:
|
||||
return None
|
||||
|
||||
return OutputData(
|
||||
id=hits[0]["_source"].get("id"),
|
||||
score=1.0,
|
||||
payload=hits[0]["_source"].get("payload", {})
|
||||
)
|
||||
return OutputData(id=hits[0]["_source"].get("id"), score=1.0, payload=hits[0]["_source"].get("payload", {}))
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector {vector_id}: {str(e)}")
|
||||
return None
|
||||
@@ -287,30 +245,19 @@ class OpenSearchDB(VectorStoreBase):
|
||||
return self.client.indices.get(index=name)
|
||||
|
||||
def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[OutputData]:
|
||||
|
||||
try:
|
||||
"""List all memories with optional filters."""
|
||||
query: Dict = {
|
||||
"query": {
|
||||
"match_all": {}
|
||||
}
|
||||
}
|
||||
query: Dict = {"query": {"match_all": {}}}
|
||||
|
||||
filter_clauses = []
|
||||
if filters:
|
||||
for key in ["user_id", "run_id", "agent_id"]:
|
||||
value = filters.get(key)
|
||||
if value:
|
||||
filter_clauses.append({
|
||||
"term": {f"payload.{key}.keyword": value}
|
||||
})
|
||||
filter_clauses.append({"term": {f"payload.{key}.keyword": value}})
|
||||
|
||||
if filter_clauses:
|
||||
query["query"] = {
|
||||
"bool": {
|
||||
"filter": filter_clauses
|
||||
}
|
||||
}
|
||||
query["query"] = {"bool": {"filter": filter_clauses}}
|
||||
|
||||
if limit:
|
||||
query["size"] = limit
|
||||
@@ -318,18 +265,15 @@ class OpenSearchDB(VectorStoreBase):
|
||||
response = self.client.search(index=self.collection_name, body=query)
|
||||
hits = response["hits"]["hits"]
|
||||
|
||||
return [[
|
||||
OutputData(
|
||||
id=hit["_source"].get("id"),
|
||||
score=1.0,
|
||||
payload=hit["_source"].get("payload", {})
|
||||
)
|
||||
for hit in hits
|
||||
]]
|
||||
return [
|
||||
[
|
||||
OutputData(id=hit["_source"].get("id"), score=1.0, payload=hit["_source"].get("payload", {}))
|
||||
for hit in hits
|
||||
]
|
||||
]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
|
||||
@@ -286,7 +286,7 @@ class PGVector(VectorStoreBase):
|
||||
self.cur.close()
|
||||
if hasattr(self, "conn"):
|
||||
self.conn.close()
|
||||
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
|
||||
@@ -232,7 +232,7 @@ class Qdrant(VectorStoreBase):
|
||||
with_vectors=False,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
|
||||
@@ -88,7 +88,7 @@ class RedisDB(VectorStoreBase):
|
||||
The created index object.
|
||||
"""
|
||||
# Use provided parameters or fall back to instance attributes
|
||||
collection_name = name or self.schema['index']['name']
|
||||
collection_name = name or self.schema["index"]["name"]
|
||||
embedding_dims = vector_size or self.embedding_model_dims
|
||||
distance_metric = distance or "cosine"
|
||||
|
||||
@@ -237,17 +237,16 @@ class RedisDB(VectorStoreBase):
|
||||
"""
|
||||
Reset the index by deleting and recreating it.
|
||||
"""
|
||||
collection_name = self.schema['index']['name']
|
||||
collection_name = self.schema["index"]["name"]
|
||||
logger.warning(f"Resetting index {collection_name}...")
|
||||
self.delete_col()
|
||||
|
||||
|
||||
self.index = SearchIndex.from_dict(self.schema)
|
||||
self.index.set_client(self.client)
|
||||
self.index.create(overwrite=True)
|
||||
|
||||
#or use
|
||||
#self.create_col(collection_name, self.embedding_model_dims)
|
||||
|
||||
# or use
|
||||
# self.create_col(collection_name, self.embedding_model_dims)
|
||||
|
||||
# Recreate the index with the same parameters
|
||||
self.create_col(collection_name, self.embedding_model_dims)
|
||||
|
||||
@@ -229,7 +229,7 @@ class Supabase(VectorStoreBase):
|
||||
records = self.collection.fetch(ids=ids)
|
||||
|
||||
return [[OutputData(id=str(record[0]), score=None, payload=record[2]) for record in records]]
|
||||
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
|
||||
@@ -285,10 +285,9 @@ class UpstashVector(VectorStoreBase):
|
||||
- Per-namespace vector and pending vector counts
|
||||
"""
|
||||
return self.client.info()
|
||||
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the Upstash Vector index.
|
||||
"""
|
||||
self.delete_col()
|
||||
|
||||
|
||||
@@ -308,7 +308,7 @@ class Weaviate(VectorStoreBase):
|
||||
payload["id"] = str(obj.uuid).split("'")[0]
|
||||
results.append(OutputData(id=str(obj.uuid).split("'")[0], score=1.0, payload=payload))
|
||||
return [results]
|
||||
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
|
||||
Reference in New Issue
Block a user