[Mem0] Update dependencies and make the package lighter (#1708)

Co-authored-by: Dev-Khant <devkhant24@gmail.com>
This commit is contained in:
Deshraj Yadav
2024-08-14 23:28:07 -07:00
committed by GitHub
parent e35786e567
commit a8ba7abb7d
35 changed files with 634 additions and 1594 deletions

View File

@@ -250,7 +250,7 @@ Mem0 supports several language models (LLMs) through integration with various [p
## Use Mem0 Platform ## Use Mem0 Platform
```python ```python
from mem0 import Mem0 from mem0.proxy.main import Mem0
client = Mem0(api_key="m0-xxx") client = Mem0(api_key="m0-xxx")

View File

@@ -4,4 +4,3 @@ __version__ = importlib.metadata.version("mem0ai")
from mem0.memory.main import Memory # noqa from mem0.memory.main import Memory # noqa
from mem0.client.main import MemoryClient # noqa from mem0.client.main import MemoryClient # noqa
from mem0.proxy.main import Mem0 #noqa

View File

@@ -7,17 +7,26 @@ from mem0.vector_stores.configs import VectorStoreConfig
from mem0.llms.configs import LlmConfig from mem0.llms.configs import LlmConfig
from mem0.embeddings.configs import EmbedderConfig from mem0.embeddings.configs import EmbedderConfig
class MemoryItem(BaseModel): class MemoryItem(BaseModel):
id: str = Field(..., description="The unique identifier for the text data") id: str = Field(..., description="The unique identifier for the text data")
memory: str = Field(..., description="The memory deduced from the text data") # TODO After prompt changes from platform, update this memory: str = Field(
..., description="The memory deduced from the text data"
) # TODO After prompt changes from platform, update this
hash: Optional[str] = Field(None, description="The hash of the memory") hash: Optional[str] = Field(None, description="The hash of the memory")
# The metadata value can be anything and not just string. Fix it # The metadata value can be anything and not just string. Fix it
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata for the text data") metadata: Optional[Dict[str, Any]] = Field(
None, description="Additional metadata for the text data"
)
score: Optional[float] = Field( score: Optional[float] = Field(
None, description="The score associated with the text data" None, description="The score associated with the text data"
) )
created_at: Optional[str] = Field(None, description="The timestamp when the memory was created") created_at: Optional[str] = Field(
updated_at: Optional[str] = Field(None, description="The timestamp when the memory was updated") None, description="The timestamp when the memory was created"
)
updated_at: Optional[str] = Field(
None, description="The timestamp when the memory was updated"
)
class MemoryConfig(BaseModel): class MemoryConfig(BaseModel):

View File

@@ -1,6 +1,7 @@
from abc import ABC from abc import ABC
from typing import Optional from typing import Optional
class BaseEmbedderConfig(ABC): class BaseEmbedderConfig(ABC):
""" """
Config for Embeddings. Config for Embeddings.
@@ -11,12 +12,10 @@ class BaseEmbedderConfig(ABC):
model: Optional[str] = None, model: Optional[str] = None,
api_key: Optional[str] = None, api_key: Optional[str] = None,
embedding_dims: Optional[int] = None, embedding_dims: Optional[int] = None,
# Ollama specific # Ollama specific
ollama_base_url: Optional[str] = None, ollama_base_url: Optional[str] = None,
# Huggingface specific # Huggingface specific
model_kwargs: Optional[dict] = None model_kwargs: Optional[dict] = None,
): ):
""" """
Initializes a configuration class instance for the Embeddings. Initializes a configuration class instance for the Embeddings.

View File

@@ -1,6 +1,7 @@
from abc import ABC from abc import ABC
from typing import Optional from typing import Optional
class BaseLlmConfig(ABC): class BaseLlmConfig(ABC):
""" """
Config for LLMs. Config for LLMs.
@@ -14,16 +15,14 @@ class BaseLlmConfig(ABC):
max_tokens: int = 3000, max_tokens: int = 3000,
top_p: float = 0, top_p: float = 0,
top_k: int = 1, top_k: int = 1,
# Openrouter specific # Openrouter specific
models: Optional[list[str]] = None, models: Optional[list[str]] = None,
route: Optional[str] = "fallback", route: Optional[str] = "fallback",
openrouter_base_url: Optional[str] = "https://openrouter.ai/api/v1", openrouter_base_url: Optional[str] = "https://openrouter.ai/api/v1",
site_url: Optional[str] = None, site_url: Optional[str] = None,
app_name: Optional[str] = None, app_name: Optional[str] = None,
# Ollama specific # Ollama specific
ollama_base_url: Optional[str] = None ollama_base_url: Optional[str] = None,
): ):
""" """
Initializes a configuration class instance for the LLM. Initializes a configuration class instance for the LLM.

View File

@@ -2,15 +2,20 @@ from typing import Optional, ClassVar, Dict, Any
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
class ChromaDbConfig(BaseModel): class ChromaDbConfig(BaseModel):
try: try:
from chromadb.api.client import Client from chromadb.api.client import Client
except ImportError: except ImportError:
raise ImportError("Chromadb requires extra dependencies. Install with `pip install chromadb`") from None raise ImportError(
"Chromadb requires extra dependencies. Install with `pip install chromadb`"
) from None
Client: ClassVar[type] = Client Client: ClassVar[type] = Client
collection_name: str = Field("mem0", description="Default name for the collection") collection_name: str = Field("mem0", description="Default name for the collection")
client: Optional[Client] = Field(None, description="Existing ChromaDB client instance") client: Optional[Client] = Field(
None, description="Existing ChromaDB client instance"
)
path: Optional[str] = Field(None, description="Path to the database directory") path: Optional[str] = Field(None, description="Path to the database directory")
host: Optional[str] = Field(None, description="Database connection remote host") host: Optional[str] = Field(None, description="Database connection remote host")
port: Optional[int] = Field(None, description="Database connection remote port") port: Optional[int] = Field(None, description="Database connection remote port")
@@ -29,7 +34,9 @@ class ChromaDbConfig(BaseModel):
input_fields = set(values.keys()) input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields extra_fields = input_fields - allowed_fields
if extra_fields: if extra_fields:
raise ValueError(f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}") raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
)
return values return values
model_config = { model_config = {

View File

@@ -2,11 +2,14 @@ from typing import Optional, Dict, Any
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
class PGVectorConfig(BaseModel): class PGVectorConfig(BaseModel):
dbname: str = Field("postgres", description="Default name for the database") dbname: str = Field("postgres", description="Default name for the database")
collection_name: str = Field("mem0", description="Default name for the collection") collection_name: str = Field("mem0", description="Default name for the collection")
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model") embedding_model_dims: Optional[int] = Field(
1536, description="Dimensions of the embedding model"
)
user: Optional[str] = Field(None, description="Database user") user: Optional[str] = Field(None, description="Database user")
password: Optional[str] = Field(None, description="Database password") password: Optional[str] = Field(None, description="Database password")
host: Optional[str] = Field(None, description="Database host. Default is localhost") host: Optional[str] = Field(None, description="Database host. Default is localhost")
@@ -29,6 +32,7 @@ class PGVectorConfig(BaseModel):
input_fields = set(values.keys()) input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields extra_fields = input_fields - allowed_fields
if extra_fields: if extra_fields:
raise ValueError(f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}") raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
)
return values return values

View File

@@ -1,16 +1,24 @@
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
from typing import Optional, ClassVar, Dict, Any from typing import Optional, ClassVar, Dict, Any
class QdrantConfig(BaseModel): class QdrantConfig(BaseModel):
from qdrant_client import QdrantClient from qdrant_client import QdrantClient
QdrantClient: ClassVar[type] = QdrantClient QdrantClient: ClassVar[type] = QdrantClient
collection_name: str = Field("mem0", description="Name of the collection") collection_name: str = Field("mem0", description="Name of the collection")
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model") embedding_model_dims: Optional[int] = Field(
client: Optional[QdrantClient] = Field(None, description="Existing Qdrant client instance") 1536, description="Dimensions of the embedding model"
)
client: Optional[QdrantClient] = Field(
None, description="Existing Qdrant client instance"
)
host: Optional[str] = Field(None, description="Host address for Qdrant server") host: Optional[str] = Field(None, description="Host address for Qdrant server")
port: Optional[int] = Field(None, description="Port for Qdrant server") port: Optional[int] = Field(None, description="Port for Qdrant server")
path: Optional[str] = Field("/tmp/qdrant", description="Path for local Qdrant database") path: Optional[str] = Field(
"/tmp/qdrant", description="Path for local Qdrant database"
)
url: Optional[str] = Field(None, description="Full URL for Qdrant server") url: Optional[str] = Field(None, description="Full URL for Qdrant server")
api_key: Optional[str] = Field(None, description="API key for Qdrant server") api_key: Optional[str] = Field(None, description="API key for Qdrant server")
on_disk: Optional[bool] = Field(False, description="Enables persistent storage") on_disk: Optional[bool] = Field(False, description="Enables persistent storage")
@@ -38,7 +46,9 @@ class QdrantConfig(BaseModel):
input_fields = set(values.keys()) input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields extra_fields = input_fields - allowed_fields
if extra_fields: if extra_fields:
raise ValueError(f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}") raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
)
return values return values
model_config = { model_config = {

View File

@@ -6,6 +6,7 @@ from openai import AzureOpenAI
from mem0.configs.embeddings.base import BaseEmbedderConfig from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.embeddings.base import EmbeddingBase from mem0.embeddings.base import EmbeddingBase
class AzureOpenAIEmbedding(EmbeddingBase): class AzureOpenAIEmbedding(EmbeddingBase):
def __init__(self, config: Optional[BaseEmbedderConfig] = None): def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config) super().__init__(config)
@@ -30,10 +31,7 @@ class AzureOpenAIEmbedding(EmbeddingBase):
""" """
text = text.replace("\n", " ") text = text.replace("\n", " ")
return ( return (
self.client.embeddings.create( self.client.embeddings.create(input=[text], model=self.config.model)
input=[text],
model=self.config.model
)
.data[0] .data[0]
.embedding .embedding
) )

View File

@@ -3,12 +3,14 @@ from abc import ABC, abstractmethod
from mem0.configs.embeddings.base import BaseEmbedderConfig from mem0.configs.embeddings.base import BaseEmbedderConfig
class EmbeddingBase(ABC): class EmbeddingBase(ABC):
"""Initialized a base embedding class """Initialized a base embedding class
:param config: Embedding configuration option class, defaults to None :param config: Embedding configuration option class, defaults to None
:type config: Optional[BaseEmbedderConfig], optional :type config: Optional[BaseEmbedderConfig], optional
""" """
def __init__(self, config: Optional[BaseEmbedderConfig] = None): def __init__(self, config: Optional[BaseEmbedderConfig] = None):
if config is None: if config is None:
self.config = BaseEmbedderConfig() self.config = BaseEmbedderConfig()

View File

@@ -9,8 +9,7 @@ class EmbedderConfig(BaseModel):
default="openai", default="openai",
) )
config: Optional[dict] = Field( config: Optional[dict] = Field(
description="Configuration for the specific embedding model", description="Configuration for the specific embedding model", default={}
default={}
) )
@field_validator("config") @field_validator("config")
@@ -20,4 +19,3 @@ class EmbedderConfig(BaseModel):
return v return v
else: else:
raise ValueError(f"Unsupported embedding provider: {provider}") raise ValueError(f"Unsupported embedding provider: {provider}")

View File

@@ -13,15 +13,11 @@ class HuggingFaceEmbedding(EmbeddingBase):
if self.config.model is None: if self.config.model is None:
self.config.model = "multi-qa-MiniLM-L6-cos-v1" self.config.model = "multi-qa-MiniLM-L6-cos-v1"
self.model = SentenceTransformer( self.model = SentenceTransformer(self.config.model, **self.config.model_kwargs)
self.config.model,
**self.config.model_kwargs
)
if self.config.embedding_dims is None: if self.config.embedding_dims is None:
self.config.embedding_dims = self.model.get_sentence_embedding_dimension() self.config.embedding_dims = self.model.get_sentence_embedding_dimension()
def embed(self, text): def embed(self, text):
""" """
Get the embedding for the given text using Hugging Face. Get the embedding for the given text using Hugging Face.

View File

@@ -6,7 +6,9 @@ from mem0.embeddings.base import EmbeddingBase
try: try:
from ollama import Client from ollama import Client
except ImportError: except ImportError:
raise ImportError("Ollama requires extra dependencies. Install with `pip install ollama`") from None raise ImportError(
"Ollama requires extra dependencies. Install with `pip install ollama`"
) from None
class OllamaEmbedding(EmbeddingBase): class OllamaEmbedding(EmbeddingBase):
@@ -14,9 +16,9 @@ class OllamaEmbedding(EmbeddingBase):
super().__init__(config) super().__init__(config)
if not self.config.model: if not self.config.model:
self.config.model="nomic-embed-text" self.config.model = "nomic-embed-text"
if not self.config.embedding_dims: if not self.config.embedding_dims:
self.config.embedding_dims=512 self.config.embedding_dims = 512
self.client = Client(host=self.config.ollama_base_url) self.client = Client(host=self.config.ollama_base_url)
self._ensure_model_exists() self._ensure_model_exists()

View File

@@ -6,6 +6,7 @@ from openai import OpenAI
from mem0.configs.embeddings.base import BaseEmbedderConfig from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.embeddings.base import EmbeddingBase from mem0.embeddings.base import EmbeddingBase
class OpenAIEmbedding(EmbeddingBase): class OpenAIEmbedding(EmbeddingBase):
def __init__(self, config: Optional[BaseEmbedderConfig] = None): def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config) super().__init__(config)
@@ -28,10 +29,7 @@ class OpenAIEmbedding(EmbeddingBase):
""" """
text = text.replace("\n", " ") text = text.replace("\n", " ")
return ( return (
self.client.embeddings.create( self.client.embeddings.create(input=[text], model=self.config.model)
input=[text],
model=self.config.model
)
.data[0] .data[0]
.embedding .embedding
) )

View File

@@ -5,22 +5,30 @@ from typing import Dict, List, Optional, Any
try: try:
import boto3 import boto3
except ImportError: except ImportError:
raise ImportError("AWS Bedrock requires extra dependencies. Install with `pip install boto3`") from None raise ImportError(
"AWS Bedrock requires extra dependencies. Install with `pip install boto3`"
) from None
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
class AWSBedrockLLM(LLMBase): class AWSBedrockLLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None): def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config) super().__init__(config)
if not self.config.model: if not self.config.model:
self.config.model="anthropic.claude-3-5-sonnet-20240620-v1:0" self.config.model = "anthropic.claude-3-5-sonnet-20240620-v1:0"
self.client = boto3.client("bedrock-runtime", region_name=os.environ.get("AWS_REGION"), aws_access_key_id=os.environ.get("AWS_ACCESS_KEY"), aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY")) self.client = boto3.client(
"bedrock-runtime",
region_name=os.environ.get("AWS_REGION"),
aws_access_key_id=os.environ.get("AWS_ACCESS_KEY"),
aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"),
)
self.model_kwargs = { self.model_kwargs = {
"temperature": self.config.temperature, "temperature": self.config.temperature,
"max_tokens_to_sample": self.config.max_tokens, "max_tokens_to_sample": self.config.max_tokens,
"top_p": self.config.top_p "top_p": self.config.top_p,
} }
def _format_messages(self, messages: List[Dict[str, str]]) -> str: def _format_messages(self, messages: List[Dict[str, str]]) -> str:
@@ -36,8 +44,8 @@ class AWSBedrockLLM(LLMBase):
""" """
formatted_messages = [] formatted_messages = []
for message in messages: for message in messages:
role = message['role'].capitalize() role = message["role"].capitalize()
content = message['content'] content = message["content"]
formatted_messages.append(f"\n\n{role}: {content}") formatted_messages.append(f"\n\n{role}: {content}")
return "".join(formatted_messages) + "\n\nAssistant:" return "".join(formatted_messages) + "\n\nAssistant:"
@@ -54,42 +62,42 @@ class AWSBedrockLLM(LLMBase):
str or dict: The processed response. str or dict: The processed response.
""" """
if tools: if tools:
processed_response = { processed_response = {"tool_calls": []}
"tool_calls": []
}
if response["output"]["message"]["content"]: if response["output"]["message"]["content"]:
for item in response["output"]["message"]["content"]: for item in response["output"]["message"]["content"]:
if "toolUse" in item: if "toolUse" in item:
processed_response["tool_calls"].append({ processed_response["tool_calls"].append(
"name": item["toolUse"]["name"], {
"arguments": item["toolUse"]["input"] "name": item["toolUse"]["name"],
}) "arguments": item["toolUse"]["input"],
}
)
return processed_response return processed_response
response_body = json.loads(response['body'].read().decode()) response_body = json.loads(response["body"].read().decode())
return response_body.get('completion', '') return response_body.get("completion", "")
def _prepare_input( def _prepare_input(
self, self,
provider: str, provider: str,
model: str, model: str,
prompt: str, prompt: str,
model_kwargs: Optional[Dict[str, Any]] = {}, model_kwargs: Optional[Dict[str, Any]] = {},
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Prepares the input dictionary for the specified provider's model by mapping and renaming Prepares the input dictionary for the specified provider's model by mapping and renaming
keys in the input based on the provider's requirements. keys in the input based on the provider's requirements.
Args: Args:
provider (str): The name of the service provider (e.g., "meta", "ai21", "mistral", "cohere", "amazon"). provider (str): The name of the service provider (e.g., "meta", "ai21", "mistral", "cohere", "amazon").
model (str): The name or identifier of the model being used. model (str): The name or identifier of the model being used.
prompt (str): The text prompt to be processed by the model. prompt (str): The text prompt to be processed by the model.
model_kwargs (Dict[str, Any]): Additional keyword arguments specific to the model's requirements. model_kwargs (Dict[str, Any]): Additional keyword arguments specific to the model's requirements.
Returns: Returns:
Dict[str, Any]: The prepared input dictionary with the correct keys and values for the specified provider. Dict[str, Any]: The prepared input dictionary with the correct keys and values for the specified provider.
""" """
input_body = {"prompt": prompt, **model_kwargs} input_body = {"prompt": prompt, **model_kwargs}
@@ -115,10 +123,14 @@ class AWSBedrockLLM(LLMBase):
"textGenerationConfig": { "textGenerationConfig": {
"maxTokenCount": model_kwargs.get("max_tokens_to_sample"), "maxTokenCount": model_kwargs.get("max_tokens_to_sample"),
"topP": model_kwargs.get("top_p"), "topP": model_kwargs.get("top_p"),
"temperature": model_kwargs.get("temperature") "temperature": model_kwargs.get("temperature"),
} },
}
input_body["textGenerationConfig"] = {
k: v
for k, v in input_body["textGenerationConfig"].items()
if v is not None
} }
input_body["textGenerationConfig"] = {k: v for k, v in input_body["textGenerationConfig"].items() if v is not None}
return input_body return input_body
@@ -135,26 +147,28 @@ class AWSBedrockLLM(LLMBase):
new_tools = [] new_tools = []
for tool in original_tools: for tool in original_tools:
if tool['type'] == 'function': if tool["type"] == "function":
function = tool['function'] function = tool["function"]
new_tool = { new_tool = {
"toolSpec": { "toolSpec": {
"name": function['name'], "name": function["name"],
"description": function['description'], "description": function["description"],
"inputSchema": { "inputSchema": {
"json": { "json": {
"type": "object", "type": "object",
"properties": {}, "properties": {},
"required": function['parameters'].get('required', []) "required": function["parameters"].get("required", []),
} }
} },
} }
} }
for prop, details in function['parameters'].get('properties', {}).items(): for prop, details in (
function["parameters"].get("properties", {}).items()
):
new_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop] = { new_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop] = {
"type": details.get('type', 'string'), "type": details.get("type", "string"),
"description": details.get('description', '') "description": details.get("description", ""),
} }
new_tools.append(new_tool) new_tools.append(new_tool)
@@ -181,28 +195,39 @@ class AWSBedrockLLM(LLMBase):
if tools: if tools:
# Use converse method when tools are provided # Use converse method when tools are provided
messages = [{"role": "user", "content": [{"text": message["content"]} for message in messages]}] messages = [
inference_config = {"temperature": self.model_kwargs["temperature"], "maxTokens": self.model_kwargs["max_tokens_to_sample"], "topP": self.model_kwargs["top_p"]} {
"role": "user",
"content": [{"text": message["content"]} for message in messages],
}
]
inference_config = {
"temperature": self.model_kwargs["temperature"],
"maxTokens": self.model_kwargs["max_tokens_to_sample"],
"topP": self.model_kwargs["top_p"],
}
tools_config = {"tools": self._convert_tool_format(tools)} tools_config = {"tools": self._convert_tool_format(tools)}
response = self.client.converse( response = self.client.converse(
modelId=self.config.model, modelId=self.config.model,
messages=messages, messages=messages,
inferenceConfig=inference_config, inferenceConfig=inference_config,
toolConfig=tools_config toolConfig=tools_config,
) )
else: else:
# Use invoke_model method when no tools are provided # Use invoke_model method when no tools are provided
prompt = self._format_messages(messages) prompt = self._format_messages(messages)
provider = self.model.split(".")[0] provider = self.model.split(".")[0]
input_body = self._prepare_input(provider, self.config.model, prompt, **self.model_kwargs) input_body = self._prepare_input(
provider, self.config.model, prompt, **self.model_kwargs
)
body = json.dumps(input_body) body = json.dumps(input_body)
response = self.client.invoke_model( response = self.client.invoke_model(
body=body, body=body,
modelId=self.model, modelId=self.model,
accept='application/json', accept="application/json",
contentType='application/json' contentType="application/json",
) )
return self._parse_response(response, tools) return self._parse_response(response, tools)

View File

@@ -6,13 +6,14 @@ from openai import AzureOpenAI
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
class AzureOpenAILLM(LLMBase): class AzureOpenAILLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None): def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config) super().__init__(config)
# Model name should match the custom deployment name chosen for it. # Model name should match the custom deployment name chosen for it.
if not self.config.model: if not self.config.model:
self.config.model="gpt-4o" self.config.model = "gpt-4o"
self.client = AzureOpenAI() self.client = AzureOpenAI()
def _parse_response(self, response, tools): def _parse_response(self, response, tools):
@@ -29,21 +30,22 @@ class AzureOpenAILLM(LLMBase):
if tools: if tools:
processed_response = { processed_response = {
"content": response.choices[0].message.content, "content": response.choices[0].message.content,
"tool_calls": [] "tool_calls": [],
} }
if response.choices[0].message.tool_calls: if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls: for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append({ processed_response["tool_calls"].append(
"name": tool_call.function.name, {
"arguments": json.loads(tool_call.function.arguments) "name": tool_call.function.name,
}) "arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response return processed_response
else: else:
return response.choices[0].message.content return response.choices[0].message.content
def generate_response( def generate_response(
self, self,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
@@ -68,7 +70,7 @@ class AzureOpenAILLM(LLMBase):
"messages": messages, "messages": messages,
"temperature": self.config.temperature, "temperature": self.config.temperature,
"max_tokens": self.config.max_tokens, "max_tokens": self.config.max_tokens,
"top_p": self.config.top_p "top_p": self.config.top_p,
} }
if response_format: if response_format:
params["response_format"] = response_format params["response_format"] = response_format

View File

@@ -14,8 +14,15 @@ class LlmConfig(BaseModel):
@field_validator("config") @field_validator("config")
def validate_config(cls, v, values): def validate_config(cls, v, values):
provider = values.data.get("provider") provider = values.data.get("provider")
if provider in ("openai", "ollama", "groq", "together", "aws_bedrock", "litellm", "azure_openai"): if provider in (
"openai",
"ollama",
"groq",
"together",
"aws_bedrock",
"litellm",
"azure_openai",
):
return v return v
else: else:
raise ValueError(f"Unsupported LLM provider: {provider}") raise ValueError(f"Unsupported LLM provider: {provider}")

View File

@@ -4,7 +4,9 @@ from typing import Dict, List, Optional
try: try:
from groq import Groq from groq import Groq
except ImportError: except ImportError:
raise ImportError("Groq requires extra dependencies. Install with `pip install groq`") from None raise ImportError(
"Groq requires extra dependencies. Install with `pip install groq`"
) from None
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
@@ -15,7 +17,7 @@ class GroqLLM(LLMBase):
super().__init__(config) super().__init__(config)
if not self.config.model: if not self.config.model:
self.config.model="llama3-70b-8192" self.config.model = "llama3-70b-8192"
self.client = Groq() self.client = Groq()
def _parse_response(self, response, tools): def _parse_response(self, response, tools):
@@ -32,15 +34,17 @@ class GroqLLM(LLMBase):
if tools: if tools:
processed_response = { processed_response = {
"content": response.choices[0].message.content, "content": response.choices[0].message.content,
"tool_calls": [] "tool_calls": [],
} }
if response.choices[0].message.tool_calls: if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls: for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append({ processed_response["tool_calls"].append(
"name": tool_call.function.name, {
"arguments": json.loads(tool_call.function.arguments) "name": tool_call.function.name,
}) "arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response return processed_response
else: else:
@@ -70,7 +74,7 @@ class GroqLLM(LLMBase):
"messages": messages, "messages": messages,
"temperature": self.config.temperature, "temperature": self.config.temperature,
"max_tokens": self.config.max_tokens, "max_tokens": self.config.max_tokens,
"top_p": self.config.top_p "top_p": self.config.top_p,
} }
if response_format: if response_format:
params["response_format"] = response_format params["response_format"] = response_format

View File

@@ -1,7 +1,12 @@
import json import json
from typing import Dict, List, Optional from typing import Dict, List, Optional
import litellm try:
import litellm
except ImportError:
raise ImportError(
"litellm requires extra dependencies. Install with `pip install litellm`"
) from None
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
@@ -12,7 +17,7 @@ class LiteLLM(LLMBase):
super().__init__(config) super().__init__(config)
if not self.config.model: if not self.config.model:
self.config.model="gpt-4o" self.config.model = "gpt-4o"
def _parse_response(self, response, tools): def _parse_response(self, response, tools):
""" """
@@ -28,15 +33,17 @@ class LiteLLM(LLMBase):
if tools: if tools:
processed_response = { processed_response = {
"content": response.choices[0].message.content, "content": response.choices[0].message.content,
"tool_calls": [] "tool_calls": [],
} }
if response.choices[0].message.tool_calls: if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls: for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append({ processed_response["tool_calls"].append(
"name": tool_call.function.name, {
"arguments": json.loads(tool_call.function.arguments) "name": tool_call.function.name,
}) "arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response return processed_response
else: else:
@@ -62,14 +69,16 @@ class LiteLLM(LLMBase):
str: The generated response. str: The generated response.
""" """
if not litellm.supports_function_calling(self.config.model): if not litellm.supports_function_calling(self.config.model):
raise ValueError(f"Model '{self.config.model}' in litellm does not support function calling.") raise ValueError(
f"Model '{self.config.model}' in litellm does not support function calling."
)
params = { params = {
"model": self.config.model, "model": self.config.model,
"messages": messages, "messages": messages,
"temperature": self.config.temperature, "temperature": self.config.temperature,
"max_tokens": self.config.max_tokens, "max_tokens": self.config.max_tokens,
"top_p": self.config.top_p "top_p": self.config.top_p,
} }
if response_format: if response_format:
params["response_format"] = response_format params["response_format"] = response_format

View File

@@ -3,17 +3,20 @@ from typing import Dict, List, Optional
try: try:
from ollama import Client from ollama import Client
except ImportError: except ImportError:
raise ImportError("Ollama requires extra dependencies. Install with `pip install ollama`") from None raise ImportError(
"Ollama requires extra dependencies. Install with `pip install ollama`"
) from None
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
class OllamaLLM(LLMBase): class OllamaLLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None): def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config) super().__init__(config)
if not self.config.model: if not self.config.model:
self.config.model="llama3.1:70b" self.config.model = "llama3.1:70b"
self.client = Client(host=self.config.ollama_base_url) self.client = Client(host=self.config.ollama_base_url)
self._ensure_model_exists() self._ensure_model_exists()
@@ -38,20 +41,22 @@ class OllamaLLM(LLMBase):
""" """
if tools: if tools:
processed_response = { processed_response = {
"content": response['message']['content'], "content": response["message"]["content"],
"tool_calls": [] "tool_calls": [],
} }
if response['message'].get('tool_calls'): if response["message"].get("tool_calls"):
for tool_call in response['message']['tool_calls']: for tool_call in response["message"]["tool_calls"]:
processed_response["tool_calls"].append({ processed_response["tool_calls"].append(
"name": tool_call["function"]["name"], {
"arguments": tool_call["function"]["arguments"] "name": tool_call["function"]["name"],
}) "arguments": tool_call["function"]["arguments"],
}
)
return processed_response return processed_response
else: else:
return response['message']['content'] return response["message"]["content"]
def generate_response( def generate_response(
self, self,
@@ -78,8 +83,8 @@ class OllamaLLM(LLMBase):
"options": { "options": {
"temperature": self.config.temperature, "temperature": self.config.temperature,
"num_predict": self.config.max_tokens, "num_predict": self.config.max_tokens,
"top_p": self.config.top_p "top_p": self.config.top_p,
} },
} }
if response_format: if response_format:
params["format"] = response_format params["format"] = response_format

View File

@@ -7,15 +7,19 @@ from openai import OpenAI
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
class OpenAILLM(LLMBase): class OpenAILLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None): def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config) super().__init__(config)
if not self.config.model: if not self.config.model:
self.config.model="gpt-4o" self.config.model = "gpt-4o"
if os.environ.get("OPENROUTER_API_KEY"): # Use OpenRouter if os.environ.get("OPENROUTER_API_KEY"): # Use OpenRouter
self.client = OpenAI(api_key=os.environ.get("OPENROUTER_API_KEY"), base_url=self.config.openrouter_base_url) self.client = OpenAI(
api_key=os.environ.get("OPENROUTER_API_KEY"),
base_url=self.config.openrouter_base_url,
)
else: else:
api_key = os.getenv("OPENAI_API_KEY") or self.config.api_key api_key = os.getenv("OPENAI_API_KEY") or self.config.api_key
self.client = OpenAI(api_key=api_key) self.client = OpenAI(api_key=api_key)
@@ -34,15 +38,17 @@ class OpenAILLM(LLMBase):
if tools: if tools:
processed_response = { processed_response = {
"content": response.choices[0].message.content, "content": response.choices[0].message.content,
"tool_calls": [] "tool_calls": [],
} }
if response.choices[0].message.tool_calls: if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls: for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append({ processed_response["tool_calls"].append(
"name": tool_call.function.name, {
"arguments": json.loads(tool_call.function.arguments) "name": tool_call.function.name,
}) "arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response return processed_response
else: else:
@@ -72,7 +78,7 @@ class OpenAILLM(LLMBase):
"messages": messages, "messages": messages,
"temperature": self.config.temperature, "temperature": self.config.temperature,
"max_tokens": self.config.max_tokens, "max_tokens": self.config.max_tokens,
"top_p": self.config.top_p "top_p": self.config.top_p,
} }
if os.getenv("OPENROUTER_API_KEY"): if os.getenv("OPENROUTER_API_KEY"):
@@ -83,11 +89,11 @@ class OpenAILLM(LLMBase):
params.pop("model") params.pop("model")
if self.config.site_url and self.config.app_name: if self.config.site_url and self.config.app_name:
extra_headers={ extra_headers = {
"HTTP-Referer": self.config.site_url, "HTTP-Referer": self.config.site_url,
"X-Title": self.config.app_name "X-Title": self.config.app_name,
} }
openrouter_params["extra_headers"] = extra_headers openrouter_params["extra_headers"] = extra_headers
params.update(**openrouter_params) params.update(**openrouter_params)

View File

@@ -4,17 +4,20 @@ from typing import Dict, List, Optional
try: try:
from together import Together from together import Together
except ImportError: except ImportError:
raise ImportError("Together requires extra dependencies. Install with `pip install together`") from None raise ImportError(
"Together requires extra dependencies. Install with `pip install together`"
) from None
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
class TogetherLLM(LLMBase): class TogetherLLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None): def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config) super().__init__(config)
if not self.config.model: if not self.config.model:
self.config.model="mistralai/Mixtral-8x7B-Instruct-v0.1" self.config.model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
self.client = Together() self.client = Together()
def _parse_response(self, response, tools): def _parse_response(self, response, tools):
@@ -31,15 +34,17 @@ class TogetherLLM(LLMBase):
if tools: if tools:
processed_response = { processed_response = {
"content": response.choices[0].message.content, "content": response.choices[0].message.content,
"tool_calls": [] "tool_calls": [],
} }
if response.choices[0].message.tool_calls: if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls: for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append({ processed_response["tool_calls"].append(
"name": tool_call.function.name, {
"arguments": json.loads(tool_call.function.arguments) "name": tool_call.function.name,
}) "arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response return processed_response
else: else:
@@ -69,7 +74,7 @@ class TogetherLLM(LLMBase):
"messages": messages, "messages": messages,
"temperature": self.config.temperature, "temperature": self.config.temperature,
"max_tokens": self.config.max_tokens, "max_tokens": self.config.max_tokens,
"top_p": self.config.top_p "top_p": self.config.top_p,
} }
if response_format: if response_format:
params["response_format"] = response_format params["response_format"] = response_format

View File

@@ -28,8 +28,12 @@ setup_config()
class Memory(MemoryBase): class Memory(MemoryBase):
def __init__(self, config: MemoryConfig = MemoryConfig()): def __init__(self, config: MemoryConfig = MemoryConfig()):
self.config = config self.config = config
self.embedding_model = EmbedderFactory.create(self.config.embedder.provider, self.config.embedder.config) self.embedding_model = EmbedderFactory.create(
self.vector_store = VectorStoreFactory.create(self.config.vector_store.provider, self.config.vector_store.config) self.config.embedder.provider, self.config.embedder.config
)
self.vector_store = VectorStoreFactory.create(
self.config.vector_store.provider, self.config.vector_store.config
)
self.llm = LlmFactory.create(self.config.llm.provider, self.config.llm.config) self.llm = LlmFactory.create(self.config.llm.provider, self.config.llm.config)
self.db = SQLiteManager(self.config.history_db_path) self.db = SQLiteManager(self.config.history_db_path)
self.collection_name = self.config.vector_store.config.collection_name self.collection_name = self.config.vector_store.config.collection_name
@@ -172,7 +176,11 @@ class Memory(MemoryBase):
if not memory: if not memory:
return None return None
filters = {key: memory.payload[key] for key in ["user_id", "agent_id", "run_id"] if memory.payload.get(key)} filters = {
key: memory.payload[key]
for key in ["user_id", "agent_id", "run_id"]
if memory.payload.get(key)
}
# Prepare base memory item # Prepare base memory item
memory_item = MemoryItem( memory_item = MemoryItem(
@@ -184,8 +192,18 @@ class Memory(MemoryBase):
).model_dump(exclude={"score"}) ).model_dump(exclude={"score"})
# Add metadata if there are additional keys # Add metadata if there are additional keys
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"} excluded_keys = {
additional_metadata = {k: v for k, v in memory.payload.items() if k not in excluded_keys} "user_id",
"agent_id",
"run_id",
"hash",
"data",
"created_at",
"updated_at",
}
additional_metadata = {
k: v for k, v in memory.payload.items() if k not in excluded_keys
}
if additional_metadata: if additional_metadata:
memory_item["metadata"] = additional_metadata memory_item["metadata"] = additional_metadata
@@ -211,7 +229,15 @@ class Memory(MemoryBase):
capture_event("mem0.get_all", self, {"filters": len(filters), "limit": limit}) capture_event("mem0.get_all", self, {"filters": len(filters), "limit": limit})
memories = self.vector_store.list(filters=filters, limit=limit) memories = self.vector_store.list(filters=filters, limit=limit)
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"} excluded_keys = {
"user_id",
"agent_id",
"run_id",
"hash",
"data",
"created_at",
"updated_at",
}
return [ return [
{ {
**MemoryItem( **MemoryItem(
@@ -221,9 +247,22 @@ class Memory(MemoryBase):
created_at=mem.payload.get("created_at"), created_at=mem.payload.get("created_at"),
updated_at=mem.payload.get("updated_at"), updated_at=mem.payload.get("updated_at"),
).model_dump(exclude={"score"}), ).model_dump(exclude={"score"}),
**{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload}, **{
**({"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}} key: mem.payload[key]
if any(k for k in mem.payload if k not in excluded_keys) else {}) for key in ["user_id", "agent_id", "run_id"]
if key in mem.payload
},
**(
{
"metadata": {
k: v
for k, v in mem.payload.items()
if k not in excluded_keys
}
}
if any(k for k in mem.payload if k not in excluded_keys)
else {}
),
} }
for mem in memories[0] for mem in memories[0]
] ]
@@ -255,9 +294,19 @@ class Memory(MemoryBase):
capture_event("mem0.search", self, {"filters": len(filters), "limit": limit}) capture_event("mem0.search", self, {"filters": len(filters), "limit": limit})
embeddings = self.embedding_model.embed(query) embeddings = self.embedding_model.embed(query)
memories = self.vector_store.search(query=embeddings, limit=limit, filters=filters) memories = self.vector_store.search(
query=embeddings, limit=limit, filters=filters
)
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"} excluded_keys = {
"user_id",
"agent_id",
"run_id",
"hash",
"data",
"created_at",
"updated_at",
}
return [ return [
{ {
@@ -269,9 +318,22 @@ class Memory(MemoryBase):
updated_at=mem.payload.get("updated_at"), updated_at=mem.payload.get("updated_at"),
score=mem.score, score=mem.score,
).model_dump(), ).model_dump(),
**{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload}, **{
**({"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}} key: mem.payload[key]
if any(k for k in mem.payload if k not in excluded_keys) else {}) for key in ["user_id", "agent_id", "run_id"]
if key in mem.payload
},
**(
{
"metadata": {
k: v
for k, v in mem.payload.items()
if k not in excluded_keys
}
}
if any(k for k in mem.payload if k not in excluded_keys)
else {}
),
} }
for mem in memories for mem in memories
] ]
@@ -289,7 +351,7 @@ class Memory(MemoryBase):
""" """
capture_event("mem0.update", self, {"memory_id": memory_id}) capture_event("mem0.update", self, {"memory_id": memory_id})
self._update_memory_tool(memory_id, data) self._update_memory_tool(memory_id, data)
return {'message': 'Memory updated successfully!'} return {"message": "Memory updated successfully!"}
def delete(self, memory_id): def delete(self, memory_id):
""" """
@@ -300,7 +362,7 @@ class Memory(MemoryBase):
""" """
capture_event("mem0.delete", self, {"memory_id": memory_id}) capture_event("mem0.delete", self, {"memory_id": memory_id})
self._delete_memory_tool(memory_id) self._delete_memory_tool(memory_id)
return {'message': 'Memory deleted successfully!'} return {"message": "Memory deleted successfully!"}
def delete_all(self, user_id=None, agent_id=None, run_id=None): def delete_all(self, user_id=None, agent_id=None, run_id=None):
""" """
@@ -328,7 +390,7 @@ class Memory(MemoryBase):
memories = self.vector_store.list(filters=filters)[0] memories = self.vector_store.list(filters=filters)[0]
for memory in memories: for memory in memories:
self._delete_memory_tool(memory.id) self._delete_memory_tool(memory.id)
return {'message': 'Memories deleted successfully!'} return {"message": "Memories deleted successfully!"}
def history(self, memory_id): def history(self, memory_id):
""" """
@@ -350,14 +412,16 @@ class Memory(MemoryBase):
metadata = metadata or {} metadata = metadata or {}
metadata["data"] = data metadata["data"] = data
metadata["hash"] = hashlib.md5(data.encode()).hexdigest() metadata["hash"] = hashlib.md5(data.encode()).hexdigest()
metadata["created_at"] = datetime.now(pytz.timezone('US/Pacific')).isoformat() metadata["created_at"] = datetime.now(pytz.timezone("US/Pacific")).isoformat()
self.vector_store.insert( self.vector_store.insert(
vectors=[embeddings], vectors=[embeddings],
ids=[memory_id], ids=[memory_id],
payloads=[metadata], payloads=[metadata],
) )
self.db.add_history(memory_id, None, data, "ADD", created_at=metadata["created_at"]) self.db.add_history(
memory_id, None, data, "ADD", created_at=metadata["created_at"]
)
return memory_id return memory_id
def _update_memory_tool(self, memory_id, data, metadata=None): def _update_memory_tool(self, memory_id, data, metadata=None):
@@ -368,7 +432,9 @@ class Memory(MemoryBase):
new_metadata["data"] = data new_metadata["data"] = data
new_metadata["hash"] = existing_memory.payload.get("hash") new_metadata["hash"] = existing_memory.payload.get("hash")
new_metadata["created_at"] = existing_memory.payload.get("created_at") new_metadata["created_at"] = existing_memory.payload.get("created_at")
new_metadata["updated_at"] = datetime.now(pytz.timezone('US/Pacific')).isoformat() new_metadata["updated_at"] = datetime.now(
pytz.timezone("US/Pacific")
).isoformat()
if "user_id" in existing_memory.payload: if "user_id" in existing_memory.payload:
new_metadata["user_id"] = existing_memory.payload["user_id"] new_metadata["user_id"] = existing_memory.payload["user_id"]
@@ -384,7 +450,14 @@ class Memory(MemoryBase):
payload=new_metadata, payload=new_metadata,
) )
logging.info(f"Updating memory with ID {memory_id=} with {data=}") logging.info(f"Updating memory with ID {memory_id=} with {data=}")
self.db.add_history(memory_id, prev_value, data, "UPDATE", created_at=new_metadata["created_at"], updated_at=new_metadata["updated_at"]) self.db.add_history(
memory_id,
prev_value,
data,
"UPDATE",
created_at=new_metadata["created_at"],
updated_at=new_metadata["updated_at"],
)
def _delete_memory_tool(self, memory_id): def _delete_memory_tool(self, memory_id):
logging.info(f"Deleting memory with {memory_id=}") logging.info(f"Deleting memory with {memory_id=}")

View File

@@ -20,12 +20,12 @@ def setup_config():
def get_user_id(): def get_user_id():
config_path = os.path.join(mem0_dir, "config.json") config_path = os.path.join(mem0_dir, "config.json")
if not os.path.exists(config_path): if not os.path.exists(config_path):
return "anonymous_user" return "anonymous_user"
try: try:
with open(config_path, "r") as config_file: with open(config_path, "r") as config_file:
config = json.load(config_file) config = json.load(config_file)
user_id = config.get("user_id") user_id = config.get("user_id")
return user_id return user_id
except: except Exception:
return "anonymous_user" return "anonymous_user"

View File

@@ -12,7 +12,9 @@ class SQLiteManager:
with self.connection: with self.connection:
cursor = self.connection.cursor() cursor = self.connection.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='history'") cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='history'"
)
table_exists = cursor.fetchone() is not None table_exists = cursor.fetchone() is not None
if table_exists: if table_exists:
@@ -22,15 +24,15 @@ class SQLiteManager:
# Define the expected schema # Define the expected schema
expected_schema = { expected_schema = {
'id': 'TEXT', "id": "TEXT",
'memory_id': 'TEXT', "memory_id": "TEXT",
'old_memory': 'TEXT', "old_memory": "TEXT",
'new_memory': 'TEXT', "new_memory": "TEXT",
'new_value': 'TEXT', "new_value": "TEXT",
'event': 'TEXT', "event": "TEXT",
'created_at': 'DATETIME', "created_at": "DATETIME",
'updated_at': 'DATETIME', "updated_at": "DATETIME",
'is_deleted': 'INTEGER' "is_deleted": "INTEGER",
} }
# Check if the schemas are the same # Check if the schemas are the same
@@ -38,7 +40,8 @@ class SQLiteManager:
# Rename the old table # Rename the old table
cursor.execute("ALTER TABLE history RENAME TO old_history") cursor.execute("ALTER TABLE history RENAME TO old_history")
cursor.execute(""" cursor.execute(
"""
CREATE TABLE IF NOT EXISTS history ( CREATE TABLE IF NOT EXISTS history (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
memory_id TEXT, memory_id TEXT,
@@ -50,20 +53,22 @@ class SQLiteManager:
updated_at DATETIME, updated_at DATETIME,
is_deleted INTEGER is_deleted INTEGER
) )
""") """
)
# Copy data from the old table to the new table # Copy data from the old table to the new table
cursor.execute(""" cursor.execute(
"""
INSERT INTO history (id, memory_id, old_memory, new_memory, new_value, event, created_at, updated_at, is_deleted) INSERT INTO history (id, memory_id, old_memory, new_memory, new_value, event, created_at, updated_at, is_deleted)
SELECT id, memory_id, prev_value, new_value, new_value, event, timestamp, timestamp, is_deleted SELECT id, memory_id, prev_value, new_value, new_value, event, timestamp, timestamp, is_deleted
FROM old_history FROM old_history
""") """
)
cursor.execute("DROP TABLE old_history") cursor.execute("DROP TABLE old_history")
self.connection.commit() self.connection.commit()
def _create_history_table(self): def _create_history_table(self):
with self.connection: with self.connection:
self.connection.execute( self.connection.execute(
@@ -82,7 +87,16 @@ class SQLiteManager:
""" """
) )
def add_history(self, memory_id, old_memory, new_memory, event, created_at = None, updated_at = None, is_deleted=0): def add_history(
self,
memory_id,
old_memory,
new_memory,
event,
created_at=None,
updated_at=None,
is_deleted=0,
):
with self.connection: with self.connection:
self.connection.execute( self.connection.execute(
""" """

View File

@@ -1,18 +1,26 @@
import httpx import httpx
from typing import Optional, List, Union from typing import Optional, List, Union
import threading import threading
import litellm
try:
import litellm
except ImportError:
raise ImportError(
"litellm requires extra dependencies. Install with `pip install litellm`"
) from None
from mem0.memory.telemetry import capture_client_event
from mem0 import Memory, MemoryClient from mem0 import Memory, MemoryClient
from mem0.configs.prompts import MEMORY_ANSWER_PROMPT from mem0.configs.prompts import MEMORY_ANSWER_PROMPT
class Mem0: class Mem0:
def __init__( def __init__(
self, self,
config: Optional[dict] = None, config: Optional[dict] = None,
api_key: Optional[str] = None, api_key: Optional[str] = None,
host: Optional[str] = None host: Optional[str] = None,
): ):
if api_key: if api_key:
self.mem0_client = MemoryClient(api_key, host) self.mem0_client = MemoryClient(api_key, host)
else: else:
@@ -77,13 +85,21 @@ class Completions:
raise ValueError("One of user_id, agent_id, run_id must be provided") raise ValueError("One of user_id, agent_id, run_id must be provided")
if not litellm.supports_function_calling(model): if not litellm.supports_function_calling(model):
raise ValueError(f"Model '{model}' does not support function calling. Please use a model that supports function calling.") raise ValueError(
f"Model '{model}' does not support function calling. Please use a model that supports function calling."
)
prepared_messages = self._prepare_messages(messages) prepared_messages = self._prepare_messages(messages)
if prepared_messages[-1]["role"] == "user": if prepared_messages[-1]["role"] == "user":
self._async_add_to_memory(messages, user_id, agent_id, run_id, metadata, filters) self._async_add_to_memory(
relevant_memories = self._fetch_relevant_memories(messages, user_id, agent_id, run_id, filters, limit) messages, user_id, agent_id, run_id, metadata, filters
prepared_messages[-1]["content"] = self._format_query_with_memories(messages, relevant_memories) )
relevant_memories = self._fetch_relevant_memories(
messages, user_id, agent_id, run_id, filters, limit
)
prepared_messages[-1]["content"] = self._format_query_with_memories(
messages, relevant_memories
)
response = litellm.completion( response = litellm.completion(
model=model, model=model,
@@ -114,9 +130,9 @@ class Completions:
base_url=base_url, base_url=base_url,
api_version=api_version, api_version=api_version,
api_key=api_key, api_key=api_key,
model_list=model_list model_list=model_list,
) )
capture_client_event("mem0.chat.create", self)
return response return response
def _prepare_messages(self, messages: List[dict]) -> List[dict]: def _prepare_messages(self, messages: List[dict]) -> List[dict]:
@@ -125,7 +141,9 @@ class Completions:
messages[0]["content"] = MEMORY_ANSWER_PROMPT messages[0]["content"] = MEMORY_ANSWER_PROMPT
return messages return messages
def _async_add_to_memory(self, messages, user_id, agent_id, run_id, metadata, filters): def _async_add_to_memory(
self, messages, user_id, agent_id, run_id, metadata, filters
):
def add_task(): def add_task():
self.mem0_client.add( self.mem0_client.add(
messages=messages, messages=messages,
@@ -135,11 +153,16 @@ class Completions:
metadata=metadata, metadata=metadata,
filters=filters, filters=filters,
) )
threading.Thread(target=add_task, daemon=True).start() threading.Thread(target=add_task, daemon=True).start()
def _fetch_relevant_memories(self, messages, user_id, agent_id, run_id, filters, limit): def _fetch_relevant_memories(
self, messages, user_id, agent_id, run_id, filters, limit
):
# Currently, only pass the last 6 messages to the search API to prevent long query # Currently, only pass the last 6 messages to the search API to prevent long query
message_input = [f"{message['role']}: {message['content']}" for message in messages][-6:] message_input = [
f"{message['role']}: {message['content']}" for message in messages
][-6:]
# TODO: Make it better by summarizing the past conversation # TODO: Make it better by summarizing the past conversation
return self.mem0_client.search( return self.mem0_client.search(
query="\n".join(message_input), query="\n".join(message_input),

View File

@@ -3,6 +3,7 @@ import importlib
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.configs.embeddings.base import BaseEmbedderConfig from mem0.configs.embeddings.base import BaseEmbedderConfig
def load_class(class_type): def load_class(class_type):
module_path, class_name = class_type.rsplit(".", 1) module_path, class_name = class_type.rsplit(".", 1)
module = importlib.import_module(module_path) module = importlib.import_module(module_path)
@@ -30,6 +31,7 @@ class LlmFactory:
else: else:
raise ValueError(f"Unsupported Llm provider: {provider_name}") raise ValueError(f"Unsupported Llm provider: {provider_name}")
class EmbedderFactory: class EmbedderFactory:
provider_to_class = { provider_to_class = {
"openai": "mem0.embeddings.openai.OpenAIEmbedding", "openai": "mem0.embeddings.openai.OpenAIEmbedding",
@@ -48,11 +50,12 @@ class EmbedderFactory:
else: else:
raise ValueError(f"Unsupported Embedder provider: {provider_name}") raise ValueError(f"Unsupported Embedder provider: {provider_name}")
class VectorStoreFactory: class VectorStoreFactory:
provider_to_class = { provider_to_class = {
"qdrant": "mem0.vector_stores.qdrant.Qdrant", "qdrant": "mem0.vector_stores.qdrant.Qdrant",
"chroma": "mem0.vector_stores.chroma.ChromaDB", "chroma": "mem0.vector_stores.chroma.ChromaDB",
"pgvector": "mem0.vector_stores.pgvector.PGVector" "pgvector": "mem0.vector_stores.pgvector.PGVector",
} }
@classmethod @classmethod

View File

@@ -7,7 +7,9 @@ try:
import chromadb import chromadb
from chromadb.config import Settings from chromadb.config import Settings
except ImportError: except ImportError:
raise ImportError("Chromadb requires extra dependencies. Install with `pip install chromadb`") from None raise ImportError(
"Chromadb requires extra dependencies. Install with `pip install chromadb`"
) from None
from mem0.vector_stores.base import VectorStoreBase from mem0.vector_stores.base import VectorStoreBase
@@ -25,7 +27,7 @@ class ChromaDB(VectorStoreBase):
client: Optional[chromadb.Client] = None, client: Optional[chromadb.Client] = None,
host: Optional[str] = None, host: Optional[str] = None,
port: Optional[int] = None, port: Optional[int] = None,
path: Optional[str] = None path: Optional[str] = None,
): ):
""" """
Initialize the Chromadb vector store. Initialize the Chromadb vector store.
@@ -68,7 +70,7 @@ class ChromaDB(VectorStoreBase):
Returns: Returns:
List[OutputData]: Parsed output data. List[OutputData]: Parsed output data.
""" """
keys = ['ids', 'distances', 'metadatas'] keys = ["ids", "distances", "metadatas"]
values = [] values = []
for key in keys: for key in keys:
@@ -78,14 +80,24 @@ class ChromaDB(VectorStoreBase):
values.append(value) values.append(value)
ids, distances, metadatas = values ids, distances, metadatas = values
max_length = max(len(v) for v in values if isinstance(v, list) and v is not None) max_length = max(
len(v) for v in values if isinstance(v, list) and v is not None
)
result = [] result = []
for i in range(max_length): for i in range(max_length):
entry = OutputData( entry = OutputData(
id=ids[i] if isinstance(ids, list) and ids and i < len(ids) else None, id=ids[i] if isinstance(ids, list) and ids and i < len(ids) else None,
score=distances[i] if isinstance(distances, list) and distances and i < len(distances) else None, score=(
payload=metadatas[i] if isinstance(metadatas, list) and metadatas and i < len(metadatas) else None, distances[i]
if isinstance(distances, list) and distances and i < len(distances)
else None
),
payload=(
metadatas[i]
if isinstance(metadatas, list) and metadatas and i < len(metadatas)
else None
),
) )
result.append(entry) result.append(entry)
@@ -114,7 +126,12 @@ class ChromaDB(VectorStoreBase):
) )
return collection return collection
def insert(self, vectors: List[list], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None): def insert(
self,
vectors: List[list],
payloads: Optional[List[Dict]] = None,
ids: Optional[List[str]] = None,
):
""" """
Insert vectors into a collection. Insert vectors into a collection.
@@ -125,7 +142,9 @@ class ChromaDB(VectorStoreBase):
""" """
self.collection.add(ids=ids, embeddings=vectors, metadatas=payloads) self.collection.add(ids=ids, embeddings=vectors, metadatas=payloads)
def search(self, query: List[list], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]: def search(
self, query: List[list], limit: int = 5, filters: Optional[Dict] = None
) -> List[OutputData]:
""" """
Search for similar vectors. Search for similar vectors.
@@ -137,7 +156,9 @@ class ChromaDB(VectorStoreBase):
Returns: Returns:
List[OutputData]: Search results. List[OutputData]: Search results.
""" """
results = self.collection.query(query_embeddings=query, where=filters, n_results=limit) results = self.collection.query(
query_embeddings=query, where=filters, n_results=limit
)
final_results = self._parse_output(results) final_results = self._parse_output(results)
return final_results return final_results
@@ -150,7 +171,12 @@ class ChromaDB(VectorStoreBase):
""" """
self.collection.delete(ids=vector_id) self.collection.delete(ids=vector_id)
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None): def update(
self,
vector_id: str,
vector: Optional[List[float]] = None,
payload: Optional[Dict] = None,
):
""" """
Update a vector and its payload. Update a vector and its payload.
@@ -198,7 +224,9 @@ class ChromaDB(VectorStoreBase):
""" """
return self.client.get_collection(name=self.collection_name) return self.client.get_collection(name=self.collection_name)
def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]: def list(
self, filters: Optional[Dict] = None, limit: int = 100
) -> List[OutputData]:
""" """
List all vectors in a collection. List all vectors in a collection.

View File

@@ -1,31 +1,34 @@
from typing import Optional, Dict from typing import Optional, Dict
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
class VectorStoreConfig(BaseModel): class VectorStoreConfig(BaseModel):
provider: str = Field( provider: str = Field(
description="Provider of the vector store (e.g., 'qdrant', 'chroma')", description="Provider of the vector store (e.g., 'qdrant', 'chroma')",
default="qdrant", default="qdrant",
) )
config: Optional[Dict] = Field( config: Optional[Dict] = Field(
description="Configuration for the specific vector store", description="Configuration for the specific vector store", default=None
default=None
) )
_provider_configs: Dict[str, str] = { _provider_configs: Dict[str, str] = {
"qdrant": "QdrantConfig", "qdrant": "QdrantConfig",
"chroma": "ChromaDbConfig", "chroma": "ChromaDbConfig",
"pgvector": "PGVectorConfig" "pgvector": "PGVectorConfig",
} }
@model_validator(mode="after") @model_validator(mode="after")
def validate_and_create_config(self) -> 'VectorStoreConfig': def validate_and_create_config(self) -> "VectorStoreConfig":
provider = self.provider provider = self.provider
config = self.config config = self.config
if provider not in self._provider_configs: if provider not in self._provider_configs:
raise ValueError(f"Unsupported vector store provider: {provider}") raise ValueError(f"Unsupported vector store provider: {provider}")
module = __import__(f"mem0.configs.vector_stores.{provider}", fromlist=[self._provider_configs[provider]]) module = __import__(
f"mem0.configs.vector_stores.{provider}",
fromlist=[self._provider_configs[provider]],
)
config_class = getattr(module, self._provider_configs[provider]) config_class = getattr(module, self._provider_configs[provider])
if config is None: if config is None:

View File

@@ -1,16 +1,19 @@
import json import json
from typing import Optional, List, Dict, Any from typing import Optional, List
from pydantic import BaseModel from pydantic import BaseModel
try: try:
import psycopg2 import psycopg2
from psycopg2.extras import execute_values from psycopg2.extras import execute_values
except ImportError: except ImportError:
raise ImportError("PGVector requires extra dependencies. Install with `pip install psycopg2`") from None raise ImportError(
"PGVector requires extra dependencies. Install with `pip install psycopg2`"
) from None
from mem0.vector_stores.base import VectorStoreBase from mem0.vector_stores.base import VectorStoreBase
class OutputData(BaseModel): class OutputData(BaseModel):
id: Optional[str] id: Optional[str]
score: Optional[float] score: Optional[float]
@@ -19,14 +22,7 @@ class OutputData(BaseModel):
class PGVector(VectorStoreBase): class PGVector(VectorStoreBase):
def __init__( def __init__(
self, self, dbname, collection_name, embedding_model_dims, user, password, host, port
dbname,
collection_name,
embedding_model_dims,
user,
password,
host,
port
): ):
""" """
Initialize the PGVector database. Initialize the PGVector database.
@@ -43,11 +39,7 @@ class PGVector(VectorStoreBase):
self.collection_name = collection_name self.collection_name = collection_name
self.conn = psycopg2.connect( self.conn = psycopg2.connect(
dbname=dbname, dbname=dbname, user=user, password=password, host=host, port=port
user=user,
password=password,
host=host,
port=port
) )
self.cur = self.conn.cursor() self.cur = self.conn.cursor()
@@ -63,16 +55,18 @@ class PGVector(VectorStoreBase):
name (str): Name of the collection. name (str): Name of the collection.
embedding_model_dims (int, optional): Dimension of the embedding vector. embedding_model_dims (int, optional): Dimension of the embedding vector.
""" """
self.cur.execute(f""" self.cur.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.collection_name} ( CREATE TABLE IF NOT EXISTS {self.collection_name} (
id UUID PRIMARY KEY, id UUID PRIMARY KEY,
vector vector({embedding_model_dims}), vector vector({embedding_model_dims}),
payload JSONB payload JSONB
); );
""") """
)
self.conn.commit() self.conn.commit()
def insert(self, vectors, payloads = None, ids = None): def insert(self, vectors, payloads=None, ids=None):
""" """
Insert vectors into a collection. Insert vectors into a collection.
@@ -83,11 +77,18 @@ class PGVector(VectorStoreBase):
""" """
json_payloads = [json.dumps(payload) for payload in payloads] json_payloads = [json.dumps(payload) for payload in payloads]
data = [(id, vector, payload) for id, vector, payload in zip(ids, vectors, json_payloads)] data = [
execute_values(self.cur, f"INSERT INTO {self.collection_name} (id, vector, payload) VALUES %s", data) (id, vector, payload)
for id, vector, payload in zip(ids, vectors, json_payloads)
]
execute_values(
self.cur,
f"INSERT INTO {self.collection_name} (id, vector, payload) VALUES %s",
data,
)
self.conn.commit() self.conn.commit()
def search(self, query, limit = 5, filters = None): def search(self, query, limit=5, filters=None):
""" """
Search for similar vectors. Search for similar vectors.
@@ -104,21 +105,28 @@ class PGVector(VectorStoreBase):
if filters: if filters:
for k, v in filters.items(): for k, v in filters.items():
filter_conditions.append(f"payload->>%s = %s") filter_conditions.append("payload->>%s = %s")
filter_params.extend([k, str(v)]) filter_params.extend([k, str(v)])
filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else "" filter_clause = (
"WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
)
self.cur.execute(f""" self.cur.execute(
f"""
SELECT id, vector <-> %s::vector AS distance, payload SELECT id, vector <-> %s::vector AS distance, payload
FROM {self.collection_name} FROM {self.collection_name}
{filter_clause} {filter_clause}
ORDER BY distance ORDER BY distance
LIMIT %s LIMIT %s
""", (query, *filter_params, limit)) """,
(query, *filter_params, limit),
)
results = self.cur.fetchall() results = self.cur.fetchall()
return [OutputData(id=str(r[0]), score=float(r[1]), payload=r[2]) for r in results] return [
OutputData(id=str(r[0]), score=float(r[1]), payload=r[2]) for r in results
]
def delete(self, vector_id): def delete(self, vector_id):
""" """
@@ -127,10 +135,12 @@ class PGVector(VectorStoreBase):
Args: Args:
vector_id (str): ID of the vector to delete. vector_id (str): ID of the vector to delete.
""" """
self.cur.execute(f"DELETE FROM {self.collection_name} WHERE id = %s", (vector_id,)) self.cur.execute(
f"DELETE FROM {self.collection_name} WHERE id = %s", (vector_id,)
)
self.conn.commit() self.conn.commit()
def update(self, vector_id, vector = None, payload = None): def update(self, vector_id, vector=None, payload=None):
""" """
Update a vector and its payload. Update a vector and its payload.
@@ -140,9 +150,15 @@ class PGVector(VectorStoreBase):
payload (Dict, optional): Updated payload. payload (Dict, optional): Updated payload.
""" """
if vector: if vector:
self.cur.execute(f"UPDATE {self.collection_name} SET vector = %s WHERE id = %s", (vector, vector_id)) self.cur.execute(
f"UPDATE {self.collection_name} SET vector = %s WHERE id = %s",
(vector, vector_id),
)
if payload: if payload:
self.cur.execute(f"UPDATE {self.collection_name} SET payload = %s WHERE id = %s", (psycopg2.extras.Json(payload), vector_id)) self.cur.execute(
f"UPDATE {self.collection_name} SET payload = %s WHERE id = %s",
(psycopg2.extras.Json(payload), vector_id),
)
self.conn.commit() self.conn.commit()
def get(self, vector_id) -> OutputData: def get(self, vector_id) -> OutputData:
@@ -155,7 +171,10 @@ class PGVector(VectorStoreBase):
Returns: Returns:
OutputData: Retrieved vector. OutputData: Retrieved vector.
""" """
self.cur.execute(f"SELECT id, vector, payload FROM {self.collection_name} WHERE id = %s", (vector_id,)) self.cur.execute(
f"SELECT id, vector, payload FROM {self.collection_name} WHERE id = %s",
(vector_id,),
)
result = self.cur.fetchone() result = self.cur.fetchone()
if not result: if not result:
return None return None
@@ -168,11 +187,13 @@ class PGVector(VectorStoreBase):
Returns: Returns:
List[str]: List of collection names. List[str]: List of collection names.
""" """
self.cur.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'") self.cur.execute(
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'"
)
return [row[0] for row in self.cur.fetchall()] return [row[0] for row in self.cur.fetchall()]
def delete_col(self): def delete_col(self):
""" Delete a collection. """ """Delete a collection."""
self.cur.execute(f"DROP TABLE IF EXISTS {self.collection_name}") self.cur.execute(f"DROP TABLE IF EXISTS {self.collection_name}")
self.conn.commit() self.conn.commit()
@@ -183,22 +204,21 @@ class PGVector(VectorStoreBase):
Returns: Returns:
Dict[str, Any]: Collection information. Dict[str, Any]: Collection information.
""" """
self.cur.execute(f""" self.cur.execute(
f"""
SELECT SELECT
table_name, table_name,
(SELECT COUNT(*) FROM {self.collection_name}) as row_count, (SELECT COUNT(*) FROM {self.collection_name}) as row_count,
(SELECT pg_size_pretty(pg_total_relation_size('{self.collection_name}'))) as total_size (SELECT pg_size_pretty(pg_total_relation_size('{self.collection_name}'))) as total_size
FROM information_schema.tables FROM information_schema.tables
WHERE table_schema = 'public' AND table_name = %s WHERE table_schema = 'public' AND table_name = %s
""", (self.collection_name,)) """,
(self.collection_name,),
)
result = self.cur.fetchone() result = self.cur.fetchone()
return { return {"name": result[0], "count": result[1], "size": result[2]}
"name": result[0],
"count": result[1],
"size": result[2]
}
def list(self, filters = None, limit = 100): def list(self, filters=None, limit=100):
""" """
List all vectors in a collection. List all vectors in a collection.
@@ -214,10 +234,12 @@ class PGVector(VectorStoreBase):
if filters: if filters:
for k, v in filters.items(): for k, v in filters.items():
filter_conditions.append(f"payload->>%s = %s") filter_conditions.append("payload->>%s = %s")
filter_params.extend([k, str(v)]) filter_params.extend([k, str(v)])
filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else "" filter_clause = (
"WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
)
query = f""" query = f"""
SELECT id, vector, payload SELECT id, vector, payload
@@ -235,7 +257,7 @@ class PGVector(VectorStoreBase):
""" """
Close the database connection when the object is deleted. Close the database connection when the object is deleted.
""" """
if hasattr(self, 'cur'): if hasattr(self, "cur"):
self.cur.close() self.cur.close()
if hasattr(self, 'conn'): if hasattr(self, "conn"):
self.conn.close() self.conn.close()

View File

@@ -28,7 +28,7 @@ class Qdrant(VectorStoreBase):
path: str = None, path: str = None,
url: str = None, url: str = None,
api_key: str = None, api_key: str = None,
on_disk: bool = False on_disk: bool = False,
): ):
""" """
Initialize the Qdrant vector store. Initialize the Qdrant vector store.
@@ -66,7 +66,9 @@ class Qdrant(VectorStoreBase):
self.collection_name = collection_name self.collection_name = collection_name
self.create_col(embedding_model_dims, on_disk) self.create_col(embedding_model_dims, on_disk)
def create_col(self, vector_size: int, on_disk: bool, distance: Distance = Distance.COSINE): def create_col(
self, vector_size: int, on_disk: bool, distance: Distance = Distance.COSINE
):
""" """
Create a new collection. Create a new collection.
@@ -79,12 +81,16 @@ class Qdrant(VectorStoreBase):
response = self.list_cols() response = self.list_cols()
for collection in response.collections: for collection in response.collections:
if collection.name == self.collection_name: if collection.name == self.collection_name:
logging.debug(f"Collection {self.collection_name} already exists. Skipping creation.") logging.debug(
f"Collection {self.collection_name} already exists. Skipping creation."
)
return return
self.client.create_collection( self.client.create_collection(
collection_name=self.collection_name, collection_name=self.collection_name,
vectors_config=VectorParams(size=vector_size, distance=distance, on_disk=on_disk), vectors_config=VectorParams(
size=vector_size, distance=distance, on_disk=on_disk
),
) )
def insert(self, vectors: list, payloads: list = None, ids: list = None): def insert(self, vectors: list, payloads: list = None, ids: list = None):
@@ -202,7 +208,7 @@ class Qdrant(VectorStoreBase):
return self.client.get_collections() return self.client.get_collections()
def delete_col(self): def delete_col(self):
""" Delete a collection. """ """Delete a collection."""
self.client.delete_collection(collection_name=self.collection_name) self.client.delete_collection(collection_name=self.collection_name)
def col_info(self) -> dict: def col_info(self) -> dict:

1222
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "mem0ai" name = "mem0ai"
version = "0.0.19" version = "0.0.20"
description = "Long-term memory for AI Agents" description = "Long-term memory for AI Agents"
authors = ["Mem0 <founders@mem0.ai>"] authors = ["Mem0 <founders@mem0.ai>"]
exclude = [ exclude = [
@@ -22,7 +22,6 @@ openai = "^1.33.0"
posthog = "^3.5.0" posthog = "^3.5.0"
pytz = "^2024.1" pytz = "^2024.1"
sqlalchemy = "^2.0.31" sqlalchemy = "^2.0.31"
litellm = "^1.42.7"
[tool.poetry.group.test.dependencies] [tool.poetry.group.test.dependencies]
pytest = "^8.2.2" pytest = "^8.2.2"

View File

@@ -2,7 +2,8 @@ import pytest
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from mem0.configs.prompts import MEMORY_ANSWER_PROMPT from mem0.configs.prompts import MEMORY_ANSWER_PROMPT
from mem0 import Memory, MemoryClient, Mem0 from mem0 import Memory, MemoryClient
from mem0.proxy.main import Mem0
from mem0.proxy.main import Chat, Completions from mem0.proxy.main import Chat, Completions
@pytest.fixture @pytest.fixture