Improvements to Graph Memory (#1779)

This commit is contained in:
Prateek Chhikara
2024-08-29 22:17:08 -07:00
committed by GitHub
parent 28bc4fe05b
commit 822a8acedb
10 changed files with 246 additions and 79 deletions

View File

@@ -11,7 +11,7 @@ os.environ["TOGETHER_API_KEY"] = "your-api-key"
config = { config = {
"llm": { "llm": {
"provider": "togetherai", "provider": "together",
"config": { "config": {
"model": "mistralai/Mixtral-8x7B-Instruct-v0.1", "model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
"temperature": 0.2, "temperature": 0.2,

View File

@@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from pydantic import BaseModel, Field, field_validator, model_validator from pydantic import BaseModel, Field, field_validator, model_validator
from mem0.llms.configs import LlmConfig
class Neo4jConfig(BaseModel): class Neo4jConfig(BaseModel):
url: Optional[str] = Field(None, description="Host address for the graph database") url: Optional[str] = Field(None, description="Host address for the graph database")
@@ -30,6 +30,14 @@ class GraphStoreConfig(BaseModel):
description="Configuration for the specific data store", description="Configuration for the specific data store",
default=None default=None
) )
llm: Optional[LlmConfig] = Field(
description="LLM configuration for querying the graph store",
default=None
)
custom_prompt: Optional[str] = Field(
description="Custom prompt to fetch entities from the given text",
default=None
)
@field_validator("config") @field_validator("config")
def validate_config(cls, v, values): def validate_config(cls, v, values):
@@ -38,3 +46,4 @@ class GraphStoreConfig(BaseModel):
return Neo4jConfig(**v.model_dump()) return Neo4jConfig(**v.model_dump())
else: else:
raise ValueError(f"Unsupported graph store provider: {provider}") raise ValueError(f"Unsupported graph store provider: {provider}")

View File

@@ -78,3 +78,66 @@ NOOP_TOOL = {
} }
} }
} }
ADD_MESSAGE_TOOL = {
"type": "function",
"function": {
"name": "add_query",
"description": "Add new entities and relationships to the graph based on the provided query.",
"strict": True,
"parameters": {
"type": "object",
"properties": {
"entities": {
"type": "array",
"items": {
"type": "object",
"properties": {
"source_node": {"type": "string"},
"source_type": {"type": "string"},
"relation": {"type": "string"},
"destination_node": {"type": "string"},
"destination_type": {"type": "string"}
},
"required": ["source_node", "source_type", "relation", "destination_node", "destination_type"],
"additionalProperties": False
}
}
},
"required": ["entities"],
"additionalProperties": False
}
}
}
SEARCH_TOOL = {
"type": "function",
"function": {
"name": "search",
"description": "Search for nodes and relations in the graph.",
"strict": True,
"parameters": {
"type": "object",
"properties": {
"nodes": {
"type": "array",
"items": {
"type": "string"
},
"description": "List of nodes to search for."
},
"relations": {
"type": "array",
"items": {
"type": "string"
},
"description": "List of relations to search for."
}
},
"required": ["nodes", "relations"],
"additionalProperties": False
}
}
}

View File

@@ -37,6 +37,7 @@ You are an advanced algorithm designed to extract structured information from te
1. Extract only explicitly stated information from the text. 1. Extract only explicitly stated information from the text.
2. Identify nodes (entities/concepts), their types, and relationships. 2. Identify nodes (entities/concepts), their types, and relationships.
3. Use "USER_ID" as the source node for any self-references (I, me, my, etc.) in user messages. 3. Use "USER_ID" as the source node for any self-references (I, me, my, etc.) in user messages.
CUSTOM_PROMPT
Nodes and Types: Nodes and Types:
- Aim for simplicity and clarity in node representation. - Aim for simplicity and clarity in node representation.

View File

@@ -22,6 +22,7 @@ class LlmConfig(BaseModel):
"aws_bedrock", "aws_bedrock",
"litellm", "litellm",
"azure_openai", "azure_openai",
"openai_structured",
): ):
return v return v
else: else:

View File

@@ -0,0 +1,88 @@
import os, json
from typing import Dict, List, Optional
from openai import OpenAI
from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig
class OpenAIStructuredLLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)
if not self.config.model:
self.config.model = "gpt-4o-2024-08-06"
api_key = os.getenv("OPENAI_API_KEY") or self.config.api_key
base_url = os.getenv("OPENAI_API_BASE") or self.config.openai_base_url
self.client = OpenAI(api_key=api_key, base_url=base_url)
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
Args:
response: The raw response from API.
response_format: The format in which the response should be processed.
Returns:
str or dict: The processed response.
"""
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": [],
}
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response
else:
return response.choices[0].message.content
def generate_response(
self,
messages: List[Dict[str, str]],
response_format=None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
"""
Generate a response based on the given messages using OpenAI.
Args:
messages (list): List of message dicts containing 'role' and 'content'.
response_format (str or object, optional): Format of the response. Defaults to "text".
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns:
str: The generated response.
"""
params = {
"model": self.config.model,
"messages": messages,
"temperature": self.config.temperature,
}
if response_format:
params["response_format"] = response_format
if tools:
params["tools"] = tools
params["tool_choice"] = tool_choice
response = self.client.beta.chat.completions.parse(**params)
return self._parse_response(response, tools)

View File

@@ -3,12 +3,14 @@ ADD_MEMORY_TOOL = {
"function": { "function": {
"name": "add_memory", "name": "add_memory",
"description": "Add a memory", "description": "Add a memory",
"strict": True,
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"data": {"type": "string", "description": "Data to add to memory"} "data": {"type": "string", "description": "Data to add to memory"}
}, },
"required": ["data"], "required": ["data"],
"additionalProperties": False
}, },
}, },
} }
@@ -18,6 +20,7 @@ UPDATE_MEMORY_TOOL = {
"function": { "function": {
"name": "update_memory", "name": "update_memory",
"description": "Update memory provided ID and data", "description": "Update memory provided ID and data",
"strict": True,
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
@@ -31,6 +34,7 @@ UPDATE_MEMORY_TOOL = {
}, },
}, },
"required": ["memory_id", "data"], "required": ["memory_id", "data"],
"additionalProperties": False
}, },
}, },
} }
@@ -40,6 +44,7 @@ DELETE_MEMORY_TOOL = {
"function": { "function": {
"name": "delete_memory", "name": "delete_memory",
"description": "Delete memory by memory_id", "description": "Delete memory by memory_id",
"strict": True,
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
@@ -49,6 +54,7 @@ DELETE_MEMORY_TOOL = {
} }
}, },
"required": ["memory_id"], "required": ["memory_id"],
"additionalProperties": False
}, },
}, },
} }

View File

@@ -1,51 +1,29 @@
from langchain_community.graphs import Neo4jGraph from langchain_community.graphs import Neo4jGraph
from pydantic import BaseModel, Field
import json import json
from openai import OpenAI from rank_bm25 import BM25Okapi
from mem0.utils.factory import LlmFactory, EmbedderFactory
from mem0.embeddings.openai import OpenAIEmbedding
from mem0.llms.openai import OpenAILLM
from mem0.graphs.utils import get_update_memory_messages, EXTRACT_ENTITIES_PROMPT from mem0.graphs.utils import get_update_memory_messages, EXTRACT_ENTITIES_PROMPT
from mem0.graphs.tools import UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL from mem0.graphs.tools import UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL, ADD_MESSAGE_TOOL, SEARCH_TOOL
client = OpenAI()
class GraphData(BaseModel):
source: str = Field(..., description="The source node of the relationship")
target: str = Field(..., description="The target node of the relationship")
relationship: str = Field(..., description="The type of the relationship")
class Entities(BaseModel):
source_node: str
source_type: str
relation: str
destination_node: str
destination_type: str
class ADDQuery(BaseModel):
entities: list[Entities]
class SEARCHQuery(BaseModel):
nodes: list[str]
relations: list[str]
def get_embedding(text):
response = client.embeddings.create(
model="text-embedding-3-small",
input=text
)
return response.data[0].embedding
class MemoryGraph: class MemoryGraph:
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.graph = Neo4jGraph(self.config.graph_store.config.url, self.config.graph_store.config.username, self.config.graph_store.config.password) self.graph = Neo4jGraph(self.config.graph_store.config.url, self.config.graph_store.config.username, self.config.graph_store.config.password)
self.embedding_model = EmbedderFactory.create(
self.config.embedder.provider, self.config.embedder.config
)
self.llm = OpenAILLM() if self.config.llm.provider:
self.embedding_model = OpenAIEmbedding() llm_provider = self.config.llm.provider
if self.config.graph_store.llm:
llm_provider = self.config.graph_store.llm.provider
else:
llm_provider = "openai_structured"
self.llm = LlmFactory.create(llm_provider, self.config.llm.config)
self.user_id = None self.user_id = None
self.threshold = 0.7 self.threshold = 0.7
self.model_name = "gpt-4o-2024-08-06"
def add(self, data): def add(self, data):
""" """
@@ -61,41 +39,45 @@ class MemoryGraph:
# retrieve the search results # retrieve the search results
search_output = self._search(data) search_output = self._search(data)
extracted_entities = client.beta.chat.completions.parse( if self.config.graph_store.custom_prompt:
model=self.model_name, messages=[
{"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id).replace("CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}")},
{"role": "user", "content": data},
]
else:
messages=[ messages=[
{"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id)}, {"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id)},
{"role": "user", "content": data}, {"role": "user", "content": data},
], ]
response_format=ADDQuery,
temperature=0, extracted_entities = self.llm.generate_response(
).choices[0].message.parsed.entities messages=messages,
tools = [ADD_MESSAGE_TOOL],
)
if extracted_entities['tool_calls']:
extracted_entities = extracted_entities['tool_calls'][0]['arguments']['entities']
else:
extracted_entities = []
update_memory_prompt = get_update_memory_messages(search_output, extracted_entities) update_memory_prompt = get_update_memory_messages(search_output, extracted_entities)
tools = [UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL]
memory_updates = client.beta.chat.completions.parse( memory_updates = self.llm.generate_response(
model=self.model_name,
messages=update_memory_prompt, messages=update_memory_prompt,
tools=tools, tools=[UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL],
temperature=0, )
).choices[0].message.tool_calls
to_be_added = [] to_be_added = []
for item in memory_updates:
function_name = item.function.name for item in memory_updates['tool_calls']:
arguments = json.loads(item.function.arguments) if item['name'] == "add_graph_memory":
if function_name == "add_graph_memory": to_be_added.append(item['arguments'])
to_be_added.append(arguments) elif item['name'] == "update_graph_memory":
elif function_name == "update_graph_memory": self._update_relationship(item['arguments']['source'], item['arguments']['destination'], item['arguments']['relationship'])
self._update_relationship(arguments['source'], arguments['destination'], arguments['relationship']) elif item['name'] == "noop":
elif function_name == "update_name":
self._update_name(arguments['name'])
elif function_name == "noop":
continue continue
new_relationships_response = []
for item in to_be_added: for item in to_be_added:
source = item['source'].lower().replace(" ", "_") source = item['source'].lower().replace(" ", "_")
source_type = item['source_type'].lower().replace(" ", "_") source_type = item['source_type'].lower().replace(" ", "_")
@@ -104,8 +86,8 @@ class MemoryGraph:
destination_type = item['destination_type'].lower().replace(" ", "_") destination_type = item['destination_type'].lower().replace(" ", "_")
# Create embeddings # Create embeddings
source_embedding = get_embedding(source) source_embedding = self.embedding_model.embed(source)
dest_embedding = get_embedding(destination) dest_embedding = self.embedding_model.embed(destination)
# Updated Cypher query to include node types and embeddings # Updated Cypher query to include node types and embeddings
cypher = f""" cypher = f"""
@@ -127,22 +109,28 @@ class MemoryGraph:
"dest_embedding": dest_embedding "dest_embedding": dest_embedding
} }
result = self.graph.query(cypher, params=params) _ = self.graph.query(cypher, params=params)
def _search(self, query): def _search(self, query):
search_results = client.beta.chat.completions.parse( search_results = self.llm.generate_response(
model="gpt-4o-2024-08-06",
messages=[ messages=[
{"role": "system", "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {self.user_id} as the source node. Extract the entities."}, {"role": "system", "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {self.user_id} as the source node. Extract the entities."},
{"role": "user", "content": query}, {"role": "user", "content": query},
], ],
response_format=SEARCHQuery, tools = [SEARCH_TOOL]
).choices[0].message )
node_list = search_results.parsed.nodes node_list = []
relation_list = search_results.parsed.relations relation_list = []
for item in search_results['tool_calls']:
if item['name'] == "search":
node_list.extend(item['arguments']['nodes'])
relation_list.extend(item['arguments']['relations'])
node_list = list(set(node_list))
relation_list = list(set(relation_list))
node_list = [node.lower().replace(" ", "_") for node in node_list] node_list = [node.lower().replace(" ", "_") for node in node_list]
relation_list = [relation.lower().replace(" ", "_") for relation in relation_list] relation_list = [relation.lower().replace(" ", "_") for relation in relation_list]
@@ -150,7 +138,7 @@ class MemoryGraph:
result_relations = [] result_relations = []
for node in node_list: for node in node_list:
n_embedding = get_embedding(node) n_embedding = self.embedding_model.embed(node)
cypher_query = """ cypher_query = """
MATCH (n) MATCH (n)
@@ -195,12 +183,22 @@ class MemoryGraph:
""" """
search_output = self._search(query) search_output = self._search(query)
if not search_output:
return []
search_outputs_sequence = [[item["source"], item["relation"], item["destination"]] for item in search_output]
bm25 = BM25Okapi(search_outputs_sequence)
tokenized_query = query.split(" ")
reranked_results = bm25.get_top_n(tokenized_query, search_outputs_sequence, n=5)
search_results = [] search_results = []
for item in search_output: for item in reranked_results:
search_results.append({ search_results.append({
"source": item['source'], "source": item[0],
"relation": item['relation'], "relation": item[1],
"destination": item['destination'] "destination": item[2]
}) })
return search_results return search_results

View File

@@ -19,6 +19,7 @@ class LlmFactory:
"aws_bedrock": "mem0.llms.aws_bedrock.AWSBedrockLLM", "aws_bedrock": "mem0.llms.aws_bedrock.AWSBedrockLLM",
"litellm": "mem0.llms.litellm.LiteLLM", "litellm": "mem0.llms.litellm.LiteLLM",
"azure_openai": "mem0.llms.azure_openai.AzureOpenAILLM", "azure_openai": "mem0.llms.azure_openai.AzureOpenAILLM",
"openai_structured": "mem0.llms.openai_structured.OpenAIStructuredLLM",
} }
@classmethod @classmethod

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "mem0ai" name = "mem0ai"
version = "0.1.5" version = "0.1.6"
description = "Long-term memory for AI Agents" description = "Long-term memory for AI Agents"
authors = ["Mem0 <founders@mem0.ai>"] authors = ["Mem0 <founders@mem0.ai>"]
exclude = [ exclude = [