[improvement]: Graph memory support for non-structured models. (#1823)
This commit is contained in:
@@ -4,7 +4,6 @@ UPDATE_MEMORY_TOOL_GRAPH = {
|
|||||||
"function": {
|
"function": {
|
||||||
"name": "update_graph_memory",
|
"name": "update_graph_memory",
|
||||||
"description": "Update the relationship key of an existing graph memory based on new information. This function should be called when there's a need to modify an existing relationship in the knowledge graph. The update should only be performed if the new information is more recent, more accurate, or provides additional context compared to the existing information. The source and destination nodes of the relationship must remain the same as in the existing graph memory; only the relationship itself can be updated.",
|
"description": "Update the relationship key of an existing graph memory based on new information. This function should be called when there's a need to modify an existing relationship in the knowledge graph. The update should only be performed if the new information is more recent, more accurate, or provides additional context compared to the existing information. The source and destination nodes of the relationship must remain the same as in the existing graph memory; only the relationship itself can be updated.",
|
||||||
"strict": True,
|
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
@@ -32,7 +31,6 @@ ADD_MEMORY_TOOL_GRAPH = {
|
|||||||
"function": {
|
"function": {
|
||||||
"name": "add_graph_memory",
|
"name": "add_graph_memory",
|
||||||
"description": "Add a new graph memory to the knowledge graph. This function creates a new relationship between two nodes, potentially creating new nodes if they don't exist.",
|
"description": "Add a new graph memory to the knowledge graph. This function creates a new relationship between two nodes, potentially creating new nodes if they don't exist.",
|
||||||
"strict": True,
|
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
@@ -69,7 +67,6 @@ NOOP_TOOL = {
|
|||||||
"function": {
|
"function": {
|
||||||
"name": "noop",
|
"name": "noop",
|
||||||
"description": "No operation should be performed to the graph entities. This function is called when the system determines that no changes or additions are necessary based on the current input or context. It serves as a placeholder action when no other actions are required, ensuring that the system can explicitly acknowledge situations where no modifications to the graph are needed.",
|
"description": "No operation should be performed to the graph entities. This function is called when the system determines that no changes or additions are necessary based on the current input or context. It serves as a placeholder action when no other actions are required, ensuring that the system can explicitly acknowledge situations where no modifications to the graph are needed.",
|
||||||
"strict": True,
|
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {},
|
"properties": {},
|
||||||
@@ -85,7 +82,6 @@ ADD_MESSAGE_TOOL = {
|
|||||||
"function": {
|
"function": {
|
||||||
"name": "add_query",
|
"name": "add_query",
|
||||||
"description": "Add new entities and relationships to the graph based on the provided query.",
|
"description": "Add new entities and relationships to the graph based on the provided query.",
|
||||||
"strict": True,
|
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
@@ -113,6 +109,148 @@ ADD_MESSAGE_TOOL = {
|
|||||||
|
|
||||||
|
|
||||||
SEARCH_TOOL = {
|
SEARCH_TOOL = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "search",
|
||||||
|
"description": "Search for nodes and relations in the graph.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"nodes": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"description": "List of nodes to search for."
|
||||||
|
},
|
||||||
|
"relations": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"description": "List of relations to search for."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["nodes", "relations"],
|
||||||
|
"additionalProperties": False
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
UPDATE_MEMORY_STRUCT_TOOL_GRAPH = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "update_graph_memory",
|
||||||
|
"description": "Update the relationship key of an existing graph memory based on new information. This function should be called when there's a need to modify an existing relationship in the knowledge graph. The update should only be performed if the new information is more recent, more accurate, or provides additional context compared to the existing information. The source and destination nodes of the relationship must remain the same as in the existing graph memory; only the relationship itself can be updated.",
|
||||||
|
"strict": True,
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"source": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The identifier of the source node in the relationship to be updated. This should match an existing node in the graph."
|
||||||
|
},
|
||||||
|
"destination": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The identifier of the destination node in the relationship to be updated. This should match an existing node in the graph."
|
||||||
|
},
|
||||||
|
"relationship": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The new or updated relationship between the source and destination nodes. This should be a concise, clear description of how the two nodes are connected."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["source", "destination", "relationship"],
|
||||||
|
"additionalProperties": False
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ADD_MEMORY_STRUCT_TOOL_GRAPH = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "add_graph_memory",
|
||||||
|
"description": "Add a new graph memory to the knowledge graph. This function creates a new relationship between two nodes, potentially creating new nodes if they don't exist.",
|
||||||
|
"strict": True,
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"source": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The identifier of the source node in the new relationship. This can be an existing node or a new node to be created."
|
||||||
|
},
|
||||||
|
"destination": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The identifier of the destination node in the new relationship. This can be an existing node or a new node to be created."
|
||||||
|
},
|
||||||
|
"relationship": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The type of relationship between the source and destination nodes. This should be a concise, clear description of how the two nodes are connected."
|
||||||
|
},
|
||||||
|
"source_type": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The type or category of the source node. This helps in classifying and organizing nodes in the graph."
|
||||||
|
},
|
||||||
|
"destination_type": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The type or category of the destination node. This helps in classifying and organizing nodes in the graph."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["source", "destination", "relationship", "source_type", "destination_type"],
|
||||||
|
"additionalProperties": False
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
NOOP_STRUCT_TOOL = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "noop",
|
||||||
|
"description": "No operation should be performed to the graph entities. This function is called when the system determines that no changes or additions are necessary based on the current input or context. It serves as a placeholder action when no other actions are required, ensuring that the system can explicitly acknowledge situations where no modifications to the graph are needed.",
|
||||||
|
"strict": True,
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {},
|
||||||
|
"required": [],
|
||||||
|
"additionalProperties": False
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
ADD_MESSAGE_STRUCT_TOOL = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "add_query",
|
||||||
|
"description": "Add new entities and relationships to the graph based on the provided query.",
|
||||||
|
"strict": True,
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"entities": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"source_node": {"type": "string"},
|
||||||
|
"source_type": {"type": "string"},
|
||||||
|
"relation": {"type": "string"},
|
||||||
|
"destination_node": {"type": "string"},
|
||||||
|
"destination_type": {"type": "string"}
|
||||||
|
},
|
||||||
|
"required": ["source_node", "source_type", "relation", "destination_node", "destination_type"],
|
||||||
|
"additionalProperties": False
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["entities"],
|
||||||
|
"additionalProperties": False
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
SEARCH_STRUCT_TOOL = {
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "search",
|
"name": "search",
|
||||||
@@ -140,4 +278,4 @@ SEARCH_TOOL = {
|
|||||||
"additionalProperties": False
|
"additionalProperties": False
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
96
mem0/llms/azure_openai_structured.py
Normal file
96
mem0/llms/azure_openai_structured.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from openai import AzureOpenAI
|
||||||
|
|
||||||
|
from mem0.llms.base import LLMBase
|
||||||
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
|
|
||||||
|
|
||||||
|
class AzureOpenAIStructuredLLM(LLMBase):
|
||||||
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
# Model name should match the custom deployment name chosen for it.
|
||||||
|
if not self.config.model:
|
||||||
|
self.config.model = "gpt-4o-2024-08-06"
|
||||||
|
|
||||||
|
api_key = os.getenv("LLM_AZURE_OPENAI_API_KEY") or self.config.azure_kwargs.api_key
|
||||||
|
azure_deployment = os.getenv("LLM_AZURE_DEPLOYMENT") or self.config.azure_kwargs.azure_deployment
|
||||||
|
azure_endpoint = os.getenv("LLM_AZURE_ENDPOINT") or self.config.azure_kwargs.azure_endpoint
|
||||||
|
api_version = os.getenv("LLM_AZURE_API_VERSION") or self.config.azure_kwargs.api_version
|
||||||
|
# Can display a warning if API version is of model and api-version
|
||||||
|
|
||||||
|
self.client = AzureOpenAI(
|
||||||
|
azure_deployment=azure_deployment,
|
||||||
|
azure_endpoint=azure_endpoint,
|
||||||
|
api_version=api_version,
|
||||||
|
api_key=api_key,
|
||||||
|
http_client=self.config.http_client
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_response(self, response, tools):
|
||||||
|
"""
|
||||||
|
Process the response based on whether tools are used or not.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The raw response from API.
|
||||||
|
tools: The list of tools provided in the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str or dict: The processed response.
|
||||||
|
"""
|
||||||
|
if tools:
|
||||||
|
processed_response = {
|
||||||
|
"content": response.choices[0].message.content,
|
||||||
|
"tool_calls": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.choices[0].message.tool_calls:
|
||||||
|
for tool_call in response.choices[0].message.tool_calls:
|
||||||
|
processed_response["tool_calls"].append(
|
||||||
|
{
|
||||||
|
"name": tool_call.function.name,
|
||||||
|
"arguments": json.loads(tool_call.function.arguments),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return processed_response
|
||||||
|
else:
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
|
def generate_response(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
response_format=None,
|
||||||
|
tools: Optional[List[Dict]] = None,
|
||||||
|
tool_choice: str = "auto",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate a response based on the given messages using Azure OpenAI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages (list): List of message dicts containing 'role' and 'content'.
|
||||||
|
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||||
|
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||||
|
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The generated response.
|
||||||
|
"""
|
||||||
|
params = {
|
||||||
|
"model": self.config.model,
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": self.config.temperature,
|
||||||
|
"max_tokens": self.config.max_tokens,
|
||||||
|
"top_p": self.config.top_p,
|
||||||
|
}
|
||||||
|
if response_format:
|
||||||
|
params["response_format"] = response_format
|
||||||
|
if tools:
|
||||||
|
params["tools"] = tools
|
||||||
|
params["tool_choice"] = tool_choice
|
||||||
|
|
||||||
|
response = self.client.chat.completions.create(**params)
|
||||||
|
return self._parse_response(response, tools)
|
||||||
@@ -23,6 +23,7 @@ class LlmConfig(BaseModel):
|
|||||||
"litellm",
|
"litellm",
|
||||||
"azure_openai",
|
"azure_openai",
|
||||||
"openai_structured",
|
"openai_structured",
|
||||||
|
"azure_openai_structured"
|
||||||
):
|
):
|
||||||
return v
|
return v
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -9,6 +9,11 @@ from mem0.graphs.tools import (
|
|||||||
NOOP_TOOL,
|
NOOP_TOOL,
|
||||||
SEARCH_TOOL,
|
SEARCH_TOOL,
|
||||||
UPDATE_MEMORY_TOOL_GRAPH,
|
UPDATE_MEMORY_TOOL_GRAPH,
|
||||||
|
UPDATE_MEMORY_STRUCT_TOOL_GRAPH,
|
||||||
|
ADD_MEMORY_STRUCT_TOOL_GRAPH,
|
||||||
|
NOOP_STRUCT_TOOL,
|
||||||
|
ADD_MESSAGE_STRUCT_TOOL,
|
||||||
|
SEARCH_STRUCT_TOOL
|
||||||
)
|
)
|
||||||
from mem0.graphs.utils import EXTRACT_ENTITIES_PROMPT, get_update_memory_messages
|
from mem0.graphs.utils import EXTRACT_ENTITIES_PROMPT, get_update_memory_messages
|
||||||
from mem0.utils.factory import EmbedderFactory, LlmFactory
|
from mem0.utils.factory import EmbedderFactory, LlmFactory
|
||||||
@@ -23,14 +28,13 @@ class MemoryGraph:
|
|||||||
self.config.embedder.provider, self.config.embedder.config
|
self.config.embedder.provider, self.config.embedder.config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.llm_provider = "openai_structured"
|
||||||
if self.config.llm.provider:
|
if self.config.llm.provider:
|
||||||
llm_provider = self.config.llm.provider
|
self.llm_provider = self.config.llm.provider
|
||||||
if self.config.graph_store.llm:
|
if self.config.graph_store.llm:
|
||||||
llm_provider = self.config.graph_store.llm.provider
|
self.llm_provider = self.config.graph_store.llm.provider
|
||||||
else:
|
|
||||||
llm_provider = "openai_structured"
|
|
||||||
|
|
||||||
self.llm = LlmFactory.create(llm_provider, self.config.llm.config)
|
self.llm = LlmFactory.create(self.llm_provider, self.config.llm.config)
|
||||||
self.user_id = None
|
self.user_id = None
|
||||||
self.threshold = 0.7
|
self.threshold = 0.7
|
||||||
|
|
||||||
@@ -60,9 +64,13 @@ class MemoryGraph:
|
|||||||
{"role": "user", "content": data},
|
{"role": "user", "content": data},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
_tools = [ADD_MESSAGE_TOOL]
|
||||||
|
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
||||||
|
_tools = [ADD_MESSAGE_STRUCT_TOOL]
|
||||||
|
|
||||||
extracted_entities = self.llm.generate_response(
|
extracted_entities = self.llm.generate_response(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools = [ADD_MESSAGE_TOOL],
|
tools = _tools,
|
||||||
)
|
)
|
||||||
|
|
||||||
if extracted_entities['tool_calls']:
|
if extracted_entities['tool_calls']:
|
||||||
@@ -74,9 +82,13 @@ class MemoryGraph:
|
|||||||
|
|
||||||
update_memory_prompt = get_update_memory_messages(search_output, extracted_entities)
|
update_memory_prompt = get_update_memory_messages(search_output, extracted_entities)
|
||||||
|
|
||||||
|
_tools=[UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL]
|
||||||
|
if self.llm_provider in ["azure_openai_structured","openai_structured"]:
|
||||||
|
_tools = [UPDATE_MEMORY_STRUCT_TOOL_GRAPH, ADD_MEMORY_STRUCT_TOOL_GRAPH, NOOP_STRUCT_TOOL]
|
||||||
|
|
||||||
memory_updates = self.llm.generate_response(
|
memory_updates = self.llm.generate_response(
|
||||||
messages=update_memory_prompt,
|
messages=update_memory_prompt,
|
||||||
tools=[UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL],
|
tools=_tools,
|
||||||
)
|
)
|
||||||
|
|
||||||
to_be_added = []
|
to_be_added = []
|
||||||
@@ -127,12 +139,15 @@ class MemoryGraph:
|
|||||||
|
|
||||||
|
|
||||||
def _search(self, query, filters):
|
def _search(self, query, filters):
|
||||||
|
_tools = [SEARCH_TOOL]
|
||||||
|
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
||||||
|
_tools = [SEARCH_STRUCT_TOOL]
|
||||||
search_results = self.llm.generate_response(
|
search_results = self.llm.generate_response(
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system", "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source node. Extract the entities."},
|
{"role": "system", "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source node. Extract the entities."},
|
||||||
{"role": "user", "content": query},
|
{"role": "user", "content": query},
|
||||||
],
|
],
|
||||||
tools = [SEARCH_TOOL]
|
tools = _tools
|
||||||
)
|
)
|
||||||
|
|
||||||
node_list = []
|
node_list = []
|
||||||
|
|||||||
@@ -20,7 +20,8 @@ class LlmFactory:
|
|||||||
"litellm": "mem0.llms.litellm.LiteLLM",
|
"litellm": "mem0.llms.litellm.LiteLLM",
|
||||||
"azure_openai": "mem0.llms.azure_openai.AzureOpenAILLM",
|
"azure_openai": "mem0.llms.azure_openai.AzureOpenAILLM",
|
||||||
"openai_structured": "mem0.llms.openai_structured.OpenAIStructuredLLM",
|
"openai_structured": "mem0.llms.openai_structured.OpenAIStructuredLLM",
|
||||||
"anthropic": "mem0.llms.anthropic.AnthropicLLM"
|
"anthropic": "mem0.llms.anthropic.AnthropicLLM",
|
||||||
|
"azure_openai_structured": "mem0.llms.azure_openai_structured.AzureOpenAIStructuredLLM"
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user