Feature/fix opensearch vector mapping (#2399)

This commit is contained in:
Mauricio A
2025-03-20 00:07:57 -04:00
committed by GitHub
parent 6d5889d98f
commit 7b516328a8
3 changed files with 37 additions and 6 deletions

View File

@@ -14,6 +14,7 @@ class OpenSearchConfig(BaseModel):
verify_certs: bool = Field(False, description="Verify SSL certificates (default False for OpenSearch)") verify_certs: bool = Field(False, description="Verify SSL certificates (default False for OpenSearch)")
use_ssl: bool = Field(False, description="Use SSL for connection (default False for OpenSearch)") use_ssl: bool = Field(False, description="Use SSL for connection (default False for OpenSearch)")
auto_create_index: bool = Field(True, description="Automatically create index during initialization") auto_create_index: bool = Field(True, description="Automatically create index during initialization")
http_auth: Optional[object] = Field(None, description="HTTP authentication method / AWS SigV4")
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
@@ -23,7 +24,7 @@ class OpenSearchConfig(BaseModel):
raise ValueError("Host must be provided for OpenSearch") raise ValueError("Host must be provided for OpenSearch")
# Authentication: Either API key or user/password must be provided # Authentication: Either API key or user/password must be provided
if not any([values.get("api_key"), (values.get("user") and values.get("password"))]): if not any([values.get("api_key"), (values.get("user") and values.get("password")), values.get("http_auth")]):
raise ValueError("Either api_key or user/password must be provided for OpenSearch authentication") raise ValueError("Either api_key or user/password must be provided for OpenSearch authentication")
return values return values

View File

@@ -2,7 +2,7 @@ import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
try: try:
from opensearchpy import OpenSearch from opensearchpy import OpenSearch, RequestsHttpConnection
from opensearchpy.helpers import bulk from opensearchpy.helpers import bulk
except ImportError: except ImportError:
raise ImportError("OpenSearch requires extra dependencies. Install with `pip install opensearch-py`") from None raise ImportError("OpenSearch requires extra dependencies. Install with `pip install opensearch-py`") from None
@@ -28,9 +28,10 @@ class OpenSearchDB(VectorStoreBase):
# Initialize OpenSearch client # Initialize OpenSearch client
self.client = OpenSearch( self.client = OpenSearch(
hosts=[{"host": config.host, "port": config.port or 9200}], hosts=[{"host": config.host, "port": config.port or 9200}],
http_auth=(config.user, config.password) if (config.user and config.password) else None, http_auth=config.http_auth if config.http_auth else ((config.user, config.password) if (config.user and config.password) else None),
use_ssl=config.use_ssl, use_ssl=config.use_ssl,
verify_certs=config.verify_certs, verify_certs=config.verify_certs,
connection_class=RequestsHttpConnection
) )
self.collection_name = config.collection_name self.collection_name = config.collection_name
@@ -43,14 +44,17 @@ class OpenSearchDB(VectorStoreBase):
def create_index(self) -> None: def create_index(self) -> None:
"""Create OpenSearch index with proper mappings if it doesn't exist.""" """Create OpenSearch index with proper mappings if it doesn't exist."""
index_settings = { index_settings = {
# ToDo change replicas to 1
"settings": { "settings": {
"index": {"number_of_replicas": 1, "number_of_shards": 5, "refresh_interval": "1s", "knn": True} "index": {"number_of_replicas": 1, "number_of_shards": 5, "refresh_interval": "1s", "knn": True}
}, },
"mappings": { "mappings": {
"properties": { "properties": {
"text": {"type": "text"}, "text": {"type": "text"},
"vector": {"type": "knn_vector", "dimension": self.vector_dim}, "vector": {
"type": "knn_vector",
"dimension": self.vector_dim,
"method": {"engine": "lucene", "name": "hnsw", "space_type": "cosinesimil"},
},
"metadata": {"type": "object", "properties": {"user_id": {"type": "keyword"}}}, "metadata": {"type": "object", "properties": {"user_id": {"type": "keyword"}}},
} }
}, },

View File

@@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
import dotenv import dotenv
try: try:
from opensearchpy import OpenSearch from opensearchpy import OpenSearch, AWSV4SignerAuth
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"OpenSearch requires extra dependencies. Install with `pip install opensearch-py`" "OpenSearch requires extra dependencies. Install with `pip install opensearch-py`"
@@ -148,3 +148,29 @@ class TestOpenSearchDB(unittest.TestCase):
def test_delete_col(self): def test_delete_col(self):
self.os_db.delete_col() self.os_db.delete_col()
self.client_mock.indices.delete.assert_called_once_with(index="test_collection") self.client_mock.indices.delete.assert_called_once_with(index="test_collection")
def test_init_with_http_auth(self):
mock_credentials = MagicMock()
mock_signer = AWSV4SignerAuth(mock_credentials, "us-east-1", "es")
with patch('mem0.vector_stores.opensearch.OpenSearch') as mock_opensearch:
test_db = OpenSearchDB(
host="localhost",
port=9200,
collection_name="test_collection",
embedding_model_dims=1536,
http_auth=mock_signer,
verify_certs=True,
use_ssl=True,
auto_create_index=False
)
# Verify OpenSearch was initialized with correct params
mock_opensearch.assert_called_once_with(
hosts=[{"host": "localhost", "port": 9200}],
http_auth=mock_signer,
use_ssl=True,
verify_certs=True,
connection_class=unittest.mock.ANY
)