diff --git a/embedchain/vectordb/pinecone.py b/embedchain/vectordb/pinecone.py index 982789ad..e0e21558 100644 --- a/embedchain/vectordb/pinecone.py +++ b/embedchain/vectordb/pinecone.py @@ -1,3 +1,4 @@ +import logging import os from typing import Optional, Union @@ -79,12 +80,20 @@ class PineconeDB(BaseVectorDB): :rtype: Set[str] """ existing_ids = list() + metadatas = [] + if ids is not None: for i in range(0, len(ids), 1000): result = self.client.fetch(ids=ids[i : i + 1000]) - batch_existing_ids = list(result.get("vectors").keys()) + vectors = result.get("vectors") + batch_existing_ids = list(vectors.keys()) existing_ids.extend(batch_existing_ids) - return {"ids": existing_ids} + metadatas.extend([vectors.get(ids).get("metadata") for ids in batch_existing_ids]) + + if where is not None: + logging.warning("Filtering is not supported by Pinecone") + + return {"ids": existing_ids, "metadatas": metadatas} def add( self, @@ -114,7 +123,7 @@ class PineconeDB(BaseVectorDB): } ) - for chunk in chunks(docs, self.BATCH_SIZE, desc="Adding chunks in batches..."): + for chunk in chunks(docs, self.BATCH_SIZE, desc="Adding chunks in batches"): self.client.upsert(chunk, **kwargs) def query( @@ -140,7 +149,10 @@ class PineconeDB(BaseVectorDB): :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]] """ query_vector = self.embedder.embedding_fn([input_query])[0] - data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True, **kwargs) + query_filter = self._generate_filter(where) + data = self.client.query( + vector=query_vector, filter=query_filter, top_k=n_results, include_metadata=True, **kwargs + ) contexts = [] for doc in data["matches"]: metadata = doc["metadata"] @@ -192,3 +204,24 @@ class PineconeDB(BaseVectorDB): :rtype: str """ return f"{self.config.collection_name}-{self.config.vector_dimension}".lower().replace("_", "-") + + @staticmethod + def _generate_filter(where: dict): + query = {} + for k, v in where.items(): + query[k] = {"$eq": v} + return query + + def delete(self, where: dict): + """Delete from database. + :param ids: list of ids to delete + :type ids: list[str] + """ + # Deleting with filters is not supported for `starter` index type. + # Follow `https://docs.pinecone.io/docs/metadata-filtering#deleting-vectors-by-metadata-filter` for more details + db_filter = self._generate_filter(where) + try: + self.client.delete(filter=db_filter) + except Exception as e: + print(f"Failed to delete from Pinecone: {e}") + return