[feature]: Improve pinecone db integration (#806)

This commit is contained in:
Deshraj Yadav
2023-10-15 02:26:35 -07:00
committed by GitHub
parent a7a61fae1d
commit 636bc0a99d
14 changed files with 85 additions and 46 deletions

6
configs/pinecone.yaml Normal file
View File

@@ -0,0 +1,6 @@
vectordb:
provider: pinecone
config:
metric: cosine
vector_dimension: 1536
collection_name: my-pinecone-index

View File

@@ -147,7 +147,23 @@ _Coming soon_
## Pinecone ## Pinecone
_Coming soon_ In order to use Pinecone as vector database, set the environment variables `PINECONE_API_KEY` and `PINECONE_ENV` which you can find on [Pinecone dashboard](https://app.pinecone.io/).
```python main.py
from embedchain import App
# load pinecone configuration from yaml file
app = App.from_config(yaml_path="config.yaml")
```
```yaml config.yaml
vectordb:
provider: pinecone
config:
metric: cosine
vector_dimension: 1536
collection_name: my-pinecone-index
```
## Qdrant ## Qdrant

View File

@@ -2,8 +2,9 @@ from string import Template
from embedchain.apps.app import App from embedchain.apps.app import App
from embedchain.apps.open_source_app import OpenSourceApp from embedchain.apps.open_source_app import OpenSourceApp
from embedchain.config import BaseLlmConfig, AppConfig from embedchain.config import AppConfig, BaseLlmConfig
from embedchain.config.llm.base import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY from embedchain.config.llm.base import (DEFAULT_PROMPT,
DEFAULT_PROMPT_WITH_HISTORY)
from embedchain.helper.json_serializable import register_deserializable from embedchain.helper.json_serializable import register_deserializable

View File

@@ -1,18 +1,20 @@
from typing import Optional from typing import Dict, Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helper.json_serializable import register_deserializable
@register_deserializable @register_deserializable
class PineconeDbConfig(BaseVectorDbConfig): class PineconeDBConfig(BaseVectorDbConfig):
def __init__( def __init__(
self, self,
collection_name: Optional[str] = None, collection_name: Optional[str] = None,
dir: Optional[str] = None, dir: Optional[str] = None,
dimension: Optional[int] = 1536, vector_dimension: int = 1536,
metric: Optional[str] = "cosine", metric: Optional[str] = "cosine",
**extra_params: Dict[str, any],
): ):
self.dimension = dimension
self.metric = metric self.metric = metric
self.vector_dimension = vector_dimension
self.extra_params = extra_params
super().__init__(collection_name=collection_name, dir=dir) super().__init__(collection_name=collection_name, dir=dir)

View File

@@ -403,6 +403,8 @@ class EmbedChain(JSONSerializable):
skip_embedding=(chunker.data_type == DataType.IMAGES), skip_embedding=(chunker.data_type == DataType.IMAGES),
) )
count_new_chunks = self.db.count() - chunks_before_addition count_new_chunks = self.db.count() - chunks_before_addition
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
return list(documents), metadatas, ids, count_new_chunks return list(documents), metadatas, ids, count_new_chunks
def _format_result(self, results): def _format_result(self, results):

View File

@@ -69,11 +69,13 @@ class VectorDBFactory:
"chroma": "embedchain.vectordb.chroma.ChromaDB", "chroma": "embedchain.vectordb.chroma.ChromaDB",
"elasticsearch": "embedchain.vectordb.elasticsearch.ElasticsearchDB", "elasticsearch": "embedchain.vectordb.elasticsearch.ElasticsearchDB",
"opensearch": "embedchain.vectordb.opensearch.OpenSearchDB", "opensearch": "embedchain.vectordb.opensearch.OpenSearchDB",
"pinecone": "embedchain.vectordb.pinecone.PineconeDB",
} }
provider_to_config_class = { provider_to_config_class = {
"chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig", "chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig",
"elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig", "elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig",
"opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig", "opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig",
"pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig",
} }
@classmethod @classmethod

View File

@@ -1,4 +1,3 @@
import copy
import os import os
from typing import Dict, List, Optional from typing import Dict, List, Optional
@@ -6,38 +5,38 @@ try:
import pinecone import pinecone
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Pinecone requires extra dependencies. Install with `pip install --upgrade embedchain[pinecone]`" "Pinecone requires extra dependencies. Install with `pip install --upgrade 'embedchain[pinecone]'`"
) from None ) from None
from embedchain.config.vectordb.pinecone import PineconeDbConfig from embedchain.config.vectordb.pinecone import PineconeDBConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helper.json_serializable import register_deserializable
from embedchain.vectordb.base import BaseVectorDB from embedchain.vectordb.base import BaseVectorDB
@register_deserializable @register_deserializable
class PineconeDb(BaseVectorDB): class PineconeDB(BaseVectorDB):
BATCH_SIZE = 100
""" """
Pinecone as vector database Pinecone as vector database
""" """
BATCH_SIZE = 100
def __init__( def __init__(
self, self,
config: Optional[PineconeDbConfig] = None, config: Optional[PineconeDBConfig] = None,
): ):
"""Pinecone as vector database. """Pinecone as vector database.
:param config: Pinecone database config, defaults to None :param config: Pinecone database config, defaults to None
:type config: PineconeDbConfig, optional :type config: PineconeDBConfig, optional
:raises ValueError: No config provided :raises ValueError: No config provided
""" """
if config is None: if config is None:
self.config = PineconeDbConfig() self.config = PineconeDBConfig()
else: else:
if not isinstance(config, PineconeDbConfig): if not isinstance(config, PineconeDBConfig):
raise TypeError( raise TypeError(
"config is not a `PineconeDbConfig` instance. " "config is not a `PineconeDBConfig` instance. "
"Please make sure the type is right and that you are passing an instance." "Please make sure the type is right and that you are passing an instance."
) )
self.config = config self.config = config
@@ -57,11 +56,14 @@ class PineconeDb(BaseVectorDB):
pinecone.init( pinecone.init(
api_key=os.environ.get("PINECONE_API_KEY"), api_key=os.environ.get("PINECONE_API_KEY"),
environment=os.environ.get("PINECONE_ENV"), environment=os.environ.get("PINECONE_ENV"),
**self.config.extra_params,
) )
self.index_name = self._get_index_name() self.index_name = self._get_index_name()
indexes = pinecone.list_indexes() indexes = pinecone.list_indexes()
if indexes is None or self.index_name not in indexes: if indexes is None or self.index_name not in indexes:
pinecone.create_index(name=self.index_name, metric=self.config.metric, dimension=self.config.dimension) pinecone.create_index(
name=self.index_name, metric=self.config.metric, dimension=self.config.vector_dimension
)
return pinecone.Index(self.index_name) return pinecone.Index(self.index_name)
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None): def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
@@ -81,7 +83,6 @@ class PineconeDb(BaseVectorDB):
result = self.client.fetch(ids=ids[i : i + 1000]) result = self.client.fetch(ids=ids[i : i + 1000])
batch_existing_ids = list(result.get("vectors").keys()) batch_existing_ids = list(result.get("vectors").keys())
existing_ids.extend(batch_existing_ids) existing_ids.extend(batch_existing_ids)
return {"ids": existing_ids} return {"ids": existing_ids}
def add( def add(
@@ -102,15 +103,15 @@ class PineconeDb(BaseVectorDB):
:type ids: List[str] :type ids: List[str]
""" """
docs = [] docs = []
if embeddings is None: print("Adding documents to Pinecone...")
embeddings = self.embedder.embedding_fn(documents)
embeddings = self.embedder.embedding_fn(documents)
for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings): for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
metadata["text"] = text
docs.append( docs.append(
{ {
"id": id, "id": id,
"values": embedding, "values": embedding,
"metadata": copy.deepcopy(metadata), "metadata": {**metadata, "text": text},
} }
) )
@@ -120,13 +121,14 @@ class PineconeDb(BaseVectorDB):
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]: def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
""" """
query contents from vector database based on vector similarity query contents from vector database based on vector similarity
:param input_query: list of query string :param input_query: list of query string
:type input_query: List[str] :type input_query: List[str]
:param n_results: no of similar documents to fetch from database :param n_results: no of similar documents to fetch from database
:type n_results: int :type n_results: int
:param where: Optional. to filter data :param where: Optional. to filter data
:type where: Dict[str, any] :type where: Dict[str, any]
:param skip_embedding: Optional. if True, input_query is already embedded
:type skip_embedding: bool
:return: Database contents that are the result of the query :return: Database contents that are the result of the query
:rtype: List[str] :rtype: List[str]
""" """
@@ -177,4 +179,4 @@ class PineconeDb(BaseVectorDB):
:return: Pinecone index :return: Pinecone index
:rtype: str :rtype: str
""" """
return f"{self.config.collection_name}-{self.config.dimension}".lower().replace("_", "-") return f"{self.config.collection_name}-{self.config.vector_dimension}".lower().replace("_", "-")

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "embedchain" name = "embedchain"
version = "0.0.70" version = "0.0.71"
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data" description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
authors = ["Taranjeet Singh, Deshraj Yadav"] authors = ["Taranjeet Singh, Deshraj Yadav"]
license = "Apache License" license = "Apache License"
@@ -132,6 +132,7 @@ click = "^8.1.3"
isort = "^5.12.0" isort = "^5.12.0"
pytest-cov = "^4.1.0" pytest-cov = "^4.1.0"
responses = "^0.23.3" responses = "^0.23.3"
mock = "^5.1.0"
[tool.poetry.extras] [tool.poetry.extras]
streamlit = ["streamlit"] streamlit = ["streamlit"]

View File

@@ -1,9 +1,11 @@
import os import os
import pytest import pytest
import yaml import yaml
from embedchain import App, CustomApp, Llama2App, OpenSourceApp from embedchain import App, CustomApp, Llama2App, OpenSourceApp
from embedchain.config import AddConfig, AppConfig, BaseEmbedderConfig, BaseLlmConfig, ChromaDbConfig from embedchain.config import (AddConfig, AppConfig, BaseEmbedderConfig,
BaseLlmConfig, ChromaDbConfig)
from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.base import BaseEmbedder
from embedchain.llm.base import BaseLlm from embedchain.llm.base import BaseLlm
from embedchain.vectordb.base import BaseVectorDB, BaseVectorDbConfig from embedchain.vectordb.base import BaseVectorDB, BaseVectorDbConfig

View File

@@ -1,7 +1,8 @@
import pytest import pytest
from embedchain.apps.app import App from embedchain.apps.app import App
from embedchain.apps.person_app import PersonApp, PersonOpenSourceApp from embedchain.apps.person_app import PersonApp, PersonOpenSourceApp
from embedchain.config import BaseLlmConfig, AppConfig from embedchain.config import AppConfig, BaseLlmConfig
from embedchain.config.llm.base import DEFAULT_PROMPT from embedchain.config.llm.base import DEFAULT_PROMPT

View File

@@ -1,8 +1,9 @@
import argparse import argparse
import pytest import pytest
from fastapi_poe.types import ProtocolMessage, QueryRequest
from embedchain.bots.poe import PoeBot, start_command from embedchain.bots.poe import PoeBot, start_command
from fastapi_poe.types import QueryRequest, ProtocolMessage
@pytest.fixture @pytest.fixture

View File

@@ -1,5 +1,7 @@
import os import os
import pytest import pytest
from embedchain import App from embedchain import App
from embedchain.config import AddConfig, AppConfig, ChunkerConfig from embedchain.config import AddConfig, AppConfig, ChunkerConfig
from embedchain.models.data_type import DataType from embedchain.models.data_type import DataType

View File

@@ -1,4 +1,5 @@
import pytest import pytest
from embedchain.llm.base import BaseLlm, BaseLlmConfig from embedchain.llm.base import BaseLlm, BaseLlmConfig

View File

@@ -4,30 +4,30 @@ from unittest.mock import patch
from embedchain import App from embedchain import App
from embedchain.config import AppConfig from embedchain.config import AppConfig
from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.base import BaseEmbedder
from embedchain.vectordb.pineconedb import PineconeDb from embedchain.vectordb.pinecone import PineconeDB
class TestPineconeDb: class TestPinecone:
@patch("embedchain.vectordb.pineconedb.pinecone") @patch("embedchain.vectordb.pinecone.pinecone")
def test_init(self, pinecone_mock): def test_init(self, pinecone_mock):
"""Test that the PineconeDb can be initialized.""" """Test that the PineconeDB can be initialized."""
# Create a PineconeDb instance # Create a PineconeDB instance
PineconeDb() PineconeDB()
# Assert that the Pinecone client was initialized # Assert that the Pinecone client was initialized
pinecone_mock.init.assert_called_once() pinecone_mock.init.assert_called_once()
pinecone_mock.list_indexes.assert_called_once() pinecone_mock.list_indexes.assert_called_once()
pinecone_mock.Index.assert_called_once() pinecone_mock.Index.assert_called_once()
@patch("embedchain.vectordb.pineconedb.pinecone") @patch("embedchain.vectordb.pinecone.pinecone")
def test_set_embedder(self, pinecone_mock): def test_set_embedder(self, pinecone_mock):
"""Test that the embedder can be set.""" """Test that the embedder can be set."""
# Set the embedder # Set the embedder
embedder = BaseEmbedder() embedder = BaseEmbedder()
# Create a PineconeDb instance # Create a PineconeDB instance
db = PineconeDb() db = PineconeDB()
app_config = AppConfig(collect_metrics=False) app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder) App(config=app_config, db=db, embedder=embedder)
@@ -35,7 +35,7 @@ class TestPineconeDb:
assert db.embedder == embedder assert db.embedder == embedder
pinecone_mock.init.assert_called_once() pinecone_mock.init.assert_called_once()
@patch("embedchain.vectordb.pineconedb.pinecone") @patch("embedchain.vectordb.pinecone.pinecone")
def test_add_documents(self, pinecone_mock): def test_add_documents(self, pinecone_mock):
"""Test that documents can be added to the database.""" """Test that documents can be added to the database."""
pinecone_client_mock = pinecone_mock.Index.return_value pinecone_client_mock = pinecone_mock.Index.return_value
@@ -46,7 +46,7 @@ class TestPineconeDb:
vectors = [[0, 0, 0], [1, 1, 1]] vectors = [[0, 0, 0], [1, 1, 1]]
embedding_function.return_value = vectors embedding_function.return_value = vectors
# Create a PineconeDb instance # Create a PineconeDb instance
db = PineconeDb() db = PineconeDB()
app_config = AppConfig(collect_metrics=False) app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=base_embedder) App(config=app_config, db=db, embedder=base_embedder)
@@ -63,7 +63,7 @@ class TestPineconeDb:
# Assert that the Pinecone client was called to upsert the documents # Assert that the Pinecone client was called to upsert the documents
pinecone_client_mock.upsert.assert_called_once_with(expected_pinecone_upsert_args) pinecone_client_mock.upsert.assert_called_once_with(expected_pinecone_upsert_args)
@patch("embedchain.vectordb.pineconedb.pinecone") @patch("embedchain.vectordb.pinecone.pinecone")
def test_query_documents(self, pinecone_mock): def test_query_documents(self, pinecone_mock):
"""Test that documents can be queried from the database.""" """Test that documents can be queried from the database."""
pinecone_client_mock = pinecone_mock.Index.return_value pinecone_client_mock = pinecone_mock.Index.return_value
@@ -73,8 +73,8 @@ class TestPineconeDb:
base_embedder.set_embedding_fn(embedding_function) base_embedder.set_embedding_fn(embedding_function)
vectors = [[0, 0, 0]] vectors = [[0, 0, 0]]
embedding_function.return_value = vectors embedding_function.return_value = vectors
# Create a PineconeDb instance # Create a PineconeDB instance
db = PineconeDb() db = PineconeDB()
app_config = AppConfig(collect_metrics=False) app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=base_embedder) App(config=app_config, db=db, embedder=base_embedder)
@@ -88,11 +88,11 @@ class TestPineconeDb:
vector=db.embedder.embedding_fn(input_query)[0], top_k=n_results, filter={}, include_metadata=True vector=db.embedder.embedding_fn(input_query)[0], top_k=n_results, filter={}, include_metadata=True
) )
@patch("embedchain.vectordb.pineconedb.pinecone") @patch("embedchain.vectordb.pinecone.pinecone")
def test_reset(self, pinecone_mock): def test_reset(self, pinecone_mock):
"""Test that the database can be reset.""" """Test that the database can be reset."""
# Create a PineconeDb instance # Create a PineconeDb instance
db = PineconeDb() db = PineconeDB()
app_config = AppConfig(collect_metrics=False) app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=BaseEmbedder()) App(config=app_config, db=db, embedder=BaseEmbedder())