From f4c0f98fde88f4dff22e16f010bc20086e452bfb Mon Sep 17 00:00:00 2001 From: Seetha Rama Guptha Date: Thu, 20 Feb 2025 11:42:12 +0530 Subject: [PATCH] Adding Native OpenSearch support for Mem0 (#2211) --- Makefile | 2 +- docs/components/vectordbs/dbs/opensearch.mdx | 59 ++++++ docs/components/vectordbs/overview.mdx | 1 + docs/docs.json | 3 +- mem0/configs/vector_stores/opensearch.py | 42 +++++ mem0/utils/factory.py | 1 + mem0/vector_stores/configs.py | 1 + mem0/vector_stores/opensearch.py | 189 +++++++++++++++++++ tests/vector_stores/test_opensearch.py | 150 +++++++++++++++ 9 files changed, 446 insertions(+), 2 deletions(-) create mode 100644 docs/components/vectordbs/dbs/opensearch.mdx create mode 100644 mem0/configs/vector_stores/opensearch.py create mode 100644 mem0/vector_stores/opensearch.py create mode 100644 tests/vector_stores/test_opensearch.py diff --git a/Makefile b/Makefile index 31b69c50..5acac2c6 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ install: install_all: poetry install poetry run pip install groq together boto3 litellm ollama chromadb sentence_transformers vertexai \ - google-generativeai elasticsearch + google-generativeai elasticsearch opensearch-py # Format code with ruff format: diff --git a/docs/components/vectordbs/dbs/opensearch.mdx b/docs/components/vectordbs/dbs/opensearch.mdx new file mode 100644 index 00000000..507d0861 --- /dev/null +++ b/docs/components/vectordbs/dbs/opensearch.mdx @@ -0,0 +1,59 @@ +[OpenSearch](https://opensearch.org/) is an open-source, enterprise-grade search and observability suite that brings order to unstructured data at scale. OpenSearch supports k-NN (k-Nearest Neighbors) and allows you to store and retrieve high-dimensional vector embeddings efficiently. + +### Installation + +OpenSearch support requires additional dependencies. Install them with: + +```bash +pip install opensearch>=2.8.0 +``` + +### Usage + +```python +import os +from mem0 import Memory + +os.environ["OPENAI_API_KEY"] = "sk-xx" + +config = { + "vector_store": { + "provider": "opensearch", + "config": { + "collection_name": "mem0", + "host": "localhost", + "port": 9200, + "embedding_model_dims": 1536 + } + } +} + +m = Memory.from_config(config) +m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"}) +``` + +### Config + +Let's see the available parameters for the `opensearch` config: + +| Parameter | Description | Default Value | +| ---------------------- | -------------------------------------------------- | ------------- | +| `collection_name` | The name of the index to store the vectors | `mem0` | +| `embedding_model_dims` | Dimensions of the embedding model | `1536` | +| `host` | The host where the OpenSearch server is running | `localhost` | +| `port` | The port where the OpenSearch server is running | `9200` | +| `api_key` | API key for authentication | `None` | +| `user` | Username for basic authentication | `None` | +| `password` | Password for basic authentication | `None` | +| `verify_certs` | Whether to verify SSL certificates | `False` | +| `auto_create_index` | Whether to automatically create the index | `True` | +| `use_ssl` | Whether to use SSL for connection | `False` | + +### Features + +- Fast and Efficient Vector Search +- Can be deployed on-premises, in containers, or on cloud platforms like AWS OpenSearch Service. +- Multiple Authentication and Security Methods (Basic Authentication, API Keys, LDAP, SAML, and OpenID Connect) +- Automatic index creation with optimized mappings for vector search +- Memory Optimization through Disk-Based Vector Search and Quantization +- Real-Time Analytics and Observability diff --git a/docs/components/vectordbs/overview.mdx b/docs/components/vectordbs/overview.mdx index 61725466..b19395ac 100644 --- a/docs/components/vectordbs/overview.mdx +++ b/docs/components/vectordbs/overview.mdx @@ -18,6 +18,7 @@ See the list of supported vector databases below. + ## Usage diff --git a/docs/docs.json b/docs/docs.json index a8420812..3ea7ea7c 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -122,7 +122,8 @@ "components/vectordbs/dbs/milvus", "components/vectordbs/dbs/azure_ai_search", "components/vectordbs/dbs/redis", - "components/vectordbs/dbs/elasticsearch" + "components/vectordbs/dbs/elasticsearch", + "components/vectordbs/dbs/opensearch" ] } ] diff --git a/mem0/configs/vector_stores/opensearch.py b/mem0/configs/vector_stores/opensearch.py new file mode 100644 index 00000000..082416d2 --- /dev/null +++ b/mem0/configs/vector_stores/opensearch.py @@ -0,0 +1,42 @@ +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Field, model_validator + + +class OpenSearchConfig(BaseModel): + collection_name: str = Field("mem0", description="Name of the index") + host: str = Field("localhost", description="OpenSearch host") + port: int = Field(9200, description="OpenSearch port") + user: Optional[str] = Field(None, description="Username for authentication") + password: Optional[str] = Field(None, description="Password for authentication") + api_key: Optional[str] = Field(None, description="API key for authentication (if applicable)") + embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector") + verify_certs: bool = Field(False, description="Verify SSL certificates (default False for OpenSearch)") + use_ssl: bool = Field(False, description="Use SSL for connection (default False for OpenSearch)") + auto_create_index: bool = Field(True, description="Automatically create index during initialization") + + @model_validator(mode="before") + @classmethod + def validate_auth(cls, values: Dict[str, Any]) -> Dict[str, Any]: + # Check if host is provided + if not values.get("host"): + raise ValueError("Host must be provided for OpenSearch") + + # Authentication: Either API key or user/password must be provided + if not any([values.get("api_key"), (values.get("user") and values.get("password"))]): + raise ValueError("Either api_key or user/password must be provided for OpenSearch authentication") + + return values + + @model_validator(mode="before") + @classmethod + def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]: + allowed_fields = set(cls.model_fields.keys()) + input_fields = set(values.keys()) + extra_fields = input_fields - allowed_fields + if extra_fields: + raise ValueError( + f"Extra fields not allowed: {', '.join(extra_fields)}. " + f"Allowed fields: {', '.join(allowed_fields)}" + ) + return values diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index 4ff9d15e..408d2c19 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -68,6 +68,7 @@ class VectorStoreFactory: "azure_ai_search": "mem0.vector_stores.azure_ai_search.AzureAISearch", "redis": "mem0.vector_stores.redis.RedisDB", "elasticsearch": "mem0.vector_stores.elasticsearch.ElasticsearchDB", + "opensearch": "mem0.vector_stores.opensearch.OpenSearchDB" } @classmethod diff --git a/mem0/vector_stores/configs.py b/mem0/vector_stores/configs.py index 4a020f50..b6d1c86a 100644 --- a/mem0/vector_stores/configs.py +++ b/mem0/vector_stores/configs.py @@ -18,6 +18,7 @@ class VectorStoreConfig(BaseModel): "azure_ai_search": "AzureAISearchConfig", "redis": "RedisDBConfig", "elasticsearch": "ElasticsearchConfig", + "opensearch": "OpenSearchConfig", } @model_validator(mode="after") diff --git a/mem0/vector_stores/opensearch.py b/mem0/vector_stores/opensearch.py new file mode 100644 index 00000000..76b73dee --- /dev/null +++ b/mem0/vector_stores/opensearch.py @@ -0,0 +1,189 @@ +import logging +from typing import Any, Dict, List, Optional + +try: + from opensearchpy import OpenSearch + from opensearchpy.helpers import bulk +except ImportError: + raise ImportError("OpenSearch requires extra dependencies. Install with `pip install opensearch-py`") from None + +from pydantic import BaseModel + +from mem0.configs.vector_stores.opensearch import OpenSearchConfig +from mem0.vector_stores.base import VectorStoreBase + +logger = logging.getLogger(__name__) + + +class OutputData(BaseModel): + id: str + score: float + payload: Dict + + +class OpenSearchDB(VectorStoreBase): + def __init__(self, **kwargs): + config = OpenSearchConfig(**kwargs) + + # Initialize OpenSearch client + self.client = OpenSearch( + hosts=[{"host": config.host, "port": config.port or 9200}], + http_auth=(config.user, config.password) if (config.user and config.password) else None, + use_ssl=config.use_ssl, + verify_certs=config.verify_certs, + ) + + self.collection_name = config.collection_name + self.vector_dim = config.embedding_model_dims + + # Create index only if auto_create_index is True + if config.auto_create_index: + self.create_index() + + def create_index(self) -> None: + """Create OpenSearch index with proper mappings if it doesn't exist.""" + index_settings = { + # ToDo change replicas to 1 + "settings": { + "index": {"number_of_replicas": 1, "number_of_shards": 5, "refresh_interval": "1s", "knn": True} + }, + "mappings": { + "properties": { + "text": {"type": "text"}, + "vector": { + "type": "knn_vector", + "dimension": self.vector_dim + }, + "metadata": {"type": "object", "properties": {"user_id": {"type": "keyword"}}}, + } + }, + } + + if not self.client.indices.exists(index=self.collection_name): + self.client.indices.create(index=self.collection_name, body=index_settings) + logger.info(f"Created index {self.collection_name}") + else: + logger.info(f"Index {self.collection_name} already exists") + + def create_col(self, name: str, vector_size: int) -> None: + """Create a new collection (index in OpenSearch).""" + index_settings = { + "mappings": { + "properties": { + "vector": { + "type": "knn_vector", + "dimension": vector_size, + "method": { "engine": "lucene", "name": "hnsw", "space_type": "cosinesimil"}, + }, + "payload": {"type": "object"}, + "id": {"type": "keyword"}, + } + } + } + + if not self.client.indices.exists(index=name): + self.client.indices.create(index=name, body=index_settings) + logger.info(f"Created index {name}") + + def insert( + self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None + ) -> List[OutputData]: + """Insert vectors into the index.""" + if not ids: + ids = [str(i) for i in range(len(vectors))] + + if payloads is None: + payloads = [{} for _ in range(len(vectors))] + + actions = [] + for i, (vec, id_) in enumerate(zip(vectors, ids)): + action = { + "_index": self.collection_name, + "_id": id_, + "_source": { + "vector": vec, + "metadata": payloads[i], # Store metadata in the metadata field + }, + } + actions.append(action) + + bulk(self.client, actions) + + results = [] + for i, id_ in enumerate(ids): + results.append(OutputData(id=id_, score=1.0, payload=payloads[i])) + return results + + def search(self, query: List[float], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]: + """Search for similar vectors using OpenSearch k-NN search with pre-filtering.""" + search_query = { + "size": limit, + "query": { + "knn": { + "vector": { + "vector": query, + "k": limit, + } + } + } + } + + if filters: + filter_conditions = [{"term": {f"metadata.{key}": value}} for key, value in filters.items()] + search_query["query"]["knn"]["vector"]["filter"] = { "bool": {"filter": filter_conditions} } + + response = self.client.search(index=self.collection_name, body=search_query) + + results = [ + OutputData(id=hit["_id"], score=hit["_score"], payload=hit["_source"].get("metadata", {})) + for hit in response["hits"]["hits"] + ] + return results + + def delete(self, vector_id: str) -> None: + """Delete a vector by ID.""" + self.client.delete(index=self.collection_name, id=vector_id) + + def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None) -> None: + """Update a vector and its payload.""" + doc = {} + if vector is not None: + doc["vector"] = vector + if payload is not None: + doc["metadata"] = payload + + self.client.update(index=self.collection_name, id=vector_id, body={"doc": doc}) + + def get(self, vector_id: str) -> Optional[OutputData]: + """Retrieve a vector by ID.""" + try: + response = self.client.get(index=self.collection_name, id=vector_id) + return OutputData(id=response["_id"], score=1.0, payload=response["_source"].get("metadata", {})) + except Exception as e: + logger.error(f"Error retrieving vector {vector_id}: {e}") + return None + + def list_cols(self) -> List[str]: + """List all collections (indices).""" + return list(self.client.indices.get_alias().keys()) + + def delete_col(self) -> None: + """Delete a collection (index).""" + self.client.indices.delete(index=self.collection_name) + + def col_info(self, name: str) -> Any: + """Get information about a collection (index).""" + return self.client.indices.get(index=name) + + def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[List[OutputData]]: + """List all memories.""" + query = {"query": {"match_all": {}}} + + if filters: + query["query"] = {"bool": {"must": [{"term": {f"metadata.{key}": value}} for key, value in filters.items()]}} + + if limit: + query["size"] = limit + + response = self.client.search(index=self.collection_name, body=query) + return [[OutputData(id=hit["_id"], score=1.0, payload=hit["_source"].get("metadata", {})) for hit in response["hits"]["hits"]]] \ No newline at end of file diff --git a/tests/vector_stores/test_opensearch.py b/tests/vector_stores/test_opensearch.py new file mode 100644 index 00000000..40d0959b --- /dev/null +++ b/tests/vector_stores/test_opensearch.py @@ -0,0 +1,150 @@ +import os +import unittest +from unittest.mock import MagicMock, patch + +import dotenv + +try: + from opensearchpy import OpenSearch +except ImportError: + raise ImportError( + "OpenSearch requires extra dependencies. Install with `pip install opensearch-py`" + ) from None + +from mem0.vector_stores.opensearch import OpenSearchDB + + +class TestOpenSearchDB(unittest.TestCase): + @classmethod + def setUpClass(cls): + dotenv.load_dotenv() + cls.original_env = { + 'OS_URL': os.getenv('OS_URL', 'http://localhost:9200'), + 'OS_USERNAME': os.getenv('OS_USERNAME', 'test_user'), + 'OS_PASSWORD': os.getenv('OS_PASSWORD', 'test_password') + } + os.environ['OS_URL'] = 'http://localhost' + os.environ['OS_USERNAME'] = 'test_user' + os.environ['OS_PASSWORD'] = 'test_password' + + def setUp(self): + self.client_mock = MagicMock(spec=OpenSearch) + self.client_mock.indices = MagicMock() + self.client_mock.indices.exists = MagicMock(return_value=False) + self.client_mock.indices.create = MagicMock() + self.client_mock.indices.delete = MagicMock() + self.client_mock.indices.get_alias = MagicMock() + self.client_mock.get = MagicMock() + self.client_mock.update = MagicMock() + self.client_mock.delete = MagicMock() + self.client_mock.search = MagicMock() + + patcher = patch('mem0.vector_stores.opensearch.OpenSearch', return_value=self.client_mock) + self.mock_os = patcher.start() + self.addCleanup(patcher.stop) + + self.os_db = OpenSearchDB( + host=os.getenv('OS_URL'), + port=9200, + collection_name="test_collection", + embedding_model_dims=1536, + user=os.getenv('OS_USERNAME'), + password=os.getenv('OS_PASSWORD'), + verify_certs=False, + use_ssl=False, + auto_create_index=False + ) + self.client_mock.reset_mock() + + @classmethod + def tearDownClass(cls): + for key, value in cls.original_env.items(): + if value is not None: + os.environ[key] = value + else: + os.environ.pop(key, None) + + def tearDown(self): + self.client_mock.reset_mock() + + def test_create_index(self): + self.client_mock.indices.exists.return_value = False + self.os_db.create_index() + self.client_mock.indices.create.assert_called_once() + create_args = self.client_mock.indices.create.call_args[1] + self.assertEqual(create_args["index"], "test_collection") + mappings = create_args["body"]["mappings"]["properties"] + self.assertEqual(mappings["vector"]["type"], "knn_vector") + self.assertEqual(mappings["vector"]["dimension"], 1536) + self.client_mock.reset_mock() + self.client_mock.indices.exists.return_value = True + self.os_db.create_index() + self.client_mock.indices.create.assert_not_called() + + def test_insert(self): + vectors = [[0.1] * 1536, [0.2] * 1536] + payloads = [{"key1": "value1"}, {"key2": "value2"}] + ids = ["id1", "id2"] + with patch('mem0.vector_stores.opensearch.bulk') as mock_bulk: + mock_bulk.return_value = (2, []) + results = self.os_db.insert(vectors=vectors, payloads=payloads, ids=ids) + mock_bulk.assert_called_once() + actions = mock_bulk.call_args[0][1] + self.assertEqual(actions[0]["_index"], "test_collection") + self.assertEqual(actions[0]["_id"], "id1") + self.assertEqual(actions[0]["_source"]["vector"], vectors[0]) + self.assertEqual(actions[0]["_source"]["metadata"], payloads[0]) + self.assertEqual(len(results), 2) + self.assertEqual(results[0].id, "id1") + self.assertEqual(results[0].payload, payloads[0]) + + def test_get(self): + mock_response = {"_id": "id1", "_source": {"metadata": {"key1": "value1"}}} + self.client_mock.get.return_value = mock_response + result = self.os_db.get("id1") + self.client_mock.get.assert_called_once_with(index="test_collection", id="id1") + self.assertIsNotNone(result) + self.assertEqual(result.id, "id1") + self.assertEqual(result.payload, {"key1": "value1"}) + + def test_update(self): + vector = [0.3] * 1536 + payload = {"key3": "value3"} + self.os_db.update("id1", vector=vector, payload=payload) + self.client_mock.update.assert_called_once() + update_args = self.client_mock.update.call_args[1] + self.assertEqual(update_args["index"], "test_collection") + self.assertEqual(update_args["id"], "id1") + self.assertEqual(update_args["body"], {"doc": {"vector": vector, "metadata": payload}}) + + def test_list_cols(self): + self.client_mock.indices.get_alias.return_value = {"test_collection": {}} + result = self.os_db.list_cols() + self.client_mock.indices.get_alias.assert_called_once() + self.assertEqual(result, ["test_collection"]) + + def test_search(self): + mock_response = {"hits": {"hits": [{"_id": "id1", "_score": 0.8, "_source": {"vector": [0.1] * 1536, "metadata": {"key1": "value1"}}}]}} + self.client_mock.search.return_value = mock_response + query_vector = [0.1] * 1536 + results = self.os_db.search(query=query_vector, limit=5) + self.client_mock.search.assert_called_once() + search_args = self.client_mock.search.call_args[1] + self.assertEqual(search_args["index"], "test_collection") + body = search_args["body"] + self.assertIn("knn", body["query"]) + self.assertIn("vector", body["query"]["knn"]) + self.assertEqual(body["query"]["knn"]["vector"]["vector"], query_vector) + self.assertEqual(body["query"]["knn"]["vector"]["k"], 5) + self.assertEqual(len(results), 1) + self.assertEqual(results[0].id, "id1") + self.assertEqual(results[0].score, 0.8) + self.assertEqual(results[0].payload, {"key1": "value1"}) + + def test_delete(self): + self.os_db.delete(vector_id="id1") + self.client_mock.delete.assert_called_once_with(index="test_collection", id="id1") + + def test_delete_col(self): + self.os_db.delete_col() + self.client_mock.indices.delete.assert_called_once_with(index="test_collection")