From e33008e3a4ca2da6f46395f354b75094b7f7cdcb Mon Sep 17 00:00:00 2001 From: Parshva Daftari <89991302+parshvadaftari@users.noreply.github.com> Date: Thu, 20 Mar 2025 12:57:32 +0530 Subject: [PATCH] Add: Pinecone integration (#2395) --- Makefile | 2 +- docs/components/vectordbs/dbs/pinecone.mdx | 88 +++++ docs/components/vectordbs/overview.mdx | 1 + docs/docs.json | 1 + mem0/configs/vector_stores/pinecone.py | 56 ++++ mem0/utils/factory.py | 1 + mem0/vector_stores/configs.py | 1 + mem0/vector_stores/pinecone.py | 368 +++++++++++++++++++++ tests/vector_stores/test_pinecone.py | 120 +++++++ 9 files changed, 637 insertions(+), 1 deletion(-) create mode 100644 docs/components/vectordbs/dbs/pinecone.mdx create mode 100644 mem0/configs/vector_stores/pinecone.py create mode 100644 mem0/vector_stores/pinecone.py create mode 100644 tests/vector_stores/test_pinecone.py diff --git a/Makefile b/Makefile index 2d3763d2..3fb584f3 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ install: install_all: poetry install poetry run pip install groq together boto3 litellm ollama chromadb weaviate weaviate-client sentence_transformers vertexai \ - google-generativeai elasticsearch opensearch-py vecs + google-generativeai elasticsearch opensearch-py vecs pinecone pinecone-text # Format code with ruff format: diff --git a/docs/components/vectordbs/dbs/pinecone.mdx b/docs/components/vectordbs/dbs/pinecone.mdx new file mode 100644 index 00000000..98debefe --- /dev/null +++ b/docs/components/vectordbs/dbs/pinecone.mdx @@ -0,0 +1,88 @@ +# Pinecone + +[Pinecone](https://www.pinecone.io/) is a fully managed vector database designed for machine learning applications, offering high performance vector search with low latency at scale. It's particularly well-suited for semantic search, recommendation systems, and other AI-powered applications. + +### Usage + +```python +import os +from mem0 import Memory + +os.environ["OPENAI_API_KEY"] = "sk-xx" +os.environ["PINECONE_API_KEY"] = "your-api-key" + +config = { + "vector_store": { + "provider": "pinecone", + "config": { + "collection_name": "memory_index", + "embedding_model_dims": 1536, + "environment": "us-west1-gcp", + "metric": "cosine" + } + } +} + +m = Memory.from_config(config) +messages = [ + {"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"}, + {"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."}, + {"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."}, + {"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."} +] +m.add(messages, user_id="alice", metadata={"category": "movies"}) +``` + +### Config + +Here are the parameters available for configuring Pinecone: + +| Parameter | Description | Default Value | +| --- | --- | --- | +| `collection_name` | Name of the index/collection | Required | +| `embedding_model_dims` | Dimensions of the embedding model | Required | +| `client` | Existing Pinecone client instance | `None` | +| `api_key` | API key for Pinecone | Environment variable: `PINECONE_API_KEY` | +| `environment` | Pinecone environment | `None` | +| `serverless_config` | Configuration for serverless deployment | `None` | +| `pod_config` | Configuration for pod-based deployment | `None` | +| `hybrid_search` | Whether to enable hybrid search | `False` | +| `metric` | Distance metric for vector similarity | `"cosine"` | +| `batch_size` | Batch size for operations | `100` | + +#### Serverless Config Example + +```python +config = { + "vector_store": { + "provider": "pinecone", + "config": { + "collection_name": "memory_index", + "embedding_model_dims": 1536, + "serverless_config": { + "cloud": "aws", + "region": "us-west-2" + } + } + } +} +``` + +#### Pod Config Example + +```python +config = { + "vector_store": { + "provider": "pinecone", + "config": { + "collection_name": "memory_index", + "embedding_model_dims": 1536, + "pod_config": { + "environment": "gcp-starter", + "replicas": 1, + "pod_type": "starter" + } + } + } +} +``` \ No newline at end of file diff --git a/docs/components/vectordbs/overview.mdx b/docs/components/vectordbs/overview.mdx index ebffe18c..0e169f10 100644 --- a/docs/components/vectordbs/overview.mdx +++ b/docs/components/vectordbs/overview.mdx @@ -19,6 +19,7 @@ See the list of supported vector databases below. + diff --git a/docs/docs.json b/docs/docs.json index 19067e81..19e435c7 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -129,6 +129,7 @@ "components/vectordbs/dbs/chroma", "components/vectordbs/dbs/pgvector", "components/vectordbs/dbs/milvus", + "components/vectordbs/dbs/pinecone", "components/vectordbs/dbs/azure_ai_search", "components/vectordbs/dbs/redis", "components/vectordbs/dbs/elasticsearch", diff --git a/mem0/configs/vector_stores/pinecone.py b/mem0/configs/vector_stores/pinecone.py new file mode 100644 index 00000000..bca8dfd8 --- /dev/null +++ b/mem0/configs/vector_stores/pinecone.py @@ -0,0 +1,56 @@ +import os +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Field, model_validator + + +class PineconeConfig(BaseModel): + """Configuration for Pinecone vector database.""" + + collection_name: str = Field("mem0", description="Name of the index/collection") + embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model") + client: Optional[Any] = Field(None, description="Existing Pinecone client instance") + api_key: Optional[str] = Field(None, description="API key for Pinecone") + environment: Optional[str] = Field(None, description="Pinecone environment") + serverless_config: Optional[Dict[str, Any]] = Field(None, description="Configuration for serverless deployment") + pod_config: Optional[Dict[str, Any]] = Field(None, description="Configuration for pod-based deployment") + hybrid_search: bool = Field(False, description="Whether to enable hybrid search") + metric: str = Field("cosine", description="Distance metric for vector similarity") + batch_size: int = Field(100, description="Batch size for operations") + extra_params: Optional[Dict[str, Any]] = Field(None, description="Additional parameters for Pinecone client") + + @model_validator(mode="before") + @classmethod + def check_api_key_or_client(cls, values: Dict[str, Any]) -> Dict[str, Any]: + api_key, client = values.get("api_key"), values.get("client") + if not api_key and not client and "PINECONE_API_KEY" not in os.environ: + raise ValueError( + "Either 'api_key' or 'client' must be provided, or PINECONE_API_KEY environment variable must be set." + ) + return values + + @model_validator(mode="before") + @classmethod + def check_pod_or_serverless(cls, values: Dict[str, Any]) -> Dict[str, Any]: + pod_config, serverless_config = values.get("pod_config"), values.get("serverless_config") + if pod_config and serverless_config: + raise ValueError( + "Both 'pod_config' and 'serverless_config' cannot be specified. Choose one deployment option." + ) + return values + + @model_validator(mode="before") + @classmethod + def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]: + allowed_fields = set(cls.model_fields.keys()) + input_fields = set(values.keys()) + extra_fields = input_fields - allowed_fields + if extra_fields: + raise ValueError( + f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}" + ) + return values + + model_config = { + "arbitrary_types_allowed": True, + } diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index 92ec9cdf..f8af1d68 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -67,6 +67,7 @@ class VectorStoreFactory: "pgvector": "mem0.vector_stores.pgvector.PGVector", "milvus": "mem0.vector_stores.milvus.MilvusDB", "azure_ai_search": "mem0.vector_stores.azure_ai_search.AzureAISearch", + "pinecone": "mem0.vector_stores.pinecone.PineconeDB", "redis": "mem0.vector_stores.redis.RedisDB", "elasticsearch": "mem0.vector_stores.elasticsearch.ElasticsearchDB", "vertex_ai_vector_search": "mem0.vector_stores.vertex_ai_vector_search.GoogleMatchingEngine", diff --git a/mem0/vector_stores/configs.py b/mem0/vector_stores/configs.py index d602e335..db7a75d3 100644 --- a/mem0/vector_stores/configs.py +++ b/mem0/vector_stores/configs.py @@ -14,6 +14,7 @@ class VectorStoreConfig(BaseModel): "qdrant": "QdrantConfig", "chroma": "ChromaDbConfig", "pgvector": "PGVectorConfig", + "pinecone": "PineconeConfig", "milvus": "MilvusDBConfig", "azure_ai_search": "AzureAISearchConfig", "redis": "RedisDBConfig", diff --git a/mem0/vector_stores/pinecone.py b/mem0/vector_stores/pinecone.py new file mode 100644 index 00000000..fb455a8d --- /dev/null +++ b/mem0/vector_stores/pinecone.py @@ -0,0 +1,368 @@ +import logging +import os +from typing import Any, Dict, List, Optional, Union + +from pydantic import BaseModel + +try: + from pinecone import Pinecone, PodSpec, ServerlessSpec + from pinecone.data.dataclasses.vector import Vector +except ImportError: + raise ImportError("Pinecone requires extra dependencies. Install with `pip install pinecone pinecone-text`") from None + +from mem0.vector_stores.base import VectorStoreBase + +logger = logging.getLogger(__name__) + + +class OutputData(BaseModel): + id: Optional[str] # memory id + score: Optional[float] # distance + payload: Optional[Dict] # metadata + + +class PineconeDB(VectorStoreBase): + def __init__( + self, + collection_name: str, + embedding_model_dims: int, + client: Optional["Pinecone"], + api_key: Optional[str], + environment: Optional[str], + serverless_config: Optional[Dict[str, Any]], + pod_config: Optional[Dict[str, Any]], + hybrid_search: bool, + metric: str, + batch_size: int, + extra_params: Optional[Dict[str, Any]] + ): + """ + Initialize the Pinecone vector store. + + Args: + collection_name (str): Name of the index/collection. + embedding_model_dims (int): Dimensions of the embedding model. + client (Pinecone, optional): Existing Pinecone client instance. Defaults to None. + api_key (str, optional): API key for Pinecone. Defaults to None. + environment (str, optional): Pinecone environment. Defaults to None. + serverless_config (Dict, optional): Configuration for serverless deployment. Defaults to None. + pod_config (Dict, optional): Configuration for pod-based deployment. Defaults to None. + hybrid_search (bool, optional): Whether to enable hybrid search. Defaults to False. + metric (str, optional): Distance metric for vector similarity. Defaults to "cosine". + batch_size (int, optional): Batch size for operations. Defaults to 100. + extra_params (Dict, optional): Additional parameters for Pinecone client. Defaults to None. + """ + if client: + self.client = client + else: + api_key = api_key or os.environ.get("PINECONE_API_KEY") + if not api_key: + raise ValueError( + "Pinecone API key must be provided either as a parameter or as an environment variable" + ) + + params = extra_params or {} + self.client = Pinecone(api_key=api_key, **params) + + self.collection_name = collection_name + self.embedding_model_dims = embedding_model_dims + self.environment = environment + self.serverless_config = serverless_config + self.pod_config = pod_config + self.hybrid_search = hybrid_search + self.metric = metric + self.batch_size = batch_size + + self.sparse_encoder = None + if self.hybrid_search: + try: + from pinecone_text.sparse import BM25Encoder + + logger.info("Initializing BM25Encoder for sparse vectors...") + self.sparse_encoder = BM25Encoder.default() + except ImportError: + logger.warning("pinecone-text not installed. Hybrid search will be disabled.") + self.hybrid_search = False + + self.create_col(embedding_model_dims, metric) + + def create_col(self, vector_size: int, metric: str = "cosine"): + """ + Create a new index/collection. + + Args: + vector_size (int): Size of the vectors to be stored. + metric (str, optional): Distance metric for vector similarity. Defaults to "cosine". + """ + existing_indexes = self.list_cols().names() + + if self.collection_name in existing_indexes: + logging.debug(f"Index {self.collection_name} already exists. Skipping creation.") + self.index = self.client.Index(self.collection_name) + return + + if self.serverless_config: + spec = ServerlessSpec(**self.serverless_config) + elif self.pod_config: + spec = PodSpec(**self.pod_config) + else: + spec = ServerlessSpec(cloud="aws", region="us-west-2") + + self.client.create_index( + name=self.collection_name, + dimension=vector_size, + metric=metric, + spec=spec, + ) + + self.index = self.client.Index(self.collection_name) + + def insert( + self, + vectors: List[List[float]], + payloads: Optional[List[Dict]] = None, + ids: Optional[List[Union[str, int]]] = None, + ): + """ + Insert vectors into an index. + + Args: + vectors (list): List of vectors to insert. + payloads (list, optional): List of payloads corresponding to vectors. Defaults to None. + ids (list, optional): List of IDs corresponding to vectors. Defaults to None. + """ + logger.info(f"Inserting {len(vectors)} vectors into index {self.collection_name}") + items = [] + + for idx, vector in enumerate(vectors): + item_id = str(ids[idx]) if ids is not None else str(idx) + payload = payloads[idx] if payloads else {} + + vector_record = {"id": item_id, "values": vector, "metadata": payload} + + if self.hybrid_search and self.sparse_encoder and "text" in payload: + sparse_vector = self.sparse_encoder.encode_documents(payload["text"]) + vector_record["sparse_values"] = sparse_vector + + items.append(vector_record) + + if len(items) >= self.batch_size: + self.index.upsert(vectors=items) + items = [] + + if items: + self.index.upsert(vectors=items) + + def _parse_output(self, data: Dict) -> List[OutputData]: + """ + Parse the output data from Pinecone search results. + + Args: + data (Dict): Output data from Pinecone query. + + Returns: + List[OutputData]: Parsed output data. + """ + if isinstance(data, Vector): + result = OutputData( + id=data.id, + score=0.0, + payload=data.metadata, + ) + return result + else: + result = [] + for match in data: + entry = OutputData( + id=match.get("id"), + score=match.get("score"), + payload=match.get("metadata"), + ) + result.append(entry) + + return result + + def _create_filter(self, filters: Optional[Dict]) -> Dict: + """ + Create a filter dictionary from the provided filters. + """ + if not filters: + return {} + + pinecone_filter = {} + + for key, value in filters.items(): + if isinstance(value, dict) and "gte" in value and "lte" in value: + pinecone_filter[key] = {"$gte": value["gte"], "$lte": value["lte"]} + else: + pinecone_filter[key] = {"$eq": value} + + return pinecone_filter + + def search(self, query: List[float], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]: + """ + Search for similar vectors. + + Args: + query (list): Query vector. + limit (int, optional): Number of results to return. Defaults to 5. + filters (dict, optional): Filters to apply to the search. Defaults to None. + + Returns: + list: Search results. + """ + filter_dict = self._create_filter(filters) if filters else None + + query_params = { + "vector": query, + "top_k": limit, + "include_metadata": True, + "include_values": False, + } + + if filter_dict: + query_params["filter"] = filter_dict + + if self.hybrid_search and self.sparse_encoder and "text" in filters: + query_text = filters.get("text") + if query_text: + sparse_vector = self.sparse_encoder.encode_queries(query_text) + query_params["sparse_vector"] = sparse_vector + + response = self.index.query(**query_params) + + results = self._parse_output(response.matches) + return results + + def delete(self, vector_id: Union[str, int]): + """ + Delete a vector by ID. + + Args: + vector_id (Union[str, int]): ID of the vector to delete. + """ + self.index.delete(ids=[str(vector_id)]) + + def update(self, vector_id: Union[str, int], vector: Optional[List[float]] = None, payload: Optional[Dict] = None): + """ + Update a vector and its payload. + + Args: + vector_id (Union[str, int]): ID of the vector to update. + vector (list, optional): Updated vector. Defaults to None. + payload (dict, optional): Updated payload. Defaults to None. + """ + item = { + "id": str(vector_id), + } + + if vector is not None: + item["values"] = vector + + if payload is not None: + item["metadata"] = payload + + if self.hybrid_search and self.sparse_encoder and "text" in payload: + sparse_vector = self.sparse_encoder.encode_documents(payload["text"]) + item["sparse_values"] = sparse_vector + + self.index.upsert(vectors=[item]) + + def get(self, vector_id: Union[str, int]) -> OutputData: + """ + Retrieve a vector by ID. + + Args: + vector_id (Union[str, int]): ID of the vector to retrieve. + + Returns: + dict: Retrieved vector or None if not found. + """ + try: + response = self.index.fetch(ids=[str(vector_id)]) + if str(vector_id) in response.vectors: + return self._parse_output(response.vectors[str(vector_id)]) + return None + except Exception as e: + logger.error(f"Error retrieving vector {vector_id}: {e}") + return None + + def list_cols(self): + """ + List all indexes/collections. + + Returns: + list: List of index information. + """ + return self.client.list_indexes() + + def delete_col(self): + """Delete an index/collection.""" + try: + self.client.delete_index(self.collection_name) + logger.info(f"Index {self.collection_name} deleted successfully") + except Exception as e: + logger.error(f"Error deleting index {self.collection_name}: {e}") + + def col_info(self) -> Dict: + """ + Get information about an index/collection. + + Returns: + dict: Index information. + """ + return self.client.describe_index(self.collection_name) + + def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]: + """ + List vectors in an index with optional filtering. + + Args: + filters (dict, optional): Filters to apply to the list. Defaults to None. + limit (int, optional): Number of vectors to return. Defaults to 100. + + Returns: + dict: List of vectors with their metadata. + """ + filter_dict = self._create_filter(filters) if filters else None + + stats = self.index.describe_index_stats() + dimension = stats.dimension + + zero_vector = [0.0] * dimension + + query_params = { + "vector": zero_vector, + "top_k": limit, + "include_metadata": True, + "include_values": True, + } + + if filter_dict: + query_params["filter"] = filter_dict + + try: + response = self.index.query(**query_params) + response = response.to_dict() + results = self._parse_output(response["matches"]) + return [results] + except Exception as e: + logger.error(f"Error listing vectors: {e}") + return {"points": [], "next_page_token": None} + + def count(self) -> int: + """ + Count number of vectors in the index. + + Returns: + int: Total number of vectors. + """ + stats = self.index.describe_index_stats() + return stats.total_vector_count + + def reset(self): + """ + Reset the index by deleting and recreating it. + """ + self.delete_col() + self.create_col(self.embedding_model_dims, self.metric) diff --git a/tests/vector_stores/test_pinecone.py b/tests/vector_stores/test_pinecone.py new file mode 100644 index 00000000..c7e796d5 --- /dev/null +++ b/tests/vector_stores/test_pinecone.py @@ -0,0 +1,120 @@ +from unittest.mock import MagicMock + +import pytest + +from mem0.vector_stores.pinecone import PineconeDB + + +@pytest.fixture +def mock_pinecone_client(): + client = MagicMock() + client.Index.return_value = MagicMock() + client.list_indexes.return_value.names.return_value = [] + return client + +@pytest.fixture +def pinecone_db(mock_pinecone_client): + return PineconeDB( + collection_name="test_index", + embedding_model_dims=128, + client=mock_pinecone_client, + api_key="fake_api_key", + environment="us-west1-gcp", + serverless_config=None, + pod_config=None, + hybrid_search=False, + metric="cosine", + batch_size=100, + extra_params=None + ) + +def test_create_col_existing_index(mock_pinecone_client): + # Set up the mock before creating the PineconeDB object + mock_pinecone_client.list_indexes.return_value.names.return_value = ["test_index"] + + pinecone_db = PineconeDB( + collection_name="test_index", + embedding_model_dims=128, + client=mock_pinecone_client, + api_key="fake_api_key", + environment="us-west1-gcp", + serverless_config=None, + pod_config=None, + hybrid_search=False, + metric="cosine", + batch_size=100, + extra_params=None + ) + + # Reset the mock to verify it wasn't called during the test + mock_pinecone_client.create_index.reset_mock() + + pinecone_db.create_col(128, "cosine") + + mock_pinecone_client.create_index.assert_not_called() + +def test_create_col_new_index(pinecone_db, mock_pinecone_client): + mock_pinecone_client.list_indexes.return_value.names.return_value = [] + pinecone_db.create_col(128, "cosine") + mock_pinecone_client.create_index.assert_called() + +def test_insert_vectors(pinecone_db): + vectors = [[0.1] * 128, [0.2] * 128] + payloads = [{"name": "vector1"}, {"name": "vector2"}] + ids = ["id1", "id2"] + pinecone_db.insert(vectors, payloads, ids) + pinecone_db.index.upsert.assert_called() + +def test_search_vectors(pinecone_db): + pinecone_db.index.query.return_value.matches = [{"id": "id1", "score": 0.9, "metadata": {"name": "vector1"}}] + results = pinecone_db.search([0.1] * 128, limit=1) + assert len(results) == 1 + assert results[0].id == "id1" + assert results[0].score == 0.9 + +def test_update_vector(pinecone_db): + pinecone_db.update("id1", vector=[0.5] * 128, payload={"name": "updated"}) + pinecone_db.index.upsert.assert_called() + +def test_get_vector_found(pinecone_db): + # Looking at the _parse_output method, it expects a Vector object + # or a list of dictionaries, not a dictionary with an 'id' field + + # Create a mock Vector object + from pinecone.data.dataclasses.vector import Vector + mock_vector = Vector( + id="id1", + values=[0.1] * 128, + metadata={"name": "vector1"} + ) + + # Mock the fetch method to return the mock response object + mock_response = MagicMock() + mock_response.vectors = {"id1": mock_vector} + pinecone_db.index.fetch.return_value = mock_response + + result = pinecone_db.get("id1") + assert result is not None + assert result.id == "id1" + assert result.payload == {"name": "vector1"} + +def test_delete_vector(pinecone_db): + pinecone_db.delete("id1") + pinecone_db.index.delete.assert_called_with(ids=["id1"]) + +def test_get_vector_not_found(pinecone_db): + pinecone_db.index.fetch.return_value.vectors = {} + result = pinecone_db.get("id1") + assert result is None + +def test_list_cols(pinecone_db): + pinecone_db.list_cols() + pinecone_db.client.list_indexes.assert_called() + +def test_delete_col(pinecone_db): + pinecone_db.delete_col() + pinecone_db.client.delete_index.assert_called_with("test_index") + +def test_col_info(pinecone_db): + pinecone_db.col_info() + pinecone_db.client.describe_index.assert_called_with("test_index")