Support for hybrid search in Azure AI vector store (#2408)

Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
Dev Khant
2025-03-20 22:57:00 +05:30
committed by GitHub
parent 8b9a8e5825
commit 8e6a08aa83
24 changed files with 275 additions and 294 deletions

View File

@@ -50,6 +50,24 @@ config = {
} }
``` ```
## Using hybrid search
```python
config = {
"vector_store": {
"provider": "azure_ai_search",
"config": {
"service_name": "ai-search-test",
"api_key": "*****",
"collection_name": "mem0",
"embedding_model_dims": 1536,
"hybrid_search": True,
"vector_filter_mode": "postFilter"
}
}
}
```
## Configuration Parameters ## Configuration Parameters
| Parameter | Description | Default Value | Options | | Parameter | Description | Default Value | Options |
@@ -60,6 +78,8 @@ config = {
| `embedding_model_dims` | Dimensions of the embedding model | `1536` | Any integer value | | `embedding_model_dims` | Dimensions of the embedding model | `1536` | Any integer value |
| `compression_type` | Type of vector compression to use | `none` | `none`, `scalar`, `binary` | | `compression_type` | Type of vector compression to use | `none` | `none`, `scalar`, `binary` |
| `use_float16` | Store vectors in half precision (Edm.Half) | `False` | `True`, `False` | | `use_float16` | Store vectors in half precision (Edm.Half) | `False` | `True`, `False` |
| `vector_filter_mode` | Vector filter mode to use | `preFilter` | `postFilter`, `preFilter` |
| `hybrid_search` | Use hybrid search | `False` | `True`, `False` |
## Notes on Configuration Options ## Notes on Configuration Options
@@ -68,6 +88,10 @@ config = {
- `scalar`: Scalar quantization with reasonable balance of speed and accuracy - `scalar`: Scalar quantization with reasonable balance of speed and accuracy
- `binary`: Binary quantization for maximum compression with some accuracy trade-off - `binary`: Binary quantization for maximum compression with some accuracy trade-off
- **vector_filter_mode**:
- `preFilter`: Applies filters before vector search (faster)
- `postFilter`: Applies filters after vector search (may provide better relevance)
- **use_float16**: Using half precision (float16) reduces storage requirements but may slightly impact accuracy. Useful for very large vector collections. - **use_float16**: Using half precision (float16) reduces storage requirements but may slightly impact accuracy. Useful for very large vector collections.
- **Filterable Fields**: The implementation automatically extracts `user_id`, `run_id`, and `agent_id` fields from payloads for filtering. - **Filterable Fields**: The implementation automatically extracts `user_id`, `run_id`, and `agent_id` fields from payloads for filtering.

View File

@@ -8,21 +8,26 @@ class AzureAISearchConfig(BaseModel):
api_key: str = Field(None, description="API key for the Azure AI Search service") api_key: str = Field(None, description="API key for the Azure AI Search service")
embedding_model_dims: int = Field(None, description="Dimension of the embedding vector") embedding_model_dims: int = Field(None, description="Dimension of the embedding vector")
compression_type: Optional[str] = Field( compression_type: Optional[str] = Field(
None, None, description="Type of vector compression to use. Options: 'scalar', 'binary', or None"
description="Type of vector compression to use. Options: 'scalar', 'binary', or None"
) )
use_float16: bool = Field( use_float16: bool = Field(
False, False,
description="Whether to store vectors in half precision (Edm.Half) instead of full precision (Edm.Single)" description="Whether to store vectors in half precision (Edm.Half) instead of full precision (Edm.Single)",
) )
hybrid_search: bool = Field(
False, description="Whether to use hybrid search. If True, vector_filter_mode must be 'preFilter'"
)
vector_filter_mode: Optional[str] = Field(
"preFilter", description="Mode for vector filtering. Options: 'preFilter', 'postFilter'"
)
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]: def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys()) allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys()) input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields extra_fields = input_fields - allowed_fields
# Check for use_compression to provide a helpful error # Check for use_compression to provide a helpful error
if "use_compression" in extra_fields: if "use_compression" in extra_fields:
raise ValueError( raise ValueError(
@@ -30,13 +35,13 @@ class AzureAISearchConfig(BaseModel):
"Please use 'compression_type=\"scalar\"' instead of 'use_compression=True' " "Please use 'compression_type=\"scalar\"' instead of 'use_compression=True' "
"or 'compression_type=None' instead of 'use_compression=False'." "or 'compression_type=None' instead of 'use_compression=False'."
) )
if extra_fields: if extra_fields:
raise ValueError( raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. " f"Extra fields not allowed: {', '.join(extra_fields)}. "
f"Please input only the following fields: {', '.join(allowed_fields)}" f"Please input only the following fields: {', '.join(allowed_fields)}"
) )
# Validate compression_type values # Validate compression_type values
if "compression_type" in values and values["compression_type"] is not None: if "compression_type" in values and values["compression_type"] is not None:
valid_types = ["scalar", "binary"] valid_types = ["scalar", "binary"]
@@ -45,9 +50,9 @@ class AzureAISearchConfig(BaseModel):
f"Invalid compression_type: {values['compression_type']}. " f"Invalid compression_type: {values['compression_type']}. "
f"Must be one of: {', '.join(valid_types)}, or None" f"Must be one of: {', '.join(valid_types)}, or None"
) )
return values return values
model_config = { model_config = {
"arbitrary_types_allowed": True, "arbitrary_types_allowed": True,
} }

View File

@@ -17,8 +17,7 @@ class ElasticsearchConfig(BaseModel):
use_ssl: bool = Field(True, description="Use SSL for connection") use_ssl: bool = Field(True, description="Use SSL for connection")
auto_create_index: bool = Field(True, description="Automatically create index during initialization") auto_create_index: bool = Field(True, description="Automatically create index during initialization")
custom_search_query: Optional[Callable[[List[float], int, Optional[Dict]], Dict]] = Field( custom_search_query: Optional[Callable[[List[float], int, Optional[Dict]], Dict]] = Field(
None, None, description="Custom search query function. Parameters: (query, limit, filters) -> Dict"
description="Custom search query function. Parameters: (query, limit, filters) -> Dict"
) )
@model_validator(mode="before") @model_validator(mode="before")

View File

@@ -14,9 +14,7 @@ class GoogleMatchingEngineConfig(BaseModel):
credentials_path: Optional[str] = Field(None, description="Path to service account credentials file") credentials_path: Optional[str] = Field(None, description="Path to service account credentials file")
vector_search_api_endpoint: Optional[str] = Field(None, description="Vector search API endpoint") vector_search_api_endpoint: Optional[str] = Field(None, description="Vector search API endpoint")
model_config = { model_config = {"extra": "forbid"}
"extra": "forbid"
}
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
@@ -26,4 +24,4 @@ class GoogleMatchingEngineConfig(BaseModel):
def model_post_init(self, _context) -> None: def model_post_init(self, _context) -> None:
"""Set collection_name to index_id if not provided""" """Set collection_name to index_id if not provided"""
if self.collection_name is None: if self.collection_name is None:
self.collection_name = self.index_id self.collection_name = self.index_id

View File

@@ -71,13 +71,14 @@ class Memory(MemoryBase):
if "vector_store" not in config_dict and "embedder" in config_dict: if "vector_store" not in config_dict and "embedder" in config_dict:
config_dict["vector_store"] = {} config_dict["vector_store"] = {}
config_dict["vector_store"]["config"] = {} config_dict["vector_store"]["config"] = {}
config_dict["vector_store"]["config"]["embedding_model_dims"] = config_dict["embedder"]["config"]["embedding_dims"] config_dict["vector_store"]["config"]["embedding_model_dims"] = config_dict["embedder"]["config"][
"embedding_dims"
]
try: try:
return config_dict return config_dict
except ValidationError as e: except ValidationError as e:
logger.error(f"Configuration validation error: {e}") logger.error(f"Configuration validation error: {e}")
raise raise
def add( def add(
self, self,
@@ -204,7 +205,8 @@ class Memory(MemoryBase):
messages_embeddings = self.embedding_model.embed(new_mem, "add") messages_embeddings = self.embedding_model.embed(new_mem, "add")
new_message_embeddings[new_mem] = messages_embeddings new_message_embeddings[new_mem] = messages_embeddings
existing_memories = self.vector_store.search( existing_memories = self.vector_store.search(
query=messages_embeddings, query=new_mem,
vectors=messages_embeddings,
limit=5, limit=5,
filters=filters, filters=filters,
) )
@@ -222,7 +224,9 @@ class Memory(MemoryBase):
temp_uuid_mapping[str(idx)] = item["id"] temp_uuid_mapping[str(idx)] = item["id"]
retrieved_old_memory[idx]["id"] = str(idx) retrieved_old_memory[idx]["id"] = str(idx)
function_calling_prompt = get_update_memory_messages(retrieved_old_memory, new_retrieved_facts, self.custom_update_memory_prompt) function_calling_prompt = get_update_memory_messages(
retrieved_old_memory, new_retrieved_facts, self.custom_update_memory_prompt
)
try: try:
new_memories_with_actions = self.llm.generate_response( new_memories_with_actions = self.llm.generate_response(
@@ -479,7 +483,7 @@ class Memory(MemoryBase):
def _search_vector_store(self, query, filters, limit): def _search_vector_store(self, query, filters, limit):
embeddings = self.embedding_model.embed(query, "search") embeddings = self.embedding_model.embed(query, "search")
memories = self.vector_store.search(query=embeddings, limit=limit, filters=filters) memories = self.vector_store.search(query=query, vectors=embeddings, limit=limit, filters=filters)
excluded_keys = { excluded_keys = {
"user_id", "user_id",

View File

@@ -45,8 +45,10 @@ class AzureAISearch(VectorStoreBase):
collection_name, collection_name,
api_key, api_key,
embedding_model_dims, embedding_model_dims,
compression_type: Optional[str] = None, compression_type: Optional[str] = None,
use_float16: bool = False, use_float16: bool = False,
hybrid_search: bool = False,
vector_filter_mode: Optional[str] = None,
): ):
""" """
Initialize the Azure AI Search vector store. Initialize the Azure AI Search vector store.
@@ -60,13 +62,17 @@ class AzureAISearch(VectorStoreBase):
Allowed values are None (no quantization), "scalar", or "binary". Allowed values are None (no quantization), "scalar", or "binary".
use_float16 (bool): Whether to store vectors in half precision (Edm.Half) or full precision (Edm.Single). use_float16 (bool): Whether to store vectors in half precision (Edm.Half) or full precision (Edm.Single).
(Note: This flag is preserved from the initial implementation per feedback.) (Note: This flag is preserved from the initial implementation per feedback.)
hybrid_search (bool): Whether to use hybrid search. Default is False.
vector_filter_mode (Optional[str]): Mode for vector filtering. Default is "preFilter".
""" """
self.index_name = collection_name self.index_name = collection_name
self.collection_name = collection_name self.collection_name = collection_name
self.embedding_model_dims = embedding_model_dims self.embedding_model_dims = embedding_model_dims
# If compression_type is None, treat it as "none". # If compression_type is None, treat it as "none".
self.compression_type = (compression_type or "none").lower() self.compression_type = (compression_type or "none").lower()
self.use_float16 = use_float16 self.use_float16 = use_float16
self.hybrid_search = hybrid_search
self.vector_filter_mode = vector_filter_mode
self.search_client = SearchClient( self.search_client = SearchClient(
endpoint=f"https://{service_name}.search.windows.net", endpoint=f"https://{service_name}.search.windows.net",
@@ -113,8 +119,6 @@ class AzureAISearch(VectorStoreBase):
) )
] ]
# If no compression is desired, compression_configurations remains empty. # If no compression is desired, compression_configurations remains empty.
fields = [ fields = [
SimpleField(name="id", type=SearchFieldDataType.String, key=True), SimpleField(name="id", type=SearchFieldDataType.String, key=True),
SimpleField(name="user_id", type=SearchFieldDataType.String, filterable=True), SimpleField(name="user_id", type=SearchFieldDataType.String, filterable=True),
@@ -123,11 +127,11 @@ class AzureAISearch(VectorStoreBase):
SearchField( SearchField(
name="vector", name="vector",
type=vector_type, type=vector_type,
searchable=True, searchable=True,
vector_search_dimensions=self.embedding_model_dims, vector_search_dimensions=self.embedding_model_dims,
vector_search_profile_name="my-vector-config", vector_search_profile_name="my-vector-config",
), ),
SimpleField(name="payload", type=SearchFieldDataType.String, searchable=True), SearchField(name="payload", type=SearchFieldDataType.String, searchable=True),
] ]
vector_search = VectorSearch( vector_search = VectorSearch(
@@ -135,7 +139,7 @@ class AzureAISearch(VectorStoreBase):
VectorSearchProfile( VectorSearchProfile(
name="my-vector-config", name="my-vector-config",
algorithm_configuration_name="my-algorithms-config", algorithm_configuration_name="my-algorithms-config",
compression_name=compression_name if self.compression_type != "none" else None compression_name=compression_name if self.compression_type != "none" else None,
) )
], ],
algorithms=[HnswAlgorithmConfiguration(name="my-algorithms-config")], algorithms=[HnswAlgorithmConfiguration(name="my-algorithms-config")],
@@ -164,8 +168,7 @@ class AzureAISearch(VectorStoreBase):
""" """
logger.info(f"Inserting {len(vectors)} vectors into index {self.index_name}") logger.info(f"Inserting {len(vectors)} vectors into index {self.index_name}")
documents = [ documents = [
self._generate_document(vector, payload, id) self._generate_document(vector, payload, id) for id, vector, payload in zip(ids, vectors, payloads)
for id, vector, payload in zip(ids, vectors, payloads)
] ]
response = self.search_client.upload_documents(documents) response = self.search_client.upload_documents(documents)
for doc in response: for doc in response:
@@ -189,12 +192,13 @@ class AzureAISearch(VectorStoreBase):
filter_expression = " and ".join(filter_conditions) filter_expression = " and ".join(filter_conditions)
return filter_expression return filter_expression
def search(self, query, limit=5, filters=None): def search(self, query, vectors, limit=5, filters=None):
""" """
Search for similar vectors. Search for similar vectors.
Args: Args:
query (List[float]): Query vector. query (str): Query.
vectors (List[float]): Query vector.
limit (int, optional): Number of results to return. Defaults to 5. limit (int, optional): Number of results to return. Defaults to 5.
filters (Dict, optional): Filters to apply to the search. Defaults to None. filters (Dict, optional): Filters to apply to the search. Defaults to None.
@@ -205,23 +209,28 @@ class AzureAISearch(VectorStoreBase):
if filters: if filters:
filter_expression = self._build_filter_expression(filters) filter_expression = self._build_filter_expression(filters)
vector_query = VectorizedQuery( vector_query = VectorizedQuery(vector=vectors, k_nearest_neighbors=limit, fields="vector")
vector=query, k_nearest_neighbors=limit, fields="vector" if self.hybrid_search:
) search_results = self.search_client.search(
search_results = self.search_client.search( search_text=query,
vector_queries=[vector_query], vector_queries=[vector_query],
filter=filter_expression, filter=filter_expression,
top=limit top=limit,
) vector_filter_mode=self.vector_filter_mode,
search_fields=["payload"],
)
else:
search_results = self.search_client.search(
vector_queries=[vector_query],
filter=filter_expression,
top=limit,
vector_filter_mode=self.vector_filter_mode,
)
results = [] results = []
for result in search_results: for result in search_results:
payload = json.loads(result["payload"]) payload = json.loads(result["payload"])
results.append( results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
OutputData(
id=result["id"], score=result["@search.score"], payload=payload
)
)
return results return results
def delete(self, vector_id): def delete(self, vector_id):
@@ -275,9 +284,7 @@ class AzureAISearch(VectorStoreBase):
result = self.search_client.get_document(key=vector_id) result = self.search_client.get_document(key=vector_id)
except ResourceNotFoundError: except ResourceNotFoundError:
return None return None
return OutputData( return OutputData(id=result["id"], score=None, payload=json.loads(result["payload"]))
id=result["id"], score=None, payload=json.loads(result["payload"])
)
def list_cols(self) -> List[str]: def list_cols(self) -> List[str]:
""" """
@@ -321,17 +328,11 @@ class AzureAISearch(VectorStoreBase):
if filters: if filters:
filter_expression = self._build_filter_expression(filters) filter_expression = self._build_filter_expression(filters)
search_results = self.search_client.search( search_results = self.search_client.search(search_text="*", filter=filter_expression, top=limit)
search_text="*", filter=filter_expression, top=limit
)
results = [] results = []
for result in search_results: for result in search_results:
payload = json.loads(result["payload"]) payload = json.loads(result["payload"])
results.append( results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
OutputData(
id=result["id"], score=result["@search.score"], payload=payload
)
)
return [results] return [results]
def __del__(self): def __del__(self):

View File

@@ -13,7 +13,7 @@ class VectorStoreBase(ABC):
pass pass
@abstractmethod @abstractmethod
def search(self, query, limit=5, filters=None): def search(self, query, vectors, limit=5, filters=None):
"""Search for similar vectors.""" """Search for similar vectors."""
pass pass

View File

@@ -127,19 +127,22 @@ class ChromaDB(VectorStoreBase):
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}") logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
self.collection.add(ids=ids, embeddings=vectors, metadatas=payloads) self.collection.add(ids=ids, embeddings=vectors, metadatas=payloads)
def search(self, query: List[list], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]: def search(
self, query: str, vectors: List[list], limit: int = 5, filters: Optional[Dict] = None
) -> List[OutputData]:
""" """
Search for similar vectors. Search for similar vectors.
Args: Args:
query (List[list]): Query vector. query (str): Query.
vectors (List[list]): List of vectors to search.
limit (int, optional): Number of results to return. Defaults to 5. limit (int, optional): Number of results to return. Defaults to 5.
filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None. filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None.
Returns: Returns:
List[OutputData]: Search results. List[OutputData]: Search results.
""" """
results = self.collection.query(query_embeddings=query, where=filters, n_results=limit) results = self.collection.query(query_embeddings=vectors, where=filters, n_results=limit)
final_results = self._parse_output(results) final_results = self._parse_output(results)
return final_results return final_results

View File

@@ -45,7 +45,7 @@ class ElasticsearchDB(VectorStoreBase):
# Create index only if auto_create_index is True # Create index only if auto_create_index is True
if config.auto_create_index: if config.auto_create_index:
self.create_index() self.create_index()
if config.custom_search_query: if config.custom_search_query:
self.custom_search_query = config.custom_search_query self.custom_search_query = config.custom_search_query
else: else:
@@ -121,16 +121,20 @@ class ElasticsearchDB(VectorStoreBase):
) )
return results return results
def search(self, query: List[float], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]: def search(
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
) -> List[OutputData]:
""" """
Search with two options: Search with two options:
1. Use custom search query if provided 1. Use custom search query if provided
2. Use KNN search on vectors with pre-filtering if no custom search query is provided 2. Use KNN search on vectors with pre-filtering if no custom search query is provided
""" """
if self.custom_search_query: if self.custom_search_query:
search_query = self.custom_search_query(query, limit, filters) search_query = self.custom_search_query(vectors, limit, filters)
else: else:
search_query = {"knn": {"field": "vector", "query_vector": query, "k": limit, "num_candidates": limit * 2}} search_query = {
"knn": {"field": "vector", "query_vector": vectors, "k": limit, "num_candidates": limit * 2}
}
if filters: if filters:
filter_conditions = [] filter_conditions = []
for key, value in filters.items(): for key, value in filters.items():

View File

@@ -134,12 +134,13 @@ class MilvusDB(VectorStoreBase):
return memory return memory
def search(self, query: list, limit: int = 5, filters: dict = None) -> list: def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None) -> list:
""" """
Search for similar vectors. Search for similar vectors.
Args: Args:
query (List[float]): Query vector. query (str): Query.
vectors (List[float]): Query vector.
limit (int, optional): Number of results to return. Defaults to 5. limit (int, optional): Number of results to return. Defaults to 5.
filters (Dict, optional): Filters to apply to the search. Defaults to None. filters (Dict, optional): Filters to apply to the search. Defaults to None.
@@ -149,7 +150,7 @@ class MilvusDB(VectorStoreBase):
query_filter = self._create_filter(filters) if filters else None query_filter = self._create_filter(filters) if filters else None
hits = self.client.search( hits = self.client.search(
collection_name=self.collection_name, collection_name=self.collection_name,
data=[query], data=[vectors],
limit=limit, limit=limit,
filter=query_filter, filter=query_filter,
output_fields=["*"], output_fields=["*"],

View File

@@ -28,10 +28,12 @@ class OpenSearchDB(VectorStoreBase):
# Initialize OpenSearch client # Initialize OpenSearch client
self.client = OpenSearch( self.client = OpenSearch(
hosts=[{"host": config.host, "port": config.port or 9200}], hosts=[{"host": config.host, "port": config.port or 9200}],
http_auth=config.http_auth if config.http_auth else ((config.user, config.password) if (config.user and config.password) else None), http_auth=config.http_auth
if config.http_auth
else ((config.user, config.password) if (config.user and config.password) else None),
use_ssl=config.use_ssl, use_ssl=config.use_ssl,
verify_certs=config.verify_certs, verify_certs=config.verify_certs,
connection_class=RequestsHttpConnection connection_class=RequestsHttpConnection,
) )
self.collection_name = config.collection_name self.collection_name = config.collection_name
@@ -115,14 +117,16 @@ class OpenSearchDB(VectorStoreBase):
results.append(OutputData(id=id_, score=1.0, payload=payloads[i])) results.append(OutputData(id=id_, score=1.0, payload=payloads[i]))
return results return results
def search(self, query: List[float], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]: def search(
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
) -> List[OutputData]:
"""Search for similar vectors using OpenSearch k-NN search with pre-filtering.""" """Search for similar vectors using OpenSearch k-NN search with pre-filtering."""
search_query = { search_query = {
"size": limit, "size": limit,
"query": { "query": {
"knn": { "knn": {
"vector": { "vector": {
"vector": query, "vector": vectors,
"k": limit, "k": limit,
} }
} }

View File

@@ -120,12 +120,13 @@ class PGVector(VectorStoreBase):
) )
self.conn.commit() self.conn.commit()
def search(self, query, limit=5, filters=None): def search(self, query, vectors, limit=5, filters=None):
""" """
Search for similar vectors. Search for similar vectors.
Args: Args:
query (List[float]): Query vector. query (str): Query.
vectors (List[float]): Query vector.
limit (int, optional): Number of results to return. Defaults to 5. limit (int, optional): Number of results to return. Defaults to 5.
filters (Dict, optional): Filters to apply to the search. Defaults to None. filters (Dict, optional): Filters to apply to the search. Defaults to None.
@@ -150,7 +151,7 @@ class PGVector(VectorStoreBase):
ORDER BY distance ORDER BY distance
LIMIT %s LIMIT %s
""", """,
(query, *filter_params, limit), (vectors, *filter_params, limit),
) )
results = self.cur.fetchall() results = self.cur.fetchall()

View File

@@ -127,12 +127,13 @@ class Qdrant(VectorStoreBase):
conditions.append(FieldCondition(key=key, match=MatchValue(value=value))) conditions.append(FieldCondition(key=key, match=MatchValue(value=value)))
return Filter(must=conditions) if conditions else None return Filter(must=conditions) if conditions else None
def search(self, query: list, limit: int = 5, filters: dict = None) -> list: def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None) -> list:
""" """
Search for similar vectors. Search for similar vectors.
Args: Args:
query (list): Query vector. query (str): Query.
vectors (list): Query vector.
limit (int, optional): Number of results to return. Defaults to 5. limit (int, optional): Number of results to return. Defaults to 5.
filters (dict, optional): Filters to apply to the search. Defaults to None. filters (dict, optional): Filters to apply to the search. Defaults to None.
@@ -142,7 +143,7 @@ class Qdrant(VectorStoreBase):
query_filter = self._create_filter(filters) if filters else None query_filter = self._create_filter(filters) if filters else None
hits = self.client.query_points( hits = self.client.query_points(
collection_name=self.collection_name, collection_name=self.collection_name,
query=query, query=vectors,
query_filter=query_filter, query_filter=query_filter,
limit=limit, limit=limit,
) )

View File

@@ -101,12 +101,12 @@ class RedisDB(VectorStoreBase):
data.append(entry) data.append(entry)
self.index.load(data, id_field="memory_id") self.index.load(data, id_field="memory_id")
def search(self, query: list, limit: int = 5, filters: dict = None): def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None):
conditions = [Tag(key) == value for key, value in filters.items() if value is not None] conditions = [Tag(key) == value for key, value in filters.items() if value is not None]
filter = reduce(lambda x, y: x & y, conditions) filter = reduce(lambda x, y: x & y, conditions)
v = VectorQuery( v = VectorQuery(
vector=np.array(query, dtype=np.float32).tobytes(), vector=np.array(vectors, dtype=np.float32).tobytes(),
vector_field_name="embedding", vector_field_name="embedding",
return_fields=["memory_id", "hash", "agent_id", "run_id", "user_id", "memory", "metadata", "created_at"], return_fields=["memory_id", "hash", "agent_id", "run_id", "user_id", "memory", "metadata", "created_at"],
filter_expression=filter, filter_expression=filter,

View File

@@ -112,16 +112,18 @@ class Supabase(VectorStoreBase):
payloads = [{} for _ in vectors] payloads = [{} for _ in vectors]
records = [(id, vector, payload) for id, vector, payload in zip(ids, vectors, payloads)] records = [(id, vector, payload) for id, vector, payload in zip(ids, vectors, payloads)]
print(records)
self.collection.upsert(records) self.collection.upsert(records)
def search(self, query: List[float], limit: int = 5, filters: Optional[dict] = None) -> List[OutputData]: def search(
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[dict] = None
) -> List[OutputData]:
""" """
Search for similar vectors. Search for similar vectors.
Args: Args:
query (List[float]): Query vector query (str): Query.
vectors (List[float]): Query vector.
limit (int, optional): Number of results to return. Defaults to 5. limit (int, optional): Number of results to return. Defaults to 5.
filters (Dict, optional): Filters to apply to the search. Defaults to None. filters (Dict, optional): Filters to apply to the search. Defaults to None.
@@ -129,11 +131,9 @@ class Supabase(VectorStoreBase):
List[OutputData]: Search results List[OutputData]: Search results
""" """
filters = self._preprocess_filters(filters) filters = self._preprocess_filters(filters)
print(filters)
results = self.collection.query( results = self.collection.query(
data=query, limit=limit, filters=filters, include_metadata=True, include_value=True data=vectors, limit=limit, filters=filters, include_metadata=True, include_value=True
) )
print(results)
return [OutputData(id=str(result[0]), score=float(result[1]), payload=result[2]) for result in results] return [OutputData(id=str(result[0]), score=float(result[1]), payload=result[2]) for result in results]

View File

@@ -32,19 +32,19 @@ class GoogleMatchingEngine(VectorStoreBase):
def __init__(self, **kwargs): def __init__(self, **kwargs):
"""Initialize Google Matching Engine client.""" """Initialize Google Matching Engine client."""
logger.debug("Initializing Google Matching Engine with kwargs: %s", kwargs) logger.debug("Initializing Google Matching Engine with kwargs: %s", kwargs)
# If collection_name is passed, use it as deployment_index_id if deployment_index_id is not provided # If collection_name is passed, use it as deployment_index_id if deployment_index_id is not provided
if 'collection_name' in kwargs and 'deployment_index_id' not in kwargs: if "collection_name" in kwargs and "deployment_index_id" not in kwargs:
kwargs['deployment_index_id'] = kwargs['collection_name'] kwargs["deployment_index_id"] = kwargs["collection_name"]
logger.debug("Using collection_name as deployment_index_id: %s", kwargs['deployment_index_id']) logger.debug("Using collection_name as deployment_index_id: %s", kwargs["deployment_index_id"])
elif 'deployment_index_id' in kwargs and 'collection_name' not in kwargs: elif "deployment_index_id" in kwargs and "collection_name" not in kwargs:
kwargs['collection_name'] = kwargs['deployment_index_id'] kwargs["collection_name"] = kwargs["deployment_index_id"]
logger.debug("Using deployment_index_id as collection_name: %s", kwargs['collection_name']) logger.debug("Using deployment_index_id as collection_name: %s", kwargs["collection_name"])
try: try:
config = GoogleMatchingEngineConfig(**kwargs) config = GoogleMatchingEngineConfig(**kwargs)
logger.debug("Config created: %s", config.model_dump()) logger.debug("Config created: %s", config.model_dump())
logger.debug("Config collection_name: %s", getattr(config, 'collection_name', None)) logger.debug("Config collection_name: %s", getattr(config, "collection_name", None))
except Exception as e: except Exception as e:
logger.error("Failed to validate config: %s", str(e)) logger.error("Failed to validate config: %s", str(e))
raise raise
@@ -57,41 +57,37 @@ class GoogleMatchingEngine(VectorStoreBase):
self.deployment_index_id = config.deployment_index_id # The deployment-specific ID self.deployment_index_id = config.deployment_index_id # The deployment-specific ID
self.collection_name = config.collection_name self.collection_name = config.collection_name
self.vector_search_api_endpoint = config.vector_search_api_endpoint self.vector_search_api_endpoint = config.vector_search_api_endpoint
logger.debug("Using project=%s, location=%s", self.project_id, self.region) logger.debug("Using project=%s, location=%s", self.project_id, self.region)
# Initialize Vertex AI with credentials if provided # Initialize Vertex AI with credentials if provided
init_args = { init_args = {
"project": self.project_id, "project": self.project_id,
"location": self.region, "location": self.region,
} }
if hasattr(config, 'credentials_path') and config.credentials_path: if hasattr(config, "credentials_path") and config.credentials_path:
logger.debug("Using credentials from: %s", config.credentials_path) logger.debug("Using credentials from: %s", config.credentials_path)
credentials = service_account.Credentials.from_service_account_file( credentials = service_account.Credentials.from_service_account_file(config.credentials_path)
config.credentials_path
)
init_args["credentials"] = credentials init_args["credentials"] = credentials
try: try:
aiplatform.init(**init_args) aiplatform.init(**init_args)
logger.debug("Vertex AI initialized successfully") logger.debug("Vertex AI initialized successfully")
except Exception as e: except Exception as e:
logger.error("Failed to initialize Vertex AI: %s", str(e)) logger.error("Failed to initialize Vertex AI: %s", str(e))
raise raise
try: try:
# Format the index path properly using the configured index_id # Format the index path properly using the configured index_id
index_path = f"projects/{self.project_number}/locations/{self.region}/indexes/{self.index_id}" index_path = f"projects/{self.project_number}/locations/{self.region}/indexes/{self.index_id}"
logger.debug("Initializing index with path: %s", index_path) logger.debug("Initializing index with path: %s", index_path)
self.index = aiplatform.MatchingEngineIndex(index_name=index_path) self.index = aiplatform.MatchingEngineIndex(index_name=index_path)
logger.debug("Index initialized successfully") logger.debug("Index initialized successfully")
# Format the endpoint name properly # Format the endpoint name properly
endpoint_name = self.endpoint_id endpoint_name = self.endpoint_id
logger.debug("Initializing endpoint with name: %s", endpoint_name) logger.debug("Initializing endpoint with name: %s", endpoint_name)
self.index_endpoint = aiplatform.MatchingEngineIndexEndpoint( self.index_endpoint = aiplatform.MatchingEngineIndexEndpoint(index_endpoint_name=endpoint_name)
index_endpoint_name=endpoint_name
)
logger.debug("Endpoint initialized successfully") logger.debug("Endpoint initialized successfully")
except Exception as e: except Exception as e:
logger.error("Failed to initialize Matching Engine components: %s", str(e)) logger.error("Failed to initialize Matching Engine components: %s", str(e))
@@ -119,47 +115,36 @@ class GoogleMatchingEngine(VectorStoreBase):
def _create_restriction(self, key: str, value: Any) -> aiplatform_v1.types.index.IndexDatapoint.Restriction: def _create_restriction(self, key: str, value: Any) -> aiplatform_v1.types.index.IndexDatapoint.Restriction:
"""Create a restriction object for the Matching Engine index. """Create a restriction object for the Matching Engine index.
Args: Args:
key: The namespace/key for the restriction key: The namespace/key for the restriction
value: The value to restrict on value: The value to restrict on
Returns: Returns:
Restriction object for the index Restriction object for the index
""" """
str_value = str(value) if value is not None else "" str_value = str(value) if value is not None else ""
return aiplatform_v1.types.index.IndexDatapoint.Restriction( return aiplatform_v1.types.index.IndexDatapoint.Restriction(namespace=key, allow_list=[str_value])
namespace=key,
allow_list=[str_value]
)
def _create_datapoint( def _create_datapoint(
self, self, vector_id: str, vector: List[float], payload: Optional[Dict] = None
vector_id: str,
vector: List[float],
payload: Optional[Dict] = None
) -> aiplatform_v1.types.index.IndexDatapoint: ) -> aiplatform_v1.types.index.IndexDatapoint:
"""Create a datapoint object for the Matching Engine index. """Create a datapoint object for the Matching Engine index.
Args: Args:
vector_id: The ID for the datapoint vector_id: The ID for the datapoint
vector: The vector to store vector: The vector to store
payload: Optional metadata to store with the vector payload: Optional metadata to store with the vector
Returns: Returns:
IndexDatapoint object IndexDatapoint object
""" """
restrictions = [] restrictions = []
if payload: if payload:
restrictions = [ restrictions = [self._create_restriction(key, value) for key, value in payload.items()]
self._create_restriction(key, value)
for key, value in payload.items()
]
return aiplatform_v1.types.index.IndexDatapoint( return aiplatform_v1.types.index.IndexDatapoint(
datapoint_id=vector_id, datapoint_id=vector_id, feature_vector=vector, restricts=restrictions
feature_vector=vector,
restricts=restrictions
) )
def insert( def insert(
@@ -169,41 +154,41 @@ class GoogleMatchingEngine(VectorStoreBase):
ids: Optional[List[str]] = None, ids: Optional[List[str]] = None,
) -> None: ) -> None:
"""Insert vectors into the Matching Engine index. """Insert vectors into the Matching Engine index.
Args: Args:
vectors: List of vectors to insert vectors: List of vectors to insert
payloads: Optional list of metadata dictionaries payloads: Optional list of metadata dictionaries
ids: Optional list of IDs for the vectors ids: Optional list of IDs for the vectors
Raises: Raises:
ValueError: If vectors is empty or lengths don't match ValueError: If vectors is empty or lengths don't match
GoogleAPIError: If the API call fails GoogleAPIError: If the API call fails
""" """
if not vectors: if not vectors:
raise ValueError("No vectors provided for insertion") raise ValueError("No vectors provided for insertion")
if payloads and len(payloads) != len(vectors): if payloads and len(payloads) != len(vectors):
raise ValueError(f"Number of payloads ({len(payloads)}) does not match number of vectors ({len(vectors)})") raise ValueError(f"Number of payloads ({len(payloads)}) does not match number of vectors ({len(vectors)})")
if ids and len(ids) != len(vectors): if ids and len(ids) != len(vectors):
raise ValueError(f"Number of ids ({len(ids)}) does not match number of vectors ({len(vectors)})") raise ValueError(f"Number of ids ({len(ids)}) does not match number of vectors ({len(vectors)})")
logger.debug("Starting insert of %d vectors", len(vectors)) logger.debug("Starting insert of %d vectors", len(vectors))
try: try:
datapoints = [ datapoints = [
self._create_datapoint( self._create_datapoint(
vector_id=ids[i] if ids else str(uuid.uuid4()), vector_id=ids[i] if ids else str(uuid.uuid4()),
vector=vector, vector=vector,
payload=payloads[i] if payloads and i < len(payloads) else None payload=payloads[i] if payloads and i < len(payloads) else None,
) )
for i, vector in enumerate(vectors) for i, vector in enumerate(vectors)
] ]
logger.debug("Created %d datapoints", len(datapoints)) logger.debug("Created %d datapoints", len(datapoints))
self.index.upsert_datapoints(datapoints=datapoints) self.index.upsert_datapoints(datapoints=datapoints)
logger.debug("Successfully inserted datapoints") logger.debug("Successfully inserted datapoints")
except google.api_core.exceptions.GoogleAPIError as e: except google.api_core.exceptions.GoogleAPIError as e:
logger.error("Failed to insert vectors: %s", str(e)) logger.error("Failed to insert vectors: %s", str(e))
raise raise
@@ -212,21 +197,22 @@ class GoogleMatchingEngine(VectorStoreBase):
logger.error("Stack trace: %s", traceback.format_exc()) logger.error("Stack trace: %s", traceback.format_exc())
raise raise
def search(
def search(self, query: List[float], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]: self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
) -> List[OutputData]:
""" """
Search for similar vectors. Search for similar vectors.
Args: Args:
query (List[float]): Query vector. query (str): Query.
vectors (List[float]): Query vector.
limit (int, optional): Number of results to return. Defaults to 5. limit (int, optional): Number of results to return. Defaults to 5.
filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None. filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None.
Returns: Returns:
List[OutputData]: Search results (unwrapped) List[OutputData]: Search results (unwrapped)
""" """
logger.debug("Starting search") logger.debug("Starting search")
logger.debug("Query type: %s, length: %d", type(query), len(query))
logger.debug("Limit: %d, Filters: %s", limit, filters) logger.debug("Limit: %d, Filters: %s", limit, filters)
try: try:
filter_namespaces = [] filter_namespaces = []
if filters: if filters:
@@ -235,53 +221,42 @@ class GoogleMatchingEngine(VectorStoreBase):
logger.debug("Processing filter %s=%s (type=%s)", key, value, type(value)) logger.debug("Processing filter %s=%s (type=%s)", key, value, type(value))
if isinstance(value, (str, int, float)): if isinstance(value, (str, int, float)):
logger.debug("Adding simple filter for %s", key) logger.debug("Adding simple filter for %s", key)
filter_namespaces.append( filter_namespaces.append(Namespace(key, [str(value)], []))
Namespace(key, [str(value)], [])
)
elif isinstance(value, dict): elif isinstance(value, dict):
logger.debug("Adding complex filter for %s", key) logger.debug("Adding complex filter for %s", key)
includes = value.get('include', []) includes = value.get("include", [])
excludes = value.get('exclude', []) excludes = value.get("exclude", [])
filter_namespaces.append( filter_namespaces.append(Namespace(key, includes, excludes))
Namespace(key, includes, excludes)
)
logger.debug("Final filter_namespaces: %s", filter_namespaces) logger.debug("Final filter_namespaces: %s", filter_namespaces)
response = self.index_endpoint.find_neighbors( response = self.index_endpoint.find_neighbors(
deployed_index_id=self.deployment_index_id, deployed_index_id=self.deployment_index_id,
queries=[query], queries=[vectors],
num_neighbors=limit, num_neighbors=limit,
filter=filter_namespaces if filter_namespaces else None, filter=filter_namespaces if filter_namespaces else None,
return_full_datapoint=True return_full_datapoint=True,
) )
if not response or len(response) == 0 or len(response[0]) == 0: if not response or len(response) == 0 or len(response[0]) == 0:
logger.debug("No results found") logger.debug("No results found")
return [] return []
results = [] results = []
for neighbor in response[0]: for neighbor in response[0]:
logger.debug("Processing neighbor - id: %s, distance: %s", logger.debug("Processing neighbor - id: %s, distance: %s", neighbor.id, neighbor.distance)
neighbor.id, neighbor.distance)
payload = {} payload = {}
if hasattr(neighbor, 'restricts'): if hasattr(neighbor, "restricts"):
logger.debug("Processing restricts") logger.debug("Processing restricts")
for restrict in neighbor.restricts: for restrict in neighbor.restricts:
if (hasattr(restrict, 'name') and if hasattr(restrict, "name") and hasattr(restrict, "allow_tokens") and restrict.allow_tokens:
hasattr(restrict, 'allow_tokens') and
restrict.allow_tokens):
logger.debug("Adding %s: %s", restrict.name, restrict.allow_tokens[0]) logger.debug("Adding %s: %s", restrict.name, restrict.allow_tokens[0])
payload[restrict.name] = restrict.allow_tokens[0] payload[restrict.name] = restrict.allow_tokens[0]
output_data = OutputData( output_data = OutputData(id=neighbor.id, score=neighbor.distance, payload=payload)
id=neighbor.id,
score=neighbor.distance,
payload=payload
)
results.append(output_data) results.append(output_data)
logger.debug("Returning %d results", len(results)) logger.debug("Returning %d results", len(results))
return results return results
@@ -291,7 +266,6 @@ class GoogleMatchingEngine(VectorStoreBase):
logger.error("Stack trace: %s", traceback.format_exc()) logger.error("Stack trace: %s", traceback.format_exc())
raise raise
def delete(self, vector_id: Optional[str] = None, ids: Optional[List[str]] = None) -> bool: def delete(self, vector_id: Optional[str] = None, ids: Optional[List[str]] = None) -> bool:
""" """
Delete vectors from the Matching Engine index. Delete vectors from the Matching Engine index.
@@ -326,14 +300,13 @@ class GoogleMatchingEngine(VectorStoreBase):
except google.api_core.exceptions.InvalidArgument as e: except google.api_core.exceptions.InvalidArgument as e:
logger.error("Invalid argument: %s", str(e)) logger.error("Invalid argument: %s", str(e))
return False return False
except Exception as e: except Exception as e:
logger.error("Error occurred: %s", str(e)) logger.error("Error occurred: %s", str(e))
logger.error("Error type: %s", type(e)) logger.error("Error type: %s", type(e))
logger.error("Stack trace: %s", traceback.format_exc()) logger.error("Stack trace: %s", traceback.format_exc())
return False return False
def update( def update(
self, self,
vector_id: str, vector_id: str,
@@ -341,42 +314,40 @@ class GoogleMatchingEngine(VectorStoreBase):
payload: Optional[Dict] = None, payload: Optional[Dict] = None,
) -> bool: ) -> bool:
"""Update a vector and its payload. """Update a vector and its payload.
Args: Args:
vector_id: ID of the vector to update vector_id: ID of the vector to update
vector: Optional new vector values vector: Optional new vector values
payload: Optional new metadata payload payload: Optional new metadata payload
Returns: Returns:
bool: True if update was successful bool: True if update was successful
Raises: Raises:
ValueError: If neither vector nor payload is provided ValueError: If neither vector nor payload is provided
GoogleAPIError: If the API call fails GoogleAPIError: If the API call fails
""" """
logger.debug("Starting update for vector_id: %s", vector_id) logger.debug("Starting update for vector_id: %s", vector_id)
if vector is None and payload is None: if vector is None and payload is None:
raise ValueError("Either vector or payload must be provided for update") raise ValueError("Either vector or payload must be provided for update")
# First check if the vector exists # First check if the vector exists
try: try:
existing = self.get(vector_id) existing = self.get(vector_id)
if existing is None: if existing is None:
logger.error("Vector ID not found: %s", vector_id) logger.error("Vector ID not found: %s", vector_id)
return False return False
datapoint = self._create_datapoint( datapoint = self._create_datapoint(
vector_id=vector_id, vector_id=vector_id, vector=vector if vector is not None else [], payload=payload
vector=vector if vector is not None else [],
payload=payload
) )
logger.debug("Upserting datapoint: %s", datapoint) logger.debug("Upserting datapoint: %s", datapoint)
self.index.upsert_datapoints(datapoints=[datapoint]) self.index.upsert_datapoints(datapoints=[datapoint])
logger.debug("Update completed successfully") logger.debug("Update completed successfully")
return True return True
except google.api_core.exceptions.GoogleAPIError as e: except google.api_core.exceptions.GoogleAPIError as e:
logger.error("API error during update: %s", str(e)) logger.error("API error during update: %s", str(e))
return False return False
@@ -385,7 +356,6 @@ class GoogleMatchingEngine(VectorStoreBase):
logger.error("Stack trace: %s", traceback.format_exc()) logger.error("Stack trace: %s", traceback.format_exc())
raise raise
def get(self, vector_id: str) -> Optional[OutputData]: def get(self, vector_id: str) -> Optional[OutputData]:
""" """
Retrieve a vector by ID. Retrieve a vector by ID.
@@ -395,24 +365,17 @@ class GoogleMatchingEngine(VectorStoreBase):
Optional[OutputData]: Retrieved vector or None if not found. Optional[OutputData]: Retrieved vector or None if not found.
""" """
logger.debug("Starting get for vector_id: %s", vector_id) logger.debug("Starting get for vector_id: %s", vector_id)
try: try:
if not self.vector_search_api_endpoint: if not self.vector_search_api_endpoint:
raise ValueError("vector_search_api_endpoint is required for get operation") raise ValueError("vector_search_api_endpoint is required for get operation")
vector_search_client = aiplatform_v1.MatchServiceClient( vector_search_client = aiplatform_v1.MatchServiceClient(
client_options={ client_options={"api_endpoint": self.vector_search_api_endpoint},
"api_endpoint": self.vector_search_api_endpoint
},
)
datapoint = aiplatform_v1.IndexDatapoint(
datapoint_id=vector_id
) )
datapoint = aiplatform_v1.IndexDatapoint(datapoint_id=vector_id)
query = aiplatform_v1.FindNeighborsRequest.Query( query = aiplatform_v1.FindNeighborsRequest.Query(datapoint=datapoint, neighbor_count=1)
datapoint=datapoint,
neighbor_count=1
)
request = aiplatform_v1.FindNeighborsRequest( request = aiplatform_v1.FindNeighborsRequest(
index_endpoint=f"projects/{self.project_number}/locations/{self.region}/indexEndpoints/{self.endpoint_id}", index_endpoint=f"projects/{self.project_number}/locations/{self.region}/indexEndpoints/{self.endpoint_id}",
deployed_index_id=self.deployment_index_id, deployed_index_id=self.deployment_index_id,
@@ -423,41 +386,36 @@ class GoogleMatchingEngine(VectorStoreBase):
try: try:
response = vector_search_client.find_neighbors(request) response = vector_search_client.find_neighbors(request)
logger.debug("Got response") logger.debug("Got response")
if response and response.nearest_neighbors: if response and response.nearest_neighbors:
nearest = response.nearest_neighbors[0] nearest = response.nearest_neighbors[0]
if nearest.neighbors: if nearest.neighbors:
neighbor = nearest.neighbors[0] neighbor = nearest.neighbors[0]
payload = {} payload = {}
if hasattr(neighbor.datapoint, 'restricts'): if hasattr(neighbor.datapoint, "restricts"):
for restrict in neighbor.datapoint.restricts: for restrict in neighbor.datapoint.restricts:
if restrict.allow_list: if restrict.allow_list:
payload[restrict.namespace] = restrict.allow_list[0] payload[restrict.namespace] = restrict.allow_list[0]
return OutputData( return OutputData(id=neighbor.datapoint.datapoint_id, score=neighbor.distance, payload=payload)
id=neighbor.datapoint.datapoint_id,
score=neighbor.distance,
payload=payload
)
logger.debug("No results found") logger.debug("No results found")
return None return None
except google.api_core.exceptions.NotFound: except google.api_core.exceptions.NotFound:
logger.debug("Datapoint not found") logger.debug("Datapoint not found")
return None return None
except google.api_core.exceptions.PermissionDenied as e: except google.api_core.exceptions.PermissionDenied as e:
logger.error("Permission denied: %s", str(e)) logger.error("Permission denied: %s", str(e))
return None return None
except Exception as e: except Exception as e:
logger.error("Error occurred: %s", str(e)) logger.error("Error occurred: %s", str(e))
logger.error("Error type: %s", type(e)) logger.error("Error type: %s", type(e))
logger.error("Stack trace: %s", traceback.format_exc()) logger.error("Stack trace: %s", traceback.format_exc())
raise raise
def list_cols(self) -> List[str]: def list_cols(self) -> List[str]:
""" """
List all collections (indexes). List all collections (indexes).
@@ -466,7 +424,6 @@ class GoogleMatchingEngine(VectorStoreBase):
""" """
return [self.deployment_index_id] return [self.deployment_index_id]
def delete_col(self): def delete_col(self):
""" """
Delete a collection (index). Delete a collection (index).
@@ -475,7 +432,6 @@ class GoogleMatchingEngine(VectorStoreBase):
logger.warning("Delete collection operation is not supported for Google Matching Engine") logger.warning("Delete collection operation is not supported for Google Matching Engine")
pass pass
def col_info(self) -> Dict: def col_info(self) -> Dict:
""" """
Get information about a collection (index). Get information about a collection (index).
@@ -486,17 +442,16 @@ class GoogleMatchingEngine(VectorStoreBase):
"index_id": self.index_id, "index_id": self.index_id,
"endpoint_id": self.endpoint_id, "endpoint_id": self.endpoint_id,
"project_id": self.project_id, "project_id": self.project_id,
"region": self.region "region": self.region,
} }
def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[List[OutputData]]: def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[List[OutputData]]:
"""List vectors matching the given filters. """List vectors matching the given filters.
Args: Args:
filters: Optional filters to apply filters: Optional filters to apply
limit: Optional maximum number of results to return limit: Optional maximum number of results to return
Returns: Returns:
List[List[OutputData]]: List of matching vectors wrapped in an extra array List[List[OutputData]]: List of matching vectors wrapped in an extra array
to match the interface to match the interface
@@ -504,36 +459,31 @@ class GoogleMatchingEngine(VectorStoreBase):
logger.debug("Starting list operation") logger.debug("Starting list operation")
logger.debug("Filters: %s", filters) logger.debug("Filters: %s", filters)
logger.debug("Limit: %s", limit) logger.debug("Limit: %s", limit)
try: try:
# Use a zero vector for the search # Use a zero vector for the search
dimension = 768 # This should be configurable based on the model dimension = 768 # This should be configurable based on the model
zero_vector = [0.0] * dimension zero_vector = [0.0] * dimension
# Use a large limit if none specified # Use a large limit if none specified
search_limit = limit if limit is not None else 10000 search_limit = limit if limit is not None else 10000
results = self.search( results = self.search(query=zero_vector, limit=search_limit, filters=filters)
query=zero_vector,
limit=search_limit,
filters=filters
)
logger.debug("Found %d results", len(results)) logger.debug("Found %d results", len(results))
return [results] # Wrap in extra array to match interface return [results] # Wrap in extra array to match interface
except Exception as e: except Exception as e:
logger.error("Error in list operation: %s", str(e)) logger.error("Error in list operation: %s", str(e))
logger.error("Stack trace: %s", traceback.format_exc()) logger.error("Stack trace: %s", traceback.format_exc())
raise raise
def create_col(self, name=None, vector_size=None, distance=None): def create_col(self, name=None, vector_size=None, distance=None):
""" """
Create a new collection. For Google Matching Engine, collections (indexes) Create a new collection. For Google Matching Engine, collections (indexes)
are created through the Google Cloud Console or API separately. are created through the Google Cloud Console or API separately.
This method is a no-op since indexes are pre-created. This method is a no-op since indexes are pre-created.
Args: Args:
name: Ignored for Google Matching Engine name: Ignored for Google Matching Engine
vector_size: Ignored for Google Matching Engine vector_size: Ignored for Google Matching Engine
@@ -543,41 +493,35 @@ class GoogleMatchingEngine(VectorStoreBase):
# This method is included only to satisfy the abstract base class # This method is included only to satisfy the abstract base class
pass pass
def add(self, text: str, metadata: Optional[Dict] = None, user_id: Optional[str] = None) -> str: def add(self, text: str, metadata: Optional[Dict] = None, user_id: Optional[str] = None) -> str:
logger.debug("Starting add operation") logger.debug("Starting add operation")
logger.debug("Text: %s", text) logger.debug("Text: %s", text)
logger.debug("Metadata: %s", metadata) logger.debug("Metadata: %s", metadata)
logger.debug("User ID: %s", user_id) logger.debug("User ID: %s", user_id)
try: try:
# Generate a unique ID for this entry # Generate a unique ID for this entry
vector_id = str(uuid.uuid4()) vector_id = str(uuid.uuid4())
# Create the payload with all necessary fields # Create the payload with all necessary fields
payload = { payload = {
"data": text, # Store the text in the data field "data": text, # Store the text in the data field
"user_id": user_id, "user_id": user_id,
**(metadata or {}) **(metadata or {}),
} }
# Get the embedding # Get the embedding
vector = self.embedder.embed_query(text) vector = self.embedder.embed_query(text)
# Insert using the insert method # Insert using the insert method
self.insert( self.insert(vectors=[vector], payloads=[payload], ids=[vector_id])
vectors=[vector],
payloads=[payload],
ids=[vector_id]
)
return vector_id return vector_id
except Exception as e: except Exception as e:
logger.error("Error occurred: %s", str(e)) logger.error("Error occurred: %s", str(e))
raise raise
def add_texts( def add_texts(
self, self,
texts: List[str], texts: List[str],
@@ -585,47 +529,45 @@ class GoogleMatchingEngine(VectorStoreBase):
ids: Optional[List[str]] = None, ids: Optional[List[str]] = None,
) -> List[str]: ) -> List[str]:
"""Add texts to the vector store. """Add texts to the vector store.
Args: Args:
texts: List of texts to add texts: List of texts to add
metadatas: Optional list of metadata dicts metadatas: Optional list of metadata dicts
ids: Optional list of IDs to use ids: Optional list of IDs to use
Returns: Returns:
List[str]: List of IDs of the added texts List[str]: List of IDs of the added texts
Raises: Raises:
ValueError: If texts is empty or lengths don't match ValueError: If texts is empty or lengths don't match
""" """
if not texts: if not texts:
raise ValueError("No texts provided") raise ValueError("No texts provided")
if metadatas and len(metadatas) != len(texts): if metadatas and len(metadatas) != len(texts):
raise ValueError(f"Number of metadata items ({len(metadatas)}) does not match number of texts ({len(texts)})") raise ValueError(
f"Number of metadata items ({len(metadatas)}) does not match number of texts ({len(texts)})"
)
if ids and len(ids) != len(texts): if ids and len(ids) != len(texts):
raise ValueError(f"Number of ids ({len(ids)}) does not match number of texts ({len(texts)})") raise ValueError(f"Number of ids ({len(ids)}) does not match number of texts ({len(texts)})")
logger.debug("Starting add_texts operation") logger.debug("Starting add_texts operation")
logger.debug("Number of texts: %d", len(texts)) logger.debug("Number of texts: %d", len(texts))
logger.debug("Has metadatas: %s", metadatas is not None) logger.debug("Has metadatas: %s", metadatas is not None)
logger.debug("Has ids: %s", ids is not None) logger.debug("Has ids: %s", ids is not None)
if ids is None: if ids is None:
ids = [str(uuid.uuid4()) for _ in texts] ids = [str(uuid.uuid4()) for _ in texts]
try: try:
# Get embeddings # Get embeddings
embeddings = self.embedder.embed_documents(texts) embeddings = self.embedder.embed_documents(texts)
# Add to store # Add to store
self.insert( self.insert(vectors=embeddings, payloads=metadatas if metadatas else [{}] * len(texts), ids=ids)
vectors=embeddings,
payloads=metadatas if metadatas else [{}] * len(texts),
ids=ids
)
return ids return ids
except Exception as e: except Exception as e:
logger.error("Error in add_texts: %s", str(e)) logger.error("Error in add_texts: %s", str(e))
logger.error("Stack trace: %s", traceback.format_exc()) logger.error("Stack trace: %s", traceback.format_exc())
@@ -657,18 +599,12 @@ class GoogleMatchingEngine(VectorStoreBase):
logger.debug("Query: %s", query) logger.debug("Query: %s", query)
logger.debug("k: %d", k) logger.debug("k: %d", k)
logger.debug("Filter: %s", filter) logger.debug("Filter: %s", filter)
embedding = self.embedder.embed_query(query) embedding = self.embedder.embed_query(query)
results = self.search(query=embedding, limit=k, filters=filter) results = self.search(query=embedding, limit=k, filters=filter)
docs_and_scores = [ docs_and_scores = [
( (Document(page_content=result.payload.get("text", ""), metadata=result.payload), result.score)
Document(
page_content=result.payload.get("text", ""),
metadata=result.payload
),
result.score
)
for result in results for result in results
] ]
logger.debug("Found %d results", len(docs_and_scores)) logger.debug("Found %d results", len(docs_and_scores))
@@ -684,4 +620,3 @@ class GoogleMatchingEngine(VectorStoreBase):
logger.debug("Starting similarity search") logger.debug("Starting similarity search")
docs_and_scores = self.similarity_search_with_score(query, k, filter) docs_and_scores = self.similarity_search_with_score(query, k, filter)
return [doc for doc, _ in docs_and_scores] return [doc for doc, _ in docs_and_scores]

View File

@@ -154,7 +154,9 @@ class Weaviate(VectorStoreBase):
batch.add_object(collection=self.collection_name, properties=data_object, uuid=object_id, vector=vector) batch.add_object(collection=self.collection_name, properties=data_object, uuid=object_id, vector=vector)
def search(self, query: List[float], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]: def search(
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
) -> List[OutputData]:
""" """
Search for similar vectors. Search for similar vectors.
""" """
@@ -167,7 +169,7 @@ class Weaviate(VectorStoreBase):
combined_filter = Filter.all_of(filter_conditions) if filter_conditions else None combined_filter = Filter.all_of(filter_conditions) if filter_conditions else None
response = collection.query.hybrid( response = collection.query.hybrid(
query="", query="",
vector=query, vector=vectors,
limit=limit, limit=limit,
filters=combined_filter, filters=combined_filter,
return_properties=["hash", "created_at", "updated_at", "user_id", "agent_id", "run_id", "data", "category"], return_properties=["hash", "created_at", "updated_at", "user_id", "agent_id", "run_id", "data", "category"],

View File

@@ -35,12 +35,11 @@ def test_search_vectors(chromadb_instance, mock_chromadb_client):
} }
chromadb_instance.collection.query.return_value = mock_result chromadb_instance.collection.query.return_value = mock_result
query = [[0.1, 0.2, 0.3]] vectors = [[0.1, 0.2, 0.3]]
results = chromadb_instance.search(query=query, limit=2) results = chromadb_instance.search(query="", vectors=vectors, limit=2)
chromadb_instance.collection.query.assert_called_once_with(query_embeddings=query, where=None, n_results=2) chromadb_instance.collection.query.assert_called_once_with(query_embeddings=vectors, where=None, n_results=2)
print(results, type(results))
assert len(results) == 2 assert len(results) == 2
assert results[0].id == "id1" assert results[0].id == "id1"
assert results[0].score == 0.1 assert results[0].score == 0.1

View File

@@ -196,8 +196,8 @@ class TestElasticsearchDB(unittest.TestCase):
self.client_mock.search.return_value = mock_response self.client_mock.search.return_value = mock_response
# Perform search # Perform search
query_vector = [0.1] * 1536 vectors = [[0.1] * 1536]
results = self.es_db.search(query=query_vector, limit=5) results = self.es_db.search(query="", vectors=vectors, limit=5)
# Verify search call # Verify search call
self.client_mock.search.assert_called_once() self.client_mock.search.assert_called_once()
@@ -210,7 +210,7 @@ class TestElasticsearchDB(unittest.TestCase):
# Verify KNN query structure # Verify KNN query structure
self.assertIn("knn", body) self.assertIn("knn", body)
self.assertEqual(body["knn"]["field"], "vector") self.assertEqual(body["knn"]["field"], "vector")
self.assertEqual(body["knn"]["query_vector"], query_vector) self.assertEqual(body["knn"]["query_vector"], vectors)
self.assertEqual(body["knn"]["k"], 5) self.assertEqual(body["knn"]["k"], 5)
self.assertEqual(body["knn"]["num_candidates"], 10) self.assertEqual(body["knn"]["num_candidates"], 10)
@@ -226,13 +226,13 @@ class TestElasticsearchDB(unittest.TestCase):
self.es_db.custom_search_query.return_value = {"custom_key": "custom_value"} self.es_db.custom_search_query.return_value = {"custom_key": "custom_value"}
# Perform search # Perform search
query_vector = [0.1] * 1536 vectors = [[0.1] * 1536]
limit = 5 limit = 5
filters = {"key1": "value1"} filters = {"key1": "value1"}
self.es_db.search(query=query_vector, limit=limit, filters=filters) self.es_db.search(query="", vectors=vectors, limit=limit, filters=filters)
# Verify custom search query function was called # Verify custom search query function was called
self.es_db.custom_search_query.assert_called_once_with(query_vector, limit, filters) self.es_db.custom_search_query.assert_called_once_with(vectors, limit, filters)
# Verify custom search query was used # Verify custom search query was used
self.client_mock.search.assert_called_once_with(index=self.es_db.collection_name, body={"custom_key": "custom_value"}) self.client_mock.search.assert_called_once_with(index=self.es_db.collection_name, body={"custom_key": "custom_value"})

View File

@@ -126,15 +126,15 @@ class TestOpenSearchDB(unittest.TestCase):
def test_search(self): def test_search(self):
mock_response = {"hits": {"hits": [{"_id": "id1", "_score": 0.8, "_source": {"vector": [0.1] * 1536, "metadata": {"key1": "value1"}}}]}} mock_response = {"hits": {"hits": [{"_id": "id1", "_score": 0.8, "_source": {"vector": [0.1] * 1536, "metadata": {"key1": "value1"}}}]}}
self.client_mock.search.return_value = mock_response self.client_mock.search.return_value = mock_response
query_vector = [0.1] * 1536 vectors = [[0.1] * 1536]
results = self.os_db.search(query=query_vector, limit=5) results = self.os_db.search(query="", vectors=vectors, limit=5)
self.client_mock.search.assert_called_once() self.client_mock.search.assert_called_once()
search_args = self.client_mock.search.call_args[1] search_args = self.client_mock.search.call_args[1]
self.assertEqual(search_args["index"], "test_collection") self.assertEqual(search_args["index"], "test_collection")
body = search_args["body"] body = search_args["body"]
self.assertIn("knn", body["query"]) self.assertIn("knn", body["query"])
self.assertIn("vector", body["query"]["knn"]) self.assertIn("vector", body["query"]["knn"])
self.assertEqual(body["query"]["knn"]["vector"]["vector"], query_vector) self.assertEqual(body["query"]["knn"]["vector"]["vector"], vectors)
self.assertEqual(body["query"]["knn"]["vector"]["k"], 5) self.assertEqual(body["query"]["knn"]["vector"]["k"], 5)
self.assertEqual(len(results), 1) self.assertEqual(len(results), 1)
self.assertEqual(results[0].id, "id1") self.assertEqual(results[0].id, "id1")

View File

@@ -50,15 +50,15 @@ class TestQdrant(unittest.TestCase):
self.assertEqual(points[0].payload, payloads[0]) self.assertEqual(points[0].payload, payloads[0])
def test_search(self): def test_search(self):
query_vector = [0.1, 0.2] vectors = [[0.1, 0.2]]
mock_point = MagicMock(id=str(uuid.uuid4()), score=0.95, payload={"key": "value"}) mock_point = MagicMock(id=str(uuid.uuid4()), score=0.95, payload={"key": "value"})
self.client_mock.query_points.return_value = MagicMock(points=[mock_point]) self.client_mock.query_points.return_value = MagicMock(points=[mock_point])
results = self.qdrant.search(query=query_vector, limit=1) results = self.qdrant.search(query="", vectors=vectors, limit=1)
self.client_mock.query_points.assert_called_once_with( self.client_mock.query_points.assert_called_once_with(
collection_name="test_collection", collection_name="test_collection",
query=query_vector, query=vectors,
query_filter=None, query_filter=None,
limit=1, limit=1,
) )

View File

@@ -77,12 +77,12 @@ def test_search_vectors(supabase_instance, mock_collection):
] ]
mock_collection.query.return_value = mock_results mock_collection.query.return_value = mock_results
query = [0.1, 0.2, 0.3] vectors = [[0.1, 0.2, 0.3]]
filters = {"category": "test"} filters = {"category": "test"}
results = supabase_instance.search(query=query, limit=2, filters=filters) results = supabase_instance.search(query="", vectors=vectors, limit=2, filters=filters)
mock_collection.query.assert_called_once_with( mock_collection.query.assert_called_once_with(
data=query, data=vectors,
limit=2, limit=2,
filters={"category": {"$eq": "test"}}, filters={"category": {"$eq": "test"}},
include_metadata=True, include_metadata=True,

View File

@@ -73,12 +73,12 @@ def test_insert_vectors(vector_store, mock_vertex_ai):
def test_search_vectors(vector_store, mock_vertex_ai): def test_search_vectors(vector_store, mock_vertex_ai):
"""Test searching vectors with filters""" """Test searching vectors with filters"""
query = [0.1, 0.2, 0.3] vectors = [[0.1, 0.2, 0.3]]
filters = {"user_id": "test_user"} filters = {"user_id": "test_user"}
mock_datapoint = Mock() mock_datapoint = Mock()
mock_datapoint.datapoint_id = "test-id" mock_datapoint.datapoint_id = "test-id"
mock_datapoint.feature_vector = query mock_datapoint.feature_vector = vectors
mock_restrict = Mock() mock_restrict = Mock()
mock_restrict.namespace = "user_id" mock_restrict.namespace = "user_id"
@@ -96,11 +96,11 @@ def test_search_vectors(vector_store, mock_vertex_ai):
mock_vertex_ai['endpoint'].find_neighbors.return_value = [[mock_neighbor]] mock_vertex_ai['endpoint'].find_neighbors.return_value = [[mock_neighbor]]
results = vector_store.search(query=query, filters=filters, limit=1) results = vector_store.search(query="", vectors=vectors, filters=filters, limit=1)
mock_vertex_ai['endpoint'].find_neighbors.assert_called_once_with( mock_vertex_ai['endpoint'].find_neighbors.assert_called_once_with(
deployed_index_id=vector_store.deployment_index_id, deployed_index_id=vector_store.deployment_index_id,
queries=[query], queries=[vectors],
num_neighbors=1, num_neighbors=1,
filter=[Namespace("user_id", ["test_user"], [])], filter=[Namespace("user_id", ["test_user"], [])],
return_full_datapoint=True return_full_datapoint=True

View File

@@ -147,8 +147,8 @@ class TestWeaviateDB(unittest.TestCase):
self.client_mock.collections.get.return_value.query.hybrid = mock_hybrid self.client_mock.collections.get.return_value.query.hybrid = mock_hybrid
mock_hybrid.return_value = mock_response mock_hybrid.return_value = mock_response
query_vector = [0.1] * 1536 vectors = [[0.1] * 1536]
results = self.weaviate_db.search(query=query_vector, limit=5) results = self.weaviate_db.search(query="", vectors=vectors, limit=5)
mock_hybrid.assert_called_once() mock_hybrid.assert_called_once()