added vector store test cases (#1868)
Co-authored-by: Dev Khant <devkhant24@gmail.com>
This commit is contained in:
2
Makefile
2
Makefile
@@ -12,7 +12,7 @@ install:
|
||||
|
||||
install_all:
|
||||
poetry install
|
||||
poetry run pip install groq together boto3 litellm ollama sentence_transformers
|
||||
poetry run pip install groq together boto3 litellm ollama chromadb sentence_transformers
|
||||
|
||||
# Format code with ruff
|
||||
format:
|
||||
|
||||
111
tests/vector_stores/test_chroma.py
Normal file
111
tests/vector_stores/test_chroma.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from unittest.mock import Mock, patch
|
||||
import pytest
|
||||
from mem0.vector_stores.chroma import ChromaDB, OutputData
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_chromadb_client():
|
||||
with patch("chromadb.Client") as mock_client:
|
||||
yield mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chromadb_instance(mock_chromadb_client):
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.return_value.get_or_create_collection.return_value = (
|
||||
mock_collection
|
||||
)
|
||||
|
||||
return ChromaDB(
|
||||
collection_name="test_collection", client=mock_chromadb_client.return_value
|
||||
)
|
||||
|
||||
|
||||
def test_insert_vectors(chromadb_instance, mock_chromadb_client):
|
||||
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
payloads = [{"name": "vector1"}, {"name": "vector2"}]
|
||||
ids = ["id1", "id2"]
|
||||
|
||||
chromadb_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
|
||||
|
||||
chromadb_instance.collection.add.assert_called_once_with(
|
||||
ids=ids, embeddings=vectors, metadatas=payloads
|
||||
)
|
||||
|
||||
|
||||
def test_search_vectors(chromadb_instance, mock_chromadb_client):
|
||||
mock_result = {
|
||||
"ids": [["id1", "id2"]],
|
||||
"distances": [[0.1, 0.2]],
|
||||
"metadatas": [[{"name": "vector1"}, {"name": "vector2"}]],
|
||||
}
|
||||
chromadb_instance.collection.query.return_value = mock_result
|
||||
|
||||
query = [[0.1, 0.2, 0.3]]
|
||||
results = chromadb_instance.search(query=query, limit=2)
|
||||
|
||||
chromadb_instance.collection.query.assert_called_once_with(
|
||||
query_embeddings=query, where=None, n_results=2
|
||||
)
|
||||
|
||||
print(results, type(results))
|
||||
assert len(results) == 2
|
||||
assert results[0].id == "id1"
|
||||
assert results[0].score == 0.1
|
||||
assert results[0].payload == {"name": "vector1"}
|
||||
|
||||
|
||||
def test_delete_vector(chromadb_instance):
|
||||
vector_id = "id1"
|
||||
|
||||
chromadb_instance.delete(vector_id=vector_id)
|
||||
|
||||
chromadb_instance.collection.delete.assert_called_once_with(ids=vector_id)
|
||||
|
||||
|
||||
def test_update_vector(chromadb_instance):
|
||||
vector_id = "id1"
|
||||
new_vector = [0.7, 0.8, 0.9]
|
||||
new_payload = {"name": "updated_vector"}
|
||||
|
||||
chromadb_instance.update(
|
||||
vector_id=vector_id, vector=new_vector, payload=new_payload
|
||||
)
|
||||
|
||||
chromadb_instance.collection.update.assert_called_once_with(
|
||||
ids=vector_id, embeddings=new_vector, metadatas=new_payload
|
||||
)
|
||||
|
||||
|
||||
def test_get_vector(chromadb_instance):
|
||||
mock_result = {
|
||||
"ids": [["id1"]],
|
||||
"distances": [[0.1]],
|
||||
"metadatas": [[{"name": "vector1"}]],
|
||||
}
|
||||
chromadb_instance.collection.get.return_value = mock_result
|
||||
|
||||
result = chromadb_instance.get(vector_id="id1")
|
||||
|
||||
chromadb_instance.collection.get.assert_called_once_with(ids=["id1"])
|
||||
|
||||
assert result.id == "id1"
|
||||
assert result.score == 0.1
|
||||
assert result.payload == {"name": "vector1"}
|
||||
|
||||
|
||||
def test_list_vectors(chromadb_instance):
|
||||
mock_result = {
|
||||
"ids": [["id1", "id2"]],
|
||||
"distances": [[0.1, 0.2]],
|
||||
"metadatas": [[{"name": "vector1"}, {"name": "vector2"}]],
|
||||
}
|
||||
chromadb_instance.collection.get.return_value = mock_result
|
||||
|
||||
results = chromadb_instance.list(limit=2)
|
||||
|
||||
chromadb_instance.collection.get.assert_called_once_with(where=None, limit=2)
|
||||
|
||||
assert len(results[0]) == 2
|
||||
assert results[0][0].id == "id1"
|
||||
assert results[0][1].id == "id2"
|
||||
130
tests/vector_stores/test_qdrant.py
Normal file
130
tests/vector_stores/test_qdrant.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
import uuid
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import (
|
||||
Distance,
|
||||
PointStruct,
|
||||
VectorParams,
|
||||
PointIdsList,
|
||||
)
|
||||
from mem0.vector_stores.qdrant import Qdrant
|
||||
|
||||
|
||||
class TestQdrant(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.client_mock = MagicMock(spec=QdrantClient)
|
||||
self.qdrant = Qdrant(
|
||||
collection_name="test_collection",
|
||||
embedding_model_dims=128,
|
||||
client=self.client_mock,
|
||||
path="test_path",
|
||||
on_disk=True,
|
||||
)
|
||||
|
||||
def test_create_col(self):
|
||||
self.client_mock.get_collections.return_value = MagicMock(collections=[])
|
||||
|
||||
self.qdrant.create_col(vector_size=128, on_disk=True)
|
||||
|
||||
expected_config = VectorParams(size=128, distance=Distance.COSINE, on_disk=True)
|
||||
|
||||
self.client_mock.create_collection.assert_called_with(
|
||||
collection_name="test_collection", vectors_config=expected_config
|
||||
)
|
||||
|
||||
def test_insert(self):
|
||||
vectors = [[0.1, 0.2], [0.3, 0.4]]
|
||||
payloads = [{"key": "value1"}, {"key": "value2"}]
|
||||
ids = [str(uuid.uuid4()), str(uuid.uuid4())]
|
||||
|
||||
self.qdrant.insert(vectors=vectors, payloads=payloads, ids=ids)
|
||||
|
||||
self.client_mock.upsert.assert_called_once()
|
||||
points = self.client_mock.upsert.call_args[1]["points"]
|
||||
|
||||
self.assertEqual(len(points), 2)
|
||||
for point in points:
|
||||
self.assertIsInstance(point, PointStruct)
|
||||
|
||||
self.assertEqual(points[0].payload, payloads[0])
|
||||
|
||||
def test_search(self):
|
||||
query_vector = [0.1, 0.2]
|
||||
self.client_mock.search.return_value = [
|
||||
{"id": str(uuid.uuid4()), "score": 0.95, "payload": {"key": "value"}}
|
||||
]
|
||||
|
||||
results = self.qdrant.search(query=query_vector, limit=1)
|
||||
|
||||
self.client_mock.search.assert_called_once_with(
|
||||
collection_name="test_collection",
|
||||
query_vector=query_vector,
|
||||
query_filter=None,
|
||||
limit=1,
|
||||
)
|
||||
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertIn("id", results[0])
|
||||
self.assertIn("score", results[0])
|
||||
self.assertIn("payload", results[0])
|
||||
|
||||
def test_delete(self):
|
||||
vector_id = str(uuid.uuid4())
|
||||
self.qdrant.delete(vector_id=vector_id)
|
||||
|
||||
self.client_mock.delete.assert_called_once_with(
|
||||
collection_name="test_collection",
|
||||
points_selector=PointIdsList(points=[vector_id]),
|
||||
)
|
||||
|
||||
def test_update(self):
|
||||
vector_id = str(uuid.uuid4())
|
||||
updated_vector = [0.2, 0.3]
|
||||
updated_payload = {"key": "updated_value"}
|
||||
|
||||
self.qdrant.update(
|
||||
vector_id=vector_id, vector=updated_vector, payload=updated_payload
|
||||
)
|
||||
|
||||
self.client_mock.upsert.assert_called_once()
|
||||
point = self.client_mock.upsert.call_args[1]["points"][0]
|
||||
self.assertEqual(point.id, vector_id)
|
||||
self.assertEqual(point.vector, updated_vector)
|
||||
self.assertEqual(point.payload, updated_payload)
|
||||
|
||||
def test_get(self):
|
||||
vector_id = str(uuid.uuid4())
|
||||
self.client_mock.retrieve.return_value = [
|
||||
{"id": vector_id, "payload": {"key": "value"}}
|
||||
]
|
||||
|
||||
result = self.qdrant.get(vector_id=vector_id)
|
||||
|
||||
self.client_mock.retrieve.assert_called_once_with(
|
||||
collection_name="test_collection", ids=[vector_id], with_payload=True
|
||||
)
|
||||
self.assertEqual(result["id"], vector_id)
|
||||
self.assertEqual(result["payload"], {"key": "value"})
|
||||
|
||||
def test_list_cols(self):
|
||||
self.client_mock.get_collections.return_value = MagicMock(
|
||||
collections=[{"name": "test_collection"}]
|
||||
)
|
||||
result = self.qdrant.list_cols()
|
||||
self.assertEqual(result.collections[0]["name"], "test_collection")
|
||||
|
||||
def test_delete_col(self):
|
||||
self.qdrant.delete_col()
|
||||
self.client_mock.delete_collection.assert_called_once_with(
|
||||
collection_name="test_collection"
|
||||
)
|
||||
|
||||
def test_col_info(self):
|
||||
self.qdrant.col_info()
|
||||
self.client_mock.get_collection.assert_called_once_with(
|
||||
collection_name="test_collection"
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
del self.qdrant
|
||||
Reference in New Issue
Block a user