import logging import os from typing import Any, Dict, List, Optional, Union from pydantic import BaseModel try: from pinecone import Pinecone, PodSpec, ServerlessSpec, Vector except ImportError: raise ImportError( "Pinecone requires extra dependencies. Install with `pip install pinecone pinecone-text`" ) from None from mem0.vector_stores.base import VectorStoreBase logger = logging.getLogger(__name__) class OutputData(BaseModel): id: Optional[str] # memory id score: Optional[float] # distance payload: Optional[Dict] # metadata class PineconeDB(VectorStoreBase): def __init__( self, collection_name: str, embedding_model_dims: int, client: Optional["Pinecone"], api_key: Optional[str], environment: Optional[str], serverless_config: Optional[Dict[str, Any]], pod_config: Optional[Dict[str, Any]], hybrid_search: bool, metric: str, batch_size: int, extra_params: Optional[Dict[str, Any]], ): """ Initialize the Pinecone vector store. Args: collection_name (str): Name of the index/collection. embedding_model_dims (int): Dimensions of the embedding model. client (Pinecone, optional): Existing Pinecone client instance. Defaults to None. api_key (str, optional): API key for Pinecone. Defaults to None. environment (str, optional): Pinecone environment. Defaults to None. serverless_config (Dict, optional): Configuration for serverless deployment. Defaults to None. pod_config (Dict, optional): Configuration for pod-based deployment. Defaults to None. hybrid_search (bool, optional): Whether to enable hybrid search. Defaults to False. metric (str, optional): Distance metric for vector similarity. Defaults to "cosine". batch_size (int, optional): Batch size for operations. Defaults to 100. extra_params (Dict, optional): Additional parameters for Pinecone client. Defaults to None. """ if client: self.client = client else: api_key = api_key or os.environ.get("PINECONE_API_KEY") if not api_key: raise ValueError( "Pinecone API key must be provided either as a parameter or as an environment variable" ) params = extra_params or {} self.client = Pinecone(api_key=api_key, **params) self.collection_name = collection_name self.embedding_model_dims = embedding_model_dims self.environment = environment self.serverless_config = serverless_config self.pod_config = pod_config self.hybrid_search = hybrid_search self.metric = metric self.batch_size = batch_size self.sparse_encoder = None if self.hybrid_search: try: from pinecone_text.sparse import BM25Encoder logger.info("Initializing BM25Encoder for sparse vectors...") self.sparse_encoder = BM25Encoder.default() except ImportError: logger.warning("pinecone-text not installed. Hybrid search will be disabled.") self.hybrid_search = False self.create_col(embedding_model_dims, metric) def create_col(self, vector_size: int, metric: str = "cosine"): """ Create a new index/collection. Args: vector_size (int): Size of the vectors to be stored. metric (str, optional): Distance metric for vector similarity. Defaults to "cosine". """ existing_indexes = self.list_cols().names() if self.collection_name in existing_indexes: logging.debug(f"Index {self.collection_name} already exists. Skipping creation.") self.index = self.client.Index(self.collection_name) return if self.serverless_config: spec = ServerlessSpec(**self.serverless_config) elif self.pod_config: spec = PodSpec(**self.pod_config) else: spec = ServerlessSpec(cloud="aws", region="us-west-2") self.client.create_index( name=self.collection_name, dimension=vector_size, metric=metric, spec=spec, ) self.index = self.client.Index(self.collection_name) def insert( self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[Union[str, int]]] = None, ): """ Insert vectors into an index. Args: vectors (list): List of vectors to insert. payloads (list, optional): List of payloads corresponding to vectors. Defaults to None. ids (list, optional): List of IDs corresponding to vectors. Defaults to None. """ logger.info(f"Inserting {len(vectors)} vectors into index {self.collection_name}") items = [] for idx, vector in enumerate(vectors): item_id = str(ids[idx]) if ids is not None else str(idx) payload = payloads[idx] if payloads else {} vector_record = {"id": item_id, "values": vector, "metadata": payload} if self.hybrid_search and self.sparse_encoder and "text" in payload: sparse_vector = self.sparse_encoder.encode_documents(payload["text"]) vector_record["sparse_values"] = sparse_vector items.append(vector_record) if len(items) >= self.batch_size: self.index.upsert(vectors=items) items = [] if items: self.index.upsert(vectors=items) def _parse_output(self, data: Dict) -> List[OutputData]: """ Parse the output data from Pinecone search results. Args: data (Dict): Output data from Pinecone query. Returns: List[OutputData]: Parsed output data. """ if isinstance(data, Vector): result = OutputData( id=data.id, score=0.0, payload=data.metadata, ) return result else: result = [] for match in data: entry = OutputData( id=match.get("id"), score=match.get("score"), payload=match.get("metadata"), ) result.append(entry) return result def _create_filter(self, filters: Optional[Dict]) -> Dict: """ Create a filter dictionary from the provided filters. """ if not filters: return {} pinecone_filter = {} for key, value in filters.items(): if isinstance(value, dict) and "gte" in value and "lte" in value: pinecone_filter[key] = {"$gte": value["gte"], "$lte": value["lte"]} else: pinecone_filter[key] = {"$eq": value} return pinecone_filter def search( self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None ) -> List[OutputData]: """ Search for similar vectors. Args: query (str): Query. vectors (list): List of vectors to search. limit (int, optional): Number of results to return. Defaults to 5. filters (dict, optional): Filters to apply to the search. Defaults to None. Returns: list: Search results. """ filter_dict = self._create_filter(filters) if filters else None query_params = { "vector": vectors, "top_k": limit, "include_metadata": True, "include_values": False, } if filter_dict: query_params["filter"] = filter_dict if self.hybrid_search and self.sparse_encoder and "text" in filters: query_text = filters.get("text") if query_text: sparse_vector = self.sparse_encoder.encode_queries(query_text) query_params["sparse_vector"] = sparse_vector response = self.index.query(**query_params) results = self._parse_output(response.matches) return results def delete(self, vector_id: Union[str, int]): """ Delete a vector by ID. Args: vector_id (Union[str, int]): ID of the vector to delete. """ self.index.delete(ids=[str(vector_id)]) def update(self, vector_id: Union[str, int], vector: Optional[List[float]] = None, payload: Optional[Dict] = None): """ Update a vector and its payload. Args: vector_id (Union[str, int]): ID of the vector to update. vector (list, optional): Updated vector. Defaults to None. payload (dict, optional): Updated payload. Defaults to None. """ item = { "id": str(vector_id), } if vector is not None: item["values"] = vector if payload is not None: item["metadata"] = payload if self.hybrid_search and self.sparse_encoder and "text" in payload: sparse_vector = self.sparse_encoder.encode_documents(payload["text"]) item["sparse_values"] = sparse_vector self.index.upsert(vectors=[item]) def get(self, vector_id: Union[str, int]) -> OutputData: """ Retrieve a vector by ID. Args: vector_id (Union[str, int]): ID of the vector to retrieve. Returns: dict: Retrieved vector or None if not found. """ try: response = self.index.fetch(ids=[str(vector_id)]) if str(vector_id) in response.vectors: return self._parse_output(response.vectors[str(vector_id)]) return None except Exception as e: logger.error(f"Error retrieving vector {vector_id}: {e}") return None def list_cols(self): """ List all indexes/collections. Returns: list: List of index information. """ return self.client.list_indexes() def delete_col(self): """Delete an index/collection.""" try: self.client.delete_index(self.collection_name) logger.info(f"Index {self.collection_name} deleted successfully") except Exception as e: logger.error(f"Error deleting index {self.collection_name}: {e}") def col_info(self) -> Dict: """ Get information about an index/collection. Returns: dict: Index information. """ return self.client.describe_index(self.collection_name) def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]: """ List vectors in an index with optional filtering. Args: filters (dict, optional): Filters to apply to the list. Defaults to None. limit (int, optional): Number of vectors to return. Defaults to 100. Returns: dict: List of vectors with their metadata. """ filter_dict = self._create_filter(filters) if filters else None stats = self.index.describe_index_stats() dimension = stats.dimension zero_vector = [0.0] * dimension query_params = { "vector": zero_vector, "top_k": limit, "include_metadata": True, "include_values": True, } if filter_dict: query_params["filter"] = filter_dict try: response = self.index.query(**query_params) response = response.to_dict() results = self._parse_output(response["matches"]) return [results] except Exception as e: logger.error(f"Error listing vectors: {e}") return {"points": [], "next_page_token": None} def count(self) -> int: """ Count number of vectors in the index. Returns: int: Total number of vectors. """ stats = self.index.describe_index_stats() return stats.total_vector_count def reset(self): """ Reset the index by deleting and recreating it. """ logger.warning(f"Resetting index {self.collection_name}...") self.delete_col() self.create_col(self.embedding_model_dims, self.metric)