Formatting (#2750)
This commit is contained in:
@@ -8,9 +8,7 @@ import pytest
|
||||
try:
|
||||
from opensearchpy import AWSV4SignerAuth, OpenSearch
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"OpenSearch requires extra dependencies. Install with `pip install opensearch-py`"
|
||||
) from None
|
||||
raise ImportError("OpenSearch requires extra dependencies. Install with `pip install opensearch-py`") from None
|
||||
|
||||
from mem0.vector_stores.opensearch import OpenSearchDB
|
||||
|
||||
@@ -20,13 +18,13 @@ class TestOpenSearchDB(unittest.TestCase):
|
||||
def setUpClass(cls):
|
||||
dotenv.load_dotenv()
|
||||
cls.original_env = {
|
||||
'OS_URL': os.getenv('OS_URL', 'http://localhost:9200'),
|
||||
'OS_USERNAME': os.getenv('OS_USERNAME', 'test_user'),
|
||||
'OS_PASSWORD': os.getenv('OS_PASSWORD', 'test_password')
|
||||
"OS_URL": os.getenv("OS_URL", "http://localhost:9200"),
|
||||
"OS_USERNAME": os.getenv("OS_USERNAME", "test_user"),
|
||||
"OS_PASSWORD": os.getenv("OS_PASSWORD", "test_password"),
|
||||
}
|
||||
os.environ['OS_URL'] = 'http://localhost'
|
||||
os.environ['OS_USERNAME'] = 'test_user'
|
||||
os.environ['OS_PASSWORD'] = 'test_password'
|
||||
os.environ["OS_URL"] = "http://localhost"
|
||||
os.environ["OS_USERNAME"] = "test_user"
|
||||
os.environ["OS_PASSWORD"] = "test_password"
|
||||
|
||||
def setUp(self):
|
||||
self.client_mock = MagicMock(spec=OpenSearch)
|
||||
@@ -40,19 +38,19 @@ class TestOpenSearchDB(unittest.TestCase):
|
||||
self.client_mock.delete = MagicMock()
|
||||
self.client_mock.search = MagicMock()
|
||||
|
||||
patcher = patch('mem0.vector_stores.opensearch.OpenSearch', return_value=self.client_mock)
|
||||
patcher = patch("mem0.vector_stores.opensearch.OpenSearch", return_value=self.client_mock)
|
||||
self.mock_os = patcher.start()
|
||||
self.addCleanup(patcher.stop)
|
||||
|
||||
self.os_db = OpenSearchDB(
|
||||
host=os.getenv('OS_URL'),
|
||||
host=os.getenv("OS_URL"),
|
||||
port=9200,
|
||||
collection_name="test_collection",
|
||||
embedding_model_dims=1536,
|
||||
user=os.getenv('OS_USERNAME'),
|
||||
password=os.getenv('OS_PASSWORD'),
|
||||
user=os.getenv("OS_USERNAME"),
|
||||
password=os.getenv("OS_PASSWORD"),
|
||||
verify_certs=False,
|
||||
use_ssl=False
|
||||
use_ssl=False,
|
||||
)
|
||||
self.client_mock.reset_mock()
|
||||
|
||||
@@ -86,29 +84,29 @@ class TestOpenSearchDB(unittest.TestCase):
|
||||
vectors = [[0.1] * 1536, [0.2] * 1536]
|
||||
payloads = [{"key1": "value1"}, {"key2": "value2"}]
|
||||
ids = ["id1", "id2"]
|
||||
|
||||
|
||||
# Mock the index method
|
||||
self.client_mock.index = MagicMock()
|
||||
|
||||
|
||||
results = self.os_db.insert(vectors=vectors, payloads=payloads, ids=ids)
|
||||
|
||||
|
||||
# Verify index was called twice (once for each vector)
|
||||
self.assertEqual(self.client_mock.index.call_count, 2)
|
||||
|
||||
|
||||
# Check first call
|
||||
first_call = self.client_mock.index.call_args_list[0]
|
||||
self.assertEqual(first_call[1]["index"], "test_collection")
|
||||
self.assertEqual(first_call[1]["body"]["vector_field"], vectors[0])
|
||||
self.assertEqual(first_call[1]["body"]["payload"], payloads[0])
|
||||
self.assertEqual(first_call[1]["body"]["id"], ids[0])
|
||||
|
||||
|
||||
# Check second call
|
||||
second_call = self.client_mock.index.call_args_list[1]
|
||||
self.assertEqual(second_call[1]["index"], "test_collection")
|
||||
self.assertEqual(second_call[1]["body"]["vector_field"], vectors[1])
|
||||
self.assertEqual(second_call[1]["body"]["payload"], payloads[1])
|
||||
self.assertEqual(second_call[1]["body"]["id"], ids[1])
|
||||
|
||||
|
||||
# Check results
|
||||
self.assertEqual(len(results), 2)
|
||||
self.assertEqual(results[0].id, "id1")
|
||||
@@ -132,7 +130,7 @@ class TestOpenSearchDB(unittest.TestCase):
|
||||
self.client_mock.search.return_value = {"hits": {"hits": []}}
|
||||
result = self.os_db.get("nonexistent")
|
||||
self.assertIsNone(result)
|
||||
|
||||
|
||||
def test_update(self):
|
||||
vector = [0.3] * 1536
|
||||
payload = {"key3": "value3"}
|
||||
@@ -152,7 +150,17 @@ class TestOpenSearchDB(unittest.TestCase):
|
||||
self.assertEqual(result, ["test_collection"])
|
||||
|
||||
def test_search(self):
|
||||
mock_response = {"hits": {"hits": [{"_id": "id1", "_score": 0.8, "_source": {"vector_field": [0.1] * 1536, "id": "id1", "payload": {"key1": "value1"}}}]}}
|
||||
mock_response = {
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_id": "id1",
|
||||
"_score": 0.8,
|
||||
"_source": {"vector_field": [0.1] * 1536, "id": "id1", "payload": {"key1": "value1"}},
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
self.client_mock.search.return_value = mock_response
|
||||
vectors = [[0.1] * 1536]
|
||||
results = self.os_db.search(query="", vectors=vectors, limit=5)
|
||||
@@ -179,12 +187,11 @@ class TestOpenSearchDB(unittest.TestCase):
|
||||
self.os_db.delete_col()
|
||||
self.client_mock.indices.delete.assert_called_once_with(index="test_collection")
|
||||
|
||||
|
||||
def test_init_with_http_auth(self):
|
||||
mock_credentials = MagicMock()
|
||||
mock_signer = AWSV4SignerAuth(mock_credentials, "us-east-1", "es")
|
||||
|
||||
with patch('mem0.vector_stores.opensearch.OpenSearch') as mock_opensearch:
|
||||
with patch("mem0.vector_stores.opensearch.OpenSearch") as mock_opensearch:
|
||||
OpenSearchDB(
|
||||
host="localhost",
|
||||
port=9200,
|
||||
@@ -192,7 +199,7 @@ class TestOpenSearchDB(unittest.TestCase):
|
||||
embedding_model_dims=1536,
|
||||
http_auth=mock_signer,
|
||||
verify_certs=True,
|
||||
use_ssl=True
|
||||
use_ssl=True,
|
||||
)
|
||||
|
||||
# Verify OpenSearch was initialized with correct params
|
||||
@@ -202,5 +209,5 @@ class TestOpenSearchDB(unittest.TestCase):
|
||||
use_ssl=True,
|
||||
verify_certs=True,
|
||||
connection_class=unittest.mock.ANY,
|
||||
pool_maxsize=20
|
||||
)
|
||||
pool_maxsize=20,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user