[feature]: Improve pinecone db integration (#806)
This commit is contained in:
@@ -2,8 +2,9 @@ from string import Template
|
||||
|
||||
from embedchain.apps.app import App
|
||||
from embedchain.apps.open_source_app import OpenSourceApp
|
||||
from embedchain.config import BaseLlmConfig, AppConfig
|
||||
from embedchain.config.llm.base import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY
|
||||
from embedchain.config import AppConfig, BaseLlmConfig
|
||||
from embedchain.config.llm.base import (DEFAULT_PROMPT,
|
||||
DEFAULT_PROMPT_WITH_HISTORY)
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
|
||||
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class PineconeDbConfig(BaseVectorDbConfig):
|
||||
class PineconeDBConfig(BaseVectorDbConfig):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: Optional[str] = None,
|
||||
dir: Optional[str] = None,
|
||||
dimension: Optional[int] = 1536,
|
||||
vector_dimension: int = 1536,
|
||||
metric: Optional[str] = "cosine",
|
||||
**extra_params: Dict[str, any],
|
||||
):
|
||||
self.dimension = dimension
|
||||
self.metric = metric
|
||||
self.vector_dimension = vector_dimension
|
||||
self.extra_params = extra_params
|
||||
super().__init__(collection_name=collection_name, dir=dir)
|
||||
|
||||
@@ -403,6 +403,8 @@ class EmbedChain(JSONSerializable):
|
||||
skip_embedding=(chunker.data_type == DataType.IMAGES),
|
||||
)
|
||||
count_new_chunks = self.db.count() - chunks_before_addition
|
||||
|
||||
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
|
||||
return list(documents), metadatas, ids, count_new_chunks
|
||||
|
||||
def _format_result(self, results):
|
||||
|
||||
@@ -69,11 +69,13 @@ class VectorDBFactory:
|
||||
"chroma": "embedchain.vectordb.chroma.ChromaDB",
|
||||
"elasticsearch": "embedchain.vectordb.elasticsearch.ElasticsearchDB",
|
||||
"opensearch": "embedchain.vectordb.opensearch.OpenSearchDB",
|
||||
"pinecone": "embedchain.vectordb.pinecone.PineconeDB",
|
||||
}
|
||||
provider_to_config_class = {
|
||||
"chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig",
|
||||
"elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig",
|
||||
"opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig",
|
||||
"pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import copy
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
@@ -6,38 +5,38 @@ try:
|
||||
import pinecone
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Pinecone requires extra dependencies. Install with `pip install --upgrade embedchain[pinecone]`"
|
||||
"Pinecone requires extra dependencies. Install with `pip install --upgrade 'embedchain[pinecone]'`"
|
||||
) from None
|
||||
|
||||
from embedchain.config.vectordb.pinecone import PineconeDbConfig
|
||||
from embedchain.config.vectordb.pinecone import PineconeDBConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class PineconeDb(BaseVectorDB):
|
||||
BATCH_SIZE = 100
|
||||
|
||||
class PineconeDB(BaseVectorDB):
|
||||
"""
|
||||
Pinecone as vector database
|
||||
"""
|
||||
|
||||
BATCH_SIZE = 100
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[PineconeDbConfig] = None,
|
||||
config: Optional[PineconeDBConfig] = None,
|
||||
):
|
||||
"""Pinecone as vector database.
|
||||
|
||||
:param config: Pinecone database config, defaults to None
|
||||
:type config: PineconeDbConfig, optional
|
||||
:type config: PineconeDBConfig, optional
|
||||
:raises ValueError: No config provided
|
||||
"""
|
||||
if config is None:
|
||||
self.config = PineconeDbConfig()
|
||||
self.config = PineconeDBConfig()
|
||||
else:
|
||||
if not isinstance(config, PineconeDbConfig):
|
||||
if not isinstance(config, PineconeDBConfig):
|
||||
raise TypeError(
|
||||
"config is not a `PineconeDbConfig` instance. "
|
||||
"config is not a `PineconeDBConfig` instance. "
|
||||
"Please make sure the type is right and that you are passing an instance."
|
||||
)
|
||||
self.config = config
|
||||
@@ -57,11 +56,14 @@ class PineconeDb(BaseVectorDB):
|
||||
pinecone.init(
|
||||
api_key=os.environ.get("PINECONE_API_KEY"),
|
||||
environment=os.environ.get("PINECONE_ENV"),
|
||||
**self.config.extra_params,
|
||||
)
|
||||
self.index_name = self._get_index_name()
|
||||
indexes = pinecone.list_indexes()
|
||||
if indexes is None or self.index_name not in indexes:
|
||||
pinecone.create_index(name=self.index_name, metric=self.config.metric, dimension=self.config.dimension)
|
||||
pinecone.create_index(
|
||||
name=self.index_name, metric=self.config.metric, dimension=self.config.vector_dimension
|
||||
)
|
||||
return pinecone.Index(self.index_name)
|
||||
|
||||
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
|
||||
@@ -81,7 +83,6 @@ class PineconeDb(BaseVectorDB):
|
||||
result = self.client.fetch(ids=ids[i : i + 1000])
|
||||
batch_existing_ids = list(result.get("vectors").keys())
|
||||
existing_ids.extend(batch_existing_ids)
|
||||
|
||||
return {"ids": existing_ids}
|
||||
|
||||
def add(
|
||||
@@ -102,15 +103,15 @@ class PineconeDb(BaseVectorDB):
|
||||
:type ids: List[str]
|
||||
"""
|
||||
docs = []
|
||||
if embeddings is None:
|
||||
embeddings = self.embedder.embedding_fn(documents)
|
||||
print("Adding documents to Pinecone...")
|
||||
|
||||
embeddings = self.embedder.embedding_fn(documents)
|
||||
for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
|
||||
metadata["text"] = text
|
||||
docs.append(
|
||||
{
|
||||
"id": id,
|
||||
"values": embedding,
|
||||
"metadata": copy.deepcopy(metadata),
|
||||
"metadata": {**metadata, "text": text},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -120,13 +121,14 @@ class PineconeDb(BaseVectorDB):
|
||||
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
|
||||
"""
|
||||
query contents from vector database based on vector similarity
|
||||
|
||||
:param input_query: list of query string
|
||||
:type input_query: List[str]
|
||||
:param n_results: no of similar documents to fetch from database
|
||||
:type n_results: int
|
||||
:param where: Optional. to filter data
|
||||
:type where: Dict[str, any]
|
||||
:param skip_embedding: Optional. if True, input_query is already embedded
|
||||
:type skip_embedding: bool
|
||||
:return: Database contents that are the result of the query
|
||||
:rtype: List[str]
|
||||
"""
|
||||
@@ -177,4 +179,4 @@ class PineconeDb(BaseVectorDB):
|
||||
:return: Pinecone index
|
||||
:rtype: str
|
||||
"""
|
||||
return f"{self.config.collection_name}-{self.config.dimension}".lower().replace("_", "-")
|
||||
return f"{self.config.collection_name}-{self.config.vector_dimension}".lower().replace("_", "-")
|
||||
Reference in New Issue
Block a user