Improvements to Graph Memory (#1779)
This commit is contained in:
@@ -11,7 +11,7 @@ os.environ["TOGETHER_API_KEY"] = "your-api-key"
|
||||
|
||||
config = {
|
||||
"llm": {
|
||||
"provider": "togetherai",
|
||||
"provider": "together",
|
||||
"config": {
|
||||
"model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
"temperature": 0.2,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from mem0.llms.configs import LlmConfig
|
||||
|
||||
class Neo4jConfig(BaseModel):
|
||||
url: Optional[str] = Field(None, description="Host address for the graph database")
|
||||
@@ -30,6 +30,14 @@ class GraphStoreConfig(BaseModel):
|
||||
description="Configuration for the specific data store",
|
||||
default=None
|
||||
)
|
||||
llm: Optional[LlmConfig] = Field(
|
||||
description="LLM configuration for querying the graph store",
|
||||
default=None
|
||||
)
|
||||
custom_prompt: Optional[str] = Field(
|
||||
description="Custom prompt to fetch entities from the given text",
|
||||
default=None
|
||||
)
|
||||
|
||||
@field_validator("config")
|
||||
def validate_config(cls, v, values):
|
||||
@@ -38,3 +46,4 @@ class GraphStoreConfig(BaseModel):
|
||||
return Neo4jConfig(**v.model_dump())
|
||||
else:
|
||||
raise ValueError(f"Unsupported graph store provider: {provider}")
|
||||
|
||||
|
||||
@@ -78,3 +78,66 @@ NOOP_TOOL = {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
ADD_MESSAGE_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "add_query",
|
||||
"description": "Add new entities and relationships to the graph based on the provided query.",
|
||||
"strict": True,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"entities": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"source_node": {"type": "string"},
|
||||
"source_type": {"type": "string"},
|
||||
"relation": {"type": "string"},
|
||||
"destination_node": {"type": "string"},
|
||||
"destination_type": {"type": "string"}
|
||||
},
|
||||
"required": ["source_node", "source_type", "relation", "destination_node", "destination_type"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["entities"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
SEARCH_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search",
|
||||
"description": "Search for nodes and relations in the graph.",
|
||||
"strict": True,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"nodes": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "List of nodes to search for."
|
||||
},
|
||||
"relations": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "List of relations to search for."
|
||||
}
|
||||
},
|
||||
"required": ["nodes", "relations"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -37,6 +37,7 @@ You are an advanced algorithm designed to extract structured information from te
|
||||
1. Extract only explicitly stated information from the text.
|
||||
2. Identify nodes (entities/concepts), their types, and relationships.
|
||||
3. Use "USER_ID" as the source node for any self-references (I, me, my, etc.) in user messages.
|
||||
CUSTOM_PROMPT
|
||||
|
||||
Nodes and Types:
|
||||
- Aim for simplicity and clarity in node representation.
|
||||
|
||||
@@ -22,6 +22,7 @@ class LlmConfig(BaseModel):
|
||||
"aws_bedrock",
|
||||
"litellm",
|
||||
"azure_openai",
|
||||
"openai_structured",
|
||||
):
|
||||
return v
|
||||
else:
|
||||
|
||||
88
mem0/llms/openai_structured.py
Normal file
88
mem0/llms/openai_structured.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import os, json
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
|
||||
|
||||
class OpenAIStructuredLLM(LLMBase):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model = "gpt-4o-2024-08-06"
|
||||
|
||||
api_key = os.getenv("OPENAI_API_KEY") or self.config.api_key
|
||||
base_url = os.getenv("OPENAI_API_BASE") or self.config.openai_base_url
|
||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
|
||||
Args:
|
||||
response: The raw response from API.
|
||||
response_format: The format in which the response should be processed.
|
||||
|
||||
Returns:
|
||||
str or dict: The processed response.
|
||||
"""
|
||||
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": response.choices[0].message.content,
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(tool_call.function.arguments),
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
|
||||
else:
|
||||
return response.choices[0].message.content
|
||||
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format=None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
):
|
||||
"""
|
||||
Generate a response based on the given messages using OpenAI.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": self.config.temperature,
|
||||
}
|
||||
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
if tools:
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
response = self.client.beta.chat.completions.parse(**params)
|
||||
|
||||
return self._parse_response(response, tools)
|
||||
@@ -3,12 +3,14 @@ ADD_MEMORY_TOOL = {
|
||||
"function": {
|
||||
"name": "add_memory",
|
||||
"description": "Add a memory",
|
||||
"strict": True,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {"type": "string", "description": "Data to add to memory"}
|
||||
},
|
||||
"required": ["data"],
|
||||
"additionalProperties": False
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -18,6 +20,7 @@ UPDATE_MEMORY_TOOL = {
|
||||
"function": {
|
||||
"name": "update_memory",
|
||||
"description": "Update memory provided ID and data",
|
||||
"strict": True,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -31,6 +34,7 @@ UPDATE_MEMORY_TOOL = {
|
||||
},
|
||||
},
|
||||
"required": ["memory_id", "data"],
|
||||
"additionalProperties": False
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -40,6 +44,7 @@ DELETE_MEMORY_TOOL = {
|
||||
"function": {
|
||||
"name": "delete_memory",
|
||||
"description": "Delete memory by memory_id",
|
||||
"strict": True,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -49,6 +54,7 @@ DELETE_MEMORY_TOOL = {
|
||||
}
|
||||
},
|
||||
"required": ["memory_id"],
|
||||
"additionalProperties": False
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,51 +1,29 @@
|
||||
from langchain_community.graphs import Neo4jGraph
|
||||
from pydantic import BaseModel, Field
|
||||
import json
|
||||
from openai import OpenAI
|
||||
|
||||
from mem0.embeddings.openai import OpenAIEmbedding
|
||||
from mem0.llms.openai import OpenAILLM
|
||||
from rank_bm25 import BM25Okapi
|
||||
from mem0.utils.factory import LlmFactory, EmbedderFactory
|
||||
from mem0.graphs.utils import get_update_memory_messages, EXTRACT_ENTITIES_PROMPT
|
||||
from mem0.graphs.tools import UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL
|
||||
from mem0.graphs.tools import UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL, ADD_MESSAGE_TOOL, SEARCH_TOOL
|
||||
|
||||
client = OpenAI()
|
||||
|
||||
class GraphData(BaseModel):
|
||||
source: str = Field(..., description="The source node of the relationship")
|
||||
target: str = Field(..., description="The target node of the relationship")
|
||||
relationship: str = Field(..., description="The type of the relationship")
|
||||
|
||||
class Entities(BaseModel):
|
||||
source_node: str
|
||||
source_type: str
|
||||
relation: str
|
||||
destination_node: str
|
||||
destination_type: str
|
||||
|
||||
class ADDQuery(BaseModel):
|
||||
entities: list[Entities]
|
||||
|
||||
class SEARCHQuery(BaseModel):
|
||||
nodes: list[str]
|
||||
relations: list[str]
|
||||
|
||||
def get_embedding(text):
|
||||
response = client.embeddings.create(
|
||||
model="text-embedding-3-small",
|
||||
input=text
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
class MemoryGraph:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.graph = Neo4jGraph(self.config.graph_store.config.url, self.config.graph_store.config.username, self.config.graph_store.config.password)
|
||||
self.embedding_model = EmbedderFactory.create(
|
||||
self.config.embedder.provider, self.config.embedder.config
|
||||
)
|
||||
|
||||
self.llm = OpenAILLM()
|
||||
self.embedding_model = OpenAIEmbedding()
|
||||
if self.config.llm.provider:
|
||||
llm_provider = self.config.llm.provider
|
||||
if self.config.graph_store.llm:
|
||||
llm_provider = self.config.graph_store.llm.provider
|
||||
else:
|
||||
llm_provider = "openai_structured"
|
||||
|
||||
self.llm = LlmFactory.create(llm_provider, self.config.llm.config)
|
||||
self.user_id = None
|
||||
self.threshold = 0.7
|
||||
self.model_name = "gpt-4o-2024-08-06"
|
||||
|
||||
def add(self, data):
|
||||
"""
|
||||
@@ -61,41 +39,45 @@ class MemoryGraph:
|
||||
|
||||
# retrieve the search results
|
||||
search_output = self._search(data)
|
||||
|
||||
extracted_entities = client.beta.chat.completions.parse(
|
||||
model=self.model_name,
|
||||
|
||||
if self.config.graph_store.custom_prompt:
|
||||
messages=[
|
||||
{"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id).replace("CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}")},
|
||||
{"role": "user", "content": data},
|
||||
]
|
||||
else:
|
||||
messages=[
|
||||
{"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id)},
|
||||
{"role": "user", "content": data},
|
||||
],
|
||||
response_format=ADDQuery,
|
||||
temperature=0,
|
||||
).choices[0].message.parsed.entities
|
||||
]
|
||||
|
||||
extracted_entities = self.llm.generate_response(
|
||||
messages=messages,
|
||||
tools = [ADD_MESSAGE_TOOL],
|
||||
)
|
||||
|
||||
if extracted_entities['tool_calls']:
|
||||
extracted_entities = extracted_entities['tool_calls'][0]['arguments']['entities']
|
||||
else:
|
||||
extracted_entities = []
|
||||
|
||||
update_memory_prompt = get_update_memory_messages(search_output, extracted_entities)
|
||||
tools = [UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL]
|
||||
|
||||
memory_updates = client.beta.chat.completions.parse(
|
||||
model=self.model_name,
|
||||
memory_updates = self.llm.generate_response(
|
||||
messages=update_memory_prompt,
|
||||
tools=tools,
|
||||
temperature=0,
|
||||
).choices[0].message.tool_calls
|
||||
tools=[UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL],
|
||||
)
|
||||
|
||||
to_be_added = []
|
||||
for item in memory_updates:
|
||||
function_name = item.function.name
|
||||
arguments = json.loads(item.function.arguments)
|
||||
if function_name == "add_graph_memory":
|
||||
to_be_added.append(arguments)
|
||||
elif function_name == "update_graph_memory":
|
||||
self._update_relationship(arguments['source'], arguments['destination'], arguments['relationship'])
|
||||
elif function_name == "update_name":
|
||||
self._update_name(arguments['name'])
|
||||
elif function_name == "noop":
|
||||
|
||||
for item in memory_updates['tool_calls']:
|
||||
if item['name'] == "add_graph_memory":
|
||||
to_be_added.append(item['arguments'])
|
||||
elif item['name'] == "update_graph_memory":
|
||||
self._update_relationship(item['arguments']['source'], item['arguments']['destination'], item['arguments']['relationship'])
|
||||
elif item['name'] == "noop":
|
||||
continue
|
||||
|
||||
new_relationships_response = []
|
||||
for item in to_be_added:
|
||||
source = item['source'].lower().replace(" ", "_")
|
||||
source_type = item['source_type'].lower().replace(" ", "_")
|
||||
@@ -104,8 +86,8 @@ class MemoryGraph:
|
||||
destination_type = item['destination_type'].lower().replace(" ", "_")
|
||||
|
||||
# Create embeddings
|
||||
source_embedding = get_embedding(source)
|
||||
dest_embedding = get_embedding(destination)
|
||||
source_embedding = self.embedding_model.embed(source)
|
||||
dest_embedding = self.embedding_model.embed(destination)
|
||||
|
||||
# Updated Cypher query to include node types and embeddings
|
||||
cypher = f"""
|
||||
@@ -127,22 +109,28 @@ class MemoryGraph:
|
||||
"dest_embedding": dest_embedding
|
||||
}
|
||||
|
||||
result = self.graph.query(cypher, params=params)
|
||||
|
||||
_ = self.graph.query(cypher, params=params)
|
||||
|
||||
|
||||
def _search(self, query):
|
||||
search_results = client.beta.chat.completions.parse(
|
||||
model="gpt-4o-2024-08-06",
|
||||
search_results = self.llm.generate_response(
|
||||
messages=[
|
||||
{"role": "system", "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {self.user_id} as the source node. Extract the entities."},
|
||||
{"role": "user", "content": query},
|
||||
],
|
||||
response_format=SEARCHQuery,
|
||||
).choices[0].message
|
||||
|
||||
node_list = search_results.parsed.nodes
|
||||
relation_list = search_results.parsed.relations
|
||||
tools = [SEARCH_TOOL]
|
||||
)
|
||||
|
||||
node_list = []
|
||||
relation_list = []
|
||||
|
||||
for item in search_results['tool_calls']:
|
||||
if item['name'] == "search":
|
||||
node_list.extend(item['arguments']['nodes'])
|
||||
relation_list.extend(item['arguments']['relations'])
|
||||
|
||||
node_list = list(set(node_list))
|
||||
relation_list = list(set(relation_list))
|
||||
|
||||
node_list = [node.lower().replace(" ", "_") for node in node_list]
|
||||
relation_list = [relation.lower().replace(" ", "_") for relation in relation_list]
|
||||
@@ -150,7 +138,7 @@ class MemoryGraph:
|
||||
result_relations = []
|
||||
|
||||
for node in node_list:
|
||||
n_embedding = get_embedding(node)
|
||||
n_embedding = self.embedding_model.embed(node)
|
||||
|
||||
cypher_query = """
|
||||
MATCH (n)
|
||||
@@ -195,12 +183,22 @@ class MemoryGraph:
|
||||
"""
|
||||
|
||||
search_output = self._search(query)
|
||||
|
||||
if not search_output:
|
||||
return []
|
||||
|
||||
search_outputs_sequence = [[item["source"], item["relation"], item["destination"]] for item in search_output]
|
||||
bm25 = BM25Okapi(search_outputs_sequence)
|
||||
|
||||
tokenized_query = query.split(" ")
|
||||
reranked_results = bm25.get_top_n(tokenized_query, search_outputs_sequence, n=5)
|
||||
|
||||
search_results = []
|
||||
for item in search_output:
|
||||
for item in reranked_results:
|
||||
search_results.append({
|
||||
"source": item['source'],
|
||||
"relation": item['relation'],
|
||||
"destination": item['destination']
|
||||
"source": item[0],
|
||||
"relation": item[1],
|
||||
"destination": item[2]
|
||||
})
|
||||
|
||||
return search_results
|
||||
|
||||
@@ -19,6 +19,7 @@ class LlmFactory:
|
||||
"aws_bedrock": "mem0.llms.aws_bedrock.AWSBedrockLLM",
|
||||
"litellm": "mem0.llms.litellm.LiteLLM",
|
||||
"azure_openai": "mem0.llms.azure_openai.AzureOpenAILLM",
|
||||
"openai_structured": "mem0.llms.openai_structured.OpenAIStructuredLLM",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "mem0ai"
|
||||
version = "0.1.5"
|
||||
version = "0.1.6"
|
||||
description = "Long-term memory for AI Agents"
|
||||
authors = ["Mem0 <founders@mem0.ai>"]
|
||||
exclude = [
|
||||
|
||||
Reference in New Issue
Block a user