diff --git a/docs/components/llms/models/together.mdx b/docs/components/llms/models/together.mdx index 98707cee..2cac8d79 100644 --- a/docs/components/llms/models/together.mdx +++ b/docs/components/llms/models/together.mdx @@ -11,7 +11,7 @@ os.environ["TOGETHER_API_KEY"] = "your-api-key" config = { "llm": { - "provider": "togetherai", + "provider": "together", "config": { "model": "mistralai/Mixtral-8x7B-Instruct-v0.1", "temperature": 0.2, diff --git a/mem0/graphs/configs.py b/mem0/graphs/configs.py index 48d8c0e2..a93d27d9 100644 --- a/mem0/graphs/configs.py +++ b/mem0/graphs/configs.py @@ -1,6 +1,6 @@ from typing import Optional - from pydantic import BaseModel, Field, field_validator, model_validator +from mem0.llms.configs import LlmConfig class Neo4jConfig(BaseModel): url: Optional[str] = Field(None, description="Host address for the graph database") @@ -30,6 +30,14 @@ class GraphStoreConfig(BaseModel): description="Configuration for the specific data store", default=None ) + llm: Optional[LlmConfig] = Field( + description="LLM configuration for querying the graph store", + default=None + ) + custom_prompt: Optional[str] = Field( + description="Custom prompt to fetch entities from the given text", + default=None + ) @field_validator("config") def validate_config(cls, v, values): @@ -38,3 +46,4 @@ class GraphStoreConfig(BaseModel): return Neo4jConfig(**v.model_dump()) else: raise ValueError(f"Unsupported graph store provider: {provider}") + diff --git a/mem0/graphs/tools.py b/mem0/graphs/tools.py index 01c9d801..c9748281 100644 --- a/mem0/graphs/tools.py +++ b/mem0/graphs/tools.py @@ -78,3 +78,66 @@ NOOP_TOOL = { } } } + + +ADD_MESSAGE_TOOL = { + "type": "function", + "function": { + "name": "add_query", + "description": "Add new entities and relationships to the graph based on the provided query.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "entities": { + "type": "array", + "items": { + "type": "object", + "properties": { + "source_node": {"type": "string"}, + "source_type": {"type": "string"}, + "relation": {"type": "string"}, + "destination_node": {"type": "string"}, + "destination_type": {"type": "string"} + }, + "required": ["source_node", "source_type", "relation", "destination_node", "destination_type"], + "additionalProperties": False + } + } + }, + "required": ["entities"], + "additionalProperties": False + } + } +} + + +SEARCH_TOOL = { + "type": "function", + "function": { + "name": "search", + "description": "Search for nodes and relations in the graph.", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "nodes": { + "type": "array", + "items": { + "type": "string" + }, + "description": "List of nodes to search for." + }, + "relations": { + "type": "array", + "items": { + "type": "string" + }, + "description": "List of relations to search for." + } + }, + "required": ["nodes", "relations"], + "additionalProperties": False + } + } +} \ No newline at end of file diff --git a/mem0/graphs/utils.py b/mem0/graphs/utils.py index bcd35996..4613952e 100644 --- a/mem0/graphs/utils.py +++ b/mem0/graphs/utils.py @@ -37,6 +37,7 @@ You are an advanced algorithm designed to extract structured information from te 1. Extract only explicitly stated information from the text. 2. Identify nodes (entities/concepts), their types, and relationships. 3. Use "USER_ID" as the source node for any self-references (I, me, my, etc.) in user messages. +CUSTOM_PROMPT Nodes and Types: - Aim for simplicity and clarity in node representation. diff --git a/mem0/llms/configs.py b/mem0/llms/configs.py index 04048ccf..78b56b33 100644 --- a/mem0/llms/configs.py +++ b/mem0/llms/configs.py @@ -22,6 +22,7 @@ class LlmConfig(BaseModel): "aws_bedrock", "litellm", "azure_openai", + "openai_structured", ): return v else: diff --git a/mem0/llms/openai_structured.py b/mem0/llms/openai_structured.py new file mode 100644 index 00000000..8f184646 --- /dev/null +++ b/mem0/llms/openai_structured.py @@ -0,0 +1,88 @@ +import os, json +from typing import Dict, List, Optional + +from openai import OpenAI + +from mem0.llms.base import LLMBase +from mem0.configs.llms.base import BaseLlmConfig + + +class OpenAIStructuredLLM(LLMBase): + def __init__(self, config: Optional[BaseLlmConfig] = None): + super().__init__(config) + + if not self.config.model: + self.config.model = "gpt-4o-2024-08-06" + + api_key = os.getenv("OPENAI_API_KEY") or self.config.api_key + base_url = os.getenv("OPENAI_API_BASE") or self.config.openai_base_url + self.client = OpenAI(api_key=api_key, base_url=base_url) + + + def _parse_response(self, response, tools): + """ + Process the response based on whether tools are used or not. + + Args: + response: The raw response from API. + response_format: The format in which the response should be processed. + + Returns: + str or dict: The processed response. + """ + + if tools: + processed_response = { + "content": response.choices[0].message.content, + "tool_calls": [], + } + + if response.choices[0].message.tool_calls: + for tool_call in response.choices[0].message.tool_calls: + processed_response["tool_calls"].append( + { + "name": tool_call.function.name, + "arguments": json.loads(tool_call.function.arguments), + } + ) + + return processed_response + + else: + return response.choices[0].message.content + + + def generate_response( + self, + messages: List[Dict[str, str]], + response_format=None, + tools: Optional[List[Dict]] = None, + tool_choice: str = "auto", + ): + """ + Generate a response based on the given messages using OpenAI. + + Args: + messages (list): List of message dicts containing 'role' and 'content'. + response_format (str or object, optional): Format of the response. Defaults to "text". + tools (list, optional): List of tools that the model can call. Defaults to None. + tool_choice (str, optional): Tool choice method. Defaults to "auto". + + Returns: + str: The generated response. + """ + params = { + "model": self.config.model, + "messages": messages, + "temperature": self.config.temperature, + } + + if response_format: + params["response_format"] = response_format + if tools: + params["tools"] = tools + params["tool_choice"] = tool_choice + + response = self.client.beta.chat.completions.parse(**params) + + return self._parse_response(response, tools) \ No newline at end of file diff --git a/mem0/llms/utils/tools.py b/mem0/llms/utils/tools.py index acb3d2a8..50031e36 100644 --- a/mem0/llms/utils/tools.py +++ b/mem0/llms/utils/tools.py @@ -3,12 +3,14 @@ ADD_MEMORY_TOOL = { "function": { "name": "add_memory", "description": "Add a memory", + "strict": True, "parameters": { "type": "object", "properties": { "data": {"type": "string", "description": "Data to add to memory"} }, "required": ["data"], + "additionalProperties": False }, }, } @@ -18,6 +20,7 @@ UPDATE_MEMORY_TOOL = { "function": { "name": "update_memory", "description": "Update memory provided ID and data", + "strict": True, "parameters": { "type": "object", "properties": { @@ -31,6 +34,7 @@ UPDATE_MEMORY_TOOL = { }, }, "required": ["memory_id", "data"], + "additionalProperties": False }, }, } @@ -40,6 +44,7 @@ DELETE_MEMORY_TOOL = { "function": { "name": "delete_memory", "description": "Delete memory by memory_id", + "strict": True, "parameters": { "type": "object", "properties": { @@ -49,6 +54,7 @@ DELETE_MEMORY_TOOL = { } }, "required": ["memory_id"], + "additionalProperties": False }, }, } diff --git a/mem0/memory/main_graph.py b/mem0/memory/main_graph.py index 2311b49e..8b84ab21 100644 --- a/mem0/memory/main_graph.py +++ b/mem0/memory/main_graph.py @@ -1,51 +1,29 @@ from langchain_community.graphs import Neo4jGraph -from pydantic import BaseModel, Field import json -from openai import OpenAI - -from mem0.embeddings.openai import OpenAIEmbedding -from mem0.llms.openai import OpenAILLM +from rank_bm25 import BM25Okapi +from mem0.utils.factory import LlmFactory, EmbedderFactory from mem0.graphs.utils import get_update_memory_messages, EXTRACT_ENTITIES_PROMPT -from mem0.graphs.tools import UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL +from mem0.graphs.tools import UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL, ADD_MESSAGE_TOOL, SEARCH_TOOL -client = OpenAI() - -class GraphData(BaseModel): - source: str = Field(..., description="The source node of the relationship") - target: str = Field(..., description="The target node of the relationship") - relationship: str = Field(..., description="The type of the relationship") - -class Entities(BaseModel): - source_node: str - source_type: str - relation: str - destination_node: str - destination_type: str - -class ADDQuery(BaseModel): - entities: list[Entities] - -class SEARCHQuery(BaseModel): - nodes: list[str] - relations: list[str] - -def get_embedding(text): - response = client.embeddings.create( - model="text-embedding-3-small", - input=text - ) - return response.data[0].embedding class MemoryGraph: def __init__(self, config): self.config = config self.graph = Neo4jGraph(self.config.graph_store.config.url, self.config.graph_store.config.username, self.config.graph_store.config.password) + self.embedding_model = EmbedderFactory.create( + self.config.embedder.provider, self.config.embedder.config + ) - self.llm = OpenAILLM() - self.embedding_model = OpenAIEmbedding() + if self.config.llm.provider: + llm_provider = self.config.llm.provider + if self.config.graph_store.llm: + llm_provider = self.config.graph_store.llm.provider + else: + llm_provider = "openai_structured" + + self.llm = LlmFactory.create(llm_provider, self.config.llm.config) self.user_id = None self.threshold = 0.7 - self.model_name = "gpt-4o-2024-08-06" def add(self, data): """ @@ -61,41 +39,45 @@ class MemoryGraph: # retrieve the search results search_output = self._search(data) - - extracted_entities = client.beta.chat.completions.parse( - model=self.model_name, + + if self.config.graph_store.custom_prompt: + messages=[ + {"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id).replace("CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}")}, + {"role": "user", "content": data}, + ] + else: messages=[ {"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id)}, {"role": "user", "content": data}, - ], - response_format=ADDQuery, - temperature=0, - ).choices[0].message.parsed.entities + ] + + extracted_entities = self.llm.generate_response( + messages=messages, + tools = [ADD_MESSAGE_TOOL], + ) + if extracted_entities['tool_calls']: + extracted_entities = extracted_entities['tool_calls'][0]['arguments']['entities'] + else: + extracted_entities = [] + update_memory_prompt = get_update_memory_messages(search_output, extracted_entities) - tools = [UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL] - memory_updates = client.beta.chat.completions.parse( - model=self.model_name, + memory_updates = self.llm.generate_response( messages=update_memory_prompt, - tools=tools, - temperature=0, - ).choices[0].message.tool_calls + tools=[UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL], + ) to_be_added = [] - for item in memory_updates: - function_name = item.function.name - arguments = json.loads(item.function.arguments) - if function_name == "add_graph_memory": - to_be_added.append(arguments) - elif function_name == "update_graph_memory": - self._update_relationship(arguments['source'], arguments['destination'], arguments['relationship']) - elif function_name == "update_name": - self._update_name(arguments['name']) - elif function_name == "noop": + + for item in memory_updates['tool_calls']: + if item['name'] == "add_graph_memory": + to_be_added.append(item['arguments']) + elif item['name'] == "update_graph_memory": + self._update_relationship(item['arguments']['source'], item['arguments']['destination'], item['arguments']['relationship']) + elif item['name'] == "noop": continue - new_relationships_response = [] for item in to_be_added: source = item['source'].lower().replace(" ", "_") source_type = item['source_type'].lower().replace(" ", "_") @@ -104,8 +86,8 @@ class MemoryGraph: destination_type = item['destination_type'].lower().replace(" ", "_") # Create embeddings - source_embedding = get_embedding(source) - dest_embedding = get_embedding(destination) + source_embedding = self.embedding_model.embed(source) + dest_embedding = self.embedding_model.embed(destination) # Updated Cypher query to include node types and embeddings cypher = f""" @@ -127,22 +109,28 @@ class MemoryGraph: "dest_embedding": dest_embedding } - result = self.graph.query(cypher, params=params) - + _ = self.graph.query(cypher, params=params) def _search(self, query): - search_results = client.beta.chat.completions.parse( - model="gpt-4o-2024-08-06", + search_results = self.llm.generate_response( messages=[ {"role": "system", "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {self.user_id} as the source node. Extract the entities."}, {"role": "user", "content": query}, ], - response_format=SEARCHQuery, - ).choices[0].message - - node_list = search_results.parsed.nodes - relation_list = search_results.parsed.relations + tools = [SEARCH_TOOL] + ) + + node_list = [] + relation_list = [] + + for item in search_results['tool_calls']: + if item['name'] == "search": + node_list.extend(item['arguments']['nodes']) + relation_list.extend(item['arguments']['relations']) + + node_list = list(set(node_list)) + relation_list = list(set(relation_list)) node_list = [node.lower().replace(" ", "_") for node in node_list] relation_list = [relation.lower().replace(" ", "_") for relation in relation_list] @@ -150,7 +138,7 @@ class MemoryGraph: result_relations = [] for node in node_list: - n_embedding = get_embedding(node) + n_embedding = self.embedding_model.embed(node) cypher_query = """ MATCH (n) @@ -195,12 +183,22 @@ class MemoryGraph: """ search_output = self._search(query) + + if not search_output: + return [] + + search_outputs_sequence = [[item["source"], item["relation"], item["destination"]] for item in search_output] + bm25 = BM25Okapi(search_outputs_sequence) + + tokenized_query = query.split(" ") + reranked_results = bm25.get_top_n(tokenized_query, search_outputs_sequence, n=5) + search_results = [] - for item in search_output: + for item in reranked_results: search_results.append({ - "source": item['source'], - "relation": item['relation'], - "destination": item['destination'] + "source": item[0], + "relation": item[1], + "destination": item[2] }) return search_results diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index db1fcbbe..bf012b42 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -19,6 +19,7 @@ class LlmFactory: "aws_bedrock": "mem0.llms.aws_bedrock.AWSBedrockLLM", "litellm": "mem0.llms.litellm.LiteLLM", "azure_openai": "mem0.llms.azure_openai.AzureOpenAILLM", + "openai_structured": "mem0.llms.openai_structured.OpenAIStructuredLLM", } @classmethod diff --git a/pyproject.toml b/pyproject.toml index c912e9e9..1d9b3338 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "mem0ai" -version = "0.1.5" +version = "0.1.6" description = "Long-term memory for AI Agents" authors = ["Mem0 "] exclude = [