diff --git a/Makefile b/Makefile index e905addd..76a25de1 100644 --- a/Makefile +++ b/Makefile @@ -12,7 +12,7 @@ install: install_all: poetry install - poetry run pip install groq together boto3 litellm ollama sentence_transformers + poetry run pip install groq together boto3 litellm ollama chromadb sentence_transformers # Format code with ruff format: diff --git a/tests/vector_stores/test_chroma.py b/tests/vector_stores/test_chroma.py new file mode 100644 index 00000000..3d0c20b3 --- /dev/null +++ b/tests/vector_stores/test_chroma.py @@ -0,0 +1,111 @@ +from unittest.mock import Mock, patch +import pytest +from mem0.vector_stores.chroma import ChromaDB, OutputData + + +@pytest.fixture +def mock_chromadb_client(): + with patch("chromadb.Client") as mock_client: + yield mock_client + + +@pytest.fixture +def chromadb_instance(mock_chromadb_client): + mock_collection = Mock() + mock_chromadb_client.return_value.get_or_create_collection.return_value = ( + mock_collection + ) + + return ChromaDB( + collection_name="test_collection", client=mock_chromadb_client.return_value + ) + + +def test_insert_vectors(chromadb_instance, mock_chromadb_client): + vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + payloads = [{"name": "vector1"}, {"name": "vector2"}] + ids = ["id1", "id2"] + + chromadb_instance.insert(vectors=vectors, payloads=payloads, ids=ids) + + chromadb_instance.collection.add.assert_called_once_with( + ids=ids, embeddings=vectors, metadatas=payloads + ) + + +def test_search_vectors(chromadb_instance, mock_chromadb_client): + mock_result = { + "ids": [["id1", "id2"]], + "distances": [[0.1, 0.2]], + "metadatas": [[{"name": "vector1"}, {"name": "vector2"}]], + } + chromadb_instance.collection.query.return_value = mock_result + + query = [[0.1, 0.2, 0.3]] + results = chromadb_instance.search(query=query, limit=2) + + chromadb_instance.collection.query.assert_called_once_with( + query_embeddings=query, where=None, n_results=2 + ) + + print(results, type(results)) + assert len(results) == 2 + assert results[0].id == "id1" + assert results[0].score == 0.1 + assert results[0].payload == {"name": "vector1"} + + +def test_delete_vector(chromadb_instance): + vector_id = "id1" + + chromadb_instance.delete(vector_id=vector_id) + + chromadb_instance.collection.delete.assert_called_once_with(ids=vector_id) + + +def test_update_vector(chromadb_instance): + vector_id = "id1" + new_vector = [0.7, 0.8, 0.9] + new_payload = {"name": "updated_vector"} + + chromadb_instance.update( + vector_id=vector_id, vector=new_vector, payload=new_payload + ) + + chromadb_instance.collection.update.assert_called_once_with( + ids=vector_id, embeddings=new_vector, metadatas=new_payload + ) + + +def test_get_vector(chromadb_instance): + mock_result = { + "ids": [["id1"]], + "distances": [[0.1]], + "metadatas": [[{"name": "vector1"}]], + } + chromadb_instance.collection.get.return_value = mock_result + + result = chromadb_instance.get(vector_id="id1") + + chromadb_instance.collection.get.assert_called_once_with(ids=["id1"]) + + assert result.id == "id1" + assert result.score == 0.1 + assert result.payload == {"name": "vector1"} + + +def test_list_vectors(chromadb_instance): + mock_result = { + "ids": [["id1", "id2"]], + "distances": [[0.1, 0.2]], + "metadatas": [[{"name": "vector1"}, {"name": "vector2"}]], + } + chromadb_instance.collection.get.return_value = mock_result + + results = chromadb_instance.list(limit=2) + + chromadb_instance.collection.get.assert_called_once_with(where=None, limit=2) + + assert len(results[0]) == 2 + assert results[0][0].id == "id1" + assert results[0][1].id == "id2" diff --git a/tests/vector_stores/test_qdrant.py b/tests/vector_stores/test_qdrant.py new file mode 100644 index 00000000..b398335f --- /dev/null +++ b/tests/vector_stores/test_qdrant.py @@ -0,0 +1,130 @@ +import unittest +from unittest.mock import MagicMock, patch +import uuid +from qdrant_client import QdrantClient +from qdrant_client.models import ( + Distance, + PointStruct, + VectorParams, + PointIdsList, +) +from mem0.vector_stores.qdrant import Qdrant + + +class TestQdrant(unittest.TestCase): + def setUp(self): + self.client_mock = MagicMock(spec=QdrantClient) + self.qdrant = Qdrant( + collection_name="test_collection", + embedding_model_dims=128, + client=self.client_mock, + path="test_path", + on_disk=True, + ) + + def test_create_col(self): + self.client_mock.get_collections.return_value = MagicMock(collections=[]) + + self.qdrant.create_col(vector_size=128, on_disk=True) + + expected_config = VectorParams(size=128, distance=Distance.COSINE, on_disk=True) + + self.client_mock.create_collection.assert_called_with( + collection_name="test_collection", vectors_config=expected_config + ) + + def test_insert(self): + vectors = [[0.1, 0.2], [0.3, 0.4]] + payloads = [{"key": "value1"}, {"key": "value2"}] + ids = [str(uuid.uuid4()), str(uuid.uuid4())] + + self.qdrant.insert(vectors=vectors, payloads=payloads, ids=ids) + + self.client_mock.upsert.assert_called_once() + points = self.client_mock.upsert.call_args[1]["points"] + + self.assertEqual(len(points), 2) + for point in points: + self.assertIsInstance(point, PointStruct) + + self.assertEqual(points[0].payload, payloads[0]) + + def test_search(self): + query_vector = [0.1, 0.2] + self.client_mock.search.return_value = [ + {"id": str(uuid.uuid4()), "score": 0.95, "payload": {"key": "value"}} + ] + + results = self.qdrant.search(query=query_vector, limit=1) + + self.client_mock.search.assert_called_once_with( + collection_name="test_collection", + query_vector=query_vector, + query_filter=None, + limit=1, + ) + + self.assertEqual(len(results), 1) + self.assertIn("id", results[0]) + self.assertIn("score", results[0]) + self.assertIn("payload", results[0]) + + def test_delete(self): + vector_id = str(uuid.uuid4()) + self.qdrant.delete(vector_id=vector_id) + + self.client_mock.delete.assert_called_once_with( + collection_name="test_collection", + points_selector=PointIdsList(points=[vector_id]), + ) + + def test_update(self): + vector_id = str(uuid.uuid4()) + updated_vector = [0.2, 0.3] + updated_payload = {"key": "updated_value"} + + self.qdrant.update( + vector_id=vector_id, vector=updated_vector, payload=updated_payload + ) + + self.client_mock.upsert.assert_called_once() + point = self.client_mock.upsert.call_args[1]["points"][0] + self.assertEqual(point.id, vector_id) + self.assertEqual(point.vector, updated_vector) + self.assertEqual(point.payload, updated_payload) + + def test_get(self): + vector_id = str(uuid.uuid4()) + self.client_mock.retrieve.return_value = [ + {"id": vector_id, "payload": {"key": "value"}} + ] + + result = self.qdrant.get(vector_id=vector_id) + + self.client_mock.retrieve.assert_called_once_with( + collection_name="test_collection", ids=[vector_id], with_payload=True + ) + self.assertEqual(result["id"], vector_id) + self.assertEqual(result["payload"], {"key": "value"}) + + def test_list_cols(self): + self.client_mock.get_collections.return_value = MagicMock( + collections=[{"name": "test_collection"}] + ) + result = self.qdrant.list_cols() + self.assertEqual(result.collections[0]["name"], "test_collection") + + def test_delete_col(self): + self.qdrant.delete_col() + self.client_mock.delete_collection.assert_called_once_with( + collection_name="test_collection" + ) + + def test_col_info(self): + self.qdrant.col_info() + self.client_mock.get_collection.assert_called_once_with( + collection_name="test_collection" + ) + + def tearDown(self): + del self.qdrant