Fix: Add Google Genai library support (#2941)
This commit is contained in:
@@ -4,7 +4,11 @@ title: Gemini
|
|||||||
|
|
||||||
<Snippet file="paper-release.mdx" />
|
<Snippet file="paper-release.mdx" />
|
||||||
|
|
||||||
To use Gemini model, you have to set the `GEMINI_API_KEY` environment variable. You can obtain the Gemini API key from the [Google AI Studio](https://aistudio.google.com/app/apikey)
|
To use the Gemini model, set the `GEMINI_API_KEY` environment variable. You can obtain the Gemini API key from [Google AI Studio](https://aistudio.google.com/app/apikey).
|
||||||
|
|
||||||
|
> **Note:** As of the latest release, Mem0 uses the new `google.genai` SDK instead of the deprecated `google.generativeai`. All message formatting and model interaction now use the updated `types` module from `google.genai`.
|
||||||
|
|
||||||
|
> **Note:** Some Gemini models are being deprecated and will retire soon. It is recommended to migrate to the latest stable models like `"gemini-2.0-flash-001"` or `"gemini-2.0-flash-lite-001"` to ensure ongoing support and improvements.
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
@@ -12,28 +16,32 @@ To use Gemini model, you have to set the `GEMINI_API_KEY` environment variable.
|
|||||||
import os
|
import os
|
||||||
from mem0 import Memory
|
from mem0 import Memory
|
||||||
|
|
||||||
os.environ["OPENAI_API_KEY"] = "your-api-key" # used for embedding model
|
os.environ["OPENAI_API_KEY"] = "your-openai-api-key" # Used for embedding model
|
||||||
os.environ["GEMINI_API_KEY"] = "your-api-key"
|
os.environ["GEMINI_API_KEY"] = "your-gemini-api-key"
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"llm": {
|
"llm": {
|
||||||
"provider": "gemini",
|
"provider": "gemini",
|
||||||
"config": {
|
"config": {
|
||||||
"model": "gemini-1.5-flash-latest",
|
"model": "gemini-2.0-flash-001",
|
||||||
"temperature": 0.2,
|
"temperature": 0.2,
|
||||||
"max_tokens": 2000,
|
"max_tokens": 2000,
|
||||||
|
"top_p": 1.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
m = Memory.from_config(config)
|
m = Memory.from_config(config)
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
|
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
|
||||||
{"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."},
|
{"role": "assistant", "content": "How about thriller movies? They can be quite engaging."},
|
||||||
{"role": "user", "content": "I’m not a big fan of thriller movies but I love sci-fi movies."},
|
{"role": "user", "content": "I’m not a big fan of thrillers, but I love sci-fi movies."},
|
||||||
{"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."}
|
{"role": "assistant", "content": "Got it! I'll avoid thrillers and suggest sci-fi movies instead."}
|
||||||
]
|
]
|
||||||
|
|
||||||
m.add(messages, user_id="alice", metadata={"category": "movies"})
|
m.add(messages, user_id="alice", metadata={"category": "movies"})
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Config
|
## Config
|
||||||
|
|||||||
@@ -238,16 +238,24 @@ The Mem0's graph supports the following operations:
|
|||||||
### Add Memories
|
### Add Memories
|
||||||
|
|
||||||
<Note>
|
<Note>
|
||||||
If you are using Mem0 with Graph Memory, it is recommended to pass `user_id`. Use `userId` in NodeSDK.
|
Mem0 with Graph Memory supports both `user_id` and `agent_id` parameters. You can use either or both to organize your memories. Use `userId` and `agentId` in NodeSDK.
|
||||||
</Note>
|
</Note>
|
||||||
|
|
||||||
<CodeGroup>
|
<CodeGroup>
|
||||||
```python Python
|
```python Python
|
||||||
|
# Using only user_id
|
||||||
m.add("I like pizza", user_id="alice")
|
m.add("I like pizza", user_id="alice")
|
||||||
|
|
||||||
|
# Using both user_id and agent_id
|
||||||
|
m.add("I like pizza", user_id="alice", agent_id="food-assistant")
|
||||||
```
|
```
|
||||||
|
|
||||||
```typescript TypeScript
|
```typescript TypeScript
|
||||||
|
// Using only userId
|
||||||
memory.add("I like pizza", { userId: "alice" });
|
memory.add("I like pizza", { userId: "alice" });
|
||||||
|
|
||||||
|
// Using both userId and agentId
|
||||||
|
memory.add("I like pizza", { userId: "alice", agentId: "food-assistant" });
|
||||||
```
|
```
|
||||||
|
|
||||||
```json Output
|
```json Output
|
||||||
@@ -260,11 +268,19 @@ memory.add("I like pizza", { userId: "alice" });
|
|||||||
|
|
||||||
<CodeGroup>
|
<CodeGroup>
|
||||||
```python Python
|
```python Python
|
||||||
|
# Get all memories for a user
|
||||||
m.get_all(user_id="alice")
|
m.get_all(user_id="alice")
|
||||||
|
|
||||||
|
# Get all memories for a specific agent belonging to a user
|
||||||
|
m.get_all(user_id="alice", agent_id="food-assistant")
|
||||||
```
|
```
|
||||||
|
|
||||||
```typescript TypeScript
|
```typescript TypeScript
|
||||||
|
// Get all memories for a user
|
||||||
memory.getAll({ userId: "alice" });
|
memory.getAll({ userId: "alice" });
|
||||||
|
|
||||||
|
// Get all memories for a specific agent belonging to a user
|
||||||
|
memory.getAll({ userId: "alice", agentId: "food-assistant" });
|
||||||
```
|
```
|
||||||
|
|
||||||
```json Output
|
```json Output
|
||||||
@@ -277,7 +293,8 @@ memory.getAll({ userId: "alice" });
|
|||||||
'metadata': None,
|
'metadata': None,
|
||||||
'created_at': '2024-08-20T14:09:27.588719-07:00',
|
'created_at': '2024-08-20T14:09:27.588719-07:00',
|
||||||
'updated_at': None,
|
'updated_at': None,
|
||||||
'user_id': 'alice'
|
'user_id': 'alice',
|
||||||
|
'agent_id': 'food-assistant'
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
'entities': [
|
'entities': [
|
||||||
@@ -295,11 +312,19 @@ memory.getAll({ userId: "alice" });
|
|||||||
|
|
||||||
<CodeGroup>
|
<CodeGroup>
|
||||||
```python Python
|
```python Python
|
||||||
|
# Search memories for a user
|
||||||
m.search("tell me my name.", user_id="alice")
|
m.search("tell me my name.", user_id="alice")
|
||||||
|
|
||||||
|
# Search memories for a specific agent belonging to a user
|
||||||
|
m.search("tell me my name.", user_id="alice", agent_id="food-assistant")
|
||||||
```
|
```
|
||||||
|
|
||||||
```typescript TypeScript
|
```typescript TypeScript
|
||||||
|
// Search memories for a user
|
||||||
memory.search("tell me my name.", { userId: "alice" });
|
memory.search("tell me my name.", { userId: "alice" });
|
||||||
|
|
||||||
|
// Search memories for a specific agent belonging to a user
|
||||||
|
memory.search("tell me my name.", { userId: "alice", agentId: "food-assistant" });
|
||||||
```
|
```
|
||||||
|
|
||||||
```json Output
|
```json Output
|
||||||
@@ -312,7 +337,8 @@ memory.search("tell me my name.", { userId: "alice" });
|
|||||||
'metadata': None,
|
'metadata': None,
|
||||||
'created_at': '2024-08-20T14:09:27.588719-07:00',
|
'created_at': '2024-08-20T14:09:27.588719-07:00',
|
||||||
'updated_at': None,
|
'updated_at': None,
|
||||||
'user_id': 'alice'
|
'user_id': 'alice',
|
||||||
|
'agent_id': 'food-assistant'
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
'entities': [
|
'entities': [
|
||||||
@@ -331,11 +357,19 @@ memory.search("tell me my name.", { userId: "alice" });
|
|||||||
|
|
||||||
<CodeGroup>
|
<CodeGroup>
|
||||||
```python Python
|
```python Python
|
||||||
|
# Delete all memories for a user
|
||||||
m.delete_all(user_id="alice")
|
m.delete_all(user_id="alice")
|
||||||
|
|
||||||
|
# Delete all memories for a specific agent belonging to a user
|
||||||
|
m.delete_all(user_id="alice", agent_id="food-assistant")
|
||||||
```
|
```
|
||||||
|
|
||||||
```typescript TypeScript
|
```typescript TypeScript
|
||||||
|
// Delete all memories for a user
|
||||||
memory.deleteAll({ userId: "alice" });
|
memory.deleteAll({ userId: "alice" });
|
||||||
|
|
||||||
|
// Delete all memories for a specific agent belonging to a user
|
||||||
|
memory.deleteAll({ userId: "alice", agentId: "food-assistant" });
|
||||||
```
|
```
|
||||||
</CodeGroup>
|
</CodeGroup>
|
||||||
|
|
||||||
@@ -516,6 +550,42 @@ memory.search("Who is spiderman?", { userId: "alice123" });
|
|||||||
|
|
||||||
> **Note:** The Graph Memory implementation is not standalone. You will be adding/retrieving memories to the vector store and the graph store simultaneously.
|
> **Note:** The Graph Memory implementation is not standalone. You will be adding/retrieving memories to the vector store and the graph store simultaneously.
|
||||||
|
|
||||||
|
## Using Multiple Agents with Graph Memory
|
||||||
|
|
||||||
|
When working with multiple agents, you can use the `agent_id` parameter to organize memories by both user and agent. This allows you to:
|
||||||
|
|
||||||
|
1. Create agent-specific knowledge graphs
|
||||||
|
2. Share common knowledge between agents
|
||||||
|
3. Isolate sensitive or specialized information to specific agents
|
||||||
|
|
||||||
|
### Example: Multi-Agent Setup
|
||||||
|
|
||||||
|
<CodeGroup>
|
||||||
|
```python Python
|
||||||
|
# Add memories for different agents
|
||||||
|
m.add("I prefer Italian cuisine", user_id="bob", agent_id="food-assistant")
|
||||||
|
m.add("I'm allergic to peanuts", user_id="bob", agent_id="health-assistant")
|
||||||
|
m.add("I live in Seattle", user_id="bob") # Shared across all agents
|
||||||
|
|
||||||
|
# Search within specific agent context
|
||||||
|
food_preferences = m.search("What food do I like?", user_id="bob", agent_id="food-assistant")
|
||||||
|
health_info = m.search("What are my allergies?", user_id="bob", agent_id="health-assistant")
|
||||||
|
location = m.search("Where do I live?", user_id="bob") # Searches across all agents
|
||||||
|
```
|
||||||
|
|
||||||
|
```typescript TypeScript
|
||||||
|
// Add memories for different agents
|
||||||
|
memory.add("I prefer Italian cuisine", { userId: "bob", agentId: "food-assistant" });
|
||||||
|
memory.add("I'm allergic to peanuts", { userId: "bob", agentId: "health-assistant" });
|
||||||
|
memory.add("I live in Seattle", { userId: "bob" }); // Shared across all agents
|
||||||
|
|
||||||
|
// Search within specific agent context
|
||||||
|
const foodPreferences = memory.search("What food do I like?", { userId: "bob", agentId: "food-assistant" });
|
||||||
|
const healthInfo = memory.search("What are my allergies?", { userId: "bob", agentId: "health-assistant" });
|
||||||
|
const location = memory.search("Where do I live?", { userId: "bob" }); // Searches across all agents
|
||||||
|
```
|
||||||
|
</CodeGroup>
|
||||||
|
|
||||||
If you want to use a managed version of Mem0, please check out [Mem0](https://mem0.dev/pd). If you have any questions, please feel free to reach out to us using one of the following methods:
|
If you want to use a managed version of Mem0, please check out [Mem0](https://mem0.dev/pd). If you have any questions, please feel free to reach out to us using one of the following methods:
|
||||||
|
|
||||||
<Snippet file="get-help.mdx" />
|
<Snippet file="get-help.mdx" />
|
||||||
|
|||||||
4
embedchain/poetry.lock
generated
4
embedchain/poetry.lock
generated
@@ -2552,7 +2552,7 @@ azure = ["adlfs (>=2024.2.0)"]
|
|||||||
clip = ["open-clip", "pillow", "torch"]
|
clip = ["open-clip", "pillow", "torch"]
|
||||||
dev = ["pre-commit", "ruff"]
|
dev = ["pre-commit", "ruff"]
|
||||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||||
embeddings = ["awscli (>=1.29.57)", "boto3 (>=1.28.57)", "botocore (>=1.31.57)", "cohere", "google-generativeai", "huggingface-hub", "instructorembedding", "open-clip-torch", "openai (>=1.6.1)", "pillow", "sentence-transformers", "torch"]
|
embeddings = ["awscli (>=1.29.57)", "boto3 (>=1.28.57)", "botocore (>=1.31.57)", "cohere", "google-generativeai", "huggingface-hub", "instructorembedding", "open-clip-torch", "openai (>=1.6.1)", "pillow", "sentence-transformers", "torch", "google-genai"]
|
||||||
tests = ["aiohttp", "boto3", "duckdb", "pandas (>=1.4)", "polars (>=0.19)", "pytest", "pytest-asyncio", "pytest-mock", "pytz", "tantivy"]
|
tests = ["aiohttp", "boto3", "duckdb", "pandas (>=1.4)", "polars (>=0.19)", "pytest", "pytest-asyncio", "pytest-mock", "pytz", "tantivy"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -7129,7 +7129,7 @@ cffi = ["cffi (>=1.11)"]
|
|||||||
aws = ["langchain-aws"]
|
aws = ["langchain-aws"]
|
||||||
elasticsearch = ["elasticsearch"]
|
elasticsearch = ["elasticsearch"]
|
||||||
gmail = ["google-api-core", "google-api-python-client", "google-auth", "google-auth-httplib2", "google-auth-oauthlib", "requests"]
|
gmail = ["google-api-core", "google-api-python-client", "google-auth", "google-auth-httplib2", "google-auth-oauthlib", "requests"]
|
||||||
google = ["google-generativeai"]
|
google = ["google-generativeai", "google-genai"]
|
||||||
googledrive = ["google-api-python-client", "google-auth-httplib2", "google-auth-oauthlib"]
|
googledrive = ["google-api-python-client", "google-auth-httplib2", "google-auth-oauthlib"]
|
||||||
lancedb = ["lancedb"]
|
lancedb = ["lancedb"]
|
||||||
llama2 = ["replicate"]
|
llama2 = ["replicate"]
|
||||||
|
|||||||
@@ -2,9 +2,9 @@ import os
|
|||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import google.generativeai as genai
|
from google import genai
|
||||||
from google.generativeai import GenerativeModel, protos
|
from google.genai import types
|
||||||
from google.generativeai.types import content_types
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'."
|
"The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'."
|
||||||
@@ -22,66 +22,71 @@ class GeminiLLM(LLMBase):
|
|||||||
self.config.model = "gemini-1.5-flash-latest"
|
self.config.model = "gemini-1.5-flash-latest"
|
||||||
|
|
||||||
api_key = self.config.api_key or os.getenv("GEMINI_API_KEY")
|
api_key = self.config.api_key or os.getenv("GEMINI_API_KEY")
|
||||||
genai.configure(api_key=api_key)
|
self.client_gemini = genai.Client(
|
||||||
self.client = GenerativeModel(model_name=self.config.model)
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
|
||||||
def _parse_response(self, response, tools):
|
def _parse_response(self, response, tools):
|
||||||
"""
|
"""
|
||||||
Process the response based on whether tools are used or not.
|
Process the response based on whether tools are used or not.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
response: The raw response from API.
|
response: The raw response from the API.
|
||||||
tools: The list of tools provided in the request.
|
tools: The list of tools provided in the request.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str or dict: The processed response.
|
str or dict: The processed response.
|
||||||
"""
|
"""
|
||||||
|
candidate = response.candidates[0]
|
||||||
|
content = candidate.content.parts[0].text if candidate.content.parts else None
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
processed_response = {
|
processed_response = {
|
||||||
"content": (content if (content := response.candidates[0].content.parts[0].text) else None),
|
"content": content,
|
||||||
"tool_calls": [],
|
"tool_calls": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
for part in response.candidates[0].content.parts:
|
for part in candidate.content.parts:
|
||||||
if fn := part.function_call:
|
fn = getattr(part, "function_call", None)
|
||||||
if isinstance(fn, protos.FunctionCall):
|
if fn:
|
||||||
fn_call = type(fn).to_dict(fn)
|
processed_response["tool_calls"].append({
|
||||||
processed_response["tool_calls"].append({"name": fn_call["name"], "arguments": fn_call["args"]})
|
"name": fn.name,
|
||||||
continue
|
"arguments": fn.args,
|
||||||
processed_response["tool_calls"].append({"name": fn.name, "arguments": fn.args})
|
})
|
||||||
|
|
||||||
return processed_response
|
return processed_response
|
||||||
else:
|
|
||||||
return response.candidates[0].content.parts[0].text
|
|
||||||
|
|
||||||
def _reformat_messages(self, messages: List[Dict[str, str]]):
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def _reformat_messages(self, messages: List[Dict[str, str]]) -> List[types.Content]:
|
||||||
"""
|
"""
|
||||||
Reformat messages for Gemini.
|
Reformat messages for Gemini using google.genai.types.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: The list of messages provided in the request.
|
messages: The list of messages provided in the request.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list: The list of messages in the required format.
|
list: A list of types.Content objects with proper role and parts.
|
||||||
"""
|
"""
|
||||||
new_messages = []
|
new_messages = []
|
||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if message["role"] == "system":
|
if message["role"] == "system":
|
||||||
content = "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"]
|
content = "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
|
|
||||||
new_messages.append(
|
new_messages.append(
|
||||||
{
|
types.Content(
|
||||||
"parts": content,
|
role="model" if message["role"] == "model" else "user",
|
||||||
"role": "model" if message["role"] == "model" else "user",
|
parts=[types.Part(text=content)]
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return new_messages
|
return new_messages
|
||||||
|
|
||||||
|
|
||||||
def _reformat_tools(self, tools: Optional[List[Dict]]):
|
def _reformat_tools(self, tools: Optional[List[Dict]]):
|
||||||
"""
|
"""
|
||||||
Reformat tools for Gemini.
|
Reformat tools for Gemini.
|
||||||
@@ -126,6 +131,7 @@ class GeminiLLM(LLMBase):
|
|||||||
tools: Optional[List[Dict]] = None,
|
tools: Optional[List[Dict]] = None,
|
||||||
tool_choice: str = "auto",
|
tool_choice: str = "auto",
|
||||||
):
|
):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Generate a response based on the given messages using Gemini.
|
Generate a response based on the given messages using Gemini.
|
||||||
|
|
||||||
@@ -149,23 +155,37 @@ class GeminiLLM(LLMBase):
|
|||||||
params["response_mime_type"] = "application/json"
|
params["response_mime_type"] = "application/json"
|
||||||
if "schema" in response_format:
|
if "schema" in response_format:
|
||||||
params["response_schema"] = response_format["schema"]
|
params["response_schema"] = response_format["schema"]
|
||||||
|
|
||||||
|
tool_config = None
|
||||||
if tool_choice:
|
if tool_choice:
|
||||||
tool_config = content_types.to_tool_config(
|
tool_config = types.ToolConfig(
|
||||||
{
|
function_calling_config=types.FunctionCallingConfig(
|
||||||
"function_calling_config": {
|
mode=tool_choice.upper(), # Assuming 'any' should become 'ANY', etc.
|
||||||
"mode": tool_choice,
|
allowed_function_names=[
|
||||||
"allowed_function_names": (
|
tool["function"]["name"] for tool in tools
|
||||||
[tool["function"]["name"] for tool in tools] if tool_choice == "any" else None
|
] if tool_choice == "any" else None
|
||||||
),
|
)
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
response = self.client.generate_content(
|
print(f"Tool config: {tool_config}")
|
||||||
contents=self._reformat_messages(messages),
|
print(f"Params: {params}" )
|
||||||
tools=self._reformat_tools(tools),
|
print(f"Messages: {messages}")
|
||||||
generation_config=genai.GenerationConfig(**params),
|
print(f"Tools: {tools}")
|
||||||
tool_config=tool_config,
|
print(f"Reformatted messages: {self._reformat_messages(messages)}")
|
||||||
)
|
print(f"Reformatted tools: {self._reformat_tools(tools)}")
|
||||||
|
|
||||||
|
response = self.client_gemini.models.generate_content(
|
||||||
|
model=self.config.model,
|
||||||
|
contents=self._reformat_messages(messages),
|
||||||
|
config=types.GenerateContentConfig(
|
||||||
|
temperature= self.config.temperature,
|
||||||
|
max_output_tokens= self.config.max_tokens,
|
||||||
|
top_p= self.config.top_p,
|
||||||
|
tools=self._reformat_tools(tools),
|
||||||
|
tool_config=tool_config,
|
||||||
|
|
||||||
|
),
|
||||||
|
)
|
||||||
|
print(f"Response test: {response}")
|
||||||
|
|
||||||
return self._parse_response(response, tools)
|
return self._parse_response(response, tools)
|
||||||
|
|||||||
@@ -80,8 +80,8 @@ class MemoryGraph:
|
|||||||
|
|
||||||
# TODO: Batch queries with APOC plugin
|
# TODO: Batch queries with APOC plugin
|
||||||
# TODO: Add more filter support
|
# TODO: Add more filter support
|
||||||
deleted_entities = self._delete_entities(to_be_deleted, filters["user_id"])
|
deleted_entities = self._delete_entities(to_be_deleted, filters)
|
||||||
added_entities = self._add_entities(to_be_added, filters["user_id"], entity_type_map)
|
added_entities = self._add_entities(to_be_added, filters, entity_type_map)
|
||||||
|
|
||||||
return {"deleted_entities": deleted_entities, "added_entities": added_entities}
|
return {"deleted_entities": deleted_entities, "added_entities": added_entities}
|
||||||
|
|
||||||
@@ -122,32 +122,35 @@ class MemoryGraph:
|
|||||||
return search_results
|
return search_results
|
||||||
|
|
||||||
def delete_all(self, filters):
|
def delete_all(self, filters):
|
||||||
cypher = f"""
|
if filters.get("agent_id"):
|
||||||
MATCH (n {self.node_label} {{user_id: $user_id}})
|
cypher = f"""
|
||||||
DETACH DELETE n
|
MATCH (n {self.node_label} {{user_id: $user_id, agent_id: $agent_id}})
|
||||||
"""
|
DETACH DELETE n
|
||||||
params = {"user_id": filters["user_id"]}
|
"""
|
||||||
|
params = {"user_id": filters["user_id"], "agent_id": filters["agent_id"]}
|
||||||
|
else:
|
||||||
|
cypher = f"""
|
||||||
|
MATCH (n {self.node_label} {{user_id: $user_id}})
|
||||||
|
DETACH DELETE n
|
||||||
|
"""
|
||||||
|
params = {"user_id": filters["user_id"]}
|
||||||
self.graph.query(cypher, params=params)
|
self.graph.query(cypher, params=params)
|
||||||
|
|
||||||
def get_all(self, filters, limit=100):
|
|
||||||
"""
|
|
||||||
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
|
|
||||||
|
|
||||||
Args:
|
def get_all(self, filters, limit=100):
|
||||||
filters (dict): A dictionary containing filters to be applied during the retrieval.
|
agent_filter = ""
|
||||||
limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100.
|
params = {"user_id": filters["user_id"], "limit": limit}
|
||||||
Returns:
|
if filters.get("agent_id"):
|
||||||
list: A list of dictionaries, each containing:
|
agent_filter = "AND n.agent_id = $agent_id AND m.agent_id = $agent_id"
|
||||||
- 'contexts': The base data store response for each memory.
|
params["agent_id"] = filters["agent_id"]
|
||||||
- 'entities': A list of strings representing the nodes and relationships
|
|
||||||
"""
|
|
||||||
# return all nodes and relationships
|
|
||||||
query = f"""
|
query = f"""
|
||||||
MATCH (n {self.node_label} {{user_id: $user_id}})-[r]->(m {self.node_label} {{user_id: $user_id}})
|
MATCH (n {self.node_label} {{user_id: $user_id}})-[r]->(m {self.node_label} {{user_id: $user_id}})
|
||||||
|
WHERE 1=1 {agent_filter}
|
||||||
RETURN n.name AS source, type(r) AS relationship, m.name AS target
|
RETURN n.name AS source, type(r) AS relationship, m.name AS target
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
results = self.graph.query(query, params={"user_id": filters["user_id"], "limit": limit})
|
results = self.graph.query(query, params=params)
|
||||||
|
|
||||||
final_results = []
|
final_results = []
|
||||||
for result in results:
|
for result in results:
|
||||||
@@ -163,6 +166,7 @@ class MemoryGraph:
|
|||||||
|
|
||||||
return final_results
|
return final_results
|
||||||
|
|
||||||
|
|
||||||
def _retrieve_nodes_from_data(self, data, filters):
|
def _retrieve_nodes_from_data(self, data, filters):
|
||||||
"""Extracts all the entities mentioned in the query."""
|
"""Extracts all the entities mentioned in the query."""
|
||||||
_tools = [EXTRACT_ENTITIES_TOOL]
|
_tools = [EXTRACT_ENTITIES_TOOL]
|
||||||
@@ -197,23 +201,27 @@ class MemoryGraph:
|
|||||||
return entity_type_map
|
return entity_type_map
|
||||||
|
|
||||||
def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
|
def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
|
||||||
"""Eshtablish relations among the extracted nodes."""
|
"""Establish relations among the extracted nodes."""
|
||||||
|
|
||||||
|
# Compose user identification string for prompt
|
||||||
|
user_identity = f"user_id: {filters['user_id']}"
|
||||||
|
if filters.get("agent_id"):
|
||||||
|
user_identity += f", agent_id: {filters['agent_id']}"
|
||||||
|
|
||||||
if self.config.graph_store.custom_prompt:
|
if self.config.graph_store.custom_prompt:
|
||||||
|
system_content = EXTRACT_RELATIONS_PROMPT.replace("USER_ID", user_identity)
|
||||||
|
# Add the custom prompt line if configured
|
||||||
|
system_content = system_content.replace(
|
||||||
|
"CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}"
|
||||||
|
)
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{"role": "system", "content": system_content},
|
||||||
"role": "system",
|
|
||||||
"content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]).replace(
|
|
||||||
"CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}"
|
|
||||||
),
|
|
||||||
},
|
|
||||||
{"role": "user", "content": data},
|
{"role": "user", "content": data},
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
|
system_content = EXTRACT_RELATIONS_PROMPT.replace("USER_ID", user_identity)
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{"role": "system", "content": system_content},
|
||||||
"role": "system",
|
|
||||||
"content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]),
|
|
||||||
},
|
|
||||||
{"role": "user", "content": f"List of entities: {list(entity_type_map.keys())}. \n\nText: {data}"},
|
{"role": "user", "content": f"List of entities: {list(entity_type_map.keys())}. \n\nText: {data}"},
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -227,8 +235,8 @@ class MemoryGraph:
|
|||||||
)
|
)
|
||||||
|
|
||||||
entities = []
|
entities = []
|
||||||
if extracted_entities["tool_calls"]:
|
if extracted_entities.get("tool_calls"):
|
||||||
entities = extracted_entities["tool_calls"][0]["arguments"]["entities"]
|
entities = extracted_entities["tool_calls"][0].get("arguments", {}).get("entities", [])
|
||||||
|
|
||||||
entities = self._remove_spaces_from_entities(entities)
|
entities = self._remove_spaces_from_entities(entities)
|
||||||
logger.debug(f"Extracted entities: {entities}")
|
logger.debug(f"Extracted entities: {entities}")
|
||||||
@@ -237,32 +245,43 @@ class MemoryGraph:
|
|||||||
def _search_graph_db(self, node_list, filters, limit=100):
|
def _search_graph_db(self, node_list, filters, limit=100):
|
||||||
"""Search similar nodes among and their respective incoming and outgoing relations."""
|
"""Search similar nodes among and their respective incoming and outgoing relations."""
|
||||||
result_relations = []
|
result_relations = []
|
||||||
|
agent_filter = ""
|
||||||
|
if filters.get("agent_id"):
|
||||||
|
agent_filter = "AND n.agent_id = $agent_id AND m.agent_id = $agent_id"
|
||||||
|
|
||||||
for node in node_list:
|
for node in node_list:
|
||||||
n_embedding = self.embedding_model.embed(node)
|
n_embedding = self.embedding_model.embed(node)
|
||||||
|
|
||||||
cypher_query = f"""
|
cypher_query = f"""
|
||||||
MATCH (n {self.node_label})
|
MATCH (n {self.node_label})
|
||||||
WHERE n.embedding IS NOT NULL AND n.user_id = $user_id
|
WHERE n.embedding IS NOT NULL AND n.user_id = $user_id
|
||||||
|
{agent_filter}
|
||||||
WITH n, round(2 * vector.similarity.cosine(n.embedding, $n_embedding) - 1, 4) AS similarity // denormalize for backward compatibility
|
WITH n, round(2 * vector.similarity.cosine(n.embedding, $n_embedding) - 1, 4) AS similarity // denormalize for backward compatibility
|
||||||
WHERE similarity >= $threshold
|
WHERE similarity >= $threshold
|
||||||
CALL (n) {{
|
CALL {{
|
||||||
MATCH (n)-[r]->(m)
|
MATCH (n)-[r]->(m)
|
||||||
|
WHERE m.user_id = $user_id {agent_filter.replace("n.", "m.")}
|
||||||
RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relationship, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id
|
RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relationship, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id
|
||||||
UNION
|
UNION
|
||||||
MATCH (m)-[r]->(n)
|
MATCH (m)-[r]->(n)
|
||||||
|
WHERE m.user_id = $user_id {agent_filter.replace("n.", "m.")}
|
||||||
RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relationship, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id
|
RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relationship, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id
|
||||||
}}
|
}}
|
||||||
WITH distinct source, source_id, relationship, relation_id, destination, destination_id, similarity //deduplicate
|
WITH distinct source, source_id, relationship, relation_id, destination, destination_id, similarity
|
||||||
RETURN source, source_id, relationship, relation_id, destination, destination_id, similarity
|
RETURN source, source_id, relationship, relation_id, destination, destination_id, similarity
|
||||||
ORDER BY similarity DESC
|
ORDER BY similarity DESC
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"n_embedding": n_embedding,
|
"n_embedding": n_embedding,
|
||||||
"threshold": self.threshold,
|
"threshold": self.threshold,
|
||||||
"user_id": filters["user_id"],
|
"user_id": filters["user_id"],
|
||||||
"limit": limit,
|
"limit": limit,
|
||||||
}
|
}
|
||||||
|
if filters.get("agent_id"):
|
||||||
|
params["agent_id"] = filters["agent_id"]
|
||||||
|
|
||||||
ans = self.graph.query(cypher_query, params=params)
|
ans = self.graph.query(cypher_query, params=params)
|
||||||
result_relations.extend(ans)
|
result_relations.extend(ans)
|
||||||
|
|
||||||
@@ -271,7 +290,13 @@ class MemoryGraph:
|
|||||||
def _get_delete_entities_from_search_output(self, search_output, data, filters):
|
def _get_delete_entities_from_search_output(self, search_output, data, filters):
|
||||||
"""Get the entities to be deleted from the search output."""
|
"""Get the entities to be deleted from the search output."""
|
||||||
search_output_string = format_entities(search_output)
|
search_output_string = format_entities(search_output)
|
||||||
system_prompt, user_prompt = get_delete_messages(search_output_string, data, filters["user_id"])
|
|
||||||
|
# Compose user identification string for prompt
|
||||||
|
user_identity = f"user_id: {filters['user_id']}"
|
||||||
|
if filters.get("agent_id"):
|
||||||
|
user_identity += f", agent_id: {filters['agent_id']}"
|
||||||
|
|
||||||
|
system_prompt, user_prompt = get_delete_messages(search_output_string, data, user_identity)
|
||||||
|
|
||||||
_tools = [DELETE_MEMORY_TOOL_GRAPH]
|
_tools = [DELETE_MEMORY_TOOL_GRAPH]
|
||||||
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
||||||
@@ -288,44 +313,59 @@ class MemoryGraph:
|
|||||||
)
|
)
|
||||||
|
|
||||||
to_be_deleted = []
|
to_be_deleted = []
|
||||||
for item in memory_updates["tool_calls"]:
|
for item in memory_updates.get("tool_calls", []):
|
||||||
if item["name"] == "delete_graph_memory":
|
if item.get("name") == "delete_graph_memory":
|
||||||
to_be_deleted.append(item["arguments"])
|
to_be_deleted.append(item.get("arguments"))
|
||||||
# in case if it is not in the correct format
|
# Clean entities formatting
|
||||||
to_be_deleted = self._remove_spaces_from_entities(to_be_deleted)
|
to_be_deleted = self._remove_spaces_from_entities(to_be_deleted)
|
||||||
logger.debug(f"Deleted relationships: {to_be_deleted}")
|
logger.debug(f"Deleted relationships: {to_be_deleted}")
|
||||||
return to_be_deleted
|
return to_be_deleted
|
||||||
|
|
||||||
def _delete_entities(self, to_be_deleted, user_id):
|
def _delete_entities(self, to_be_deleted, filters):
|
||||||
"""Delete the entities from the graph."""
|
"""Delete the entities from the graph."""
|
||||||
|
user_id = filters["user_id"]
|
||||||
|
agent_id = filters.get("agent_id", None)
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
for item in to_be_deleted:
|
for item in to_be_deleted:
|
||||||
source = item["source"]
|
source = item["source"]
|
||||||
destination = item["destination"]
|
destination = item["destination"]
|
||||||
relationship = item["relationship"]
|
relationship = item["relationship"]
|
||||||
|
|
||||||
|
# Build the agent filter for the query
|
||||||
|
agent_filter = ""
|
||||||
|
params = {
|
||||||
|
"source_name": source,
|
||||||
|
"dest_name": destination,
|
||||||
|
"user_id": user_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
if agent_id:
|
||||||
|
agent_filter = "AND n.agent_id = $agent_id AND m.agent_id = $agent_id"
|
||||||
|
params["agent_id"] = agent_id
|
||||||
|
|
||||||
# Delete the specific relationship between nodes
|
# Delete the specific relationship between nodes
|
||||||
cypher = f"""
|
cypher = f"""
|
||||||
MATCH (n {self.node_label} {{name: $source_name, user_id: $user_id}})
|
MATCH (n {self.node_label} {{name: $source_name, user_id: $user_id}})
|
||||||
-[r:{relationship}]->
|
-[r:{relationship}]->
|
||||||
(m {self.node_label} {{name: $dest_name, user_id: $user_id}})
|
(m {self.node_label} {{name: $dest_name, user_id: $user_id}})
|
||||||
|
WHERE 1=1 {agent_filter}
|
||||||
DELETE r
|
DELETE r
|
||||||
RETURN
|
RETURN
|
||||||
n.name AS source,
|
n.name AS source,
|
||||||
m.name AS target,
|
m.name AS target,
|
||||||
type(r) AS relationship
|
type(r) AS relationship
|
||||||
"""
|
"""
|
||||||
params = {
|
|
||||||
"source_name": source,
|
|
||||||
"dest_name": destination,
|
|
||||||
"user_id": user_id,
|
|
||||||
}
|
|
||||||
result = self.graph.query(cypher, params=params)
|
result = self.graph.query(cypher, params=params)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def _add_entities(self, to_be_added, user_id, entity_type_map):
|
def _add_entities(self, to_be_added, filters, entity_type_map):
|
||||||
"""Add the new entities to the graph. Merge the nodes if they already exist."""
|
"""Add the new entities to the graph. Merge the nodes if they already exist."""
|
||||||
|
user_id = filters["user_id"]
|
||||||
|
agent_id = filters.get("agent_id", None)
|
||||||
results = []
|
results = []
|
||||||
for item in to_be_added:
|
for item in to_be_added:
|
||||||
# entities
|
# entities
|
||||||
@@ -346,65 +386,80 @@ class MemoryGraph:
|
|||||||
dest_embedding = self.embedding_model.embed(destination)
|
dest_embedding = self.embedding_model.embed(destination)
|
||||||
|
|
||||||
# search for the nodes with the closest embeddings
|
# search for the nodes with the closest embeddings
|
||||||
source_node_search_result = self._search_source_node(source_embedding, user_id, threshold=0.9)
|
source_node_search_result = self._search_source_node(source_embedding, filters, threshold=0.9)
|
||||||
destination_node_search_result = self._search_destination_node(dest_embedding, user_id, threshold=0.9)
|
destination_node_search_result = self._search_destination_node(dest_embedding, filters, threshold=0.9)
|
||||||
|
|
||||||
# TODO: Create a cypher query and common params for all the cases
|
# TODO: Create a cypher query and common params for all the cases
|
||||||
if not destination_node_search_result and source_node_search_result:
|
if not destination_node_search_result and source_node_search_result:
|
||||||
cypher = f"""
|
# Build destination MERGE properties
|
||||||
MATCH (source)
|
merge_props = ["name: $destination_name", "user_id: $user_id"]
|
||||||
WHERE elementId(source) = $source_id
|
if agent_id:
|
||||||
SET source.mentions = coalesce(source.mentions, 0) + 1
|
merge_props.append("agent_id: $agent_id")
|
||||||
WITH source
|
merge_props_str = ", ".join(merge_props)
|
||||||
MERGE (destination {destination_label} {{name: $destination_name, user_id: $user_id}})
|
|
||||||
ON CREATE SET
|
|
||||||
destination.created = timestamp(),
|
|
||||||
destination.mentions = 1
|
|
||||||
{destination_extra_set}
|
|
||||||
ON MATCH SET
|
|
||||||
destination.mentions = coalesce(destination.mentions, 0) + 1
|
|
||||||
WITH source, destination
|
|
||||||
CALL db.create.setNodeVectorProperty(destination, 'embedding', $destination_embedding)
|
|
||||||
WITH source, destination
|
|
||||||
MERGE (source)-[r:{relationship}]->(destination)
|
|
||||||
ON CREATE SET
|
|
||||||
r.created = timestamp(),
|
|
||||||
r.mentions = 1
|
|
||||||
ON MATCH SET
|
|
||||||
r.mentions = coalesce(r.mentions, 0) + 1
|
|
||||||
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
cypher = f"""
|
||||||
|
MATCH (source)
|
||||||
|
WHERE elementId(source) = $source_id
|
||||||
|
SET source.mentions = coalesce(source.mentions, 0) + 1
|
||||||
|
WITH source
|
||||||
|
MERGE (destination {destination_label} {{{merge_props_str}}})
|
||||||
|
ON CREATE SET
|
||||||
|
destination.created = timestamp(),
|
||||||
|
destination.mentions = 1
|
||||||
|
{destination_extra_set}
|
||||||
|
ON MATCH SET
|
||||||
|
destination.mentions = coalesce(destination.mentions, 0) + 1
|
||||||
|
WITH source, destination
|
||||||
|
CALL db.create.setNodeVectorProperty(destination, 'embedding', $destination_embedding)
|
||||||
|
WITH source, destination
|
||||||
|
MERGE (source)-[r:{relationship}]->(destination)
|
||||||
|
ON CREATE SET
|
||||||
|
r.created = timestamp(),
|
||||||
|
r.mentions = 1
|
||||||
|
ON MATCH SET
|
||||||
|
r.mentions = coalesce(r.mentions, 0) + 1
|
||||||
|
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
||||||
|
"""
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"source_id": source_node_search_result[0]["elementId(source_candidate)"],
|
"source_id": source_node_search_result[0]["elementId(source_candidate)"],
|
||||||
"destination_name": destination,
|
"destination_name": destination,
|
||||||
"destination_embedding": dest_embedding,
|
"destination_embedding": dest_embedding,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
}
|
}
|
||||||
|
if agent_id:
|
||||||
|
params["agent_id"] = agent_id
|
||||||
|
|
||||||
elif destination_node_search_result and not source_node_search_result:
|
elif destination_node_search_result and not source_node_search_result:
|
||||||
|
# Build source MERGE properties
|
||||||
|
merge_props = ["name: $source_name", "user_id: $user_id"]
|
||||||
|
if agent_id:
|
||||||
|
merge_props.append("agent_id: $agent_id")
|
||||||
|
merge_props_str = ", ".join(merge_props)
|
||||||
|
|
||||||
cypher = f"""
|
cypher = f"""
|
||||||
MATCH (destination)
|
MATCH (destination)
|
||||||
WHERE elementId(destination) = $destination_id
|
WHERE elementId(destination) = $destination_id
|
||||||
SET destination.mentions = coalesce(destination.mentions, 0) + 1
|
SET destination.mentions = coalesce(destination.mentions, 0) + 1
|
||||||
WITH destination
|
WITH destination
|
||||||
MERGE (source {source_label} {{name: $source_name, user_id: $user_id}})
|
MERGE (source {source_label} {{{merge_props_str}}})
|
||||||
ON CREATE SET
|
ON CREATE SET
|
||||||
source.created = timestamp(),
|
source.created = timestamp(),
|
||||||
source.mentions = 1
|
source.mentions = 1
|
||||||
{source_extra_set}
|
{source_extra_set}
|
||||||
ON MATCH SET
|
ON MATCH SET
|
||||||
source.mentions = coalesce(source.mentions, 0) + 1
|
source.mentions = coalesce(source.mentions, 0) + 1
|
||||||
WITH source, destination
|
WITH source, destination
|
||||||
CALL db.create.setNodeVectorProperty(source, 'embedding', $source_embedding)
|
CALL db.create.setNodeVectorProperty(source, 'embedding', $source_embedding)
|
||||||
WITH source, destination
|
WITH source, destination
|
||||||
MERGE (source)-[r:{relationship}]->(destination)
|
MERGE (source)-[r:{relationship}]->(destination)
|
||||||
ON CREATE SET
|
ON CREATE SET
|
||||||
r.created = timestamp(),
|
r.created = timestamp(),
|
||||||
r.mentions = 1
|
r.mentions = 1
|
||||||
ON MATCH SET
|
ON MATCH SET
|
||||||
r.mentions = coalesce(r.mentions, 0) + 1
|
r.mentions = coalesce(r.mentions, 0) + 1
|
||||||
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
||||||
"""
|
"""
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"destination_id": destination_node_search_result[0]["elementId(destination_candidate)"],
|
"destination_id": destination_node_search_result[0]["elementId(destination_candidate)"],
|
||||||
@@ -412,53 +467,68 @@ class MemoryGraph:
|
|||||||
"source_embedding": source_embedding,
|
"source_embedding": source_embedding,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
}
|
}
|
||||||
|
if agent_id:
|
||||||
|
params["agent_id"] = agent_id
|
||||||
|
|
||||||
elif source_node_search_result and destination_node_search_result:
|
elif source_node_search_result and destination_node_search_result:
|
||||||
cypher = f"""
|
cypher = f"""
|
||||||
MATCH (source)
|
MATCH (source)
|
||||||
WHERE elementId(source) = $source_id
|
WHERE elementId(source) = $source_id
|
||||||
SET source.mentions = coalesce(source.mentions, 0) + 1
|
SET source.mentions = coalesce(source.mentions, 0) + 1
|
||||||
WITH source
|
WITH source
|
||||||
MATCH (destination)
|
MATCH (destination)
|
||||||
WHERE elementId(destination) = $destination_id
|
WHERE elementId(destination) = $destination_id
|
||||||
SET destination.mentions = coalesce(destination.mentions) + 1
|
SET destination.mentions = coalesce(destination.mentions, 0) + 1
|
||||||
MERGE (source)-[r:{relationship}]->(destination)
|
MERGE (source)-[r:{relationship}]->(destination)
|
||||||
ON CREATE SET
|
ON CREATE SET
|
||||||
r.created_at = timestamp(),
|
r.created_at = timestamp(),
|
||||||
r.updated_at = timestamp(),
|
r.updated_at = timestamp(),
|
||||||
r.mentions = 1
|
r.mentions = 1
|
||||||
ON MATCH SET r.mentions = coalesce(r.mentions, 0) + 1
|
ON MATCH SET r.mentions = coalesce(r.mentions, 0) + 1
|
||||||
|
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
||||||
|
"""
|
||||||
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
|
||||||
"""
|
|
||||||
params = {
|
params = {
|
||||||
"source_id": source_node_search_result[0]["elementId(source_candidate)"],
|
"source_id": source_node_search_result[0]["elementId(source_candidate)"],
|
||||||
"destination_id": destination_node_search_result[0]["elementId(destination_candidate)"],
|
"destination_id": destination_node_search_result[0]["elementId(destination_candidate)"],
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
}
|
}
|
||||||
|
if agent_id:
|
||||||
|
params["agent_id"] = agent_id
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
# Build dynamic MERGE props for both source and destination
|
||||||
|
source_props = ["name: $source_name", "user_id: $user_id"]
|
||||||
|
dest_props = ["name: $dest_name", "user_id: $user_id"]
|
||||||
|
if agent_id:
|
||||||
|
source_props.append("agent_id: $agent_id")
|
||||||
|
dest_props.append("agent_id: $agent_id")
|
||||||
|
source_props_str = ", ".join(source_props)
|
||||||
|
dest_props_str = ", ".join(dest_props)
|
||||||
|
|
||||||
cypher = f"""
|
cypher = f"""
|
||||||
MERGE (source {source_label} {{name: $source_name, user_id: $user_id}})
|
MERGE (source {source_label} {{{source_props_str}}})
|
||||||
ON CREATE SET source.created = timestamp(),
|
ON CREATE SET source.created = timestamp(),
|
||||||
source.mentions = 1
|
source.mentions = 1
|
||||||
{source_extra_set}
|
{source_extra_set}
|
||||||
ON MATCH SET source.mentions = coalesce(source.mentions, 0) + 1
|
ON MATCH SET source.mentions = coalesce(source.mentions, 0) + 1
|
||||||
WITH source
|
WITH source
|
||||||
CALL db.create.setNodeVectorProperty(source, 'embedding', $source_embedding)
|
CALL db.create.setNodeVectorProperty(source, 'embedding', $source_embedding)
|
||||||
WITH source
|
WITH source
|
||||||
MERGE (destination {destination_label} {{name: $dest_name, user_id: $user_id}})
|
MERGE (destination {destination_label} {{{dest_props_str}}})
|
||||||
ON CREATE SET destination.created = timestamp(),
|
ON CREATE SET destination.created = timestamp(),
|
||||||
destination.mentions = 1
|
destination.mentions = 1
|
||||||
{destination_extra_set}
|
{destination_extra_set}
|
||||||
ON MATCH SET destination.mentions = coalesce(destination.mentions, 0) + 1
|
ON MATCH SET destination.mentions = coalesce(destination.mentions, 0) + 1
|
||||||
WITH source, destination
|
WITH source, destination
|
||||||
CALL db.create.setNodeVectorProperty(destination, 'embedding', $source_embedding)
|
CALL db.create.setNodeVectorProperty(destination, 'embedding', $dest_embedding)
|
||||||
WITH source, destination
|
WITH source, destination
|
||||||
MERGE (source)-[rel:{relationship}]->(destination)
|
MERGE (source)-[rel:{relationship}]->(destination)
|
||||||
ON CREATE SET rel.created = timestamp(), rel.mentions = 1
|
ON CREATE SET rel.created = timestamp(), rel.mentions = 1
|
||||||
ON MATCH SET rel.mentions = coalesce(rel.mentions, 0) + 1
|
ON MATCH SET rel.mentions = coalesce(rel.mentions, 0) + 1
|
||||||
RETURN source.name AS source, type(rel) AS relationship, destination.name AS target
|
RETURN source.name AS source, type(rel) AS relationship, destination.name AS target
|
||||||
"""
|
"""
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"source_name": source,
|
"source_name": source,
|
||||||
"dest_name": destination,
|
"dest_name": destination,
|
||||||
@@ -466,6 +536,8 @@ class MemoryGraph:
|
|||||||
"dest_embedding": dest_embedding,
|
"dest_embedding": dest_embedding,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
}
|
}
|
||||||
|
if agent_id:
|
||||||
|
params["agent_id"] = agent_id
|
||||||
result = self.graph.query(cypher, params=params)
|
result = self.graph.query(cypher, params=params)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
return results
|
return results
|
||||||
@@ -477,11 +549,16 @@ class MemoryGraph:
|
|||||||
item["destination"] = item["destination"].lower().replace(" ", "_")
|
item["destination"] = item["destination"].lower().replace(" ", "_")
|
||||||
return entity_list
|
return entity_list
|
||||||
|
|
||||||
def _search_source_node(self, source_embedding, user_id, threshold=0.9):
|
def _search_source_node(self, source_embedding, filters, threshold=0.9):
|
||||||
|
agent_filter = ""
|
||||||
|
if filters.get("agent_id"):
|
||||||
|
agent_filter = "AND source_candidate.agent_id = $agent_id"
|
||||||
|
|
||||||
cypher = f"""
|
cypher = f"""
|
||||||
MATCH (source_candidate {self.node_label})
|
MATCH (source_candidate {self.node_label})
|
||||||
WHERE source_candidate.embedding IS NOT NULL
|
WHERE source_candidate.embedding IS NOT NULL
|
||||||
AND source_candidate.user_id = $user_id
|
AND source_candidate.user_id = $user_id
|
||||||
|
{agent_filter}
|
||||||
|
|
||||||
WITH source_candidate,
|
WITH source_candidate,
|
||||||
round(2 * vector.similarity.cosine(source_candidate.embedding, $source_embedding) - 1, 4) AS source_similarity // denormalize for backward compatibility
|
round(2 * vector.similarity.cosine(source_candidate.embedding, $source_embedding) - 1, 4) AS source_similarity // denormalize for backward compatibility
|
||||||
@@ -496,18 +573,26 @@ class MemoryGraph:
|
|||||||
|
|
||||||
params = {
|
params = {
|
||||||
"source_embedding": source_embedding,
|
"source_embedding": source_embedding,
|
||||||
"user_id": user_id,
|
"user_id": filters["user_id"],
|
||||||
"threshold": threshold,
|
"threshold": threshold,
|
||||||
}
|
}
|
||||||
|
if filters.get("agent_id"):
|
||||||
|
params["agent_id"] = filters["agent_id"]
|
||||||
|
|
||||||
result = self.graph.query(cypher, params=params)
|
result = self.graph.query(cypher, params=params)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _search_destination_node(self, destination_embedding, user_id, threshold=0.9):
|
|
||||||
|
def _search_destination_node(self, destination_embedding, filters, threshold=0.9):
|
||||||
|
agent_filter = ""
|
||||||
|
if filters.get("agent_id"):
|
||||||
|
agent_filter = "AND destination_candidate.agent_id = $agent_id"
|
||||||
|
|
||||||
cypher = f"""
|
cypher = f"""
|
||||||
MATCH (destination_candidate {self.node_label})
|
MATCH (destination_candidate {self.node_label})
|
||||||
WHERE destination_candidate.embedding IS NOT NULL
|
WHERE destination_candidate.embedding IS NOT NULL
|
||||||
AND destination_candidate.user_id = $user_id
|
AND destination_candidate.user_id = $user_id
|
||||||
|
{agent_filter}
|
||||||
|
|
||||||
WITH destination_candidate,
|
WITH destination_candidate,
|
||||||
round(2 * vector.similarity.cosine(destination_candidate.embedding, $destination_embedding) - 1, 4) AS destination_similarity // denormalize for backward compatibility
|
round(2 * vector.similarity.cosine(destination_candidate.embedding, $destination_embedding) - 1, 4) AS destination_similarity // denormalize for backward compatibility
|
||||||
@@ -520,11 +605,14 @@ class MemoryGraph:
|
|||||||
|
|
||||||
RETURN elementId(destination_candidate)
|
RETURN elementId(destination_candidate)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"destination_embedding": destination_embedding,
|
"destination_embedding": destination_embedding,
|
||||||
"user_id": user_id,
|
"user_id": filters["user_id"],
|
||||||
"threshold": threshold,
|
"threshold": threshold,
|
||||||
}
|
}
|
||||||
|
if filters.get("agent_id"):
|
||||||
|
params["agent_id"] = filters["agent_id"]
|
||||||
|
|
||||||
result = self.graph.query(cypher, params=params)
|
result = self.graph.query(cypher, params=params)
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -118,11 +118,19 @@ class MemoryGraph:
|
|||||||
return search_results
|
return search_results
|
||||||
|
|
||||||
def delete_all(self, filters):
|
def delete_all(self, filters):
|
||||||
cypher = """
|
"""Delete all nodes and relationships for a user or specific agent."""
|
||||||
MATCH (n {user_id: $user_id})
|
if filters.get("agent_id"):
|
||||||
DETACH DELETE n
|
cypher = """
|
||||||
"""
|
MATCH (n:Entity {user_id: $user_id, agent_id: $agent_id})
|
||||||
params = {"user_id": filters["user_id"]}
|
DETACH DELETE n
|
||||||
|
"""
|
||||||
|
params = {"user_id": filters["user_id"], "agent_id": filters["agent_id"]}
|
||||||
|
else:
|
||||||
|
cypher = """
|
||||||
|
MATCH (n:Entity {user_id: $user_id})
|
||||||
|
DETACH DELETE n
|
||||||
|
"""
|
||||||
|
params = {"user_id": filters["user_id"]}
|
||||||
self.graph.query(cypher, params=params)
|
self.graph.query(cypher, params=params)
|
||||||
|
|
||||||
def get_all(self, filters, limit=100):
|
def get_all(self, filters, limit=100):
|
||||||
@@ -131,20 +139,31 @@ class MemoryGraph:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
filters (dict): A dictionary containing filters to be applied during the retrieval.
|
filters (dict): A dictionary containing filters to be applied during the retrieval.
|
||||||
|
Supports 'user_id' (required) and 'agent_id' (optional).
|
||||||
limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100.
|
limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100.
|
||||||
Returns:
|
Returns:
|
||||||
list: A list of dictionaries, each containing:
|
list: A list of dictionaries, each containing:
|
||||||
- 'contexts': The base data store response for each memory.
|
- 'source': The source node name.
|
||||||
- 'entities': A list of strings representing the nodes and relationships
|
- 'relationship': The relationship type.
|
||||||
|
- 'target': The target node name.
|
||||||
"""
|
"""
|
||||||
|
# Build query based on whether agent_id is provided
|
||||||
# return all nodes and relationships
|
if filters.get("agent_id"):
|
||||||
query = """
|
query = """
|
||||||
MATCH (n:Entity {user_id: $user_id})-[r]->(m:Entity {user_id: $user_id})
|
MATCH (n:Entity {user_id: $user_id, agent_id: $agent_id})-[r]->(m:Entity {user_id: $user_id, agent_id: $agent_id})
|
||||||
RETURN n.name AS source, type(r) AS relationship, m.name AS target
|
RETURN n.name AS source, type(r) AS relationship, m.name AS target
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
results = self.graph.query(query, params={"user_id": filters["user_id"], "limit": limit})
|
params = {"user_id": filters["user_id"], "agent_id": filters["agent_id"], "limit": limit}
|
||||||
|
else:
|
||||||
|
query = """
|
||||||
|
MATCH (n:Entity {user_id: $user_id})-[r]->(m:Entity {user_id: $user_id})
|
||||||
|
RETURN n.name AS source, type(r) AS relationship, m.name AS target
|
||||||
|
LIMIT $limit
|
||||||
|
"""
|
||||||
|
params = {"user_id": filters["user_id"], "limit": limit}
|
||||||
|
|
||||||
|
results = self.graph.query(query, params=params)
|
||||||
|
|
||||||
final_results = []
|
final_results = []
|
||||||
for result in results:
|
for result in results:
|
||||||
@@ -241,33 +260,65 @@ class MemoryGraph:
|
|||||||
for node in node_list:
|
for node in node_list:
|
||||||
n_embedding = self.embedding_model.embed(node)
|
n_embedding = self.embedding_model.embed(node)
|
||||||
|
|
||||||
cypher_query = """
|
# Build query based on whether agent_id is provided
|
||||||
MATCH (n:Entity {user_id: $user_id})-[r]->(m:Entity)
|
if filters.get("agent_id"):
|
||||||
WHERE n.embedding IS NOT NULL
|
cypher_query = """
|
||||||
WITH collect(n) AS nodes1, collect(m) AS nodes2, r
|
MATCH (n:Entity {user_id: $user_id, agent_id: $agent_id})-[r]->(m:Entity)
|
||||||
CALL node_similarity.cosine_pairwise("embedding", nodes1, nodes2)
|
WHERE n.embedding IS NOT NULL
|
||||||
YIELD node1, node2, similarity
|
WITH collect(n) AS nodes1, collect(m) AS nodes2, r
|
||||||
WITH node1, node2, similarity, r
|
CALL node_similarity.cosine_pairwise("embedding", nodes1, nodes2)
|
||||||
WHERE similarity >= $threshold
|
YIELD node1, node2, similarity
|
||||||
RETURN node1.user_id AS source, id(node1) AS source_id, type(r) AS relationship, id(r) AS relation_id, node2.user_id AS destination, id(node2) AS destination_id, similarity
|
WITH node1, node2, similarity, r
|
||||||
UNION
|
WHERE similarity >= $threshold
|
||||||
MATCH (n:Entity {user_id: $user_id})<-[r]-(m:Entity)
|
RETURN node1.name AS source, id(node1) AS source_id, type(r) AS relationship, id(r) AS relation_id, node2.name AS destination, id(node2) AS destination_id, similarity
|
||||||
WHERE n.embedding IS NOT NULL
|
UNION
|
||||||
WITH collect(n) AS nodes1, collect(m) AS nodes2, r
|
MATCH (n:Entity {user_id: $user_id, agent_id: $agent_id})<-[r]-(m:Entity)
|
||||||
CALL node_similarity.cosine_pairwise("embedding", nodes1, nodes2)
|
WHERE n.embedding IS NOT NULL
|
||||||
YIELD node1, node2, similarity
|
WITH collect(n) AS nodes1, collect(m) AS nodes2, r
|
||||||
WITH node1, node2, similarity, r
|
CALL node_similarity.cosine_pairwise("embedding", nodes1, nodes2)
|
||||||
WHERE similarity >= $threshold
|
YIELD node1, node2, similarity
|
||||||
RETURN node2.name AS source, id(node2) AS source_id, type(r) AS relationship, id(r) AS relation_id, node1.name AS destination, id(node1) AS destination_id, similarity
|
WITH node1, node2, similarity, r
|
||||||
ORDER BY similarity DESC
|
WHERE similarity >= $threshold
|
||||||
LIMIT $limit;
|
RETURN node2.name AS source, id(node2) AS source_id, type(r) AS relationship, id(r) AS relation_id, node1.name AS destination, id(node1) AS destination_id, similarity
|
||||||
"""
|
ORDER BY similarity DESC
|
||||||
params = {
|
LIMIT $limit;
|
||||||
"n_embedding": n_embedding,
|
"""
|
||||||
"threshold": self.threshold,
|
params = {
|
||||||
"user_id": filters["user_id"],
|
"n_embedding": n_embedding,
|
||||||
"limit": limit,
|
"threshold": self.threshold,
|
||||||
}
|
"user_id": filters["user_id"],
|
||||||
|
"agent_id": filters["agent_id"],
|
||||||
|
"limit": limit,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
cypher_query = """
|
||||||
|
MATCH (n:Entity {user_id: $user_id})-[r]->(m:Entity)
|
||||||
|
WHERE n.embedding IS NOT NULL
|
||||||
|
WITH collect(n) AS nodes1, collect(m) AS nodes2, r
|
||||||
|
CALL node_similarity.cosine_pairwise("embedding", nodes1, nodes2)
|
||||||
|
YIELD node1, node2, similarity
|
||||||
|
WITH node1, node2, similarity, r
|
||||||
|
WHERE similarity >= $threshold
|
||||||
|
RETURN node1.name AS source, id(node1) AS source_id, type(r) AS relationship, id(r) AS relation_id, node2.name AS destination, id(node2) AS destination_id, similarity
|
||||||
|
UNION
|
||||||
|
MATCH (n:Entity {user_id: $user_id})<-[r]-(m:Entity)
|
||||||
|
WHERE n.embedding IS NOT NULL
|
||||||
|
WITH collect(n) AS nodes1, collect(m) AS nodes2, r
|
||||||
|
CALL node_similarity.cosine_pairwise("embedding", nodes1, nodes2)
|
||||||
|
YIELD node1, node2, similarity
|
||||||
|
WITH node1, node2, similarity, r
|
||||||
|
WHERE similarity >= $threshold
|
||||||
|
RETURN node2.name AS source, id(node2) AS source_id, type(r) AS relationship, id(r) AS relation_id, node1.name AS destination, id(node1) AS destination_id, similarity
|
||||||
|
ORDER BY similarity DESC
|
||||||
|
LIMIT $limit;
|
||||||
|
"""
|
||||||
|
params = {
|
||||||
|
"n_embedding": n_embedding,
|
||||||
|
"threshold": self.threshold,
|
||||||
|
"user_id": filters["user_id"],
|
||||||
|
"limit": limit,
|
||||||
|
}
|
||||||
|
|
||||||
ans = self.graph.query(cypher_query, params=params)
|
ans = self.graph.query(cypher_query, params=params)
|
||||||
result_relations.extend(ans)
|
result_relations.extend(ans)
|
||||||
|
|
||||||
@@ -300,38 +351,54 @@ class MemoryGraph:
|
|||||||
logger.debug(f"Deleted relationships: {to_be_deleted}")
|
logger.debug(f"Deleted relationships: {to_be_deleted}")
|
||||||
return to_be_deleted
|
return to_be_deleted
|
||||||
|
|
||||||
def _delete_entities(self, to_be_deleted, user_id):
|
def _delete_entities(self, to_be_deleted, filters):
|
||||||
"""Delete the entities from the graph."""
|
"""Delete the entities from the graph."""
|
||||||
|
user_id = filters["user_id"]
|
||||||
|
agent_id = filters.get("agent_id", None)
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
for item in to_be_deleted:
|
for item in to_be_deleted:
|
||||||
source = item["source"]
|
source = item["source"]
|
||||||
destination = item["destination"]
|
destination = item["destination"]
|
||||||
relationship = item["relationship"]
|
relationship = item["relationship"]
|
||||||
|
|
||||||
|
# Build the agent filter for the query
|
||||||
|
agent_filter = ""
|
||||||
|
params = {
|
||||||
|
"source_name": source,
|
||||||
|
"dest_name": destination,
|
||||||
|
"user_id": user_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
if agent_id:
|
||||||
|
agent_filter = "AND n.agent_id = $agent_id AND m.agent_id = $agent_id"
|
||||||
|
params["agent_id"] = agent_id
|
||||||
|
|
||||||
# Delete the specific relationship between nodes
|
# Delete the specific relationship between nodes
|
||||||
cypher = f"""
|
cypher = f"""
|
||||||
MATCH (n:Entity {{name: $source_name, user_id: $user_id}})
|
MATCH (n:Entity {{name: $source_name, user_id: $user_id}})
|
||||||
-[r:{relationship}]->
|
-[r:{relationship}]->
|
||||||
(m {{name: $dest_name, user_id: $user_id}})
|
(m:Entity {{name: $dest_name, user_id: $user_id}})
|
||||||
|
WHERE 1=1 {agent_filter}
|
||||||
DELETE r
|
DELETE r
|
||||||
RETURN
|
RETURN
|
||||||
n.name AS source,
|
n.name AS source,
|
||||||
m.name AS target,
|
m.name AS target,
|
||||||
type(r) AS relationship
|
type(r) AS relationship
|
||||||
"""
|
"""
|
||||||
params = {
|
|
||||||
"source_name": source,
|
|
||||||
"dest_name": destination,
|
|
||||||
"user_id": user_id,
|
|
||||||
}
|
|
||||||
result = self.graph.query(cypher, params=params)
|
result = self.graph.query(cypher, params=params)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
# added Entity label to all nodes for vector search to work
|
# added Entity label to all nodes for vector search to work
|
||||||
def _add_entities(self, to_be_added, user_id, entity_type_map):
|
def _add_entities(self, to_be_added, filters, entity_type_map):
|
||||||
"""Add the new entities to the graph. Merge the nodes if they already exist."""
|
"""Add the new entities to the graph. Merge the nodes if they already exist."""
|
||||||
|
user_id = filters["user_id"]
|
||||||
|
agent_id = filters.get("agent_id", None)
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
for item in to_be_added:
|
for item in to_be_added:
|
||||||
# entities
|
# entities
|
||||||
source = item["source"]
|
source = item["source"]
|
||||||
@@ -346,18 +413,21 @@ class MemoryGraph:
|
|||||||
source_embedding = self.embedding_model.embed(source)
|
source_embedding = self.embedding_model.embed(source)
|
||||||
dest_embedding = self.embedding_model.embed(destination)
|
dest_embedding = self.embedding_model.embed(destination)
|
||||||
|
|
||||||
# search for the nodes with the closest embeddings; this is basically
|
# search for the nodes with the closest embeddings
|
||||||
# comparison of one embedding to all embeddings in a graph -> vector
|
source_node_search_result = self._search_source_node(source_embedding, filters, threshold=0.9)
|
||||||
# search with cosine similarity metric
|
destination_node_search_result = self._search_destination_node(dest_embedding, filters, threshold=0.9)
|
||||||
source_node_search_result = self._search_source_node(source_embedding, user_id, threshold=0.9)
|
|
||||||
destination_node_search_result = self._search_destination_node(dest_embedding, user_id, threshold=0.9)
|
|
||||||
|
|
||||||
|
# Prepare agent_id for node creation
|
||||||
|
agent_id_clause = ""
|
||||||
|
if agent_id:
|
||||||
|
agent_id_clause = ", agent_id: $agent_id"
|
||||||
|
|
||||||
# TODO: Create a cypher query and common params for all the cases
|
# TODO: Create a cypher query and common params for all the cases
|
||||||
if not destination_node_search_result and source_node_search_result:
|
if not destination_node_search_result and source_node_search_result:
|
||||||
cypher = f"""
|
cypher = f"""
|
||||||
MATCH (source:Entity)
|
MATCH (source:Entity)
|
||||||
WHERE id(source) = $source_id
|
WHERE id(source) = $source_id
|
||||||
MERGE (destination:{destination_type}:Entity {{name: $destination_name, user_id: $user_id}})
|
MERGE (destination:{destination_type}:Entity {{name: $destination_name, user_id: $user_id{agent_id_clause}}})
|
||||||
ON CREATE SET
|
ON CREATE SET
|
||||||
destination.created = timestamp(),
|
destination.created = timestamp(),
|
||||||
destination.embedding = $destination_embedding,
|
destination.embedding = $destination_embedding,
|
||||||
@@ -374,11 +444,14 @@ class MemoryGraph:
|
|||||||
"destination_embedding": dest_embedding,
|
"destination_embedding": dest_embedding,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
}
|
}
|
||||||
|
if agent_id:
|
||||||
|
params["agent_id"] = agent_id
|
||||||
|
|
||||||
elif destination_node_search_result and not source_node_search_result:
|
elif destination_node_search_result and not source_node_search_result:
|
||||||
cypher = f"""
|
cypher = f"""
|
||||||
MATCH (destination:Entity)
|
MATCH (destination:Entity)
|
||||||
WHERE id(destination) = $destination_id
|
WHERE id(destination) = $destination_id
|
||||||
MERGE (source:{source_type}:Entity {{name: $source_name, user_id: $user_id}})
|
MERGE (source:{source_type}:Entity {{name: $source_name, user_id: $user_id{agent_id_clause}}})
|
||||||
ON CREATE SET
|
ON CREATE SET
|
||||||
source.created = timestamp(),
|
source.created = timestamp(),
|
||||||
source.embedding = $source_embedding,
|
source.embedding = $source_embedding,
|
||||||
@@ -395,6 +468,9 @@ class MemoryGraph:
|
|||||||
"source_embedding": source_embedding,
|
"source_embedding": source_embedding,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
}
|
}
|
||||||
|
if agent_id:
|
||||||
|
params["agent_id"] = agent_id
|
||||||
|
|
||||||
elif source_node_search_result and destination_node_search_result:
|
elif source_node_search_result and destination_node_search_result:
|
||||||
cypher = f"""
|
cypher = f"""
|
||||||
MATCH (source:Entity)
|
MATCH (source:Entity)
|
||||||
@@ -412,12 +488,15 @@ class MemoryGraph:
|
|||||||
"destination_id": destination_node_search_result[0]["id(destination_candidate)"],
|
"destination_id": destination_node_search_result[0]["id(destination_candidate)"],
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
}
|
}
|
||||||
|
if agent_id:
|
||||||
|
params["agent_id"] = agent_id
|
||||||
|
|
||||||
else:
|
else:
|
||||||
cypher = f"""
|
cypher = f"""
|
||||||
MERGE (n:{source_type}:Entity {{name: $source_name, user_id: $user_id}})
|
MERGE (n:{source_type}:Entity {{name: $source_name, user_id: $user_id{agent_id_clause}}})
|
||||||
ON CREATE SET n.created = timestamp(), n.embedding = $source_embedding, n:Entity
|
ON CREATE SET n.created = timestamp(), n.embedding = $source_embedding, n:Entity
|
||||||
ON MATCH SET n.embedding = $source_embedding
|
ON MATCH SET n.embedding = $source_embedding
|
||||||
MERGE (m:{destination_type}:Entity {{name: $dest_name, user_id: $user_id}})
|
MERGE (m:{destination_type}:Entity {{name: $dest_name, user_id: $user_id{agent_id_clause}}})
|
||||||
ON CREATE SET m.created = timestamp(), m.embedding = $dest_embedding, m:Entity
|
ON CREATE SET m.created = timestamp(), m.embedding = $dest_embedding, m:Entity
|
||||||
ON MATCH SET m.embedding = $dest_embedding
|
ON MATCH SET m.embedding = $dest_embedding
|
||||||
MERGE (n)-[rel:{relationship}]->(m)
|
MERGE (n)-[rel:{relationship}]->(m)
|
||||||
@@ -431,6 +510,9 @@ class MemoryGraph:
|
|||||||
"dest_embedding": dest_embedding,
|
"dest_embedding": dest_embedding,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
}
|
}
|
||||||
|
if agent_id:
|
||||||
|
params["agent_id"] = agent_id
|
||||||
|
|
||||||
result = self.graph.query(cypher, params=params)
|
result = self.graph.query(cypher, params=params)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
return results
|
return results
|
||||||
@@ -442,37 +524,80 @@ class MemoryGraph:
|
|||||||
item["destination"] = item["destination"].lower().replace(" ", "_")
|
item["destination"] = item["destination"].lower().replace(" ", "_")
|
||||||
return entity_list
|
return entity_list
|
||||||
|
|
||||||
def _search_source_node(self, source_embedding, user_id, threshold=0.9):
|
def _search_source_node(self, source_embedding, filters, threshold=0.9):
|
||||||
cypher = """
|
"""Search for source nodes with similar embeddings."""
|
||||||
CALL vector_search.search("memzero", 1, $source_embedding)
|
user_id = filters["user_id"]
|
||||||
YIELD distance, node, similarity
|
agent_id = filters.get("agent_id", None)
|
||||||
WITH node AS source_candidate, similarity
|
|
||||||
WHERE source_candidate.user_id = $user_id AND similarity >= $threshold
|
if agent_id:
|
||||||
RETURN id(source_candidate);
|
cypher = """
|
||||||
"""
|
CALL vector_search.search("memzero", 1, $source_embedding)
|
||||||
|
YIELD distance, node, similarity
|
||||||
params = {
|
WITH node AS source_candidate, similarity
|
||||||
"source_embedding": source_embedding,
|
WHERE source_candidate.user_id = $user_id
|
||||||
"user_id": user_id,
|
AND source_candidate.agent_id = $agent_id
|
||||||
"threshold": threshold,
|
AND similarity >= $threshold
|
||||||
}
|
RETURN id(source_candidate);
|
||||||
|
"""
|
||||||
|
params = {
|
||||||
|
"source_embedding": source_embedding,
|
||||||
|
"user_id": user_id,
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"threshold": threshold,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
cypher = """
|
||||||
|
CALL vector_search.search("memzero", 1, $source_embedding)
|
||||||
|
YIELD distance, node, similarity
|
||||||
|
WITH node AS source_candidate, similarity
|
||||||
|
WHERE source_candidate.user_id = $user_id
|
||||||
|
AND similarity >= $threshold
|
||||||
|
RETURN id(source_candidate);
|
||||||
|
"""
|
||||||
|
params = {
|
||||||
|
"source_embedding": source_embedding,
|
||||||
|
"user_id": user_id,
|
||||||
|
"threshold": threshold,
|
||||||
|
}
|
||||||
|
|
||||||
result = self.graph.query(cypher, params=params)
|
result = self.graph.query(cypher, params=params)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _search_destination_node(self, destination_embedding, user_id, threshold=0.9):
|
def _search_destination_node(self, destination_embedding, filters, threshold=0.9):
|
||||||
cypher = """
|
"""Search for destination nodes with similar embeddings."""
|
||||||
CALL vector_search.search("memzero", 1, $destination_embedding)
|
user_id = filters["user_id"]
|
||||||
YIELD distance, node, similarity
|
agent_id = filters.get("agent_id", None)
|
||||||
WITH node AS destination_candidate, similarity
|
|
||||||
WHERE node.user_id = $user_id AND similarity >= $threshold
|
if agent_id:
|
||||||
RETURN id(destination_candidate);
|
cypher = """
|
||||||
"""
|
CALL vector_search.search("memzero", 1, $destination_embedding)
|
||||||
params = {
|
YIELD distance, node, similarity
|
||||||
"destination_embedding": destination_embedding,
|
WITH node AS destination_candidate, similarity
|
||||||
"user_id": user_id,
|
WHERE node.user_id = $user_id
|
||||||
"threshold": threshold,
|
AND node.agent_id = $agent_id
|
||||||
}
|
AND similarity >= $threshold
|
||||||
|
RETURN id(destination_candidate);
|
||||||
|
"""
|
||||||
|
params = {
|
||||||
|
"destination_embedding": destination_embedding,
|
||||||
|
"user_id": user_id,
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"threshold": threshold,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
cypher = """
|
||||||
|
CALL vector_search.search("memzero", 1, $destination_embedding)
|
||||||
|
YIELD distance, node, similarity
|
||||||
|
WITH node AS destination_candidate, similarity
|
||||||
|
WHERE node.user_id = $user_id
|
||||||
|
AND similarity >= $threshold
|
||||||
|
RETURN id(destination_candidate);
|
||||||
|
"""
|
||||||
|
params = {
|
||||||
|
"destination_embedding": destination_embedding,
|
||||||
|
"user_id": user_id,
|
||||||
|
"threshold": threshold,
|
||||||
|
}
|
||||||
|
|
||||||
result = self.graph.query(cypher, params=params)
|
result = self.graph.query(cypher, params=params)
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -43,6 +43,8 @@ llms = [
|
|||||||
"ollama>=0.1.0",
|
"ollama>=0.1.0",
|
||||||
"vertexai>=0.1.0",
|
"vertexai>=0.1.0",
|
||||||
"google-generativeai>=0.3.0",
|
"google-generativeai>=0.3.0",
|
||||||
|
"google-genai>=1.0.0",
|
||||||
|
|
||||||
]
|
]
|
||||||
extras = [
|
extras = [
|
||||||
"boto3>=1.34.0",
|
"boto3>=1.34.0",
|
||||||
|
|||||||
Reference in New Issue
Block a user