[feature]: Improve pinecone db integration (#806)

This commit is contained in:
Deshraj Yadav
2023-10-15 02:26:35 -07:00
committed by GitHub
parent a7a61fae1d
commit 636bc0a99d
14 changed files with 85 additions and 46 deletions

6
configs/pinecone.yaml Normal file
View File

@@ -0,0 +1,6 @@
vectordb:
provider: pinecone
config:
metric: cosine
vector_dimension: 1536
collection_name: my-pinecone-index

View File

@@ -147,7 +147,23 @@ _Coming soon_
## Pinecone
_Coming soon_
In order to use Pinecone as vector database, set the environment variables `PINECONE_API_KEY` and `PINECONE_ENV` which you can find on [Pinecone dashboard](https://app.pinecone.io/).
```python main.py
from embedchain import App
# load pinecone configuration from yaml file
app = App.from_config(yaml_path="config.yaml")
```
```yaml config.yaml
vectordb:
provider: pinecone
config:
metric: cosine
vector_dimension: 1536
collection_name: my-pinecone-index
```
## Qdrant

View File

@@ -2,8 +2,9 @@ from string import Template
from embedchain.apps.app import App
from embedchain.apps.open_source_app import OpenSourceApp
from embedchain.config import BaseLlmConfig, AppConfig
from embedchain.config.llm.base import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY
from embedchain.config import AppConfig, BaseLlmConfig
from embedchain.config.llm.base import (DEFAULT_PROMPT,
DEFAULT_PROMPT_WITH_HISTORY)
from embedchain.helper.json_serializable import register_deserializable

View File

@@ -1,18 +1,20 @@
from typing import Optional
from typing import Dict, Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helper.json_serializable import register_deserializable
@register_deserializable
class PineconeDbConfig(BaseVectorDbConfig):
class PineconeDBConfig(BaseVectorDbConfig):
def __init__(
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
dimension: Optional[int] = 1536,
vector_dimension: int = 1536,
metric: Optional[str] = "cosine",
**extra_params: Dict[str, any],
):
self.dimension = dimension
self.metric = metric
self.vector_dimension = vector_dimension
self.extra_params = extra_params
super().__init__(collection_name=collection_name, dir=dir)

View File

@@ -403,6 +403,8 @@ class EmbedChain(JSONSerializable):
skip_embedding=(chunker.data_type == DataType.IMAGES),
)
count_new_chunks = self.db.count() - chunks_before_addition
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
return list(documents), metadatas, ids, count_new_chunks
def _format_result(self, results):

View File

@@ -69,11 +69,13 @@ class VectorDBFactory:
"chroma": "embedchain.vectordb.chroma.ChromaDB",
"elasticsearch": "embedchain.vectordb.elasticsearch.ElasticsearchDB",
"opensearch": "embedchain.vectordb.opensearch.OpenSearchDB",
"pinecone": "embedchain.vectordb.pinecone.PineconeDB",
}
provider_to_config_class = {
"chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig",
"elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig",
"opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig",
"pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig",
}
@classmethod

View File

@@ -1,4 +1,3 @@
import copy
import os
from typing import Dict, List, Optional
@@ -6,38 +5,38 @@ try:
import pinecone
except ImportError:
raise ImportError(
"Pinecone requires extra dependencies. Install with `pip install --upgrade embedchain[pinecone]`"
"Pinecone requires extra dependencies. Install with `pip install --upgrade 'embedchain[pinecone]'`"
) from None
from embedchain.config.vectordb.pinecone import PineconeDbConfig
from embedchain.config.vectordb.pinecone import PineconeDBConfig
from embedchain.helper.json_serializable import register_deserializable
from embedchain.vectordb.base import BaseVectorDB
@register_deserializable
class PineconeDb(BaseVectorDB):
BATCH_SIZE = 100
class PineconeDB(BaseVectorDB):
"""
Pinecone as vector database
"""
BATCH_SIZE = 100
def __init__(
self,
config: Optional[PineconeDbConfig] = None,
config: Optional[PineconeDBConfig] = None,
):
"""Pinecone as vector database.
:param config: Pinecone database config, defaults to None
:type config: PineconeDbConfig, optional
:type config: PineconeDBConfig, optional
:raises ValueError: No config provided
"""
if config is None:
self.config = PineconeDbConfig()
self.config = PineconeDBConfig()
else:
if not isinstance(config, PineconeDbConfig):
if not isinstance(config, PineconeDBConfig):
raise TypeError(
"config is not a `PineconeDbConfig` instance. "
"config is not a `PineconeDBConfig` instance. "
"Please make sure the type is right and that you are passing an instance."
)
self.config = config
@@ -57,11 +56,14 @@ class PineconeDb(BaseVectorDB):
pinecone.init(
api_key=os.environ.get("PINECONE_API_KEY"),
environment=os.environ.get("PINECONE_ENV"),
**self.config.extra_params,
)
self.index_name = self._get_index_name()
indexes = pinecone.list_indexes()
if indexes is None or self.index_name not in indexes:
pinecone.create_index(name=self.index_name, metric=self.config.metric, dimension=self.config.dimension)
pinecone.create_index(
name=self.index_name, metric=self.config.metric, dimension=self.config.vector_dimension
)
return pinecone.Index(self.index_name)
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
@@ -81,7 +83,6 @@ class PineconeDb(BaseVectorDB):
result = self.client.fetch(ids=ids[i : i + 1000])
batch_existing_ids = list(result.get("vectors").keys())
existing_ids.extend(batch_existing_ids)
return {"ids": existing_ids}
def add(
@@ -102,15 +103,15 @@ class PineconeDb(BaseVectorDB):
:type ids: List[str]
"""
docs = []
if embeddings is None:
embeddings = self.embedder.embedding_fn(documents)
print("Adding documents to Pinecone...")
embeddings = self.embedder.embedding_fn(documents)
for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
metadata["text"] = text
docs.append(
{
"id": id,
"values": embedding,
"metadata": copy.deepcopy(metadata),
"metadata": {**metadata, "text": text},
}
)
@@ -120,13 +121,14 @@ class PineconeDb(BaseVectorDB):
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
"""
query contents from vector database based on vector similarity
:param input_query: list of query string
:type input_query: List[str]
:param n_results: no of similar documents to fetch from database
:type n_results: int
:param where: Optional. to filter data
:type where: Dict[str, any]
:param skip_embedding: Optional. if True, input_query is already embedded
:type skip_embedding: bool
:return: Database contents that are the result of the query
:rtype: List[str]
"""
@@ -177,4 +179,4 @@ class PineconeDb(BaseVectorDB):
:return: Pinecone index
:rtype: str
"""
return f"{self.config.collection_name}-{self.config.dimension}".lower().replace("_", "-")
return f"{self.config.collection_name}-{self.config.vector_dimension}".lower().replace("_", "-")

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "embedchain"
version = "0.0.70"
version = "0.0.71"
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
authors = ["Taranjeet Singh, Deshraj Yadav"]
license = "Apache License"
@@ -132,6 +132,7 @@ click = "^8.1.3"
isort = "^5.12.0"
pytest-cov = "^4.1.0"
responses = "^0.23.3"
mock = "^5.1.0"
[tool.poetry.extras]
streamlit = ["streamlit"]

View File

@@ -1,9 +1,11 @@
import os
import pytest
import yaml
from embedchain import App, CustomApp, Llama2App, OpenSourceApp
from embedchain.config import AddConfig, AppConfig, BaseEmbedderConfig, BaseLlmConfig, ChromaDbConfig
from embedchain.config import (AddConfig, AppConfig, BaseEmbedderConfig,
BaseLlmConfig, ChromaDbConfig)
from embedchain.embedder.base import BaseEmbedder
from embedchain.llm.base import BaseLlm
from embedchain.vectordb.base import BaseVectorDB, BaseVectorDbConfig

View File

@@ -1,7 +1,8 @@
import pytest
from embedchain.apps.app import App
from embedchain.apps.person_app import PersonApp, PersonOpenSourceApp
from embedchain.config import BaseLlmConfig, AppConfig
from embedchain.config import AppConfig, BaseLlmConfig
from embedchain.config.llm.base import DEFAULT_PROMPT

View File

@@ -1,8 +1,9 @@
import argparse
import pytest
from fastapi_poe.types import ProtocolMessage, QueryRequest
from embedchain.bots.poe import PoeBot, start_command
from fastapi_poe.types import QueryRequest, ProtocolMessage
@pytest.fixture

View File

@@ -1,5 +1,7 @@
import os
import pytest
from embedchain import App
from embedchain.config import AddConfig, AppConfig, ChunkerConfig
from embedchain.models.data_type import DataType

View File

@@ -1,4 +1,5 @@
import pytest
from embedchain.llm.base import BaseLlm, BaseLlmConfig

View File

@@ -4,30 +4,30 @@ from unittest.mock import patch
from embedchain import App
from embedchain.config import AppConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.vectordb.pineconedb import PineconeDb
from embedchain.vectordb.pinecone import PineconeDB
class TestPineconeDb:
@patch("embedchain.vectordb.pineconedb.pinecone")
class TestPinecone:
@patch("embedchain.vectordb.pinecone.pinecone")
def test_init(self, pinecone_mock):
"""Test that the PineconeDb can be initialized."""
# Create a PineconeDb instance
PineconeDb()
"""Test that the PineconeDB can be initialized."""
# Create a PineconeDB instance
PineconeDB()
# Assert that the Pinecone client was initialized
pinecone_mock.init.assert_called_once()
pinecone_mock.list_indexes.assert_called_once()
pinecone_mock.Index.assert_called_once()
@patch("embedchain.vectordb.pineconedb.pinecone")
@patch("embedchain.vectordb.pinecone.pinecone")
def test_set_embedder(self, pinecone_mock):
"""Test that the embedder can be set."""
# Set the embedder
embedder = BaseEmbedder()
# Create a PineconeDb instance
db = PineconeDb()
# Create a PineconeDB instance
db = PineconeDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
@@ -35,7 +35,7 @@ class TestPineconeDb:
assert db.embedder == embedder
pinecone_mock.init.assert_called_once()
@patch("embedchain.vectordb.pineconedb.pinecone")
@patch("embedchain.vectordb.pinecone.pinecone")
def test_add_documents(self, pinecone_mock):
"""Test that documents can be added to the database."""
pinecone_client_mock = pinecone_mock.Index.return_value
@@ -46,7 +46,7 @@ class TestPineconeDb:
vectors = [[0, 0, 0], [1, 1, 1]]
embedding_function.return_value = vectors
# Create a PineconeDb instance
db = PineconeDb()
db = PineconeDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=base_embedder)
@@ -63,7 +63,7 @@ class TestPineconeDb:
# Assert that the Pinecone client was called to upsert the documents
pinecone_client_mock.upsert.assert_called_once_with(expected_pinecone_upsert_args)
@patch("embedchain.vectordb.pineconedb.pinecone")
@patch("embedchain.vectordb.pinecone.pinecone")
def test_query_documents(self, pinecone_mock):
"""Test that documents can be queried from the database."""
pinecone_client_mock = pinecone_mock.Index.return_value
@@ -73,8 +73,8 @@ class TestPineconeDb:
base_embedder.set_embedding_fn(embedding_function)
vectors = [[0, 0, 0]]
embedding_function.return_value = vectors
# Create a PineconeDb instance
db = PineconeDb()
# Create a PineconeDB instance
db = PineconeDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=base_embedder)
@@ -88,11 +88,11 @@ class TestPineconeDb:
vector=db.embedder.embedding_fn(input_query)[0], top_k=n_results, filter={}, include_metadata=True
)
@patch("embedchain.vectordb.pineconedb.pinecone")
@patch("embedchain.vectordb.pinecone.pinecone")
def test_reset(self, pinecone_mock):
"""Test that the database can be reset."""
# Create a PineconeDb instance
db = PineconeDb()
db = PineconeDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=BaseEmbedder())