Feature/fix opensearch vector mapping (#2399)
This commit is contained in:
@@ -14,6 +14,7 @@ class OpenSearchConfig(BaseModel):
|
||||
verify_certs: bool = Field(False, description="Verify SSL certificates (default False for OpenSearch)")
|
||||
use_ssl: bool = Field(False, description="Use SSL for connection (default False for OpenSearch)")
|
||||
auto_create_index: bool = Field(True, description="Automatically create index during initialization")
|
||||
http_auth: Optional[object] = Field(None, description="HTTP authentication method / AWS SigV4")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -23,7 +24,7 @@ class OpenSearchConfig(BaseModel):
|
||||
raise ValueError("Host must be provided for OpenSearch")
|
||||
|
||||
# Authentication: Either API key or user/password must be provided
|
||||
if not any([values.get("api_key"), (values.get("user") and values.get("password"))]):
|
||||
if not any([values.get("api_key"), (values.get("user") and values.get("password")), values.get("http_auth")]):
|
||||
raise ValueError("Either api_key or user/password must be provided for OpenSearch authentication")
|
||||
|
||||
return values
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
try:
|
||||
from opensearchpy import OpenSearch
|
||||
from opensearchpy import OpenSearch, RequestsHttpConnection
|
||||
from opensearchpy.helpers import bulk
|
||||
except ImportError:
|
||||
raise ImportError("OpenSearch requires extra dependencies. Install with `pip install opensearch-py`") from None
|
||||
@@ -28,9 +28,10 @@ class OpenSearchDB(VectorStoreBase):
|
||||
# Initialize OpenSearch client
|
||||
self.client = OpenSearch(
|
||||
hosts=[{"host": config.host, "port": config.port or 9200}],
|
||||
http_auth=(config.user, config.password) if (config.user and config.password) else None,
|
||||
http_auth=config.http_auth if config.http_auth else ((config.user, config.password) if (config.user and config.password) else None),
|
||||
use_ssl=config.use_ssl,
|
||||
verify_certs=config.verify_certs,
|
||||
connection_class=RequestsHttpConnection
|
||||
)
|
||||
|
||||
self.collection_name = config.collection_name
|
||||
@@ -43,14 +44,17 @@ class OpenSearchDB(VectorStoreBase):
|
||||
def create_index(self) -> None:
|
||||
"""Create OpenSearch index with proper mappings if it doesn't exist."""
|
||||
index_settings = {
|
||||
# ToDo change replicas to 1
|
||||
"settings": {
|
||||
"index": {"number_of_replicas": 1, "number_of_shards": 5, "refresh_interval": "1s", "knn": True}
|
||||
},
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"text": {"type": "text"},
|
||||
"vector": {"type": "knn_vector", "dimension": self.vector_dim},
|
||||
"vector": {
|
||||
"type": "knn_vector",
|
||||
"dimension": self.vector_dim,
|
||||
"method": {"engine": "lucene", "name": "hnsw", "space_type": "cosinesimil"},
|
||||
},
|
||||
"metadata": {"type": "object", "properties": {"user_id": {"type": "keyword"}}},
|
||||
}
|
||||
},
|
||||
|
||||
@@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
|
||||
import dotenv
|
||||
|
||||
try:
|
||||
from opensearchpy import OpenSearch
|
||||
from opensearchpy import OpenSearch, AWSV4SignerAuth
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"OpenSearch requires extra dependencies. Install with `pip install opensearch-py`"
|
||||
@@ -148,3 +148,29 @@ class TestOpenSearchDB(unittest.TestCase):
|
||||
def test_delete_col(self):
|
||||
self.os_db.delete_col()
|
||||
self.client_mock.indices.delete.assert_called_once_with(index="test_collection")
|
||||
|
||||
|
||||
def test_init_with_http_auth(self):
|
||||
mock_credentials = MagicMock()
|
||||
mock_signer = AWSV4SignerAuth(mock_credentials, "us-east-1", "es")
|
||||
|
||||
with patch('mem0.vector_stores.opensearch.OpenSearch') as mock_opensearch:
|
||||
test_db = OpenSearchDB(
|
||||
host="localhost",
|
||||
port=9200,
|
||||
collection_name="test_collection",
|
||||
embedding_model_dims=1536,
|
||||
http_auth=mock_signer,
|
||||
verify_certs=True,
|
||||
use_ssl=True,
|
||||
auto_create_index=False
|
||||
)
|
||||
|
||||
# Verify OpenSearch was initialized with correct params
|
||||
mock_opensearch.assert_called_once_with(
|
||||
hosts=[{"host": "localhost", "port": 9200}],
|
||||
http_auth=mock_signer,
|
||||
use_ssl=True,
|
||||
verify_certs=True,
|
||||
connection_class=unittest.mock.ANY
|
||||
)
|
||||
Reference in New Issue
Block a user