Support Custom Search Query for Elasticsearch (#2372)
This commit is contained in:
@@ -54,6 +54,7 @@ Let's see the available parameters for the `elasticsearch` config:
|
|||||||
| `password` | Password for basic authentication | `None` |
|
| `password` | Password for basic authentication | `None` |
|
||||||
| `verify_certs` | Whether to verify SSL certificates | `True` |
|
| `verify_certs` | Whether to verify SSL certificates | `True` |
|
||||||
| `auto_create_index` | Whether to automatically create the index | `True` |
|
| `auto_create_index` | Whether to automatically create the index | `True` |
|
||||||
|
| `custom_search_query` | Function returning a custom search query | `None` |
|
||||||
|
|
||||||
### Features
|
### Features
|
||||||
|
|
||||||
@@ -62,3 +63,46 @@ Let's see the available parameters for the `elasticsearch` config:
|
|||||||
- Multiple authentication methods (Basic Auth, API Key)
|
- Multiple authentication methods (Basic Auth, API Key)
|
||||||
- Automatic index creation with optimized mappings for vector search
|
- Automatic index creation with optimized mappings for vector search
|
||||||
- Memory isolation through payload filtering
|
- Memory isolation through payload filtering
|
||||||
|
- Custom search query function to customize the search query
|
||||||
|
|
||||||
|
### Custom Search Query
|
||||||
|
|
||||||
|
The `custom_search_query` parameter allows you to customize the search query when `Memory.search` is called.
|
||||||
|
|
||||||
|
__Example__
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from typing import List, Optional, Dict
|
||||||
|
from mem0 import Memory
|
||||||
|
|
||||||
|
def custom_search_query(query: List[float], limit: int, filters: Optional[Dict]) -> Dict:
|
||||||
|
return {
|
||||||
|
"knn": {
|
||||||
|
"field": "vector",
|
||||||
|
"query_vector": query,
|
||||||
|
"k": limit,
|
||||||
|
"num_candidates": limit * 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
os.environ["OPENAI_API_KEY"] = "sk-xx"
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"vector_store": {
|
||||||
|
"provider": "elasticsearch",
|
||||||
|
"config": {
|
||||||
|
"collection_name": "mem0",
|
||||||
|
"host": "localhost",
|
||||||
|
"port": 9200,
|
||||||
|
"embedding_model_dims": 1536,
|
||||||
|
"custom_search_query": custom_search_query
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
It should be a function that takes the following parameters:
|
||||||
|
- `query`: a query vector used in `Memory.search`
|
||||||
|
- `limit`: a number of results used in `Memory.search`
|
||||||
|
- `filters`: a dictionary of key-value pairs used in `Memory.search`. You can add custom pairs for the custom search query.
|
||||||
|
|
||||||
|
The function should return a query body for the Elasticsearch search API.
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
from typing import Any, Dict, Optional
|
from collections.abc import Callable
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
@@ -15,6 +16,10 @@ class ElasticsearchConfig(BaseModel):
|
|||||||
verify_certs: bool = Field(True, description="Verify SSL certificates")
|
verify_certs: bool = Field(True, description="Verify SSL certificates")
|
||||||
use_ssl: bool = Field(True, description="Use SSL for connection")
|
use_ssl: bool = Field(True, description="Use SSL for connection")
|
||||||
auto_create_index: bool = Field(True, description="Automatically create index during initialization")
|
auto_create_index: bool = Field(True, description="Automatically create index during initialization")
|
||||||
|
custom_search_query: Optional[Callable[[List[float], int, Optional[Dict]], Dict]] = Field(
|
||||||
|
None,
|
||||||
|
description="Custom search query function. Parameters: (query, limit, filters) -> Dict"
|
||||||
|
)
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -45,6 +45,11 @@ class ElasticsearchDB(VectorStoreBase):
|
|||||||
# Create index only if auto_create_index is True
|
# Create index only if auto_create_index is True
|
||||||
if config.auto_create_index:
|
if config.auto_create_index:
|
||||||
self.create_index()
|
self.create_index()
|
||||||
|
|
||||||
|
if config.custom_search_query:
|
||||||
|
self.custom_search_query = config.custom_search_query
|
||||||
|
else:
|
||||||
|
self.custom_search_query = None
|
||||||
|
|
||||||
def create_index(self) -> None:
|
def create_index(self) -> None:
|
||||||
"""Create Elasticsearch index with proper mappings if it doesn't exist"""
|
"""Create Elasticsearch index with proper mappings if it doesn't exist"""
|
||||||
@@ -117,25 +122,20 @@ class ElasticsearchDB(VectorStoreBase):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
def search(self, query: List[float], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]:
|
def search(self, query: List[float], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]:
|
||||||
"""Search for similar vectors using KNN search with pre-filtering."""
|
"""
|
||||||
if not filters:
|
Search with two options:
|
||||||
# If no filters, just do KNN search
|
1. Use custom search query if provided
|
||||||
search_query = {"knn": {"field": "vector", "query_vector": query, "k": limit, "num_candidates": limit * 2}}
|
2. Use KNN search on vectors with pre-filtering if no custom search query is provided
|
||||||
|
"""
|
||||||
|
if self.custom_search_query:
|
||||||
|
search_query = self.custom_search_query(query, limit, filters)
|
||||||
else:
|
else:
|
||||||
# If filters exist, apply them with KNN search
|
search_query = {"knn": {"field": "vector", "query_vector": query, "k": limit, "num_candidates": limit * 2}}
|
||||||
filter_conditions = []
|
if filters:
|
||||||
for key, value in filters.items():
|
filter_conditions = []
|
||||||
filter_conditions.append({"term": {f"metadata.{key}": value}})
|
for key, value in filters.items():
|
||||||
|
filter_conditions.append({"term": {f"metadata.{key}": value}})
|
||||||
search_query = {
|
search_query["filter"] = {"bool": {"must": filter_conditions}}
|
||||||
"knn": {
|
|
||||||
"field": "vector",
|
|
||||||
"query_vector": query,
|
|
||||||
"k": limit,
|
|
||||||
"num_candidates": limit * 2,
|
|
||||||
"filter": {"bool": {"must": filter_conditions}},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
response = self.client.search(index=self.collection_name, body=search_query)
|
response = self.client.search(index=self.collection_name, body=search_query)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
import dotenv
|
import dotenv
|
||||||
|
|
||||||
@@ -220,6 +220,23 @@ class TestElasticsearchDB(unittest.TestCase):
|
|||||||
self.assertEqual(results[0].score, 0.8)
|
self.assertEqual(results[0].score, 0.8)
|
||||||
self.assertEqual(results[0].payload, {"key1": "value1"})
|
self.assertEqual(results[0].payload, {"key1": "value1"})
|
||||||
|
|
||||||
|
def test_custom_search_query(self):
|
||||||
|
# Mock custom search query
|
||||||
|
self.es_db.custom_search_query = Mock()
|
||||||
|
self.es_db.custom_search_query.return_value = {"custom_key": "custom_value"}
|
||||||
|
|
||||||
|
# Perform search
|
||||||
|
query_vector = [0.1] * 1536
|
||||||
|
limit = 5
|
||||||
|
filters = {"key1": "value1"}
|
||||||
|
self.es_db.search(query=query_vector, limit=limit, filters=filters)
|
||||||
|
|
||||||
|
# Verify custom search query function was called
|
||||||
|
self.es_db.custom_search_query.assert_called_once_with(query_vector, limit, filters)
|
||||||
|
|
||||||
|
# Verify custom search query was used
|
||||||
|
self.client_mock.search.assert_called_once_with(index=self.es_db.collection_name, body={"custom_key": "custom_value"})
|
||||||
|
|
||||||
def test_get(self):
|
def test_get(self):
|
||||||
# Mock get response with correct structure
|
# Mock get response with correct structure
|
||||||
mock_response = {
|
mock_response = {
|
||||||
|
|||||||
Reference in New Issue
Block a user