[feature]: Improve pinecone db integration (#806)
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import copy
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
@@ -6,38 +5,38 @@ try:
|
||||
import pinecone
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Pinecone requires extra dependencies. Install with `pip install --upgrade embedchain[pinecone]`"
|
||||
"Pinecone requires extra dependencies. Install with `pip install --upgrade 'embedchain[pinecone]'`"
|
||||
) from None
|
||||
|
||||
from embedchain.config.vectordb.pinecone import PineconeDbConfig
|
||||
from embedchain.config.vectordb.pinecone import PineconeDBConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class PineconeDb(BaseVectorDB):
|
||||
BATCH_SIZE = 100
|
||||
|
||||
class PineconeDB(BaseVectorDB):
|
||||
"""
|
||||
Pinecone as vector database
|
||||
"""
|
||||
|
||||
BATCH_SIZE = 100
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[PineconeDbConfig] = None,
|
||||
config: Optional[PineconeDBConfig] = None,
|
||||
):
|
||||
"""Pinecone as vector database.
|
||||
|
||||
:param config: Pinecone database config, defaults to None
|
||||
:type config: PineconeDbConfig, optional
|
||||
:type config: PineconeDBConfig, optional
|
||||
:raises ValueError: No config provided
|
||||
"""
|
||||
if config is None:
|
||||
self.config = PineconeDbConfig()
|
||||
self.config = PineconeDBConfig()
|
||||
else:
|
||||
if not isinstance(config, PineconeDbConfig):
|
||||
if not isinstance(config, PineconeDBConfig):
|
||||
raise TypeError(
|
||||
"config is not a `PineconeDbConfig` instance. "
|
||||
"config is not a `PineconeDBConfig` instance. "
|
||||
"Please make sure the type is right and that you are passing an instance."
|
||||
)
|
||||
self.config = config
|
||||
@@ -57,11 +56,14 @@ class PineconeDb(BaseVectorDB):
|
||||
pinecone.init(
|
||||
api_key=os.environ.get("PINECONE_API_KEY"),
|
||||
environment=os.environ.get("PINECONE_ENV"),
|
||||
**self.config.extra_params,
|
||||
)
|
||||
self.index_name = self._get_index_name()
|
||||
indexes = pinecone.list_indexes()
|
||||
if indexes is None or self.index_name not in indexes:
|
||||
pinecone.create_index(name=self.index_name, metric=self.config.metric, dimension=self.config.dimension)
|
||||
pinecone.create_index(
|
||||
name=self.index_name, metric=self.config.metric, dimension=self.config.vector_dimension
|
||||
)
|
||||
return pinecone.Index(self.index_name)
|
||||
|
||||
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
|
||||
@@ -81,7 +83,6 @@ class PineconeDb(BaseVectorDB):
|
||||
result = self.client.fetch(ids=ids[i : i + 1000])
|
||||
batch_existing_ids = list(result.get("vectors").keys())
|
||||
existing_ids.extend(batch_existing_ids)
|
||||
|
||||
return {"ids": existing_ids}
|
||||
|
||||
def add(
|
||||
@@ -102,15 +103,15 @@ class PineconeDb(BaseVectorDB):
|
||||
:type ids: List[str]
|
||||
"""
|
||||
docs = []
|
||||
if embeddings is None:
|
||||
embeddings = self.embedder.embedding_fn(documents)
|
||||
print("Adding documents to Pinecone...")
|
||||
|
||||
embeddings = self.embedder.embedding_fn(documents)
|
||||
for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
|
||||
metadata["text"] = text
|
||||
docs.append(
|
||||
{
|
||||
"id": id,
|
||||
"values": embedding,
|
||||
"metadata": copy.deepcopy(metadata),
|
||||
"metadata": {**metadata, "text": text},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -120,13 +121,14 @@ class PineconeDb(BaseVectorDB):
|
||||
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
|
||||
"""
|
||||
query contents from vector database based on vector similarity
|
||||
|
||||
:param input_query: list of query string
|
||||
:type input_query: List[str]
|
||||
:param n_results: no of similar documents to fetch from database
|
||||
:type n_results: int
|
||||
:param where: Optional. to filter data
|
||||
:type where: Dict[str, any]
|
||||
:param skip_embedding: Optional. if True, input_query is already embedded
|
||||
:type skip_embedding: bool
|
||||
:return: Database contents that are the result of the query
|
||||
:rtype: List[str]
|
||||
"""
|
||||
@@ -177,4 +179,4 @@ class PineconeDb(BaseVectorDB):
|
||||
:return: Pinecone index
|
||||
:rtype: str
|
||||
"""
|
||||
return f"{self.config.collection_name}-{self.config.dimension}".lower().replace("_", "-")
|
||||
return f"{self.config.collection_name}-{self.config.vector_dimension}".lower().replace("_", "-")
|
||||
Reference in New Issue
Block a user