[feature]: Improve pinecone db integration (#806)

This commit is contained in:
Deshraj Yadav
2023-10-15 02:26:35 -07:00
committed by GitHub
parent a7a61fae1d
commit 636bc0a99d
14 changed files with 85 additions and 46 deletions

View File

@@ -2,8 +2,9 @@ from string import Template
from embedchain.apps.app import App
from embedchain.apps.open_source_app import OpenSourceApp
from embedchain.config import BaseLlmConfig, AppConfig
from embedchain.config.llm.base import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY
from embedchain.config import AppConfig, BaseLlmConfig
from embedchain.config.llm.base import (DEFAULT_PROMPT,
DEFAULT_PROMPT_WITH_HISTORY)
from embedchain.helper.json_serializable import register_deserializable

View File

@@ -1,18 +1,20 @@
from typing import Optional
from typing import Dict, Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helper.json_serializable import register_deserializable
@register_deserializable
class PineconeDbConfig(BaseVectorDbConfig):
class PineconeDBConfig(BaseVectorDbConfig):
def __init__(
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
dimension: Optional[int] = 1536,
vector_dimension: int = 1536,
metric: Optional[str] = "cosine",
**extra_params: Dict[str, any],
):
self.dimension = dimension
self.metric = metric
self.vector_dimension = vector_dimension
self.extra_params = extra_params
super().__init__(collection_name=collection_name, dir=dir)

View File

@@ -403,6 +403,8 @@ class EmbedChain(JSONSerializable):
skip_embedding=(chunker.data_type == DataType.IMAGES),
)
count_new_chunks = self.db.count() - chunks_before_addition
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
return list(documents), metadatas, ids, count_new_chunks
def _format_result(self, results):

View File

@@ -69,11 +69,13 @@ class VectorDBFactory:
"chroma": "embedchain.vectordb.chroma.ChromaDB",
"elasticsearch": "embedchain.vectordb.elasticsearch.ElasticsearchDB",
"opensearch": "embedchain.vectordb.opensearch.OpenSearchDB",
"pinecone": "embedchain.vectordb.pinecone.PineconeDB",
}
provider_to_config_class = {
"chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig",
"elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig",
"opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig",
"pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig",
}
@classmethod

View File

@@ -1,4 +1,3 @@
import copy
import os
from typing import Dict, List, Optional
@@ -6,38 +5,38 @@ try:
import pinecone
except ImportError:
raise ImportError(
"Pinecone requires extra dependencies. Install with `pip install --upgrade embedchain[pinecone]`"
"Pinecone requires extra dependencies. Install with `pip install --upgrade 'embedchain[pinecone]'`"
) from None
from embedchain.config.vectordb.pinecone import PineconeDbConfig
from embedchain.config.vectordb.pinecone import PineconeDBConfig
from embedchain.helper.json_serializable import register_deserializable
from embedchain.vectordb.base import BaseVectorDB
@register_deserializable
class PineconeDb(BaseVectorDB):
BATCH_SIZE = 100
class PineconeDB(BaseVectorDB):
"""
Pinecone as vector database
"""
BATCH_SIZE = 100
def __init__(
self,
config: Optional[PineconeDbConfig] = None,
config: Optional[PineconeDBConfig] = None,
):
"""Pinecone as vector database.
:param config: Pinecone database config, defaults to None
:type config: PineconeDbConfig, optional
:type config: PineconeDBConfig, optional
:raises ValueError: No config provided
"""
if config is None:
self.config = PineconeDbConfig()
self.config = PineconeDBConfig()
else:
if not isinstance(config, PineconeDbConfig):
if not isinstance(config, PineconeDBConfig):
raise TypeError(
"config is not a `PineconeDbConfig` instance. "
"config is not a `PineconeDBConfig` instance. "
"Please make sure the type is right and that you are passing an instance."
)
self.config = config
@@ -57,11 +56,14 @@ class PineconeDb(BaseVectorDB):
pinecone.init(
api_key=os.environ.get("PINECONE_API_KEY"),
environment=os.environ.get("PINECONE_ENV"),
**self.config.extra_params,
)
self.index_name = self._get_index_name()
indexes = pinecone.list_indexes()
if indexes is None or self.index_name not in indexes:
pinecone.create_index(name=self.index_name, metric=self.config.metric, dimension=self.config.dimension)
pinecone.create_index(
name=self.index_name, metric=self.config.metric, dimension=self.config.vector_dimension
)
return pinecone.Index(self.index_name)
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
@@ -81,7 +83,6 @@ class PineconeDb(BaseVectorDB):
result = self.client.fetch(ids=ids[i : i + 1000])
batch_existing_ids = list(result.get("vectors").keys())
existing_ids.extend(batch_existing_ids)
return {"ids": existing_ids}
def add(
@@ -102,15 +103,15 @@ class PineconeDb(BaseVectorDB):
:type ids: List[str]
"""
docs = []
if embeddings is None:
embeddings = self.embedder.embedding_fn(documents)
print("Adding documents to Pinecone...")
embeddings = self.embedder.embedding_fn(documents)
for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
metadata["text"] = text
docs.append(
{
"id": id,
"values": embedding,
"metadata": copy.deepcopy(metadata),
"metadata": {**metadata, "text": text},
}
)
@@ -120,13 +121,14 @@ class PineconeDb(BaseVectorDB):
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
"""
query contents from vector database based on vector similarity
:param input_query: list of query string
:type input_query: List[str]
:param n_results: no of similar documents to fetch from database
:type n_results: int
:param where: Optional. to filter data
:type where: Dict[str, any]
:param skip_embedding: Optional. if True, input_query is already embedded
:type skip_embedding: bool
:return: Database contents that are the result of the query
:rtype: List[str]
"""
@@ -177,4 +179,4 @@ class PineconeDb(BaseVectorDB):
:return: Pinecone index
:rtype: str
"""
return f"{self.config.collection_name}-{self.config.dimension}".lower().replace("_", "-")
return f"{self.config.collection_name}-{self.config.vector_dimension}".lower().replace("_", "-")