Code formatting and doc update (#2130)
This commit is contained in:
@@ -27,9 +27,9 @@ from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
client = MemoryClient(
|
||||
"---",
|
||||
org_id="---",
|
||||
project_id="---"
|
||||
api_key=your_api_key,
|
||||
org_id=your_org_id,
|
||||
project_id=your_project_id
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
@@ -362,11 +362,7 @@ class MemoryClient:
|
||||
Raises:
|
||||
APIError: If the API request fails.
|
||||
"""
|
||||
response = self.client.request(
|
||||
"DELETE",
|
||||
"/v1/batch/",
|
||||
json={"memories": memories}
|
||||
)
|
||||
response = self.client.request("DELETE", "/v1/batch/", json={"memories": memories})
|
||||
response.raise_for_status()
|
||||
|
||||
capture_client_event("client.batch_delete", self)
|
||||
@@ -383,10 +379,7 @@ class MemoryClient:
|
||||
Returns:
|
||||
Dict containing export request ID and status message
|
||||
"""
|
||||
response = self.client.post(
|
||||
"/v1/exports/",
|
||||
json={"schema": schema, **self._prepare_params(kwargs)}
|
||||
)
|
||||
response = self.client.post("/v1/exports/", json={"schema": schema, **self._prepare_params(kwargs)})
|
||||
response.raise_for_status()
|
||||
capture_client_event("client.create_memory_export", self, {"schema": schema, "keys": list(kwargs.keys())})
|
||||
return response.json()
|
||||
@@ -401,10 +394,7 @@ class MemoryClient:
|
||||
Returns:
|
||||
Dict containing the exported data
|
||||
"""
|
||||
response = self.client.get(
|
||||
"/v1/exports/",
|
||||
params=self._prepare_params(kwargs)
|
||||
)
|
||||
response = self.client.get("/v1/exports/", params=self._prepare_params(kwargs))
|
||||
response.raise_for_status()
|
||||
capture_client_event("client.get_memory_export", self, {"keys": list(kwargs.keys())})
|
||||
return response.json()
|
||||
@@ -498,14 +488,7 @@ class AsyncMemoryClient:
|
||||
org_id: Optional[str] = None,
|
||||
project_id: Optional[str] = None,
|
||||
):
|
||||
self.sync_client = MemoryClient(
|
||||
api_key,
|
||||
host,
|
||||
organization,
|
||||
project,
|
||||
org_id,
|
||||
project_id
|
||||
)
|
||||
self.sync_client = MemoryClient(api_key, host, organization, project, org_id, project_id)
|
||||
self.async_client = httpx.AsyncClient(
|
||||
base_url=self.sync_client.host,
|
||||
headers=self.sync_client.client.headers,
|
||||
|
||||
@@ -95,17 +95,14 @@ RELATIONS_TOOL = {
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"source": {
|
||||
"type": "string",
|
||||
"description": "The source entity of the relationship."
|
||||
},
|
||||
"source": {"type": "string", "description": "The source entity of the relationship."},
|
||||
"relationship": {
|
||||
"type": "string",
|
||||
"description": "The relationship between the source and destination entities."
|
||||
"description": "The relationship between the source and destination entities.",
|
||||
},
|
||||
"destination": {
|
||||
"type": "string",
|
||||
"description": "The destination entity of the relationship."
|
||||
"description": "The destination entity of the relationship.",
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
@@ -137,25 +134,19 @@ EXTRACT_ENTITIES_TOOL = {
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"entity": {
|
||||
"type": "string",
|
||||
"description": "The name or identifier of the entity."
|
||||
},
|
||||
"entity_type": {
|
||||
"type": "string",
|
||||
"description": "The type or category of the entity."
|
||||
}
|
||||
"entity": {"type": "string", "description": "The name or identifier of the entity."},
|
||||
"entity_type": {"type": "string", "description": "The type or category of the entity."},
|
||||
},
|
||||
"required": ["entity", "entity_type"],
|
||||
"additionalProperties": False
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"description": "An array of entities with their types."
|
||||
"description": "An array of entities with their types.",
|
||||
}
|
||||
},
|
||||
"required": ["entities"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
UPDATE_MEMORY_STRUCT_TOOL_GRAPH = {
|
||||
@@ -260,15 +251,15 @@ RELATIONS_STRUCT_TOOL = {
|
||||
"properties": {
|
||||
"source_entity": {
|
||||
"type": "string",
|
||||
"description": "The source entity of the relationship."
|
||||
"description": "The source entity of the relationship.",
|
||||
},
|
||||
"relatationship": {
|
||||
"type": "string",
|
||||
"description": "The relationship between the source and destination entities."
|
||||
"description": "The relationship between the source and destination entities.",
|
||||
},
|
||||
"destination_entity": {
|
||||
"type": "string",
|
||||
"description": "The destination entity of the relationship."
|
||||
"description": "The destination entity of the relationship.",
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
@@ -301,25 +292,19 @@ EXTRACT_ENTITIES_STRUCT_TOOL = {
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"entity": {
|
||||
"type": "string",
|
||||
"description": "The name or identifier of the entity."
|
||||
},
|
||||
"entity_type": {
|
||||
"type": "string",
|
||||
"description": "The type or category of the entity."
|
||||
}
|
||||
"entity": {"type": "string", "description": "The name or identifier of the entity."},
|
||||
"entity_type": {"type": "string", "description": "The type or category of the entity."},
|
||||
},
|
||||
"required": ["entity", "entity_type"],
|
||||
"additionalProperties": False
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"description": "An array of entities with their types."
|
||||
"description": "An array of entities with their types.",
|
||||
}
|
||||
},
|
||||
"required": ["entities"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
DELETE_MEMORY_STRUCT_TOOL_GRAPH = {
|
||||
@@ -342,7 +327,7 @@ DELETE_MEMORY_STRUCT_TOOL_GRAPH = {
|
||||
"destination": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the destination node in the relationship.",
|
||||
}
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"source",
|
||||
@@ -373,7 +358,7 @@ DELETE_MEMORY_TOOL_GRAPH = {
|
||||
"destination": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the destination node in the relationship.",
|
||||
}
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"source",
|
||||
|
||||
@@ -90,5 +90,8 @@ source -- relationship -- destination
|
||||
Provide a list of deletion instructions, each specifying the relationship to be deleted.
|
||||
"""
|
||||
|
||||
|
||||
def get_delete_messages(existing_memories_string, data, user_id):
|
||||
return DELETE_RELATIONS_SYSTEM_PROMPT.replace("USER_ID", user_id), f"Here are the existing memories: {existing_memories_string} \n\n New Information: {data}"
|
||||
return DELETE_RELATIONS_SYSTEM_PROMPT.replace(
|
||||
"USER_ID", user_id
|
||||
), f"Here are the existing memories: {existing_memories_string} \n\n New Information: {data}"
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
try:
|
||||
@@ -39,11 +38,7 @@ class GeminiLLM(LLMBase):
|
||||
"""
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": (
|
||||
content
|
||||
if (content := response.candidates[0].content.parts[0].text)
|
||||
else None
|
||||
),
|
||||
"content": (content if (content := response.candidates[0].content.parts[0].text) else None),
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
@@ -51,13 +46,9 @@ class GeminiLLM(LLMBase):
|
||||
if fn := part.function_call:
|
||||
if isinstance(fn, protos.FunctionCall):
|
||||
fn_call = type(fn).to_dict(fn)
|
||||
processed_response["tool_calls"].append(
|
||||
{"name": fn_call["name"], "arguments": fn_call["args"]}
|
||||
)
|
||||
processed_response["tool_calls"].append({"name": fn_call["name"], "arguments": fn_call["args"]})
|
||||
continue
|
||||
processed_response["tool_calls"].append(
|
||||
{"name": fn.name, "arguments": fn.args}
|
||||
)
|
||||
processed_response["tool_calls"].append({"name": fn.name, "arguments": fn.args})
|
||||
|
||||
return processed_response
|
||||
else:
|
||||
@@ -77,9 +68,7 @@ class GeminiLLM(LLMBase):
|
||||
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
content = (
|
||||
"THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"]
|
||||
)
|
||||
content = "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"]
|
||||
|
||||
else:
|
||||
content = message["content"]
|
||||
@@ -121,9 +110,7 @@ class GeminiLLM(LLMBase):
|
||||
if tools:
|
||||
for tool in tools:
|
||||
func = tool["function"].copy()
|
||||
new_tools.append(
|
||||
{"function_declarations": [remove_additional_properties(func)]}
|
||||
)
|
||||
new_tools.append({"function_declarations": [remove_additional_properties(func)]})
|
||||
|
||||
# TODO: temporarily ignore it to pass tests, will come back to update according to standards later.
|
||||
# return content_types.to_function_library(new_tools)
|
||||
@@ -168,9 +155,7 @@ class GeminiLLM(LLMBase):
|
||||
"function_calling_config": {
|
||||
"mode": tool_choice,
|
||||
"allowed_function_names": (
|
||||
[tool["function"]["name"] for tool in tools]
|
||||
if tool_choice == "any"
|
||||
else None
|
||||
[tool["function"]["name"] for tool in tools] if tool_choice == "any" else None
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,7 +86,9 @@ class MemoryGraph:
|
||||
if not search_output:
|
||||
return []
|
||||
|
||||
search_outputs_sequence = [[item["source"], item["relatationship"], item["destination"]] for item in search_output]
|
||||
search_outputs_sequence = [
|
||||
[item["source"], item["relatationship"], item["destination"]] for item in search_output
|
||||
]
|
||||
bm25 = BM25Okapi(search_outputs_sequence)
|
||||
|
||||
tokenized_query = query.split(" ")
|
||||
@@ -343,7 +345,7 @@ class MemoryGraph:
|
||||
"""
|
||||
|
||||
params = {
|
||||
"source_id": source_node_search_result[0]['elementId(source_candidate)'],
|
||||
"source_id": source_node_search_result[0]["elementId(source_candidate)"],
|
||||
"destination_name": destination,
|
||||
"relationship": relationship,
|
||||
"destination_type": destination_type,
|
||||
@@ -368,7 +370,7 @@ class MemoryGraph:
|
||||
"""
|
||||
|
||||
params = {
|
||||
"destination_id": destination_node_search_result[0]['elementId(destination_candidate)'],
|
||||
"destination_id": destination_node_search_result[0]["elementId(destination_candidate)"],
|
||||
"source_name": source,
|
||||
"relationship": relationship,
|
||||
"source_type": source_type,
|
||||
@@ -391,8 +393,8 @@ class MemoryGraph:
|
||||
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
||||
"""
|
||||
params = {
|
||||
"source_id": source_node_search_result[0]['elementId(source_candidate)'],
|
||||
"destination_id": destination_node_search_result[0]['elementId(destination_candidate)'],
|
||||
"source_id": source_node_search_result[0]["elementId(source_candidate)"],
|
||||
"destination_id": destination_node_search_result[0]["elementId(destination_candidate)"],
|
||||
"user_id": user_id,
|
||||
"relationship": relationship,
|
||||
}
|
||||
@@ -432,7 +434,7 @@ class MemoryGraph:
|
||||
return entity_list
|
||||
|
||||
def _search_source_node(self, source_embedding, user_id, threshold=0.9):
|
||||
cypher = f"""
|
||||
cypher = """
|
||||
MATCH (source_candidate)
|
||||
WHERE source_candidate.embedding IS NOT NULL
|
||||
AND source_candidate.user_id = $user_id
|
||||
@@ -465,7 +467,7 @@ class MemoryGraph:
|
||||
return result
|
||||
|
||||
def _search_destination_node(self, destination_embedding, user_id, threshold=0.9):
|
||||
cypher = f"""
|
||||
cypher = """
|
||||
MATCH (destination_candidate)
|
||||
WHERE destination_candidate.embedding IS NOT NULL
|
||||
AND destination_candidate.user_id = $user_id
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import re
|
||||
import json
|
||||
|
||||
from mem0.configs.prompts import FACT_RETRIEVAL_PROMPT
|
||||
|
||||
@@ -19,6 +18,7 @@ def parse_messages(messages):
|
||||
response += f"assistant: {msg['content']}\n"
|
||||
return response
|
||||
|
||||
|
||||
def format_entities(entities):
|
||||
if not entities:
|
||||
return ""
|
||||
|
||||
Reference in New Issue
Block a user