Code formatting (#1986)

This commit is contained in:
Dev Khant
2024-10-29 11:32:07 +05:30
committed by GitHub
parent dca74a1ec0
commit 605558da9d
13 changed files with 119 additions and 149 deletions

View File

@@ -1,6 +1,6 @@
from unittest.mock import Mock, patch
import pytest
from mem0.vector_stores.chroma import ChromaDB, OutputData
from mem0.vector_stores.chroma import ChromaDB
@pytest.fixture
@@ -12,13 +12,9 @@ def mock_chromadb_client():
@pytest.fixture
def chromadb_instance(mock_chromadb_client):
mock_collection = Mock()
mock_chromadb_client.return_value.get_or_create_collection.return_value = (
mock_collection
)
mock_chromadb_client.return_value.get_or_create_collection.return_value = mock_collection
return ChromaDB(
collection_name="test_collection", client=mock_chromadb_client.return_value
)
return ChromaDB(collection_name="test_collection", client=mock_chromadb_client.return_value)
def test_insert_vectors(chromadb_instance, mock_chromadb_client):
@@ -28,9 +24,7 @@ def test_insert_vectors(chromadb_instance, mock_chromadb_client):
chromadb_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
chromadb_instance.collection.add.assert_called_once_with(
ids=ids, embeddings=vectors, metadatas=payloads
)
chromadb_instance.collection.add.assert_called_once_with(ids=ids, embeddings=vectors, metadatas=payloads)
def test_search_vectors(chromadb_instance, mock_chromadb_client):
@@ -44,9 +38,7 @@ def test_search_vectors(chromadb_instance, mock_chromadb_client):
query = [[0.1, 0.2, 0.3]]
results = chromadb_instance.search(query=query, limit=2)
chromadb_instance.collection.query.assert_called_once_with(
query_embeddings=query, where=None, n_results=2
)
chromadb_instance.collection.query.assert_called_once_with(query_embeddings=query, where=None, n_results=2)
print(results, type(results))
assert len(results) == 2
@@ -68,9 +60,7 @@ def test_update_vector(chromadb_instance):
new_vector = [0.7, 0.8, 0.9]
new_payload = {"name": "updated_vector"}
chromadb_instance.update(
vector_id=vector_id, vector=new_vector, payload=new_payload
)
chromadb_instance.update(vector_id=vector_id, vector=new_vector, payload=new_payload)
chromadb_instance.collection.update.assert_called_once_with(
ids=vector_id, embeddings=new_vector, metadatas=new_payload

View File

@@ -1,5 +1,5 @@
import unittest
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock
import uuid
from qdrant_client import QdrantClient
from qdrant_client.models import (
@@ -51,9 +51,7 @@ class TestQdrant(unittest.TestCase):
def test_search(self):
query_vector = [0.1, 0.2]
self.client_mock.search.return_value = [
{"id": str(uuid.uuid4()), "score": 0.95, "payload": {"key": "value"}}
]
self.client_mock.search.return_value = [{"id": str(uuid.uuid4()), "score": 0.95, "payload": {"key": "value"}}]
results = self.qdrant.search(query=query_vector, limit=1)
@@ -83,9 +81,7 @@ class TestQdrant(unittest.TestCase):
updated_vector = [0.2, 0.3]
updated_payload = {"key": "updated_value"}
self.qdrant.update(
vector_id=vector_id, vector=updated_vector, payload=updated_payload
)
self.qdrant.update(vector_id=vector_id, vector=updated_vector, payload=updated_payload)
self.client_mock.upsert.assert_called_once()
point = self.client_mock.upsert.call_args[1]["points"][0]
@@ -95,9 +91,7 @@ class TestQdrant(unittest.TestCase):
def test_get(self):
vector_id = str(uuid.uuid4())
self.client_mock.retrieve.return_value = [
{"id": vector_id, "payload": {"key": "value"}}
]
self.client_mock.retrieve.return_value = [{"id": vector_id, "payload": {"key": "value"}}]
result = self.qdrant.get(vector_id=vector_id)
@@ -108,23 +102,17 @@ class TestQdrant(unittest.TestCase):
self.assertEqual(result["payload"], {"key": "value"})
def test_list_cols(self):
self.client_mock.get_collections.return_value = MagicMock(
collections=[{"name": "test_collection"}]
)
self.client_mock.get_collections.return_value = MagicMock(collections=[{"name": "test_collection"}])
result = self.qdrant.list_cols()
self.assertEqual(result.collections[0]["name"], "test_collection")
def test_delete_col(self):
self.qdrant.delete_col()
self.client_mock.delete_collection.assert_called_once_with(
collection_name="test_collection"
)
self.client_mock.delete_collection.assert_called_once_with(collection_name="test_collection")
def test_col_info(self):
self.qdrant.col_info()
self.client_mock.get_collection.assert_called_once_with(
collection_name="test_collection"
)
self.client_mock.get_collection.assert_called_once_with(collection_name="test_collection")
def tearDown(self):
del self.qdrant