[Improvement] update pinecone client v3 (#1200)
Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
@@ -9,14 +10,29 @@ class PineconeDBConfig(BaseVectorDbConfig):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
index_name: Optional[str] = None,
|
||||
dir: Optional[str] = None,
|
||||
vector_dimension: int = 1536,
|
||||
metric: Optional[str] = "cosine",
|
||||
pod_config: Optional[dict[str, any]] = None,
|
||||
serverless_config: Optional[dict[str, any]] = None,
|
||||
**extra_params: dict[str, any],
|
||||
):
|
||||
self.metric = metric
|
||||
self.api_key = api_key
|
||||
self.vector_dimension = vector_dimension
|
||||
self.extra_params = extra_params
|
||||
self.index_name = index_name or f"{collection_name}-{vector_dimension}".lower().replace("_", "-")
|
||||
super().__init__(collection_name=collection_name, dir=dir)
|
||||
if pod_config is None and serverless_config is None:
|
||||
# If no config is provided, use the default pod spec config
|
||||
pod_environment = os.environ.get("PINECONE_ENV", "gcp-starter")
|
||||
self.pod_config = {"environment": pod_environment, "metadata_config": {"indexed": ["*"]}}
|
||||
else:
|
||||
self.pod_config = pod_config
|
||||
self.serverless_config = serverless_config
|
||||
|
||||
if self.pod_config and self.serverless_config:
|
||||
raise ValueError("Only one of pod_config or serverless_config can be provided.")
|
||||
|
||||
super().__init__(collection_name=collection_name, dir=None)
|
||||
|
||||
@@ -42,7 +42,7 @@ class PineconeDB(BaseVectorDB):
|
||||
"Please make sure the type is right and that you are passing an instance."
|
||||
)
|
||||
self.config = config
|
||||
self.client = self._setup_pinecone_index()
|
||||
self._setup_pinecone_index()
|
||||
# Call parent init here because embedder is needed
|
||||
super().__init__(config=self.config)
|
||||
|
||||
@@ -57,17 +57,26 @@ class PineconeDB(BaseVectorDB):
|
||||
"""
|
||||
Loads the Pinecone index or creates it if not present.
|
||||
"""
|
||||
pinecone.init(
|
||||
api_key=os.environ.get("PINECONE_API_KEY"),
|
||||
environment=os.environ.get("PINECONE_ENV"),
|
||||
**self.config.extra_params,
|
||||
)
|
||||
indexes = pinecone.list_indexes()
|
||||
api_key = self.config.api_key or os.environ.get("PINECONE_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("Please set the PINECONE_API_KEY environment variable or pass it in config.")
|
||||
self.client = pinecone.Pinecone(api_key=api_key, **self.config.extra_params)
|
||||
indexes = self.client.list_indexes().names()
|
||||
if indexes is None or self.config.index_name not in indexes:
|
||||
pinecone.create_index(
|
||||
name=self.config.index_name, metric=self.config.metric, dimension=self.config.vector_dimension
|
||||
if self.config.pod_config:
|
||||
spec = pinecone.PodSpec(**self.config.pod_config)
|
||||
elif self.config.serverless_config:
|
||||
spec = pinecone.ServerlessSpec(**self.config.serverless_config)
|
||||
else:
|
||||
raise ValueError("No pod_config or serverless_config found.")
|
||||
|
||||
self.client.create_index(
|
||||
name=self.config.index_name,
|
||||
metric=self.config.metric,
|
||||
dimension=self.config.vector_dimension,
|
||||
spec=spec,
|
||||
)
|
||||
return pinecone.Index(self.config.index_name)
|
||||
self.pinecone_index = self.client.Index(self.config.index_name)
|
||||
|
||||
def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
|
||||
"""
|
||||
@@ -85,7 +94,7 @@ class PineconeDB(BaseVectorDB):
|
||||
|
||||
if ids is not None:
|
||||
for i in range(0, len(ids), 1000):
|
||||
result = self.client.fetch(ids=ids[i : i + 1000])
|
||||
result = self.pinecone_index.fetch(ids=ids[i : i + 1000])
|
||||
vectors = result.get("vectors")
|
||||
batch_existing_ids = list(vectors.keys())
|
||||
existing_ids.extend(batch_existing_ids)
|
||||
@@ -125,7 +134,7 @@ class PineconeDB(BaseVectorDB):
|
||||
)
|
||||
|
||||
for chunk in chunks(docs, self.BATCH_SIZE, desc="Adding chunks in batches"):
|
||||
self.client.upsert(chunk, **kwargs)
|
||||
self.pinecone_index.upsert(chunk, **kwargs)
|
||||
|
||||
def query(
|
||||
self,
|
||||
@@ -151,15 +160,19 @@ class PineconeDB(BaseVectorDB):
|
||||
"""
|
||||
query_vector = self.embedder.embedding_fn([input_query])[0]
|
||||
query_filter = self._generate_filter(where)
|
||||
data = self.client.query(
|
||||
vector=query_vector, filter=query_filter, top_k=n_results, include_metadata=True, **kwargs
|
||||
data = self.pinecone_index.query(
|
||||
vector=query_vector,
|
||||
filter=query_filter,
|
||||
top_k=n_results,
|
||||
include_metadata=True,
|
||||
**kwargs,
|
||||
)
|
||||
contexts = []
|
||||
for doc in data["matches"]:
|
||||
metadata = doc["metadata"]
|
||||
context = metadata["text"]
|
||||
for doc in data.get("matches", []):
|
||||
metadata = doc.get("metadata", {})
|
||||
context = metadata.get("text")
|
||||
if citations:
|
||||
metadata["score"] = doc["score"]
|
||||
metadata["score"] = doc.get("score")
|
||||
contexts.append(tuple((context, metadata)))
|
||||
else:
|
||||
contexts.append(context)
|
||||
@@ -183,7 +196,8 @@ class PineconeDB(BaseVectorDB):
|
||||
:return: number of documents
|
||||
:rtype: int
|
||||
"""
|
||||
return self.client.describe_index_stats()["total_vector_count"]
|
||||
data = self.pinecone_index.describe_index_stats()
|
||||
return data["total_vector_count"]
|
||||
|
||||
def _get_or_create_db(self):
|
||||
"""Called during initialization"""
|
||||
@@ -194,7 +208,7 @@ class PineconeDB(BaseVectorDB):
|
||||
Resets the database. Deletes all embeddings irreversibly.
|
||||
"""
|
||||
# Delete all data from the database
|
||||
pinecone.delete_index(self.config.index_name)
|
||||
self.client.delete_index(self.config.index_name)
|
||||
self._setup_pinecone_index()
|
||||
|
||||
@staticmethod
|
||||
@@ -213,7 +227,7 @@ class PineconeDB(BaseVectorDB):
|
||||
# Follow `https://docs.pinecone.io/docs/metadata-filtering#deleting-vectors-by-metadata-filter` for more details
|
||||
db_filter = self._generate_filter(where)
|
||||
try:
|
||||
self.client.delete(filter=db_filter)
|
||||
self.pinecone_index.delete(filter=db_filter)
|
||||
except Exception as e:
|
||||
print(f"Failed to delete from Pinecone: {e}")
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user