[improvement]: Graph memory support for non-structured models. (#1823)

This commit is contained in:
Mayank
2024-09-08 01:56:43 +05:30
committed by GitHub
parent a972d2fb07
commit 51c4f2aae8
5 changed files with 265 additions and 14 deletions

View File

@@ -4,7 +4,6 @@ UPDATE_MEMORY_TOOL_GRAPH = {
"function": { "function": {
"name": "update_graph_memory", "name": "update_graph_memory",
"description": "Update the relationship key of an existing graph memory based on new information. This function should be called when there's a need to modify an existing relationship in the knowledge graph. The update should only be performed if the new information is more recent, more accurate, or provides additional context compared to the existing information. The source and destination nodes of the relationship must remain the same as in the existing graph memory; only the relationship itself can be updated.", "description": "Update the relationship key of an existing graph memory based on new information. This function should be called when there's a need to modify an existing relationship in the knowledge graph. The update should only be performed if the new information is more recent, more accurate, or provides additional context compared to the existing information. The source and destination nodes of the relationship must remain the same as in the existing graph memory; only the relationship itself can be updated.",
"strict": True,
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
@@ -32,7 +31,6 @@ ADD_MEMORY_TOOL_GRAPH = {
"function": { "function": {
"name": "add_graph_memory", "name": "add_graph_memory",
"description": "Add a new graph memory to the knowledge graph. This function creates a new relationship between two nodes, potentially creating new nodes if they don't exist.", "description": "Add a new graph memory to the knowledge graph. This function creates a new relationship between two nodes, potentially creating new nodes if they don't exist.",
"strict": True,
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
@@ -69,7 +67,6 @@ NOOP_TOOL = {
"function": { "function": {
"name": "noop", "name": "noop",
"description": "No operation should be performed to the graph entities. This function is called when the system determines that no changes or additions are necessary based on the current input or context. It serves as a placeholder action when no other actions are required, ensuring that the system can explicitly acknowledge situations where no modifications to the graph are needed.", "description": "No operation should be performed to the graph entities. This function is called when the system determines that no changes or additions are necessary based on the current input or context. It serves as a placeholder action when no other actions are required, ensuring that the system can explicitly acknowledge situations where no modifications to the graph are needed.",
"strict": True,
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": {}, "properties": {},
@@ -85,7 +82,6 @@ ADD_MESSAGE_TOOL = {
"function": { "function": {
"name": "add_query", "name": "add_query",
"description": "Add new entities and relationships to the graph based on the provided query.", "description": "Add new entities and relationships to the graph based on the provided query.",
"strict": True,
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
@@ -113,6 +109,148 @@ ADD_MESSAGE_TOOL = {
SEARCH_TOOL = { SEARCH_TOOL = {
"type": "function",
"function": {
"name": "search",
"description": "Search for nodes and relations in the graph.",
"parameters": {
"type": "object",
"properties": {
"nodes": {
"type": "array",
"items": {
"type": "string"
},
"description": "List of nodes to search for."
},
"relations": {
"type": "array",
"items": {
"type": "string"
},
"description": "List of relations to search for."
}
},
"required": ["nodes", "relations"],
"additionalProperties": False
}
}
}
UPDATE_MEMORY_STRUCT_TOOL_GRAPH = {
"type": "function",
"function": {
"name": "update_graph_memory",
"description": "Update the relationship key of an existing graph memory based on new information. This function should be called when there's a need to modify an existing relationship in the knowledge graph. The update should only be performed if the new information is more recent, more accurate, or provides additional context compared to the existing information. The source and destination nodes of the relationship must remain the same as in the existing graph memory; only the relationship itself can be updated.",
"strict": True,
"parameters": {
"type": "object",
"properties": {
"source": {
"type": "string",
"description": "The identifier of the source node in the relationship to be updated. This should match an existing node in the graph."
},
"destination": {
"type": "string",
"description": "The identifier of the destination node in the relationship to be updated. This should match an existing node in the graph."
},
"relationship": {
"type": "string",
"description": "The new or updated relationship between the source and destination nodes. This should be a concise, clear description of how the two nodes are connected."
}
},
"required": ["source", "destination", "relationship"],
"additionalProperties": False
}
}
}
ADD_MEMORY_STRUCT_TOOL_GRAPH = {
"type": "function",
"function": {
"name": "add_graph_memory",
"description": "Add a new graph memory to the knowledge graph. This function creates a new relationship between two nodes, potentially creating new nodes if they don't exist.",
"strict": True,
"parameters": {
"type": "object",
"properties": {
"source": {
"type": "string",
"description": "The identifier of the source node in the new relationship. This can be an existing node or a new node to be created."
},
"destination": {
"type": "string",
"description": "The identifier of the destination node in the new relationship. This can be an existing node or a new node to be created."
},
"relationship": {
"type": "string",
"description": "The type of relationship between the source and destination nodes. This should be a concise, clear description of how the two nodes are connected."
},
"source_type": {
"type": "string",
"description": "The type or category of the source node. This helps in classifying and organizing nodes in the graph."
},
"destination_type": {
"type": "string",
"description": "The type or category of the destination node. This helps in classifying and organizing nodes in the graph."
}
},
"required": ["source", "destination", "relationship", "source_type", "destination_type"],
"additionalProperties": False
}
}
}
NOOP_STRUCT_TOOL = {
"type": "function",
"function": {
"name": "noop",
"description": "No operation should be performed to the graph entities. This function is called when the system determines that no changes or additions are necessary based on the current input or context. It serves as a placeholder action when no other actions are required, ensuring that the system can explicitly acknowledge situations where no modifications to the graph are needed.",
"strict": True,
"parameters": {
"type": "object",
"properties": {},
"required": [],
"additionalProperties": False
}
}
}
ADD_MESSAGE_STRUCT_TOOL = {
"type": "function",
"function": {
"name": "add_query",
"description": "Add new entities and relationships to the graph based on the provided query.",
"strict": True,
"parameters": {
"type": "object",
"properties": {
"entities": {
"type": "array",
"items": {
"type": "object",
"properties": {
"source_node": {"type": "string"},
"source_type": {"type": "string"},
"relation": {"type": "string"},
"destination_node": {"type": "string"},
"destination_type": {"type": "string"}
},
"required": ["source_node", "source_type", "relation", "destination_node", "destination_type"],
"additionalProperties": False
}
}
},
"required": ["entities"],
"additionalProperties": False
}
}
}
SEARCH_STRUCT_TOOL = {
"type": "function", "type": "function",
"function": { "function": {
"name": "search", "name": "search",
@@ -140,4 +278,4 @@ SEARCH_TOOL = {
"additionalProperties": False "additionalProperties": False
} }
} }
} }

View File

@@ -0,0 +1,96 @@
import os
import json
from typing import Dict, List, Optional
from openai import AzureOpenAI
from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig
class AzureOpenAIStructuredLLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)
# Model name should match the custom deployment name chosen for it.
if not self.config.model:
self.config.model = "gpt-4o-2024-08-06"
api_key = os.getenv("LLM_AZURE_OPENAI_API_KEY") or self.config.azure_kwargs.api_key
azure_deployment = os.getenv("LLM_AZURE_DEPLOYMENT") or self.config.azure_kwargs.azure_deployment
azure_endpoint = os.getenv("LLM_AZURE_ENDPOINT") or self.config.azure_kwargs.azure_endpoint
api_version = os.getenv("LLM_AZURE_API_VERSION") or self.config.azure_kwargs.api_version
# Can display a warning if API version is of model and api-version
self.client = AzureOpenAI(
azure_deployment=azure_deployment,
azure_endpoint=azure_endpoint,
api_version=api_version,
api_key=api_key,
http_client=self.config.http_client
)
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
Args:
response: The raw response from API.
tools: The list of tools provided in the request.
Returns:
str or dict: The processed response.
"""
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": [],
}
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response
else:
return response.choices[0].message.content
def generate_response(
self,
messages: List[Dict[str, str]],
response_format=None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
"""
Generate a response based on the given messages using Azure OpenAI.
Args:
messages (list): List of message dicts containing 'role' and 'content'.
response_format (str or object, optional): Format of the response. Defaults to "text".
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns:
str: The generated response.
"""
params = {
"model": self.config.model,
"messages": messages,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p,
}
if response_format:
params["response_format"] = response_format
if tools:
params["tools"] = tools
params["tool_choice"] = tool_choice
response = self.client.chat.completions.create(**params)
return self._parse_response(response, tools)

View File

@@ -23,6 +23,7 @@ class LlmConfig(BaseModel):
"litellm", "litellm",
"azure_openai", "azure_openai",
"openai_structured", "openai_structured",
"azure_openai_structured"
): ):
return v return v
else: else:

View File

@@ -9,6 +9,11 @@ from mem0.graphs.tools import (
NOOP_TOOL, NOOP_TOOL,
SEARCH_TOOL, SEARCH_TOOL,
UPDATE_MEMORY_TOOL_GRAPH, UPDATE_MEMORY_TOOL_GRAPH,
UPDATE_MEMORY_STRUCT_TOOL_GRAPH,
ADD_MEMORY_STRUCT_TOOL_GRAPH,
NOOP_STRUCT_TOOL,
ADD_MESSAGE_STRUCT_TOOL,
SEARCH_STRUCT_TOOL
) )
from mem0.graphs.utils import EXTRACT_ENTITIES_PROMPT, get_update_memory_messages from mem0.graphs.utils import EXTRACT_ENTITIES_PROMPT, get_update_memory_messages
from mem0.utils.factory import EmbedderFactory, LlmFactory from mem0.utils.factory import EmbedderFactory, LlmFactory
@@ -23,14 +28,13 @@ class MemoryGraph:
self.config.embedder.provider, self.config.embedder.config self.config.embedder.provider, self.config.embedder.config
) )
self.llm_provider = "openai_structured"
if self.config.llm.provider: if self.config.llm.provider:
llm_provider = self.config.llm.provider self.llm_provider = self.config.llm.provider
if self.config.graph_store.llm: if self.config.graph_store.llm:
llm_provider = self.config.graph_store.llm.provider self.llm_provider = self.config.graph_store.llm.provider
else:
llm_provider = "openai_structured"
self.llm = LlmFactory.create(llm_provider, self.config.llm.config) self.llm = LlmFactory.create(self.llm_provider, self.config.llm.config)
self.user_id = None self.user_id = None
self.threshold = 0.7 self.threshold = 0.7
@@ -60,9 +64,13 @@ class MemoryGraph:
{"role": "user", "content": data}, {"role": "user", "content": data},
] ]
_tools = [ADD_MESSAGE_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [ADD_MESSAGE_STRUCT_TOOL]
extracted_entities = self.llm.generate_response( extracted_entities = self.llm.generate_response(
messages=messages, messages=messages,
tools = [ADD_MESSAGE_TOOL], tools = _tools,
) )
if extracted_entities['tool_calls']: if extracted_entities['tool_calls']:
@@ -74,9 +82,13 @@ class MemoryGraph:
update_memory_prompt = get_update_memory_messages(search_output, extracted_entities) update_memory_prompt = get_update_memory_messages(search_output, extracted_entities)
_tools=[UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL]
if self.llm_provider in ["azure_openai_structured","openai_structured"]:
_tools = [UPDATE_MEMORY_STRUCT_TOOL_GRAPH, ADD_MEMORY_STRUCT_TOOL_GRAPH, NOOP_STRUCT_TOOL]
memory_updates = self.llm.generate_response( memory_updates = self.llm.generate_response(
messages=update_memory_prompt, messages=update_memory_prompt,
tools=[UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL], tools=_tools,
) )
to_be_added = [] to_be_added = []
@@ -127,12 +139,15 @@ class MemoryGraph:
def _search(self, query, filters): def _search(self, query, filters):
_tools = [SEARCH_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [SEARCH_STRUCT_TOOL]
search_results = self.llm.generate_response( search_results = self.llm.generate_response(
messages=[ messages=[
{"role": "system", "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source node. Extract the entities."}, {"role": "system", "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source node. Extract the entities."},
{"role": "user", "content": query}, {"role": "user", "content": query},
], ],
tools = [SEARCH_TOOL] tools = _tools
) )
node_list = [] node_list = []

View File

@@ -20,7 +20,8 @@ class LlmFactory:
"litellm": "mem0.llms.litellm.LiteLLM", "litellm": "mem0.llms.litellm.LiteLLM",
"azure_openai": "mem0.llms.azure_openai.AzureOpenAILLM", "azure_openai": "mem0.llms.azure_openai.AzureOpenAILLM",
"openai_structured": "mem0.llms.openai_structured.OpenAIStructuredLLM", "openai_structured": "mem0.llms.openai_structured.OpenAIStructuredLLM",
"anthropic": "mem0.llms.anthropic.AnthropicLLM" "anthropic": "mem0.llms.anthropic.AnthropicLLM",
"azure_openai_structured": "mem0.llms.azure_openai_structured.AzureOpenAIStructuredLLM"
} }
@classmethod @classmethod