diff --git a/Makefile b/Makefile index dae34e4a..2d3763d2 100644 --- a/Makefile +++ b/Makefile @@ -12,7 +12,7 @@ install: install_all: poetry install - poetry run pip install groq together boto3 litellm ollama chromadb sentence_transformers vertexai \ + poetry run pip install groq together boto3 litellm ollama chromadb weaviate weaviate-client sentence_transformers vertexai \ google-generativeai elasticsearch opensearch-py vecs # Format code with ruff diff --git a/docs/components/vectordbs/dbs/weaviate.mdx b/docs/components/vectordbs/dbs/weaviate.mdx new file mode 100644 index 00000000..f5c36f4f --- /dev/null +++ b/docs/components/vectordbs/dbs/weaviate.mdx @@ -0,0 +1,47 @@ +[Weaviate](https://weaviate.io/) is an open-source vector search engine. It allows efficient storage and retrieval of high-dimensional vector embeddings, enabling powerful search and retrieval capabilities. + + +### Installation +```bash +pip install weaviate weaviate-client +``` + +### Usage + +```python Python +import os +from mem0 import Memory + +os.environ["OPENAI_API_KEY"] = "sk-xx" + +config = { + "vector_store": { + "provider": "weaviate", + "config": { + "collection_name": "test", + "cluster_url": "http://localhost:8080", + "auth_client_secret": None, + } + } +} + +m = Memory.from_config(config) +messages = [ + {"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"}, + {"role": "assistant", "content": "How about a thriller movie? They can be quite engaging."}, + {"role": "user", "content": "I’m not a big fan of thriller movies but I love sci-fi movies."}, + {"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."} +] +m.add(messages, user_id="alice", metadata={"category": "movies"}) +``` + +### Config + +Let's see the available parameters for the `weaviate` config: + +| Parameter | Description | Default Value | +| --- | --- | --- | +| `collection_name` | The name of the collection to store the vectors | `mem0` | +| `embedding_model_dims` | Dimensions of the embedding model | `1536` | +| `cluster_url` | URL for the Weaviate server | `None` | +| `auth_client_secret` | API key for Weaviate authentication | `None` | \ No newline at end of file diff --git a/docs/components/vectordbs/overview.mdx b/docs/components/vectordbs/overview.mdx index 993a4e48..ebffe18c 100644 --- a/docs/components/vectordbs/overview.mdx +++ b/docs/components/vectordbs/overview.mdx @@ -25,6 +25,7 @@ See the list of supported vector databases below. + ## Usage diff --git a/docs/docs.json b/docs/docs.json index e8519700..c4b4f738 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -130,7 +130,8 @@ "components/vectordbs/dbs/elasticsearch", "components/vectordbs/dbs/opensearch", "components/vectordbs/dbs/supabase", - "components/vectordbs/dbs/vertex_ai_vector_search" + "components/vectordbs/dbs/vertex_ai_vector_search", + "components/vectordbs/dbs/weaviate" ] } ] diff --git a/mem0/configs/vector_stores/weaviate.py b/mem0/configs/vector_stores/weaviate.py new file mode 100644 index 00000000..54b5faa7 --- /dev/null +++ b/mem0/configs/vector_stores/weaviate.py @@ -0,0 +1,42 @@ +from typing import Any, ClassVar, Dict, Optional +from pydantic import BaseModel, Field, model_validator + + +class WeaviateConfig(BaseModel): + from weaviate import WeaviateClient + + WeaviateClient: ClassVar[type] = WeaviateClient + + collection_name: str = Field("mem0", description="Name of the collection") + embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model") + cluster_url: Optional[str] = Field(None, description="URL for Weaviate server") + auth_client_secret: Optional[str] = Field(None, description="API key for Weaviate authentication") + additional_headers: Optional[Dict[str, str]] = Field(None, description="Additional headers for requests") + + @model_validator(mode="before") + @classmethod + def check_connection_params(cls, values: Dict[str, Any]) -> Dict[str, Any]: + cluster_url = values.get("cluster_url") + + if not cluster_url: + raise ValueError("'cluster_url' must be provided.") + + return values + + @model_validator(mode="before") + @classmethod + def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]: + allowed_fields = set(cls.model_fields.keys()) + input_fields = set(values.keys()) + extra_fields = input_fields - allowed_fields + + if extra_fields: + raise ValueError( + f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}" + ) + + return values + + model_config = { + "arbitrary_types_allowed": True, + } diff --git a/mem0/memory/main.py b/mem0/memory/main.py index 1772b418..c2c45d48 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -313,6 +313,7 @@ class Memory(MemoryBase): "data", "created_at", "updated_at", + "id" } additional_metadata = {k: v for k, v in memory.payload.items() if k not in excluded_keys} if additional_metadata: @@ -376,6 +377,7 @@ class Memory(MemoryBase): "data", "created_at", "updated_at", + "id", } all_memories = [ { @@ -469,6 +471,7 @@ class Memory(MemoryBase): "data", "created_at", "updated_at", + "id", } original_memories = [ diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index 2e5eb262..92ec9cdf 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -72,6 +72,7 @@ class VectorStoreFactory: "vertex_ai_vector_search": "mem0.vector_stores.vertex_ai_vector_search.GoogleMatchingEngine", "opensearch": "mem0.vector_stores.opensearch.OpenSearchDB", "supabase": "mem0.vector_stores.supabase.Supabase", + "weaviate": "mem0.vector_stores.weaviate.Weaviate", } @classmethod diff --git a/mem0/vector_stores/configs.py b/mem0/vector_stores/configs.py index ffa7b5a3..d602e335 100644 --- a/mem0/vector_stores/configs.py +++ b/mem0/vector_stores/configs.py @@ -21,6 +21,7 @@ class VectorStoreConfig(BaseModel): "vertex_ai_vector_search": "GoogleMatchingEngineConfig", "opensearch": "OpenSearchConfig", "supabase": "SupabaseConfig", + "weaviate": "WeaviateConfig", } @model_validator(mode="after") diff --git a/mem0/vector_stores/supabase.py b/mem0/vector_stores/supabase.py index 765c2194..bd14d668 100644 --- a/mem0/vector_stores/supabase.py +++ b/mem0/vector_stores/supabase.py @@ -1,6 +1,6 @@ import logging import uuid -from typing import List, Optional, Dict, Any +from typing import List, Optional from pydantic import BaseModel diff --git a/mem0/vector_stores/weaviate.py b/mem0/vector_stores/weaviate.py new file mode 100644 index 00000000..bfad8db8 --- /dev/null +++ b/mem0/vector_stores/weaviate.py @@ -0,0 +1,307 @@ +import logging +import uuid +from typing import Dict, List, Mapping, Optional + +from pydantic import BaseModel + +try: + import weaviate +except ImportError: + raise ImportError( + "The 'weaviate' library is required. Please install it using 'pip install weaviate-client weaviate'." + ) + +import weaviate.classes.config as wvcc +from weaviate.classes.init import Auth +from weaviate.classes.query import Filter, MetadataQuery +from weaviate.util import get_valid_uuid + +from mem0.vector_stores.base import VectorStoreBase + +logger = logging.getLogger(__name__) + + +class OutputData(BaseModel): + id: str + score: float + payload: Dict + + +class Weaviate(VectorStoreBase): + def __init__( + self, + collection_name: str, + embedding_model_dims: int, + cluster_url: str = None, + auth_client_secret: str = None, + additional_headers: dict = None, + ): + """ + Initialize the Weaviate vector store. + + Args: + collection_name (str): Name of the collection/class in Weaviate. + embedding_model_dims (int): Dimensions of the embedding model. + client (WeaviateClient, optional): Existing Weaviate client instance. Defaults to None. + cluster_url (str, optional): URL for Weaviate server. Defaults to None. + auth_config (dict, optional): Authentication configuration for Weaviate. Defaults to None. + additional_headers (dict, optional): Additional headers for requests. Defaults to None. + """ + if "localhost" in cluster_url: + self.client = weaviate.connect_to_local(headers=additional_headers) + else: + self.client = weaviate.connect_to_wcs( + cluster_url=cluster_url, + auth_credentials=Auth.api_key(auth_client_secret), + headers=additional_headers, + ) + + self.collection_name = collection_name + self.create_col(embedding_model_dims) + + def _parse_output(self, data: Dict) -> List[OutputData]: + """ + Parse the output data. + + Args: + data (Dict): Output data. + + Returns: + List[OutputData]: Parsed output data. + """ + keys = ["ids", "distances", "metadatas"] + values = [] + + for key in keys: + value = data.get(key, []) + if isinstance(value, list) and value and isinstance(value[0], list): + value = value[0] + values.append(value) + + ids, distances, metadatas = values + max_length = max(len(v) for v in values if isinstance(v, list) and v is not None) + + result = [] + for i in range(max_length): + entry = OutputData( + id=ids[i] if isinstance(ids, list) and ids and i < len(ids) else None, + score=(distances[i] if isinstance(distances, list) and distances and i < len(distances) else None), + payload=(metadatas[i] if isinstance(metadatas, list) and metadatas and i < len(metadatas) else None), + ) + result.append(entry) + + return result + + def create_col(self, vector_size, distance="cosine"): + """ + Create a new collection with the specified schema. + + Args: + vector_size (int): Size of the vectors to be stored. + distance (str, optional): Distance metric for vector similarity. Defaults to "cosine". + """ + if self.client.collections.exists(self.collection_name): + logging.debug(f"Collection {self.collection_name} already exists. Skipping creation.") + return + + properties = [ + wvcc.Property(name="ids", data_type=wvcc.DataType.TEXT), + wvcc.Property(name="hash", data_type=wvcc.DataType.TEXT), + wvcc.Property( + name="metadata", + data_type=wvcc.DataType.TEXT, + description="Additional metadata", + ), + wvcc.Property(name="data", data_type=wvcc.DataType.TEXT), + wvcc.Property(name="created_at", data_type=wvcc.DataType.TEXT), + wvcc.Property(name="category", data_type=wvcc.DataType.TEXT), + wvcc.Property(name="updated_at", data_type=wvcc.DataType.TEXT), + wvcc.Property(name="user_id", data_type=wvcc.DataType.TEXT), + wvcc.Property(name="agent_id", data_type=wvcc.DataType.TEXT), + wvcc.Property(name="run_id", data_type=wvcc.DataType.TEXT), + ] + + vectorizer_config = wvcc.Configure.Vectorizer.none() + vector_index_config = wvcc.Configure.VectorIndex.hnsw() + + self.client.collections.create( + self.collection_name, + vectorizer_config=vectorizer_config, + vector_index_config=vector_index_config, + properties=properties, + ) + + def insert(self, vectors, payloads=None, ids=None): + """ + Insert vectors into a collection. + + Args: + vectors (list): List of vectors to insert. + payloads (list, optional): List of payloads corresponding to vectors. Defaults to None. + ids (list, optional): List of IDs corresponding to vectors. Defaults to None. + """ + logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}") + with self.client.batch.fixed_size(batch_size=100) as batch: + for idx, vector in enumerate(vectors): + object_id = ids[idx] if ids and idx < len(ids) else str(uuid.uuid4()) + object_id = get_valid_uuid(object_id) + + data_object = payloads[idx] if payloads and idx < len(payloads) else {} + + # Ensure 'id' is not included in properties (it's used as the Weaviate object ID) + if "ids" in data_object: + del data_object["ids"] + + batch.add_object(collection=self.collection_name, properties=data_object, uuid=object_id, vector=vector) + + def search(self, query: List[float], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]: + """ + Search for similar vectors. + """ + collection = self.client.collections.get(str(self.collection_name)) + filter_conditions = [] + if filters: + for key, value in filters.items(): + if value and key in ["user_id", "agent_id", "run_id"]: + filter_conditions.append(Filter.by_property(key).equal(value)) + combined_filter = Filter.all_of(filter_conditions) if filter_conditions else None + response = collection.query.hybrid( + query="", + vector=query, + limit=limit, + filters=combined_filter, + return_properties=["hash", "created_at", "updated_at", "user_id", "agent_id", "run_id", "data", "category"], + return_metadata=MetadataQuery(score=True), + ) + results = [] + for obj in response.objects: + payload = obj.properties.copy() + + for id_field in ["run_id", "agent_id", "user_id"]: + if id_field in payload and payload[id_field] is None: + del payload[id_field] + + payload["id"] = str(obj.uuid).split("'")[0] # Include the id in the payload + results.append( + OutputData( + id=str(obj.uuid), + score=1 + if obj.metadata.distance is None + else 1 - obj.metadata.distance, # Convert distance to score + payload=payload, + ) + ) + return results + + def delete(self, vector_id): + """ + Delete a vector by ID. + + Args: + vector_id: ID of the vector to delete. + """ + collection = self.client.collections.get(str(self.collection_name)) + collection.data.delete_by_id(vector_id) + + def update(self, vector_id, vector=None, payload=None): + """ + Update a vector and its payload. + + Args: + vector_id: ID of the vector to update. + vector (list, optional): Updated vector. Defaults to None. + payload (dict, optional): Updated payload. Defaults to None. + """ + collection = self.client.collections.get(str(self.collection_name)) + + if payload: + collection.data.update(uuid=vector_id, properties=payload) + + if vector: + existing_data = self.get(vector_id) + if existing_data: + existing_data = dict(existing_data) + if "id" in existing_data: + del existing_data["id"] + existing_payload: Mapping[str, str] = existing_data + collection.data.update(uuid=vector_id, properties=existing_payload, vector=vector) + + def get(self, vector_id): + """ + Retrieve a vector by ID. + + Args: + vector_id: ID of the vector to retrieve. + + Returns: + dict: Retrieved vector and metadata. + """ + vector_id = get_valid_uuid(vector_id) + collection = self.client.collections.get(str(self.collection_name)) + + response = collection.query.fetch_object_by_id( + uuid=vector_id, + return_properties=["hash", "created_at", "updated_at", "user_id", "agent_id", "run_id", "data", "category"], + ) + # results = {} + # print("reponse",response) + # for obj in response.objects: + payload = response.properties.copy() + payload["id"] = str(response.uuid).split("'")[0] + results = OutputData( + id=str(response.uuid).split("'")[0], + score=1.0, + payload=payload, + ) + return results + + def list_cols(self): + """ + List all collections. + + Returns: + list: List of collection names. + """ + collections = self.client.collections.list_all() + logger.debug(f"collections: {collections}") + print(f"collections: {collections}") + return {"collections": [{"name": col.name} for col in collections]} + + def delete_col(self): + """Delete a collection.""" + self.client.collections.delete(self.collection_name) + + def col_info(self): + """ + Get information about a collection. + + Returns: + dict: Collection information. + """ + schema = self.client.collections.get(self.collection_name) + if schema: + return schema + return None + + def list(self, filters=None, limit=100) -> List[OutputData]: + """ + List all vectors in a collection. + """ + collection = self.client.collections.get(self.collection_name) + filter_conditions = [] + if filters: + for key, value in filters.items(): + if value and key in ["user_id", "agent_id", "run_id"]: + filter_conditions.append(Filter.by_property(key).equal(value)) + combined_filter = Filter.all_of(filter_conditions) if filter_conditions else None + response = collection.query.fetch_objects( + limit=limit, + filters=combined_filter, + return_properties=["hash", "created_at", "updated_at", "user_id", "agent_id", "run_id", "data", "category"], + ) + results = [] + for obj in response.objects: + payload = obj.properties.copy() + payload["id"] = str(obj.uuid).split("'")[0] + results.append(OutputData(id=str(obj.uuid).split("'")[0], score=1.0, payload=payload)) + return [results] diff --git a/tests/vector_stores/test_weaviate.py b/tests/vector_stores/test_weaviate.py new file mode 100644 index 00000000..231e2bff --- /dev/null +++ b/tests/vector_stores/test_weaviate.py @@ -0,0 +1,220 @@ +import os +import uuid +import httpx +import unittest +from unittest.mock import MagicMock, patch + +import dotenv +import weaviate +from weaviate.classes.query import MetadataQuery, Filter +from weaviate.exceptions import UnexpectedStatusCodeException + +from mem0.vector_stores.weaviate import Weaviate, OutputData + + +class TestWeaviateDB(unittest.TestCase): + @classmethod + def setUpClass(cls): + dotenv.load_dotenv() + + cls.original_env = { + 'WEAVIATE_CLUSTER_URL': os.getenv('WEAVIATE_CLUSTER_URL', 'http://localhost:8080'), + 'WEAVIATE_API_KEY': os.getenv('WEAVIATE_API_KEY', 'test_api_key'), + } + + os.environ['WEAVIATE_CLUSTER_URL'] = 'http://localhost:8080' + os.environ['WEAVIATE_API_KEY'] = 'test_api_key' + + def setUp(self): + self.client_mock = MagicMock(spec=weaviate.WeaviateClient) + self.client_mock.collections = MagicMock() + self.client_mock.collections.exists.return_value = False + self.client_mock.collections.create.return_value = None + self.client_mock.collections.delete.return_value = None + + patcher = patch('mem0.vector_stores.weaviate.weaviate.connect_to_local', return_value=self.client_mock) + self.mock_weaviate = patcher.start() + self.addCleanup(patcher.stop) + + self.weaviate_db = Weaviate( + collection_name="test_collection", + embedding_model_dims=1536, + cluster_url=os.getenv('WEAVIATE_CLUSTER_URL'), + auth_client_secret=os.getenv('WEAVIATE_API_KEY'), + additional_headers={"X-OpenAI-Api-Key": "test_key"}, + ) + + self.client_mock.reset_mock() + + @classmethod + def tearDownClass(cls): + for key, value in cls.original_env.items(): + if value is not None: + os.environ[key] = value + else: + os.environ.pop(key, None) + + def tearDown(self): + self.client_mock.reset_mock() + + def test_create_col(self): + self.client_mock.collections.exists.return_value = False + self.weaviate_db.create_col(vector_size=1536) + + + self.client_mock.collections.create.assert_called_once() + + + self.client_mock.reset_mock() + + self.client_mock.collections.exists.return_value = True + self.weaviate_db.create_col(vector_size=1536) + + self.client_mock.collections.create.assert_not_called() + + def test_insert(self): + self.client_mock.batch = MagicMock() + + self.client_mock.batch.fixed_size.return_value.__enter__.return_value = MagicMock() + + self.client_mock.collections.get.return_value.data.insert_many.return_value = { + "results": [{"id": "id1"}, {"id": "id2"}] + } + + vectors = [[0.1] * 1536, [0.2] * 1536] + payloads = [{"key1": "value1"}, {"key2": "value2"}] + ids = [str(uuid.uuid4()), str(uuid.uuid4())] + + results = self.weaviate_db.insert(vectors=vectors, payloads=payloads, ids=ids) + + def test_get(self): + valid_uuid = str(uuid.uuid4()) + + mock_response = MagicMock() + mock_response.properties = { + "hash": "abc123", + "created_at": "2025-03-08T12:00:00Z", + "updated_at": "2025-03-08T13:00:00Z", + "user_id": "user_123", + "agent_id": "agent_456", + "run_id": "run_789", + "data": {"key": "value"}, + "category": "test", + } + mock_response.uuid = valid_uuid + + self.client_mock.collections.get.return_value.query.fetch_object_by_id.return_value = mock_response + + result = self.weaviate_db.get(vector_id=valid_uuid) + + assert result.id == valid_uuid + + expected_payload = mock_response.properties.copy() + expected_payload["id"] = valid_uuid + + assert result.payload == expected_payload + + + def test_get_not_found(self): + mock_response = httpx.Response(status_code=404, json={"error": "Not found"}) + + self.client_mock.collections.get.return_value.data.get_by_id.side_effect = UnexpectedStatusCodeException( + "Not found", mock_response + ) + + + def test_search(self): + mock_objects = [ + { + "uuid": "id1", + "properties": {"key1": "value1"}, + "metadata": {"distance": 0.2} + } + ] + + mock_response = MagicMock() + mock_response.objects = [] + + for obj in mock_objects: + mock_obj = MagicMock() + mock_obj.uuid = obj["uuid"] + mock_obj.properties = obj["properties"] + mock_obj.metadata = MagicMock() + mock_obj.metadata.distance = obj["metadata"]["distance"] + mock_response.objects.append(mock_obj) + + mock_hybrid = MagicMock() + self.client_mock.collections.get.return_value.query.hybrid = mock_hybrid + mock_hybrid.return_value = mock_response + + query_vector = [0.1] * 1536 + results = self.weaviate_db.search(query=query_vector, limit=5) + + mock_hybrid.assert_called_once() + + self.assertEqual(len(results), 1) + self.assertEqual(results[0].id, "id1") + self.assertEqual(results[0].score, 0.8) + + def test_delete(self): + self.weaviate_db.delete(vector_id="id1") + + self.client_mock.collections.get.return_value.data.delete_by_id.assert_called_once_with("id1") + + def test_list(self): + mock_objects = [] + + mock_obj1 = MagicMock() + mock_obj1.uuid = "id1" + mock_obj1.properties = {"key1": "value1"} + mock_objects.append(mock_obj1) + + mock_obj2 = MagicMock() + mock_obj2.uuid = "id2" + mock_obj2.properties = {"key2": "value2"} + mock_objects.append(mock_obj2) + + mock_response = MagicMock() + mock_response.objects = mock_objects + + mock_fetch = MagicMock() + self.client_mock.collections.get.return_value.query.fetch_objects = mock_fetch + mock_fetch.return_value = mock_response + + results = self.weaviate_db.list(limit=10) + + mock_fetch.assert_called_once() + + # Verify results + self.assertEqual(len(results), 1) + self.assertEqual(len(results[0]), 2) + self.assertEqual(results[0][0].id, "id1") + self.assertEqual(results[0][0].payload["key1"], "value1") + self.assertEqual(results[0][1].id, "id2") + self.assertEqual(results[0][1].payload["key2"], "value2") + + + def test_list_cols(self): + mock_collection1 = MagicMock() + mock_collection1.name = "collection1" + + mock_collection2 = MagicMock() + mock_collection2.name = "collection2" + self.client_mock.collections.list_all.return_value = [mock_collection1, mock_collection2] + + result = self.weaviate_db.list_cols() + expected = {"collections": [{"name": "collection1"}, {"name": "collection2"}]} + + assert result == expected + + self.client_mock.collections.list_all.assert_called_once() + + + def test_delete_col(self): + self.weaviate_db.delete_col() + + self.client_mock.collections.delete.assert_called_once_with("test_collection") + + +if __name__ == '__main__': + unittest.main()