Support Custom Search Query for Elasticsearch (#2372)

This commit is contained in:
Wonbin Kim
2025-03-18 14:04:34 +09:00
committed by GitHub
parent 00a2ea9ff0
commit b8f40f728f
4 changed files with 86 additions and 20 deletions

View File

@@ -54,6 +54,7 @@ Let's see the available parameters for the `elasticsearch` config:
| `password` | Password for basic authentication | `None` |
| `verify_certs` | Whether to verify SSL certificates | `True` |
| `auto_create_index` | Whether to automatically create the index | `True` |
| `custom_search_query` | Function returning a custom search query | `None` |
### Features
@@ -62,3 +63,46 @@ Let's see the available parameters for the `elasticsearch` config:
- Multiple authentication methods (Basic Auth, API Key)
- Automatic index creation with optimized mappings for vector search
- Memory isolation through payload filtering
- Custom search query function to customize the search query
### Custom Search Query
The `custom_search_query` parameter allows you to customize the search query when `Memory.search` is called.
__Example__
```python
import os
from typing import List, Optional, Dict
from mem0 import Memory
def custom_search_query(query: List[float], limit: int, filters: Optional[Dict]) -> Dict:
return {
"knn": {
"field": "vector",
"query_vector": query,
"k": limit,
"num_candidates": limit * 2
}
}
os.environ["OPENAI_API_KEY"] = "sk-xx"
config = {
"vector_store": {
"provider": "elasticsearch",
"config": {
"collection_name": "mem0",
"host": "localhost",
"port": 9200,
"embedding_model_dims": 1536,
"custom_search_query": custom_search_query
}
}
}
```
It should be a function that takes the following parameters:
- `query`: a query vector used in `Memory.search`
- `limit`: a number of results used in `Memory.search`
- `filters`: a dictionary of key-value pairs used in `Memory.search`. You can add custom pairs for the custom search query.
The function should return a query body for the Elasticsearch search API.

View File

@@ -1,4 +1,5 @@
from typing import Any, Dict, Optional
from collections.abc import Callable
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, model_validator
@@ -15,6 +16,10 @@ class ElasticsearchConfig(BaseModel):
verify_certs: bool = Field(True, description="Verify SSL certificates")
use_ssl: bool = Field(True, description="Use SSL for connection")
auto_create_index: bool = Field(True, description="Automatically create index during initialization")
custom_search_query: Optional[Callable[[List[float], int, Optional[Dict]], Dict]] = Field(
None,
description="Custom search query function. Parameters: (query, limit, filters) -> Dict"
)
@model_validator(mode="before")
@classmethod

View File

@@ -45,6 +45,11 @@ class ElasticsearchDB(VectorStoreBase):
# Create index only if auto_create_index is True
if config.auto_create_index:
self.create_index()
if config.custom_search_query:
self.custom_search_query = config.custom_search_query
else:
self.custom_search_query = None
def create_index(self) -> None:
"""Create Elasticsearch index with proper mappings if it doesn't exist"""
@@ -117,25 +122,20 @@ class ElasticsearchDB(VectorStoreBase):
return results
def search(self, query: List[float], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]:
"""Search for similar vectors using KNN search with pre-filtering."""
if not filters:
# If no filters, just do KNN search
search_query = {"knn": {"field": "vector", "query_vector": query, "k": limit, "num_candidates": limit * 2}}
"""
Search with two options:
1. Use custom search query if provided
2. Use KNN search on vectors with pre-filtering if no custom search query is provided
"""
if self.custom_search_query:
search_query = self.custom_search_query(query, limit, filters)
else:
# If filters exist, apply them with KNN search
filter_conditions = []
for key, value in filters.items():
filter_conditions.append({"term": {f"metadata.{key}": value}})
search_query = {
"knn": {
"field": "vector",
"query_vector": query,
"k": limit,
"num_candidates": limit * 2,
"filter": {"bool": {"must": filter_conditions}},
}
}
search_query = {"knn": {"field": "vector", "query_vector": query, "k": limit, "num_candidates": limit * 2}}
if filters:
filter_conditions = []
for key, value in filters.items():
filter_conditions.append({"term": {f"metadata.{key}": value}})
search_query["filter"] = {"bool": {"must": filter_conditions}}
response = self.client.search(index=self.collection_name, body=search_query)

View File

@@ -1,6 +1,6 @@
import os
import unittest
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, Mock, patch
import dotenv
@@ -220,6 +220,23 @@ class TestElasticsearchDB(unittest.TestCase):
self.assertEqual(results[0].score, 0.8)
self.assertEqual(results[0].payload, {"key1": "value1"})
def test_custom_search_query(self):
# Mock custom search query
self.es_db.custom_search_query = Mock()
self.es_db.custom_search_query.return_value = {"custom_key": "custom_value"}
# Perform search
query_vector = [0.1] * 1536
limit = 5
filters = {"key1": "value1"}
self.es_db.search(query=query_vector, limit=limit, filters=filters)
# Verify custom search query function was called
self.es_db.custom_search_query.assert_called_once_with(query_vector, limit, filters)
# Verify custom search query was used
self.client_mock.search.assert_called_once_with(index=self.es_db.collection_name, body={"custom_key": "custom_value"})
def test_get(self):
# Mock get response with correct structure
mock_response = {