diff --git a/docs/components/vectordbs/dbs/elasticsearch.mdx b/docs/components/vectordbs/dbs/elasticsearch.mdx index e3f19e5d..d8918112 100644 --- a/docs/components/vectordbs/dbs/elasticsearch.mdx +++ b/docs/components/vectordbs/dbs/elasticsearch.mdx @@ -54,6 +54,7 @@ Let's see the available parameters for the `elasticsearch` config: | `password` | Password for basic authentication | `None` | | `verify_certs` | Whether to verify SSL certificates | `True` | | `auto_create_index` | Whether to automatically create the index | `True` | +| `custom_search_query` | Function returning a custom search query | `None` | ### Features @@ -62,3 +63,46 @@ Let's see the available parameters for the `elasticsearch` config: - Multiple authentication methods (Basic Auth, API Key) - Automatic index creation with optimized mappings for vector search - Memory isolation through payload filtering +- Custom search query function to customize the search query + +### Custom Search Query + +The `custom_search_query` parameter allows you to customize the search query when `Memory.search` is called. + +__Example__ +```python +import os +from typing import List, Optional, Dict +from mem0 import Memory + +def custom_search_query(query: List[float], limit: int, filters: Optional[Dict]) -> Dict: + return { + "knn": { + "field": "vector", + "query_vector": query, + "k": limit, + "num_candidates": limit * 2 + } + } + +os.environ["OPENAI_API_KEY"] = "sk-xx" + +config = { + "vector_store": { + "provider": "elasticsearch", + "config": { + "collection_name": "mem0", + "host": "localhost", + "port": 9200, + "embedding_model_dims": 1536, + "custom_search_query": custom_search_query + } + } +} +``` +It should be a function that takes the following parameters: +- `query`: a query vector used in `Memory.search` +- `limit`: a number of results used in `Memory.search` +- `filters`: a dictionary of key-value pairs used in `Memory.search`. You can add custom pairs for the custom search query. + +The function should return a query body for the Elasticsearch search API. \ No newline at end of file diff --git a/mem0/configs/vector_stores/elasticsearch.py b/mem0/configs/vector_stores/elasticsearch.py index 0406dc9f..b0d3ef29 100644 --- a/mem0/configs/vector_stores/elasticsearch.py +++ b/mem0/configs/vector_stores/elasticsearch.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, Optional +from collections.abc import Callable +from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field, model_validator @@ -15,6 +16,10 @@ class ElasticsearchConfig(BaseModel): verify_certs: bool = Field(True, description="Verify SSL certificates") use_ssl: bool = Field(True, description="Use SSL for connection") auto_create_index: bool = Field(True, description="Automatically create index during initialization") + custom_search_query: Optional[Callable[[List[float], int, Optional[Dict]], Dict]] = Field( + None, + description="Custom search query function. Parameters: (query, limit, filters) -> Dict" + ) @model_validator(mode="before") @classmethod diff --git a/mem0/vector_stores/elasticsearch.py b/mem0/vector_stores/elasticsearch.py index 0de5fb23..ce349740 100644 --- a/mem0/vector_stores/elasticsearch.py +++ b/mem0/vector_stores/elasticsearch.py @@ -45,6 +45,11 @@ class ElasticsearchDB(VectorStoreBase): # Create index only if auto_create_index is True if config.auto_create_index: self.create_index() + + if config.custom_search_query: + self.custom_search_query = config.custom_search_query + else: + self.custom_search_query = None def create_index(self) -> None: """Create Elasticsearch index with proper mappings if it doesn't exist""" @@ -117,25 +122,20 @@ class ElasticsearchDB(VectorStoreBase): return results def search(self, query: List[float], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]: - """Search for similar vectors using KNN search with pre-filtering.""" - if not filters: - # If no filters, just do KNN search - search_query = {"knn": {"field": "vector", "query_vector": query, "k": limit, "num_candidates": limit * 2}} + """ + Search with two options: + 1. Use custom search query if provided + 2. Use KNN search on vectors with pre-filtering if no custom search query is provided + """ + if self.custom_search_query: + search_query = self.custom_search_query(query, limit, filters) else: - # If filters exist, apply them with KNN search - filter_conditions = [] - for key, value in filters.items(): - filter_conditions.append({"term": {f"metadata.{key}": value}}) - - search_query = { - "knn": { - "field": "vector", - "query_vector": query, - "k": limit, - "num_candidates": limit * 2, - "filter": {"bool": {"must": filter_conditions}}, - } - } + search_query = {"knn": {"field": "vector", "query_vector": query, "k": limit, "num_candidates": limit * 2}} + if filters: + filter_conditions = [] + for key, value in filters.items(): + filter_conditions.append({"term": {f"metadata.{key}": value}}) + search_query["filter"] = {"bool": {"must": filter_conditions}} response = self.client.search(index=self.collection_name, body=search_query) diff --git a/tests/vector_stores/test_elasticsearch.py b/tests/vector_stores/test_elasticsearch.py index cf8e8a45..be8134e6 100644 --- a/tests/vector_stores/test_elasticsearch.py +++ b/tests/vector_stores/test_elasticsearch.py @@ -1,6 +1,6 @@ import os import unittest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, Mock, patch import dotenv @@ -220,6 +220,23 @@ class TestElasticsearchDB(unittest.TestCase): self.assertEqual(results[0].score, 0.8) self.assertEqual(results[0].payload, {"key1": "value1"}) + def test_custom_search_query(self): + # Mock custom search query + self.es_db.custom_search_query = Mock() + self.es_db.custom_search_query.return_value = {"custom_key": "custom_value"} + + # Perform search + query_vector = [0.1] * 1536 + limit = 5 + filters = {"key1": "value1"} + self.es_db.search(query=query_vector, limit=limit, filters=filters) + + # Verify custom search query function was called + self.es_db.custom_search_query.assert_called_once_with(query_vector, limit, filters) + + # Verify custom search query was used + self.client_mock.search.assert_called_once_with(index=self.es_db.collection_name, body={"custom_key": "custom_value"}) + def test_get(self): # Mock get response with correct structure mock_response = {