Support Custom Search Query for Elasticsearch (#2372)
This commit is contained in:
@@ -54,6 +54,7 @@ Let's see the available parameters for the `elasticsearch` config:
|
||||
| `password` | Password for basic authentication | `None` |
|
||||
| `verify_certs` | Whether to verify SSL certificates | `True` |
|
||||
| `auto_create_index` | Whether to automatically create the index | `True` |
|
||||
| `custom_search_query` | Function returning a custom search query | `None` |
|
||||
|
||||
### Features
|
||||
|
||||
@@ -62,3 +63,46 @@ Let's see the available parameters for the `elasticsearch` config:
|
||||
- Multiple authentication methods (Basic Auth, API Key)
|
||||
- Automatic index creation with optimized mappings for vector search
|
||||
- Memory isolation through payload filtering
|
||||
- Custom search query function to customize the search query
|
||||
|
||||
### Custom Search Query
|
||||
|
||||
The `custom_search_query` parameter allows you to customize the search query when `Memory.search` is called.
|
||||
|
||||
__Example__
|
||||
```python
|
||||
import os
|
||||
from typing import List, Optional, Dict
|
||||
from mem0 import Memory
|
||||
|
||||
def custom_search_query(query: List[float], limit: int, filters: Optional[Dict]) -> Dict:
|
||||
return {
|
||||
"knn": {
|
||||
"field": "vector",
|
||||
"query_vector": query,
|
||||
"k": limit,
|
||||
"num_candidates": limit * 2
|
||||
}
|
||||
}
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "sk-xx"
|
||||
|
||||
config = {
|
||||
"vector_store": {
|
||||
"provider": "elasticsearch",
|
||||
"config": {
|
||||
"collection_name": "mem0",
|
||||
"host": "localhost",
|
||||
"port": 9200,
|
||||
"embedding_model_dims": 1536,
|
||||
"custom_search_query": custom_search_query
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
It should be a function that takes the following parameters:
|
||||
- `query`: a query vector used in `Memory.search`
|
||||
- `limit`: a number of results used in `Memory.search`
|
||||
- `filters`: a dictionary of key-value pairs used in `Memory.search`. You can add custom pairs for the custom search query.
|
||||
|
||||
The function should return a query body for the Elasticsearch search API.
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
@@ -15,6 +16,10 @@ class ElasticsearchConfig(BaseModel):
|
||||
verify_certs: bool = Field(True, description="Verify SSL certificates")
|
||||
use_ssl: bool = Field(True, description="Use SSL for connection")
|
||||
auto_create_index: bool = Field(True, description="Automatically create index during initialization")
|
||||
custom_search_query: Optional[Callable[[List[float], int, Optional[Dict]], Dict]] = Field(
|
||||
None,
|
||||
description="Custom search query function. Parameters: (query, limit, filters) -> Dict"
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
||||
@@ -45,6 +45,11 @@ class ElasticsearchDB(VectorStoreBase):
|
||||
# Create index only if auto_create_index is True
|
||||
if config.auto_create_index:
|
||||
self.create_index()
|
||||
|
||||
if config.custom_search_query:
|
||||
self.custom_search_query = config.custom_search_query
|
||||
else:
|
||||
self.custom_search_query = None
|
||||
|
||||
def create_index(self) -> None:
|
||||
"""Create Elasticsearch index with proper mappings if it doesn't exist"""
|
||||
@@ -117,25 +122,20 @@ class ElasticsearchDB(VectorStoreBase):
|
||||
return results
|
||||
|
||||
def search(self, query: List[float], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]:
|
||||
"""Search for similar vectors using KNN search with pre-filtering."""
|
||||
if not filters:
|
||||
# If no filters, just do KNN search
|
||||
search_query = {"knn": {"field": "vector", "query_vector": query, "k": limit, "num_candidates": limit * 2}}
|
||||
"""
|
||||
Search with two options:
|
||||
1. Use custom search query if provided
|
||||
2. Use KNN search on vectors with pre-filtering if no custom search query is provided
|
||||
"""
|
||||
if self.custom_search_query:
|
||||
search_query = self.custom_search_query(query, limit, filters)
|
||||
else:
|
||||
# If filters exist, apply them with KNN search
|
||||
filter_conditions = []
|
||||
for key, value in filters.items():
|
||||
filter_conditions.append({"term": {f"metadata.{key}": value}})
|
||||
|
||||
search_query = {
|
||||
"knn": {
|
||||
"field": "vector",
|
||||
"query_vector": query,
|
||||
"k": limit,
|
||||
"num_candidates": limit * 2,
|
||||
"filter": {"bool": {"must": filter_conditions}},
|
||||
}
|
||||
}
|
||||
search_query = {"knn": {"field": "vector", "query_vector": query, "k": limit, "num_candidates": limit * 2}}
|
||||
if filters:
|
||||
filter_conditions = []
|
||||
for key, value in filters.items():
|
||||
filter_conditions.append({"term": {f"metadata.{key}": value}})
|
||||
search_query["filter"] = {"bool": {"must": filter_conditions}}
|
||||
|
||||
response = self.client.search(index=self.collection_name, body=search_query)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import dotenv
|
||||
|
||||
@@ -220,6 +220,23 @@ class TestElasticsearchDB(unittest.TestCase):
|
||||
self.assertEqual(results[0].score, 0.8)
|
||||
self.assertEqual(results[0].payload, {"key1": "value1"})
|
||||
|
||||
def test_custom_search_query(self):
|
||||
# Mock custom search query
|
||||
self.es_db.custom_search_query = Mock()
|
||||
self.es_db.custom_search_query.return_value = {"custom_key": "custom_value"}
|
||||
|
||||
# Perform search
|
||||
query_vector = [0.1] * 1536
|
||||
limit = 5
|
||||
filters = {"key1": "value1"}
|
||||
self.es_db.search(query=query_vector, limit=limit, filters=filters)
|
||||
|
||||
# Verify custom search query function was called
|
||||
self.es_db.custom_search_query.assert_called_once_with(query_vector, limit, filters)
|
||||
|
||||
# Verify custom search query was used
|
||||
self.client_mock.search.assert_called_once_with(index=self.es_db.collection_name, body={"custom_key": "custom_value"})
|
||||
|
||||
def test_get(self):
|
||||
# Mock get response with correct structure
|
||||
mock_response = {
|
||||
|
||||
Reference in New Issue
Block a user