Improvements to Graph Memory (#1779)

This commit is contained in:
Prateek Chhikara
2024-08-29 22:17:08 -07:00
committed by GitHub
parent 28bc4fe05b
commit 822a8acedb
10 changed files with 246 additions and 79 deletions

View File

@@ -11,7 +11,7 @@ os.environ["TOGETHER_API_KEY"] = "your-api-key"
config = {
"llm": {
"provider": "togetherai",
"provider": "together",
"config": {
"model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
"temperature": 0.2,

View File

@@ -1,6 +1,6 @@
from typing import Optional
from pydantic import BaseModel, Field, field_validator, model_validator
from mem0.llms.configs import LlmConfig
class Neo4jConfig(BaseModel):
url: Optional[str] = Field(None, description="Host address for the graph database")
@@ -30,6 +30,14 @@ class GraphStoreConfig(BaseModel):
description="Configuration for the specific data store",
default=None
)
llm: Optional[LlmConfig] = Field(
description="LLM configuration for querying the graph store",
default=None
)
custom_prompt: Optional[str] = Field(
description="Custom prompt to fetch entities from the given text",
default=None
)
@field_validator("config")
def validate_config(cls, v, values):
@@ -38,3 +46,4 @@ class GraphStoreConfig(BaseModel):
return Neo4jConfig(**v.model_dump())
else:
raise ValueError(f"Unsupported graph store provider: {provider}")

View File

@@ -78,3 +78,66 @@ NOOP_TOOL = {
}
}
}
ADD_MESSAGE_TOOL = {
"type": "function",
"function": {
"name": "add_query",
"description": "Add new entities and relationships to the graph based on the provided query.",
"strict": True,
"parameters": {
"type": "object",
"properties": {
"entities": {
"type": "array",
"items": {
"type": "object",
"properties": {
"source_node": {"type": "string"},
"source_type": {"type": "string"},
"relation": {"type": "string"},
"destination_node": {"type": "string"},
"destination_type": {"type": "string"}
},
"required": ["source_node", "source_type", "relation", "destination_node", "destination_type"],
"additionalProperties": False
}
}
},
"required": ["entities"],
"additionalProperties": False
}
}
}
SEARCH_TOOL = {
"type": "function",
"function": {
"name": "search",
"description": "Search for nodes and relations in the graph.",
"strict": True,
"parameters": {
"type": "object",
"properties": {
"nodes": {
"type": "array",
"items": {
"type": "string"
},
"description": "List of nodes to search for."
},
"relations": {
"type": "array",
"items": {
"type": "string"
},
"description": "List of relations to search for."
}
},
"required": ["nodes", "relations"],
"additionalProperties": False
}
}
}

View File

@@ -37,6 +37,7 @@ You are an advanced algorithm designed to extract structured information from te
1. Extract only explicitly stated information from the text.
2. Identify nodes (entities/concepts), their types, and relationships.
3. Use "USER_ID" as the source node for any self-references (I, me, my, etc.) in user messages.
CUSTOM_PROMPT
Nodes and Types:
- Aim for simplicity and clarity in node representation.

View File

@@ -22,6 +22,7 @@ class LlmConfig(BaseModel):
"aws_bedrock",
"litellm",
"azure_openai",
"openai_structured",
):
return v
else:

View File

@@ -0,0 +1,88 @@
import os, json
from typing import Dict, List, Optional
from openai import OpenAI
from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig
class OpenAIStructuredLLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)
if not self.config.model:
self.config.model = "gpt-4o-2024-08-06"
api_key = os.getenv("OPENAI_API_KEY") or self.config.api_key
base_url = os.getenv("OPENAI_API_BASE") or self.config.openai_base_url
self.client = OpenAI(api_key=api_key, base_url=base_url)
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
Args:
response: The raw response from API.
response_format: The format in which the response should be processed.
Returns:
str or dict: The processed response.
"""
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": [],
}
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response
else:
return response.choices[0].message.content
def generate_response(
self,
messages: List[Dict[str, str]],
response_format=None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
"""
Generate a response based on the given messages using OpenAI.
Args:
messages (list): List of message dicts containing 'role' and 'content'.
response_format (str or object, optional): Format of the response. Defaults to "text".
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns:
str: The generated response.
"""
params = {
"model": self.config.model,
"messages": messages,
"temperature": self.config.temperature,
}
if response_format:
params["response_format"] = response_format
if tools:
params["tools"] = tools
params["tool_choice"] = tool_choice
response = self.client.beta.chat.completions.parse(**params)
return self._parse_response(response, tools)

View File

@@ -3,12 +3,14 @@ ADD_MEMORY_TOOL = {
"function": {
"name": "add_memory",
"description": "Add a memory",
"strict": True,
"parameters": {
"type": "object",
"properties": {
"data": {"type": "string", "description": "Data to add to memory"}
},
"required": ["data"],
"additionalProperties": False
},
},
}
@@ -18,6 +20,7 @@ UPDATE_MEMORY_TOOL = {
"function": {
"name": "update_memory",
"description": "Update memory provided ID and data",
"strict": True,
"parameters": {
"type": "object",
"properties": {
@@ -31,6 +34,7 @@ UPDATE_MEMORY_TOOL = {
},
},
"required": ["memory_id", "data"],
"additionalProperties": False
},
},
}
@@ -40,6 +44,7 @@ DELETE_MEMORY_TOOL = {
"function": {
"name": "delete_memory",
"description": "Delete memory by memory_id",
"strict": True,
"parameters": {
"type": "object",
"properties": {
@@ -49,6 +54,7 @@ DELETE_MEMORY_TOOL = {
}
},
"required": ["memory_id"],
"additionalProperties": False
},
},
}

View File

@@ -1,51 +1,29 @@
from langchain_community.graphs import Neo4jGraph
from pydantic import BaseModel, Field
import json
from openai import OpenAI
from mem0.embeddings.openai import OpenAIEmbedding
from mem0.llms.openai import OpenAILLM
from rank_bm25 import BM25Okapi
from mem0.utils.factory import LlmFactory, EmbedderFactory
from mem0.graphs.utils import get_update_memory_messages, EXTRACT_ENTITIES_PROMPT
from mem0.graphs.tools import UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL
from mem0.graphs.tools import UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL, ADD_MESSAGE_TOOL, SEARCH_TOOL
client = OpenAI()
class GraphData(BaseModel):
source: str = Field(..., description="The source node of the relationship")
target: str = Field(..., description="The target node of the relationship")
relationship: str = Field(..., description="The type of the relationship")
class Entities(BaseModel):
source_node: str
source_type: str
relation: str
destination_node: str
destination_type: str
class ADDQuery(BaseModel):
entities: list[Entities]
class SEARCHQuery(BaseModel):
nodes: list[str]
relations: list[str]
def get_embedding(text):
response = client.embeddings.create(
model="text-embedding-3-small",
input=text
)
return response.data[0].embedding
class MemoryGraph:
def __init__(self, config):
self.config = config
self.graph = Neo4jGraph(self.config.graph_store.config.url, self.config.graph_store.config.username, self.config.graph_store.config.password)
self.embedding_model = EmbedderFactory.create(
self.config.embedder.provider, self.config.embedder.config
)
self.llm = OpenAILLM()
self.embedding_model = OpenAIEmbedding()
if self.config.llm.provider:
llm_provider = self.config.llm.provider
if self.config.graph_store.llm:
llm_provider = self.config.graph_store.llm.provider
else:
llm_provider = "openai_structured"
self.llm = LlmFactory.create(llm_provider, self.config.llm.config)
self.user_id = None
self.threshold = 0.7
self.model_name = "gpt-4o-2024-08-06"
def add(self, data):
"""
@@ -61,41 +39,45 @@ class MemoryGraph:
# retrieve the search results
search_output = self._search(data)
extracted_entities = client.beta.chat.completions.parse(
model=self.model_name,
if self.config.graph_store.custom_prompt:
messages=[
{"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id).replace("CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}")},
{"role": "user", "content": data},
]
else:
messages=[
{"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id)},
{"role": "user", "content": data},
],
response_format=ADDQuery,
temperature=0,
).choices[0].message.parsed.entities
]
extracted_entities = self.llm.generate_response(
messages=messages,
tools = [ADD_MESSAGE_TOOL],
)
if extracted_entities['tool_calls']:
extracted_entities = extracted_entities['tool_calls'][0]['arguments']['entities']
else:
extracted_entities = []
update_memory_prompt = get_update_memory_messages(search_output, extracted_entities)
tools = [UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL]
memory_updates = client.beta.chat.completions.parse(
model=self.model_name,
memory_updates = self.llm.generate_response(
messages=update_memory_prompt,
tools=tools,
temperature=0,
).choices[0].message.tool_calls
tools=[UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL],
)
to_be_added = []
for item in memory_updates:
function_name = item.function.name
arguments = json.loads(item.function.arguments)
if function_name == "add_graph_memory":
to_be_added.append(arguments)
elif function_name == "update_graph_memory":
self._update_relationship(arguments['source'], arguments['destination'], arguments['relationship'])
elif function_name == "update_name":
self._update_name(arguments['name'])
elif function_name == "noop":
for item in memory_updates['tool_calls']:
if item['name'] == "add_graph_memory":
to_be_added.append(item['arguments'])
elif item['name'] == "update_graph_memory":
self._update_relationship(item['arguments']['source'], item['arguments']['destination'], item['arguments']['relationship'])
elif item['name'] == "noop":
continue
new_relationships_response = []
for item in to_be_added:
source = item['source'].lower().replace(" ", "_")
source_type = item['source_type'].lower().replace(" ", "_")
@@ -104,8 +86,8 @@ class MemoryGraph:
destination_type = item['destination_type'].lower().replace(" ", "_")
# Create embeddings
source_embedding = get_embedding(source)
dest_embedding = get_embedding(destination)
source_embedding = self.embedding_model.embed(source)
dest_embedding = self.embedding_model.embed(destination)
# Updated Cypher query to include node types and embeddings
cypher = f"""
@@ -127,22 +109,28 @@ class MemoryGraph:
"dest_embedding": dest_embedding
}
result = self.graph.query(cypher, params=params)
_ = self.graph.query(cypher, params=params)
def _search(self, query):
search_results = client.beta.chat.completions.parse(
model="gpt-4o-2024-08-06",
search_results = self.llm.generate_response(
messages=[
{"role": "system", "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {self.user_id} as the source node. Extract the entities."},
{"role": "user", "content": query},
],
response_format=SEARCHQuery,
).choices[0].message
node_list = search_results.parsed.nodes
relation_list = search_results.parsed.relations
tools = [SEARCH_TOOL]
)
node_list = []
relation_list = []
for item in search_results['tool_calls']:
if item['name'] == "search":
node_list.extend(item['arguments']['nodes'])
relation_list.extend(item['arguments']['relations'])
node_list = list(set(node_list))
relation_list = list(set(relation_list))
node_list = [node.lower().replace(" ", "_") for node in node_list]
relation_list = [relation.lower().replace(" ", "_") for relation in relation_list]
@@ -150,7 +138,7 @@ class MemoryGraph:
result_relations = []
for node in node_list:
n_embedding = get_embedding(node)
n_embedding = self.embedding_model.embed(node)
cypher_query = """
MATCH (n)
@@ -195,12 +183,22 @@ class MemoryGraph:
"""
search_output = self._search(query)
if not search_output:
return []
search_outputs_sequence = [[item["source"], item["relation"], item["destination"]] for item in search_output]
bm25 = BM25Okapi(search_outputs_sequence)
tokenized_query = query.split(" ")
reranked_results = bm25.get_top_n(tokenized_query, search_outputs_sequence, n=5)
search_results = []
for item in search_output:
for item in reranked_results:
search_results.append({
"source": item['source'],
"relation": item['relation'],
"destination": item['destination']
"source": item[0],
"relation": item[1],
"destination": item[2]
})
return search_results

View File

@@ -19,6 +19,7 @@ class LlmFactory:
"aws_bedrock": "mem0.llms.aws_bedrock.AWSBedrockLLM",
"litellm": "mem0.llms.litellm.LiteLLM",
"azure_openai": "mem0.llms.azure_openai.AzureOpenAILLM",
"openai_structured": "mem0.llms.openai_structured.OpenAIStructuredLLM",
}
@classmethod

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "mem0ai"
version = "0.1.5"
version = "0.1.6"
description = "Long-term memory for AI Agents"
authors = ["Mem0 <founders@mem0.ai>"]
exclude = [