221 lines
7.4 KiB
Python
221 lines
7.4 KiB
Python
# import os
|
|
# import uuid
|
|
# import httpx
|
|
# import unittest
|
|
# from unittest.mock import MagicMock, patch
|
|
|
|
# import dotenv
|
|
# import weaviate
|
|
# from weaviate.classes.query import MetadataQuery, Filter
|
|
# from weaviate.exceptions import UnexpectedStatusCodeException
|
|
|
|
# from mem0.vector_stores.weaviate import Weaviate, OutputData
|
|
|
|
|
|
# class TestWeaviateDB(unittest.TestCase):
|
|
# @classmethod
|
|
# def setUpClass(cls):
|
|
# dotenv.load_dotenv()
|
|
|
|
# cls.original_env = {
|
|
# 'WEAVIATE_CLUSTER_URL': os.getenv('WEAVIATE_CLUSTER_URL', 'http://localhost:8080'),
|
|
# 'WEAVIATE_API_KEY': os.getenv('WEAVIATE_API_KEY', 'test_api_key'),
|
|
# }
|
|
|
|
# os.environ['WEAVIATE_CLUSTER_URL'] = 'http://localhost:8080'
|
|
# os.environ['WEAVIATE_API_KEY'] = 'test_api_key'
|
|
|
|
# def setUp(self):
|
|
# self.client_mock = MagicMock(spec=weaviate.WeaviateClient)
|
|
# self.client_mock.collections = MagicMock()
|
|
# self.client_mock.collections.exists.return_value = False
|
|
# self.client_mock.collections.create.return_value = None
|
|
# self.client_mock.collections.delete.return_value = None
|
|
|
|
# patcher = patch('mem0.vector_stores.weaviate.weaviate.connect_to_local', return_value=self.client_mock)
|
|
# self.mock_weaviate = patcher.start()
|
|
# self.addCleanup(patcher.stop)
|
|
|
|
# self.weaviate_db = Weaviate(
|
|
# collection_name="test_collection",
|
|
# embedding_model_dims=1536,
|
|
# cluster_url=os.getenv('WEAVIATE_CLUSTER_URL'),
|
|
# auth_client_secret=os.getenv('WEAVIATE_API_KEY'),
|
|
# additional_headers={"X-OpenAI-Api-Key": "test_key"},
|
|
# )
|
|
|
|
# self.client_mock.reset_mock()
|
|
|
|
# @classmethod
|
|
# def tearDownClass(cls):
|
|
# for key, value in cls.original_env.items():
|
|
# if value is not None:
|
|
# os.environ[key] = value
|
|
# else:
|
|
# os.environ.pop(key, None)
|
|
|
|
# def tearDown(self):
|
|
# self.client_mock.reset_mock()
|
|
|
|
# def test_create_col(self):
|
|
# self.client_mock.collections.exists.return_value = False
|
|
# self.weaviate_db.create_col(vector_size=1536)
|
|
|
|
|
|
# self.client_mock.collections.create.assert_called_once()
|
|
|
|
|
|
# self.client_mock.reset_mock()
|
|
|
|
# self.client_mock.collections.exists.return_value = True
|
|
# self.weaviate_db.create_col(vector_size=1536)
|
|
|
|
# self.client_mock.collections.create.assert_not_called()
|
|
|
|
# def test_insert(self):
|
|
# self.client_mock.batch = MagicMock()
|
|
|
|
# self.client_mock.batch.fixed_size.return_value.__enter__.return_value = MagicMock()
|
|
|
|
# self.client_mock.collections.get.return_value.data.insert_many.return_value = {
|
|
# "results": [{"id": "id1"}, {"id": "id2"}]
|
|
# }
|
|
|
|
# vectors = [[0.1] * 1536, [0.2] * 1536]
|
|
# payloads = [{"key1": "value1"}, {"key2": "value2"}]
|
|
# ids = [str(uuid.uuid4()), str(uuid.uuid4())]
|
|
|
|
# results = self.weaviate_db.insert(vectors=vectors, payloads=payloads, ids=ids)
|
|
|
|
# def test_get(self):
|
|
# valid_uuid = str(uuid.uuid4())
|
|
|
|
# mock_response = MagicMock()
|
|
# mock_response.properties = {
|
|
# "hash": "abc123",
|
|
# "created_at": "2025-03-08T12:00:00Z",
|
|
# "updated_at": "2025-03-08T13:00:00Z",
|
|
# "user_id": "user_123",
|
|
# "agent_id": "agent_456",
|
|
# "run_id": "run_789",
|
|
# "data": {"key": "value"},
|
|
# "category": "test",
|
|
# }
|
|
# mock_response.uuid = valid_uuid
|
|
|
|
# self.client_mock.collections.get.return_value.query.fetch_object_by_id.return_value = mock_response
|
|
|
|
# result = self.weaviate_db.get(vector_id=valid_uuid)
|
|
|
|
# assert result.id == valid_uuid
|
|
|
|
# expected_payload = mock_response.properties.copy()
|
|
# expected_payload["id"] = valid_uuid
|
|
|
|
# assert result.payload == expected_payload
|
|
|
|
|
|
# def test_get_not_found(self):
|
|
# mock_response = httpx.Response(status_code=404, json={"error": "Not found"})
|
|
|
|
# self.client_mock.collections.get.return_value.data.get_by_id.side_effect = UnexpectedStatusCodeException(
|
|
# "Not found", mock_response
|
|
# )
|
|
|
|
|
|
# def test_search(self):
|
|
# mock_objects = [
|
|
# {
|
|
# "uuid": "id1",
|
|
# "properties": {"key1": "value1"},
|
|
# "metadata": {"distance": 0.2}
|
|
# }
|
|
# ]
|
|
|
|
# mock_response = MagicMock()
|
|
# mock_response.objects = []
|
|
|
|
# for obj in mock_objects:
|
|
# mock_obj = MagicMock()
|
|
# mock_obj.uuid = obj["uuid"]
|
|
# mock_obj.properties = obj["properties"]
|
|
# mock_obj.metadata = MagicMock()
|
|
# mock_obj.metadata.distance = obj["metadata"]["distance"]
|
|
# mock_response.objects.append(mock_obj)
|
|
|
|
# mock_hybrid = MagicMock()
|
|
# self.client_mock.collections.get.return_value.query.hybrid = mock_hybrid
|
|
# mock_hybrid.return_value = mock_response
|
|
|
|
# vectors = [[0.1] * 1536]
|
|
# results = self.weaviate_db.search(query="", vectors=vectors, limit=5)
|
|
|
|
# mock_hybrid.assert_called_once()
|
|
|
|
# self.assertEqual(len(results), 1)
|
|
# self.assertEqual(results[0].id, "id1")
|
|
# self.assertEqual(results[0].score, 0.8)
|
|
|
|
# def test_delete(self):
|
|
# self.weaviate_db.delete(vector_id="id1")
|
|
|
|
# self.client_mock.collections.get.return_value.data.delete_by_id.assert_called_once_with("id1")
|
|
|
|
# def test_list(self):
|
|
# mock_objects = []
|
|
|
|
# mock_obj1 = MagicMock()
|
|
# mock_obj1.uuid = "id1"
|
|
# mock_obj1.properties = {"key1": "value1"}
|
|
# mock_objects.append(mock_obj1)
|
|
|
|
# mock_obj2 = MagicMock()
|
|
# mock_obj2.uuid = "id2"
|
|
# mock_obj2.properties = {"key2": "value2"}
|
|
# mock_objects.append(mock_obj2)
|
|
|
|
# mock_response = MagicMock()
|
|
# mock_response.objects = mock_objects
|
|
|
|
# mock_fetch = MagicMock()
|
|
# self.client_mock.collections.get.return_value.query.fetch_objects = mock_fetch
|
|
# mock_fetch.return_value = mock_response
|
|
|
|
# results = self.weaviate_db.list(limit=10)
|
|
|
|
# mock_fetch.assert_called_once()
|
|
|
|
# # Verify results
|
|
# self.assertEqual(len(results), 1)
|
|
# self.assertEqual(len(results[0]), 2)
|
|
# self.assertEqual(results[0][0].id, "id1")
|
|
# self.assertEqual(results[0][0].payload["key1"], "value1")
|
|
# self.assertEqual(results[0][1].id, "id2")
|
|
# self.assertEqual(results[0][1].payload["key2"], "value2")
|
|
|
|
|
|
# def test_list_cols(self):
|
|
# mock_collection1 = MagicMock()
|
|
# mock_collection1.name = "collection1"
|
|
|
|
# mock_collection2 = MagicMock()
|
|
# mock_collection2.name = "collection2"
|
|
# self.client_mock.collections.list_all.return_value = [mock_collection1, mock_collection2]
|
|
|
|
# result = self.weaviate_db.list_cols()
|
|
# expected = {"collections": [{"name": "collection1"}, {"name": "collection2"}]}
|
|
|
|
# assert result == expected
|
|
|
|
# self.client_mock.collections.list_all.assert_called_once()
|
|
|
|
|
|
# def test_delete_col(self):
|
|
# self.weaviate_db.delete_col()
|
|
|
|
# self.client_mock.collections.delete.assert_called_once_with("test_collection")
|
|
|
|
|
|
# if __name__ == '__main__':
|
|
# unittest.main()
|