[Misc] Lint code and fix code smells (#1871)
This commit is contained in:
@@ -10,7 +10,11 @@ from mem0.memory.setup import setup_config
|
||||
from mem0.memory.telemetry import capture_client_event
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
warnings.filterwarnings('always', category=DeprecationWarning, message="The 'session_id' parameter is deprecated. User 'run_id' instead.")
|
||||
warnings.filterwarnings(
|
||||
"always",
|
||||
category=DeprecationWarning,
|
||||
message="The 'session_id' parameter is deprecated. User 'run_id' instead.",
|
||||
)
|
||||
|
||||
# Setup user config
|
||||
setup_config()
|
||||
@@ -82,14 +86,10 @@ class MemoryClient:
|
||||
response = self.client.get("/v1/memories/", params={"user_id": "test"})
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError:
|
||||
raise ValueError(
|
||||
"Invalid API Key. Please get a valid API Key from https://app.mem0.ai"
|
||||
)
|
||||
raise ValueError("Invalid API Key. Please get a valid API Key from https://app.mem0.ai")
|
||||
|
||||
@api_error_handler
|
||||
def add(
|
||||
self, messages: Union[str, List[Dict[str, str]]], **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
def add(self, messages: Union[str, List[Dict[str, str]]], **kwargs) -> Dict[str, Any]:
|
||||
"""Add a new memory.
|
||||
|
||||
Args:
|
||||
@@ -253,9 +253,7 @@ class MemoryClient:
|
||||
"""Delete all users, agents, or sessions."""
|
||||
entities = self.users()
|
||||
for entity in entities["results"]:
|
||||
response = self.client.delete(
|
||||
f"/v1/entities/{entity['type']}/{entity['id']}/"
|
||||
)
|
||||
response = self.client.delete(f"/v1/entities/{entity['type']}/{entity['id']}/")
|
||||
response.raise_for_status()
|
||||
|
||||
capture_client_event("client.delete_users", self)
|
||||
@@ -312,7 +310,7 @@ class MemoryClient:
|
||||
"The 'session_id' parameter is deprecated and will be removed in version 0.1.20. "
|
||||
"Use 'run_id' instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
stacklevel=2,
|
||||
)
|
||||
kwargs["run_id"] = kwargs.pop("session_id")
|
||||
|
||||
@@ -335,7 +333,7 @@ class MemoryClient:
|
||||
"The 'session_id' parameter is deprecated and will be removed in version 0.1.20. "
|
||||
"Use 'run_id' instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
stacklevel=2,
|
||||
)
|
||||
kwargs["run_id"] = kwargs.pop("session_id")
|
||||
|
||||
|
||||
@@ -17,18 +17,10 @@ class MemoryItem(BaseModel):
|
||||
) # TODO After prompt changes from platform, update this
|
||||
hash: Optional[str] = Field(None, description="The hash of the memory")
|
||||
# The metadata value can be anything and not just string. Fix it
|
||||
metadata: Optional[Dict[str, Any]] = Field(
|
||||
None, description="Additional metadata for the text data"
|
||||
)
|
||||
score: Optional[float] = Field(
|
||||
None, description="The score associated with the text data"
|
||||
)
|
||||
created_at: Optional[str] = Field(
|
||||
None, description="The timestamp when the memory was created"
|
||||
)
|
||||
updated_at: Optional[str] = Field(
|
||||
None, description="The timestamp when the memory was updated"
|
||||
)
|
||||
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata for the text data")
|
||||
score: Optional[float] = Field(None, description="The score associated with the text data")
|
||||
created_at: Optional[str] = Field(None, description="The timestamp when the memory was created")
|
||||
updated_at: Optional[str] = Field(None, description="The timestamp when the memory was updated")
|
||||
|
||||
|
||||
class MemoryConfig(BaseModel):
|
||||
@@ -60,7 +52,7 @@ class MemoryConfig(BaseModel):
|
||||
description="Custom prompt for the memory",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
class AzureConfig(BaseModel):
|
||||
"""
|
||||
@@ -73,7 +65,10 @@ class AzureConfig(BaseModel):
|
||||
api_version (str): The version of the Azure API being used.
|
||||
"""
|
||||
|
||||
api_key: str = Field(description="The API key used for authenticating with the Azure service.", default=None)
|
||||
azure_deployment : str = Field(description="The name of the Azure deployment.", default=None)
|
||||
azure_endpoint : str = Field(description="The endpoint URL for the Azure service.", default=None)
|
||||
api_version : str = Field(description="The version of the Azure API being used.", default=None)
|
||||
api_key: str = Field(
|
||||
description="The API key used for authenticating with the Azure service.",
|
||||
default=None,
|
||||
)
|
||||
azure_deployment: str = Field(description="The name of the Azure deployment.", default=None)
|
||||
azure_endpoint: str = Field(description="The endpoint URL for the Azure service.", default=None)
|
||||
api_version: str = Field(description="The version of the Azure API being used.", default=None)
|
||||
|
||||
@@ -60,6 +60,6 @@ class BaseEmbedderConfig(ABC):
|
||||
|
||||
# Huggingface specific
|
||||
self.model_kwargs = model_kwargs or {}
|
||||
|
||||
|
||||
# AzureOpenAI specific
|
||||
self.azure_kwargs = AzureConfig(**azure_kwargs) or {}
|
||||
|
||||
@@ -59,6 +59,7 @@ You should detect the language of the user input and record the facts in the sam
|
||||
If you do not find anything relevant facts, user memories, and preferences in the below conversation, you can return an empty list corresponding to the "facts" key.
|
||||
"""
|
||||
|
||||
|
||||
def get_update_memory_messages(retrieved_old_memory_dict, response_content):
|
||||
return f"""You are a smart memory manager which controls the memory of a system.
|
||||
You can perform four operations: (1) add into the memory, (2) update the memory, (3) delete from the memory, and (4) no change.
|
||||
|
||||
@@ -13,9 +13,7 @@ class ChromaDbConfig(BaseModel):
|
||||
Client: ClassVar[type] = Client
|
||||
|
||||
collection_name: str = Field("mem0", description="Default name for the collection")
|
||||
client: Optional[Client] = Field(
|
||||
None, description="Existing ChromaDB client instance"
|
||||
)
|
||||
client: Optional[Client] = Field(None, description="Existing ChromaDB client instance")
|
||||
path: Optional[str] = Field(None, description="Path to the database directory")
|
||||
host: Optional[str] = Field(None, description="Database connection remote host")
|
||||
port: Optional[int] = Field(None, description="Database connection remote port")
|
||||
|
||||
@@ -1,22 +1,24 @@
|
||||
from enum import Enum
|
||||
from typing import Dict, Any
|
||||
from pydantic import BaseModel, model_validator, Field
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class MetricType(str, Enum):
|
||||
"""
|
||||
Metric Constant for milvus/ zilliz server.
|
||||
"""
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.value)
|
||||
|
||||
|
||||
L2 = "L2"
|
||||
IP = "IP"
|
||||
COSINE = "COSINE"
|
||||
HAMMING = "HAMMING"
|
||||
JACCARD = "JACCARD"
|
||||
|
||||
|
||||
IP = "IP"
|
||||
COSINE = "COSINE"
|
||||
HAMMING = "HAMMING"
|
||||
JACCARD = "JACCARD"
|
||||
|
||||
|
||||
class MilvusDBConfig(BaseModel):
|
||||
url: str = Field("http://localhost:19530", description="Full URL for Milvus/Zilliz server")
|
||||
token: str = Field(None, description="Token for Zilliz server / local setup defaults to None.")
|
||||
@@ -38,4 +40,4 @@ class MilvusDBConfig(BaseModel):
|
||||
|
||||
model_config = {
|
||||
"arbitrary_types_allowed": True,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,12 +4,9 @@ from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class PGVectorConfig(BaseModel):
|
||||
|
||||
dbname: str = Field("postgres", description="Default name for the database")
|
||||
collection_name: str = Field("mem0", description="Default name for the collection")
|
||||
embedding_model_dims: Optional[int] = Field(
|
||||
1536, description="Dimensions of the embedding model"
|
||||
)
|
||||
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
|
||||
user: Optional[str] = Field(None, description="Database user")
|
||||
password: Optional[str] = Field(None, description="Database password")
|
||||
host: Optional[str] = Field(None, description="Database host. Default is localhost")
|
||||
|
||||
@@ -9,17 +9,11 @@ class QdrantConfig(BaseModel):
|
||||
QdrantClient: ClassVar[type] = QdrantClient
|
||||
|
||||
collection_name: str = Field("mem0", description="Name of the collection")
|
||||
embedding_model_dims: Optional[int] = Field(
|
||||
1536, description="Dimensions of the embedding model"
|
||||
)
|
||||
client: Optional[QdrantClient] = Field(
|
||||
None, description="Existing Qdrant client instance"
|
||||
)
|
||||
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
|
||||
client: Optional[QdrantClient] = Field(None, description="Existing Qdrant client instance")
|
||||
host: Optional[str] = Field(None, description="Host address for Qdrant server")
|
||||
port: Optional[int] = Field(None, description="Port for Qdrant server")
|
||||
path: Optional[str] = Field(
|
||||
"/tmp/qdrant", description="Path for local Qdrant database"
|
||||
)
|
||||
path: Optional[str] = Field("/tmp/qdrant", description="Path for local Qdrant database")
|
||||
url: Optional[str] = Field(None, description="Full URL for Qdrant server")
|
||||
api_key: Optional[str] = Field(None, description="API key for Qdrant server")
|
||||
on_disk: Optional[bool] = Field(False, description="Enables persistent storage")
|
||||
@@ -35,9 +29,7 @@ class QdrantConfig(BaseModel):
|
||||
values.get("api_key"),
|
||||
)
|
||||
if not path and not (host and port) and not (url and api_key):
|
||||
raise ValueError(
|
||||
"Either 'host' and 'port' or 'url' and 'api_key' or 'path' must be provided."
|
||||
)
|
||||
raise ValueError("Either 'host' and 'port' or 'url' and 'api_key' or 'path' must be provided.")
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
|
||||
@@ -15,14 +15,14 @@ class AzureOpenAIEmbedding(EmbeddingBase):
|
||||
azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv("EMBEDDING_AZURE_DEPLOYMENT")
|
||||
azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv("EMBEDDING_AZURE_ENDPOINT")
|
||||
api_version = self.config.azure_kwargs.api_version or os.getenv("EMBEDDING_AZURE_API_VERSION")
|
||||
|
||||
|
||||
self.client = AzureOpenAI(
|
||||
azure_deployment=azure_deployment,
|
||||
azure_deployment=azure_deployment,
|
||||
azure_endpoint=azure_endpoint,
|
||||
api_version=api_version,
|
||||
api_key=api_key,
|
||||
http_client=self.config.http_client
|
||||
)
|
||||
http_client=self.config.http_client,
|
||||
)
|
||||
|
||||
def embed(self, text):
|
||||
"""
|
||||
@@ -35,8 +35,4 @@ class AzureOpenAIEmbedding(EmbeddingBase):
|
||||
list: The embedding vector.
|
||||
"""
|
||||
text = text.replace("\n", " ")
|
||||
return (
|
||||
self.client.embeddings.create(input=[text], model=self.config.model)
|
||||
.data[0]
|
||||
.embedding
|
||||
)
|
||||
return self.client.embeddings.create(input=[text], model=self.config.model).data[0].embedding
|
||||
|
||||
@@ -8,9 +8,7 @@ class EmbedderConfig(BaseModel):
|
||||
description="Provider of the embedding model (e.g., 'ollama', 'openai')",
|
||||
default="openai",
|
||||
)
|
||||
config: Optional[dict] = Field(
|
||||
description="Configuration for the specific embedding model", default={}
|
||||
)
|
||||
config: Optional[dict] = Field(description="Configuration for the specific embedding model", default={})
|
||||
|
||||
@field_validator("config")
|
||||
def validate_config(cls, v, values):
|
||||
|
||||
@@ -9,7 +9,7 @@ try:
|
||||
from ollama import Client
|
||||
except ImportError:
|
||||
user_input = input("The 'ollama' library is required. Install it now? [y/N]: ")
|
||||
if user_input.lower() == 'y':
|
||||
if user_input.lower() == "y":
|
||||
try:
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "ollama"])
|
||||
from ollama import Client
|
||||
|
||||
@@ -29,8 +29,4 @@ class OpenAIEmbedding(EmbeddingBase):
|
||||
list: The embedding vector.
|
||||
"""
|
||||
text = text.replace("\n", " ")
|
||||
return (
|
||||
self.client.embeddings.create(input=[text], model=self.config.model)
|
||||
.data[0]
|
||||
.embedding
|
||||
)
|
||||
return self.client.embeddings.create(input=[text], model=self.config.model).data[0].embedding
|
||||
|
||||
@@ -6,6 +6,7 @@ from vertexai.language_models import TextEmbeddingModel
|
||||
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
||||
from mem0.embeddings.base import EmbeddingBase
|
||||
|
||||
|
||||
class VertexAI(EmbeddingBase):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config)
|
||||
@@ -34,6 +35,6 @@ class VertexAI(EmbeddingBase):
|
||||
Returns:
|
||||
list: The embedding vector.
|
||||
"""
|
||||
embeddings = self.model.get_embeddings(texts=[text], output_dimensionality= self.config.embedding_dims)
|
||||
|
||||
embeddings = self.model.get_embeddings(texts=[text], output_dimensionality=self.config.embedding_dims)
|
||||
|
||||
return embeddings[0].values
|
||||
|
||||
@@ -18,28 +18,16 @@ class Neo4jConfig(BaseModel):
|
||||
values.get("password"),
|
||||
)
|
||||
if not url or not username or not password:
|
||||
raise ValueError(
|
||||
"Please provide 'url', 'username' and 'password'."
|
||||
)
|
||||
raise ValueError("Please provide 'url', 'username' and 'password'.")
|
||||
return values
|
||||
|
||||
|
||||
class GraphStoreConfig(BaseModel):
|
||||
provider: str = Field(
|
||||
description="Provider of the data store (e.g., 'neo4j')",
|
||||
default="neo4j"
|
||||
)
|
||||
config: Neo4jConfig = Field(
|
||||
description="Configuration for the specific data store",
|
||||
default=None
|
||||
)
|
||||
llm: Optional[LlmConfig] = Field(
|
||||
description="LLM configuration for querying the graph store",
|
||||
default=None
|
||||
)
|
||||
provider: str = Field(description="Provider of the data store (e.g., 'neo4j')", default="neo4j")
|
||||
config: Neo4jConfig = Field(description="Configuration for the specific data store", default=None)
|
||||
llm: Optional[LlmConfig] = Field(description="LLM configuration for querying the graph store", default=None)
|
||||
custom_prompt: Optional[str] = Field(
|
||||
description="Custom prompt to fetch entities from the given text",
|
||||
default=None
|
||||
description="Custom prompt to fetch entities from the given text", default=None
|
||||
)
|
||||
|
||||
@field_validator("config")
|
||||
@@ -49,4 +37,3 @@ class GraphStoreConfig(BaseModel):
|
||||
return Neo4jConfig(**v.model_dump())
|
||||
else:
|
||||
raise ValueError(f"Unsupported graph store provider: {provider}")
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
UPDATE_MEMORY_TOOL_GRAPH = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
@@ -9,21 +8,21 @@ UPDATE_MEMORY_TOOL_GRAPH = {
|
||||
"properties": {
|
||||
"source": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the source node in the relationship to be updated. This should match an existing node in the graph."
|
||||
"description": "The identifier of the source node in the relationship to be updated. This should match an existing node in the graph.",
|
||||
},
|
||||
"destination": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the destination node in the relationship to be updated. This should match an existing node in the graph."
|
||||
"description": "The identifier of the destination node in the relationship to be updated. This should match an existing node in the graph.",
|
||||
},
|
||||
"relationship": {
|
||||
"type": "string",
|
||||
"description": "The new or updated relationship between the source and destination nodes. This should be a concise, clear description of how the two nodes are connected."
|
||||
}
|
||||
"description": "The new or updated relationship between the source and destination nodes. This should be a concise, clear description of how the two nodes are connected.",
|
||||
},
|
||||
},
|
||||
"required": ["source", "destination", "relationship"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ADD_MEMORY_TOOL_GRAPH = {
|
||||
@@ -36,29 +35,35 @@ ADD_MEMORY_TOOL_GRAPH = {
|
||||
"properties": {
|
||||
"source": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the source node in the new relationship. This can be an existing node or a new node to be created."
|
||||
"description": "The identifier of the source node in the new relationship. This can be an existing node or a new node to be created.",
|
||||
},
|
||||
"destination": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the destination node in the new relationship. This can be an existing node or a new node to be created."
|
||||
"description": "The identifier of the destination node in the new relationship. This can be an existing node or a new node to be created.",
|
||||
},
|
||||
"relationship": {
|
||||
"type": "string",
|
||||
"description": "The type of relationship between the source and destination nodes. This should be a concise, clear description of how the two nodes are connected."
|
||||
"description": "The type of relationship between the source and destination nodes. This should be a concise, clear description of how the two nodes are connected.",
|
||||
},
|
||||
"source_type": {
|
||||
"type": "string",
|
||||
"description": "The type or category of the source node. This helps in classifying and organizing nodes in the graph."
|
||||
"description": "The type or category of the source node. This helps in classifying and organizing nodes in the graph.",
|
||||
},
|
||||
"destination_type": {
|
||||
"type": "string",
|
||||
"description": "The type or category of the destination node. This helps in classifying and organizing nodes in the graph."
|
||||
}
|
||||
"description": "The type or category of the destination node. This helps in classifying and organizing nodes in the graph.",
|
||||
},
|
||||
},
|
||||
"required": ["source", "destination", "relationship", "source_type", "destination_type"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
"required": [
|
||||
"source",
|
||||
"destination",
|
||||
"relationship",
|
||||
"source_type",
|
||||
"destination_type",
|
||||
],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -71,9 +76,9 @@ NOOP_TOOL = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -94,17 +99,23 @@ ADD_MESSAGE_TOOL = {
|
||||
"source_type": {"type": "string"},
|
||||
"relation": {"type": "string"},
|
||||
"destination_node": {"type": "string"},
|
||||
"destination_type": {"type": "string"}
|
||||
"destination_type": {"type": "string"},
|
||||
},
|
||||
"required": ["source_node", "source_type", "relation", "destination_node", "destination_type"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
"required": [
|
||||
"source_node",
|
||||
"source_type",
|
||||
"relation",
|
||||
"destination_node",
|
||||
"destination_type",
|
||||
],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
}
|
||||
},
|
||||
"required": ["entities"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -118,23 +129,19 @@ SEARCH_TOOL = {
|
||||
"properties": {
|
||||
"nodes": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "List of nodes to search for."
|
||||
"items": {"type": "string"},
|
||||
"description": "List of nodes to search for.",
|
||||
},
|
||||
"relations": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "List of relations to search for."
|
||||
}
|
||||
"items": {"type": "string"},
|
||||
"description": "List of relations to search for.",
|
||||
},
|
||||
},
|
||||
"required": ["nodes", "relations"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
UPDATE_MEMORY_STRUCT_TOOL_GRAPH = {
|
||||
@@ -148,21 +155,21 @@ UPDATE_MEMORY_STRUCT_TOOL_GRAPH = {
|
||||
"properties": {
|
||||
"source": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the source node in the relationship to be updated. This should match an existing node in the graph."
|
||||
"description": "The identifier of the source node in the relationship to be updated. This should match an existing node in the graph.",
|
||||
},
|
||||
"destination": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the destination node in the relationship to be updated. This should match an existing node in the graph."
|
||||
"description": "The identifier of the destination node in the relationship to be updated. This should match an existing node in the graph.",
|
||||
},
|
||||
"relationship": {
|
||||
"type": "string",
|
||||
"description": "The new or updated relationship between the source and destination nodes. This should be a concise, clear description of how the two nodes are connected."
|
||||
}
|
||||
"description": "The new or updated relationship between the source and destination nodes. This should be a concise, clear description of how the two nodes are connected.",
|
||||
},
|
||||
},
|
||||
"required": ["source", "destination", "relationship"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ADD_MEMORY_STRUCT_TOOL_GRAPH = {
|
||||
@@ -176,29 +183,35 @@ ADD_MEMORY_STRUCT_TOOL_GRAPH = {
|
||||
"properties": {
|
||||
"source": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the source node in the new relationship. This can be an existing node or a new node to be created."
|
||||
"description": "The identifier of the source node in the new relationship. This can be an existing node or a new node to be created.",
|
||||
},
|
||||
"destination": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the destination node in the new relationship. This can be an existing node or a new node to be created."
|
||||
"description": "The identifier of the destination node in the new relationship. This can be an existing node or a new node to be created.",
|
||||
},
|
||||
"relationship": {
|
||||
"type": "string",
|
||||
"description": "The type of relationship between the source and destination nodes. This should be a concise, clear description of how the two nodes are connected."
|
||||
"description": "The type of relationship between the source and destination nodes. This should be a concise, clear description of how the two nodes are connected.",
|
||||
},
|
||||
"source_type": {
|
||||
"type": "string",
|
||||
"description": "The type or category of the source node. This helps in classifying and organizing nodes in the graph."
|
||||
"description": "The type or category of the source node. This helps in classifying and organizing nodes in the graph.",
|
||||
},
|
||||
"destination_type": {
|
||||
"type": "string",
|
||||
"description": "The type or category of the destination node. This helps in classifying and organizing nodes in the graph."
|
||||
}
|
||||
"description": "The type or category of the destination node. This helps in classifying and organizing nodes in the graph.",
|
||||
},
|
||||
},
|
||||
"required": ["source", "destination", "relationship", "source_type", "destination_type"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
"required": [
|
||||
"source",
|
||||
"destination",
|
||||
"relationship",
|
||||
"source_type",
|
||||
"destination_type",
|
||||
],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -212,9 +225,9 @@ NOOP_STRUCT_TOOL = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -236,17 +249,23 @@ ADD_MESSAGE_STRUCT_TOOL = {
|
||||
"source_type": {"type": "string"},
|
||||
"relation": {"type": "string"},
|
||||
"destination_node": {"type": "string"},
|
||||
"destination_type": {"type": "string"}
|
||||
"destination_type": {"type": "string"},
|
||||
},
|
||||
"required": ["source_node", "source_type", "relation", "destination_node", "destination_type"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
"required": [
|
||||
"source_node",
|
||||
"source_type",
|
||||
"relation",
|
||||
"destination_node",
|
||||
"destination_type",
|
||||
],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
}
|
||||
},
|
||||
"required": ["entities"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -261,21 +280,17 @@ SEARCH_STRUCT_TOOL = {
|
||||
"properties": {
|
||||
"nodes": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "List of nodes to search for."
|
||||
"items": {"type": "string"},
|
||||
"description": "List of nodes to search for.",
|
||||
},
|
||||
"relations": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "List of relations to search for."
|
||||
}
|
||||
"items": {"type": "string"},
|
||||
"description": "List of relations to search for.",
|
||||
},
|
||||
},
|
||||
"required": ["nodes", "relations"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
UPDATE_GRAPH_PROMPT = """
|
||||
You are an AI expert specializing in graph memory management and optimization. Your task is to analyze existing graph memories alongside new information, and update the relationships in the memory list to ensure the most accurate, current, and coherent representation of knowledge.
|
||||
|
||||
@@ -55,10 +54,10 @@ Strive for a coherent, easily understandable knowledge graph by maintaining cons
|
||||
Adhere strictly to these guidelines to ensure high-quality knowledge graph extraction."""
|
||||
|
||||
|
||||
|
||||
def get_update_memory_prompt(existing_memories, memory, template):
|
||||
return template.format(existing_memories=existing_memories, memory=memory)
|
||||
|
||||
|
||||
def get_update_memory_messages(existing_memories, memory):
|
||||
return [
|
||||
{
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Dict, List, Optional
|
||||
try:
|
||||
import anthropic
|
||||
except ImportError:
|
||||
raise ImportError("The 'anthropic' library is required. Please install it using 'pip install anthropic'.")
|
||||
raise ImportError("The 'anthropic' library is required. Please install it using 'pip install anthropic'.")
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
@@ -43,8 +43,8 @@ class AnthropicLLM(LLMBase):
|
||||
system_message = ""
|
||||
filtered_messages = []
|
||||
for message in messages:
|
||||
if message['role'] == 'system':
|
||||
system_message = message['content']
|
||||
if message["role"] == "system":
|
||||
system_message = message["content"]
|
||||
else:
|
||||
filtered_messages.append(message)
|
||||
|
||||
@@ -56,7 +56,7 @@ class AnthropicLLM(LLMBase):
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"top_p": self.config.top_p,
|
||||
}
|
||||
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
|
||||
@@ -125,9 +125,7 @@ class AWSBedrockLLM(LLMBase):
|
||||
},
|
||||
}
|
||||
input_body["textGenerationConfig"] = {
|
||||
k: v
|
||||
for k, v in input_body["textGenerationConfig"].items()
|
||||
if v is not None
|
||||
k: v for k, v in input_body["textGenerationConfig"].items() if v is not None
|
||||
}
|
||||
|
||||
return input_body
|
||||
@@ -161,9 +159,7 @@ class AWSBedrockLLM(LLMBase):
|
||||
}
|
||||
}
|
||||
|
||||
for prop, details in (
|
||||
function["parameters"].get("properties", {}).items()
|
||||
):
|
||||
for prop, details in function["parameters"].get("properties", {}).items():
|
||||
new_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop] = {
|
||||
"type": details.get("type", "string"),
|
||||
"description": details.get("description", ""),
|
||||
@@ -216,9 +212,7 @@ class AWSBedrockLLM(LLMBase):
|
||||
# Use invoke_model method when no tools are provided
|
||||
prompt = self._format_messages(messages)
|
||||
provider = self.model.split(".")[0]
|
||||
input_body = self._prepare_input(
|
||||
provider, self.config.model, prompt, **self.model_kwargs
|
||||
)
|
||||
input_body = self._prepare_input(provider, self.config.model, prompt, **self.model_kwargs)
|
||||
body = json.dumps(input_body)
|
||||
|
||||
response = self.client.invoke_model(
|
||||
|
||||
@@ -15,20 +15,20 @@ class AzureOpenAILLM(LLMBase):
|
||||
# Model name should match the custom deployment name chosen for it.
|
||||
if not self.config.model:
|
||||
self.config.model = "gpt-4o"
|
||||
|
||||
|
||||
api_key = self.config.azure_kwargs.api_key or os.getenv("LLM_AZURE_OPENAI_API_KEY")
|
||||
azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv("LLM_AZURE_DEPLOYMENT")
|
||||
azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv("LLM_AZURE_ENDPOINT")
|
||||
api_version = self.config.azure_kwargs.api_version or os.getenv("LLM_AZURE_API_VERSION")
|
||||
|
||||
self.client = AzureOpenAI(
|
||||
azure_deployment=azure_deployment,
|
||||
azure_deployment=azure_deployment,
|
||||
azure_endpoint=azure_endpoint,
|
||||
api_version=api_version,
|
||||
api_key=api_key,
|
||||
http_client=self.config.http_client
|
||||
)
|
||||
|
||||
http_client=self.config.http_client,
|
||||
)
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
@@ -87,7 +87,7 @@ class AzureOpenAILLM(LLMBase):
|
||||
}
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from openai import AzureOpenAI
|
||||
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class AzureOpenAIStructuredLLM(LLMBase):
|
||||
@@ -15,21 +15,21 @@ class AzureOpenAIStructuredLLM(LLMBase):
|
||||
# Model name should match the custom deployment name chosen for it.
|
||||
if not self.config.model:
|
||||
self.config.model = "gpt-4o-2024-08-06"
|
||||
|
||||
|
||||
api_key = os.getenv("LLM_AZURE_OPENAI_API_KEY") or self.config.azure_kwargs.api_key
|
||||
azure_deployment = os.getenv("LLM_AZURE_DEPLOYMENT") or self.config.azure_kwargs.azure_deployment
|
||||
azure_endpoint = os.getenv("LLM_AZURE_ENDPOINT") or self.config.azure_kwargs.azure_endpoint
|
||||
api_version = os.getenv("LLM_AZURE_API_VERSION") or self.config.azure_kwargs.api_version
|
||||
# Can display a warning if API version is of model and api-version
|
||||
|
||||
|
||||
self.client = AzureOpenAI(
|
||||
azure_deployment=azure_deployment,
|
||||
azure_deployment=azure_deployment,
|
||||
azure_endpoint=azure_endpoint,
|
||||
api_version=api_version,
|
||||
api_key=api_key,
|
||||
http_client=self.config.http_client
|
||||
)
|
||||
|
||||
http_client=self.config.http_client,
|
||||
)
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
|
||||
@@ -4,12 +4,8 @@ from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class LlmConfig(BaseModel):
|
||||
provider: str = Field(
|
||||
description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai"
|
||||
)
|
||||
config: Optional[dict] = Field(
|
||||
description="Configuration for the specific LLM", default={}
|
||||
)
|
||||
provider: str = Field(description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai")
|
||||
config: Optional[dict] = Field(description="Configuration for the specific LLM", default={})
|
||||
|
||||
@field_validator("config")
|
||||
def validate_config(cls, v, values):
|
||||
@@ -23,7 +19,7 @@ class LlmConfig(BaseModel):
|
||||
"litellm",
|
||||
"azure_openai",
|
||||
"openai_structured",
|
||||
"azure_openai_structured"
|
||||
"azure_openai_structured",
|
||||
):
|
||||
return v
|
||||
else:
|
||||
|
||||
@@ -67,9 +67,7 @@ class LiteLLM(LLMBase):
|
||||
str: The generated response.
|
||||
"""
|
||||
if not litellm.supports_function_calling(self.config.model):
|
||||
raise ValueError(
|
||||
f"Model '{self.config.model}' in litellm does not support function calling."
|
||||
)
|
||||
raise ValueError(f"Model '{self.config.model}' in litellm does not support function calling.")
|
||||
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
@@ -80,7 +78,7 @@ class LiteLLM(LLMBase):
|
||||
}
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
|
||||
@@ -100,7 +100,7 @@ class OpenAILLM(LLMBase):
|
||||
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
import json
|
||||
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
@@ -20,7 +19,6 @@ class OpenAIStructuredLLM(LLMBase):
|
||||
base_url = self.config.openai_base_url or os.getenv("OPENAI_API_BASE")
|
||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
@@ -31,8 +29,8 @@ class OpenAIStructuredLLM(LLMBase):
|
||||
|
||||
Returns:
|
||||
str or dict: The processed response.
|
||||
"""
|
||||
|
||||
"""
|
||||
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": response.choices[0].message.content,
|
||||
@@ -52,7 +50,6 @@ class OpenAIStructuredLLM(LLMBase):
|
||||
|
||||
else:
|
||||
return response.choices[0].message.content
|
||||
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
@@ -87,4 +84,4 @@ class OpenAIStructuredLLM(LLMBase):
|
||||
|
||||
response = self.client.beta.chat.completions.parse(**params)
|
||||
|
||||
return self._parse_response(response, tools)
|
||||
return self._parse_response(response, tools)
|
||||
|
||||
@@ -20,7 +20,7 @@ class TogetherLLM(LLMBase):
|
||||
|
||||
api_key = self.config.api_key or os.getenv("TOGETHER_API_KEY")
|
||||
self.client = Together(api_key=api_key)
|
||||
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
@@ -79,7 +79,7 @@ class TogetherLLM(LLMBase):
|
||||
}
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
|
||||
@@ -7,11 +7,9 @@ ADD_MEMORY_TOOL = {
|
||||
"description": "Add a memory",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {"type": "string", "description": "Data to add to memory"}
|
||||
},
|
||||
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
|
||||
"required": ["data"],
|
||||
"additionalProperties": False
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -34,7 +32,7 @@ UPDATE_MEMORY_TOOL = {
|
||||
},
|
||||
},
|
||||
"required": ["memory_id", "data"],
|
||||
"additionalProperties": False
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -53,7 +51,7 @@ DELETE_MEMORY_TOOL = {
|
||||
}
|
||||
},
|
||||
"required": ["memory_id"],
|
||||
"additionalProperties": False
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -3,30 +3,28 @@ import logging
|
||||
from langchain_community.graphs import Neo4jGraph
|
||||
from rank_bm25 import BM25Okapi
|
||||
|
||||
from mem0.graphs.tools import (
|
||||
ADD_MEMORY_TOOL_GRAPH,
|
||||
ADD_MESSAGE_TOOL,
|
||||
NOOP_TOOL,
|
||||
SEARCH_TOOL,
|
||||
UPDATE_MEMORY_TOOL_GRAPH,
|
||||
UPDATE_MEMORY_STRUCT_TOOL_GRAPH,
|
||||
ADD_MEMORY_STRUCT_TOOL_GRAPH,
|
||||
NOOP_STRUCT_TOOL,
|
||||
ADD_MESSAGE_STRUCT_TOOL,
|
||||
SEARCH_STRUCT_TOOL
|
||||
)
|
||||
from mem0.graphs.utils import EXTRACT_ENTITIES_PROMPT, get_update_memory_messages
|
||||
from mem0.graphs.tools import (ADD_MEMORY_STRUCT_TOOL_GRAPH,
|
||||
ADD_MEMORY_TOOL_GRAPH, ADD_MESSAGE_STRUCT_TOOL,
|
||||
ADD_MESSAGE_TOOL, NOOP_STRUCT_TOOL, NOOP_TOOL,
|
||||
SEARCH_STRUCT_TOOL, SEARCH_TOOL,
|
||||
UPDATE_MEMORY_STRUCT_TOOL_GRAPH,
|
||||
UPDATE_MEMORY_TOOL_GRAPH)
|
||||
from mem0.graphs.utils import (EXTRACT_ENTITIES_PROMPT,
|
||||
get_update_memory_messages)
|
||||
from mem0.utils.factory import EmbedderFactory, LlmFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryGraph:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.graph = Neo4jGraph(self.config.graph_store.config.url, self.config.graph_store.config.username, self.config.graph_store.config.password)
|
||||
self.embedding_model = EmbedderFactory.create(
|
||||
self.config.embedder.provider, self.config.embedder.config
|
||||
self.graph = Neo4jGraph(
|
||||
self.config.graph_store.config.url,
|
||||
self.config.graph_store.config.username,
|
||||
self.config.graph_store.config.password,
|
||||
)
|
||||
self.embedding_model = EmbedderFactory.create(self.config.embedder.provider, self.config.embedder.config)
|
||||
|
||||
self.llm_provider = "openai_structured"
|
||||
if self.config.llm.provider:
|
||||
@@ -51,15 +49,23 @@ class MemoryGraph:
|
||||
search_output = self._search(data, filters)
|
||||
|
||||
if self.config.graph_store.custom_prompt:
|
||||
messages=[
|
||||
{"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id).replace("CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}")},
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id).replace(
|
||||
"CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}"
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": data},
|
||||
]
|
||||
else:
|
||||
messages=[
|
||||
{"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id)},
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id),
|
||||
},
|
||||
{"role": "user", "content": data},
|
||||
]
|
||||
]
|
||||
|
||||
_tools = [ADD_MESSAGE_TOOL]
|
||||
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
||||
@@ -67,11 +73,11 @@ class MemoryGraph:
|
||||
|
||||
extracted_entities = self.llm.generate_response(
|
||||
messages=messages,
|
||||
tools = _tools,
|
||||
tools=_tools,
|
||||
)
|
||||
|
||||
if extracted_entities['tool_calls']:
|
||||
extracted_entities = extracted_entities['tool_calls'][0]['arguments']['entities']
|
||||
if extracted_entities["tool_calls"]:
|
||||
extracted_entities = extracted_entities["tool_calls"][0]["arguments"]["entities"]
|
||||
else:
|
||||
extracted_entities = []
|
||||
|
||||
@@ -79,9 +85,13 @@ class MemoryGraph:
|
||||
|
||||
update_memory_prompt = get_update_memory_messages(search_output, extracted_entities)
|
||||
|
||||
_tools=[UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL]
|
||||
if self.llm_provider in ["azure_openai_structured","openai_structured"]:
|
||||
_tools = [UPDATE_MEMORY_STRUCT_TOOL_GRAPH, ADD_MEMORY_STRUCT_TOOL_GRAPH, NOOP_STRUCT_TOOL]
|
||||
_tools = [UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL]
|
||||
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
||||
_tools = [
|
||||
UPDATE_MEMORY_STRUCT_TOOL_GRAPH,
|
||||
ADD_MEMORY_STRUCT_TOOL_GRAPH,
|
||||
NOOP_STRUCT_TOOL,
|
||||
]
|
||||
|
||||
memory_updates = self.llm.generate_response(
|
||||
messages=update_memory_prompt,
|
||||
@@ -90,28 +100,29 @@ class MemoryGraph:
|
||||
|
||||
to_be_added = []
|
||||
|
||||
for item in memory_updates['tool_calls']:
|
||||
if item['name'] == "add_graph_memory":
|
||||
to_be_added.append(item['arguments'])
|
||||
elif item['name'] == "update_graph_memory":
|
||||
self._update_relationship(item['arguments']['source'], item['arguments']['destination'], item['arguments']['relationship'], filters)
|
||||
elif item['name'] == "noop":
|
||||
for item in memory_updates["tool_calls"]:
|
||||
if item["name"] == "add_graph_memory":
|
||||
to_be_added.append(item["arguments"])
|
||||
elif item["name"] == "update_graph_memory":
|
||||
self._update_relationship(
|
||||
item["arguments"]["source"],
|
||||
item["arguments"]["destination"],
|
||||
item["arguments"]["relationship"],
|
||||
filters,
|
||||
)
|
||||
elif item["name"] == "noop":
|
||||
continue
|
||||
|
||||
returned_entities = []
|
||||
|
||||
for item in to_be_added:
|
||||
source = item['source'].lower().replace(" ", "_")
|
||||
source_type = item['source_type'].lower().replace(" ", "_")
|
||||
relation = item['relationship'].lower().replace(" ", "_")
|
||||
destination = item['destination'].lower().replace(" ", "_")
|
||||
destination_type = item['destination_type'].lower().replace(" ", "_")
|
||||
source = item["source"].lower().replace(" ", "_")
|
||||
source_type = item["source_type"].lower().replace(" ", "_")
|
||||
relation = item["relationship"].lower().replace(" ", "_")
|
||||
destination = item["destination"].lower().replace(" ", "_")
|
||||
destination_type = item["destination_type"].lower().replace(" ", "_")
|
||||
|
||||
returned_entities.append({
|
||||
"source" : source,
|
||||
"relationship" : relation,
|
||||
"target" : destination
|
||||
})
|
||||
returned_entities.append({"source": source, "relationship": relation, "target": destination})
|
||||
|
||||
# Create embeddings
|
||||
source_embedding = self.embedding_model.embed(source)
|
||||
@@ -135,7 +146,7 @@ class MemoryGraph:
|
||||
"dest_name": destination,
|
||||
"source_embedding": source_embedding,
|
||||
"dest_embedding": dest_embedding,
|
||||
"user_id": filters["user_id"]
|
||||
"user_id": filters["user_id"],
|
||||
}
|
||||
|
||||
_ = self.graph.query(cypher, params=params)
|
||||
@@ -150,19 +161,22 @@ class MemoryGraph:
|
||||
_tools = [SEARCH_STRUCT_TOOL]
|
||||
search_results = self.llm.generate_response(
|
||||
messages=[
|
||||
{"role": "system", "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source node. Extract the entities."},
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source node. Extract the entities.",
|
||||
},
|
||||
{"role": "user", "content": query},
|
||||
],
|
||||
tools = _tools
|
||||
tools=_tools,
|
||||
)
|
||||
|
||||
node_list = []
|
||||
relation_list = []
|
||||
|
||||
for item in search_results['tool_calls']:
|
||||
if item['name'] == "search":
|
||||
for item in search_results["tool_calls"]:
|
||||
if item["name"] == "search":
|
||||
try:
|
||||
node_list.extend(item['arguments']['nodes'])
|
||||
node_list.extend(item["arguments"]["nodes"])
|
||||
except Exception as e:
|
||||
logger.error(f"Error in search tool: {e}")
|
||||
|
||||
@@ -201,13 +215,16 @@ class MemoryGraph:
|
||||
RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relation, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id, similarity
|
||||
ORDER BY similarity DESC
|
||||
"""
|
||||
params = {"n_embedding": n_embedding, "threshold": self.threshold, "user_id": filters["user_id"]}
|
||||
params = {
|
||||
"n_embedding": n_embedding,
|
||||
"threshold": self.threshold,
|
||||
"user_id": filters["user_id"],
|
||||
}
|
||||
ans = self.graph.query(cypher_query, params=params)
|
||||
result_relations.extend(ans)
|
||||
|
||||
return result_relations
|
||||
|
||||
|
||||
def search(self, query, filters):
|
||||
"""
|
||||
Search for memories and related graph data.
|
||||
@@ -235,17 +252,12 @@ class MemoryGraph:
|
||||
|
||||
search_results = []
|
||||
for item in reranked_results:
|
||||
search_results.append({
|
||||
"source": item[0],
|
||||
"relationship": item[1],
|
||||
"target": item[2]
|
||||
})
|
||||
search_results.append({"source": item[0], "relationship": item[1], "target": item[2]})
|
||||
|
||||
logger.info(f"Returned {len(search_results)} search results")
|
||||
|
||||
return search_results
|
||||
|
||||
|
||||
def delete_all(self, filters):
|
||||
cypher = """
|
||||
MATCH (n {user_id: $user_id})
|
||||
@@ -254,7 +266,6 @@ class MemoryGraph:
|
||||
params = {"user_id": filters["user_id"]}
|
||||
self.graph.query(cypher, params=params)
|
||||
|
||||
|
||||
def get_all(self, filters):
|
||||
"""
|
||||
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
|
||||
@@ -276,17 +287,18 @@ class MemoryGraph:
|
||||
|
||||
final_results = []
|
||||
for result in results:
|
||||
final_results.append({
|
||||
"source": result['source'],
|
||||
"relationship": result['relationship'],
|
||||
"target": result['target']
|
||||
})
|
||||
final_results.append(
|
||||
{
|
||||
"source": result["source"],
|
||||
"relationship": result["relationship"],
|
||||
"target": result["target"],
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Retrieved {len(final_results)} relationships")
|
||||
|
||||
return final_results
|
||||
|
||||
|
||||
|
||||
def _update_relationship(self, source, target, relationship, filters):
|
||||
"""
|
||||
Update or create a relationship between two nodes in the graph.
|
||||
@@ -309,14 +321,20 @@ class MemoryGraph:
|
||||
MERGE (n1 {name: $source, user_id: $user_id})
|
||||
MERGE (n2 {name: $target, user_id: $user_id})
|
||||
"""
|
||||
self.graph.query(check_and_create_query, params={"source": source, "target": target, "user_id": filters["user_id"]})
|
||||
self.graph.query(
|
||||
check_and_create_query,
|
||||
params={"source": source, "target": target, "user_id": filters["user_id"]},
|
||||
)
|
||||
|
||||
# Delete any existing relationship between the nodes
|
||||
delete_query = """
|
||||
MATCH (n1 {name: $source, user_id: $user_id})-[r]->(n2 {name: $target, user_id: $user_id})
|
||||
DELETE r
|
||||
"""
|
||||
self.graph.query(delete_query, params={"source": source, "target": target, "user_id": filters["user_id"]})
|
||||
self.graph.query(
|
||||
delete_query,
|
||||
params={"source": source, "target": target, "user_id": filters["user_id"]},
|
||||
)
|
||||
|
||||
# Create the new relationship
|
||||
create_query = f"""
|
||||
@@ -324,7 +342,10 @@ class MemoryGraph:
|
||||
CREATE (n1)-[r:{relationship}]->(n2)
|
||||
RETURN n1, r, n2
|
||||
"""
|
||||
result = self.graph.query(create_query, params={"source": source, "target": target, "user_id": filters["user_id"]})
|
||||
result = self.graph.query(
|
||||
create_query,
|
||||
params={"source": source, "target": target, "user_id": filters["user_id"]},
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise Exception(f"Failed to update or create relationship between {source} and {target}")
|
||||
|
||||
@@ -10,14 +10,14 @@ from typing import Any, Dict
|
||||
import pytz
|
||||
from pydantic import ValidationError
|
||||
|
||||
from mem0.configs.base import MemoryConfig, MemoryItem
|
||||
from mem0.configs.prompts import get_update_memory_messages
|
||||
from mem0.memory.base import MemoryBase
|
||||
from mem0.memory.setup import setup_config
|
||||
from mem0.memory.storage import SQLiteManager
|
||||
from mem0.memory.telemetry import capture_event
|
||||
from mem0.memory.utils import get_fact_retrieval_messages, parse_messages
|
||||
from mem0.utils.factory import LlmFactory, EmbedderFactory, VectorStoreFactory
|
||||
from mem0.configs.base import MemoryItem, MemoryConfig
|
||||
from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
|
||||
|
||||
# Setup user config
|
||||
setup_config()
|
||||
@@ -30,9 +30,7 @@ class Memory(MemoryBase):
|
||||
self.config = config
|
||||
|
||||
self.custom_prompt = self.config.custom_prompt
|
||||
self.embedding_model = EmbedderFactory.create(
|
||||
self.config.embedder.provider, self.config.embedder.config
|
||||
)
|
||||
self.embedding_model = EmbedderFactory.create(self.config.embedder.provider, self.config.embedder.config)
|
||||
self.vector_store = VectorStoreFactory.create(
|
||||
self.config.vector_store.provider, self.config.vector_store.config
|
||||
)
|
||||
@@ -45,12 +43,12 @@ class Memory(MemoryBase):
|
||||
|
||||
if self.version == "v1.1" and self.config.graph_store.config:
|
||||
from mem0.memory.graph_memory import MemoryGraph
|
||||
|
||||
self.graph = MemoryGraph(self.config)
|
||||
self.enable_graph = True
|
||||
|
||||
capture_event("mem0.init", self)
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config_dict: Dict[str, Any]):
|
||||
try:
|
||||
@@ -60,7 +58,6 @@ class Memory(MemoryBase):
|
||||
raise
|
||||
return cls(config)
|
||||
|
||||
|
||||
def add(
|
||||
self,
|
||||
messages,
|
||||
@@ -98,9 +95,7 @@ class Memory(MemoryBase):
|
||||
filters["run_id"] = metadata["run_id"] = run_id
|
||||
|
||||
if not any(key in filters for key in ("user_id", "agent_id", "run_id")):
|
||||
raise ValueError(
|
||||
"One of the filters: user_id, agent_id or run_id is required!"
|
||||
)
|
||||
raise ValueError("One of the filters: user_id, agent_id or run_id is required!")
|
||||
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
@@ -116,8 +111,8 @@ class Memory(MemoryBase):
|
||||
|
||||
if self.version == "v1.1":
|
||||
return {
|
||||
"results" : vector_store_result,
|
||||
"relations" : graph_result,
|
||||
"results": vector_store_result,
|
||||
"relations": graph_result,
|
||||
}
|
||||
else:
|
||||
warnings.warn(
|
||||
@@ -125,29 +120,29 @@ class Memory(MemoryBase):
|
||||
"To use the latest format, set `api_version='v1.1'`. "
|
||||
"The current format will be removed in mem0ai 1.1.0 and later versions.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2
|
||||
stacklevel=2,
|
||||
)
|
||||
return {"message": "ok"}
|
||||
|
||||
|
||||
def _add_to_vector_store(self, messages, metadata, filters):
|
||||
parsed_messages = parse_messages(messages)
|
||||
|
||||
if self.custom_prompt:
|
||||
system_prompt=self.custom_prompt
|
||||
user_prompt=f"Input: {parsed_messages}"
|
||||
system_prompt = self.custom_prompt
|
||||
user_prompt = f"Input: {parsed_messages}"
|
||||
else:
|
||||
system_prompt, user_prompt = get_fact_retrieval_messages(parsed_messages)
|
||||
|
||||
response = self.llm.generate_response(
|
||||
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}],
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
|
||||
try:
|
||||
new_retrieved_facts = json.loads(response)[
|
||||
"facts"
|
||||
]
|
||||
new_retrieved_facts = json.loads(response)["facts"]
|
||||
except Exception as e:
|
||||
logging.error(f"Error in new_retrieved_facts: {e}")
|
||||
new_retrieved_facts = []
|
||||
@@ -178,24 +173,30 @@ class Memory(MemoryBase):
|
||||
logging.info(resp)
|
||||
try:
|
||||
if resp["event"] == "ADD":
|
||||
memory_id = self._create_memory(data=resp["text"], metadata=metadata)
|
||||
returned_memories.append({
|
||||
"memory" : resp["text"],
|
||||
"event" : resp["event"],
|
||||
})
|
||||
_ = self._create_memory(data=resp["text"], metadata=metadata)
|
||||
returned_memories.append(
|
||||
{
|
||||
"memory": resp["text"],
|
||||
"event": resp["event"],
|
||||
}
|
||||
)
|
||||
elif resp["event"] == "UPDATE":
|
||||
self._update_memory(memory_id=resp["id"], data=resp["text"], metadata=metadata)
|
||||
returned_memories.append({
|
||||
"memory" : resp["text"],
|
||||
"event" : resp["event"],
|
||||
"previous_memory" : resp["old_memory"],
|
||||
})
|
||||
returned_memories.append(
|
||||
{
|
||||
"memory": resp["text"],
|
||||
"event": resp["event"],
|
||||
"previous_memory": resp["old_memory"],
|
||||
}
|
||||
)
|
||||
elif resp["event"] == "DELETE":
|
||||
self._delete_memory(memory_id=resp["id"])
|
||||
returned_memories.append({
|
||||
"memory" : resp["text"],
|
||||
"event" : resp["event"],
|
||||
})
|
||||
returned_memories.append(
|
||||
{
|
||||
"memory": resp["text"],
|
||||
"event": resp["event"],
|
||||
}
|
||||
)
|
||||
elif resp["event"] == "NONE":
|
||||
logging.info("NOOP for Memory.")
|
||||
except Exception as e:
|
||||
@@ -206,7 +207,6 @@ class Memory(MemoryBase):
|
||||
capture_event("mem0.add", self)
|
||||
|
||||
return returned_memories
|
||||
|
||||
|
||||
def _add_to_graph(self, messages, filters):
|
||||
added_entities = []
|
||||
@@ -220,7 +220,6 @@ class Memory(MemoryBase):
|
||||
|
||||
return added_entities
|
||||
|
||||
|
||||
def get(self, memory_id):
|
||||
"""
|
||||
Retrieve a memory by ID.
|
||||
@@ -236,11 +235,7 @@ class Memory(MemoryBase):
|
||||
if not memory:
|
||||
return None
|
||||
|
||||
filters = {
|
||||
key: memory.payload[key]
|
||||
for key in ["user_id", "agent_id", "run_id"]
|
||||
if memory.payload.get(key)
|
||||
}
|
||||
filters = {key: memory.payload[key] for key in ["user_id", "agent_id", "run_id"] if memory.payload.get(key)}
|
||||
|
||||
# Prepare base memory item
|
||||
memory_item = MemoryItem(
|
||||
@@ -261,9 +256,7 @@ class Memory(MemoryBase):
|
||||
"created_at",
|
||||
"updated_at",
|
||||
}
|
||||
additional_metadata = {
|
||||
k: v for k, v in memory.payload.items() if k not in excluded_keys
|
||||
}
|
||||
additional_metadata = {k: v for k, v in memory.payload.items() if k not in excluded_keys}
|
||||
if additional_metadata:
|
||||
memory_item["metadata"] = additional_metadata
|
||||
|
||||
@@ -271,7 +264,6 @@ class Memory(MemoryBase):
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100):
|
||||
"""
|
||||
List all memories.
|
||||
@@ -288,10 +280,12 @@ class Memory(MemoryBase):
|
||||
filters["run_id"] = run_id
|
||||
|
||||
capture_event("mem0.get_all", self, {"filters": len(filters), "limit": limit})
|
||||
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future_memories = executor.submit(self._get_all_from_vector_store, filters, limit)
|
||||
future_graph_entities = executor.submit(self.graph.get_all, filters) if self.version == "v1.1" and self.enable_graph else None
|
||||
future_graph_entities = (
|
||||
executor.submit(self.graph.get_all, filters) if self.version == "v1.1" and self.enable_graph else None
|
||||
)
|
||||
|
||||
all_memories = future_memories.result()
|
||||
graph_entities = future_graph_entities.result() if future_graph_entities else None
|
||||
@@ -307,15 +301,22 @@ class Memory(MemoryBase):
|
||||
"To use the latest format, set `api_version='v1.1'`. "
|
||||
"The current format will be removed in mem0ai 1.1.0 and later versions.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2
|
||||
stacklevel=2,
|
||||
)
|
||||
return all_memories
|
||||
|
||||
|
||||
def _get_all_from_vector_store(self, filters, limit):
|
||||
memories = self.vector_store.list(filters=filters, limit=limit)
|
||||
|
||||
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
|
||||
excluded_keys = {
|
||||
"user_id",
|
||||
"agent_id",
|
||||
"run_id",
|
||||
"hash",
|
||||
"data",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
}
|
||||
all_memories = [
|
||||
{
|
||||
**MemoryItem(
|
||||
@@ -325,19 +326,9 @@ class Memory(MemoryBase):
|
||||
created_at=mem.payload.get("created_at"),
|
||||
updated_at=mem.payload.get("updated_at"),
|
||||
).model_dump(exclude={"score"}),
|
||||
**{
|
||||
key: mem.payload[key]
|
||||
for key in ["user_id", "agent_id", "run_id"]
|
||||
if key in mem.payload
|
||||
},
|
||||
**{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload},
|
||||
**(
|
||||
{
|
||||
"metadata": {
|
||||
k: v
|
||||
for k, v in mem.payload.items()
|
||||
if k not in excluded_keys
|
||||
}
|
||||
}
|
||||
{"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}}
|
||||
if any(k for k in mem.payload if k not in excluded_keys)
|
||||
else {}
|
||||
),
|
||||
@@ -346,10 +337,7 @@ class Memory(MemoryBase):
|
||||
]
|
||||
return all_memories
|
||||
|
||||
|
||||
def search(
|
||||
self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None
|
||||
):
|
||||
def search(self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None):
|
||||
"""
|
||||
Search for memories.
|
||||
|
||||
@@ -373,15 +361,21 @@ class Memory(MemoryBase):
|
||||
filters["run_id"] = run_id
|
||||
|
||||
if not any(key in filters for key in ("user_id", "agent_id", "run_id")):
|
||||
raise ValueError(
|
||||
"One of the filters: user_id, agent_id or run_id is required!"
|
||||
)
|
||||
raise ValueError("One of the filters: user_id, agent_id or run_id is required!")
|
||||
|
||||
capture_event("mem0.search", self, {"filters": len(filters), "limit": limit, "version": self.version})
|
||||
capture_event(
|
||||
"mem0.search",
|
||||
self,
|
||||
{"filters": len(filters), "limit": limit, "version": self.version},
|
||||
)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future_memories = executor.submit(self._search_vector_store, query, filters, limit)
|
||||
future_graph_entities = executor.submit(self.graph.search, query, filters) if self.version == "v1.1" and self.enable_graph else None
|
||||
future_graph_entities = (
|
||||
executor.submit(self.graph.search, query, filters)
|
||||
if self.version == "v1.1" and self.enable_graph
|
||||
else None
|
||||
)
|
||||
|
||||
original_memories = future_memories.result()
|
||||
graph_entities = future_graph_entities.result() if future_graph_entities else None
|
||||
@@ -390,23 +384,20 @@ class Memory(MemoryBase):
|
||||
if self.enable_graph:
|
||||
return {"results": original_memories, "relations": graph_entities}
|
||||
else:
|
||||
return {"results" : original_memories}
|
||||
return {"results": original_memories}
|
||||
else:
|
||||
warnings.warn(
|
||||
"The current get_all API output format is deprecated. "
|
||||
"To use the latest format, set `api_version='v1.1'`. "
|
||||
"The current format will be removed in mem0ai 1.1.0 and later versions.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2
|
||||
stacklevel=2,
|
||||
)
|
||||
return original_memories
|
||||
|
||||
|
||||
def _search_vector_store(self, query, filters, limit):
|
||||
embeddings = self.embedding_model.embed(query)
|
||||
memories = self.vector_store.search(
|
||||
query=embeddings, limit=limit, filters=filters
|
||||
)
|
||||
memories = self.vector_store.search(query=embeddings, limit=limit, filters=filters)
|
||||
|
||||
excluded_keys = {
|
||||
"user_id",
|
||||
@@ -428,19 +419,9 @@ class Memory(MemoryBase):
|
||||
updated_at=mem.payload.get("updated_at"),
|
||||
score=mem.score,
|
||||
).model_dump(),
|
||||
**{
|
||||
key: mem.payload[key]
|
||||
for key in ["user_id", "agent_id", "run_id"]
|
||||
if key in mem.payload
|
||||
},
|
||||
**{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload},
|
||||
**(
|
||||
{
|
||||
"metadata": {
|
||||
k: v
|
||||
for k, v in mem.payload.items()
|
||||
if k not in excluded_keys
|
||||
}
|
||||
}
|
||||
{"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}}
|
||||
if any(k for k in mem.payload if k not in excluded_keys)
|
||||
else {}
|
||||
),
|
||||
@@ -450,7 +431,6 @@ class Memory(MemoryBase):
|
||||
|
||||
return original_memories
|
||||
|
||||
|
||||
def update(self, memory_id, data):
|
||||
"""
|
||||
Update a memory by ID.
|
||||
@@ -466,7 +446,6 @@ class Memory(MemoryBase):
|
||||
self._update_memory(memory_id, data)
|
||||
return {"message": "Memory updated successfully!"}
|
||||
|
||||
|
||||
def delete(self, memory_id):
|
||||
"""
|
||||
Delete a memory by ID.
|
||||
@@ -478,7 +457,6 @@ class Memory(MemoryBase):
|
||||
self._delete_memory(memory_id)
|
||||
return {"message": "Memory deleted successfully!"}
|
||||
|
||||
|
||||
def delete_all(self, user_id=None, agent_id=None, run_id=None):
|
||||
"""
|
||||
Delete all memories.
|
||||
@@ -511,8 +489,7 @@ class Memory(MemoryBase):
|
||||
if self.version == "v1.1" and self.enable_graph:
|
||||
self.graph.delete_all(filters)
|
||||
|
||||
return {'message': 'Memories deleted successfully!'}
|
||||
|
||||
return {"message": "Memories deleted successfully!"}
|
||||
|
||||
def history(self, memory_id):
|
||||
"""
|
||||
@@ -527,7 +504,6 @@ class Memory(MemoryBase):
|
||||
capture_event("mem0.history", self, {"memory_id": memory_id})
|
||||
return self.db.get_history(memory_id)
|
||||
|
||||
|
||||
def _create_memory(self, data, metadata=None):
|
||||
logging.info(f"Creating memory with {data=}")
|
||||
embeddings = self.embedding_model.embed(data)
|
||||
@@ -542,12 +518,9 @@ class Memory(MemoryBase):
|
||||
ids=[memory_id],
|
||||
payloads=[metadata],
|
||||
)
|
||||
self.db.add_history(
|
||||
memory_id, None, data, "ADD", created_at=metadata["created_at"]
|
||||
)
|
||||
self.db.add_history(memory_id, None, data, "ADD", created_at=metadata["created_at"])
|
||||
return memory_id
|
||||
|
||||
|
||||
def _update_memory(self, memory_id, data, metadata=None):
|
||||
logger.info(f"Updating memory with {data=}")
|
||||
existing_memory = self.vector_store.get(vector_id=memory_id)
|
||||
@@ -557,9 +530,7 @@ class Memory(MemoryBase):
|
||||
new_metadata["data"] = data
|
||||
new_metadata["hash"] = existing_memory.payload.get("hash")
|
||||
new_metadata["created_at"] = existing_memory.payload.get("created_at")
|
||||
new_metadata["updated_at"] = datetime.now(
|
||||
pytz.timezone("US/Pacific")
|
||||
).isoformat()
|
||||
new_metadata["updated_at"] = datetime.now(pytz.timezone("US/Pacific")).isoformat()
|
||||
|
||||
if "user_id" in existing_memory.payload:
|
||||
new_metadata["user_id"] = existing_memory.payload["user_id"]
|
||||
@@ -584,7 +555,6 @@ class Memory(MemoryBase):
|
||||
updated_at=new_metadata["updated_at"],
|
||||
)
|
||||
|
||||
|
||||
def _delete_memory(self, memory_id):
|
||||
logging.info(f"Deleting memory with {memory_id=}")
|
||||
existing_memory = self.vector_store.get(vector_id=memory_id)
|
||||
@@ -592,7 +562,6 @@ class Memory(MemoryBase):
|
||||
self.vector_store.delete(vector_id=memory_id)
|
||||
self.db.add_history(memory_id, prev_value, None, "DELETE", is_deleted=1)
|
||||
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the memory store.
|
||||
@@ -602,6 +571,5 @@ class Memory(MemoryBase):
|
||||
self.db.reset()
|
||||
capture_event("mem0.reset", self)
|
||||
|
||||
|
||||
def chat(self, query):
|
||||
raise NotImplementedError("Chat function not implemented yet.")
|
||||
|
||||
@@ -12,9 +12,7 @@ class SQLiteManager:
|
||||
with self.connection:
|
||||
cursor = self.connection.cursor()
|
||||
|
||||
cursor.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='history'"
|
||||
)
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='history'")
|
||||
table_exists = cursor.fetchone() is not None
|
||||
|
||||
if table_exists:
|
||||
@@ -62,7 +60,7 @@ class SQLiteManager:
|
||||
INSERT INTO history (id, memory_id, old_memory, new_memory, new_value, event, created_at, updated_at, is_deleted)
|
||||
SELECT id, memory_id, prev_value, new_value, new_value, event, timestamp, timestamp, is_deleted
|
||||
FROM old_history
|
||||
"""
|
||||
""" # noqa: E501
|
||||
)
|
||||
|
||||
cursor.execute("DROP TABLE old_history")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
import os
|
||||
|
||||
from posthog import Posthog
|
||||
|
||||
@@ -15,8 +15,9 @@ if isinstance(MEM0_TELEMETRY, str):
|
||||
if not isinstance(MEM0_TELEMETRY, bool):
|
||||
raise ValueError("MEM0_TELEMETRY must be a boolean value.")
|
||||
|
||||
logging.getLogger('posthog').setLevel(logging.CRITICAL + 1)
|
||||
logging.getLogger('urllib3').setLevel(logging.CRITICAL + 1)
|
||||
logging.getLogger("posthog").setLevel(logging.CRITICAL + 1)
|
||||
logging.getLogger("urllib3").setLevel(logging.CRITICAL + 1)
|
||||
|
||||
|
||||
class AnonymousTelemetry:
|
||||
def __init__(self, project_api_key, host):
|
||||
@@ -24,9 +25,8 @@ class AnonymousTelemetry:
|
||||
# Call setup config to ensure that the user_id is generated
|
||||
setup_config()
|
||||
self.user_id = get_user_id()
|
||||
# Optional
|
||||
if not MEM0_TELEMETRY:
|
||||
self.posthog.disabled = True
|
||||
if not MEM0_TELEMETRY:
|
||||
self.posthog.disabled = True
|
||||
|
||||
def capture_event(self, event_name, properties=None):
|
||||
if properties is None:
|
||||
@@ -40,9 +40,7 @@ class AnonymousTelemetry:
|
||||
"machine": platform.machine(),
|
||||
**properties,
|
||||
}
|
||||
self.posthog.capture(
|
||||
distinct_id=self.user_id, event=event_name, properties=properties
|
||||
)
|
||||
self.posthog.capture(distinct_id=self.user_id, event=event_name, properties=properties)
|
||||
|
||||
def identify_user(self, user_id, properties=None):
|
||||
if properties is None:
|
||||
@@ -65,6 +63,7 @@ def capture_event(event_name, memory_instance, additional_data=None):
|
||||
"collection": memory_instance.collection_name,
|
||||
"vector_size": memory_instance.embedding_model.config.embedding_dims,
|
||||
"history_store": "sqlite",
|
||||
"graph_store": f"{memory_instance.graph.__class__.__module__}.{memory_instance.graph.__class__.__name__}" if memory_instance.config.graph_store.config else None,
|
||||
"vector_store": f"{memory_instance.vector_store.__class__.__module__}.{memory_instance.vector_store.__class__.__name__}",
|
||||
"llm": f"{memory_instance.llm.__class__.__module__}.{memory_instance.llm.__class__.__name__}",
|
||||
"embedding_model": f"{memory_instance.embedding_model.__class__.__module__}.{memory_instance.embedding_model.__class__.__name__}",
|
||||
@@ -76,7 +75,6 @@ def capture_event(event_name, memory_instance, additional_data=None):
|
||||
telemetry.capture_event(event_name, event_data)
|
||||
|
||||
|
||||
|
||||
def capture_client_event(event_name, instance, additional_data=None):
|
||||
event_data = {
|
||||
"function": f"{instance.__class__.__module__}.{instance.__class__.__name__}",
|
||||
|
||||
@@ -4,13 +4,14 @@ from mem0.configs.prompts import FACT_RETRIEVAL_PROMPT
|
||||
def get_fact_retrieval_messages(message):
|
||||
return FACT_RETRIEVAL_PROMPT, f"Input: {message}"
|
||||
|
||||
|
||||
def parse_messages(messages):
|
||||
response = ""
|
||||
for msg in messages:
|
||||
if msg["role"] == "system":
|
||||
response += f"system: {msg['content']}\n"
|
||||
if msg["role"] == "user":
|
||||
response += f"user: {msg['content']}\n"
|
||||
if msg["role"] == "assistant":
|
||||
response += f"assistant: {msg['content']}\n"
|
||||
return response
|
||||
response = ""
|
||||
for msg in messages:
|
||||
if msg["role"] == "system":
|
||||
response += f"system: {msg['content']}\n"
|
||||
if msg["role"] == "user":
|
||||
response += f"user: {msg['content']}\n"
|
||||
if msg["role"] == "assistant":
|
||||
response += f"assistant: {msg['content']}\n"
|
||||
return response
|
||||
|
||||
@@ -10,7 +10,7 @@ try:
|
||||
import litellm
|
||||
except ImportError:
|
||||
user_input = input("The 'litellm' library is required. Install it now? [y/N]: ")
|
||||
if user_input.lower() == 'y':
|
||||
if user_input.lower() == "y":
|
||||
try:
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "litellm"])
|
||||
import litellm
|
||||
@@ -105,16 +105,10 @@ class Completions:
|
||||
|
||||
prepared_messages = self._prepare_messages(messages)
|
||||
if prepared_messages[-1]["role"] == "user":
|
||||
self._async_add_to_memory(
|
||||
messages, user_id, agent_id, run_id, metadata, filters
|
||||
)
|
||||
relevant_memories = self._fetch_relevant_memories(
|
||||
messages, user_id, agent_id, run_id, filters, limit
|
||||
)
|
||||
self._async_add_to_memory(messages, user_id, agent_id, run_id, metadata, filters)
|
||||
relevant_memories = self._fetch_relevant_memories(messages, user_id, agent_id, run_id, filters, limit)
|
||||
logger.debug(f"Retrieved {len(relevant_memories)} relevant memories")
|
||||
prepared_messages[-1]["content"] = self._format_query_with_memories(
|
||||
messages, relevant_memories
|
||||
)
|
||||
prepared_messages[-1]["content"] = self._format_query_with_memories(messages, relevant_memories)
|
||||
|
||||
response = litellm.completion(
|
||||
model=model,
|
||||
@@ -156,9 +150,7 @@ class Completions:
|
||||
messages[0]["content"] = MEMORY_ANSWER_PROMPT
|
||||
return messages
|
||||
|
||||
def _async_add_to_memory(
|
||||
self, messages, user_id, agent_id, run_id, metadata, filters
|
||||
):
|
||||
def _async_add_to_memory(self, messages, user_id, agent_id, run_id, metadata, filters):
|
||||
def add_task():
|
||||
logger.debug("Adding to memory asynchronously")
|
||||
self.mem0_client.add(
|
||||
@@ -172,13 +164,9 @@ class Completions:
|
||||
|
||||
threading.Thread(target=add_task, daemon=True).start()
|
||||
|
||||
def _fetch_relevant_memories(
|
||||
self, messages, user_id, agent_id, run_id, filters, limit
|
||||
):
|
||||
def _fetch_relevant_memories(self, messages, user_id, agent_id, run_id, filters, limit):
|
||||
# Currently, only pass the last 6 messages to the search API to prevent long query
|
||||
message_input = [
|
||||
f"{message['role']}: {message['content']}" for message in messages
|
||||
][-6:]
|
||||
message_input = [f"{message['role']}: {message['content']}" for message in messages][-6:]
|
||||
# TODO: Make it better by summarizing the past conversation
|
||||
return self.mem0_client.search(
|
||||
query="\n".join(message_input),
|
||||
|
||||
@@ -21,7 +21,7 @@ class LlmFactory:
|
||||
"azure_openai": "mem0.llms.azure_openai.AzureOpenAILLM",
|
||||
"openai_structured": "mem0.llms.openai_structured.OpenAIStructuredLLM",
|
||||
"anthropic": "mem0.llms.anthropic.AnthropicLLM",
|
||||
"azure_openai_structured": "mem0.llms.azure_openai_structured.AzureOpenAIStructuredLLM"
|
||||
"azure_openai_structured": "mem0.llms.azure_openai_structured.AzureOpenAIStructuredLLM",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -59,7 +59,7 @@ class VectorStoreFactory:
|
||||
"qdrant": "mem0.vector_stores.qdrant.Qdrant",
|
||||
"chroma": "mem0.vector_stores.chroma.ChromaDB",
|
||||
"pgvector": "mem0.vector_stores.pgvector.PGVector",
|
||||
"milvus": "mem0.vector_stores.milvus.MilvusDB"
|
||||
"milvus": "mem0.vector_stores.milvus.MilvusDB",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -80,24 +80,14 @@ class ChromaDB(VectorStoreBase):
|
||||
values.append(value)
|
||||
|
||||
ids, distances, metadatas = values
|
||||
max_length = max(
|
||||
len(v) for v in values if isinstance(v, list) and v is not None
|
||||
)
|
||||
max_length = max(len(v) for v in values if isinstance(v, list) and v is not None)
|
||||
|
||||
result = []
|
||||
for i in range(max_length):
|
||||
entry = OutputData(
|
||||
id=ids[i] if isinstance(ids, list) and ids and i < len(ids) else None,
|
||||
score=(
|
||||
distances[i]
|
||||
if isinstance(distances, list) and distances and i < len(distances)
|
||||
else None
|
||||
),
|
||||
payload=(
|
||||
metadatas[i]
|
||||
if isinstance(metadatas, list) and metadatas and i < len(metadatas)
|
||||
else None
|
||||
),
|
||||
score=(distances[i] if isinstance(distances, list) and distances and i < len(distances) else None),
|
||||
payload=(metadatas[i] if isinstance(metadatas, list) and metadatas and i < len(metadatas) else None),
|
||||
)
|
||||
result.append(entry)
|
||||
|
||||
@@ -143,9 +133,7 @@ class ChromaDB(VectorStoreBase):
|
||||
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
|
||||
self.collection.add(ids=ids, embeddings=vectors, metadatas=payloads)
|
||||
|
||||
def search(
|
||||
self, query: List[list], limit: int = 5, filters: Optional[Dict] = None
|
||||
) -> List[OutputData]:
|
||||
def search(self, query: List[list], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]:
|
||||
"""
|
||||
Search for similar vectors.
|
||||
|
||||
@@ -157,9 +145,7 @@ class ChromaDB(VectorStoreBase):
|
||||
Returns:
|
||||
List[OutputData]: Search results.
|
||||
"""
|
||||
results = self.collection.query(
|
||||
query_embeddings=query, where=filters, n_results=limit
|
||||
)
|
||||
results = self.collection.query(query_embeddings=query, where=filters, n_results=limit)
|
||||
final_results = self._parse_output(results)
|
||||
return final_results
|
||||
|
||||
@@ -225,9 +211,7 @@ class ChromaDB(VectorStoreBase):
|
||||
"""
|
||||
return self.client.get_collection(name=self.collection_name)
|
||||
|
||||
def list(
|
||||
self, filters: Optional[Dict] = None, limit: int = 100
|
||||
) -> List[OutputData]:
|
||||
def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]:
|
||||
"""
|
||||
List all vectors in a collection.
|
||||
|
||||
|
||||
@@ -8,15 +8,13 @@ class VectorStoreConfig(BaseModel):
|
||||
description="Provider of the vector store (e.g., 'qdrant', 'chroma')",
|
||||
default="qdrant",
|
||||
)
|
||||
config: Optional[Dict] = Field(
|
||||
description="Configuration for the specific vector store", default=None
|
||||
)
|
||||
config: Optional[Dict] = Field(description="Configuration for the specific vector store", default=None)
|
||||
|
||||
_provider_configs: Dict[str, str] = {
|
||||
"qdrant": "QdrantConfig",
|
||||
"chroma": "ChromaDbConfig",
|
||||
"pgvector": "PGVectorConfig",
|
||||
"milvus" : "MilvusDBConfig"
|
||||
"milvus": "MilvusDBConfig",
|
||||
}
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, Dict
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
from mem0.configs.vector_stores.milvus import MetricType
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
try:
|
||||
import pymilvus
|
||||
import pymilvus # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError("The 'pymilvus' library is required. Please install it using 'pip install pymilvus'.")
|
||||
|
||||
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
|
||||
from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -20,9 +22,15 @@ class OutputData(BaseModel):
|
||||
payload: Optional[Dict] # metadata
|
||||
|
||||
|
||||
|
||||
class MilvusDB(VectorStoreBase):
|
||||
def __init__(self, url: str, token: str, collection_name: str, embedding_model_dims: int, metric_type: MetricType) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
token: str,
|
||||
collection_name: str,
|
||||
embedding_model_dims: int,
|
||||
metric_type: MetricType,
|
||||
) -> None:
|
||||
"""Initialize the MilvusDB database.
|
||||
|
||||
Args:
|
||||
@@ -32,22 +40,21 @@ class MilvusDB(VectorStoreBase):
|
||||
embedding_model_dims (int): Dimensions of the embedding model (defaults to 1536).
|
||||
metric_type (MetricType): Metric type for similarity search (defaults to L2).
|
||||
"""
|
||||
|
||||
self.collection_name = collection_name
|
||||
self.embedding_model_dims = embedding_model_dims
|
||||
self.metric_type = metric_type
|
||||
|
||||
self.client = MilvusClient(uri=url,token=token)
|
||||
|
||||
self.client = MilvusClient(uri=url, token=token)
|
||||
self.create_col(
|
||||
collection_name=self.collection_name,
|
||||
vector_size=self.embedding_model_dims,
|
||||
metric_type=self.metric_type
|
||||
metric_type=self.metric_type,
|
||||
)
|
||||
|
||||
|
||||
|
||||
def create_col(
|
||||
self, collection_name : str, vector_size : str, metric_type : MetricType = MetricType.COSINE
|
||||
self,
|
||||
collection_name: str,
|
||||
vector_size: str,
|
||||
metric_type: MetricType = MetricType.COSINE,
|
||||
) -> None:
|
||||
"""Create a new collection with index_type AUTOINDEX.
|
||||
|
||||
@@ -65,7 +72,7 @@ class MilvusDB(VectorStoreBase):
|
||||
FieldSchema(name="vectors", dtype=DataType.FLOAT_VECTOR, dim=vector_size),
|
||||
FieldSchema(name="metadata", dtype=DataType.JSON),
|
||||
]
|
||||
|
||||
|
||||
schema = CollectionSchema(fields, enable_dynamic_field=True)
|
||||
|
||||
index = self.client.prepare_index_params(
|
||||
@@ -73,12 +80,10 @@ class MilvusDB(VectorStoreBase):
|
||||
metric_type=metric_type,
|
||||
index_type="AUTOINDEX",
|
||||
index_name="vector_index",
|
||||
params={ "nlist": 128 }
|
||||
params={"nlist": 128},
|
||||
)
|
||||
|
||||
self.client.create_collection(collection_name=collection_name, schema=schema, index_params=index)
|
||||
|
||||
|
||||
|
||||
def insert(self, ids, vectors, payloads, **kwargs: Optional[dict[str, any]]):
|
||||
"""Insert vectors into a collection.
|
||||
|
||||
@@ -91,9 +96,8 @@ class MilvusDB(VectorStoreBase):
|
||||
data = {"id": idx, "vectors": embedding, "metadata": metadata}
|
||||
self.client.insert(collection_name=self.collection_name, data=data, **kwargs)
|
||||
|
||||
|
||||
def _create_filter(self, filters: dict):
|
||||
"""Prepare filters for efficient query.
|
||||
"""Prepare filters for efficient query.
|
||||
|
||||
Args:
|
||||
filters (dict): filters [user_id, agent_id, run_id]
|
||||
@@ -109,8 +113,7 @@ class MilvusDB(VectorStoreBase):
|
||||
operands.append(f'(metadata["{key}"] == {value})')
|
||||
|
||||
return " and ".join(operands)
|
||||
|
||||
|
||||
|
||||
def _parse_output(self, data: list):
|
||||
"""
|
||||
Parse the output data.
|
||||
@@ -125,16 +128,15 @@ class MilvusDB(VectorStoreBase):
|
||||
|
||||
for value in data:
|
||||
uid, score, metadata = (
|
||||
value.get("id"),
|
||||
value.get("distance"),
|
||||
value.get("entity",{}).get("metadata")
|
||||
value.get("id"),
|
||||
value.get("distance"),
|
||||
value.get("entity", {}).get("metadata"),
|
||||
)
|
||||
|
||||
|
||||
memory_obj = OutputData(id=uid, score=score, payload=metadata)
|
||||
memory.append(memory_obj)
|
||||
|
||||
return memory
|
||||
|
||||
|
||||
def search(self, query: list, limit: int = 5, filters: dict = None) -> list:
|
||||
"""
|
||||
@@ -150,14 +152,15 @@ class MilvusDB(VectorStoreBase):
|
||||
"""
|
||||
query_filter = self._create_filter(filters) if filters else None
|
||||
hits = self.client.search(
|
||||
collection_name=self.collection_name,
|
||||
data=[query], limit=limit, filter=query_filter,
|
||||
output_fields=["*"]
|
||||
collection_name=self.collection_name,
|
||||
data=[query],
|
||||
limit=limit,
|
||||
filter=query_filter,
|
||||
output_fields=["*"],
|
||||
)
|
||||
result = self._parse_output(data=hits[0])
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def delete(self, vector_id):
|
||||
"""
|
||||
Delete a vector by ID.
|
||||
@@ -166,7 +169,6 @@ class MilvusDB(VectorStoreBase):
|
||||
vector_id (str): ID of the vector to delete.
|
||||
"""
|
||||
self.client.delete(collection_name=self.collection_name, ids=vector_id)
|
||||
|
||||
|
||||
def update(self, vector_id=None, vector=None, payload=None):
|
||||
"""
|
||||
@@ -177,7 +179,7 @@ class MilvusDB(VectorStoreBase):
|
||||
vector (List[float], optional): Updated vector.
|
||||
payload (Dict, optional): Updated payload.
|
||||
"""
|
||||
schema = {"id" : vector_id, "vectors": vector, "metadata" : payload}
|
||||
schema = {"id": vector_id, "vectors": vector, "metadata": payload}
|
||||
self.client.upsert(collection_name=self.collection_name, data=schema)
|
||||
|
||||
def get(self, vector_id):
|
||||
@@ -191,7 +193,11 @@ class MilvusDB(VectorStoreBase):
|
||||
OutputData: Retrieved vector.
|
||||
"""
|
||||
result = self.client.get(collection_name=self.collection_name, ids=vector_id)
|
||||
output = OutputData(id=result[0].get("id", None), score=None, payload=result[0].get("metadata", None))
|
||||
output = OutputData(
|
||||
id=result[0].get("id", None),
|
||||
score=None,
|
||||
payload=result[0].get("metadata", None),
|
||||
)
|
||||
return output
|
||||
|
||||
def list_cols(self):
|
||||
@@ -228,12 +234,9 @@ class MilvusDB(VectorStoreBase):
|
||||
List[OutputData]: List of vectors.
|
||||
"""
|
||||
query_filter = self._create_filter(filters) if filters else None
|
||||
result = self.client.query(
|
||||
collection_name=self.collection_name,
|
||||
filter=query_filter,
|
||||
limit=limit)
|
||||
result = self.client.query(collection_name=self.collection_name, filter=query_filter, limit=limit)
|
||||
memories = []
|
||||
for data in result:
|
||||
obj = OutputData(id=data.get("id"), score=None, payload=data.get("metadata"))
|
||||
memories.append(obj)
|
||||
return [memories]
|
||||
return [memories]
|
||||
|
||||
@@ -14,6 +14,7 @@ from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: Optional[str]
|
||||
score: Optional[float]
|
||||
@@ -22,7 +23,15 @@ class OutputData(BaseModel):
|
||||
|
||||
class PGVector(VectorStoreBase):
|
||||
def __init__(
|
||||
self, dbname, collection_name, embedding_model_dims, user, password, host, port, diskann
|
||||
self,
|
||||
dbname,
|
||||
collection_name,
|
||||
embedding_model_dims,
|
||||
user,
|
||||
password,
|
||||
host,
|
||||
port,
|
||||
diskann,
|
||||
):
|
||||
"""
|
||||
Initialize the PGVector database.
|
||||
@@ -40,9 +49,7 @@ class PGVector(VectorStoreBase):
|
||||
self.collection_name = collection_name
|
||||
self.use_diskann = diskann
|
||||
|
||||
self.conn = psycopg2.connect(
|
||||
dbname=dbname, user=user, password=password, host=host, port=port
|
||||
)
|
||||
self.conn = psycopg2.connect(dbname=dbname, user=user, password=password, host=host, port=port)
|
||||
self.cur = self.conn.cursor()
|
||||
|
||||
collections = self.list_cols()
|
||||
@@ -73,7 +80,8 @@ class PGVector(VectorStoreBase):
|
||||
self.cur.execute("SELECT * FROM pg_extension WHERE extname = 'vectorscale'")
|
||||
if self.cur.fetchone():
|
||||
# Create DiskANN index if extension is installed for faster search
|
||||
self.cur.execute(f"""
|
||||
self.cur.execute(
|
||||
f"""
|
||||
CREATE INDEX IF NOT EXISTS {self.collection_name}_vector_idx
|
||||
ON {self.collection_name}
|
||||
USING diskann (vector);
|
||||
@@ -94,10 +102,7 @@ class PGVector(VectorStoreBase):
|
||||
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
|
||||
json_payloads = [json.dumps(payload) for payload in payloads]
|
||||
|
||||
data = [
|
||||
(id, vector, payload)
|
||||
for id, vector, payload in zip(ids, vectors, json_payloads)
|
||||
]
|
||||
data = [(id, vector, payload) for id, vector, payload in zip(ids, vectors, json_payloads)]
|
||||
execute_values(
|
||||
self.cur,
|
||||
f"INSERT INTO {self.collection_name} (id, vector, payload) VALUES %s",
|
||||
@@ -125,9 +130,7 @@ class PGVector(VectorStoreBase):
|
||||
filter_conditions.append("payload->>%s = %s")
|
||||
filter_params.extend([k, str(v)])
|
||||
|
||||
filter_clause = (
|
||||
"WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
|
||||
)
|
||||
filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
|
||||
|
||||
self.cur.execute(
|
||||
f"""
|
||||
@@ -137,13 +140,11 @@ class PGVector(VectorStoreBase):
|
||||
ORDER BY distance
|
||||
LIMIT %s
|
||||
""",
|
||||
(query, *filter_params, limit),
|
||||
(query, *filter_params, limit),
|
||||
)
|
||||
|
||||
results = self.cur.fetchall()
|
||||
return [
|
||||
OutputData(id=str(r[0]), score=float(r[1]), payload=r[2]) for r in results
|
||||
]
|
||||
return [OutputData(id=str(r[0]), score=float(r[1]), payload=r[2]) for r in results]
|
||||
|
||||
def delete(self, vector_id):
|
||||
"""
|
||||
@@ -152,9 +153,7 @@ class PGVector(VectorStoreBase):
|
||||
Args:
|
||||
vector_id (str): ID of the vector to delete.
|
||||
"""
|
||||
self.cur.execute(
|
||||
f"DELETE FROM {self.collection_name} WHERE id = %s", (vector_id,)
|
||||
)
|
||||
self.cur.execute(f"DELETE FROM {self.collection_name} WHERE id = %s", (vector_id,))
|
||||
self.conn.commit()
|
||||
|
||||
def update(self, vector_id, vector=None, payload=None):
|
||||
@@ -204,9 +203,7 @@ class PGVector(VectorStoreBase):
|
||||
Returns:
|
||||
List[str]: List of collection names.
|
||||
"""
|
||||
self.cur.execute(
|
||||
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'"
|
||||
)
|
||||
self.cur.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'")
|
||||
return [row[0] for row in self.cur.fetchall()]
|
||||
|
||||
def delete_col(self):
|
||||
@@ -254,9 +251,7 @@ class PGVector(VectorStoreBase):
|
||||
filter_conditions.append("payload->>%s = %s")
|
||||
filter_params.extend([k, str(v)])
|
||||
|
||||
filter_clause = (
|
||||
"WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
|
||||
)
|
||||
filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
|
||||
|
||||
query = f"""
|
||||
SELECT id, vector, payload
|
||||
|
||||
@@ -3,16 +3,9 @@ import os
|
||||
import shutil
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import (
|
||||
Distance,
|
||||
FieldCondition,
|
||||
Filter,
|
||||
MatchValue,
|
||||
PointIdsList,
|
||||
PointStruct,
|
||||
Range,
|
||||
VectorParams,
|
||||
)
|
||||
from qdrant_client.models import (Distance, FieldCondition, Filter, MatchValue,
|
||||
PointIdsList, PointStruct, Range,
|
||||
VectorParams)
|
||||
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
@@ -68,9 +61,7 @@ class Qdrant(VectorStoreBase):
|
||||
self.collection_name = collection_name
|
||||
self.create_col(embedding_model_dims, on_disk)
|
||||
|
||||
def create_col(
|
||||
self, vector_size: int, on_disk: bool, distance: Distance = Distance.COSINE
|
||||
):
|
||||
def create_col(self, vector_size: int, on_disk: bool, distance: Distance = Distance.COSINE):
|
||||
"""
|
||||
Create a new collection.
|
||||
|
||||
@@ -83,16 +74,12 @@ class Qdrant(VectorStoreBase):
|
||||
response = self.list_cols()
|
||||
for collection in response.collections:
|
||||
if collection.name == self.collection_name:
|
||||
logging.debug(
|
||||
f"Collection {self.collection_name} already exists. Skipping creation."
|
||||
)
|
||||
logging.debug(f"Collection {self.collection_name} already exists. Skipping creation.")
|
||||
return
|
||||
|
||||
self.client.create_collection(
|
||||
collection_name=self.collection_name,
|
||||
vectors_config=VectorParams(
|
||||
size=vector_size, distance=distance, on_disk=on_disk
|
||||
),
|
||||
vectors_config=VectorParams(size=vector_size, distance=distance, on_disk=on_disk),
|
||||
)
|
||||
|
||||
def insert(self, vectors: list, payloads: list = None, ids: list = None):
|
||||
@@ -128,15 +115,9 @@ class Qdrant(VectorStoreBase):
|
||||
conditions = []
|
||||
for key, value in filters.items():
|
||||
if isinstance(value, dict) and "gte" in value and "lte" in value:
|
||||
conditions.append(
|
||||
FieldCondition(
|
||||
key=key, range=Range(gte=value["gte"], lte=value["lte"])
|
||||
)
|
||||
)
|
||||
conditions.append(FieldCondition(key=key, range=Range(gte=value["gte"], lte=value["lte"])))
|
||||
else:
|
||||
conditions.append(
|
||||
FieldCondition(key=key, match=MatchValue(value=value))
|
||||
)
|
||||
conditions.append(FieldCondition(key=key, match=MatchValue(value=value)))
|
||||
return Filter(must=conditions) if conditions else None
|
||||
|
||||
def search(self, query: list, limit: int = 5, filters: dict = None) -> list:
|
||||
@@ -196,9 +177,7 @@ class Qdrant(VectorStoreBase):
|
||||
Returns:
|
||||
dict: Retrieved vector.
|
||||
"""
|
||||
result = self.client.retrieve(
|
||||
collection_name=self.collection_name, ids=[vector_id], with_payload=True
|
||||
)
|
||||
result = self.client.retrieve(collection_name=self.collection_name, ids=[vector_id], with_payload=True)
|
||||
return result[0] if result else None
|
||||
|
||||
def list_cols(self) -> list:
|
||||
|
||||
Reference in New Issue
Block a user