Support for hybrid search in Azure AI vector store (#2408)
Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
@@ -50,6 +50,24 @@ config = {
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Using hybrid search
|
||||||
|
|
||||||
|
```python
|
||||||
|
config = {
|
||||||
|
"vector_store": {
|
||||||
|
"provider": "azure_ai_search",
|
||||||
|
"config": {
|
||||||
|
"service_name": "ai-search-test",
|
||||||
|
"api_key": "*****",
|
||||||
|
"collection_name": "mem0",
|
||||||
|
"embedding_model_dims": 1536,
|
||||||
|
"hybrid_search": True,
|
||||||
|
"vector_filter_mode": "postFilter"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
## Configuration Parameters
|
## Configuration Parameters
|
||||||
|
|
||||||
| Parameter | Description | Default Value | Options |
|
| Parameter | Description | Default Value | Options |
|
||||||
@@ -60,6 +78,8 @@ config = {
|
|||||||
| `embedding_model_dims` | Dimensions of the embedding model | `1536` | Any integer value |
|
| `embedding_model_dims` | Dimensions of the embedding model | `1536` | Any integer value |
|
||||||
| `compression_type` | Type of vector compression to use | `none` | `none`, `scalar`, `binary` |
|
| `compression_type` | Type of vector compression to use | `none` | `none`, `scalar`, `binary` |
|
||||||
| `use_float16` | Store vectors in half precision (Edm.Half) | `False` | `True`, `False` |
|
| `use_float16` | Store vectors in half precision (Edm.Half) | `False` | `True`, `False` |
|
||||||
|
| `vector_filter_mode` | Vector filter mode to use | `preFilter` | `postFilter`, `preFilter` |
|
||||||
|
| `hybrid_search` | Use hybrid search | `False` | `True`, `False` |
|
||||||
|
|
||||||
## Notes on Configuration Options
|
## Notes on Configuration Options
|
||||||
|
|
||||||
@@ -68,6 +88,10 @@ config = {
|
|||||||
- `scalar`: Scalar quantization with reasonable balance of speed and accuracy
|
- `scalar`: Scalar quantization with reasonable balance of speed and accuracy
|
||||||
- `binary`: Binary quantization for maximum compression with some accuracy trade-off
|
- `binary`: Binary quantization for maximum compression with some accuracy trade-off
|
||||||
|
|
||||||
|
- **vector_filter_mode**:
|
||||||
|
- `preFilter`: Applies filters before vector search (faster)
|
||||||
|
- `postFilter`: Applies filters after vector search (may provide better relevance)
|
||||||
|
|
||||||
- **use_float16**: Using half precision (float16) reduces storage requirements but may slightly impact accuracy. Useful for very large vector collections.
|
- **use_float16**: Using half precision (float16) reduces storage requirements but may slightly impact accuracy. Useful for very large vector collections.
|
||||||
|
|
||||||
- **Filterable Fields**: The implementation automatically extracts `user_id`, `run_id`, and `agent_id` fields from payloads for filtering.
|
- **Filterable Fields**: The implementation automatically extracts `user_id`, `run_id`, and `agent_id` fields from payloads for filtering.
|
||||||
@@ -8,21 +8,26 @@ class AzureAISearchConfig(BaseModel):
|
|||||||
api_key: str = Field(None, description="API key for the Azure AI Search service")
|
api_key: str = Field(None, description="API key for the Azure AI Search service")
|
||||||
embedding_model_dims: int = Field(None, description="Dimension of the embedding vector")
|
embedding_model_dims: int = Field(None, description="Dimension of the embedding vector")
|
||||||
compression_type: Optional[str] = Field(
|
compression_type: Optional[str] = Field(
|
||||||
None,
|
None, description="Type of vector compression to use. Options: 'scalar', 'binary', or None"
|
||||||
description="Type of vector compression to use. Options: 'scalar', 'binary', or None"
|
|
||||||
)
|
)
|
||||||
use_float16: bool = Field(
|
use_float16: bool = Field(
|
||||||
False,
|
False,
|
||||||
description="Whether to store vectors in half precision (Edm.Half) instead of full precision (Edm.Single)"
|
description="Whether to store vectors in half precision (Edm.Half) instead of full precision (Edm.Single)",
|
||||||
)
|
)
|
||||||
|
hybrid_search: bool = Field(
|
||||||
|
False, description="Whether to use hybrid search. If True, vector_filter_mode must be 'preFilter'"
|
||||||
|
)
|
||||||
|
vector_filter_mode: Optional[str] = Field(
|
||||||
|
"preFilter", description="Mode for vector filtering. Options: 'preFilter', 'postFilter'"
|
||||||
|
)
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
allowed_fields = set(cls.model_fields.keys())
|
allowed_fields = set(cls.model_fields.keys())
|
||||||
input_fields = set(values.keys())
|
input_fields = set(values.keys())
|
||||||
extra_fields = input_fields - allowed_fields
|
extra_fields = input_fields - allowed_fields
|
||||||
|
|
||||||
# Check for use_compression to provide a helpful error
|
# Check for use_compression to provide a helpful error
|
||||||
if "use_compression" in extra_fields:
|
if "use_compression" in extra_fields:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -30,13 +35,13 @@ class AzureAISearchConfig(BaseModel):
|
|||||||
"Please use 'compression_type=\"scalar\"' instead of 'use_compression=True' "
|
"Please use 'compression_type=\"scalar\"' instead of 'use_compression=True' "
|
||||||
"or 'compression_type=None' instead of 'use_compression=False'."
|
"or 'compression_type=None' instead of 'use_compression=False'."
|
||||||
)
|
)
|
||||||
|
|
||||||
if extra_fields:
|
if extra_fields:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Extra fields not allowed: {', '.join(extra_fields)}. "
|
f"Extra fields not allowed: {', '.join(extra_fields)}. "
|
||||||
f"Please input only the following fields: {', '.join(allowed_fields)}"
|
f"Please input only the following fields: {', '.join(allowed_fields)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate compression_type values
|
# Validate compression_type values
|
||||||
if "compression_type" in values and values["compression_type"] is not None:
|
if "compression_type" in values and values["compression_type"] is not None:
|
||||||
valid_types = ["scalar", "binary"]
|
valid_types = ["scalar", "binary"]
|
||||||
@@ -45,9 +50,9 @@ class AzureAISearchConfig(BaseModel):
|
|||||||
f"Invalid compression_type: {values['compression_type']}. "
|
f"Invalid compression_type: {values['compression_type']}. "
|
||||||
f"Must be one of: {', '.join(valid_types)}, or None"
|
f"Must be one of: {', '.join(valid_types)}, or None"
|
||||||
)
|
)
|
||||||
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
model_config = {
|
model_config = {
|
||||||
"arbitrary_types_allowed": True,
|
"arbitrary_types_allowed": True,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,8 +17,7 @@ class ElasticsearchConfig(BaseModel):
|
|||||||
use_ssl: bool = Field(True, description="Use SSL for connection")
|
use_ssl: bool = Field(True, description="Use SSL for connection")
|
||||||
auto_create_index: bool = Field(True, description="Automatically create index during initialization")
|
auto_create_index: bool = Field(True, description="Automatically create index during initialization")
|
||||||
custom_search_query: Optional[Callable[[List[float], int, Optional[Dict]], Dict]] = Field(
|
custom_search_query: Optional[Callable[[List[float], int, Optional[Dict]], Dict]] = Field(
|
||||||
None,
|
None, description="Custom search query function. Parameters: (query, limit, filters) -> Dict"
|
||||||
description="Custom search query function. Parameters: (query, limit, filters) -> Dict"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
|
|||||||
@@ -14,9 +14,7 @@ class GoogleMatchingEngineConfig(BaseModel):
|
|||||||
credentials_path: Optional[str] = Field(None, description="Path to service account credentials file")
|
credentials_path: Optional[str] = Field(None, description="Path to service account credentials file")
|
||||||
vector_search_api_endpoint: Optional[str] = Field(None, description="Vector search API endpoint")
|
vector_search_api_endpoint: Optional[str] = Field(None, description="Vector search API endpoint")
|
||||||
|
|
||||||
model_config = {
|
model_config = {"extra": "forbid"}
|
||||||
"extra": "forbid"
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@@ -26,4 +24,4 @@ class GoogleMatchingEngineConfig(BaseModel):
|
|||||||
def model_post_init(self, _context) -> None:
|
def model_post_init(self, _context) -> None:
|
||||||
"""Set collection_name to index_id if not provided"""
|
"""Set collection_name to index_id if not provided"""
|
||||||
if self.collection_name is None:
|
if self.collection_name is None:
|
||||||
self.collection_name = self.index_id
|
self.collection_name = self.index_id
|
||||||
|
|||||||
@@ -71,13 +71,14 @@ class Memory(MemoryBase):
|
|||||||
if "vector_store" not in config_dict and "embedder" in config_dict:
|
if "vector_store" not in config_dict and "embedder" in config_dict:
|
||||||
config_dict["vector_store"] = {}
|
config_dict["vector_store"] = {}
|
||||||
config_dict["vector_store"]["config"] = {}
|
config_dict["vector_store"]["config"] = {}
|
||||||
config_dict["vector_store"]["config"]["embedding_model_dims"] = config_dict["embedder"]["config"]["embedding_dims"]
|
config_dict["vector_store"]["config"]["embedding_model_dims"] = config_dict["embedder"]["config"][
|
||||||
|
"embedding_dims"
|
||||||
|
]
|
||||||
try:
|
try:
|
||||||
return config_dict
|
return config_dict
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
logger.error(f"Configuration validation error: {e}")
|
logger.error(f"Configuration validation error: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
@@ -204,7 +205,8 @@ class Memory(MemoryBase):
|
|||||||
messages_embeddings = self.embedding_model.embed(new_mem, "add")
|
messages_embeddings = self.embedding_model.embed(new_mem, "add")
|
||||||
new_message_embeddings[new_mem] = messages_embeddings
|
new_message_embeddings[new_mem] = messages_embeddings
|
||||||
existing_memories = self.vector_store.search(
|
existing_memories = self.vector_store.search(
|
||||||
query=messages_embeddings,
|
query=new_mem,
|
||||||
|
vectors=messages_embeddings,
|
||||||
limit=5,
|
limit=5,
|
||||||
filters=filters,
|
filters=filters,
|
||||||
)
|
)
|
||||||
@@ -222,7 +224,9 @@ class Memory(MemoryBase):
|
|||||||
temp_uuid_mapping[str(idx)] = item["id"]
|
temp_uuid_mapping[str(idx)] = item["id"]
|
||||||
retrieved_old_memory[idx]["id"] = str(idx)
|
retrieved_old_memory[idx]["id"] = str(idx)
|
||||||
|
|
||||||
function_calling_prompt = get_update_memory_messages(retrieved_old_memory, new_retrieved_facts, self.custom_update_memory_prompt)
|
function_calling_prompt = get_update_memory_messages(
|
||||||
|
retrieved_old_memory, new_retrieved_facts, self.custom_update_memory_prompt
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
new_memories_with_actions = self.llm.generate_response(
|
new_memories_with_actions = self.llm.generate_response(
|
||||||
@@ -479,7 +483,7 @@ class Memory(MemoryBase):
|
|||||||
|
|
||||||
def _search_vector_store(self, query, filters, limit):
|
def _search_vector_store(self, query, filters, limit):
|
||||||
embeddings = self.embedding_model.embed(query, "search")
|
embeddings = self.embedding_model.embed(query, "search")
|
||||||
memories = self.vector_store.search(query=embeddings, limit=limit, filters=filters)
|
memories = self.vector_store.search(query=query, vectors=embeddings, limit=limit, filters=filters)
|
||||||
|
|
||||||
excluded_keys = {
|
excluded_keys = {
|
||||||
"user_id",
|
"user_id",
|
||||||
|
|||||||
@@ -45,8 +45,10 @@ class AzureAISearch(VectorStoreBase):
|
|||||||
collection_name,
|
collection_name,
|
||||||
api_key,
|
api_key,
|
||||||
embedding_model_dims,
|
embedding_model_dims,
|
||||||
compression_type: Optional[str] = None,
|
compression_type: Optional[str] = None,
|
||||||
use_float16: bool = False,
|
use_float16: bool = False,
|
||||||
|
hybrid_search: bool = False,
|
||||||
|
vector_filter_mode: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the Azure AI Search vector store.
|
Initialize the Azure AI Search vector store.
|
||||||
@@ -60,13 +62,17 @@ class AzureAISearch(VectorStoreBase):
|
|||||||
Allowed values are None (no quantization), "scalar", or "binary".
|
Allowed values are None (no quantization), "scalar", or "binary".
|
||||||
use_float16 (bool): Whether to store vectors in half precision (Edm.Half) or full precision (Edm.Single).
|
use_float16 (bool): Whether to store vectors in half precision (Edm.Half) or full precision (Edm.Single).
|
||||||
(Note: This flag is preserved from the initial implementation per feedback.)
|
(Note: This flag is preserved from the initial implementation per feedback.)
|
||||||
|
hybrid_search (bool): Whether to use hybrid search. Default is False.
|
||||||
|
vector_filter_mode (Optional[str]): Mode for vector filtering. Default is "preFilter".
|
||||||
"""
|
"""
|
||||||
self.index_name = collection_name
|
self.index_name = collection_name
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
self.embedding_model_dims = embedding_model_dims
|
self.embedding_model_dims = embedding_model_dims
|
||||||
# If compression_type is None, treat it as "none".
|
# If compression_type is None, treat it as "none".
|
||||||
self.compression_type = (compression_type or "none").lower()
|
self.compression_type = (compression_type or "none").lower()
|
||||||
self.use_float16 = use_float16
|
self.use_float16 = use_float16
|
||||||
|
self.hybrid_search = hybrid_search
|
||||||
|
self.vector_filter_mode = vector_filter_mode
|
||||||
|
|
||||||
self.search_client = SearchClient(
|
self.search_client = SearchClient(
|
||||||
endpoint=f"https://{service_name}.search.windows.net",
|
endpoint=f"https://{service_name}.search.windows.net",
|
||||||
@@ -113,8 +119,6 @@ class AzureAISearch(VectorStoreBase):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
# If no compression is desired, compression_configurations remains empty.
|
# If no compression is desired, compression_configurations remains empty.
|
||||||
|
|
||||||
|
|
||||||
fields = [
|
fields = [
|
||||||
SimpleField(name="id", type=SearchFieldDataType.String, key=True),
|
SimpleField(name="id", type=SearchFieldDataType.String, key=True),
|
||||||
SimpleField(name="user_id", type=SearchFieldDataType.String, filterable=True),
|
SimpleField(name="user_id", type=SearchFieldDataType.String, filterable=True),
|
||||||
@@ -123,11 +127,11 @@ class AzureAISearch(VectorStoreBase):
|
|||||||
SearchField(
|
SearchField(
|
||||||
name="vector",
|
name="vector",
|
||||||
type=vector_type,
|
type=vector_type,
|
||||||
searchable=True,
|
searchable=True,
|
||||||
vector_search_dimensions=self.embedding_model_dims,
|
vector_search_dimensions=self.embedding_model_dims,
|
||||||
vector_search_profile_name="my-vector-config",
|
vector_search_profile_name="my-vector-config",
|
||||||
),
|
),
|
||||||
SimpleField(name="payload", type=SearchFieldDataType.String, searchable=True),
|
SearchField(name="payload", type=SearchFieldDataType.String, searchable=True),
|
||||||
]
|
]
|
||||||
|
|
||||||
vector_search = VectorSearch(
|
vector_search = VectorSearch(
|
||||||
@@ -135,7 +139,7 @@ class AzureAISearch(VectorStoreBase):
|
|||||||
VectorSearchProfile(
|
VectorSearchProfile(
|
||||||
name="my-vector-config",
|
name="my-vector-config",
|
||||||
algorithm_configuration_name="my-algorithms-config",
|
algorithm_configuration_name="my-algorithms-config",
|
||||||
compression_name=compression_name if self.compression_type != "none" else None
|
compression_name=compression_name if self.compression_type != "none" else None,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
algorithms=[HnswAlgorithmConfiguration(name="my-algorithms-config")],
|
algorithms=[HnswAlgorithmConfiguration(name="my-algorithms-config")],
|
||||||
@@ -164,8 +168,7 @@ class AzureAISearch(VectorStoreBase):
|
|||||||
"""
|
"""
|
||||||
logger.info(f"Inserting {len(vectors)} vectors into index {self.index_name}")
|
logger.info(f"Inserting {len(vectors)} vectors into index {self.index_name}")
|
||||||
documents = [
|
documents = [
|
||||||
self._generate_document(vector, payload, id)
|
self._generate_document(vector, payload, id) for id, vector, payload in zip(ids, vectors, payloads)
|
||||||
for id, vector, payload in zip(ids, vectors, payloads)
|
|
||||||
]
|
]
|
||||||
response = self.search_client.upload_documents(documents)
|
response = self.search_client.upload_documents(documents)
|
||||||
for doc in response:
|
for doc in response:
|
||||||
@@ -189,12 +192,13 @@ class AzureAISearch(VectorStoreBase):
|
|||||||
filter_expression = " and ".join(filter_conditions)
|
filter_expression = " and ".join(filter_conditions)
|
||||||
return filter_expression
|
return filter_expression
|
||||||
|
|
||||||
def search(self, query, limit=5, filters=None):
|
def search(self, query, vectors, limit=5, filters=None):
|
||||||
"""
|
"""
|
||||||
Search for similar vectors.
|
Search for similar vectors.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (List[float]): Query vector.
|
query (str): Query.
|
||||||
|
vectors (List[float]): Query vector.
|
||||||
limit (int, optional): Number of results to return. Defaults to 5.
|
limit (int, optional): Number of results to return. Defaults to 5.
|
||||||
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
||||||
|
|
||||||
@@ -205,23 +209,28 @@ class AzureAISearch(VectorStoreBase):
|
|||||||
if filters:
|
if filters:
|
||||||
filter_expression = self._build_filter_expression(filters)
|
filter_expression = self._build_filter_expression(filters)
|
||||||
|
|
||||||
vector_query = VectorizedQuery(
|
vector_query = VectorizedQuery(vector=vectors, k_nearest_neighbors=limit, fields="vector")
|
||||||
vector=query, k_nearest_neighbors=limit, fields="vector"
|
if self.hybrid_search:
|
||||||
)
|
search_results = self.search_client.search(
|
||||||
search_results = self.search_client.search(
|
search_text=query,
|
||||||
vector_queries=[vector_query],
|
vector_queries=[vector_query],
|
||||||
filter=filter_expression,
|
filter=filter_expression,
|
||||||
top=limit
|
top=limit,
|
||||||
)
|
vector_filter_mode=self.vector_filter_mode,
|
||||||
|
search_fields=["payload"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
search_results = self.search_client.search(
|
||||||
|
vector_queries=[vector_query],
|
||||||
|
filter=filter_expression,
|
||||||
|
top=limit,
|
||||||
|
vector_filter_mode=self.vector_filter_mode,
|
||||||
|
)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for result in search_results:
|
for result in search_results:
|
||||||
payload = json.loads(result["payload"])
|
payload = json.loads(result["payload"])
|
||||||
results.append(
|
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
|
||||||
OutputData(
|
|
||||||
id=result["id"], score=result["@search.score"], payload=payload
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def delete(self, vector_id):
|
def delete(self, vector_id):
|
||||||
@@ -275,9 +284,7 @@ class AzureAISearch(VectorStoreBase):
|
|||||||
result = self.search_client.get_document(key=vector_id)
|
result = self.search_client.get_document(key=vector_id)
|
||||||
except ResourceNotFoundError:
|
except ResourceNotFoundError:
|
||||||
return None
|
return None
|
||||||
return OutputData(
|
return OutputData(id=result["id"], score=None, payload=json.loads(result["payload"]))
|
||||||
id=result["id"], score=None, payload=json.loads(result["payload"])
|
|
||||||
)
|
|
||||||
|
|
||||||
def list_cols(self) -> List[str]:
|
def list_cols(self) -> List[str]:
|
||||||
"""
|
"""
|
||||||
@@ -321,17 +328,11 @@ class AzureAISearch(VectorStoreBase):
|
|||||||
if filters:
|
if filters:
|
||||||
filter_expression = self._build_filter_expression(filters)
|
filter_expression = self._build_filter_expression(filters)
|
||||||
|
|
||||||
search_results = self.search_client.search(
|
search_results = self.search_client.search(search_text="*", filter=filter_expression, top=limit)
|
||||||
search_text="*", filter=filter_expression, top=limit
|
|
||||||
)
|
|
||||||
results = []
|
results = []
|
||||||
for result in search_results:
|
for result in search_results:
|
||||||
payload = json.loads(result["payload"])
|
payload = json.loads(result["payload"])
|
||||||
results.append(
|
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
|
||||||
OutputData(
|
|
||||||
id=result["id"], score=result["@search.score"], payload=payload
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return [results]
|
return [results]
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ class VectorStoreBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def search(self, query, limit=5, filters=None):
|
def search(self, query, vectors, limit=5, filters=None):
|
||||||
"""Search for similar vectors."""
|
"""Search for similar vectors."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -127,19 +127,22 @@ class ChromaDB(VectorStoreBase):
|
|||||||
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
|
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
|
||||||
self.collection.add(ids=ids, embeddings=vectors, metadatas=payloads)
|
self.collection.add(ids=ids, embeddings=vectors, metadatas=payloads)
|
||||||
|
|
||||||
def search(self, query: List[list], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]:
|
def search(
|
||||||
|
self, query: str, vectors: List[list], limit: int = 5, filters: Optional[Dict] = None
|
||||||
|
) -> List[OutputData]:
|
||||||
"""
|
"""
|
||||||
Search for similar vectors.
|
Search for similar vectors.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (List[list]): Query vector.
|
query (str): Query.
|
||||||
|
vectors (List[list]): List of vectors to search.
|
||||||
limit (int, optional): Number of results to return. Defaults to 5.
|
limit (int, optional): Number of results to return. Defaults to 5.
|
||||||
filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None.
|
filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[OutputData]: Search results.
|
List[OutputData]: Search results.
|
||||||
"""
|
"""
|
||||||
results = self.collection.query(query_embeddings=query, where=filters, n_results=limit)
|
results = self.collection.query(query_embeddings=vectors, where=filters, n_results=limit)
|
||||||
final_results = self._parse_output(results)
|
final_results = self._parse_output(results)
|
||||||
return final_results
|
return final_results
|
||||||
|
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ class ElasticsearchDB(VectorStoreBase):
|
|||||||
# Create index only if auto_create_index is True
|
# Create index only if auto_create_index is True
|
||||||
if config.auto_create_index:
|
if config.auto_create_index:
|
||||||
self.create_index()
|
self.create_index()
|
||||||
|
|
||||||
if config.custom_search_query:
|
if config.custom_search_query:
|
||||||
self.custom_search_query = config.custom_search_query
|
self.custom_search_query = config.custom_search_query
|
||||||
else:
|
else:
|
||||||
@@ -121,16 +121,20 @@ class ElasticsearchDB(VectorStoreBase):
|
|||||||
)
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def search(self, query: List[float], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]:
|
def search(
|
||||||
|
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
|
||||||
|
) -> List[OutputData]:
|
||||||
"""
|
"""
|
||||||
Search with two options:
|
Search with two options:
|
||||||
1. Use custom search query if provided
|
1. Use custom search query if provided
|
||||||
2. Use KNN search on vectors with pre-filtering if no custom search query is provided
|
2. Use KNN search on vectors with pre-filtering if no custom search query is provided
|
||||||
"""
|
"""
|
||||||
if self.custom_search_query:
|
if self.custom_search_query:
|
||||||
search_query = self.custom_search_query(query, limit, filters)
|
search_query = self.custom_search_query(vectors, limit, filters)
|
||||||
else:
|
else:
|
||||||
search_query = {"knn": {"field": "vector", "query_vector": query, "k": limit, "num_candidates": limit * 2}}
|
search_query = {
|
||||||
|
"knn": {"field": "vector", "query_vector": vectors, "k": limit, "num_candidates": limit * 2}
|
||||||
|
}
|
||||||
if filters:
|
if filters:
|
||||||
filter_conditions = []
|
filter_conditions = []
|
||||||
for key, value in filters.items():
|
for key, value in filters.items():
|
||||||
|
|||||||
@@ -134,12 +134,13 @@ class MilvusDB(VectorStoreBase):
|
|||||||
|
|
||||||
return memory
|
return memory
|
||||||
|
|
||||||
def search(self, query: list, limit: int = 5, filters: dict = None) -> list:
|
def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None) -> list:
|
||||||
"""
|
"""
|
||||||
Search for similar vectors.
|
Search for similar vectors.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (List[float]): Query vector.
|
query (str): Query.
|
||||||
|
vectors (List[float]): Query vector.
|
||||||
limit (int, optional): Number of results to return. Defaults to 5.
|
limit (int, optional): Number of results to return. Defaults to 5.
|
||||||
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
||||||
|
|
||||||
@@ -149,7 +150,7 @@ class MilvusDB(VectorStoreBase):
|
|||||||
query_filter = self._create_filter(filters) if filters else None
|
query_filter = self._create_filter(filters) if filters else None
|
||||||
hits = self.client.search(
|
hits = self.client.search(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
data=[query],
|
data=[vectors],
|
||||||
limit=limit,
|
limit=limit,
|
||||||
filter=query_filter,
|
filter=query_filter,
|
||||||
output_fields=["*"],
|
output_fields=["*"],
|
||||||
|
|||||||
@@ -28,10 +28,12 @@ class OpenSearchDB(VectorStoreBase):
|
|||||||
# Initialize OpenSearch client
|
# Initialize OpenSearch client
|
||||||
self.client = OpenSearch(
|
self.client = OpenSearch(
|
||||||
hosts=[{"host": config.host, "port": config.port or 9200}],
|
hosts=[{"host": config.host, "port": config.port or 9200}],
|
||||||
http_auth=config.http_auth if config.http_auth else ((config.user, config.password) if (config.user and config.password) else None),
|
http_auth=config.http_auth
|
||||||
|
if config.http_auth
|
||||||
|
else ((config.user, config.password) if (config.user and config.password) else None),
|
||||||
use_ssl=config.use_ssl,
|
use_ssl=config.use_ssl,
|
||||||
verify_certs=config.verify_certs,
|
verify_certs=config.verify_certs,
|
||||||
connection_class=RequestsHttpConnection
|
connection_class=RequestsHttpConnection,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.collection_name = config.collection_name
|
self.collection_name = config.collection_name
|
||||||
@@ -115,14 +117,16 @@ class OpenSearchDB(VectorStoreBase):
|
|||||||
results.append(OutputData(id=id_, score=1.0, payload=payloads[i]))
|
results.append(OutputData(id=id_, score=1.0, payload=payloads[i]))
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def search(self, query: List[float], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]:
|
def search(
|
||||||
|
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
|
||||||
|
) -> List[OutputData]:
|
||||||
"""Search for similar vectors using OpenSearch k-NN search with pre-filtering."""
|
"""Search for similar vectors using OpenSearch k-NN search with pre-filtering."""
|
||||||
search_query = {
|
search_query = {
|
||||||
"size": limit,
|
"size": limit,
|
||||||
"query": {
|
"query": {
|
||||||
"knn": {
|
"knn": {
|
||||||
"vector": {
|
"vector": {
|
||||||
"vector": query,
|
"vector": vectors,
|
||||||
"k": limit,
|
"k": limit,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -120,12 +120,13 @@ class PGVector(VectorStoreBase):
|
|||||||
)
|
)
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
|
|
||||||
def search(self, query, limit=5, filters=None):
|
def search(self, query, vectors, limit=5, filters=None):
|
||||||
"""
|
"""
|
||||||
Search for similar vectors.
|
Search for similar vectors.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (List[float]): Query vector.
|
query (str): Query.
|
||||||
|
vectors (List[float]): Query vector.
|
||||||
limit (int, optional): Number of results to return. Defaults to 5.
|
limit (int, optional): Number of results to return. Defaults to 5.
|
||||||
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
||||||
|
|
||||||
@@ -150,7 +151,7 @@ class PGVector(VectorStoreBase):
|
|||||||
ORDER BY distance
|
ORDER BY distance
|
||||||
LIMIT %s
|
LIMIT %s
|
||||||
""",
|
""",
|
||||||
(query, *filter_params, limit),
|
(vectors, *filter_params, limit),
|
||||||
)
|
)
|
||||||
|
|
||||||
results = self.cur.fetchall()
|
results = self.cur.fetchall()
|
||||||
|
|||||||
@@ -127,12 +127,13 @@ class Qdrant(VectorStoreBase):
|
|||||||
conditions.append(FieldCondition(key=key, match=MatchValue(value=value)))
|
conditions.append(FieldCondition(key=key, match=MatchValue(value=value)))
|
||||||
return Filter(must=conditions) if conditions else None
|
return Filter(must=conditions) if conditions else None
|
||||||
|
|
||||||
def search(self, query: list, limit: int = 5, filters: dict = None) -> list:
|
def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None) -> list:
|
||||||
"""
|
"""
|
||||||
Search for similar vectors.
|
Search for similar vectors.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (list): Query vector.
|
query (str): Query.
|
||||||
|
vectors (list): Query vector.
|
||||||
limit (int, optional): Number of results to return. Defaults to 5.
|
limit (int, optional): Number of results to return. Defaults to 5.
|
||||||
filters (dict, optional): Filters to apply to the search. Defaults to None.
|
filters (dict, optional): Filters to apply to the search. Defaults to None.
|
||||||
|
|
||||||
@@ -142,7 +143,7 @@ class Qdrant(VectorStoreBase):
|
|||||||
query_filter = self._create_filter(filters) if filters else None
|
query_filter = self._create_filter(filters) if filters else None
|
||||||
hits = self.client.query_points(
|
hits = self.client.query_points(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
query=query,
|
query=vectors,
|
||||||
query_filter=query_filter,
|
query_filter=query_filter,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -101,12 +101,12 @@ class RedisDB(VectorStoreBase):
|
|||||||
data.append(entry)
|
data.append(entry)
|
||||||
self.index.load(data, id_field="memory_id")
|
self.index.load(data, id_field="memory_id")
|
||||||
|
|
||||||
def search(self, query: list, limit: int = 5, filters: dict = None):
|
def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None):
|
||||||
conditions = [Tag(key) == value for key, value in filters.items() if value is not None]
|
conditions = [Tag(key) == value for key, value in filters.items() if value is not None]
|
||||||
filter = reduce(lambda x, y: x & y, conditions)
|
filter = reduce(lambda x, y: x & y, conditions)
|
||||||
|
|
||||||
v = VectorQuery(
|
v = VectorQuery(
|
||||||
vector=np.array(query, dtype=np.float32).tobytes(),
|
vector=np.array(vectors, dtype=np.float32).tobytes(),
|
||||||
vector_field_name="embedding",
|
vector_field_name="embedding",
|
||||||
return_fields=["memory_id", "hash", "agent_id", "run_id", "user_id", "memory", "metadata", "created_at"],
|
return_fields=["memory_id", "hash", "agent_id", "run_id", "user_id", "memory", "metadata", "created_at"],
|
||||||
filter_expression=filter,
|
filter_expression=filter,
|
||||||
|
|||||||
@@ -112,16 +112,18 @@ class Supabase(VectorStoreBase):
|
|||||||
payloads = [{} for _ in vectors]
|
payloads = [{} for _ in vectors]
|
||||||
|
|
||||||
records = [(id, vector, payload) for id, vector, payload in zip(ids, vectors, payloads)]
|
records = [(id, vector, payload) for id, vector, payload in zip(ids, vectors, payloads)]
|
||||||
print(records)
|
|
||||||
|
|
||||||
self.collection.upsert(records)
|
self.collection.upsert(records)
|
||||||
|
|
||||||
def search(self, query: List[float], limit: int = 5, filters: Optional[dict] = None) -> List[OutputData]:
|
def search(
|
||||||
|
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[dict] = None
|
||||||
|
) -> List[OutputData]:
|
||||||
"""
|
"""
|
||||||
Search for similar vectors.
|
Search for similar vectors.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (List[float]): Query vector
|
query (str): Query.
|
||||||
|
vectors (List[float]): Query vector.
|
||||||
limit (int, optional): Number of results to return. Defaults to 5.
|
limit (int, optional): Number of results to return. Defaults to 5.
|
||||||
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
||||||
|
|
||||||
@@ -129,11 +131,9 @@ class Supabase(VectorStoreBase):
|
|||||||
List[OutputData]: Search results
|
List[OutputData]: Search results
|
||||||
"""
|
"""
|
||||||
filters = self._preprocess_filters(filters)
|
filters = self._preprocess_filters(filters)
|
||||||
print(filters)
|
|
||||||
results = self.collection.query(
|
results = self.collection.query(
|
||||||
data=query, limit=limit, filters=filters, include_metadata=True, include_value=True
|
data=vectors, limit=limit, filters=filters, include_metadata=True, include_value=True
|
||||||
)
|
)
|
||||||
print(results)
|
|
||||||
|
|
||||||
return [OutputData(id=str(result[0]), score=float(result[1]), payload=result[2]) for result in results]
|
return [OutputData(id=str(result[0]), score=float(result[1]), payload=result[2]) for result in results]
|
||||||
|
|
||||||
|
|||||||
@@ -32,19 +32,19 @@ class GoogleMatchingEngine(VectorStoreBase):
|
|||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
"""Initialize Google Matching Engine client."""
|
"""Initialize Google Matching Engine client."""
|
||||||
logger.debug("Initializing Google Matching Engine with kwargs: %s", kwargs)
|
logger.debug("Initializing Google Matching Engine with kwargs: %s", kwargs)
|
||||||
|
|
||||||
# If collection_name is passed, use it as deployment_index_id if deployment_index_id is not provided
|
# If collection_name is passed, use it as deployment_index_id if deployment_index_id is not provided
|
||||||
if 'collection_name' in kwargs and 'deployment_index_id' not in kwargs:
|
if "collection_name" in kwargs and "deployment_index_id" not in kwargs:
|
||||||
kwargs['deployment_index_id'] = kwargs['collection_name']
|
kwargs["deployment_index_id"] = kwargs["collection_name"]
|
||||||
logger.debug("Using collection_name as deployment_index_id: %s", kwargs['deployment_index_id'])
|
logger.debug("Using collection_name as deployment_index_id: %s", kwargs["deployment_index_id"])
|
||||||
elif 'deployment_index_id' in kwargs and 'collection_name' not in kwargs:
|
elif "deployment_index_id" in kwargs and "collection_name" not in kwargs:
|
||||||
kwargs['collection_name'] = kwargs['deployment_index_id']
|
kwargs["collection_name"] = kwargs["deployment_index_id"]
|
||||||
logger.debug("Using deployment_index_id as collection_name: %s", kwargs['collection_name'])
|
logger.debug("Using deployment_index_id as collection_name: %s", kwargs["collection_name"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
config = GoogleMatchingEngineConfig(**kwargs)
|
config = GoogleMatchingEngineConfig(**kwargs)
|
||||||
logger.debug("Config created: %s", config.model_dump())
|
logger.debug("Config created: %s", config.model_dump())
|
||||||
logger.debug("Config collection_name: %s", getattr(config, 'collection_name', None))
|
logger.debug("Config collection_name: %s", getattr(config, "collection_name", None))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to validate config: %s", str(e))
|
logger.error("Failed to validate config: %s", str(e))
|
||||||
raise
|
raise
|
||||||
@@ -57,41 +57,37 @@ class GoogleMatchingEngine(VectorStoreBase):
|
|||||||
self.deployment_index_id = config.deployment_index_id # The deployment-specific ID
|
self.deployment_index_id = config.deployment_index_id # The deployment-specific ID
|
||||||
self.collection_name = config.collection_name
|
self.collection_name = config.collection_name
|
||||||
self.vector_search_api_endpoint = config.vector_search_api_endpoint
|
self.vector_search_api_endpoint = config.vector_search_api_endpoint
|
||||||
|
|
||||||
logger.debug("Using project=%s, location=%s", self.project_id, self.region)
|
logger.debug("Using project=%s, location=%s", self.project_id, self.region)
|
||||||
|
|
||||||
# Initialize Vertex AI with credentials if provided
|
# Initialize Vertex AI with credentials if provided
|
||||||
init_args = {
|
init_args = {
|
||||||
"project": self.project_id,
|
"project": self.project_id,
|
||||||
"location": self.region,
|
"location": self.region,
|
||||||
}
|
}
|
||||||
if hasattr(config, 'credentials_path') and config.credentials_path:
|
if hasattr(config, "credentials_path") and config.credentials_path:
|
||||||
logger.debug("Using credentials from: %s", config.credentials_path)
|
logger.debug("Using credentials from: %s", config.credentials_path)
|
||||||
credentials = service_account.Credentials.from_service_account_file(
|
credentials = service_account.Credentials.from_service_account_file(config.credentials_path)
|
||||||
config.credentials_path
|
|
||||||
)
|
|
||||||
init_args["credentials"] = credentials
|
init_args["credentials"] = credentials
|
||||||
|
|
||||||
try:
|
try:
|
||||||
aiplatform.init(**init_args)
|
aiplatform.init(**init_args)
|
||||||
logger.debug("Vertex AI initialized successfully")
|
logger.debug("Vertex AI initialized successfully")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to initialize Vertex AI: %s", str(e))
|
logger.error("Failed to initialize Vertex AI: %s", str(e))
|
||||||
raise
|
raise
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Format the index path properly using the configured index_id
|
# Format the index path properly using the configured index_id
|
||||||
index_path = f"projects/{self.project_number}/locations/{self.region}/indexes/{self.index_id}"
|
index_path = f"projects/{self.project_number}/locations/{self.region}/indexes/{self.index_id}"
|
||||||
logger.debug("Initializing index with path: %s", index_path)
|
logger.debug("Initializing index with path: %s", index_path)
|
||||||
self.index = aiplatform.MatchingEngineIndex(index_name=index_path)
|
self.index = aiplatform.MatchingEngineIndex(index_name=index_path)
|
||||||
logger.debug("Index initialized successfully")
|
logger.debug("Index initialized successfully")
|
||||||
|
|
||||||
# Format the endpoint name properly
|
# Format the endpoint name properly
|
||||||
endpoint_name = self.endpoint_id
|
endpoint_name = self.endpoint_id
|
||||||
logger.debug("Initializing endpoint with name: %s", endpoint_name)
|
logger.debug("Initializing endpoint with name: %s", endpoint_name)
|
||||||
self.index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
|
self.index_endpoint = aiplatform.MatchingEngineIndexEndpoint(index_endpoint_name=endpoint_name)
|
||||||
index_endpoint_name=endpoint_name
|
|
||||||
)
|
|
||||||
logger.debug("Endpoint initialized successfully")
|
logger.debug("Endpoint initialized successfully")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to initialize Matching Engine components: %s", str(e))
|
logger.error("Failed to initialize Matching Engine components: %s", str(e))
|
||||||
@@ -119,47 +115,36 @@ class GoogleMatchingEngine(VectorStoreBase):
|
|||||||
|
|
||||||
def _create_restriction(self, key: str, value: Any) -> aiplatform_v1.types.index.IndexDatapoint.Restriction:
|
def _create_restriction(self, key: str, value: Any) -> aiplatform_v1.types.index.IndexDatapoint.Restriction:
|
||||||
"""Create a restriction object for the Matching Engine index.
|
"""Create a restriction object for the Matching Engine index.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key: The namespace/key for the restriction
|
key: The namespace/key for the restriction
|
||||||
value: The value to restrict on
|
value: The value to restrict on
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Restriction object for the index
|
Restriction object for the index
|
||||||
"""
|
"""
|
||||||
str_value = str(value) if value is not None else ""
|
str_value = str(value) if value is not None else ""
|
||||||
return aiplatform_v1.types.index.IndexDatapoint.Restriction(
|
return aiplatform_v1.types.index.IndexDatapoint.Restriction(namespace=key, allow_list=[str_value])
|
||||||
namespace=key,
|
|
||||||
allow_list=[str_value]
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_datapoint(
|
def _create_datapoint(
|
||||||
self,
|
self, vector_id: str, vector: List[float], payload: Optional[Dict] = None
|
||||||
vector_id: str,
|
|
||||||
vector: List[float],
|
|
||||||
payload: Optional[Dict] = None
|
|
||||||
) -> aiplatform_v1.types.index.IndexDatapoint:
|
) -> aiplatform_v1.types.index.IndexDatapoint:
|
||||||
"""Create a datapoint object for the Matching Engine index.
|
"""Create a datapoint object for the Matching Engine index.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vector_id: The ID for the datapoint
|
vector_id: The ID for the datapoint
|
||||||
vector: The vector to store
|
vector: The vector to store
|
||||||
payload: Optional metadata to store with the vector
|
payload: Optional metadata to store with the vector
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
IndexDatapoint object
|
IndexDatapoint object
|
||||||
"""
|
"""
|
||||||
restrictions = []
|
restrictions = []
|
||||||
if payload:
|
if payload:
|
||||||
restrictions = [
|
restrictions = [self._create_restriction(key, value) for key, value in payload.items()]
|
||||||
self._create_restriction(key, value)
|
|
||||||
for key, value in payload.items()
|
|
||||||
]
|
|
||||||
|
|
||||||
return aiplatform_v1.types.index.IndexDatapoint(
|
return aiplatform_v1.types.index.IndexDatapoint(
|
||||||
datapoint_id=vector_id,
|
datapoint_id=vector_id, feature_vector=vector, restricts=restrictions
|
||||||
feature_vector=vector,
|
|
||||||
restricts=restrictions
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def insert(
|
def insert(
|
||||||
@@ -169,41 +154,41 @@ class GoogleMatchingEngine(VectorStoreBase):
|
|||||||
ids: Optional[List[str]] = None,
|
ids: Optional[List[str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Insert vectors into the Matching Engine index.
|
"""Insert vectors into the Matching Engine index.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vectors: List of vectors to insert
|
vectors: List of vectors to insert
|
||||||
payloads: Optional list of metadata dictionaries
|
payloads: Optional list of metadata dictionaries
|
||||||
ids: Optional list of IDs for the vectors
|
ids: Optional list of IDs for the vectors
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If vectors is empty or lengths don't match
|
ValueError: If vectors is empty or lengths don't match
|
||||||
GoogleAPIError: If the API call fails
|
GoogleAPIError: If the API call fails
|
||||||
"""
|
"""
|
||||||
if not vectors:
|
if not vectors:
|
||||||
raise ValueError("No vectors provided for insertion")
|
raise ValueError("No vectors provided for insertion")
|
||||||
|
|
||||||
if payloads and len(payloads) != len(vectors):
|
if payloads and len(payloads) != len(vectors):
|
||||||
raise ValueError(f"Number of payloads ({len(payloads)}) does not match number of vectors ({len(vectors)})")
|
raise ValueError(f"Number of payloads ({len(payloads)}) does not match number of vectors ({len(vectors)})")
|
||||||
|
|
||||||
if ids and len(ids) != len(vectors):
|
if ids and len(ids) != len(vectors):
|
||||||
raise ValueError(f"Number of ids ({len(ids)}) does not match number of vectors ({len(vectors)})")
|
raise ValueError(f"Number of ids ({len(ids)}) does not match number of vectors ({len(vectors)})")
|
||||||
|
|
||||||
logger.debug("Starting insert of %d vectors", len(vectors))
|
logger.debug("Starting insert of %d vectors", len(vectors))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
datapoints = [
|
datapoints = [
|
||||||
self._create_datapoint(
|
self._create_datapoint(
|
||||||
vector_id=ids[i] if ids else str(uuid.uuid4()),
|
vector_id=ids[i] if ids else str(uuid.uuid4()),
|
||||||
vector=vector,
|
vector=vector,
|
||||||
payload=payloads[i] if payloads and i < len(payloads) else None
|
payload=payloads[i] if payloads and i < len(payloads) else None,
|
||||||
)
|
)
|
||||||
for i, vector in enumerate(vectors)
|
for i, vector in enumerate(vectors)
|
||||||
]
|
]
|
||||||
|
|
||||||
logger.debug("Created %d datapoints", len(datapoints))
|
logger.debug("Created %d datapoints", len(datapoints))
|
||||||
self.index.upsert_datapoints(datapoints=datapoints)
|
self.index.upsert_datapoints(datapoints=datapoints)
|
||||||
logger.debug("Successfully inserted datapoints")
|
logger.debug("Successfully inserted datapoints")
|
||||||
|
|
||||||
except google.api_core.exceptions.GoogleAPIError as e:
|
except google.api_core.exceptions.GoogleAPIError as e:
|
||||||
logger.error("Failed to insert vectors: %s", str(e))
|
logger.error("Failed to insert vectors: %s", str(e))
|
||||||
raise
|
raise
|
||||||
@@ -212,21 +197,22 @@ class GoogleMatchingEngine(VectorStoreBase):
|
|||||||
logger.error("Stack trace: %s", traceback.format_exc())
|
logger.error("Stack trace: %s", traceback.format_exc())
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def search(
|
||||||
def search(self, query: List[float], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]:
|
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
|
||||||
|
) -> List[OutputData]:
|
||||||
"""
|
"""
|
||||||
Search for similar vectors.
|
Search for similar vectors.
|
||||||
Args:
|
Args:
|
||||||
query (List[float]): Query vector.
|
query (str): Query.
|
||||||
|
vectors (List[float]): Query vector.
|
||||||
limit (int, optional): Number of results to return. Defaults to 5.
|
limit (int, optional): Number of results to return. Defaults to 5.
|
||||||
filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None.
|
filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None.
|
||||||
Returns:
|
Returns:
|
||||||
List[OutputData]: Search results (unwrapped)
|
List[OutputData]: Search results (unwrapped)
|
||||||
"""
|
"""
|
||||||
logger.debug("Starting search")
|
logger.debug("Starting search")
|
||||||
logger.debug("Query type: %s, length: %d", type(query), len(query))
|
|
||||||
logger.debug("Limit: %d, Filters: %s", limit, filters)
|
logger.debug("Limit: %d, Filters: %s", limit, filters)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
filter_namespaces = []
|
filter_namespaces = []
|
||||||
if filters:
|
if filters:
|
||||||
@@ -235,53 +221,42 @@ class GoogleMatchingEngine(VectorStoreBase):
|
|||||||
logger.debug("Processing filter %s=%s (type=%s)", key, value, type(value))
|
logger.debug("Processing filter %s=%s (type=%s)", key, value, type(value))
|
||||||
if isinstance(value, (str, int, float)):
|
if isinstance(value, (str, int, float)):
|
||||||
logger.debug("Adding simple filter for %s", key)
|
logger.debug("Adding simple filter for %s", key)
|
||||||
filter_namespaces.append(
|
filter_namespaces.append(Namespace(key, [str(value)], []))
|
||||||
Namespace(key, [str(value)], [])
|
|
||||||
)
|
|
||||||
elif isinstance(value, dict):
|
elif isinstance(value, dict):
|
||||||
logger.debug("Adding complex filter for %s", key)
|
logger.debug("Adding complex filter for %s", key)
|
||||||
includes = value.get('include', [])
|
includes = value.get("include", [])
|
||||||
excludes = value.get('exclude', [])
|
excludes = value.get("exclude", [])
|
||||||
filter_namespaces.append(
|
filter_namespaces.append(Namespace(key, includes, excludes))
|
||||||
Namespace(key, includes, excludes)
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug("Final filter_namespaces: %s", filter_namespaces)
|
logger.debug("Final filter_namespaces: %s", filter_namespaces)
|
||||||
|
|
||||||
response = self.index_endpoint.find_neighbors(
|
response = self.index_endpoint.find_neighbors(
|
||||||
deployed_index_id=self.deployment_index_id,
|
deployed_index_id=self.deployment_index_id,
|
||||||
queries=[query],
|
queries=[vectors],
|
||||||
num_neighbors=limit,
|
num_neighbors=limit,
|
||||||
filter=filter_namespaces if filter_namespaces else None,
|
filter=filter_namespaces if filter_namespaces else None,
|
||||||
return_full_datapoint=True
|
return_full_datapoint=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not response or len(response) == 0 or len(response[0]) == 0:
|
if not response or len(response) == 0 or len(response[0]) == 0:
|
||||||
logger.debug("No results found")
|
logger.debug("No results found")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for neighbor in response[0]:
|
for neighbor in response[0]:
|
||||||
logger.debug("Processing neighbor - id: %s, distance: %s",
|
logger.debug("Processing neighbor - id: %s, distance: %s", neighbor.id, neighbor.distance)
|
||||||
neighbor.id, neighbor.distance)
|
|
||||||
|
|
||||||
payload = {}
|
payload = {}
|
||||||
if hasattr(neighbor, 'restricts'):
|
if hasattr(neighbor, "restricts"):
|
||||||
logger.debug("Processing restricts")
|
logger.debug("Processing restricts")
|
||||||
for restrict in neighbor.restricts:
|
for restrict in neighbor.restricts:
|
||||||
if (hasattr(restrict, 'name') and
|
if hasattr(restrict, "name") and hasattr(restrict, "allow_tokens") and restrict.allow_tokens:
|
||||||
hasattr(restrict, 'allow_tokens') and
|
|
||||||
restrict.allow_tokens):
|
|
||||||
logger.debug("Adding %s: %s", restrict.name, restrict.allow_tokens[0])
|
logger.debug("Adding %s: %s", restrict.name, restrict.allow_tokens[0])
|
||||||
payload[restrict.name] = restrict.allow_tokens[0]
|
payload[restrict.name] = restrict.allow_tokens[0]
|
||||||
|
|
||||||
output_data = OutputData(
|
output_data = OutputData(id=neighbor.id, score=neighbor.distance, payload=payload)
|
||||||
id=neighbor.id,
|
|
||||||
score=neighbor.distance,
|
|
||||||
payload=payload
|
|
||||||
)
|
|
||||||
results.append(output_data)
|
results.append(output_data)
|
||||||
|
|
||||||
logger.debug("Returning %d results", len(results))
|
logger.debug("Returning %d results", len(results))
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@@ -291,7 +266,6 @@ class GoogleMatchingEngine(VectorStoreBase):
|
|||||||
logger.error("Stack trace: %s", traceback.format_exc())
|
logger.error("Stack trace: %s", traceback.format_exc())
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def delete(self, vector_id: Optional[str] = None, ids: Optional[List[str]] = None) -> bool:
|
def delete(self, vector_id: Optional[str] = None, ids: Optional[List[str]] = None) -> bool:
|
||||||
"""
|
"""
|
||||||
Delete vectors from the Matching Engine index.
|
Delete vectors from the Matching Engine index.
|
||||||
@@ -326,14 +300,13 @@ class GoogleMatchingEngine(VectorStoreBase):
|
|||||||
except google.api_core.exceptions.InvalidArgument as e:
|
except google.api_core.exceptions.InvalidArgument as e:
|
||||||
logger.error("Invalid argument: %s", str(e))
|
logger.error("Invalid argument: %s", str(e))
|
||||||
return False
|
return False
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error occurred: %s", str(e))
|
logger.error("Error occurred: %s", str(e))
|
||||||
logger.error("Error type: %s", type(e))
|
logger.error("Error type: %s", type(e))
|
||||||
logger.error("Stack trace: %s", traceback.format_exc())
|
logger.error("Stack trace: %s", traceback.format_exc())
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
vector_id: str,
|
vector_id: str,
|
||||||
@@ -341,42 +314,40 @@ class GoogleMatchingEngine(VectorStoreBase):
|
|||||||
payload: Optional[Dict] = None,
|
payload: Optional[Dict] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Update a vector and its payload.
|
"""Update a vector and its payload.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vector_id: ID of the vector to update
|
vector_id: ID of the vector to update
|
||||||
vector: Optional new vector values
|
vector: Optional new vector values
|
||||||
payload: Optional new metadata payload
|
payload: Optional new metadata payload
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if update was successful
|
bool: True if update was successful
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If neither vector nor payload is provided
|
ValueError: If neither vector nor payload is provided
|
||||||
GoogleAPIError: If the API call fails
|
GoogleAPIError: If the API call fails
|
||||||
"""
|
"""
|
||||||
logger.debug("Starting update for vector_id: %s", vector_id)
|
logger.debug("Starting update for vector_id: %s", vector_id)
|
||||||
|
|
||||||
if vector is None and payload is None:
|
if vector is None and payload is None:
|
||||||
raise ValueError("Either vector or payload must be provided for update")
|
raise ValueError("Either vector or payload must be provided for update")
|
||||||
|
|
||||||
# First check if the vector exists
|
# First check if the vector exists
|
||||||
try:
|
try:
|
||||||
existing = self.get(vector_id)
|
existing = self.get(vector_id)
|
||||||
if existing is None:
|
if existing is None:
|
||||||
logger.error("Vector ID not found: %s", vector_id)
|
logger.error("Vector ID not found: %s", vector_id)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
datapoint = self._create_datapoint(
|
datapoint = self._create_datapoint(
|
||||||
vector_id=vector_id,
|
vector_id=vector_id, vector=vector if vector is not None else [], payload=payload
|
||||||
vector=vector if vector is not None else [],
|
|
||||||
payload=payload
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("Upserting datapoint: %s", datapoint)
|
logger.debug("Upserting datapoint: %s", datapoint)
|
||||||
self.index.upsert_datapoints(datapoints=[datapoint])
|
self.index.upsert_datapoints(datapoints=[datapoint])
|
||||||
logger.debug("Update completed successfully")
|
logger.debug("Update completed successfully")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except google.api_core.exceptions.GoogleAPIError as e:
|
except google.api_core.exceptions.GoogleAPIError as e:
|
||||||
logger.error("API error during update: %s", str(e))
|
logger.error("API error during update: %s", str(e))
|
||||||
return False
|
return False
|
||||||
@@ -385,7 +356,6 @@ class GoogleMatchingEngine(VectorStoreBase):
|
|||||||
logger.error("Stack trace: %s", traceback.format_exc())
|
logger.error("Stack trace: %s", traceback.format_exc())
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def get(self, vector_id: str) -> Optional[OutputData]:
|
def get(self, vector_id: str) -> Optional[OutputData]:
|
||||||
"""
|
"""
|
||||||
Retrieve a vector by ID.
|
Retrieve a vector by ID.
|
||||||
@@ -395,24 +365,17 @@ class GoogleMatchingEngine(VectorStoreBase):
|
|||||||
Optional[OutputData]: Retrieved vector or None if not found.
|
Optional[OutputData]: Retrieved vector or None if not found.
|
||||||
"""
|
"""
|
||||||
logger.debug("Starting get for vector_id: %s", vector_id)
|
logger.debug("Starting get for vector_id: %s", vector_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not self.vector_search_api_endpoint:
|
if not self.vector_search_api_endpoint:
|
||||||
raise ValueError("vector_search_api_endpoint is required for get operation")
|
raise ValueError("vector_search_api_endpoint is required for get operation")
|
||||||
|
|
||||||
vector_search_client = aiplatform_v1.MatchServiceClient(
|
vector_search_client = aiplatform_v1.MatchServiceClient(
|
||||||
client_options={
|
client_options={"api_endpoint": self.vector_search_api_endpoint},
|
||||||
"api_endpoint": self.vector_search_api_endpoint
|
|
||||||
},
|
|
||||||
)
|
|
||||||
datapoint = aiplatform_v1.IndexDatapoint(
|
|
||||||
datapoint_id=vector_id
|
|
||||||
)
|
)
|
||||||
|
datapoint = aiplatform_v1.IndexDatapoint(datapoint_id=vector_id)
|
||||||
|
|
||||||
query = aiplatform_v1.FindNeighborsRequest.Query(
|
query = aiplatform_v1.FindNeighborsRequest.Query(datapoint=datapoint, neighbor_count=1)
|
||||||
datapoint=datapoint,
|
|
||||||
neighbor_count=1
|
|
||||||
)
|
|
||||||
request = aiplatform_v1.FindNeighborsRequest(
|
request = aiplatform_v1.FindNeighborsRequest(
|
||||||
index_endpoint=f"projects/{self.project_number}/locations/{self.region}/indexEndpoints/{self.endpoint_id}",
|
index_endpoint=f"projects/{self.project_number}/locations/{self.region}/indexEndpoints/{self.endpoint_id}",
|
||||||
deployed_index_id=self.deployment_index_id,
|
deployed_index_id=self.deployment_index_id,
|
||||||
@@ -423,41 +386,36 @@ class GoogleMatchingEngine(VectorStoreBase):
|
|||||||
try:
|
try:
|
||||||
response = vector_search_client.find_neighbors(request)
|
response = vector_search_client.find_neighbors(request)
|
||||||
logger.debug("Got response")
|
logger.debug("Got response")
|
||||||
|
|
||||||
if response and response.nearest_neighbors:
|
if response and response.nearest_neighbors:
|
||||||
nearest = response.nearest_neighbors[0]
|
nearest = response.nearest_neighbors[0]
|
||||||
if nearest.neighbors:
|
if nearest.neighbors:
|
||||||
neighbor = nearest.neighbors[0]
|
neighbor = nearest.neighbors[0]
|
||||||
|
|
||||||
payload = {}
|
payload = {}
|
||||||
if hasattr(neighbor.datapoint, 'restricts'):
|
if hasattr(neighbor.datapoint, "restricts"):
|
||||||
for restrict in neighbor.datapoint.restricts:
|
for restrict in neighbor.datapoint.restricts:
|
||||||
if restrict.allow_list:
|
if restrict.allow_list:
|
||||||
payload[restrict.namespace] = restrict.allow_list[0]
|
payload[restrict.namespace] = restrict.allow_list[0]
|
||||||
|
|
||||||
return OutputData(
|
return OutputData(id=neighbor.datapoint.datapoint_id, score=neighbor.distance, payload=payload)
|
||||||
id=neighbor.datapoint.datapoint_id,
|
|
||||||
score=neighbor.distance,
|
|
||||||
payload=payload
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug("No results found")
|
logger.debug("No results found")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
except google.api_core.exceptions.NotFound:
|
except google.api_core.exceptions.NotFound:
|
||||||
logger.debug("Datapoint not found")
|
logger.debug("Datapoint not found")
|
||||||
return None
|
return None
|
||||||
except google.api_core.exceptions.PermissionDenied as e:
|
except google.api_core.exceptions.PermissionDenied as e:
|
||||||
logger.error("Permission denied: %s", str(e))
|
logger.error("Permission denied: %s", str(e))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error occurred: %s", str(e))
|
logger.error("Error occurred: %s", str(e))
|
||||||
logger.error("Error type: %s", type(e))
|
logger.error("Error type: %s", type(e))
|
||||||
logger.error("Stack trace: %s", traceback.format_exc())
|
logger.error("Stack trace: %s", traceback.format_exc())
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def list_cols(self) -> List[str]:
|
def list_cols(self) -> List[str]:
|
||||||
"""
|
"""
|
||||||
List all collections (indexes).
|
List all collections (indexes).
|
||||||
@@ -466,7 +424,6 @@ class GoogleMatchingEngine(VectorStoreBase):
|
|||||||
"""
|
"""
|
||||||
return [self.deployment_index_id]
|
return [self.deployment_index_id]
|
||||||
|
|
||||||
|
|
||||||
def delete_col(self):
|
def delete_col(self):
|
||||||
"""
|
"""
|
||||||
Delete a collection (index).
|
Delete a collection (index).
|
||||||
@@ -475,7 +432,6 @@ class GoogleMatchingEngine(VectorStoreBase):
|
|||||||
logger.warning("Delete collection operation is not supported for Google Matching Engine")
|
logger.warning("Delete collection operation is not supported for Google Matching Engine")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def col_info(self) -> Dict:
|
def col_info(self) -> Dict:
|
||||||
"""
|
"""
|
||||||
Get information about a collection (index).
|
Get information about a collection (index).
|
||||||
@@ -486,17 +442,16 @@ class GoogleMatchingEngine(VectorStoreBase):
|
|||||||
"index_id": self.index_id,
|
"index_id": self.index_id,
|
||||||
"endpoint_id": self.endpoint_id,
|
"endpoint_id": self.endpoint_id,
|
||||||
"project_id": self.project_id,
|
"project_id": self.project_id,
|
||||||
"region": self.region
|
"region": self.region,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[List[OutputData]]:
|
def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[List[OutputData]]:
|
||||||
"""List vectors matching the given filters.
|
"""List vectors matching the given filters.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filters: Optional filters to apply
|
filters: Optional filters to apply
|
||||||
limit: Optional maximum number of results to return
|
limit: Optional maximum number of results to return
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[List[OutputData]]: List of matching vectors wrapped in an extra array
|
List[List[OutputData]]: List of matching vectors wrapped in an extra array
|
||||||
to match the interface
|
to match the interface
|
||||||
@@ -504,36 +459,31 @@ class GoogleMatchingEngine(VectorStoreBase):
|
|||||||
logger.debug("Starting list operation")
|
logger.debug("Starting list operation")
|
||||||
logger.debug("Filters: %s", filters)
|
logger.debug("Filters: %s", filters)
|
||||||
logger.debug("Limit: %s", limit)
|
logger.debug("Limit: %s", limit)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Use a zero vector for the search
|
# Use a zero vector for the search
|
||||||
dimension = 768 # This should be configurable based on the model
|
dimension = 768 # This should be configurable based on the model
|
||||||
zero_vector = [0.0] * dimension
|
zero_vector = [0.0] * dimension
|
||||||
|
|
||||||
# Use a large limit if none specified
|
# Use a large limit if none specified
|
||||||
search_limit = limit if limit is not None else 10000
|
search_limit = limit if limit is not None else 10000
|
||||||
|
|
||||||
results = self.search(
|
results = self.search(query=zero_vector, limit=search_limit, filters=filters)
|
||||||
query=zero_vector,
|
|
||||||
limit=search_limit,
|
|
||||||
filters=filters
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug("Found %d results", len(results))
|
logger.debug("Found %d results", len(results))
|
||||||
return [results] # Wrap in extra array to match interface
|
return [results] # Wrap in extra array to match interface
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error in list operation: %s", str(e))
|
logger.error("Error in list operation: %s", str(e))
|
||||||
logger.error("Stack trace: %s", traceback.format_exc())
|
logger.error("Stack trace: %s", traceback.format_exc())
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def create_col(self, name=None, vector_size=None, distance=None):
|
def create_col(self, name=None, vector_size=None, distance=None):
|
||||||
"""
|
"""
|
||||||
Create a new collection. For Google Matching Engine, collections (indexes)
|
Create a new collection. For Google Matching Engine, collections (indexes)
|
||||||
are created through the Google Cloud Console or API separately.
|
are created through the Google Cloud Console or API separately.
|
||||||
This method is a no-op since indexes are pre-created.
|
This method is a no-op since indexes are pre-created.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: Ignored for Google Matching Engine
|
name: Ignored for Google Matching Engine
|
||||||
vector_size: Ignored for Google Matching Engine
|
vector_size: Ignored for Google Matching Engine
|
||||||
@@ -543,41 +493,35 @@ class GoogleMatchingEngine(VectorStoreBase):
|
|||||||
# This method is included only to satisfy the abstract base class
|
# This method is included only to satisfy the abstract base class
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def add(self, text: str, metadata: Optional[Dict] = None, user_id: Optional[str] = None) -> str:
|
def add(self, text: str, metadata: Optional[Dict] = None, user_id: Optional[str] = None) -> str:
|
||||||
logger.debug("Starting add operation")
|
logger.debug("Starting add operation")
|
||||||
logger.debug("Text: %s", text)
|
logger.debug("Text: %s", text)
|
||||||
logger.debug("Metadata: %s", metadata)
|
logger.debug("Metadata: %s", metadata)
|
||||||
logger.debug("User ID: %s", user_id)
|
logger.debug("User ID: %s", user_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Generate a unique ID for this entry
|
# Generate a unique ID for this entry
|
||||||
vector_id = str(uuid.uuid4())
|
vector_id = str(uuid.uuid4())
|
||||||
|
|
||||||
# Create the payload with all necessary fields
|
# Create the payload with all necessary fields
|
||||||
payload = {
|
payload = {
|
||||||
"data": text, # Store the text in the data field
|
"data": text, # Store the text in the data field
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
**(metadata or {})
|
**(metadata or {}),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Get the embedding
|
# Get the embedding
|
||||||
vector = self.embedder.embed_query(text)
|
vector = self.embedder.embed_query(text)
|
||||||
|
|
||||||
# Insert using the insert method
|
# Insert using the insert method
|
||||||
self.insert(
|
self.insert(vectors=[vector], payloads=[payload], ids=[vector_id])
|
||||||
vectors=[vector],
|
|
||||||
payloads=[payload],
|
|
||||||
ids=[vector_id]
|
|
||||||
)
|
|
||||||
|
|
||||||
return vector_id
|
return vector_id
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error occurred: %s", str(e))
|
logger.error("Error occurred: %s", str(e))
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def add_texts(
|
def add_texts(
|
||||||
self,
|
self,
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
@@ -585,47 +529,45 @@ class GoogleMatchingEngine(VectorStoreBase):
|
|||||||
ids: Optional[List[str]] = None,
|
ids: Optional[List[str]] = None,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""Add texts to the vector store.
|
"""Add texts to the vector store.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
texts: List of texts to add
|
texts: List of texts to add
|
||||||
metadatas: Optional list of metadata dicts
|
metadatas: Optional list of metadata dicts
|
||||||
ids: Optional list of IDs to use
|
ids: Optional list of IDs to use
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[str]: List of IDs of the added texts
|
List[str]: List of IDs of the added texts
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If texts is empty or lengths don't match
|
ValueError: If texts is empty or lengths don't match
|
||||||
"""
|
"""
|
||||||
if not texts:
|
if not texts:
|
||||||
raise ValueError("No texts provided")
|
raise ValueError("No texts provided")
|
||||||
|
|
||||||
if metadatas and len(metadatas) != len(texts):
|
if metadatas and len(metadatas) != len(texts):
|
||||||
raise ValueError(f"Number of metadata items ({len(metadatas)}) does not match number of texts ({len(texts)})")
|
raise ValueError(
|
||||||
|
f"Number of metadata items ({len(metadatas)}) does not match number of texts ({len(texts)})"
|
||||||
|
)
|
||||||
|
|
||||||
if ids and len(ids) != len(texts):
|
if ids and len(ids) != len(texts):
|
||||||
raise ValueError(f"Number of ids ({len(ids)}) does not match number of texts ({len(texts)})")
|
raise ValueError(f"Number of ids ({len(ids)}) does not match number of texts ({len(texts)})")
|
||||||
|
|
||||||
logger.debug("Starting add_texts operation")
|
logger.debug("Starting add_texts operation")
|
||||||
logger.debug("Number of texts: %d", len(texts))
|
logger.debug("Number of texts: %d", len(texts))
|
||||||
logger.debug("Has metadatas: %s", metadatas is not None)
|
logger.debug("Has metadatas: %s", metadatas is not None)
|
||||||
logger.debug("Has ids: %s", ids is not None)
|
logger.debug("Has ids: %s", ids is not None)
|
||||||
|
|
||||||
if ids is None:
|
if ids is None:
|
||||||
ids = [str(uuid.uuid4()) for _ in texts]
|
ids = [str(uuid.uuid4()) for _ in texts]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get embeddings
|
# Get embeddings
|
||||||
embeddings = self.embedder.embed_documents(texts)
|
embeddings = self.embedder.embed_documents(texts)
|
||||||
|
|
||||||
# Add to store
|
# Add to store
|
||||||
self.insert(
|
self.insert(vectors=embeddings, payloads=metadatas if metadatas else [{}] * len(texts), ids=ids)
|
||||||
vectors=embeddings,
|
|
||||||
payloads=metadatas if metadatas else [{}] * len(texts),
|
|
||||||
ids=ids
|
|
||||||
)
|
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error in add_texts: %s", str(e))
|
logger.error("Error in add_texts: %s", str(e))
|
||||||
logger.error("Stack trace: %s", traceback.format_exc())
|
logger.error("Stack trace: %s", traceback.format_exc())
|
||||||
@@ -657,18 +599,12 @@ class GoogleMatchingEngine(VectorStoreBase):
|
|||||||
logger.debug("Query: %s", query)
|
logger.debug("Query: %s", query)
|
||||||
logger.debug("k: %d", k)
|
logger.debug("k: %d", k)
|
||||||
logger.debug("Filter: %s", filter)
|
logger.debug("Filter: %s", filter)
|
||||||
|
|
||||||
embedding = self.embedder.embed_query(query)
|
embedding = self.embedder.embed_query(query)
|
||||||
results = self.search(query=embedding, limit=k, filters=filter)
|
results = self.search(query=embedding, limit=k, filters=filter)
|
||||||
|
|
||||||
docs_and_scores = [
|
docs_and_scores = [
|
||||||
(
|
(Document(page_content=result.payload.get("text", ""), metadata=result.payload), result.score)
|
||||||
Document(
|
|
||||||
page_content=result.payload.get("text", ""),
|
|
||||||
metadata=result.payload
|
|
||||||
),
|
|
||||||
result.score
|
|
||||||
)
|
|
||||||
for result in results
|
for result in results
|
||||||
]
|
]
|
||||||
logger.debug("Found %d results", len(docs_and_scores))
|
logger.debug("Found %d results", len(docs_and_scores))
|
||||||
@@ -684,4 +620,3 @@ class GoogleMatchingEngine(VectorStoreBase):
|
|||||||
logger.debug("Starting similarity search")
|
logger.debug("Starting similarity search")
|
||||||
docs_and_scores = self.similarity_search_with_score(query, k, filter)
|
docs_and_scores = self.similarity_search_with_score(query, k, filter)
|
||||||
return [doc for doc, _ in docs_and_scores]
|
return [doc for doc, _ in docs_and_scores]
|
||||||
|
|
||||||
|
|||||||
@@ -154,7 +154,9 @@ class Weaviate(VectorStoreBase):
|
|||||||
|
|
||||||
batch.add_object(collection=self.collection_name, properties=data_object, uuid=object_id, vector=vector)
|
batch.add_object(collection=self.collection_name, properties=data_object, uuid=object_id, vector=vector)
|
||||||
|
|
||||||
def search(self, query: List[float], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]:
|
def search(
|
||||||
|
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
|
||||||
|
) -> List[OutputData]:
|
||||||
"""
|
"""
|
||||||
Search for similar vectors.
|
Search for similar vectors.
|
||||||
"""
|
"""
|
||||||
@@ -167,7 +169,7 @@ class Weaviate(VectorStoreBase):
|
|||||||
combined_filter = Filter.all_of(filter_conditions) if filter_conditions else None
|
combined_filter = Filter.all_of(filter_conditions) if filter_conditions else None
|
||||||
response = collection.query.hybrid(
|
response = collection.query.hybrid(
|
||||||
query="",
|
query="",
|
||||||
vector=query,
|
vector=vectors,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
filters=combined_filter,
|
filters=combined_filter,
|
||||||
return_properties=["hash", "created_at", "updated_at", "user_id", "agent_id", "run_id", "data", "category"],
|
return_properties=["hash", "created_at", "updated_at", "user_id", "agent_id", "run_id", "data", "category"],
|
||||||
|
|||||||
@@ -35,12 +35,11 @@ def test_search_vectors(chromadb_instance, mock_chromadb_client):
|
|||||||
}
|
}
|
||||||
chromadb_instance.collection.query.return_value = mock_result
|
chromadb_instance.collection.query.return_value = mock_result
|
||||||
|
|
||||||
query = [[0.1, 0.2, 0.3]]
|
vectors = [[0.1, 0.2, 0.3]]
|
||||||
results = chromadb_instance.search(query=query, limit=2)
|
results = chromadb_instance.search(query="", vectors=vectors, limit=2)
|
||||||
|
|
||||||
chromadb_instance.collection.query.assert_called_once_with(query_embeddings=query, where=None, n_results=2)
|
chromadb_instance.collection.query.assert_called_once_with(query_embeddings=vectors, where=None, n_results=2)
|
||||||
|
|
||||||
print(results, type(results))
|
|
||||||
assert len(results) == 2
|
assert len(results) == 2
|
||||||
assert results[0].id == "id1"
|
assert results[0].id == "id1"
|
||||||
assert results[0].score == 0.1
|
assert results[0].score == 0.1
|
||||||
|
|||||||
@@ -196,8 +196,8 @@ class TestElasticsearchDB(unittest.TestCase):
|
|||||||
self.client_mock.search.return_value = mock_response
|
self.client_mock.search.return_value = mock_response
|
||||||
|
|
||||||
# Perform search
|
# Perform search
|
||||||
query_vector = [0.1] * 1536
|
vectors = [[0.1] * 1536]
|
||||||
results = self.es_db.search(query=query_vector, limit=5)
|
results = self.es_db.search(query="", vectors=vectors, limit=5)
|
||||||
|
|
||||||
# Verify search call
|
# Verify search call
|
||||||
self.client_mock.search.assert_called_once()
|
self.client_mock.search.assert_called_once()
|
||||||
@@ -210,7 +210,7 @@ class TestElasticsearchDB(unittest.TestCase):
|
|||||||
# Verify KNN query structure
|
# Verify KNN query structure
|
||||||
self.assertIn("knn", body)
|
self.assertIn("knn", body)
|
||||||
self.assertEqual(body["knn"]["field"], "vector")
|
self.assertEqual(body["knn"]["field"], "vector")
|
||||||
self.assertEqual(body["knn"]["query_vector"], query_vector)
|
self.assertEqual(body["knn"]["query_vector"], vectors)
|
||||||
self.assertEqual(body["knn"]["k"], 5)
|
self.assertEqual(body["knn"]["k"], 5)
|
||||||
self.assertEqual(body["knn"]["num_candidates"], 10)
|
self.assertEqual(body["knn"]["num_candidates"], 10)
|
||||||
|
|
||||||
@@ -226,13 +226,13 @@ class TestElasticsearchDB(unittest.TestCase):
|
|||||||
self.es_db.custom_search_query.return_value = {"custom_key": "custom_value"}
|
self.es_db.custom_search_query.return_value = {"custom_key": "custom_value"}
|
||||||
|
|
||||||
# Perform search
|
# Perform search
|
||||||
query_vector = [0.1] * 1536
|
vectors = [[0.1] * 1536]
|
||||||
limit = 5
|
limit = 5
|
||||||
filters = {"key1": "value1"}
|
filters = {"key1": "value1"}
|
||||||
self.es_db.search(query=query_vector, limit=limit, filters=filters)
|
self.es_db.search(query="", vectors=vectors, limit=limit, filters=filters)
|
||||||
|
|
||||||
# Verify custom search query function was called
|
# Verify custom search query function was called
|
||||||
self.es_db.custom_search_query.assert_called_once_with(query_vector, limit, filters)
|
self.es_db.custom_search_query.assert_called_once_with(vectors, limit, filters)
|
||||||
|
|
||||||
# Verify custom search query was used
|
# Verify custom search query was used
|
||||||
self.client_mock.search.assert_called_once_with(index=self.es_db.collection_name, body={"custom_key": "custom_value"})
|
self.client_mock.search.assert_called_once_with(index=self.es_db.collection_name, body={"custom_key": "custom_value"})
|
||||||
|
|||||||
@@ -126,15 +126,15 @@ class TestOpenSearchDB(unittest.TestCase):
|
|||||||
def test_search(self):
|
def test_search(self):
|
||||||
mock_response = {"hits": {"hits": [{"_id": "id1", "_score": 0.8, "_source": {"vector": [0.1] * 1536, "metadata": {"key1": "value1"}}}]}}
|
mock_response = {"hits": {"hits": [{"_id": "id1", "_score": 0.8, "_source": {"vector": [0.1] * 1536, "metadata": {"key1": "value1"}}}]}}
|
||||||
self.client_mock.search.return_value = mock_response
|
self.client_mock.search.return_value = mock_response
|
||||||
query_vector = [0.1] * 1536
|
vectors = [[0.1] * 1536]
|
||||||
results = self.os_db.search(query=query_vector, limit=5)
|
results = self.os_db.search(query="", vectors=vectors, limit=5)
|
||||||
self.client_mock.search.assert_called_once()
|
self.client_mock.search.assert_called_once()
|
||||||
search_args = self.client_mock.search.call_args[1]
|
search_args = self.client_mock.search.call_args[1]
|
||||||
self.assertEqual(search_args["index"], "test_collection")
|
self.assertEqual(search_args["index"], "test_collection")
|
||||||
body = search_args["body"]
|
body = search_args["body"]
|
||||||
self.assertIn("knn", body["query"])
|
self.assertIn("knn", body["query"])
|
||||||
self.assertIn("vector", body["query"]["knn"])
|
self.assertIn("vector", body["query"]["knn"])
|
||||||
self.assertEqual(body["query"]["knn"]["vector"]["vector"], query_vector)
|
self.assertEqual(body["query"]["knn"]["vector"]["vector"], vectors)
|
||||||
self.assertEqual(body["query"]["knn"]["vector"]["k"], 5)
|
self.assertEqual(body["query"]["knn"]["vector"]["k"], 5)
|
||||||
self.assertEqual(len(results), 1)
|
self.assertEqual(len(results), 1)
|
||||||
self.assertEqual(results[0].id, "id1")
|
self.assertEqual(results[0].id, "id1")
|
||||||
|
|||||||
@@ -50,15 +50,15 @@ class TestQdrant(unittest.TestCase):
|
|||||||
self.assertEqual(points[0].payload, payloads[0])
|
self.assertEqual(points[0].payload, payloads[0])
|
||||||
|
|
||||||
def test_search(self):
|
def test_search(self):
|
||||||
query_vector = [0.1, 0.2]
|
vectors = [[0.1, 0.2]]
|
||||||
mock_point = MagicMock(id=str(uuid.uuid4()), score=0.95, payload={"key": "value"})
|
mock_point = MagicMock(id=str(uuid.uuid4()), score=0.95, payload={"key": "value"})
|
||||||
self.client_mock.query_points.return_value = MagicMock(points=[mock_point])
|
self.client_mock.query_points.return_value = MagicMock(points=[mock_point])
|
||||||
|
|
||||||
results = self.qdrant.search(query=query_vector, limit=1)
|
results = self.qdrant.search(query="", vectors=vectors, limit=1)
|
||||||
|
|
||||||
self.client_mock.query_points.assert_called_once_with(
|
self.client_mock.query_points.assert_called_once_with(
|
||||||
collection_name="test_collection",
|
collection_name="test_collection",
|
||||||
query=query_vector,
|
query=vectors,
|
||||||
query_filter=None,
|
query_filter=None,
|
||||||
limit=1,
|
limit=1,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -77,12 +77,12 @@ def test_search_vectors(supabase_instance, mock_collection):
|
|||||||
]
|
]
|
||||||
mock_collection.query.return_value = mock_results
|
mock_collection.query.return_value = mock_results
|
||||||
|
|
||||||
query = [0.1, 0.2, 0.3]
|
vectors = [[0.1, 0.2, 0.3]]
|
||||||
filters = {"category": "test"}
|
filters = {"category": "test"}
|
||||||
results = supabase_instance.search(query=query, limit=2, filters=filters)
|
results = supabase_instance.search(query="", vectors=vectors, limit=2, filters=filters)
|
||||||
|
|
||||||
mock_collection.query.assert_called_once_with(
|
mock_collection.query.assert_called_once_with(
|
||||||
data=query,
|
data=vectors,
|
||||||
limit=2,
|
limit=2,
|
||||||
filters={"category": {"$eq": "test"}},
|
filters={"category": {"$eq": "test"}},
|
||||||
include_metadata=True,
|
include_metadata=True,
|
||||||
|
|||||||
@@ -73,12 +73,12 @@ def test_insert_vectors(vector_store, mock_vertex_ai):
|
|||||||
|
|
||||||
def test_search_vectors(vector_store, mock_vertex_ai):
|
def test_search_vectors(vector_store, mock_vertex_ai):
|
||||||
"""Test searching vectors with filters"""
|
"""Test searching vectors with filters"""
|
||||||
query = [0.1, 0.2, 0.3]
|
vectors = [[0.1, 0.2, 0.3]]
|
||||||
filters = {"user_id": "test_user"}
|
filters = {"user_id": "test_user"}
|
||||||
|
|
||||||
mock_datapoint = Mock()
|
mock_datapoint = Mock()
|
||||||
mock_datapoint.datapoint_id = "test-id"
|
mock_datapoint.datapoint_id = "test-id"
|
||||||
mock_datapoint.feature_vector = query
|
mock_datapoint.feature_vector = vectors
|
||||||
|
|
||||||
mock_restrict = Mock()
|
mock_restrict = Mock()
|
||||||
mock_restrict.namespace = "user_id"
|
mock_restrict.namespace = "user_id"
|
||||||
@@ -96,11 +96,11 @@ def test_search_vectors(vector_store, mock_vertex_ai):
|
|||||||
|
|
||||||
mock_vertex_ai['endpoint'].find_neighbors.return_value = [[mock_neighbor]]
|
mock_vertex_ai['endpoint'].find_neighbors.return_value = [[mock_neighbor]]
|
||||||
|
|
||||||
results = vector_store.search(query=query, filters=filters, limit=1)
|
results = vector_store.search(query="", vectors=vectors, filters=filters, limit=1)
|
||||||
|
|
||||||
mock_vertex_ai['endpoint'].find_neighbors.assert_called_once_with(
|
mock_vertex_ai['endpoint'].find_neighbors.assert_called_once_with(
|
||||||
deployed_index_id=vector_store.deployment_index_id,
|
deployed_index_id=vector_store.deployment_index_id,
|
||||||
queries=[query],
|
queries=[vectors],
|
||||||
num_neighbors=1,
|
num_neighbors=1,
|
||||||
filter=[Namespace("user_id", ["test_user"], [])],
|
filter=[Namespace("user_id", ["test_user"], [])],
|
||||||
return_full_datapoint=True
|
return_full_datapoint=True
|
||||||
|
|||||||
@@ -147,8 +147,8 @@ class TestWeaviateDB(unittest.TestCase):
|
|||||||
self.client_mock.collections.get.return_value.query.hybrid = mock_hybrid
|
self.client_mock.collections.get.return_value.query.hybrid = mock_hybrid
|
||||||
mock_hybrid.return_value = mock_response
|
mock_hybrid.return_value = mock_response
|
||||||
|
|
||||||
query_vector = [0.1] * 1536
|
vectors = [[0.1] * 1536]
|
||||||
results = self.weaviate_db.search(query=query_vector, limit=5)
|
results = self.weaviate_db.search(query="", vectors=vectors, limit=5)
|
||||||
|
|
||||||
mock_hybrid.assert_called_once()
|
mock_hybrid.assert_called_once()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user