Support for hybrid search in Azure AI vector store (#2408)

Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
Dev Khant
2025-03-20 22:57:00 +05:30
committed by GitHub
parent 8b9a8e5825
commit 8e6a08aa83
24 changed files with 275 additions and 294 deletions

View File

@@ -45,8 +45,10 @@ class AzureAISearch(VectorStoreBase):
collection_name,
api_key,
embedding_model_dims,
compression_type: Optional[str] = None,
compression_type: Optional[str] = None,
use_float16: bool = False,
hybrid_search: bool = False,
vector_filter_mode: Optional[str] = None,
):
"""
Initialize the Azure AI Search vector store.
@@ -60,13 +62,17 @@ class AzureAISearch(VectorStoreBase):
Allowed values are None (no quantization), "scalar", or "binary".
use_float16 (bool): Whether to store vectors in half precision (Edm.Half) or full precision (Edm.Single).
(Note: This flag is preserved from the initial implementation per feedback.)
hybrid_search (bool): Whether to use hybrid search. Default is False.
vector_filter_mode (Optional[str]): Mode for vector filtering. Default is "preFilter".
"""
self.index_name = collection_name
self.collection_name = collection_name
self.embedding_model_dims = embedding_model_dims
# If compression_type is None, treat it as "none".
self.compression_type = (compression_type or "none").lower()
self.compression_type = (compression_type or "none").lower()
self.use_float16 = use_float16
self.hybrid_search = hybrid_search
self.vector_filter_mode = vector_filter_mode
self.search_client = SearchClient(
endpoint=f"https://{service_name}.search.windows.net",
@@ -113,8 +119,6 @@ class AzureAISearch(VectorStoreBase):
)
]
# If no compression is desired, compression_configurations remains empty.
fields = [
SimpleField(name="id", type=SearchFieldDataType.String, key=True),
SimpleField(name="user_id", type=SearchFieldDataType.String, filterable=True),
@@ -123,11 +127,11 @@ class AzureAISearch(VectorStoreBase):
SearchField(
name="vector",
type=vector_type,
searchable=True,
searchable=True,
vector_search_dimensions=self.embedding_model_dims,
vector_search_profile_name="my-vector-config",
),
SimpleField(name="payload", type=SearchFieldDataType.String, searchable=True),
SearchField(name="payload", type=SearchFieldDataType.String, searchable=True),
]
vector_search = VectorSearch(
@@ -135,7 +139,7 @@ class AzureAISearch(VectorStoreBase):
VectorSearchProfile(
name="my-vector-config",
algorithm_configuration_name="my-algorithms-config",
compression_name=compression_name if self.compression_type != "none" else None
compression_name=compression_name if self.compression_type != "none" else None,
)
],
algorithms=[HnswAlgorithmConfiguration(name="my-algorithms-config")],
@@ -164,8 +168,7 @@ class AzureAISearch(VectorStoreBase):
"""
logger.info(f"Inserting {len(vectors)} vectors into index {self.index_name}")
documents = [
self._generate_document(vector, payload, id)
for id, vector, payload in zip(ids, vectors, payloads)
self._generate_document(vector, payload, id) for id, vector, payload in zip(ids, vectors, payloads)
]
response = self.search_client.upload_documents(documents)
for doc in response:
@@ -189,12 +192,13 @@ class AzureAISearch(VectorStoreBase):
filter_expression = " and ".join(filter_conditions)
return filter_expression
def search(self, query, limit=5, filters=None):
def search(self, query, vectors, limit=5, filters=None):
"""
Search for similar vectors.
Args:
query (List[float]): Query vector.
query (str): Query.
vectors (List[float]): Query vector.
limit (int, optional): Number of results to return. Defaults to 5.
filters (Dict, optional): Filters to apply to the search. Defaults to None.
@@ -205,23 +209,28 @@ class AzureAISearch(VectorStoreBase):
if filters:
filter_expression = self._build_filter_expression(filters)
vector_query = VectorizedQuery(
vector=query, k_nearest_neighbors=limit, fields="vector"
)
search_results = self.search_client.search(
vector_queries=[vector_query],
filter=filter_expression,
top=limit
)
vector_query = VectorizedQuery(vector=vectors, k_nearest_neighbors=limit, fields="vector")
if self.hybrid_search:
search_results = self.search_client.search(
search_text=query,
vector_queries=[vector_query],
filter=filter_expression,
top=limit,
vector_filter_mode=self.vector_filter_mode,
search_fields=["payload"],
)
else:
search_results = self.search_client.search(
vector_queries=[vector_query],
filter=filter_expression,
top=limit,
vector_filter_mode=self.vector_filter_mode,
)
results = []
for result in search_results:
payload = json.loads(result["payload"])
results.append(
OutputData(
id=result["id"], score=result["@search.score"], payload=payload
)
)
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
return results
def delete(self, vector_id):
@@ -275,9 +284,7 @@ class AzureAISearch(VectorStoreBase):
result = self.search_client.get_document(key=vector_id)
except ResourceNotFoundError:
return None
return OutputData(
id=result["id"], score=None, payload=json.loads(result["payload"])
)
return OutputData(id=result["id"], score=None, payload=json.loads(result["payload"]))
def list_cols(self) -> List[str]:
"""
@@ -321,17 +328,11 @@ class AzureAISearch(VectorStoreBase):
if filters:
filter_expression = self._build_filter_expression(filters)
search_results = self.search_client.search(
search_text="*", filter=filter_expression, top=limit
)
search_results = self.search_client.search(search_text="*", filter=filter_expression, top=limit)
results = []
for result in search_results:
payload = json.loads(result["payload"])
results.append(
OutputData(
id=result["id"], score=result["@search.score"], payload=payload
)
)
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
return [results]
def __del__(self):

View File

@@ -13,7 +13,7 @@ class VectorStoreBase(ABC):
pass
@abstractmethod
def search(self, query, limit=5, filters=None):
def search(self, query, vectors, limit=5, filters=None):
"""Search for similar vectors."""
pass

View File

@@ -127,19 +127,22 @@ class ChromaDB(VectorStoreBase):
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
self.collection.add(ids=ids, embeddings=vectors, metadatas=payloads)
def search(self, query: List[list], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]:
def search(
self, query: str, vectors: List[list], limit: int = 5, filters: Optional[Dict] = None
) -> List[OutputData]:
"""
Search for similar vectors.
Args:
query (List[list]): Query vector.
query (str): Query.
vectors (List[list]): List of vectors to search.
limit (int, optional): Number of results to return. Defaults to 5.
filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None.
Returns:
List[OutputData]: Search results.
"""
results = self.collection.query(query_embeddings=query, where=filters, n_results=limit)
results = self.collection.query(query_embeddings=vectors, where=filters, n_results=limit)
final_results = self._parse_output(results)
return final_results

View File

@@ -45,7 +45,7 @@ class ElasticsearchDB(VectorStoreBase):
# Create index only if auto_create_index is True
if config.auto_create_index:
self.create_index()
if config.custom_search_query:
self.custom_search_query = config.custom_search_query
else:
@@ -121,16 +121,20 @@ class ElasticsearchDB(VectorStoreBase):
)
return results
def search(self, query: List[float], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]:
def search(
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
) -> List[OutputData]:
"""
Search with two options:
1. Use custom search query if provided
2. Use KNN search on vectors with pre-filtering if no custom search query is provided
"""
if self.custom_search_query:
search_query = self.custom_search_query(query, limit, filters)
search_query = self.custom_search_query(vectors, limit, filters)
else:
search_query = {"knn": {"field": "vector", "query_vector": query, "k": limit, "num_candidates": limit * 2}}
search_query = {
"knn": {"field": "vector", "query_vector": vectors, "k": limit, "num_candidates": limit * 2}
}
if filters:
filter_conditions = []
for key, value in filters.items():

View File

@@ -134,12 +134,13 @@ class MilvusDB(VectorStoreBase):
return memory
def search(self, query: list, limit: int = 5, filters: dict = None) -> list:
def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None) -> list:
"""
Search for similar vectors.
Args:
query (List[float]): Query vector.
query (str): Query.
vectors (List[float]): Query vector.
limit (int, optional): Number of results to return. Defaults to 5.
filters (Dict, optional): Filters to apply to the search. Defaults to None.
@@ -149,7 +150,7 @@ class MilvusDB(VectorStoreBase):
query_filter = self._create_filter(filters) if filters else None
hits = self.client.search(
collection_name=self.collection_name,
data=[query],
data=[vectors],
limit=limit,
filter=query_filter,
output_fields=["*"],

View File

@@ -28,10 +28,12 @@ class OpenSearchDB(VectorStoreBase):
# Initialize OpenSearch client
self.client = OpenSearch(
hosts=[{"host": config.host, "port": config.port or 9200}],
http_auth=config.http_auth if config.http_auth else ((config.user, config.password) if (config.user and config.password) else None),
http_auth=config.http_auth
if config.http_auth
else ((config.user, config.password) if (config.user and config.password) else None),
use_ssl=config.use_ssl,
verify_certs=config.verify_certs,
connection_class=RequestsHttpConnection
connection_class=RequestsHttpConnection,
)
self.collection_name = config.collection_name
@@ -115,14 +117,16 @@ class OpenSearchDB(VectorStoreBase):
results.append(OutputData(id=id_, score=1.0, payload=payloads[i]))
return results
def search(self, query: List[float], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]:
def search(
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
) -> List[OutputData]:
"""Search for similar vectors using OpenSearch k-NN search with pre-filtering."""
search_query = {
"size": limit,
"query": {
"knn": {
"vector": {
"vector": query,
"vector": vectors,
"k": limit,
}
}

View File

@@ -120,12 +120,13 @@ class PGVector(VectorStoreBase):
)
self.conn.commit()
def search(self, query, limit=5, filters=None):
def search(self, query, vectors, limit=5, filters=None):
"""
Search for similar vectors.
Args:
query (List[float]): Query vector.
query (str): Query.
vectors (List[float]): Query vector.
limit (int, optional): Number of results to return. Defaults to 5.
filters (Dict, optional): Filters to apply to the search. Defaults to None.
@@ -150,7 +151,7 @@ class PGVector(VectorStoreBase):
ORDER BY distance
LIMIT %s
""",
(query, *filter_params, limit),
(vectors, *filter_params, limit),
)
results = self.cur.fetchall()

View File

@@ -127,12 +127,13 @@ class Qdrant(VectorStoreBase):
conditions.append(FieldCondition(key=key, match=MatchValue(value=value)))
return Filter(must=conditions) if conditions else None
def search(self, query: list, limit: int = 5, filters: dict = None) -> list:
def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None) -> list:
"""
Search for similar vectors.
Args:
query (list): Query vector.
query (str): Query.
vectors (list): Query vector.
limit (int, optional): Number of results to return. Defaults to 5.
filters (dict, optional): Filters to apply to the search. Defaults to None.
@@ -142,7 +143,7 @@ class Qdrant(VectorStoreBase):
query_filter = self._create_filter(filters) if filters else None
hits = self.client.query_points(
collection_name=self.collection_name,
query=query,
query=vectors,
query_filter=query_filter,
limit=limit,
)

View File

@@ -101,12 +101,12 @@ class RedisDB(VectorStoreBase):
data.append(entry)
self.index.load(data, id_field="memory_id")
def search(self, query: list, limit: int = 5, filters: dict = None):
def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None):
conditions = [Tag(key) == value for key, value in filters.items() if value is not None]
filter = reduce(lambda x, y: x & y, conditions)
v = VectorQuery(
vector=np.array(query, dtype=np.float32).tobytes(),
vector=np.array(vectors, dtype=np.float32).tobytes(),
vector_field_name="embedding",
return_fields=["memory_id", "hash", "agent_id", "run_id", "user_id", "memory", "metadata", "created_at"],
filter_expression=filter,

View File

@@ -112,16 +112,18 @@ class Supabase(VectorStoreBase):
payloads = [{} for _ in vectors]
records = [(id, vector, payload) for id, vector, payload in zip(ids, vectors, payloads)]
print(records)
self.collection.upsert(records)
def search(self, query: List[float], limit: int = 5, filters: Optional[dict] = None) -> List[OutputData]:
def search(
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[dict] = None
) -> List[OutputData]:
"""
Search for similar vectors.
Args:
query (List[float]): Query vector
query (str): Query.
vectors (List[float]): Query vector.
limit (int, optional): Number of results to return. Defaults to 5.
filters (Dict, optional): Filters to apply to the search. Defaults to None.
@@ -129,11 +131,9 @@ class Supabase(VectorStoreBase):
List[OutputData]: Search results
"""
filters = self._preprocess_filters(filters)
print(filters)
results = self.collection.query(
data=query, limit=limit, filters=filters, include_metadata=True, include_value=True
data=vectors, limit=limit, filters=filters, include_metadata=True, include_value=True
)
print(results)
return [OutputData(id=str(result[0]), score=float(result[1]), payload=result[2]) for result in results]

View File

@@ -32,19 +32,19 @@ class GoogleMatchingEngine(VectorStoreBase):
def __init__(self, **kwargs):
"""Initialize Google Matching Engine client."""
logger.debug("Initializing Google Matching Engine with kwargs: %s", kwargs)
# If collection_name is passed, use it as deployment_index_id if deployment_index_id is not provided
if 'collection_name' in kwargs and 'deployment_index_id' not in kwargs:
kwargs['deployment_index_id'] = kwargs['collection_name']
logger.debug("Using collection_name as deployment_index_id: %s", kwargs['deployment_index_id'])
elif 'deployment_index_id' in kwargs and 'collection_name' not in kwargs:
kwargs['collection_name'] = kwargs['deployment_index_id']
logger.debug("Using deployment_index_id as collection_name: %s", kwargs['collection_name'])
if "collection_name" in kwargs and "deployment_index_id" not in kwargs:
kwargs["deployment_index_id"] = kwargs["collection_name"]
logger.debug("Using collection_name as deployment_index_id: %s", kwargs["deployment_index_id"])
elif "deployment_index_id" in kwargs and "collection_name" not in kwargs:
kwargs["collection_name"] = kwargs["deployment_index_id"]
logger.debug("Using deployment_index_id as collection_name: %s", kwargs["collection_name"])
try:
config = GoogleMatchingEngineConfig(**kwargs)
logger.debug("Config created: %s", config.model_dump())
logger.debug("Config collection_name: %s", getattr(config, 'collection_name', None))
logger.debug("Config collection_name: %s", getattr(config, "collection_name", None))
except Exception as e:
logger.error("Failed to validate config: %s", str(e))
raise
@@ -57,41 +57,37 @@ class GoogleMatchingEngine(VectorStoreBase):
self.deployment_index_id = config.deployment_index_id # The deployment-specific ID
self.collection_name = config.collection_name
self.vector_search_api_endpoint = config.vector_search_api_endpoint
logger.debug("Using project=%s, location=%s", self.project_id, self.region)
# Initialize Vertex AI with credentials if provided
init_args = {
"project": self.project_id,
"location": self.region,
}
if hasattr(config, 'credentials_path') and config.credentials_path:
if hasattr(config, "credentials_path") and config.credentials_path:
logger.debug("Using credentials from: %s", config.credentials_path)
credentials = service_account.Credentials.from_service_account_file(
config.credentials_path
)
credentials = service_account.Credentials.from_service_account_file(config.credentials_path)
init_args["credentials"] = credentials
try:
aiplatform.init(**init_args)
logger.debug("Vertex AI initialized successfully")
except Exception as e:
logger.error("Failed to initialize Vertex AI: %s", str(e))
raise
try:
# Format the index path properly using the configured index_id
index_path = f"projects/{self.project_number}/locations/{self.region}/indexes/{self.index_id}"
logger.debug("Initializing index with path: %s", index_path)
self.index = aiplatform.MatchingEngineIndex(index_name=index_path)
logger.debug("Index initialized successfully")
# Format the endpoint name properly
endpoint_name = self.endpoint_id
logger.debug("Initializing endpoint with name: %s", endpoint_name)
self.index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
index_endpoint_name=endpoint_name
)
self.index_endpoint = aiplatform.MatchingEngineIndexEndpoint(index_endpoint_name=endpoint_name)
logger.debug("Endpoint initialized successfully")
except Exception as e:
logger.error("Failed to initialize Matching Engine components: %s", str(e))
@@ -119,47 +115,36 @@ class GoogleMatchingEngine(VectorStoreBase):
def _create_restriction(self, key: str, value: Any) -> aiplatform_v1.types.index.IndexDatapoint.Restriction:
"""Create a restriction object for the Matching Engine index.
Args:
key: The namespace/key for the restriction
value: The value to restrict on
Returns:
Restriction object for the index
"""
str_value = str(value) if value is not None else ""
return aiplatform_v1.types.index.IndexDatapoint.Restriction(
namespace=key,
allow_list=[str_value]
)
return aiplatform_v1.types.index.IndexDatapoint.Restriction(namespace=key, allow_list=[str_value])
def _create_datapoint(
self,
vector_id: str,
vector: List[float],
payload: Optional[Dict] = None
self, vector_id: str, vector: List[float], payload: Optional[Dict] = None
) -> aiplatform_v1.types.index.IndexDatapoint:
"""Create a datapoint object for the Matching Engine index.
Args:
vector_id: The ID for the datapoint
vector: The vector to store
payload: Optional metadata to store with the vector
Returns:
IndexDatapoint object
"""
restrictions = []
if payload:
restrictions = [
self._create_restriction(key, value)
for key, value in payload.items()
]
restrictions = [self._create_restriction(key, value) for key, value in payload.items()]
return aiplatform_v1.types.index.IndexDatapoint(
datapoint_id=vector_id,
feature_vector=vector,
restricts=restrictions
datapoint_id=vector_id, feature_vector=vector, restricts=restrictions
)
def insert(
@@ -169,41 +154,41 @@ class GoogleMatchingEngine(VectorStoreBase):
ids: Optional[List[str]] = None,
) -> None:
"""Insert vectors into the Matching Engine index.
Args:
vectors: List of vectors to insert
payloads: Optional list of metadata dictionaries
ids: Optional list of IDs for the vectors
Raises:
ValueError: If vectors is empty or lengths don't match
GoogleAPIError: If the API call fails
"""
if not vectors:
raise ValueError("No vectors provided for insertion")
if payloads and len(payloads) != len(vectors):
raise ValueError(f"Number of payloads ({len(payloads)}) does not match number of vectors ({len(vectors)})")
if ids and len(ids) != len(vectors):
raise ValueError(f"Number of ids ({len(ids)}) does not match number of vectors ({len(vectors)})")
logger.debug("Starting insert of %d vectors", len(vectors))
try:
datapoints = [
self._create_datapoint(
vector_id=ids[i] if ids else str(uuid.uuid4()),
vector=vector,
payload=payloads[i] if payloads and i < len(payloads) else None
payload=payloads[i] if payloads and i < len(payloads) else None,
)
for i, vector in enumerate(vectors)
]
logger.debug("Created %d datapoints", len(datapoints))
self.index.upsert_datapoints(datapoints=datapoints)
logger.debug("Successfully inserted datapoints")
except google.api_core.exceptions.GoogleAPIError as e:
logger.error("Failed to insert vectors: %s", str(e))
raise
@@ -212,21 +197,22 @@ class GoogleMatchingEngine(VectorStoreBase):
logger.error("Stack trace: %s", traceback.format_exc())
raise
def search(self, query: List[float], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]:
def search(
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
) -> List[OutputData]:
"""
Search for similar vectors.
Args:
query (List[float]): Query vector.
query (str): Query.
vectors (List[float]): Query vector.
limit (int, optional): Number of results to return. Defaults to 5.
filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None.
Returns:
List[OutputData]: Search results (unwrapped)
"""
logger.debug("Starting search")
logger.debug("Query type: %s, length: %d", type(query), len(query))
logger.debug("Limit: %d, Filters: %s", limit, filters)
try:
filter_namespaces = []
if filters:
@@ -235,53 +221,42 @@ class GoogleMatchingEngine(VectorStoreBase):
logger.debug("Processing filter %s=%s (type=%s)", key, value, type(value))
if isinstance(value, (str, int, float)):
logger.debug("Adding simple filter for %s", key)
filter_namespaces.append(
Namespace(key, [str(value)], [])
)
filter_namespaces.append(Namespace(key, [str(value)], []))
elif isinstance(value, dict):
logger.debug("Adding complex filter for %s", key)
includes = value.get('include', [])
excludes = value.get('exclude', [])
filter_namespaces.append(
Namespace(key, includes, excludes)
)
includes = value.get("include", [])
excludes = value.get("exclude", [])
filter_namespaces.append(Namespace(key, includes, excludes))
logger.debug("Final filter_namespaces: %s", filter_namespaces)
response = self.index_endpoint.find_neighbors(
deployed_index_id=self.deployment_index_id,
queries=[query],
queries=[vectors],
num_neighbors=limit,
filter=filter_namespaces if filter_namespaces else None,
return_full_datapoint=True
return_full_datapoint=True,
)
if not response or len(response) == 0 or len(response[0]) == 0:
logger.debug("No results found")
return []
results = []
for neighbor in response[0]:
logger.debug("Processing neighbor - id: %s, distance: %s",
neighbor.id, neighbor.distance)
logger.debug("Processing neighbor - id: %s, distance: %s", neighbor.id, neighbor.distance)
payload = {}
if hasattr(neighbor, 'restricts'):
if hasattr(neighbor, "restricts"):
logger.debug("Processing restricts")
for restrict in neighbor.restricts:
if (hasattr(restrict, 'name') and
hasattr(restrict, 'allow_tokens') and
restrict.allow_tokens):
if hasattr(restrict, "name") and hasattr(restrict, "allow_tokens") and restrict.allow_tokens:
logger.debug("Adding %s: %s", restrict.name, restrict.allow_tokens[0])
payload[restrict.name] = restrict.allow_tokens[0]
output_data = OutputData(
id=neighbor.id,
score=neighbor.distance,
payload=payload
)
output_data = OutputData(id=neighbor.id, score=neighbor.distance, payload=payload)
results.append(output_data)
logger.debug("Returning %d results", len(results))
return results
@@ -291,7 +266,6 @@ class GoogleMatchingEngine(VectorStoreBase):
logger.error("Stack trace: %s", traceback.format_exc())
raise
def delete(self, vector_id: Optional[str] = None, ids: Optional[List[str]] = None) -> bool:
"""
Delete vectors from the Matching Engine index.
@@ -326,14 +300,13 @@ class GoogleMatchingEngine(VectorStoreBase):
except google.api_core.exceptions.InvalidArgument as e:
logger.error("Invalid argument: %s", str(e))
return False
except Exception as e:
logger.error("Error occurred: %s", str(e))
logger.error("Error type: %s", type(e))
logger.error("Stack trace: %s", traceback.format_exc())
return False
def update(
self,
vector_id: str,
@@ -341,42 +314,40 @@ class GoogleMatchingEngine(VectorStoreBase):
payload: Optional[Dict] = None,
) -> bool:
"""Update a vector and its payload.
Args:
vector_id: ID of the vector to update
vector: Optional new vector values
payload: Optional new metadata payload
Returns:
bool: True if update was successful
Raises:
ValueError: If neither vector nor payload is provided
GoogleAPIError: If the API call fails
"""
logger.debug("Starting update for vector_id: %s", vector_id)
if vector is None and payload is None:
raise ValueError("Either vector or payload must be provided for update")
# First check if the vector exists
try:
existing = self.get(vector_id)
if existing is None:
logger.error("Vector ID not found: %s", vector_id)
return False
datapoint = self._create_datapoint(
vector_id=vector_id,
vector=vector if vector is not None else [],
payload=payload
vector_id=vector_id, vector=vector if vector is not None else [], payload=payload
)
logger.debug("Upserting datapoint: %s", datapoint)
self.index.upsert_datapoints(datapoints=[datapoint])
logger.debug("Update completed successfully")
return True
except google.api_core.exceptions.GoogleAPIError as e:
logger.error("API error during update: %s", str(e))
return False
@@ -385,7 +356,6 @@ class GoogleMatchingEngine(VectorStoreBase):
logger.error("Stack trace: %s", traceback.format_exc())
raise
def get(self, vector_id: str) -> Optional[OutputData]:
"""
Retrieve a vector by ID.
@@ -395,24 +365,17 @@ class GoogleMatchingEngine(VectorStoreBase):
Optional[OutputData]: Retrieved vector or None if not found.
"""
logger.debug("Starting get for vector_id: %s", vector_id)
try:
if not self.vector_search_api_endpoint:
raise ValueError("vector_search_api_endpoint is required for get operation")
vector_search_client = aiplatform_v1.MatchServiceClient(
client_options={
"api_endpoint": self.vector_search_api_endpoint
},
)
datapoint = aiplatform_v1.IndexDatapoint(
datapoint_id=vector_id
client_options={"api_endpoint": self.vector_search_api_endpoint},
)
datapoint = aiplatform_v1.IndexDatapoint(datapoint_id=vector_id)
query = aiplatform_v1.FindNeighborsRequest.Query(
datapoint=datapoint,
neighbor_count=1
)
query = aiplatform_v1.FindNeighborsRequest.Query(datapoint=datapoint, neighbor_count=1)
request = aiplatform_v1.FindNeighborsRequest(
index_endpoint=f"projects/{self.project_number}/locations/{self.region}/indexEndpoints/{self.endpoint_id}",
deployed_index_id=self.deployment_index_id,
@@ -423,41 +386,36 @@ class GoogleMatchingEngine(VectorStoreBase):
try:
response = vector_search_client.find_neighbors(request)
logger.debug("Got response")
if response and response.nearest_neighbors:
nearest = response.nearest_neighbors[0]
if nearest.neighbors:
neighbor = nearest.neighbors[0]
payload = {}
if hasattr(neighbor.datapoint, 'restricts'):
if hasattr(neighbor.datapoint, "restricts"):
for restrict in neighbor.datapoint.restricts:
if restrict.allow_list:
payload[restrict.namespace] = restrict.allow_list[0]
return OutputData(
id=neighbor.datapoint.datapoint_id,
score=neighbor.distance,
payload=payload
)
return OutputData(id=neighbor.datapoint.datapoint_id, score=neighbor.distance, payload=payload)
logger.debug("No results found")
return None
except google.api_core.exceptions.NotFound:
logger.debug("Datapoint not found")
return None
except google.api_core.exceptions.PermissionDenied as e:
logger.error("Permission denied: %s", str(e))
return None
except Exception as e:
logger.error("Error occurred: %s", str(e))
logger.error("Error type: %s", type(e))
logger.error("Stack trace: %s", traceback.format_exc())
raise
def list_cols(self) -> List[str]:
"""
List all collections (indexes).
@@ -466,7 +424,6 @@ class GoogleMatchingEngine(VectorStoreBase):
"""
return [self.deployment_index_id]
def delete_col(self):
"""
Delete a collection (index).
@@ -475,7 +432,6 @@ class GoogleMatchingEngine(VectorStoreBase):
logger.warning("Delete collection operation is not supported for Google Matching Engine")
pass
def col_info(self) -> Dict:
"""
Get information about a collection (index).
@@ -486,17 +442,16 @@ class GoogleMatchingEngine(VectorStoreBase):
"index_id": self.index_id,
"endpoint_id": self.endpoint_id,
"project_id": self.project_id,
"region": self.region
"region": self.region,
}
def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[List[OutputData]]:
"""List vectors matching the given filters.
Args:
filters: Optional filters to apply
limit: Optional maximum number of results to return
Returns:
List[List[OutputData]]: List of matching vectors wrapped in an extra array
to match the interface
@@ -504,36 +459,31 @@ class GoogleMatchingEngine(VectorStoreBase):
logger.debug("Starting list operation")
logger.debug("Filters: %s", filters)
logger.debug("Limit: %s", limit)
try:
# Use a zero vector for the search
dimension = 768 # This should be configurable based on the model
zero_vector = [0.0] * dimension
# Use a large limit if none specified
search_limit = limit if limit is not None else 10000
results = self.search(
query=zero_vector,
limit=search_limit,
filters=filters
)
results = self.search(query=zero_vector, limit=search_limit, filters=filters)
logger.debug("Found %d results", len(results))
return [results] # Wrap in extra array to match interface
except Exception as e:
logger.error("Error in list operation: %s", str(e))
logger.error("Stack trace: %s", traceback.format_exc())
raise
def create_col(self, name=None, vector_size=None, distance=None):
"""
Create a new collection. For Google Matching Engine, collections (indexes)
Create a new collection. For Google Matching Engine, collections (indexes)
are created through the Google Cloud Console or API separately.
This method is a no-op since indexes are pre-created.
Args:
name: Ignored for Google Matching Engine
vector_size: Ignored for Google Matching Engine
@@ -543,41 +493,35 @@ class GoogleMatchingEngine(VectorStoreBase):
# This method is included only to satisfy the abstract base class
pass
def add(self, text: str, metadata: Optional[Dict] = None, user_id: Optional[str] = None) -> str:
logger.debug("Starting add operation")
logger.debug("Text: %s", text)
logger.debug("Metadata: %s", metadata)
logger.debug("User ID: %s", user_id)
try:
# Generate a unique ID for this entry
vector_id = str(uuid.uuid4())
# Create the payload with all necessary fields
payload = {
"data": text, # Store the text in the data field
"user_id": user_id,
**(metadata or {})
**(metadata or {}),
}
# Get the embedding
vector = self.embedder.embed_query(text)
# Insert using the insert method
self.insert(
vectors=[vector],
payloads=[payload],
ids=[vector_id]
)
self.insert(vectors=[vector], payloads=[payload], ids=[vector_id])
return vector_id
except Exception as e:
logger.error("Error occurred: %s", str(e))
raise
def add_texts(
self,
texts: List[str],
@@ -585,47 +529,45 @@ class GoogleMatchingEngine(VectorStoreBase):
ids: Optional[List[str]] = None,
) -> List[str]:
"""Add texts to the vector store.
Args:
texts: List of texts to add
metadatas: Optional list of metadata dicts
ids: Optional list of IDs to use
Returns:
List[str]: List of IDs of the added texts
Raises:
ValueError: If texts is empty or lengths don't match
"""
if not texts:
raise ValueError("No texts provided")
if metadatas and len(metadatas) != len(texts):
raise ValueError(f"Number of metadata items ({len(metadatas)}) does not match number of texts ({len(texts)})")
raise ValueError(
f"Number of metadata items ({len(metadatas)}) does not match number of texts ({len(texts)})"
)
if ids and len(ids) != len(texts):
raise ValueError(f"Number of ids ({len(ids)}) does not match number of texts ({len(texts)})")
logger.debug("Starting add_texts operation")
logger.debug("Number of texts: %d", len(texts))
logger.debug("Has metadatas: %s", metadatas is not None)
logger.debug("Has ids: %s", ids is not None)
if ids is None:
ids = [str(uuid.uuid4()) for _ in texts]
try:
# Get embeddings
embeddings = self.embedder.embed_documents(texts)
# Add to store
self.insert(
vectors=embeddings,
payloads=metadatas if metadatas else [{}] * len(texts),
ids=ids
)
self.insert(vectors=embeddings, payloads=metadatas if metadatas else [{}] * len(texts), ids=ids)
return ids
except Exception as e:
logger.error("Error in add_texts: %s", str(e))
logger.error("Stack trace: %s", traceback.format_exc())
@@ -657,18 +599,12 @@ class GoogleMatchingEngine(VectorStoreBase):
logger.debug("Query: %s", query)
logger.debug("k: %d", k)
logger.debug("Filter: %s", filter)
embedding = self.embedder.embed_query(query)
results = self.search(query=embedding, limit=k, filters=filter)
docs_and_scores = [
(
Document(
page_content=result.payload.get("text", ""),
metadata=result.payload
),
result.score
)
(Document(page_content=result.payload.get("text", ""), metadata=result.payload), result.score)
for result in results
]
logger.debug("Found %d results", len(docs_and_scores))
@@ -684,4 +620,3 @@ class GoogleMatchingEngine(VectorStoreBase):
logger.debug("Starting similarity search")
docs_and_scores = self.similarity_search_with_score(query, k, filter)
return [doc for doc, _ in docs_and_scores]

View File

@@ -154,7 +154,9 @@ class Weaviate(VectorStoreBase):
batch.add_object(collection=self.collection_name, properties=data_object, uuid=object_id, vector=vector)
def search(self, query: List[float], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]:
def search(
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
) -> List[OutputData]:
"""
Search for similar vectors.
"""
@@ -167,7 +169,7 @@ class Weaviate(VectorStoreBase):
combined_filter = Filter.all_of(filter_conditions) if filter_conditions else None
response = collection.query.hybrid(
query="",
vector=query,
vector=vectors,
limit=limit,
filters=combined_filter,
return_properties=["hash", "created_at", "updated_at", "user_id", "agent_id", "run_id", "data", "category"],