[Mem0] Update dependencies and make the package lighter (#1708)

Co-authored-by: Dev-Khant <devkhant24@gmail.com>
This commit is contained in:
Deshraj Yadav
2024-08-14 23:28:07 -07:00
committed by GitHub
parent e35786e567
commit a8ba7abb7d
35 changed files with 634 additions and 1594 deletions

View File

@@ -250,7 +250,7 @@ Mem0 supports several language models (LLMs) through integration with various [p
## Use Mem0 Platform
```python
from mem0 import Mem0
from mem0.proxy.main import Mem0
client = Mem0(api_key="m0-xxx")

View File

@@ -4,4 +4,3 @@ __version__ = importlib.metadata.version("mem0ai")
from mem0.memory.main import Memory # noqa
from mem0.client.main import MemoryClient # noqa
from mem0.proxy.main import Mem0 #noqa

View File

@@ -7,17 +7,26 @@ from mem0.vector_stores.configs import VectorStoreConfig
from mem0.llms.configs import LlmConfig
from mem0.embeddings.configs import EmbedderConfig
class MemoryItem(BaseModel):
id: str = Field(..., description="The unique identifier for the text data")
memory: str = Field(..., description="The memory deduced from the text data") # TODO After prompt changes from platform, update this
memory: str = Field(
..., description="The memory deduced from the text data"
) # TODO After prompt changes from platform, update this
hash: Optional[str] = Field(None, description="The hash of the memory")
# The metadata value can be anything and not just string. Fix it
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata for the text data")
metadata: Optional[Dict[str, Any]] = Field(
None, description="Additional metadata for the text data"
)
score: Optional[float] = Field(
None, description="The score associated with the text data"
)
created_at: Optional[str] = Field(None, description="The timestamp when the memory was created")
updated_at: Optional[str] = Field(None, description="The timestamp when the memory was updated")
created_at: Optional[str] = Field(
None, description="The timestamp when the memory was created"
)
updated_at: Optional[str] = Field(
None, description="The timestamp when the memory was updated"
)
class MemoryConfig(BaseModel):

View File

@@ -1,6 +1,7 @@
from abc import ABC
from typing import Optional
class BaseEmbedderConfig(ABC):
"""
Config for Embeddings.
@@ -11,12 +12,10 @@ class BaseEmbedderConfig(ABC):
model: Optional[str] = None,
api_key: Optional[str] = None,
embedding_dims: Optional[int] = None,
# Ollama specific
ollama_base_url: Optional[str] = None,
# Huggingface specific
model_kwargs: Optional[dict] = None
model_kwargs: Optional[dict] = None,
):
"""
Initializes a configuration class instance for the Embeddings.

View File

@@ -1,6 +1,7 @@
from abc import ABC
from typing import Optional
class BaseLlmConfig(ABC):
"""
Config for LLMs.
@@ -14,16 +15,14 @@ class BaseLlmConfig(ABC):
max_tokens: int = 3000,
top_p: float = 0,
top_k: int = 1,
# Openrouter specific
models: Optional[list[str]] = None,
route: Optional[str] = "fallback",
openrouter_base_url: Optional[str] = "https://openrouter.ai/api/v1",
site_url: Optional[str] = None,
app_name: Optional[str] = None,
# Ollama specific
ollama_base_url: Optional[str] = None
ollama_base_url: Optional[str] = None,
):
"""
Initializes a configuration class instance for the LLM.

View File

@@ -2,15 +2,20 @@ from typing import Optional, ClassVar, Dict, Any
from pydantic import BaseModel, Field, model_validator
class ChromaDbConfig(BaseModel):
try:
from chromadb.api.client import Client
except ImportError:
raise ImportError("Chromadb requires extra dependencies. Install with `pip install chromadb`") from None
raise ImportError(
"Chromadb requires extra dependencies. Install with `pip install chromadb`"
) from None
Client: ClassVar[type] = Client
collection_name: str = Field("mem0", description="Default name for the collection")
client: Optional[Client] = Field(None, description="Existing ChromaDB client instance")
client: Optional[Client] = Field(
None, description="Existing ChromaDB client instance"
)
path: Optional[str] = Field(None, description="Path to the database directory")
host: Optional[str] = Field(None, description="Database connection remote host")
port: Optional[int] = Field(None, description="Database connection remote port")
@@ -29,7 +34,9 @@ class ChromaDbConfig(BaseModel):
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}")
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
)
return values
model_config = {

View File

@@ -2,11 +2,14 @@ from typing import Optional, Dict, Any
from pydantic import BaseModel, Field, model_validator
class PGVectorConfig(BaseModel):
dbname: str = Field("postgres", description="Default name for the database")
collection_name: str = Field("mem0", description="Default name for the collection")
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
embedding_model_dims: Optional[int] = Field(
1536, description="Dimensions of the embedding model"
)
user: Optional[str] = Field(None, description="Database user")
password: Optional[str] = Field(None, description="Database password")
host: Optional[str] = Field(None, description="Database host. Default is localhost")
@@ -29,6 +32,7 @@ class PGVectorConfig(BaseModel):
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}")
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
)
return values

View File

@@ -1,16 +1,24 @@
from pydantic import BaseModel, Field, model_validator
from typing import Optional, ClassVar, Dict, Any
class QdrantConfig(BaseModel):
from qdrant_client import QdrantClient
QdrantClient: ClassVar[type] = QdrantClient
collection_name: str = Field("mem0", description="Name of the collection")
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
client: Optional[QdrantClient] = Field(None, description="Existing Qdrant client instance")
embedding_model_dims: Optional[int] = Field(
1536, description="Dimensions of the embedding model"
)
client: Optional[QdrantClient] = Field(
None, description="Existing Qdrant client instance"
)
host: Optional[str] = Field(None, description="Host address for Qdrant server")
port: Optional[int] = Field(None, description="Port for Qdrant server")
path: Optional[str] = Field("/tmp/qdrant", description="Path for local Qdrant database")
path: Optional[str] = Field(
"/tmp/qdrant", description="Path for local Qdrant database"
)
url: Optional[str] = Field(None, description="Full URL for Qdrant server")
api_key: Optional[str] = Field(None, description="API key for Qdrant server")
on_disk: Optional[bool] = Field(False, description="Enables persistent storage")
@@ -38,7 +46,9 @@ class QdrantConfig(BaseModel):
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}")
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
)
return values
model_config = {

View File

@@ -6,6 +6,7 @@ from openai import AzureOpenAI
from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.embeddings.base import EmbeddingBase
class AzureOpenAIEmbedding(EmbeddingBase):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config)
@@ -30,10 +31,7 @@ class AzureOpenAIEmbedding(EmbeddingBase):
"""
text = text.replace("\n", " ")
return (
self.client.embeddings.create(
input=[text],
model=self.config.model
)
self.client.embeddings.create(input=[text], model=self.config.model)
.data[0]
.embedding
)

View File

@@ -3,12 +3,14 @@ from abc import ABC, abstractmethod
from mem0.configs.embeddings.base import BaseEmbedderConfig
class EmbeddingBase(ABC):
"""Initialized a base embedding class
:param config: Embedding configuration option class, defaults to None
:type config: Optional[BaseEmbedderConfig], optional
"""
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
if config is None:
self.config = BaseEmbedderConfig()

View File

@@ -9,8 +9,7 @@ class EmbedderConfig(BaseModel):
default="openai",
)
config: Optional[dict] = Field(
description="Configuration for the specific embedding model",
default={}
description="Configuration for the specific embedding model", default={}
)
@field_validator("config")
@@ -20,4 +19,3 @@ class EmbedderConfig(BaseModel):
return v
else:
raise ValueError(f"Unsupported embedding provider: {provider}")

View File

@@ -13,15 +13,11 @@ class HuggingFaceEmbedding(EmbeddingBase):
if self.config.model is None:
self.config.model = "multi-qa-MiniLM-L6-cos-v1"
self.model = SentenceTransformer(
self.config.model,
**self.config.model_kwargs
)
self.model = SentenceTransformer(self.config.model, **self.config.model_kwargs)
if self.config.embedding_dims is None:
self.config.embedding_dims = self.model.get_sentence_embedding_dimension()
def embed(self, text):
"""
Get the embedding for the given text using Hugging Face.

View File

@@ -6,7 +6,9 @@ from mem0.embeddings.base import EmbeddingBase
try:
from ollama import Client
except ImportError:
raise ImportError("Ollama requires extra dependencies. Install with `pip install ollama`") from None
raise ImportError(
"Ollama requires extra dependencies. Install with `pip install ollama`"
) from None
class OllamaEmbedding(EmbeddingBase):
@@ -14,9 +16,9 @@ class OllamaEmbedding(EmbeddingBase):
super().__init__(config)
if not self.config.model:
self.config.model="nomic-embed-text"
self.config.model = "nomic-embed-text"
if not self.config.embedding_dims:
self.config.embedding_dims=512
self.config.embedding_dims = 512
self.client = Client(host=self.config.ollama_base_url)
self._ensure_model_exists()

View File

@@ -6,6 +6,7 @@ from openai import OpenAI
from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.embeddings.base import EmbeddingBase
class OpenAIEmbedding(EmbeddingBase):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config)
@@ -28,10 +29,7 @@ class OpenAIEmbedding(EmbeddingBase):
"""
text = text.replace("\n", " ")
return (
self.client.embeddings.create(
input=[text],
model=self.config.model
)
self.client.embeddings.create(input=[text], model=self.config.model)
.data[0]
.embedding
)

View File

@@ -5,22 +5,30 @@ from typing import Dict, List, Optional, Any
try:
import boto3
except ImportError:
raise ImportError("AWS Bedrock requires extra dependencies. Install with `pip install boto3`") from None
raise ImportError(
"AWS Bedrock requires extra dependencies. Install with `pip install boto3`"
) from None
from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig
class AWSBedrockLLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)
if not self.config.model:
self.config.model="anthropic.claude-3-5-sonnet-20240620-v1:0"
self.client = boto3.client("bedrock-runtime", region_name=os.environ.get("AWS_REGION"), aws_access_key_id=os.environ.get("AWS_ACCESS_KEY"), aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"))
self.config.model = "anthropic.claude-3-5-sonnet-20240620-v1:0"
self.client = boto3.client(
"bedrock-runtime",
region_name=os.environ.get("AWS_REGION"),
aws_access_key_id=os.environ.get("AWS_ACCESS_KEY"),
aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"),
)
self.model_kwargs = {
"temperature": self.config.temperature,
"max_tokens_to_sample": self.config.max_tokens,
"top_p": self.config.top_p
"top_p": self.config.top_p,
}
def _format_messages(self, messages: List[Dict[str, str]]) -> str:
@@ -36,8 +44,8 @@ class AWSBedrockLLM(LLMBase):
"""
formatted_messages = []
for message in messages:
role = message['role'].capitalize()
content = message['content']
role = message["role"].capitalize()
content = message["content"]
formatted_messages.append(f"\n\n{role}: {content}")
return "".join(formatted_messages) + "\n\nAssistant:"
@@ -54,42 +62,42 @@ class AWSBedrockLLM(LLMBase):
str or dict: The processed response.
"""
if tools:
processed_response = {
"tool_calls": []
}
processed_response = {"tool_calls": []}
if response["output"]["message"]["content"]:
for item in response["output"]["message"]["content"]:
if "toolUse" in item:
processed_response["tool_calls"].append({
"name": item["toolUse"]["name"],
"arguments": item["toolUse"]["input"]
})
processed_response["tool_calls"].append(
{
"name": item["toolUse"]["name"],
"arguments": item["toolUse"]["input"],
}
)
return processed_response
response_body = json.loads(response['body'].read().decode())
return response_body.get('completion', '')
response_body = json.loads(response["body"].read().decode())
return response_body.get("completion", "")
def _prepare_input(
self,
provider: str,
model: str,
prompt: str,
model_kwargs: Optional[Dict[str, Any]] = {},
) -> Dict[str, Any]:
self,
provider: str,
model: str,
prompt: str,
model_kwargs: Optional[Dict[str, Any]] = {},
) -> Dict[str, Any]:
"""
Prepares the input dictionary for the specified provider's model by mapping and renaming
keys in the input based on the provider's requirements.
Prepares the input dictionary for the specified provider's model by mapping and renaming
keys in the input based on the provider's requirements.
Args:
provider (str): The name of the service provider (e.g., "meta", "ai21", "mistral", "cohere", "amazon").
model (str): The name or identifier of the model being used.
prompt (str): The text prompt to be processed by the model.
model_kwargs (Dict[str, Any]): Additional keyword arguments specific to the model's requirements.
Args:
provider (str): The name of the service provider (e.g., "meta", "ai21", "mistral", "cohere", "amazon").
model (str): The name or identifier of the model being used.
prompt (str): The text prompt to be processed by the model.
model_kwargs (Dict[str, Any]): Additional keyword arguments specific to the model's requirements.
Returns:
Dict[str, Any]: The prepared input dictionary with the correct keys and values for the specified provider.
Returns:
Dict[str, Any]: The prepared input dictionary with the correct keys and values for the specified provider.
"""
input_body = {"prompt": prompt, **model_kwargs}
@@ -115,10 +123,14 @@ class AWSBedrockLLM(LLMBase):
"textGenerationConfig": {
"maxTokenCount": model_kwargs.get("max_tokens_to_sample"),
"topP": model_kwargs.get("top_p"),
"temperature": model_kwargs.get("temperature")
}
"temperature": model_kwargs.get("temperature"),
},
}
input_body["textGenerationConfig"] = {
k: v
for k, v in input_body["textGenerationConfig"].items()
if v is not None
}
input_body["textGenerationConfig"] = {k: v for k, v in input_body["textGenerationConfig"].items() if v is not None}
return input_body
@@ -135,26 +147,28 @@ class AWSBedrockLLM(LLMBase):
new_tools = []
for tool in original_tools:
if tool['type'] == 'function':
function = tool['function']
if tool["type"] == "function":
function = tool["function"]
new_tool = {
"toolSpec": {
"name": function['name'],
"description": function['description'],
"name": function["name"],
"description": function["description"],
"inputSchema": {
"json": {
"type": "object",
"properties": {},
"required": function['parameters'].get('required', [])
"required": function["parameters"].get("required", []),
}
}
},
}
}
for prop, details in function['parameters'].get('properties', {}).items():
for prop, details in (
function["parameters"].get("properties", {}).items()
):
new_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop] = {
"type": details.get('type', 'string'),
"description": details.get('description', '')
"type": details.get("type", "string"),
"description": details.get("description", ""),
}
new_tools.append(new_tool)
@@ -181,28 +195,39 @@ class AWSBedrockLLM(LLMBase):
if tools:
# Use converse method when tools are provided
messages = [{"role": "user", "content": [{"text": message["content"]} for message in messages]}]
inference_config = {"temperature": self.model_kwargs["temperature"], "maxTokens": self.model_kwargs["max_tokens_to_sample"], "topP": self.model_kwargs["top_p"]}
messages = [
{
"role": "user",
"content": [{"text": message["content"]} for message in messages],
}
]
inference_config = {
"temperature": self.model_kwargs["temperature"],
"maxTokens": self.model_kwargs["max_tokens_to_sample"],
"topP": self.model_kwargs["top_p"],
}
tools_config = {"tools": self._convert_tool_format(tools)}
response = self.client.converse(
modelId=self.config.model,
messages=messages,
inferenceConfig=inference_config,
toolConfig=tools_config
toolConfig=tools_config,
)
else:
# Use invoke_model method when no tools are provided
prompt = self._format_messages(messages)
provider = self.model.split(".")[0]
input_body = self._prepare_input(provider, self.config.model, prompt, **self.model_kwargs)
input_body = self._prepare_input(
provider, self.config.model, prompt, **self.model_kwargs
)
body = json.dumps(input_body)
response = self.client.invoke_model(
body=body,
modelId=self.model,
accept='application/json',
contentType='application/json'
accept="application/json",
contentType="application/json",
)
return self._parse_response(response, tools)

View File

@@ -6,13 +6,14 @@ from openai import AzureOpenAI
from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig
class AzureOpenAILLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)
# Model name should match the custom deployment name chosen for it.
if not self.config.model:
self.config.model="gpt-4o"
self.config.model = "gpt-4o"
self.client = AzureOpenAI()
def _parse_response(self, response, tools):
@@ -29,21 +30,22 @@ class AzureOpenAILLM(LLMBase):
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": []
"tool_calls": [],
}
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append({
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments)
})
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response
else:
return response.choices[0].message.content
def generate_response(
self,
messages: List[Dict[str, str]],
@@ -68,7 +70,7 @@ class AzureOpenAILLM(LLMBase):
"messages": messages,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p
"top_p": self.config.top_p,
}
if response_format:
params["response_format"] = response_format

View File

@@ -14,8 +14,15 @@ class LlmConfig(BaseModel):
@field_validator("config")
def validate_config(cls, v, values):
provider = values.data.get("provider")
if provider in ("openai", "ollama", "groq", "together", "aws_bedrock", "litellm", "azure_openai"):
if provider in (
"openai",
"ollama",
"groq",
"together",
"aws_bedrock",
"litellm",
"azure_openai",
):
return v
else:
raise ValueError(f"Unsupported LLM provider: {provider}")

View File

@@ -4,7 +4,9 @@ from typing import Dict, List, Optional
try:
from groq import Groq
except ImportError:
raise ImportError("Groq requires extra dependencies. Install with `pip install groq`") from None
raise ImportError(
"Groq requires extra dependencies. Install with `pip install groq`"
) from None
from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig
@@ -15,7 +17,7 @@ class GroqLLM(LLMBase):
super().__init__(config)
if not self.config.model:
self.config.model="llama3-70b-8192"
self.config.model = "llama3-70b-8192"
self.client = Groq()
def _parse_response(self, response, tools):
@@ -32,15 +34,17 @@ class GroqLLM(LLMBase):
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": []
"tool_calls": [],
}
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append({
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments)
})
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response
else:
@@ -70,7 +74,7 @@ class GroqLLM(LLMBase):
"messages": messages,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p
"top_p": self.config.top_p,
}
if response_format:
params["response_format"] = response_format

View File

@@ -1,7 +1,12 @@
import json
from typing import Dict, List, Optional
import litellm
try:
import litellm
except ImportError:
raise ImportError(
"litellm requires extra dependencies. Install with `pip install litellm`"
) from None
from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig
@@ -12,7 +17,7 @@ class LiteLLM(LLMBase):
super().__init__(config)
if not self.config.model:
self.config.model="gpt-4o"
self.config.model = "gpt-4o"
def _parse_response(self, response, tools):
"""
@@ -28,15 +33,17 @@ class LiteLLM(LLMBase):
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": []
"tool_calls": [],
}
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append({
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments)
})
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response
else:
@@ -62,14 +69,16 @@ class LiteLLM(LLMBase):
str: The generated response.
"""
if not litellm.supports_function_calling(self.config.model):
raise ValueError(f"Model '{self.config.model}' in litellm does not support function calling.")
raise ValueError(
f"Model '{self.config.model}' in litellm does not support function calling."
)
params = {
"model": self.config.model,
"messages": messages,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p
"top_p": self.config.top_p,
}
if response_format:
params["response_format"] = response_format

View File

@@ -3,17 +3,20 @@ from typing import Dict, List, Optional
try:
from ollama import Client
except ImportError:
raise ImportError("Ollama requires extra dependencies. Install with `pip install ollama`") from None
raise ImportError(
"Ollama requires extra dependencies. Install with `pip install ollama`"
) from None
from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig
class OllamaLLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)
if not self.config.model:
self.config.model="llama3.1:70b"
self.config.model = "llama3.1:70b"
self.client = Client(host=self.config.ollama_base_url)
self._ensure_model_exists()
@@ -38,20 +41,22 @@ class OllamaLLM(LLMBase):
"""
if tools:
processed_response = {
"content": response['message']['content'],
"tool_calls": []
"content": response["message"]["content"],
"tool_calls": [],
}
if response['message'].get('tool_calls'):
for tool_call in response['message']['tool_calls']:
processed_response["tool_calls"].append({
"name": tool_call["function"]["name"],
"arguments": tool_call["function"]["arguments"]
})
if response["message"].get("tool_calls"):
for tool_call in response["message"]["tool_calls"]:
processed_response["tool_calls"].append(
{
"name": tool_call["function"]["name"],
"arguments": tool_call["function"]["arguments"],
}
)
return processed_response
else:
return response['message']['content']
return response["message"]["content"]
def generate_response(
self,
@@ -78,8 +83,8 @@ class OllamaLLM(LLMBase):
"options": {
"temperature": self.config.temperature,
"num_predict": self.config.max_tokens,
"top_p": self.config.top_p
}
"top_p": self.config.top_p,
},
}
if response_format:
params["format"] = response_format

View File

@@ -7,15 +7,19 @@ from openai import OpenAI
from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig
class OpenAILLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)
if not self.config.model:
self.config.model="gpt-4o"
self.config.model = "gpt-4o"
if os.environ.get("OPENROUTER_API_KEY"): # Use OpenRouter
self.client = OpenAI(api_key=os.environ.get("OPENROUTER_API_KEY"), base_url=self.config.openrouter_base_url)
if os.environ.get("OPENROUTER_API_KEY"): # Use OpenRouter
self.client = OpenAI(
api_key=os.environ.get("OPENROUTER_API_KEY"),
base_url=self.config.openrouter_base_url,
)
else:
api_key = os.getenv("OPENAI_API_KEY") or self.config.api_key
self.client = OpenAI(api_key=api_key)
@@ -34,15 +38,17 @@ class OpenAILLM(LLMBase):
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": []
"tool_calls": [],
}
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append({
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments)
})
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response
else:
@@ -72,7 +78,7 @@ class OpenAILLM(LLMBase):
"messages": messages,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p
"top_p": self.config.top_p,
}
if os.getenv("OPENROUTER_API_KEY"):
@@ -83,11 +89,11 @@ class OpenAILLM(LLMBase):
params.pop("model")
if self.config.site_url and self.config.app_name:
extra_headers={
"HTTP-Referer": self.config.site_url,
"X-Title": self.config.app_name
}
openrouter_params["extra_headers"] = extra_headers
extra_headers = {
"HTTP-Referer": self.config.site_url,
"X-Title": self.config.app_name,
}
openrouter_params["extra_headers"] = extra_headers
params.update(**openrouter_params)

View File

@@ -4,17 +4,20 @@ from typing import Dict, List, Optional
try:
from together import Together
except ImportError:
raise ImportError("Together requires extra dependencies. Install with `pip install together`") from None
raise ImportError(
"Together requires extra dependencies. Install with `pip install together`"
) from None
from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig
class TogetherLLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)
if not self.config.model:
self.config.model="mistralai/Mixtral-8x7B-Instruct-v0.1"
self.config.model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
self.client = Together()
def _parse_response(self, response, tools):
@@ -31,15 +34,17 @@ class TogetherLLM(LLMBase):
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": []
"tool_calls": [],
}
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append({
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments)
})
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response
else:
@@ -69,7 +74,7 @@ class TogetherLLM(LLMBase):
"messages": messages,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p
"top_p": self.config.top_p,
}
if response_format:
params["response_format"] = response_format

View File

@@ -28,8 +28,12 @@ setup_config()
class Memory(MemoryBase):
def __init__(self, config: MemoryConfig = MemoryConfig()):
self.config = config
self.embedding_model = EmbedderFactory.create(self.config.embedder.provider, self.config.embedder.config)
self.vector_store = VectorStoreFactory.create(self.config.vector_store.provider, self.config.vector_store.config)
self.embedding_model = EmbedderFactory.create(
self.config.embedder.provider, self.config.embedder.config
)
self.vector_store = VectorStoreFactory.create(
self.config.vector_store.provider, self.config.vector_store.config
)
self.llm = LlmFactory.create(self.config.llm.provider, self.config.llm.config)
self.db = SQLiteManager(self.config.history_db_path)
self.collection_name = self.config.vector_store.config.collection_name
@@ -172,7 +176,11 @@ class Memory(MemoryBase):
if not memory:
return None
filters = {key: memory.payload[key] for key in ["user_id", "agent_id", "run_id"] if memory.payload.get(key)}
filters = {
key: memory.payload[key]
for key in ["user_id", "agent_id", "run_id"]
if memory.payload.get(key)
}
# Prepare base memory item
memory_item = MemoryItem(
@@ -184,8 +192,18 @@ class Memory(MemoryBase):
).model_dump(exclude={"score"})
# Add metadata if there are additional keys
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
additional_metadata = {k: v for k, v in memory.payload.items() if k not in excluded_keys}
excluded_keys = {
"user_id",
"agent_id",
"run_id",
"hash",
"data",
"created_at",
"updated_at",
}
additional_metadata = {
k: v for k, v in memory.payload.items() if k not in excluded_keys
}
if additional_metadata:
memory_item["metadata"] = additional_metadata
@@ -211,7 +229,15 @@ class Memory(MemoryBase):
capture_event("mem0.get_all", self, {"filters": len(filters), "limit": limit})
memories = self.vector_store.list(filters=filters, limit=limit)
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
excluded_keys = {
"user_id",
"agent_id",
"run_id",
"hash",
"data",
"created_at",
"updated_at",
}
return [
{
**MemoryItem(
@@ -221,9 +247,22 @@ class Memory(MemoryBase):
created_at=mem.payload.get("created_at"),
updated_at=mem.payload.get("updated_at"),
).model_dump(exclude={"score"}),
**{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload},
**({"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}}
if any(k for k in mem.payload if k not in excluded_keys) else {})
**{
key: mem.payload[key]
for key in ["user_id", "agent_id", "run_id"]
if key in mem.payload
},
**(
{
"metadata": {
k: v
for k, v in mem.payload.items()
if k not in excluded_keys
}
}
if any(k for k in mem.payload if k not in excluded_keys)
else {}
),
}
for mem in memories[0]
]
@@ -255,9 +294,19 @@ class Memory(MemoryBase):
capture_event("mem0.search", self, {"filters": len(filters), "limit": limit})
embeddings = self.embedding_model.embed(query)
memories = self.vector_store.search(query=embeddings, limit=limit, filters=filters)
memories = self.vector_store.search(
query=embeddings, limit=limit, filters=filters
)
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
excluded_keys = {
"user_id",
"agent_id",
"run_id",
"hash",
"data",
"created_at",
"updated_at",
}
return [
{
@@ -269,9 +318,22 @@ class Memory(MemoryBase):
updated_at=mem.payload.get("updated_at"),
score=mem.score,
).model_dump(),
**{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload},
**({"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}}
if any(k for k in mem.payload if k not in excluded_keys) else {})
**{
key: mem.payload[key]
for key in ["user_id", "agent_id", "run_id"]
if key in mem.payload
},
**(
{
"metadata": {
k: v
for k, v in mem.payload.items()
if k not in excluded_keys
}
}
if any(k for k in mem.payload if k not in excluded_keys)
else {}
),
}
for mem in memories
]
@@ -289,7 +351,7 @@ class Memory(MemoryBase):
"""
capture_event("mem0.update", self, {"memory_id": memory_id})
self._update_memory_tool(memory_id, data)
return {'message': 'Memory updated successfully!'}
return {"message": "Memory updated successfully!"}
def delete(self, memory_id):
"""
@@ -300,7 +362,7 @@ class Memory(MemoryBase):
"""
capture_event("mem0.delete", self, {"memory_id": memory_id})
self._delete_memory_tool(memory_id)
return {'message': 'Memory deleted successfully!'}
return {"message": "Memory deleted successfully!"}
def delete_all(self, user_id=None, agent_id=None, run_id=None):
"""
@@ -328,7 +390,7 @@ class Memory(MemoryBase):
memories = self.vector_store.list(filters=filters)[0]
for memory in memories:
self._delete_memory_tool(memory.id)
return {'message': 'Memories deleted successfully!'}
return {"message": "Memories deleted successfully!"}
def history(self, memory_id):
"""
@@ -350,14 +412,16 @@ class Memory(MemoryBase):
metadata = metadata or {}
metadata["data"] = data
metadata["hash"] = hashlib.md5(data.encode()).hexdigest()
metadata["created_at"] = datetime.now(pytz.timezone('US/Pacific')).isoformat()
metadata["created_at"] = datetime.now(pytz.timezone("US/Pacific")).isoformat()
self.vector_store.insert(
vectors=[embeddings],
ids=[memory_id],
payloads=[metadata],
)
self.db.add_history(memory_id, None, data, "ADD", created_at=metadata["created_at"])
self.db.add_history(
memory_id, None, data, "ADD", created_at=metadata["created_at"]
)
return memory_id
def _update_memory_tool(self, memory_id, data, metadata=None):
@@ -368,7 +432,9 @@ class Memory(MemoryBase):
new_metadata["data"] = data
new_metadata["hash"] = existing_memory.payload.get("hash")
new_metadata["created_at"] = existing_memory.payload.get("created_at")
new_metadata["updated_at"] = datetime.now(pytz.timezone('US/Pacific')).isoformat()
new_metadata["updated_at"] = datetime.now(
pytz.timezone("US/Pacific")
).isoformat()
if "user_id" in existing_memory.payload:
new_metadata["user_id"] = existing_memory.payload["user_id"]
@@ -384,7 +450,14 @@ class Memory(MemoryBase):
payload=new_metadata,
)
logging.info(f"Updating memory with ID {memory_id=} with {data=}")
self.db.add_history(memory_id, prev_value, data, "UPDATE", created_at=new_metadata["created_at"], updated_at=new_metadata["updated_at"])
self.db.add_history(
memory_id,
prev_value,
data,
"UPDATE",
created_at=new_metadata["created_at"],
updated_at=new_metadata["updated_at"],
)
def _delete_memory_tool(self, memory_id):
logging.info(f"Deleting memory with {memory_id=}")

View File

@@ -20,12 +20,12 @@ def setup_config():
def get_user_id():
config_path = os.path.join(mem0_dir, "config.json")
if not os.path.exists(config_path):
return "anonymous_user"
return "anonymous_user"
try:
with open(config_path, "r") as config_file:
config = json.load(config_file)
user_id = config.get("user_id")
return user_id
except:
except Exception:
return "anonymous_user"

View File

@@ -12,7 +12,9 @@ class SQLiteManager:
with self.connection:
cursor = self.connection.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='history'")
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='history'"
)
table_exists = cursor.fetchone() is not None
if table_exists:
@@ -22,15 +24,15 @@ class SQLiteManager:
# Define the expected schema
expected_schema = {
'id': 'TEXT',
'memory_id': 'TEXT',
'old_memory': 'TEXT',
'new_memory': 'TEXT',
'new_value': 'TEXT',
'event': 'TEXT',
'created_at': 'DATETIME',
'updated_at': 'DATETIME',
'is_deleted': 'INTEGER'
"id": "TEXT",
"memory_id": "TEXT",
"old_memory": "TEXT",
"new_memory": "TEXT",
"new_value": "TEXT",
"event": "TEXT",
"created_at": "DATETIME",
"updated_at": "DATETIME",
"is_deleted": "INTEGER",
}
# Check if the schemas are the same
@@ -38,7 +40,8 @@ class SQLiteManager:
# Rename the old table
cursor.execute("ALTER TABLE history RENAME TO old_history")
cursor.execute("""
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS history (
id TEXT PRIMARY KEY,
memory_id TEXT,
@@ -50,20 +53,22 @@ class SQLiteManager:
updated_at DATETIME,
is_deleted INTEGER
)
""")
"""
)
# Copy data from the old table to the new table
cursor.execute("""
cursor.execute(
"""
INSERT INTO history (id, memory_id, old_memory, new_memory, new_value, event, created_at, updated_at, is_deleted)
SELECT id, memory_id, prev_value, new_value, new_value, event, timestamp, timestamp, is_deleted
FROM old_history
""")
"""
)
cursor.execute("DROP TABLE old_history")
self.connection.commit()
def _create_history_table(self):
with self.connection:
self.connection.execute(
@@ -82,7 +87,16 @@ class SQLiteManager:
"""
)
def add_history(self, memory_id, old_memory, new_memory, event, created_at = None, updated_at = None, is_deleted=0):
def add_history(
self,
memory_id,
old_memory,
new_memory,
event,
created_at=None,
updated_at=None,
is_deleted=0,
):
with self.connection:
self.connection.execute(
"""

View File

@@ -1,18 +1,26 @@
import httpx
from typing import Optional, List, Union
import threading
import litellm
try:
import litellm
except ImportError:
raise ImportError(
"litellm requires extra dependencies. Install with `pip install litellm`"
) from None
from mem0.memory.telemetry import capture_client_event
from mem0 import Memory, MemoryClient
from mem0.configs.prompts import MEMORY_ANSWER_PROMPT
class Mem0:
def __init__(
self,
config: Optional[dict] = None,
api_key: Optional[str] = None,
host: Optional[str] = None
):
self,
config: Optional[dict] = None,
api_key: Optional[str] = None,
host: Optional[str] = None,
):
if api_key:
self.mem0_client = MemoryClient(api_key, host)
else:
@@ -77,13 +85,21 @@ class Completions:
raise ValueError("One of user_id, agent_id, run_id must be provided")
if not litellm.supports_function_calling(model):
raise ValueError(f"Model '{model}' does not support function calling. Please use a model that supports function calling.")
raise ValueError(
f"Model '{model}' does not support function calling. Please use a model that supports function calling."
)
prepared_messages = self._prepare_messages(messages)
if prepared_messages[-1]["role"] == "user":
self._async_add_to_memory(messages, user_id, agent_id, run_id, metadata, filters)
relevant_memories = self._fetch_relevant_memories(messages, user_id, agent_id, run_id, filters, limit)
prepared_messages[-1]["content"] = self._format_query_with_memories(messages, relevant_memories)
self._async_add_to_memory(
messages, user_id, agent_id, run_id, metadata, filters
)
relevant_memories = self._fetch_relevant_memories(
messages, user_id, agent_id, run_id, filters, limit
)
prepared_messages[-1]["content"] = self._format_query_with_memories(
messages, relevant_memories
)
response = litellm.completion(
model=model,
@@ -114,9 +130,9 @@ class Completions:
base_url=base_url,
api_version=api_version,
api_key=api_key,
model_list=model_list
model_list=model_list,
)
capture_client_event("mem0.chat.create", self)
return response
def _prepare_messages(self, messages: List[dict]) -> List[dict]:
@@ -125,7 +141,9 @@ class Completions:
messages[0]["content"] = MEMORY_ANSWER_PROMPT
return messages
def _async_add_to_memory(self, messages, user_id, agent_id, run_id, metadata, filters):
def _async_add_to_memory(
self, messages, user_id, agent_id, run_id, metadata, filters
):
def add_task():
self.mem0_client.add(
messages=messages,
@@ -135,11 +153,16 @@ class Completions:
metadata=metadata,
filters=filters,
)
threading.Thread(target=add_task, daemon=True).start()
def _fetch_relevant_memories(self, messages, user_id, agent_id, run_id, filters, limit):
def _fetch_relevant_memories(
self, messages, user_id, agent_id, run_id, filters, limit
):
# Currently, only pass the last 6 messages to the search API to prevent long query
message_input = [f"{message['role']}: {message['content']}" for message in messages][-6:]
message_input = [
f"{message['role']}: {message['content']}" for message in messages
][-6:]
# TODO: Make it better by summarizing the past conversation
return self.mem0_client.search(
query="\n".join(message_input),

View File

@@ -3,6 +3,7 @@ import importlib
from mem0.configs.llms.base import BaseLlmConfig
from mem0.configs.embeddings.base import BaseEmbedderConfig
def load_class(class_type):
module_path, class_name = class_type.rsplit(".", 1)
module = importlib.import_module(module_path)
@@ -30,6 +31,7 @@ class LlmFactory:
else:
raise ValueError(f"Unsupported Llm provider: {provider_name}")
class EmbedderFactory:
provider_to_class = {
"openai": "mem0.embeddings.openai.OpenAIEmbedding",
@@ -48,11 +50,12 @@ class EmbedderFactory:
else:
raise ValueError(f"Unsupported Embedder provider: {provider_name}")
class VectorStoreFactory:
provider_to_class = {
"qdrant": "mem0.vector_stores.qdrant.Qdrant",
"chroma": "mem0.vector_stores.chroma.ChromaDB",
"pgvector": "mem0.vector_stores.pgvector.PGVector"
"pgvector": "mem0.vector_stores.pgvector.PGVector",
}
@classmethod

View File

@@ -7,7 +7,9 @@ try:
import chromadb
from chromadb.config import Settings
except ImportError:
raise ImportError("Chromadb requires extra dependencies. Install with `pip install chromadb`") from None
raise ImportError(
"Chromadb requires extra dependencies. Install with `pip install chromadb`"
) from None
from mem0.vector_stores.base import VectorStoreBase
@@ -25,7 +27,7 @@ class ChromaDB(VectorStoreBase):
client: Optional[chromadb.Client] = None,
host: Optional[str] = None,
port: Optional[int] = None,
path: Optional[str] = None
path: Optional[str] = None,
):
"""
Initialize the Chromadb vector store.
@@ -68,7 +70,7 @@ class ChromaDB(VectorStoreBase):
Returns:
List[OutputData]: Parsed output data.
"""
keys = ['ids', 'distances', 'metadatas']
keys = ["ids", "distances", "metadatas"]
values = []
for key in keys:
@@ -78,14 +80,24 @@ class ChromaDB(VectorStoreBase):
values.append(value)
ids, distances, metadatas = values
max_length = max(len(v) for v in values if isinstance(v, list) and v is not None)
max_length = max(
len(v) for v in values if isinstance(v, list) and v is not None
)
result = []
for i in range(max_length):
entry = OutputData(
id=ids[i] if isinstance(ids, list) and ids and i < len(ids) else None,
score=distances[i] if isinstance(distances, list) and distances and i < len(distances) else None,
payload=metadatas[i] if isinstance(metadatas, list) and metadatas and i < len(metadatas) else None,
score=(
distances[i]
if isinstance(distances, list) and distances and i < len(distances)
else None
),
payload=(
metadatas[i]
if isinstance(metadatas, list) and metadatas and i < len(metadatas)
else None
),
)
result.append(entry)
@@ -114,7 +126,12 @@ class ChromaDB(VectorStoreBase):
)
return collection
def insert(self, vectors: List[list], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None):
def insert(
self,
vectors: List[list],
payloads: Optional[List[Dict]] = None,
ids: Optional[List[str]] = None,
):
"""
Insert vectors into a collection.
@@ -125,7 +142,9 @@ class ChromaDB(VectorStoreBase):
"""
self.collection.add(ids=ids, embeddings=vectors, metadatas=payloads)
def search(self, query: List[list], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]:
def search(
self, query: List[list], limit: int = 5, filters: Optional[Dict] = None
) -> List[OutputData]:
"""
Search for similar vectors.
@@ -137,7 +156,9 @@ class ChromaDB(VectorStoreBase):
Returns:
List[OutputData]: Search results.
"""
results = self.collection.query(query_embeddings=query, where=filters, n_results=limit)
results = self.collection.query(
query_embeddings=query, where=filters, n_results=limit
)
final_results = self._parse_output(results)
return final_results
@@ -150,7 +171,12 @@ class ChromaDB(VectorStoreBase):
"""
self.collection.delete(ids=vector_id)
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None):
def update(
self,
vector_id: str,
vector: Optional[List[float]] = None,
payload: Optional[Dict] = None,
):
"""
Update a vector and its payload.
@@ -198,7 +224,9 @@ class ChromaDB(VectorStoreBase):
"""
return self.client.get_collection(name=self.collection_name)
def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]:
def list(
self, filters: Optional[Dict] = None, limit: int = 100
) -> List[OutputData]:
"""
List all vectors in a collection.

View File

@@ -1,31 +1,34 @@
from typing import Optional, Dict
from pydantic import BaseModel, Field, model_validator
class VectorStoreConfig(BaseModel):
provider: str = Field(
description="Provider of the vector store (e.g., 'qdrant', 'chroma')",
default="qdrant",
)
config: Optional[Dict] = Field(
description="Configuration for the specific vector store",
default=None
description="Configuration for the specific vector store", default=None
)
_provider_configs: Dict[str, str] = {
"qdrant": "QdrantConfig",
"chroma": "ChromaDbConfig",
"pgvector": "PGVectorConfig"
"pgvector": "PGVectorConfig",
}
@model_validator(mode="after")
def validate_and_create_config(self) -> 'VectorStoreConfig':
def validate_and_create_config(self) -> "VectorStoreConfig":
provider = self.provider
config = self.config
if provider not in self._provider_configs:
raise ValueError(f"Unsupported vector store provider: {provider}")
module = __import__(f"mem0.configs.vector_stores.{provider}", fromlist=[self._provider_configs[provider]])
module = __import__(
f"mem0.configs.vector_stores.{provider}",
fromlist=[self._provider_configs[provider]],
)
config_class = getattr(module, self._provider_configs[provider])
if config is None:

View File

@@ -1,16 +1,19 @@
import json
from typing import Optional, List, Dict, Any
from typing import Optional, List
from pydantic import BaseModel
try:
import psycopg2
from psycopg2.extras import execute_values
except ImportError:
raise ImportError("PGVector requires extra dependencies. Install with `pip install psycopg2`") from None
raise ImportError(
"PGVector requires extra dependencies. Install with `pip install psycopg2`"
) from None
from mem0.vector_stores.base import VectorStoreBase
class OutputData(BaseModel):
id: Optional[str]
score: Optional[float]
@@ -19,14 +22,7 @@ class OutputData(BaseModel):
class PGVector(VectorStoreBase):
def __init__(
self,
dbname,
collection_name,
embedding_model_dims,
user,
password,
host,
port
self, dbname, collection_name, embedding_model_dims, user, password, host, port
):
"""
Initialize the PGVector database.
@@ -43,11 +39,7 @@ class PGVector(VectorStoreBase):
self.collection_name = collection_name
self.conn = psycopg2.connect(
dbname=dbname,
user=user,
password=password,
host=host,
port=port
dbname=dbname, user=user, password=password, host=host, port=port
)
self.cur = self.conn.cursor()
@@ -63,16 +55,18 @@ class PGVector(VectorStoreBase):
name (str): Name of the collection.
embedding_model_dims (int, optional): Dimension of the embedding vector.
"""
self.cur.execute(f"""
self.cur.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.collection_name} (
id UUID PRIMARY KEY,
vector vector({embedding_model_dims}),
payload JSONB
);
""")
"""
)
self.conn.commit()
def insert(self, vectors, payloads = None, ids = None):
def insert(self, vectors, payloads=None, ids=None):
"""
Insert vectors into a collection.
@@ -83,11 +77,18 @@ class PGVector(VectorStoreBase):
"""
json_payloads = [json.dumps(payload) for payload in payloads]
data = [(id, vector, payload) for id, vector, payload in zip(ids, vectors, json_payloads)]
execute_values(self.cur, f"INSERT INTO {self.collection_name} (id, vector, payload) VALUES %s", data)
data = [
(id, vector, payload)
for id, vector, payload in zip(ids, vectors, json_payloads)
]
execute_values(
self.cur,
f"INSERT INTO {self.collection_name} (id, vector, payload) VALUES %s",
data,
)
self.conn.commit()
def search(self, query, limit = 5, filters = None):
def search(self, query, limit=5, filters=None):
"""
Search for similar vectors.
@@ -104,21 +105,28 @@ class PGVector(VectorStoreBase):
if filters:
for k, v in filters.items():
filter_conditions.append(f"payload->>%s = %s")
filter_conditions.append("payload->>%s = %s")
filter_params.extend([k, str(v)])
filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
filter_clause = (
"WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
)
self.cur.execute(f"""
self.cur.execute(
f"""
SELECT id, vector <-> %s::vector AS distance, payload
FROM {self.collection_name}
{filter_clause}
ORDER BY distance
LIMIT %s
""", (query, *filter_params, limit))
""",
(query, *filter_params, limit),
)
results = self.cur.fetchall()
return [OutputData(id=str(r[0]), score=float(r[1]), payload=r[2]) for r in results]
return [
OutputData(id=str(r[0]), score=float(r[1]), payload=r[2]) for r in results
]
def delete(self, vector_id):
"""
@@ -127,10 +135,12 @@ class PGVector(VectorStoreBase):
Args:
vector_id (str): ID of the vector to delete.
"""
self.cur.execute(f"DELETE FROM {self.collection_name} WHERE id = %s", (vector_id,))
self.cur.execute(
f"DELETE FROM {self.collection_name} WHERE id = %s", (vector_id,)
)
self.conn.commit()
def update(self, vector_id, vector = None, payload = None):
def update(self, vector_id, vector=None, payload=None):
"""
Update a vector and its payload.
@@ -140,9 +150,15 @@ class PGVector(VectorStoreBase):
payload (Dict, optional): Updated payload.
"""
if vector:
self.cur.execute(f"UPDATE {self.collection_name} SET vector = %s WHERE id = %s", (vector, vector_id))
self.cur.execute(
f"UPDATE {self.collection_name} SET vector = %s WHERE id = %s",
(vector, vector_id),
)
if payload:
self.cur.execute(f"UPDATE {self.collection_name} SET payload = %s WHERE id = %s", (psycopg2.extras.Json(payload), vector_id))
self.cur.execute(
f"UPDATE {self.collection_name} SET payload = %s WHERE id = %s",
(psycopg2.extras.Json(payload), vector_id),
)
self.conn.commit()
def get(self, vector_id) -> OutputData:
@@ -155,7 +171,10 @@ class PGVector(VectorStoreBase):
Returns:
OutputData: Retrieved vector.
"""
self.cur.execute(f"SELECT id, vector, payload FROM {self.collection_name} WHERE id = %s", (vector_id,))
self.cur.execute(
f"SELECT id, vector, payload FROM {self.collection_name} WHERE id = %s",
(vector_id,),
)
result = self.cur.fetchone()
if not result:
return None
@@ -168,11 +187,13 @@ class PGVector(VectorStoreBase):
Returns:
List[str]: List of collection names.
"""
self.cur.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'")
self.cur.execute(
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'"
)
return [row[0] for row in self.cur.fetchall()]
def delete_col(self):
""" Delete a collection. """
"""Delete a collection."""
self.cur.execute(f"DROP TABLE IF EXISTS {self.collection_name}")
self.conn.commit()
@@ -183,22 +204,21 @@ class PGVector(VectorStoreBase):
Returns:
Dict[str, Any]: Collection information.
"""
self.cur.execute(f"""
self.cur.execute(
f"""
SELECT
table_name,
(SELECT COUNT(*) FROM {self.collection_name}) as row_count,
(SELECT pg_size_pretty(pg_total_relation_size('{self.collection_name}'))) as total_size
FROM information_schema.tables
WHERE table_schema = 'public' AND table_name = %s
""", (self.collection_name,))
""",
(self.collection_name,),
)
result = self.cur.fetchone()
return {
"name": result[0],
"count": result[1],
"size": result[2]
}
return {"name": result[0], "count": result[1], "size": result[2]}
def list(self, filters = None, limit = 100):
def list(self, filters=None, limit=100):
"""
List all vectors in a collection.
@@ -214,10 +234,12 @@ class PGVector(VectorStoreBase):
if filters:
for k, v in filters.items():
filter_conditions.append(f"payload->>%s = %s")
filter_conditions.append("payload->>%s = %s")
filter_params.extend([k, str(v)])
filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
filter_clause = (
"WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
)
query = f"""
SELECT id, vector, payload
@@ -235,7 +257,7 @@ class PGVector(VectorStoreBase):
"""
Close the database connection when the object is deleted.
"""
if hasattr(self, 'cur'):
if hasattr(self, "cur"):
self.cur.close()
if hasattr(self, 'conn'):
if hasattr(self, "conn"):
self.conn.close()

View File

@@ -28,7 +28,7 @@ class Qdrant(VectorStoreBase):
path: str = None,
url: str = None,
api_key: str = None,
on_disk: bool = False
on_disk: bool = False,
):
"""
Initialize the Qdrant vector store.
@@ -66,7 +66,9 @@ class Qdrant(VectorStoreBase):
self.collection_name = collection_name
self.create_col(embedding_model_dims, on_disk)
def create_col(self, vector_size: int, on_disk: bool, distance: Distance = Distance.COSINE):
def create_col(
self, vector_size: int, on_disk: bool, distance: Distance = Distance.COSINE
):
"""
Create a new collection.
@@ -79,12 +81,16 @@ class Qdrant(VectorStoreBase):
response = self.list_cols()
for collection in response.collections:
if collection.name == self.collection_name:
logging.debug(f"Collection {self.collection_name} already exists. Skipping creation.")
logging.debug(
f"Collection {self.collection_name} already exists. Skipping creation."
)
return
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=vector_size, distance=distance, on_disk=on_disk),
vectors_config=VectorParams(
size=vector_size, distance=distance, on_disk=on_disk
),
)
def insert(self, vectors: list, payloads: list = None, ids: list = None):
@@ -202,7 +208,7 @@ class Qdrant(VectorStoreBase):
return self.client.get_collections()
def delete_col(self):
""" Delete a collection. """
"""Delete a collection."""
self.client.delete_collection(collection_name=self.collection_name)
def col_info(self) -> dict:

1222
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "mem0ai"
version = "0.0.19"
version = "0.0.20"
description = "Long-term memory for AI Agents"
authors = ["Mem0 <founders@mem0.ai>"]
exclude = [
@@ -22,7 +22,6 @@ openai = "^1.33.0"
posthog = "^3.5.0"
pytz = "^2024.1"
sqlalchemy = "^2.0.31"
litellm = "^1.42.7"
[tool.poetry.group.test.dependencies]
pytest = "^8.2.2"

View File

@@ -2,7 +2,8 @@ import pytest
from unittest.mock import Mock, patch
from mem0.configs.prompts import MEMORY_ANSWER_PROMPT
from mem0 import Memory, MemoryClient, Mem0
from mem0 import Memory, MemoryClient
from mem0.proxy.main import Mem0
from mem0.proxy.main import Chat, Completions
@pytest.fixture