From 8e6a08aa837a36d97da9e76db165b5c2ac50cfdc Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Thu, 20 Mar 2025 22:57:00 +0530 Subject: [PATCH] Support for hybrid search in Azure AI vector store (#2408) Co-authored-by: Deshraj Yadav --- .../vectordbs/dbs/azure_ai_search.mdx | 24 ++ mem0/configs/vector_stores/azure_ai_search.py | 27 +- mem0/configs/vector_stores/elasticsearch.py | 3 +- .../vector_stores/vertex_ai_vector_search.py | 6 +- mem0/memory/main.py | 14 +- mem0/vector_stores/azure_ai_search.py | 71 ++--- mem0/vector_stores/base.py | 2 +- mem0/vector_stores/chroma.py | 9 +- mem0/vector_stores/elasticsearch.py | 12 +- mem0/vector_stores/milvus.py | 7 +- mem0/vector_stores/opensearch.py | 12 +- mem0/vector_stores/pgvector.py | 7 +- mem0/vector_stores/qdrant.py | 7 +- mem0/vector_stores/redis.py | 4 +- mem0/vector_stores/supabase.py | 12 +- mem0/vector_stores/vertex_ai_vector_search.py | 297 +++++++----------- mem0/vector_stores/weaviate.py | 6 +- tests/vector_stores/test_chroma.py | 7 +- tests/vector_stores/test_elasticsearch.py | 12 +- tests/vector_stores/test_opensearch.py | 6 +- tests/vector_stores/test_qdrant.py | 6 +- tests/vector_stores/test_supabase.py | 6 +- .../test_vertex_ai_vector_search.py | 8 +- tests/vector_stores/test_weaviate.py | 4 +- 24 files changed, 275 insertions(+), 294 deletions(-) diff --git a/docs/components/vectordbs/dbs/azure_ai_search.mdx b/docs/components/vectordbs/dbs/azure_ai_search.mdx index 4a846314..4243cc3f 100644 --- a/docs/components/vectordbs/dbs/azure_ai_search.mdx +++ b/docs/components/vectordbs/dbs/azure_ai_search.mdx @@ -50,6 +50,24 @@ config = { } ``` +## Using hybrid search + +```python +config = { + "vector_store": { + "provider": "azure_ai_search", + "config": { + "service_name": "ai-search-test", + "api_key": "*****", + "collection_name": "mem0", + "embedding_model_dims": 1536, + "hybrid_search": True, + "vector_filter_mode": "postFilter" + } + } +} +``` + ## Configuration Parameters | Parameter | Description | Default Value | Options | @@ -60,6 +78,8 @@ config = { | `embedding_model_dims` | Dimensions of the embedding model | `1536` | Any integer value | | `compression_type` | Type of vector compression to use | `none` | `none`, `scalar`, `binary` | | `use_float16` | Store vectors in half precision (Edm.Half) | `False` | `True`, `False` | +| `vector_filter_mode` | Vector filter mode to use | `preFilter` | `postFilter`, `preFilter` | +| `hybrid_search` | Use hybrid search | `False` | `True`, `False` | ## Notes on Configuration Options @@ -68,6 +88,10 @@ config = { - `scalar`: Scalar quantization with reasonable balance of speed and accuracy - `binary`: Binary quantization for maximum compression with some accuracy trade-off +- **vector_filter_mode**: + - `preFilter`: Applies filters before vector search (faster) + - `postFilter`: Applies filters after vector search (may provide better relevance) + - **use_float16**: Using half precision (float16) reduces storage requirements but may slightly impact accuracy. Useful for very large vector collections. - **Filterable Fields**: The implementation automatically extracts `user_id`, `run_id`, and `agent_id` fields from payloads for filtering. \ No newline at end of file diff --git a/mem0/configs/vector_stores/azure_ai_search.py b/mem0/configs/vector_stores/azure_ai_search.py index b256e139..a248c9cc 100644 --- a/mem0/configs/vector_stores/azure_ai_search.py +++ b/mem0/configs/vector_stores/azure_ai_search.py @@ -8,21 +8,26 @@ class AzureAISearchConfig(BaseModel): api_key: str = Field(None, description="API key for the Azure AI Search service") embedding_model_dims: int = Field(None, description="Dimension of the embedding vector") compression_type: Optional[str] = Field( - None, - description="Type of vector compression to use. Options: 'scalar', 'binary', or None" + None, description="Type of vector compression to use. Options: 'scalar', 'binary', or None" ) use_float16: bool = Field( - False, - description="Whether to store vectors in half precision (Edm.Half) instead of full precision (Edm.Single)" + False, + description="Whether to store vectors in half precision (Edm.Half) instead of full precision (Edm.Single)", ) - + hybrid_search: bool = Field( + False, description="Whether to use hybrid search. If True, vector_filter_mode must be 'preFilter'" + ) + vector_filter_mode: Optional[str] = Field( + "preFilter", description="Mode for vector filtering. Options: 'preFilter', 'postFilter'" + ) + @model_validator(mode="before") @classmethod def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]: allowed_fields = set(cls.model_fields.keys()) input_fields = set(values.keys()) extra_fields = input_fields - allowed_fields - + # Check for use_compression to provide a helpful error if "use_compression" in extra_fields: raise ValueError( @@ -30,13 +35,13 @@ class AzureAISearchConfig(BaseModel): "Please use 'compression_type=\"scalar\"' instead of 'use_compression=True' " "or 'compression_type=None' instead of 'use_compression=False'." ) - + if extra_fields: raise ValueError( f"Extra fields not allowed: {', '.join(extra_fields)}. " f"Please input only the following fields: {', '.join(allowed_fields)}" ) - + # Validate compression_type values if "compression_type" in values and values["compression_type"] is not None: valid_types = ["scalar", "binary"] @@ -45,9 +50,9 @@ class AzureAISearchConfig(BaseModel): f"Invalid compression_type: {values['compression_type']}. " f"Must be one of: {', '.join(valid_types)}, or None" ) - + return values - + model_config = { "arbitrary_types_allowed": True, - } \ No newline at end of file + } diff --git a/mem0/configs/vector_stores/elasticsearch.py b/mem0/configs/vector_stores/elasticsearch.py index b0d3ef29..7f76c238 100644 --- a/mem0/configs/vector_stores/elasticsearch.py +++ b/mem0/configs/vector_stores/elasticsearch.py @@ -17,8 +17,7 @@ class ElasticsearchConfig(BaseModel): use_ssl: bool = Field(True, description="Use SSL for connection") auto_create_index: bool = Field(True, description="Automatically create index during initialization") custom_search_query: Optional[Callable[[List[float], int, Optional[Dict]], Dict]] = Field( - None, - description="Custom search query function. Parameters: (query, limit, filters) -> Dict" + None, description="Custom search query function. Parameters: (query, limit, filters) -> Dict" ) @model_validator(mode="before") diff --git a/mem0/configs/vector_stores/vertex_ai_vector_search.py b/mem0/configs/vector_stores/vertex_ai_vector_search.py index 9cc1cdbb..e2153290 100644 --- a/mem0/configs/vector_stores/vertex_ai_vector_search.py +++ b/mem0/configs/vector_stores/vertex_ai_vector_search.py @@ -14,9 +14,7 @@ class GoogleMatchingEngineConfig(BaseModel): credentials_path: Optional[str] = Field(None, description="Path to service account credentials file") vector_search_api_endpoint: Optional[str] = Field(None, description="Vector search API endpoint") - model_config = { - "extra": "forbid" - } + model_config = {"extra": "forbid"} def __init__(self, **kwargs): super().__init__(**kwargs) @@ -26,4 +24,4 @@ class GoogleMatchingEngineConfig(BaseModel): def model_post_init(self, _context) -> None: """Set collection_name to index_id if not provided""" if self.collection_name is None: - self.collection_name = self.index_id \ No newline at end of file + self.collection_name = self.index_id diff --git a/mem0/memory/main.py b/mem0/memory/main.py index 2e1279b1..0665d874 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -71,13 +71,14 @@ class Memory(MemoryBase): if "vector_store" not in config_dict and "embedder" in config_dict: config_dict["vector_store"] = {} config_dict["vector_store"]["config"] = {} - config_dict["vector_store"]["config"]["embedding_model_dims"] = config_dict["embedder"]["config"]["embedding_dims"] + config_dict["vector_store"]["config"]["embedding_model_dims"] = config_dict["embedder"]["config"][ + "embedding_dims" + ] try: return config_dict except ValidationError as e: logger.error(f"Configuration validation error: {e}") raise - def add( self, @@ -204,7 +205,8 @@ class Memory(MemoryBase): messages_embeddings = self.embedding_model.embed(new_mem, "add") new_message_embeddings[new_mem] = messages_embeddings existing_memories = self.vector_store.search( - query=messages_embeddings, + query=new_mem, + vectors=messages_embeddings, limit=5, filters=filters, ) @@ -222,7 +224,9 @@ class Memory(MemoryBase): temp_uuid_mapping[str(idx)] = item["id"] retrieved_old_memory[idx]["id"] = str(idx) - function_calling_prompt = get_update_memory_messages(retrieved_old_memory, new_retrieved_facts, self.custom_update_memory_prompt) + function_calling_prompt = get_update_memory_messages( + retrieved_old_memory, new_retrieved_facts, self.custom_update_memory_prompt + ) try: new_memories_with_actions = self.llm.generate_response( @@ -479,7 +483,7 @@ class Memory(MemoryBase): def _search_vector_store(self, query, filters, limit): embeddings = self.embedding_model.embed(query, "search") - memories = self.vector_store.search(query=embeddings, limit=limit, filters=filters) + memories = self.vector_store.search(query=query, vectors=embeddings, limit=limit, filters=filters) excluded_keys = { "user_id", diff --git a/mem0/vector_stores/azure_ai_search.py b/mem0/vector_stores/azure_ai_search.py index ad3728c9..4207302c 100644 --- a/mem0/vector_stores/azure_ai_search.py +++ b/mem0/vector_stores/azure_ai_search.py @@ -45,8 +45,10 @@ class AzureAISearch(VectorStoreBase): collection_name, api_key, embedding_model_dims, - compression_type: Optional[str] = None, + compression_type: Optional[str] = None, use_float16: bool = False, + hybrid_search: bool = False, + vector_filter_mode: Optional[str] = None, ): """ Initialize the Azure AI Search vector store. @@ -60,13 +62,17 @@ class AzureAISearch(VectorStoreBase): Allowed values are None (no quantization), "scalar", or "binary". use_float16 (bool): Whether to store vectors in half precision (Edm.Half) or full precision (Edm.Single). (Note: This flag is preserved from the initial implementation per feedback.) + hybrid_search (bool): Whether to use hybrid search. Default is False. + vector_filter_mode (Optional[str]): Mode for vector filtering. Default is "preFilter". """ self.index_name = collection_name self.collection_name = collection_name self.embedding_model_dims = embedding_model_dims # If compression_type is None, treat it as "none". - self.compression_type = (compression_type or "none").lower() + self.compression_type = (compression_type or "none").lower() self.use_float16 = use_float16 + self.hybrid_search = hybrid_search + self.vector_filter_mode = vector_filter_mode self.search_client = SearchClient( endpoint=f"https://{service_name}.search.windows.net", @@ -113,8 +119,6 @@ class AzureAISearch(VectorStoreBase): ) ] # If no compression is desired, compression_configurations remains empty. - - fields = [ SimpleField(name="id", type=SearchFieldDataType.String, key=True), SimpleField(name="user_id", type=SearchFieldDataType.String, filterable=True), @@ -123,11 +127,11 @@ class AzureAISearch(VectorStoreBase): SearchField( name="vector", type=vector_type, - searchable=True, + searchable=True, vector_search_dimensions=self.embedding_model_dims, vector_search_profile_name="my-vector-config", ), - SimpleField(name="payload", type=SearchFieldDataType.String, searchable=True), + SearchField(name="payload", type=SearchFieldDataType.String, searchable=True), ] vector_search = VectorSearch( @@ -135,7 +139,7 @@ class AzureAISearch(VectorStoreBase): VectorSearchProfile( name="my-vector-config", algorithm_configuration_name="my-algorithms-config", - compression_name=compression_name if self.compression_type != "none" else None + compression_name=compression_name if self.compression_type != "none" else None, ) ], algorithms=[HnswAlgorithmConfiguration(name="my-algorithms-config")], @@ -164,8 +168,7 @@ class AzureAISearch(VectorStoreBase): """ logger.info(f"Inserting {len(vectors)} vectors into index {self.index_name}") documents = [ - self._generate_document(vector, payload, id) - for id, vector, payload in zip(ids, vectors, payloads) + self._generate_document(vector, payload, id) for id, vector, payload in zip(ids, vectors, payloads) ] response = self.search_client.upload_documents(documents) for doc in response: @@ -189,12 +192,13 @@ class AzureAISearch(VectorStoreBase): filter_expression = " and ".join(filter_conditions) return filter_expression - def search(self, query, limit=5, filters=None): + def search(self, query, vectors, limit=5, filters=None): """ Search for similar vectors. Args: - query (List[float]): Query vector. + query (str): Query. + vectors (List[float]): Query vector. limit (int, optional): Number of results to return. Defaults to 5. filters (Dict, optional): Filters to apply to the search. Defaults to None. @@ -205,23 +209,28 @@ class AzureAISearch(VectorStoreBase): if filters: filter_expression = self._build_filter_expression(filters) - vector_query = VectorizedQuery( - vector=query, k_nearest_neighbors=limit, fields="vector" - ) - search_results = self.search_client.search( - vector_queries=[vector_query], - filter=filter_expression, - top=limit - ) + vector_query = VectorizedQuery(vector=vectors, k_nearest_neighbors=limit, fields="vector") + if self.hybrid_search: + search_results = self.search_client.search( + search_text=query, + vector_queries=[vector_query], + filter=filter_expression, + top=limit, + vector_filter_mode=self.vector_filter_mode, + search_fields=["payload"], + ) + else: + search_results = self.search_client.search( + vector_queries=[vector_query], + filter=filter_expression, + top=limit, + vector_filter_mode=self.vector_filter_mode, + ) results = [] for result in search_results: payload = json.loads(result["payload"]) - results.append( - OutputData( - id=result["id"], score=result["@search.score"], payload=payload - ) - ) + results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload)) return results def delete(self, vector_id): @@ -275,9 +284,7 @@ class AzureAISearch(VectorStoreBase): result = self.search_client.get_document(key=vector_id) except ResourceNotFoundError: return None - return OutputData( - id=result["id"], score=None, payload=json.loads(result["payload"]) - ) + return OutputData(id=result["id"], score=None, payload=json.loads(result["payload"])) def list_cols(self) -> List[str]: """ @@ -321,17 +328,11 @@ class AzureAISearch(VectorStoreBase): if filters: filter_expression = self._build_filter_expression(filters) - search_results = self.search_client.search( - search_text="*", filter=filter_expression, top=limit - ) + search_results = self.search_client.search(search_text="*", filter=filter_expression, top=limit) results = [] for result in search_results: payload = json.loads(result["payload"]) - results.append( - OutputData( - id=result["id"], score=result["@search.score"], payload=payload - ) - ) + results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload)) return [results] def __del__(self): diff --git a/mem0/vector_stores/base.py b/mem0/vector_stores/base.py index db62c572..4f55d109 100644 --- a/mem0/vector_stores/base.py +++ b/mem0/vector_stores/base.py @@ -13,7 +13,7 @@ class VectorStoreBase(ABC): pass @abstractmethod - def search(self, query, limit=5, filters=None): + def search(self, query, vectors, limit=5, filters=None): """Search for similar vectors.""" pass diff --git a/mem0/vector_stores/chroma.py b/mem0/vector_stores/chroma.py index fc11ad20..696a3047 100644 --- a/mem0/vector_stores/chroma.py +++ b/mem0/vector_stores/chroma.py @@ -127,19 +127,22 @@ class ChromaDB(VectorStoreBase): logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}") self.collection.add(ids=ids, embeddings=vectors, metadatas=payloads) - def search(self, query: List[list], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]: + def search( + self, query: str, vectors: List[list], limit: int = 5, filters: Optional[Dict] = None + ) -> List[OutputData]: """ Search for similar vectors. Args: - query (List[list]): Query vector. + query (str): Query. + vectors (List[list]): List of vectors to search. limit (int, optional): Number of results to return. Defaults to 5. filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None. Returns: List[OutputData]: Search results. """ - results = self.collection.query(query_embeddings=query, where=filters, n_results=limit) + results = self.collection.query(query_embeddings=vectors, where=filters, n_results=limit) final_results = self._parse_output(results) return final_results diff --git a/mem0/vector_stores/elasticsearch.py b/mem0/vector_stores/elasticsearch.py index 529afe77..8a535429 100644 --- a/mem0/vector_stores/elasticsearch.py +++ b/mem0/vector_stores/elasticsearch.py @@ -45,7 +45,7 @@ class ElasticsearchDB(VectorStoreBase): # Create index only if auto_create_index is True if config.auto_create_index: self.create_index() - + if config.custom_search_query: self.custom_search_query = config.custom_search_query else: @@ -121,16 +121,20 @@ class ElasticsearchDB(VectorStoreBase): ) return results - def search(self, query: List[float], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]: + def search( + self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None + ) -> List[OutputData]: """ Search with two options: 1. Use custom search query if provided 2. Use KNN search on vectors with pre-filtering if no custom search query is provided """ if self.custom_search_query: - search_query = self.custom_search_query(query, limit, filters) + search_query = self.custom_search_query(vectors, limit, filters) else: - search_query = {"knn": {"field": "vector", "query_vector": query, "k": limit, "num_candidates": limit * 2}} + search_query = { + "knn": {"field": "vector", "query_vector": vectors, "k": limit, "num_candidates": limit * 2} + } if filters: filter_conditions = [] for key, value in filters.items(): diff --git a/mem0/vector_stores/milvus.py b/mem0/vector_stores/milvus.py index 013fc0e3..ff48e306 100644 --- a/mem0/vector_stores/milvus.py +++ b/mem0/vector_stores/milvus.py @@ -134,12 +134,13 @@ class MilvusDB(VectorStoreBase): return memory - def search(self, query: list, limit: int = 5, filters: dict = None) -> list: + def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None) -> list: """ Search for similar vectors. Args: - query (List[float]): Query vector. + query (str): Query. + vectors (List[float]): Query vector. limit (int, optional): Number of results to return. Defaults to 5. filters (Dict, optional): Filters to apply to the search. Defaults to None. @@ -149,7 +150,7 @@ class MilvusDB(VectorStoreBase): query_filter = self._create_filter(filters) if filters else None hits = self.client.search( collection_name=self.collection_name, - data=[query], + data=[vectors], limit=limit, filter=query_filter, output_fields=["*"], diff --git a/mem0/vector_stores/opensearch.py b/mem0/vector_stores/opensearch.py index 2a58ac45..dca5b287 100644 --- a/mem0/vector_stores/opensearch.py +++ b/mem0/vector_stores/opensearch.py @@ -28,10 +28,12 @@ class OpenSearchDB(VectorStoreBase): # Initialize OpenSearch client self.client = OpenSearch( hosts=[{"host": config.host, "port": config.port or 9200}], - http_auth=config.http_auth if config.http_auth else ((config.user, config.password) if (config.user and config.password) else None), + http_auth=config.http_auth + if config.http_auth + else ((config.user, config.password) if (config.user and config.password) else None), use_ssl=config.use_ssl, verify_certs=config.verify_certs, - connection_class=RequestsHttpConnection + connection_class=RequestsHttpConnection, ) self.collection_name = config.collection_name @@ -115,14 +117,16 @@ class OpenSearchDB(VectorStoreBase): results.append(OutputData(id=id_, score=1.0, payload=payloads[i])) return results - def search(self, query: List[float], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]: + def search( + self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None + ) -> List[OutputData]: """Search for similar vectors using OpenSearch k-NN search with pre-filtering.""" search_query = { "size": limit, "query": { "knn": { "vector": { - "vector": query, + "vector": vectors, "k": limit, } } diff --git a/mem0/vector_stores/pgvector.py b/mem0/vector_stores/pgvector.py index 989b60df..095bff22 100644 --- a/mem0/vector_stores/pgvector.py +++ b/mem0/vector_stores/pgvector.py @@ -120,12 +120,13 @@ class PGVector(VectorStoreBase): ) self.conn.commit() - def search(self, query, limit=5, filters=None): + def search(self, query, vectors, limit=5, filters=None): """ Search for similar vectors. Args: - query (List[float]): Query vector. + query (str): Query. + vectors (List[float]): Query vector. limit (int, optional): Number of results to return. Defaults to 5. filters (Dict, optional): Filters to apply to the search. Defaults to None. @@ -150,7 +151,7 @@ class PGVector(VectorStoreBase): ORDER BY distance LIMIT %s """, - (query, *filter_params, limit), + (vectors, *filter_params, limit), ) results = self.cur.fetchall() diff --git a/mem0/vector_stores/qdrant.py b/mem0/vector_stores/qdrant.py index 73d6d0dd..7e16089f 100644 --- a/mem0/vector_stores/qdrant.py +++ b/mem0/vector_stores/qdrant.py @@ -127,12 +127,13 @@ class Qdrant(VectorStoreBase): conditions.append(FieldCondition(key=key, match=MatchValue(value=value))) return Filter(must=conditions) if conditions else None - def search(self, query: list, limit: int = 5, filters: dict = None) -> list: + def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None) -> list: """ Search for similar vectors. Args: - query (list): Query vector. + query (str): Query. + vectors (list): Query vector. limit (int, optional): Number of results to return. Defaults to 5. filters (dict, optional): Filters to apply to the search. Defaults to None. @@ -142,7 +143,7 @@ class Qdrant(VectorStoreBase): query_filter = self._create_filter(filters) if filters else None hits = self.client.query_points( collection_name=self.collection_name, - query=query, + query=vectors, query_filter=query_filter, limit=limit, ) diff --git a/mem0/vector_stores/redis.py b/mem0/vector_stores/redis.py index 0f553f32..8947beeb 100644 --- a/mem0/vector_stores/redis.py +++ b/mem0/vector_stores/redis.py @@ -101,12 +101,12 @@ class RedisDB(VectorStoreBase): data.append(entry) self.index.load(data, id_field="memory_id") - def search(self, query: list, limit: int = 5, filters: dict = None): + def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None): conditions = [Tag(key) == value for key, value in filters.items() if value is not None] filter = reduce(lambda x, y: x & y, conditions) v = VectorQuery( - vector=np.array(query, dtype=np.float32).tobytes(), + vector=np.array(vectors, dtype=np.float32).tobytes(), vector_field_name="embedding", return_fields=["memory_id", "hash", "agent_id", "run_id", "user_id", "memory", "metadata", "created_at"], filter_expression=filter, diff --git a/mem0/vector_stores/supabase.py b/mem0/vector_stores/supabase.py index bd14d668..1e297cb8 100644 --- a/mem0/vector_stores/supabase.py +++ b/mem0/vector_stores/supabase.py @@ -112,16 +112,18 @@ class Supabase(VectorStoreBase): payloads = [{} for _ in vectors] records = [(id, vector, payload) for id, vector, payload in zip(ids, vectors, payloads)] - print(records) self.collection.upsert(records) - def search(self, query: List[float], limit: int = 5, filters: Optional[dict] = None) -> List[OutputData]: + def search( + self, query: str, vectors: List[float], limit: int = 5, filters: Optional[dict] = None + ) -> List[OutputData]: """ Search for similar vectors. Args: - query (List[float]): Query vector + query (str): Query. + vectors (List[float]): Query vector. limit (int, optional): Number of results to return. Defaults to 5. filters (Dict, optional): Filters to apply to the search. Defaults to None. @@ -129,11 +131,9 @@ class Supabase(VectorStoreBase): List[OutputData]: Search results """ filters = self._preprocess_filters(filters) - print(filters) results = self.collection.query( - data=query, limit=limit, filters=filters, include_metadata=True, include_value=True + data=vectors, limit=limit, filters=filters, include_metadata=True, include_value=True ) - print(results) return [OutputData(id=str(result[0]), score=float(result[1]), payload=result[2]) for result in results] diff --git a/mem0/vector_stores/vertex_ai_vector_search.py b/mem0/vector_stores/vertex_ai_vector_search.py index a7163499..6f526584 100644 --- a/mem0/vector_stores/vertex_ai_vector_search.py +++ b/mem0/vector_stores/vertex_ai_vector_search.py @@ -32,19 +32,19 @@ class GoogleMatchingEngine(VectorStoreBase): def __init__(self, **kwargs): """Initialize Google Matching Engine client.""" logger.debug("Initializing Google Matching Engine with kwargs: %s", kwargs) - + # If collection_name is passed, use it as deployment_index_id if deployment_index_id is not provided - if 'collection_name' in kwargs and 'deployment_index_id' not in kwargs: - kwargs['deployment_index_id'] = kwargs['collection_name'] - logger.debug("Using collection_name as deployment_index_id: %s", kwargs['deployment_index_id']) - elif 'deployment_index_id' in kwargs and 'collection_name' not in kwargs: - kwargs['collection_name'] = kwargs['deployment_index_id'] - logger.debug("Using deployment_index_id as collection_name: %s", kwargs['collection_name']) - + if "collection_name" in kwargs and "deployment_index_id" not in kwargs: + kwargs["deployment_index_id"] = kwargs["collection_name"] + logger.debug("Using collection_name as deployment_index_id: %s", kwargs["deployment_index_id"]) + elif "deployment_index_id" in kwargs and "collection_name" not in kwargs: + kwargs["collection_name"] = kwargs["deployment_index_id"] + logger.debug("Using deployment_index_id as collection_name: %s", kwargs["collection_name"]) + try: config = GoogleMatchingEngineConfig(**kwargs) logger.debug("Config created: %s", config.model_dump()) - logger.debug("Config collection_name: %s", getattr(config, 'collection_name', None)) + logger.debug("Config collection_name: %s", getattr(config, "collection_name", None)) except Exception as e: logger.error("Failed to validate config: %s", str(e)) raise @@ -57,41 +57,37 @@ class GoogleMatchingEngine(VectorStoreBase): self.deployment_index_id = config.deployment_index_id # The deployment-specific ID self.collection_name = config.collection_name self.vector_search_api_endpoint = config.vector_search_api_endpoint - + logger.debug("Using project=%s, location=%s", self.project_id, self.region) - + # Initialize Vertex AI with credentials if provided init_args = { "project": self.project_id, "location": self.region, } - if hasattr(config, 'credentials_path') and config.credentials_path: + if hasattr(config, "credentials_path") and config.credentials_path: logger.debug("Using credentials from: %s", config.credentials_path) - credentials = service_account.Credentials.from_service_account_file( - config.credentials_path - ) + credentials = service_account.Credentials.from_service_account_file(config.credentials_path) init_args["credentials"] = credentials - + try: aiplatform.init(**init_args) logger.debug("Vertex AI initialized successfully") except Exception as e: logger.error("Failed to initialize Vertex AI: %s", str(e)) raise - + try: # Format the index path properly using the configured index_id index_path = f"projects/{self.project_number}/locations/{self.region}/indexes/{self.index_id}" logger.debug("Initializing index with path: %s", index_path) self.index = aiplatform.MatchingEngineIndex(index_name=index_path) logger.debug("Index initialized successfully") - + # Format the endpoint name properly endpoint_name = self.endpoint_id logger.debug("Initializing endpoint with name: %s", endpoint_name) - self.index_endpoint = aiplatform.MatchingEngineIndexEndpoint( - index_endpoint_name=endpoint_name - ) + self.index_endpoint = aiplatform.MatchingEngineIndexEndpoint(index_endpoint_name=endpoint_name) logger.debug("Endpoint initialized successfully") except Exception as e: logger.error("Failed to initialize Matching Engine components: %s", str(e)) @@ -119,47 +115,36 @@ class GoogleMatchingEngine(VectorStoreBase): def _create_restriction(self, key: str, value: Any) -> aiplatform_v1.types.index.IndexDatapoint.Restriction: """Create a restriction object for the Matching Engine index. - + Args: key: The namespace/key for the restriction value: The value to restrict on - + Returns: Restriction object for the index """ str_value = str(value) if value is not None else "" - return aiplatform_v1.types.index.IndexDatapoint.Restriction( - namespace=key, - allow_list=[str_value] - ) + return aiplatform_v1.types.index.IndexDatapoint.Restriction(namespace=key, allow_list=[str_value]) def _create_datapoint( - self, - vector_id: str, - vector: List[float], - payload: Optional[Dict] = None + self, vector_id: str, vector: List[float], payload: Optional[Dict] = None ) -> aiplatform_v1.types.index.IndexDatapoint: """Create a datapoint object for the Matching Engine index. - + Args: vector_id: The ID for the datapoint vector: The vector to store payload: Optional metadata to store with the vector - + Returns: IndexDatapoint object """ restrictions = [] if payload: - restrictions = [ - self._create_restriction(key, value) - for key, value in payload.items() - ] - + restrictions = [self._create_restriction(key, value) for key, value in payload.items()] + return aiplatform_v1.types.index.IndexDatapoint( - datapoint_id=vector_id, - feature_vector=vector, - restricts=restrictions + datapoint_id=vector_id, feature_vector=vector, restricts=restrictions ) def insert( @@ -169,41 +154,41 @@ class GoogleMatchingEngine(VectorStoreBase): ids: Optional[List[str]] = None, ) -> None: """Insert vectors into the Matching Engine index. - + Args: vectors: List of vectors to insert payloads: Optional list of metadata dictionaries ids: Optional list of IDs for the vectors - + Raises: ValueError: If vectors is empty or lengths don't match GoogleAPIError: If the API call fails """ if not vectors: raise ValueError("No vectors provided for insertion") - + if payloads and len(payloads) != len(vectors): raise ValueError(f"Number of payloads ({len(payloads)}) does not match number of vectors ({len(vectors)})") - + if ids and len(ids) != len(vectors): raise ValueError(f"Number of ids ({len(ids)}) does not match number of vectors ({len(vectors)})") - + logger.debug("Starting insert of %d vectors", len(vectors)) - + try: datapoints = [ self._create_datapoint( vector_id=ids[i] if ids else str(uuid.uuid4()), vector=vector, - payload=payloads[i] if payloads and i < len(payloads) else None + payload=payloads[i] if payloads and i < len(payloads) else None, ) for i, vector in enumerate(vectors) ] - + logger.debug("Created %d datapoints", len(datapoints)) self.index.upsert_datapoints(datapoints=datapoints) logger.debug("Successfully inserted datapoints") - + except google.api_core.exceptions.GoogleAPIError as e: logger.error("Failed to insert vectors: %s", str(e)) raise @@ -212,21 +197,22 @@ class GoogleMatchingEngine(VectorStoreBase): logger.error("Stack trace: %s", traceback.format_exc()) raise - - def search(self, query: List[float], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]: + def search( + self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None + ) -> List[OutputData]: """ Search for similar vectors. Args: - query (List[float]): Query vector. + query (str): Query. + vectors (List[float]): Query vector. limit (int, optional): Number of results to return. Defaults to 5. filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None. Returns: List[OutputData]: Search results (unwrapped) """ logger.debug("Starting search") - logger.debug("Query type: %s, length: %d", type(query), len(query)) logger.debug("Limit: %d, Filters: %s", limit, filters) - + try: filter_namespaces = [] if filters: @@ -235,53 +221,42 @@ class GoogleMatchingEngine(VectorStoreBase): logger.debug("Processing filter %s=%s (type=%s)", key, value, type(value)) if isinstance(value, (str, int, float)): logger.debug("Adding simple filter for %s", key) - filter_namespaces.append( - Namespace(key, [str(value)], []) - ) + filter_namespaces.append(Namespace(key, [str(value)], [])) elif isinstance(value, dict): logger.debug("Adding complex filter for %s", key) - includes = value.get('include', []) - excludes = value.get('exclude', []) - filter_namespaces.append( - Namespace(key, includes, excludes) - ) - + includes = value.get("include", []) + excludes = value.get("exclude", []) + filter_namespaces.append(Namespace(key, includes, excludes)) + logger.debug("Final filter_namespaces: %s", filter_namespaces) - + response = self.index_endpoint.find_neighbors( deployed_index_id=self.deployment_index_id, - queries=[query], + queries=[vectors], num_neighbors=limit, filter=filter_namespaces if filter_namespaces else None, - return_full_datapoint=True + return_full_datapoint=True, ) - + if not response or len(response) == 0 or len(response[0]) == 0: logger.debug("No results found") return [] - + results = [] for neighbor in response[0]: - logger.debug("Processing neighbor - id: %s, distance: %s", - neighbor.id, neighbor.distance) - + logger.debug("Processing neighbor - id: %s, distance: %s", neighbor.id, neighbor.distance) + payload = {} - if hasattr(neighbor, 'restricts'): + if hasattr(neighbor, "restricts"): logger.debug("Processing restricts") for restrict in neighbor.restricts: - if (hasattr(restrict, 'name') and - hasattr(restrict, 'allow_tokens') and - restrict.allow_tokens): + if hasattr(restrict, "name") and hasattr(restrict, "allow_tokens") and restrict.allow_tokens: logger.debug("Adding %s: %s", restrict.name, restrict.allow_tokens[0]) payload[restrict.name] = restrict.allow_tokens[0] - - output_data = OutputData( - id=neighbor.id, - score=neighbor.distance, - payload=payload - ) + + output_data = OutputData(id=neighbor.id, score=neighbor.distance, payload=payload) results.append(output_data) - + logger.debug("Returning %d results", len(results)) return results @@ -291,7 +266,6 @@ class GoogleMatchingEngine(VectorStoreBase): logger.error("Stack trace: %s", traceback.format_exc()) raise - def delete(self, vector_id: Optional[str] = None, ids: Optional[List[str]] = None) -> bool: """ Delete vectors from the Matching Engine index. @@ -326,14 +300,13 @@ class GoogleMatchingEngine(VectorStoreBase): except google.api_core.exceptions.InvalidArgument as e: logger.error("Invalid argument: %s", str(e)) return False - + except Exception as e: logger.error("Error occurred: %s", str(e)) logger.error("Error type: %s", type(e)) logger.error("Stack trace: %s", traceback.format_exc()) return False - def update( self, vector_id: str, @@ -341,42 +314,40 @@ class GoogleMatchingEngine(VectorStoreBase): payload: Optional[Dict] = None, ) -> bool: """Update a vector and its payload. - + Args: vector_id: ID of the vector to update vector: Optional new vector values payload: Optional new metadata payload - + Returns: bool: True if update was successful - + Raises: ValueError: If neither vector nor payload is provided GoogleAPIError: If the API call fails """ logger.debug("Starting update for vector_id: %s", vector_id) - + if vector is None and payload is None: raise ValueError("Either vector or payload must be provided for update") - + # First check if the vector exists try: existing = self.get(vector_id) if existing is None: logger.error("Vector ID not found: %s", vector_id) return False - + datapoint = self._create_datapoint( - vector_id=vector_id, - vector=vector if vector is not None else [], - payload=payload + vector_id=vector_id, vector=vector if vector is not None else [], payload=payload ) - + logger.debug("Upserting datapoint: %s", datapoint) self.index.upsert_datapoints(datapoints=[datapoint]) logger.debug("Update completed successfully") return True - + except google.api_core.exceptions.GoogleAPIError as e: logger.error("API error during update: %s", str(e)) return False @@ -385,7 +356,6 @@ class GoogleMatchingEngine(VectorStoreBase): logger.error("Stack trace: %s", traceback.format_exc()) raise - def get(self, vector_id: str) -> Optional[OutputData]: """ Retrieve a vector by ID. @@ -395,24 +365,17 @@ class GoogleMatchingEngine(VectorStoreBase): Optional[OutputData]: Retrieved vector or None if not found. """ logger.debug("Starting get for vector_id: %s", vector_id) - + try: if not self.vector_search_api_endpoint: raise ValueError("vector_search_api_endpoint is required for get operation") vector_search_client = aiplatform_v1.MatchServiceClient( - client_options={ - "api_endpoint": self.vector_search_api_endpoint - }, - ) - datapoint = aiplatform_v1.IndexDatapoint( - datapoint_id=vector_id + client_options={"api_endpoint": self.vector_search_api_endpoint}, ) + datapoint = aiplatform_v1.IndexDatapoint(datapoint_id=vector_id) - query = aiplatform_v1.FindNeighborsRequest.Query( - datapoint=datapoint, - neighbor_count=1 - ) + query = aiplatform_v1.FindNeighborsRequest.Query(datapoint=datapoint, neighbor_count=1) request = aiplatform_v1.FindNeighborsRequest( index_endpoint=f"projects/{self.project_number}/locations/{self.region}/indexEndpoints/{self.endpoint_id}", deployed_index_id=self.deployment_index_id, @@ -423,41 +386,36 @@ class GoogleMatchingEngine(VectorStoreBase): try: response = vector_search_client.find_neighbors(request) logger.debug("Got response") - + if response and response.nearest_neighbors: nearest = response.nearest_neighbors[0] if nearest.neighbors: neighbor = nearest.neighbors[0] - + payload = {} - if hasattr(neighbor.datapoint, 'restricts'): + if hasattr(neighbor.datapoint, "restricts"): for restrict in neighbor.datapoint.restricts: if restrict.allow_list: payload[restrict.namespace] = restrict.allow_list[0] - - return OutputData( - id=neighbor.datapoint.datapoint_id, - score=neighbor.distance, - payload=payload - ) - + + return OutputData(id=neighbor.datapoint.datapoint_id, score=neighbor.distance, payload=payload) + logger.debug("No results found") return None - + except google.api_core.exceptions.NotFound: logger.debug("Datapoint not found") return None except google.api_core.exceptions.PermissionDenied as e: logger.error("Permission denied: %s", str(e)) return None - + except Exception as e: logger.error("Error occurred: %s", str(e)) logger.error("Error type: %s", type(e)) logger.error("Stack trace: %s", traceback.format_exc()) raise - def list_cols(self) -> List[str]: """ List all collections (indexes). @@ -466,7 +424,6 @@ class GoogleMatchingEngine(VectorStoreBase): """ return [self.deployment_index_id] - def delete_col(self): """ Delete a collection (index). @@ -475,7 +432,6 @@ class GoogleMatchingEngine(VectorStoreBase): logger.warning("Delete collection operation is not supported for Google Matching Engine") pass - def col_info(self) -> Dict: """ Get information about a collection (index). @@ -486,17 +442,16 @@ class GoogleMatchingEngine(VectorStoreBase): "index_id": self.index_id, "endpoint_id": self.endpoint_id, "project_id": self.project_id, - "region": self.region + "region": self.region, } - def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[List[OutputData]]: """List vectors matching the given filters. - + Args: filters: Optional filters to apply limit: Optional maximum number of results to return - + Returns: List[List[OutputData]]: List of matching vectors wrapped in an extra array to match the interface @@ -504,36 +459,31 @@ class GoogleMatchingEngine(VectorStoreBase): logger.debug("Starting list operation") logger.debug("Filters: %s", filters) logger.debug("Limit: %s", limit) - + try: # Use a zero vector for the search dimension = 768 # This should be configurable based on the model zero_vector = [0.0] * dimension - + # Use a large limit if none specified search_limit = limit if limit is not None else 10000 - - results = self.search( - query=zero_vector, - limit=search_limit, - filters=filters - ) - + + results = self.search(query=zero_vector, limit=search_limit, filters=filters) + logger.debug("Found %d results", len(results)) return [results] # Wrap in extra array to match interface - + except Exception as e: logger.error("Error in list operation: %s", str(e)) logger.error("Stack trace: %s", traceback.format_exc()) raise - def create_col(self, name=None, vector_size=None, distance=None): """ - Create a new collection. For Google Matching Engine, collections (indexes) + Create a new collection. For Google Matching Engine, collections (indexes) are created through the Google Cloud Console or API separately. This method is a no-op since indexes are pre-created. - + Args: name: Ignored for Google Matching Engine vector_size: Ignored for Google Matching Engine @@ -543,41 +493,35 @@ class GoogleMatchingEngine(VectorStoreBase): # This method is included only to satisfy the abstract base class pass - def add(self, text: str, metadata: Optional[Dict] = None, user_id: Optional[str] = None) -> str: logger.debug("Starting add operation") logger.debug("Text: %s", text) logger.debug("Metadata: %s", metadata) logger.debug("User ID: %s", user_id) - + try: # Generate a unique ID for this entry vector_id = str(uuid.uuid4()) - + # Create the payload with all necessary fields payload = { "data": text, # Store the text in the data field "user_id": user_id, - **(metadata or {}) + **(metadata or {}), } - + # Get the embedding vector = self.embedder.embed_query(text) - + # Insert using the insert method - self.insert( - vectors=[vector], - payloads=[payload], - ids=[vector_id] - ) - + self.insert(vectors=[vector], payloads=[payload], ids=[vector_id]) + return vector_id - + except Exception as e: logger.error("Error occurred: %s", str(e)) raise - def add_texts( self, texts: List[str], @@ -585,47 +529,45 @@ class GoogleMatchingEngine(VectorStoreBase): ids: Optional[List[str]] = None, ) -> List[str]: """Add texts to the vector store. - + Args: texts: List of texts to add metadatas: Optional list of metadata dicts ids: Optional list of IDs to use - + Returns: List[str]: List of IDs of the added texts - + Raises: ValueError: If texts is empty or lengths don't match """ if not texts: raise ValueError("No texts provided") - + if metadatas and len(metadatas) != len(texts): - raise ValueError(f"Number of metadata items ({len(metadatas)}) does not match number of texts ({len(texts)})") - + raise ValueError( + f"Number of metadata items ({len(metadatas)}) does not match number of texts ({len(texts)})" + ) + if ids and len(ids) != len(texts): raise ValueError(f"Number of ids ({len(ids)}) does not match number of texts ({len(texts)})") - + logger.debug("Starting add_texts operation") logger.debug("Number of texts: %d", len(texts)) logger.debug("Has metadatas: %s", metadatas is not None) logger.debug("Has ids: %s", ids is not None) - + if ids is None: ids = [str(uuid.uuid4()) for _ in texts] - + try: # Get embeddings embeddings = self.embedder.embed_documents(texts) - + # Add to store - self.insert( - vectors=embeddings, - payloads=metadatas if metadatas else [{}] * len(texts), - ids=ids - ) + self.insert(vectors=embeddings, payloads=metadatas if metadatas else [{}] * len(texts), ids=ids) return ids - + except Exception as e: logger.error("Error in add_texts: %s", str(e)) logger.error("Stack trace: %s", traceback.format_exc()) @@ -657,18 +599,12 @@ class GoogleMatchingEngine(VectorStoreBase): logger.debug("Query: %s", query) logger.debug("k: %d", k) logger.debug("Filter: %s", filter) - + embedding = self.embedder.embed_query(query) results = self.search(query=embedding, limit=k, filters=filter) - + docs_and_scores = [ - ( - Document( - page_content=result.payload.get("text", ""), - metadata=result.payload - ), - result.score - ) + (Document(page_content=result.payload.get("text", ""), metadata=result.payload), result.score) for result in results ] logger.debug("Found %d results", len(docs_and_scores)) @@ -684,4 +620,3 @@ class GoogleMatchingEngine(VectorStoreBase): logger.debug("Starting similarity search") docs_and_scores = self.similarity_search_with_score(query, k, filter) return [doc for doc, _ in docs_and_scores] - diff --git a/mem0/vector_stores/weaviate.py b/mem0/vector_stores/weaviate.py index bfad8db8..0988b103 100644 --- a/mem0/vector_stores/weaviate.py +++ b/mem0/vector_stores/weaviate.py @@ -154,7 +154,9 @@ class Weaviate(VectorStoreBase): batch.add_object(collection=self.collection_name, properties=data_object, uuid=object_id, vector=vector) - def search(self, query: List[float], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]: + def search( + self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None + ) -> List[OutputData]: """ Search for similar vectors. """ @@ -167,7 +169,7 @@ class Weaviate(VectorStoreBase): combined_filter = Filter.all_of(filter_conditions) if filter_conditions else None response = collection.query.hybrid( query="", - vector=query, + vector=vectors, limit=limit, filters=combined_filter, return_properties=["hash", "created_at", "updated_at", "user_id", "agent_id", "run_id", "data", "category"], diff --git a/tests/vector_stores/test_chroma.py b/tests/vector_stores/test_chroma.py index 6995217f..b038db1c 100644 --- a/tests/vector_stores/test_chroma.py +++ b/tests/vector_stores/test_chroma.py @@ -35,12 +35,11 @@ def test_search_vectors(chromadb_instance, mock_chromadb_client): } chromadb_instance.collection.query.return_value = mock_result - query = [[0.1, 0.2, 0.3]] - results = chromadb_instance.search(query=query, limit=2) + vectors = [[0.1, 0.2, 0.3]] + results = chromadb_instance.search(query="", vectors=vectors, limit=2) - chromadb_instance.collection.query.assert_called_once_with(query_embeddings=query, where=None, n_results=2) + chromadb_instance.collection.query.assert_called_once_with(query_embeddings=vectors, where=None, n_results=2) - print(results, type(results)) assert len(results) == 2 assert results[0].id == "id1" assert results[0].score == 0.1 diff --git a/tests/vector_stores/test_elasticsearch.py b/tests/vector_stores/test_elasticsearch.py index be8134e6..3107cf7c 100644 --- a/tests/vector_stores/test_elasticsearch.py +++ b/tests/vector_stores/test_elasticsearch.py @@ -196,8 +196,8 @@ class TestElasticsearchDB(unittest.TestCase): self.client_mock.search.return_value = mock_response # Perform search - query_vector = [0.1] * 1536 - results = self.es_db.search(query=query_vector, limit=5) + vectors = [[0.1] * 1536] + results = self.es_db.search(query="", vectors=vectors, limit=5) # Verify search call self.client_mock.search.assert_called_once() @@ -210,7 +210,7 @@ class TestElasticsearchDB(unittest.TestCase): # Verify KNN query structure self.assertIn("knn", body) self.assertEqual(body["knn"]["field"], "vector") - self.assertEqual(body["knn"]["query_vector"], query_vector) + self.assertEqual(body["knn"]["query_vector"], vectors) self.assertEqual(body["knn"]["k"], 5) self.assertEqual(body["knn"]["num_candidates"], 10) @@ -226,13 +226,13 @@ class TestElasticsearchDB(unittest.TestCase): self.es_db.custom_search_query.return_value = {"custom_key": "custom_value"} # Perform search - query_vector = [0.1] * 1536 + vectors = [[0.1] * 1536] limit = 5 filters = {"key1": "value1"} - self.es_db.search(query=query_vector, limit=limit, filters=filters) + self.es_db.search(query="", vectors=vectors, limit=limit, filters=filters) # Verify custom search query function was called - self.es_db.custom_search_query.assert_called_once_with(query_vector, limit, filters) + self.es_db.custom_search_query.assert_called_once_with(vectors, limit, filters) # Verify custom search query was used self.client_mock.search.assert_called_once_with(index=self.es_db.collection_name, body={"custom_key": "custom_value"}) diff --git a/tests/vector_stores/test_opensearch.py b/tests/vector_stores/test_opensearch.py index 912b660d..9d6daf53 100644 --- a/tests/vector_stores/test_opensearch.py +++ b/tests/vector_stores/test_opensearch.py @@ -126,15 +126,15 @@ class TestOpenSearchDB(unittest.TestCase): def test_search(self): mock_response = {"hits": {"hits": [{"_id": "id1", "_score": 0.8, "_source": {"vector": [0.1] * 1536, "metadata": {"key1": "value1"}}}]}} self.client_mock.search.return_value = mock_response - query_vector = [0.1] * 1536 - results = self.os_db.search(query=query_vector, limit=5) + vectors = [[0.1] * 1536] + results = self.os_db.search(query="", vectors=vectors, limit=5) self.client_mock.search.assert_called_once() search_args = self.client_mock.search.call_args[1] self.assertEqual(search_args["index"], "test_collection") body = search_args["body"] self.assertIn("knn", body["query"]) self.assertIn("vector", body["query"]["knn"]) - self.assertEqual(body["query"]["knn"]["vector"]["vector"], query_vector) + self.assertEqual(body["query"]["knn"]["vector"]["vector"], vectors) self.assertEqual(body["query"]["knn"]["vector"]["k"], 5) self.assertEqual(len(results), 1) self.assertEqual(results[0].id, "id1") diff --git a/tests/vector_stores/test_qdrant.py b/tests/vector_stores/test_qdrant.py index cf6e63a5..34648527 100644 --- a/tests/vector_stores/test_qdrant.py +++ b/tests/vector_stores/test_qdrant.py @@ -50,15 +50,15 @@ class TestQdrant(unittest.TestCase): self.assertEqual(points[0].payload, payloads[0]) def test_search(self): - query_vector = [0.1, 0.2] + vectors = [[0.1, 0.2]] mock_point = MagicMock(id=str(uuid.uuid4()), score=0.95, payload={"key": "value"}) self.client_mock.query_points.return_value = MagicMock(points=[mock_point]) - results = self.qdrant.search(query=query_vector, limit=1) + results = self.qdrant.search(query="", vectors=vectors, limit=1) self.client_mock.query_points.assert_called_once_with( collection_name="test_collection", - query=query_vector, + query=vectors, query_filter=None, limit=1, ) diff --git a/tests/vector_stores/test_supabase.py b/tests/vector_stores/test_supabase.py index bbdb468f..f94c203e 100644 --- a/tests/vector_stores/test_supabase.py +++ b/tests/vector_stores/test_supabase.py @@ -77,12 +77,12 @@ def test_search_vectors(supabase_instance, mock_collection): ] mock_collection.query.return_value = mock_results - query = [0.1, 0.2, 0.3] + vectors = [[0.1, 0.2, 0.3]] filters = {"category": "test"} - results = supabase_instance.search(query=query, limit=2, filters=filters) + results = supabase_instance.search(query="", vectors=vectors, limit=2, filters=filters) mock_collection.query.assert_called_once_with( - data=query, + data=vectors, limit=2, filters={"category": {"$eq": "test"}}, include_metadata=True, diff --git a/tests/vector_stores/test_vertex_ai_vector_search.py b/tests/vector_stores/test_vertex_ai_vector_search.py index 36233d20..3a1ab50d 100644 --- a/tests/vector_stores/test_vertex_ai_vector_search.py +++ b/tests/vector_stores/test_vertex_ai_vector_search.py @@ -73,12 +73,12 @@ def test_insert_vectors(vector_store, mock_vertex_ai): def test_search_vectors(vector_store, mock_vertex_ai): """Test searching vectors with filters""" - query = [0.1, 0.2, 0.3] + vectors = [[0.1, 0.2, 0.3]] filters = {"user_id": "test_user"} mock_datapoint = Mock() mock_datapoint.datapoint_id = "test-id" - mock_datapoint.feature_vector = query + mock_datapoint.feature_vector = vectors mock_restrict = Mock() mock_restrict.namespace = "user_id" @@ -96,11 +96,11 @@ def test_search_vectors(vector_store, mock_vertex_ai): mock_vertex_ai['endpoint'].find_neighbors.return_value = [[mock_neighbor]] - results = vector_store.search(query=query, filters=filters, limit=1) + results = vector_store.search(query="", vectors=vectors, filters=filters, limit=1) mock_vertex_ai['endpoint'].find_neighbors.assert_called_once_with( deployed_index_id=vector_store.deployment_index_id, - queries=[query], + queries=[vectors], num_neighbors=1, filter=[Namespace("user_id", ["test_user"], [])], return_full_datapoint=True diff --git a/tests/vector_stores/test_weaviate.py b/tests/vector_stores/test_weaviate.py index 231e2bff..15fb1b3b 100644 --- a/tests/vector_stores/test_weaviate.py +++ b/tests/vector_stores/test_weaviate.py @@ -147,8 +147,8 @@ class TestWeaviateDB(unittest.TestCase): self.client_mock.collections.get.return_value.query.hybrid = mock_hybrid mock_hybrid.return_value = mock_response - query_vector = [0.1] * 1536 - results = self.weaviate_db.search(query=query_vector, limit=5) + vectors = [[0.1] * 1536] + results = self.weaviate_db.search(query="", vectors=vectors, limit=5) mock_hybrid.assert_called_once()