Code formatting and doc update (#2130)
This commit is contained in:
@@ -27,9 +27,9 @@ from pydantic import BaseModel, Field
|
|||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
client = MemoryClient(
|
client = MemoryClient(
|
||||||
"---",
|
api_key=your_api_key,
|
||||||
org_id="---",
|
org_id=your_org_id,
|
||||||
project_id="---"
|
project_id=your_project_id
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -362,11 +362,7 @@ class MemoryClient:
|
|||||||
Raises:
|
Raises:
|
||||||
APIError: If the API request fails.
|
APIError: If the API request fails.
|
||||||
"""
|
"""
|
||||||
response = self.client.request(
|
response = self.client.request("DELETE", "/v1/batch/", json={"memories": memories})
|
||||||
"DELETE",
|
|
||||||
"/v1/batch/",
|
|
||||||
json={"memories": memories}
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
capture_client_event("client.batch_delete", self)
|
capture_client_event("client.batch_delete", self)
|
||||||
@@ -383,15 +379,12 @@ class MemoryClient:
|
|||||||
Returns:
|
Returns:
|
||||||
Dict containing export request ID and status message
|
Dict containing export request ID and status message
|
||||||
"""
|
"""
|
||||||
response = self.client.post(
|
response = self.client.post("/v1/exports/", json={"schema": schema, **self._prepare_params(kwargs)})
|
||||||
"/v1/exports/",
|
|
||||||
json={"schema": schema, **self._prepare_params(kwargs)}
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
capture_client_event("client.create_memory_export", self, {"schema": schema, "keys": list(kwargs.keys())})
|
capture_client_event("client.create_memory_export", self, {"schema": schema, "keys": list(kwargs.keys())})
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
@api_error_handler
|
@api_error_handler
|
||||||
def get_memory_export(self, **kwargs) -> Dict[str, Any]:
|
def get_memory_export(self, **kwargs) -> Dict[str, Any]:
|
||||||
"""Get a memory export.
|
"""Get a memory export.
|
||||||
|
|
||||||
@@ -401,10 +394,7 @@ class MemoryClient:
|
|||||||
Returns:
|
Returns:
|
||||||
Dict containing the exported data
|
Dict containing the exported data
|
||||||
"""
|
"""
|
||||||
response = self.client.get(
|
response = self.client.get("/v1/exports/", params=self._prepare_params(kwargs))
|
||||||
"/v1/exports/",
|
|
||||||
params=self._prepare_params(kwargs)
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
capture_client_event("client.get_memory_export", self, {"keys": list(kwargs.keys())})
|
capture_client_event("client.get_memory_export", self, {"keys": list(kwargs.keys())})
|
||||||
return response.json()
|
return response.json()
|
||||||
@@ -456,7 +446,7 @@ class MemoryClient:
|
|||||||
|
|
||||||
has_new = bool(self.org_id or self.project_id)
|
has_new = bool(self.org_id or self.project_id)
|
||||||
has_old = bool(self.organization or self.project)
|
has_old = bool(self.organization or self.project)
|
||||||
|
|
||||||
if has_new and has_old:
|
if has_new and has_old:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Please use either org_id/project_id or org_name/project_name, not both. "
|
"Please use either org_id/project_id or org_name/project_name, not both. "
|
||||||
@@ -480,7 +470,7 @@ class MemoryClient:
|
|||||||
|
|
||||||
class AsyncMemoryClient:
|
class AsyncMemoryClient:
|
||||||
"""Asynchronous client for interacting with the Mem0 API.
|
"""Asynchronous client for interacting with the Mem0 API.
|
||||||
|
|
||||||
This class provides asynchronous versions of all MemoryClient methods.
|
This class provides asynchronous versions of all MemoryClient methods.
|
||||||
It uses httpx.AsyncClient for making non-blocking API requests.
|
It uses httpx.AsyncClient for making non-blocking API requests.
|
||||||
|
|
||||||
@@ -498,14 +488,7 @@ class AsyncMemoryClient:
|
|||||||
org_id: Optional[str] = None,
|
org_id: Optional[str] = None,
|
||||||
project_id: Optional[str] = None,
|
project_id: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.sync_client = MemoryClient(
|
self.sync_client = MemoryClient(api_key, host, organization, project, org_id, project_id)
|
||||||
api_key,
|
|
||||||
host,
|
|
||||||
organization,
|
|
||||||
project,
|
|
||||||
org_id,
|
|
||||||
project_id
|
|
||||||
)
|
|
||||||
self.async_client = httpx.AsyncClient(
|
self.async_client = httpx.AsyncClient(
|
||||||
base_url=self.sync_client.host,
|
base_url=self.sync_client.host,
|
||||||
headers=self.sync_client.client.headers,
|
headers=self.sync_client.client.headers,
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ class TogetherEmbedding(EmbeddingBase):
|
|||||||
# TODO: check if this is correct
|
# TODO: check if this is correct
|
||||||
self.config.embedding_dims = self.config.embedding_dims or 768
|
self.config.embedding_dims = self.config.embedding_dims or 768
|
||||||
self.client = Together(api_key=api_key)
|
self.client = Together(api_key=api_key)
|
||||||
|
|
||||||
def embed(self, text):
|
def embed(self, text):
|
||||||
"""
|
"""
|
||||||
Get the embedding for the given text using OpenAI.
|
Get the embedding for the given text using OpenAI.
|
||||||
@@ -28,4 +28,4 @@ class TogetherEmbedding(EmbeddingBase):
|
|||||||
list: The embedding vector.
|
list: The embedding vector.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self.client.embeddings.create(model=self.config.model, input=text).data[0].embedding
|
return self.client.embeddings.create(model=self.config.model, input=text).data[0].embedding
|
||||||
|
|||||||
@@ -95,20 +95,17 @@ RELATIONS_TOOL = {
|
|||||||
"items": {
|
"items": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"source": {
|
"source": {"type": "string", "description": "The source entity of the relationship."},
|
||||||
"type": "string",
|
|
||||||
"description": "The source entity of the relationship."
|
|
||||||
},
|
|
||||||
"relationship": {
|
"relationship": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The relationship between the source and destination entities."
|
"description": "The relationship between the source and destination entities.",
|
||||||
},
|
},
|
||||||
"destination": {
|
"destination": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The destination entity of the relationship."
|
"description": "The destination entity of the relationship.",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"required": [
|
"required": [
|
||||||
"source",
|
"source",
|
||||||
"relationship",
|
"relationship",
|
||||||
"destination",
|
"destination",
|
||||||
@@ -137,25 +134,19 @@ EXTRACT_ENTITIES_TOOL = {
|
|||||||
"items": {
|
"items": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"entity": {
|
"entity": {"type": "string", "description": "The name or identifier of the entity."},
|
||||||
"type": "string",
|
"entity_type": {"type": "string", "description": "The type or category of the entity."},
|
||||||
"description": "The name or identifier of the entity."
|
|
||||||
},
|
|
||||||
"entity_type": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The type or category of the entity."
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required": ["entity", "entity_type"],
|
"required": ["entity", "entity_type"],
|
||||||
"additionalProperties": False
|
"additionalProperties": False,
|
||||||
},
|
},
|
||||||
"description": "An array of entities with their types."
|
"description": "An array of entities with their types.",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": ["entities"],
|
"required": ["entities"],
|
||||||
"additionalProperties": False
|
"additionalProperties": False,
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
UPDATE_MEMORY_STRUCT_TOOL_GRAPH = {
|
UPDATE_MEMORY_STRUCT_TOOL_GRAPH = {
|
||||||
@@ -260,18 +251,18 @@ RELATIONS_STRUCT_TOOL = {
|
|||||||
"properties": {
|
"properties": {
|
||||||
"source_entity": {
|
"source_entity": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The source entity of the relationship."
|
"description": "The source entity of the relationship.",
|
||||||
},
|
},
|
||||||
"relatationship": {
|
"relatationship": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The relationship between the source and destination entities."
|
"description": "The relationship between the source and destination entities.",
|
||||||
},
|
},
|
||||||
"destination_entity": {
|
"destination_entity": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The destination entity of the relationship."
|
"description": "The destination entity of the relationship.",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"required": [
|
"required": [
|
||||||
"source_entity",
|
"source_entity",
|
||||||
"relatationship",
|
"relatationship",
|
||||||
"destination_entity",
|
"destination_entity",
|
||||||
@@ -301,25 +292,19 @@ EXTRACT_ENTITIES_STRUCT_TOOL = {
|
|||||||
"items": {
|
"items": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"entity": {
|
"entity": {"type": "string", "description": "The name or identifier of the entity."},
|
||||||
"type": "string",
|
"entity_type": {"type": "string", "description": "The type or category of the entity."},
|
||||||
"description": "The name or identifier of the entity."
|
|
||||||
},
|
|
||||||
"entity_type": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The type or category of the entity."
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required": ["entity", "entity_type"],
|
"required": ["entity", "entity_type"],
|
||||||
"additionalProperties": False
|
"additionalProperties": False,
|
||||||
},
|
},
|
||||||
"description": "An array of entities with their types."
|
"description": "An array of entities with their types.",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": ["entities"],
|
"required": ["entities"],
|
||||||
"additionalProperties": False
|
"additionalProperties": False,
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
DELETE_MEMORY_STRUCT_TOOL_GRAPH = {
|
DELETE_MEMORY_STRUCT_TOOL_GRAPH = {
|
||||||
@@ -342,7 +327,7 @@ DELETE_MEMORY_STRUCT_TOOL_GRAPH = {
|
|||||||
"destination": {
|
"destination": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The identifier of the destination node in the relationship.",
|
"description": "The identifier of the destination node in the relationship.",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
"required": [
|
"required": [
|
||||||
"source",
|
"source",
|
||||||
@@ -373,7 +358,7 @@ DELETE_MEMORY_TOOL_GRAPH = {
|
|||||||
"destination": {
|
"destination": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The identifier of the destination node in the relationship.",
|
"description": "The identifier of the destination node in the relationship.",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
"required": [
|
"required": [
|
||||||
"source",
|
"source",
|
||||||
@@ -383,4 +368,4 @@ DELETE_MEMORY_TOOL_GRAPH = {
|
|||||||
"additionalProperties": False,
|
"additionalProperties": False,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -90,5 +90,8 @@ source -- relationship -- destination
|
|||||||
Provide a list of deletion instructions, each specifying the relationship to be deleted.
|
Provide a list of deletion instructions, each specifying the relationship to be deleted.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_delete_messages(existing_memories_string, data, user_id):
|
def get_delete_messages(existing_memories_string, data, user_id):
|
||||||
return DELETE_RELATIONS_SYSTEM_PROMPT.replace("USER_ID", user_id), f"Here are the existing memories: {existing_memories_string} \n\n New Information: {data}"
|
return DELETE_RELATIONS_SYSTEM_PROMPT.replace(
|
||||||
|
"USER_ID", user_id
|
||||||
|
), f"Here are the existing memories: {existing_memories_string} \n\n New Information: {data}"
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
@@ -11,12 +10,12 @@ from mem0.configs.llms.base import BaseLlmConfig
|
|||||||
from mem0.llms.base import LLMBase
|
from mem0.llms.base import LLMBase
|
||||||
|
|
||||||
|
|
||||||
class AWSBedrockLLM(LLMBase):
|
class AWSBedrockLLM(LLMBase):
|
||||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
if not self.config.model:
|
if not self.config.model:
|
||||||
self.config.model="anthropic.claude-3-5-sonnet-20240620-v1:0"
|
self.config.model = "anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||||
self.client = boto3.client("bedrock-runtime")
|
self.client = boto3.client("bedrock-runtime")
|
||||||
self.model_kwargs = {
|
self.model_kwargs = {
|
||||||
"temperature": self.config.temperature,
|
"temperature": self.config.temperature,
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -39,11 +38,7 @@ class GeminiLLM(LLMBase):
|
|||||||
"""
|
"""
|
||||||
if tools:
|
if tools:
|
||||||
processed_response = {
|
processed_response = {
|
||||||
"content": (
|
"content": (content if (content := response.candidates[0].content.parts[0].text) else None),
|
||||||
content
|
|
||||||
if (content := response.candidates[0].content.parts[0].text)
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
"tool_calls": [],
|
"tool_calls": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -51,13 +46,9 @@ class GeminiLLM(LLMBase):
|
|||||||
if fn := part.function_call:
|
if fn := part.function_call:
|
||||||
if isinstance(fn, protos.FunctionCall):
|
if isinstance(fn, protos.FunctionCall):
|
||||||
fn_call = type(fn).to_dict(fn)
|
fn_call = type(fn).to_dict(fn)
|
||||||
processed_response["tool_calls"].append(
|
processed_response["tool_calls"].append({"name": fn_call["name"], "arguments": fn_call["args"]})
|
||||||
{"name": fn_call["name"], "arguments": fn_call["args"]}
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
processed_response["tool_calls"].append(
|
processed_response["tool_calls"].append({"name": fn.name, "arguments": fn.args})
|
||||||
{"name": fn.name, "arguments": fn.args}
|
|
||||||
)
|
|
||||||
|
|
||||||
return processed_response
|
return processed_response
|
||||||
else:
|
else:
|
||||||
@@ -77,9 +68,7 @@ class GeminiLLM(LLMBase):
|
|||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if message["role"] == "system":
|
if message["role"] == "system":
|
||||||
content = (
|
content = "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"]
|
||||||
"THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"]
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
@@ -121,9 +110,7 @@ class GeminiLLM(LLMBase):
|
|||||||
if tools:
|
if tools:
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
func = tool["function"].copy()
|
func = tool["function"].copy()
|
||||||
new_tools.append(
|
new_tools.append({"function_declarations": [remove_additional_properties(func)]})
|
||||||
{"function_declarations": [remove_additional_properties(func)]}
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: temporarily ignore it to pass tests, will come back to update according to standards later.
|
# TODO: temporarily ignore it to pass tests, will come back to update according to standards later.
|
||||||
# return content_types.to_function_library(new_tools)
|
# return content_types.to_function_library(new_tools)
|
||||||
@@ -168,9 +155,7 @@ class GeminiLLM(LLMBase):
|
|||||||
"function_calling_config": {
|
"function_calling_config": {
|
||||||
"mode": tool_choice,
|
"mode": tool_choice,
|
||||||
"allowed_function_names": (
|
"allowed_function_names": (
|
||||||
[tool["function"]["name"] for tool in tools]
|
[tool["function"]["name"] for tool in tools] if tool_choice == "any" else None
|
||||||
if tool_choice == "any"
|
|
||||||
else None
|
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -58,12 +58,12 @@ class MemoryGraph:
|
|||||||
to_be_added = self._establish_nodes_relations_from_data(data, filters, entity_type_map)
|
to_be_added = self._establish_nodes_relations_from_data(data, filters, entity_type_map)
|
||||||
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
|
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
|
||||||
to_be_deleted = self._get_delete_entities_from_search_output(search_output, data, filters)
|
to_be_deleted = self._get_delete_entities_from_search_output(search_output, data, filters)
|
||||||
|
|
||||||
#TODO: Batch queries with APOC plugin
|
# TODO: Batch queries with APOC plugin
|
||||||
#TODO: Add more filter support
|
# TODO: Add more filter support
|
||||||
deleted_entities = self._delete_entities(to_be_deleted, filters["user_id"])
|
deleted_entities = self._delete_entities(to_be_deleted, filters["user_id"])
|
||||||
added_entities = self._add_entities(to_be_added, filters["user_id"], entity_type_map)
|
added_entities = self._add_entities(to_be_added, filters["user_id"], entity_type_map)
|
||||||
|
|
||||||
return {"deleted_entities": deleted_entities, "added_entities": added_entities}
|
return {"deleted_entities": deleted_entities, "added_entities": added_entities}
|
||||||
|
|
||||||
def search(self, query, filters, limit=100):
|
def search(self, query, filters, limit=100):
|
||||||
@@ -86,7 +86,9 @@ class MemoryGraph:
|
|||||||
if not search_output:
|
if not search_output:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
search_outputs_sequence = [[item["source"], item["relatationship"], item["destination"]] for item in search_output]
|
search_outputs_sequence = [
|
||||||
|
[item["source"], item["relatationship"], item["destination"]] for item in search_output
|
||||||
|
]
|
||||||
bm25 = BM25Okapi(search_outputs_sequence)
|
bm25 = BM25Okapi(search_outputs_sequence)
|
||||||
|
|
||||||
tokenized_query = query.split(" ")
|
tokenized_query = query.split(" ")
|
||||||
@@ -142,7 +144,7 @@ class MemoryGraph:
|
|||||||
logger.info(f"Retrieved {len(final_results)} relationships")
|
logger.info(f"Retrieved {len(final_results)} relationships")
|
||||||
|
|
||||||
return final_results
|
return final_results
|
||||||
|
|
||||||
def _retrieve_nodes_from_data(self, data, filters):
|
def _retrieve_nodes_from_data(self, data, filters):
|
||||||
"""Extracts all the entities mentioned in the query."""
|
"""Extracts all the entities mentioned in the query."""
|
||||||
_tools = [EXTRACT_ENTITIES_TOOL]
|
_tools = [EXTRACT_ENTITIES_TOOL]
|
||||||
@@ -170,7 +172,7 @@ class MemoryGraph:
|
|||||||
entity_type_map = {k.lower().replace(" ", "_"): v.lower().replace(" ", "_") for k, v in entity_type_map.items()}
|
entity_type_map = {k.lower().replace(" ", "_"): v.lower().replace(" ", "_") for k, v in entity_type_map.items()}
|
||||||
logger.debug(f"Entity type map: {entity_type_map}")
|
logger.debug(f"Entity type map: {entity_type_map}")
|
||||||
return entity_type_map
|
return entity_type_map
|
||||||
|
|
||||||
def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
|
def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
|
||||||
"""Eshtablish relations among the extracted nodes."""
|
"""Eshtablish relations among the extracted nodes."""
|
||||||
if self.config.graph_store.custom_prompt:
|
if self.config.graph_store.custom_prompt:
|
||||||
@@ -209,7 +211,7 @@ class MemoryGraph:
|
|||||||
extracted_entities = self._remove_spaces_from_entities(extracted_entities)
|
extracted_entities = self._remove_spaces_from_entities(extracted_entities)
|
||||||
logger.debug(f"Extracted entities: {extracted_entities}")
|
logger.debug(f"Extracted entities: {extracted_entities}")
|
||||||
return extracted_entities
|
return extracted_entities
|
||||||
|
|
||||||
def _search_graph_db(self, node_list, filters, limit=100):
|
def _search_graph_db(self, node_list, filters, limit=100):
|
||||||
"""Search similar nodes among and their respective incoming and outgoing relations."""
|
"""Search similar nodes among and their respective incoming and outgoing relations."""
|
||||||
result_relations = []
|
result_relations = []
|
||||||
@@ -250,7 +252,7 @@ class MemoryGraph:
|
|||||||
result_relations.extend(ans)
|
result_relations.extend(ans)
|
||||||
|
|
||||||
return result_relations
|
return result_relations
|
||||||
|
|
||||||
def _get_delete_entities_from_search_output(self, search_output, data, filters):
|
def _get_delete_entities_from_search_output(self, search_output, data, filters):
|
||||||
"""Get the entities to be deleted from the search output."""
|
"""Get the entities to be deleted from the search output."""
|
||||||
search_output_string = format_entities(search_output)
|
search_output_string = format_entities(search_output)
|
||||||
@@ -273,11 +275,11 @@ class MemoryGraph:
|
|||||||
for item in memory_updates["tool_calls"]:
|
for item in memory_updates["tool_calls"]:
|
||||||
if item["name"] == "delete_graph_memory":
|
if item["name"] == "delete_graph_memory":
|
||||||
to_be_deleted.append(item["arguments"])
|
to_be_deleted.append(item["arguments"])
|
||||||
#in case if it is not in the correct format
|
# in case if it is not in the correct format
|
||||||
to_be_deleted = self._remove_spaces_from_entities(to_be_deleted)
|
to_be_deleted = self._remove_spaces_from_entities(to_be_deleted)
|
||||||
logger.debug(f"Deleted relationships: {to_be_deleted}")
|
logger.debug(f"Deleted relationships: {to_be_deleted}")
|
||||||
return to_be_deleted
|
return to_be_deleted
|
||||||
|
|
||||||
def _delete_entities(self, to_be_deleted, user_id):
|
def _delete_entities(self, to_be_deleted, user_id):
|
||||||
"""Delete the entities from the graph."""
|
"""Delete the entities from the graph."""
|
||||||
results = []
|
results = []
|
||||||
@@ -285,7 +287,7 @@ class MemoryGraph:
|
|||||||
source = item["source"]
|
source = item["source"]
|
||||||
destination = item["destination"]
|
destination = item["destination"]
|
||||||
relatationship = item["relationship"]
|
relatationship = item["relationship"]
|
||||||
|
|
||||||
# Delete the specific relationship between nodes
|
# Delete the specific relationship between nodes
|
||||||
cypher = f"""
|
cypher = f"""
|
||||||
MATCH (n {{name: $source_name, user_id: $user_id}})
|
MATCH (n {{name: $source_name, user_id: $user_id}})
|
||||||
@@ -305,29 +307,29 @@ class MemoryGraph:
|
|||||||
result = self.graph.query(cypher, params=params)
|
result = self.graph.query(cypher, params=params)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def _add_entities(self, to_be_added, user_id, entity_type_map):
|
def _add_entities(self, to_be_added, user_id, entity_type_map):
|
||||||
"""Add the new entities to the graph. Merge the nodes if they already exist."""
|
"""Add the new entities to the graph. Merge the nodes if they already exist."""
|
||||||
results = []
|
results = []
|
||||||
for item in to_be_added:
|
for item in to_be_added:
|
||||||
#entities
|
# entities
|
||||||
source = item["source"]
|
source = item["source"]
|
||||||
destination = item["destination"]
|
destination = item["destination"]
|
||||||
relationship = item["relationship"]
|
relationship = item["relationship"]
|
||||||
|
|
||||||
#types
|
# types
|
||||||
source_type = entity_type_map.get(source, "unknown")
|
source_type = entity_type_map.get(source, "unknown")
|
||||||
destination_type = entity_type_map.get(destination, "unknown")
|
destination_type = entity_type_map.get(destination, "unknown")
|
||||||
|
|
||||||
#embeddings
|
# embeddings
|
||||||
source_embedding = self.embedding_model.embed(source)
|
source_embedding = self.embedding_model.embed(source)
|
||||||
dest_embedding = self.embedding_model.embed(destination)
|
dest_embedding = self.embedding_model.embed(destination)
|
||||||
|
|
||||||
#search for the nodes with the closest embeddings
|
# search for the nodes with the closest embeddings
|
||||||
source_node_search_result = self._search_source_node(source_embedding, user_id, threshold=0.9)
|
source_node_search_result = self._search_source_node(source_embedding, user_id, threshold=0.9)
|
||||||
destination_node_search_result = self._search_destination_node(dest_embedding, user_id, threshold=0.9)
|
destination_node_search_result = self._search_destination_node(dest_embedding, user_id, threshold=0.9)
|
||||||
|
|
||||||
#TODO: Create a cypher query and common params for all the cases
|
# TODO: Create a cypher query and common params for all the cases
|
||||||
if not destination_node_search_result and source_node_search_result:
|
if not destination_node_search_result and source_node_search_result:
|
||||||
cypher = f"""
|
cypher = f"""
|
||||||
MATCH (source)
|
MATCH (source)
|
||||||
@@ -343,7 +345,7 @@ class MemoryGraph:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"source_id": source_node_search_result[0]['elementId(source_candidate)'],
|
"source_id": source_node_search_result[0]["elementId(source_candidate)"],
|
||||||
"destination_name": destination,
|
"destination_name": destination,
|
||||||
"relationship": relationship,
|
"relationship": relationship,
|
||||||
"destination_type": destination_type,
|
"destination_type": destination_type,
|
||||||
@@ -366,9 +368,9 @@ class MemoryGraph:
|
|||||||
r.created = timestamp()
|
r.created = timestamp()
|
||||||
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
||||||
"""
|
"""
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"destination_id": destination_node_search_result[0]['elementId(destination_candidate)'],
|
"destination_id": destination_node_search_result[0]["elementId(destination_candidate)"],
|
||||||
"source_name": source,
|
"source_name": source,
|
||||||
"relationship": relationship,
|
"relationship": relationship,
|
||||||
"source_type": source_type,
|
"source_type": source_type,
|
||||||
@@ -377,7 +379,7 @@ class MemoryGraph:
|
|||||||
}
|
}
|
||||||
resp = self.graph.query(cypher, params=params)
|
resp = self.graph.query(cypher, params=params)
|
||||||
results.append(resp)
|
results.append(resp)
|
||||||
|
|
||||||
elif source_node_search_result and destination_node_search_result:
|
elif source_node_search_result and destination_node_search_result:
|
||||||
cypher = f"""
|
cypher = f"""
|
||||||
MATCH (source)
|
MATCH (source)
|
||||||
@@ -391,8 +393,8 @@ class MemoryGraph:
|
|||||||
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
||||||
"""
|
"""
|
||||||
params = {
|
params = {
|
||||||
"source_id": source_node_search_result[0]['elementId(source_candidate)'],
|
"source_id": source_node_search_result[0]["elementId(source_candidate)"],
|
||||||
"destination_id": destination_node_search_result[0]['elementId(destination_candidate)'],
|
"destination_id": destination_node_search_result[0]["elementId(destination_candidate)"],
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"relationship": relationship,
|
"relationship": relationship,
|
||||||
}
|
}
|
||||||
@@ -432,7 +434,7 @@ class MemoryGraph:
|
|||||||
return entity_list
|
return entity_list
|
||||||
|
|
||||||
def _search_source_node(self, source_embedding, user_id, threshold=0.9):
|
def _search_source_node(self, source_embedding, user_id, threshold=0.9):
|
||||||
cypher = f"""
|
cypher = """
|
||||||
MATCH (source_candidate)
|
MATCH (source_candidate)
|
||||||
WHERE source_candidate.embedding IS NOT NULL
|
WHERE source_candidate.embedding IS NOT NULL
|
||||||
AND source_candidate.user_id = $user_id
|
AND source_candidate.user_id = $user_id
|
||||||
@@ -454,7 +456,7 @@ class MemoryGraph:
|
|||||||
|
|
||||||
RETURN elementId(source_candidate)
|
RETURN elementId(source_candidate)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"source_embedding": source_embedding,
|
"source_embedding": source_embedding,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
@@ -465,7 +467,7 @@ class MemoryGraph:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def _search_destination_node(self, destination_embedding, user_id, threshold=0.9):
|
def _search_destination_node(self, destination_embedding, user_id, threshold=0.9):
|
||||||
cypher = f"""
|
cypher = """
|
||||||
MATCH (destination_candidate)
|
MATCH (destination_candidate)
|
||||||
WHERE destination_candidate.embedding IS NOT NULL
|
WHERE destination_candidate.embedding IS NOT NULL
|
||||||
AND destination_candidate.user_id = $user_id
|
AND destination_candidate.user_id = $user_id
|
||||||
@@ -494,4 +496,4 @@ class MemoryGraph:
|
|||||||
}
|
}
|
||||||
|
|
||||||
result = self.graph.query(cypher, params=params)
|
result = self.graph.query(cypher, params=params)
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -249,7 +249,7 @@ class Memory(MemoryBase):
|
|||||||
if self.api_version == "v1.1" and self.enable_graph:
|
if self.api_version == "v1.1" and self.enable_graph:
|
||||||
if filters.get("user_id") is None:
|
if filters.get("user_id") is None:
|
||||||
filters["user_id"] = "user"
|
filters["user_id"] = "user"
|
||||||
|
|
||||||
data = "\n".join([msg["content"] for msg in messages if "content" in msg and msg["role"] != "system"])
|
data = "\n".join([msg["content"] for msg in messages if "content" in msg and msg["role"] != "system"])
|
||||||
added_entities = self.graph.add(data, filters)
|
added_entities = self.graph.add(data, filters)
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import re
|
import re
|
||||||
import json
|
|
||||||
|
|
||||||
from mem0.configs.prompts import FACT_RETRIEVAL_PROMPT
|
from mem0.configs.prompts import FACT_RETRIEVAL_PROMPT
|
||||||
|
|
||||||
@@ -19,6 +18,7 @@ def parse_messages(messages):
|
|||||||
response += f"assistant: {msg['content']}\n"
|
response += f"assistant: {msg['content']}\n"
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def format_entities(entities):
|
def format_entities(entities):
|
||||||
if not entities:
|
if not entities:
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
Reference in New Issue
Block a user