diff --git a/docs/features/langchain-tools.mdx b/docs/features/langchain-tools.mdx index c5ec2469..e1194549 100644 --- a/docs/features/langchain-tools.mdx +++ b/docs/features/langchain-tools.mdx @@ -27,9 +27,9 @@ from pydantic import BaseModel, Field from typing import List, Dict, Any, Optional client = MemoryClient( - "---", - org_id="---", - project_id="---" + api_key=your_api_key, + org_id=your_org_id, + project_id=your_project_id ) ``` diff --git a/mem0/client/main.py b/mem0/client/main.py index ec52cbe1..29349b1d 100644 --- a/mem0/client/main.py +++ b/mem0/client/main.py @@ -362,11 +362,7 @@ class MemoryClient: Raises: APIError: If the API request fails. """ - response = self.client.request( - "DELETE", - "/v1/batch/", - json={"memories": memories} - ) + response = self.client.request("DELETE", "/v1/batch/", json={"memories": memories}) response.raise_for_status() capture_client_event("client.batch_delete", self) @@ -383,15 +379,12 @@ class MemoryClient: Returns: Dict containing export request ID and status message """ - response = self.client.post( - "/v1/exports/", - json={"schema": schema, **self._prepare_params(kwargs)} - ) + response = self.client.post("/v1/exports/", json={"schema": schema, **self._prepare_params(kwargs)}) response.raise_for_status() capture_client_event("client.create_memory_export", self, {"schema": schema, "keys": list(kwargs.keys())}) return response.json() - @api_error_handler + @api_error_handler def get_memory_export(self, **kwargs) -> Dict[str, Any]: """Get a memory export. @@ -401,10 +394,7 @@ class MemoryClient: Returns: Dict containing the exported data """ - response = self.client.get( - "/v1/exports/", - params=self._prepare_params(kwargs) - ) + response = self.client.get("/v1/exports/", params=self._prepare_params(kwargs)) response.raise_for_status() capture_client_event("client.get_memory_export", self, {"keys": list(kwargs.keys())}) return response.json() @@ -456,7 +446,7 @@ class MemoryClient: has_new = bool(self.org_id or self.project_id) has_old = bool(self.organization or self.project) - + if has_new and has_old: raise ValueError( "Please use either org_id/project_id or org_name/project_name, not both. " @@ -480,7 +470,7 @@ class MemoryClient: class AsyncMemoryClient: """Asynchronous client for interacting with the Mem0 API. - + This class provides asynchronous versions of all MemoryClient methods. It uses httpx.AsyncClient for making non-blocking API requests. @@ -498,14 +488,7 @@ class AsyncMemoryClient: org_id: Optional[str] = None, project_id: Optional[str] = None, ): - self.sync_client = MemoryClient( - api_key, - host, - organization, - project, - org_id, - project_id - ) + self.sync_client = MemoryClient(api_key, host, organization, project, org_id, project_id) self.async_client = httpx.AsyncClient( base_url=self.sync_client.host, headers=self.sync_client.client.headers, diff --git a/mem0/embeddings/together.py b/mem0/embeddings/together.py index dd16bd99..76b9124b 100644 --- a/mem0/embeddings/together.py +++ b/mem0/embeddings/together.py @@ -16,7 +16,7 @@ class TogetherEmbedding(EmbeddingBase): # TODO: check if this is correct self.config.embedding_dims = self.config.embedding_dims or 768 self.client = Together(api_key=api_key) - + def embed(self, text): """ Get the embedding for the given text using OpenAI. @@ -28,4 +28,4 @@ class TogetherEmbedding(EmbeddingBase): list: The embedding vector. """ - return self.client.embeddings.create(model=self.config.model, input=text).data[0].embedding \ No newline at end of file + return self.client.embeddings.create(model=self.config.model, input=text).data[0].embedding diff --git a/mem0/graphs/tools.py b/mem0/graphs/tools.py index bd02c35f..95bb32ad 100644 --- a/mem0/graphs/tools.py +++ b/mem0/graphs/tools.py @@ -95,20 +95,17 @@ RELATIONS_TOOL = { "items": { "type": "object", "properties": { - "source": { - "type": "string", - "description": "The source entity of the relationship." - }, + "source": {"type": "string", "description": "The source entity of the relationship."}, "relationship": { "type": "string", - "description": "The relationship between the source and destination entities." + "description": "The relationship between the source and destination entities.", }, "destination": { "type": "string", - "description": "The destination entity of the relationship." + "description": "The destination entity of the relationship.", }, }, - "required": [ + "required": [ "source", "relationship", "destination", @@ -137,25 +134,19 @@ EXTRACT_ENTITIES_TOOL = { "items": { "type": "object", "properties": { - "entity": { - "type": "string", - "description": "The name or identifier of the entity." - }, - "entity_type": { - "type": "string", - "description": "The type or category of the entity." - } + "entity": {"type": "string", "description": "The name or identifier of the entity."}, + "entity_type": {"type": "string", "description": "The type or category of the entity."}, }, "required": ["entity", "entity_type"], - "additionalProperties": False + "additionalProperties": False, }, - "description": "An array of entities with their types." + "description": "An array of entities with their types.", } }, "required": ["entities"], - "additionalProperties": False - } - } + "additionalProperties": False, + }, + }, } UPDATE_MEMORY_STRUCT_TOOL_GRAPH = { @@ -260,18 +251,18 @@ RELATIONS_STRUCT_TOOL = { "properties": { "source_entity": { "type": "string", - "description": "The source entity of the relationship." + "description": "The source entity of the relationship.", }, "relatationship": { "type": "string", - "description": "The relationship between the source and destination entities." + "description": "The relationship between the source and destination entities.", }, "destination_entity": { "type": "string", - "description": "The destination entity of the relationship." + "description": "The destination entity of the relationship.", }, }, - "required": [ + "required": [ "source_entity", "relatationship", "destination_entity", @@ -301,25 +292,19 @@ EXTRACT_ENTITIES_STRUCT_TOOL = { "items": { "type": "object", "properties": { - "entity": { - "type": "string", - "description": "The name or identifier of the entity." - }, - "entity_type": { - "type": "string", - "description": "The type or category of the entity." - } + "entity": {"type": "string", "description": "The name or identifier of the entity."}, + "entity_type": {"type": "string", "description": "The type or category of the entity."}, }, "required": ["entity", "entity_type"], - "additionalProperties": False + "additionalProperties": False, }, - "description": "An array of entities with their types." + "description": "An array of entities with their types.", } }, "required": ["entities"], - "additionalProperties": False - } - } + "additionalProperties": False, + }, + }, } DELETE_MEMORY_STRUCT_TOOL_GRAPH = { @@ -342,7 +327,7 @@ DELETE_MEMORY_STRUCT_TOOL_GRAPH = { "destination": { "type": "string", "description": "The identifier of the destination node in the relationship.", - } + }, }, "required": [ "source", @@ -373,7 +358,7 @@ DELETE_MEMORY_TOOL_GRAPH = { "destination": { "type": "string", "description": "The identifier of the destination node in the relationship.", - } + }, }, "required": [ "source", @@ -383,4 +368,4 @@ DELETE_MEMORY_TOOL_GRAPH = { "additionalProperties": False, }, }, -} \ No newline at end of file +} diff --git a/mem0/graphs/utils.py b/mem0/graphs/utils.py index c189bc62..de138477 100644 --- a/mem0/graphs/utils.py +++ b/mem0/graphs/utils.py @@ -90,5 +90,8 @@ source -- relationship -- destination Provide a list of deletion instructions, each specifying the relationship to be deleted. """ + def get_delete_messages(existing_memories_string, data, user_id): - return DELETE_RELATIONS_SYSTEM_PROMPT.replace("USER_ID", user_id), f"Here are the existing memories: {existing_memories_string} \n\n New Information: {data}" + return DELETE_RELATIONS_SYSTEM_PROMPT.replace( + "USER_ID", user_id + ), f"Here are the existing memories: {existing_memories_string} \n\n New Information: {data}" diff --git a/mem0/llms/aws_bedrock.py b/mem0/llms/aws_bedrock.py index 9f6c1b09..8d8bb01d 100644 --- a/mem0/llms/aws_bedrock.py +++ b/mem0/llms/aws_bedrock.py @@ -1,4 +1,3 @@ - import json from typing import Any, Dict, List, Optional @@ -11,12 +10,12 @@ from mem0.configs.llms.base import BaseLlmConfig from mem0.llms.base import LLMBase -class AWSBedrockLLM(LLMBase): +class AWSBedrockLLM(LLMBase): def __init__(self, config: Optional[BaseLlmConfig] = None): super().__init__(config) if not self.config.model: - self.config.model="anthropic.claude-3-5-sonnet-20240620-v1:0" + self.config.model = "anthropic.claude-3-5-sonnet-20240620-v1:0" self.client = boto3.client("bedrock-runtime") self.model_kwargs = { "temperature": self.config.temperature, diff --git a/mem0/llms/gemini.py b/mem0/llms/gemini.py index 8091e995..7881cf05 100644 --- a/mem0/llms/gemini.py +++ b/mem0/llms/gemini.py @@ -1,5 +1,4 @@ import os -import json from typing import Dict, List, Optional try: @@ -39,11 +38,7 @@ class GeminiLLM(LLMBase): """ if tools: processed_response = { - "content": ( - content - if (content := response.candidates[0].content.parts[0].text) - else None - ), + "content": (content if (content := response.candidates[0].content.parts[0].text) else None), "tool_calls": [], } @@ -51,13 +46,9 @@ class GeminiLLM(LLMBase): if fn := part.function_call: if isinstance(fn, protos.FunctionCall): fn_call = type(fn).to_dict(fn) - processed_response["tool_calls"].append( - {"name": fn_call["name"], "arguments": fn_call["args"]} - ) + processed_response["tool_calls"].append({"name": fn_call["name"], "arguments": fn_call["args"]}) continue - processed_response["tool_calls"].append( - {"name": fn.name, "arguments": fn.args} - ) + processed_response["tool_calls"].append({"name": fn.name, "arguments": fn.args}) return processed_response else: @@ -77,9 +68,7 @@ class GeminiLLM(LLMBase): for message in messages: if message["role"] == "system": - content = ( - "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"] - ) + content = "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"] else: content = message["content"] @@ -121,9 +110,7 @@ class GeminiLLM(LLMBase): if tools: for tool in tools: func = tool["function"].copy() - new_tools.append( - {"function_declarations": [remove_additional_properties(func)]} - ) + new_tools.append({"function_declarations": [remove_additional_properties(func)]}) # TODO: temporarily ignore it to pass tests, will come back to update according to standards later. # return content_types.to_function_library(new_tools) @@ -168,9 +155,7 @@ class GeminiLLM(LLMBase): "function_calling_config": { "mode": tool_choice, "allowed_function_names": ( - [tool["function"]["name"] for tool in tools] - if tool_choice == "any" - else None + [tool["function"]["name"] for tool in tools] if tool_choice == "any" else None ), } } diff --git a/mem0/memory/graph_memory.py b/mem0/memory/graph_memory.py index 3074d23d..c1cf6fa5 100644 --- a/mem0/memory/graph_memory.py +++ b/mem0/memory/graph_memory.py @@ -58,12 +58,12 @@ class MemoryGraph: to_be_added = self._establish_nodes_relations_from_data(data, filters, entity_type_map) search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters) to_be_deleted = self._get_delete_entities_from_search_output(search_output, data, filters) - - #TODO: Batch queries with APOC plugin - #TODO: Add more filter support + + # TODO: Batch queries with APOC plugin + # TODO: Add more filter support deleted_entities = self._delete_entities(to_be_deleted, filters["user_id"]) added_entities = self._add_entities(to_be_added, filters["user_id"], entity_type_map) - + return {"deleted_entities": deleted_entities, "added_entities": added_entities} def search(self, query, filters, limit=100): @@ -86,7 +86,9 @@ class MemoryGraph: if not search_output: return [] - search_outputs_sequence = [[item["source"], item["relatationship"], item["destination"]] for item in search_output] + search_outputs_sequence = [ + [item["source"], item["relatationship"], item["destination"]] for item in search_output + ] bm25 = BM25Okapi(search_outputs_sequence) tokenized_query = query.split(" ") @@ -142,7 +144,7 @@ class MemoryGraph: logger.info(f"Retrieved {len(final_results)} relationships") return final_results - + def _retrieve_nodes_from_data(self, data, filters): """Extracts all the entities mentioned in the query.""" _tools = [EXTRACT_ENTITIES_TOOL] @@ -170,7 +172,7 @@ class MemoryGraph: entity_type_map = {k.lower().replace(" ", "_"): v.lower().replace(" ", "_") for k, v in entity_type_map.items()} logger.debug(f"Entity type map: {entity_type_map}") return entity_type_map - + def _establish_nodes_relations_from_data(self, data, filters, entity_type_map): """Eshtablish relations among the extracted nodes.""" if self.config.graph_store.custom_prompt: @@ -209,7 +211,7 @@ class MemoryGraph: extracted_entities = self._remove_spaces_from_entities(extracted_entities) logger.debug(f"Extracted entities: {extracted_entities}") return extracted_entities - + def _search_graph_db(self, node_list, filters, limit=100): """Search similar nodes among and their respective incoming and outgoing relations.""" result_relations = [] @@ -250,7 +252,7 @@ class MemoryGraph: result_relations.extend(ans) return result_relations - + def _get_delete_entities_from_search_output(self, search_output, data, filters): """Get the entities to be deleted from the search output.""" search_output_string = format_entities(search_output) @@ -273,11 +275,11 @@ class MemoryGraph: for item in memory_updates["tool_calls"]: if item["name"] == "delete_graph_memory": to_be_deleted.append(item["arguments"]) - #in case if it is not in the correct format + # in case if it is not in the correct format to_be_deleted = self._remove_spaces_from_entities(to_be_deleted) logger.debug(f"Deleted relationships: {to_be_deleted}") return to_be_deleted - + def _delete_entities(self, to_be_deleted, user_id): """Delete the entities from the graph.""" results = [] @@ -285,7 +287,7 @@ class MemoryGraph: source = item["source"] destination = item["destination"] relatationship = item["relationship"] - + # Delete the specific relationship between nodes cypher = f""" MATCH (n {{name: $source_name, user_id: $user_id}}) @@ -305,29 +307,29 @@ class MemoryGraph: result = self.graph.query(cypher, params=params) results.append(result) return results - + def _add_entities(self, to_be_added, user_id, entity_type_map): """Add the new entities to the graph. Merge the nodes if they already exist.""" results = [] for item in to_be_added: - #entities + # entities source = item["source"] destination = item["destination"] relationship = item["relationship"] - #types + # types source_type = entity_type_map.get(source, "unknown") destination_type = entity_type_map.get(destination, "unknown") - - #embeddings + + # embeddings source_embedding = self.embedding_model.embed(source) dest_embedding = self.embedding_model.embed(destination) - - #search for the nodes with the closest embeddings + + # search for the nodes with the closest embeddings source_node_search_result = self._search_source_node(source_embedding, user_id, threshold=0.9) destination_node_search_result = self._search_destination_node(dest_embedding, user_id, threshold=0.9) - #TODO: Create a cypher query and common params for all the cases + # TODO: Create a cypher query and common params for all the cases if not destination_node_search_result and source_node_search_result: cypher = f""" MATCH (source) @@ -343,7 +345,7 @@ class MemoryGraph: """ params = { - "source_id": source_node_search_result[0]['elementId(source_candidate)'], + "source_id": source_node_search_result[0]["elementId(source_candidate)"], "destination_name": destination, "relationship": relationship, "destination_type": destination_type, @@ -366,9 +368,9 @@ class MemoryGraph: r.created = timestamp() RETURN source.name AS source, type(r) AS relationship, destination.name AS target """ - + params = { - "destination_id": destination_node_search_result[0]['elementId(destination_candidate)'], + "destination_id": destination_node_search_result[0]["elementId(destination_candidate)"], "source_name": source, "relationship": relationship, "source_type": source_type, @@ -377,7 +379,7 @@ class MemoryGraph: } resp = self.graph.query(cypher, params=params) results.append(resp) - + elif source_node_search_result and destination_node_search_result: cypher = f""" MATCH (source) @@ -391,8 +393,8 @@ class MemoryGraph: RETURN source.name AS source, type(r) AS relationship, destination.name AS target """ params = { - "source_id": source_node_search_result[0]['elementId(source_candidate)'], - "destination_id": destination_node_search_result[0]['elementId(destination_candidate)'], + "source_id": source_node_search_result[0]["elementId(source_candidate)"], + "destination_id": destination_node_search_result[0]["elementId(destination_candidate)"], "user_id": user_id, "relationship": relationship, } @@ -432,7 +434,7 @@ class MemoryGraph: return entity_list def _search_source_node(self, source_embedding, user_id, threshold=0.9): - cypher = f""" + cypher = """ MATCH (source_candidate) WHERE source_candidate.embedding IS NOT NULL AND source_candidate.user_id = $user_id @@ -454,7 +456,7 @@ class MemoryGraph: RETURN elementId(source_candidate) """ - + params = { "source_embedding": source_embedding, "user_id": user_id, @@ -465,7 +467,7 @@ class MemoryGraph: return result def _search_destination_node(self, destination_embedding, user_id, threshold=0.9): - cypher = f""" + cypher = """ MATCH (destination_candidate) WHERE destination_candidate.embedding IS NOT NULL AND destination_candidate.user_id = $user_id @@ -494,4 +496,4 @@ class MemoryGraph: } result = self.graph.query(cypher, params=params) - return result \ No newline at end of file + return result diff --git a/mem0/memory/main.py b/mem0/memory/main.py index 59baf93d..2e78d06a 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -249,7 +249,7 @@ class Memory(MemoryBase): if self.api_version == "v1.1" and self.enable_graph: if filters.get("user_id") is None: filters["user_id"] = "user" - + data = "\n".join([msg["content"] for msg in messages if "content" in msg and msg["role"] != "system"]) added_entities = self.graph.add(data, filters) diff --git a/mem0/memory/utils.py b/mem0/memory/utils.py index 18e943be..290660db 100644 --- a/mem0/memory/utils.py +++ b/mem0/memory/utils.py @@ -1,5 +1,4 @@ import re -import json from mem0.configs.prompts import FACT_RETRIEVAL_PROMPT @@ -19,6 +18,7 @@ def parse_messages(messages): response += f"assistant: {msg['content']}\n" return response + def format_entities(entities): if not entities: return ""