Feature/fix opensearch vector mapping (#2399)
This commit is contained in:
@@ -14,6 +14,7 @@ class OpenSearchConfig(BaseModel):
|
|||||||
verify_certs: bool = Field(False, description="Verify SSL certificates (default False for OpenSearch)")
|
verify_certs: bool = Field(False, description="Verify SSL certificates (default False for OpenSearch)")
|
||||||
use_ssl: bool = Field(False, description="Use SSL for connection (default False for OpenSearch)")
|
use_ssl: bool = Field(False, description="Use SSL for connection (default False for OpenSearch)")
|
||||||
auto_create_index: bool = Field(True, description="Automatically create index during initialization")
|
auto_create_index: bool = Field(True, description="Automatically create index during initialization")
|
||||||
|
http_auth: Optional[object] = Field(None, description="HTTP authentication method / AWS SigV4")
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -23,7 +24,7 @@ class OpenSearchConfig(BaseModel):
|
|||||||
raise ValueError("Host must be provided for OpenSearch")
|
raise ValueError("Host must be provided for OpenSearch")
|
||||||
|
|
||||||
# Authentication: Either API key or user/password must be provided
|
# Authentication: Either API key or user/password must be provided
|
||||||
if not any([values.get("api_key"), (values.get("user") and values.get("password"))]):
|
if not any([values.get("api_key"), (values.get("user") and values.get("password")), values.get("http_auth")]):
|
||||||
raise ValueError("Either api_key or user/password must be provided for OpenSearch authentication")
|
raise ValueError("Either api_key or user/password must be provided for OpenSearch authentication")
|
||||||
|
|
||||||
return values
|
return values
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import logging
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from opensearchpy import OpenSearch
|
from opensearchpy import OpenSearch, RequestsHttpConnection
|
||||||
from opensearchpy.helpers import bulk
|
from opensearchpy.helpers import bulk
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("OpenSearch requires extra dependencies. Install with `pip install opensearch-py`") from None
|
raise ImportError("OpenSearch requires extra dependencies. Install with `pip install opensearch-py`") from None
|
||||||
@@ -28,9 +28,10 @@ class OpenSearchDB(VectorStoreBase):
|
|||||||
# Initialize OpenSearch client
|
# Initialize OpenSearch client
|
||||||
self.client = OpenSearch(
|
self.client = OpenSearch(
|
||||||
hosts=[{"host": config.host, "port": config.port or 9200}],
|
hosts=[{"host": config.host, "port": config.port or 9200}],
|
||||||
http_auth=(config.user, config.password) if (config.user and config.password) else None,
|
http_auth=config.http_auth if config.http_auth else ((config.user, config.password) if (config.user and config.password) else None),
|
||||||
use_ssl=config.use_ssl,
|
use_ssl=config.use_ssl,
|
||||||
verify_certs=config.verify_certs,
|
verify_certs=config.verify_certs,
|
||||||
|
connection_class=RequestsHttpConnection
|
||||||
)
|
)
|
||||||
|
|
||||||
self.collection_name = config.collection_name
|
self.collection_name = config.collection_name
|
||||||
@@ -43,14 +44,17 @@ class OpenSearchDB(VectorStoreBase):
|
|||||||
def create_index(self) -> None:
|
def create_index(self) -> None:
|
||||||
"""Create OpenSearch index with proper mappings if it doesn't exist."""
|
"""Create OpenSearch index with proper mappings if it doesn't exist."""
|
||||||
index_settings = {
|
index_settings = {
|
||||||
# ToDo change replicas to 1
|
|
||||||
"settings": {
|
"settings": {
|
||||||
"index": {"number_of_replicas": 1, "number_of_shards": 5, "refresh_interval": "1s", "knn": True}
|
"index": {"number_of_replicas": 1, "number_of_shards": 5, "refresh_interval": "1s", "knn": True}
|
||||||
},
|
},
|
||||||
"mappings": {
|
"mappings": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"text": {"type": "text"},
|
"text": {"type": "text"},
|
||||||
"vector": {"type": "knn_vector", "dimension": self.vector_dim},
|
"vector": {
|
||||||
|
"type": "knn_vector",
|
||||||
|
"dimension": self.vector_dim,
|
||||||
|
"method": {"engine": "lucene", "name": "hnsw", "space_type": "cosinesimil"},
|
||||||
|
},
|
||||||
"metadata": {"type": "object", "properties": {"user_id": {"type": "keyword"}}},
|
"metadata": {"type": "object", "properties": {"user_id": {"type": "keyword"}}},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
|
|||||||
import dotenv
|
import dotenv
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from opensearchpy import OpenSearch
|
from opensearchpy import OpenSearch, AWSV4SignerAuth
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"OpenSearch requires extra dependencies. Install with `pip install opensearch-py`"
|
"OpenSearch requires extra dependencies. Install with `pip install opensearch-py`"
|
||||||
@@ -148,3 +148,29 @@ class TestOpenSearchDB(unittest.TestCase):
|
|||||||
def test_delete_col(self):
|
def test_delete_col(self):
|
||||||
self.os_db.delete_col()
|
self.os_db.delete_col()
|
||||||
self.client_mock.indices.delete.assert_called_once_with(index="test_collection")
|
self.client_mock.indices.delete.assert_called_once_with(index="test_collection")
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_with_http_auth(self):
|
||||||
|
mock_credentials = MagicMock()
|
||||||
|
mock_signer = AWSV4SignerAuth(mock_credentials, "us-east-1", "es")
|
||||||
|
|
||||||
|
with patch('mem0.vector_stores.opensearch.OpenSearch') as mock_opensearch:
|
||||||
|
test_db = OpenSearchDB(
|
||||||
|
host="localhost",
|
||||||
|
port=9200,
|
||||||
|
collection_name="test_collection",
|
||||||
|
embedding_model_dims=1536,
|
||||||
|
http_auth=mock_signer,
|
||||||
|
verify_certs=True,
|
||||||
|
use_ssl=True,
|
||||||
|
auto_create_index=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify OpenSearch was initialized with correct params
|
||||||
|
mock_opensearch.assert_called_once_with(
|
||||||
|
hosts=[{"host": "localhost", "port": 9200}],
|
||||||
|
http_auth=mock_signer,
|
||||||
|
use_ssl=True,
|
||||||
|
verify_certs=True,
|
||||||
|
connection_class=unittest.mock.ANY
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user