Code formatting and doc update (#2130)

This commit is contained in:
Dev Khant
2025-01-09 20:48:18 +05:30
committed by GitHub
parent 21854c6a24
commit a8f3ec25b7
10 changed files with 83 additions and 126 deletions

View File

@@ -27,9 +27,9 @@ from pydantic import BaseModel, Field
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
client = MemoryClient( client = MemoryClient(
"---", api_key=your_api_key,
org_id="---", org_id=your_org_id,
project_id="---" project_id=your_project_id
) )
``` ```

View File

@@ -362,11 +362,7 @@ class MemoryClient:
Raises: Raises:
APIError: If the API request fails. APIError: If the API request fails.
""" """
response = self.client.request( response = self.client.request("DELETE", "/v1/batch/", json={"memories": memories})
"DELETE",
"/v1/batch/",
json={"memories": memories}
)
response.raise_for_status() response.raise_for_status()
capture_client_event("client.batch_delete", self) capture_client_event("client.batch_delete", self)
@@ -383,10 +379,7 @@ class MemoryClient:
Returns: Returns:
Dict containing export request ID and status message Dict containing export request ID and status message
""" """
response = self.client.post( response = self.client.post("/v1/exports/", json={"schema": schema, **self._prepare_params(kwargs)})
"/v1/exports/",
json={"schema": schema, **self._prepare_params(kwargs)}
)
response.raise_for_status() response.raise_for_status()
capture_client_event("client.create_memory_export", self, {"schema": schema, "keys": list(kwargs.keys())}) capture_client_event("client.create_memory_export", self, {"schema": schema, "keys": list(kwargs.keys())})
return response.json() return response.json()
@@ -401,10 +394,7 @@ class MemoryClient:
Returns: Returns:
Dict containing the exported data Dict containing the exported data
""" """
response = self.client.get( response = self.client.get("/v1/exports/", params=self._prepare_params(kwargs))
"/v1/exports/",
params=self._prepare_params(kwargs)
)
response.raise_for_status() response.raise_for_status()
capture_client_event("client.get_memory_export", self, {"keys": list(kwargs.keys())}) capture_client_event("client.get_memory_export", self, {"keys": list(kwargs.keys())})
return response.json() return response.json()
@@ -498,14 +488,7 @@ class AsyncMemoryClient:
org_id: Optional[str] = None, org_id: Optional[str] = None,
project_id: Optional[str] = None, project_id: Optional[str] = None,
): ):
self.sync_client = MemoryClient( self.sync_client = MemoryClient(api_key, host, organization, project, org_id, project_id)
api_key,
host,
organization,
project,
org_id,
project_id
)
self.async_client = httpx.AsyncClient( self.async_client = httpx.AsyncClient(
base_url=self.sync_client.host, base_url=self.sync_client.host,
headers=self.sync_client.client.headers, headers=self.sync_client.client.headers,

View File

@@ -95,17 +95,14 @@ RELATIONS_TOOL = {
"items": { "items": {
"type": "object", "type": "object",
"properties": { "properties": {
"source": { "source": {"type": "string", "description": "The source entity of the relationship."},
"type": "string",
"description": "The source entity of the relationship."
},
"relationship": { "relationship": {
"type": "string", "type": "string",
"description": "The relationship between the source and destination entities." "description": "The relationship between the source and destination entities.",
}, },
"destination": { "destination": {
"type": "string", "type": "string",
"description": "The destination entity of the relationship." "description": "The destination entity of the relationship.",
}, },
}, },
"required": [ "required": [
@@ -137,25 +134,19 @@ EXTRACT_ENTITIES_TOOL = {
"items": { "items": {
"type": "object", "type": "object",
"properties": { "properties": {
"entity": { "entity": {"type": "string", "description": "The name or identifier of the entity."},
"type": "string", "entity_type": {"type": "string", "description": "The type or category of the entity."},
"description": "The name or identifier of the entity."
},
"entity_type": {
"type": "string",
"description": "The type or category of the entity."
}
}, },
"required": ["entity", "entity_type"], "required": ["entity", "entity_type"],
"additionalProperties": False "additionalProperties": False,
}, },
"description": "An array of entities with their types." "description": "An array of entities with their types.",
} }
}, },
"required": ["entities"], "required": ["entities"],
"additionalProperties": False "additionalProperties": False,
} },
} },
} }
UPDATE_MEMORY_STRUCT_TOOL_GRAPH = { UPDATE_MEMORY_STRUCT_TOOL_GRAPH = {
@@ -260,15 +251,15 @@ RELATIONS_STRUCT_TOOL = {
"properties": { "properties": {
"source_entity": { "source_entity": {
"type": "string", "type": "string",
"description": "The source entity of the relationship." "description": "The source entity of the relationship.",
}, },
"relatationship": { "relatationship": {
"type": "string", "type": "string",
"description": "The relationship between the source and destination entities." "description": "The relationship between the source and destination entities.",
}, },
"destination_entity": { "destination_entity": {
"type": "string", "type": "string",
"description": "The destination entity of the relationship." "description": "The destination entity of the relationship.",
}, },
}, },
"required": [ "required": [
@@ -301,25 +292,19 @@ EXTRACT_ENTITIES_STRUCT_TOOL = {
"items": { "items": {
"type": "object", "type": "object",
"properties": { "properties": {
"entity": { "entity": {"type": "string", "description": "The name or identifier of the entity."},
"type": "string", "entity_type": {"type": "string", "description": "The type or category of the entity."},
"description": "The name or identifier of the entity."
},
"entity_type": {
"type": "string",
"description": "The type or category of the entity."
}
}, },
"required": ["entity", "entity_type"], "required": ["entity", "entity_type"],
"additionalProperties": False "additionalProperties": False,
}, },
"description": "An array of entities with their types." "description": "An array of entities with their types.",
} }
}, },
"required": ["entities"], "required": ["entities"],
"additionalProperties": False "additionalProperties": False,
} },
} },
} }
DELETE_MEMORY_STRUCT_TOOL_GRAPH = { DELETE_MEMORY_STRUCT_TOOL_GRAPH = {
@@ -342,7 +327,7 @@ DELETE_MEMORY_STRUCT_TOOL_GRAPH = {
"destination": { "destination": {
"type": "string", "type": "string",
"description": "The identifier of the destination node in the relationship.", "description": "The identifier of the destination node in the relationship.",
} },
}, },
"required": [ "required": [
"source", "source",
@@ -373,7 +358,7 @@ DELETE_MEMORY_TOOL_GRAPH = {
"destination": { "destination": {
"type": "string", "type": "string",
"description": "The identifier of the destination node in the relationship.", "description": "The identifier of the destination node in the relationship.",
} },
}, },
"required": [ "required": [
"source", "source",

View File

@@ -90,5 +90,8 @@ source -- relationship -- destination
Provide a list of deletion instructions, each specifying the relationship to be deleted. Provide a list of deletion instructions, each specifying the relationship to be deleted.
""" """
def get_delete_messages(existing_memories_string, data, user_id): def get_delete_messages(existing_memories_string, data, user_id):
return DELETE_RELATIONS_SYSTEM_PROMPT.replace("USER_ID", user_id), f"Here are the existing memories: {existing_memories_string} \n\n New Information: {data}" return DELETE_RELATIONS_SYSTEM_PROMPT.replace(
"USER_ID", user_id
), f"Here are the existing memories: {existing_memories_string} \n\n New Information: {data}"

View File

@@ -1,4 +1,3 @@
import json import json
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional

View File

@@ -1,5 +1,4 @@
import os import os
import json
from typing import Dict, List, Optional from typing import Dict, List, Optional
try: try:
@@ -39,11 +38,7 @@ class GeminiLLM(LLMBase):
""" """
if tools: if tools:
processed_response = { processed_response = {
"content": ( "content": (content if (content := response.candidates[0].content.parts[0].text) else None),
content
if (content := response.candidates[0].content.parts[0].text)
else None
),
"tool_calls": [], "tool_calls": [],
} }
@@ -51,13 +46,9 @@ class GeminiLLM(LLMBase):
if fn := part.function_call: if fn := part.function_call:
if isinstance(fn, protos.FunctionCall): if isinstance(fn, protos.FunctionCall):
fn_call = type(fn).to_dict(fn) fn_call = type(fn).to_dict(fn)
processed_response["tool_calls"].append( processed_response["tool_calls"].append({"name": fn_call["name"], "arguments": fn_call["args"]})
{"name": fn_call["name"], "arguments": fn_call["args"]}
)
continue continue
processed_response["tool_calls"].append( processed_response["tool_calls"].append({"name": fn.name, "arguments": fn.args})
{"name": fn.name, "arguments": fn.args}
)
return processed_response return processed_response
else: else:
@@ -77,9 +68,7 @@ class GeminiLLM(LLMBase):
for message in messages: for message in messages:
if message["role"] == "system": if message["role"] == "system":
content = ( content = "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"]
"THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"]
)
else: else:
content = message["content"] content = message["content"]
@@ -121,9 +110,7 @@ class GeminiLLM(LLMBase):
if tools: if tools:
for tool in tools: for tool in tools:
func = tool["function"].copy() func = tool["function"].copy()
new_tools.append( new_tools.append({"function_declarations": [remove_additional_properties(func)]})
{"function_declarations": [remove_additional_properties(func)]}
)
# TODO: temporarily ignore it to pass tests, will come back to update according to standards later. # TODO: temporarily ignore it to pass tests, will come back to update according to standards later.
# return content_types.to_function_library(new_tools) # return content_types.to_function_library(new_tools)
@@ -168,9 +155,7 @@ class GeminiLLM(LLMBase):
"function_calling_config": { "function_calling_config": {
"mode": tool_choice, "mode": tool_choice,
"allowed_function_names": ( "allowed_function_names": (
[tool["function"]["name"] for tool in tools] [tool["function"]["name"] for tool in tools] if tool_choice == "any" else None
if tool_choice == "any"
else None
), ),
} }
} }

View File

@@ -86,7 +86,9 @@ class MemoryGraph:
if not search_output: if not search_output:
return [] return []
search_outputs_sequence = [[item["source"], item["relatationship"], item["destination"]] for item in search_output] search_outputs_sequence = [
[item["source"], item["relatationship"], item["destination"]] for item in search_output
]
bm25 = BM25Okapi(search_outputs_sequence) bm25 = BM25Okapi(search_outputs_sequence)
tokenized_query = query.split(" ") tokenized_query = query.split(" ")
@@ -343,7 +345,7 @@ class MemoryGraph:
""" """
params = { params = {
"source_id": source_node_search_result[0]['elementId(source_candidate)'], "source_id": source_node_search_result[0]["elementId(source_candidate)"],
"destination_name": destination, "destination_name": destination,
"relationship": relationship, "relationship": relationship,
"destination_type": destination_type, "destination_type": destination_type,
@@ -368,7 +370,7 @@ class MemoryGraph:
""" """
params = { params = {
"destination_id": destination_node_search_result[0]['elementId(destination_candidate)'], "destination_id": destination_node_search_result[0]["elementId(destination_candidate)"],
"source_name": source, "source_name": source,
"relationship": relationship, "relationship": relationship,
"source_type": source_type, "source_type": source_type,
@@ -391,8 +393,8 @@ class MemoryGraph:
RETURN source.name AS source, type(r) AS relationship, destination.name AS target RETURN source.name AS source, type(r) AS relationship, destination.name AS target
""" """
params = { params = {
"source_id": source_node_search_result[0]['elementId(source_candidate)'], "source_id": source_node_search_result[0]["elementId(source_candidate)"],
"destination_id": destination_node_search_result[0]['elementId(destination_candidate)'], "destination_id": destination_node_search_result[0]["elementId(destination_candidate)"],
"user_id": user_id, "user_id": user_id,
"relationship": relationship, "relationship": relationship,
} }
@@ -432,7 +434,7 @@ class MemoryGraph:
return entity_list return entity_list
def _search_source_node(self, source_embedding, user_id, threshold=0.9): def _search_source_node(self, source_embedding, user_id, threshold=0.9):
cypher = f""" cypher = """
MATCH (source_candidate) MATCH (source_candidate)
WHERE source_candidate.embedding IS NOT NULL WHERE source_candidate.embedding IS NOT NULL
AND source_candidate.user_id = $user_id AND source_candidate.user_id = $user_id
@@ -465,7 +467,7 @@ class MemoryGraph:
return result return result
def _search_destination_node(self, destination_embedding, user_id, threshold=0.9): def _search_destination_node(self, destination_embedding, user_id, threshold=0.9):
cypher = f""" cypher = """
MATCH (destination_candidate) MATCH (destination_candidate)
WHERE destination_candidate.embedding IS NOT NULL WHERE destination_candidate.embedding IS NOT NULL
AND destination_candidate.user_id = $user_id AND destination_candidate.user_id = $user_id

View File

@@ -1,5 +1,4 @@
import re import re
import json
from mem0.configs.prompts import FACT_RETRIEVAL_PROMPT from mem0.configs.prompts import FACT_RETRIEVAL_PROMPT
@@ -19,6 +18,7 @@ def parse_messages(messages):
response += f"assistant: {msg['content']}\n" response += f"assistant: {msg['content']}\n"
return response return response
def format_entities(entities): def format_entities(entities):
if not entities: if not entities:
return "" return ""