[feature]: Improve pinecone db integration (#806)
This commit is contained in:
6
configs/pinecone.yaml
Normal file
6
configs/pinecone.yaml
Normal file
@@ -0,0 +1,6 @@
|
||||
vectordb:
|
||||
provider: pinecone
|
||||
config:
|
||||
metric: cosine
|
||||
vector_dimension: 1536
|
||||
collection_name: my-pinecone-index
|
||||
@@ -147,7 +147,23 @@ _Coming soon_
|
||||
|
||||
## Pinecone
|
||||
|
||||
_Coming soon_
|
||||
In order to use Pinecone as vector database, set the environment variables `PINECONE_API_KEY` and `PINECONE_ENV` which you can find on [Pinecone dashboard](https://app.pinecone.io/).
|
||||
|
||||
```python main.py
|
||||
from embedchain import App
|
||||
|
||||
# load pinecone configuration from yaml file
|
||||
app = App.from_config(yaml_path="config.yaml")
|
||||
```
|
||||
|
||||
```yaml config.yaml
|
||||
vectordb:
|
||||
provider: pinecone
|
||||
config:
|
||||
metric: cosine
|
||||
vector_dimension: 1536
|
||||
collection_name: my-pinecone-index
|
||||
```
|
||||
|
||||
## Qdrant
|
||||
|
||||
|
||||
@@ -2,8 +2,9 @@ from string import Template
|
||||
|
||||
from embedchain.apps.app import App
|
||||
from embedchain.apps.open_source_app import OpenSourceApp
|
||||
from embedchain.config import BaseLlmConfig, AppConfig
|
||||
from embedchain.config.llm.base import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY
|
||||
from embedchain.config import AppConfig, BaseLlmConfig
|
||||
from embedchain.config.llm.base import (DEFAULT_PROMPT,
|
||||
DEFAULT_PROMPT_WITH_HISTORY)
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
|
||||
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class PineconeDbConfig(BaseVectorDbConfig):
|
||||
class PineconeDBConfig(BaseVectorDbConfig):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: Optional[str] = None,
|
||||
dir: Optional[str] = None,
|
||||
dimension: Optional[int] = 1536,
|
||||
vector_dimension: int = 1536,
|
||||
metric: Optional[str] = "cosine",
|
||||
**extra_params: Dict[str, any],
|
||||
):
|
||||
self.dimension = dimension
|
||||
self.metric = metric
|
||||
self.vector_dimension = vector_dimension
|
||||
self.extra_params = extra_params
|
||||
super().__init__(collection_name=collection_name, dir=dir)
|
||||
|
||||
@@ -403,6 +403,8 @@ class EmbedChain(JSONSerializable):
|
||||
skip_embedding=(chunker.data_type == DataType.IMAGES),
|
||||
)
|
||||
count_new_chunks = self.db.count() - chunks_before_addition
|
||||
|
||||
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
|
||||
return list(documents), metadatas, ids, count_new_chunks
|
||||
|
||||
def _format_result(self, results):
|
||||
|
||||
@@ -69,11 +69,13 @@ class VectorDBFactory:
|
||||
"chroma": "embedchain.vectordb.chroma.ChromaDB",
|
||||
"elasticsearch": "embedchain.vectordb.elasticsearch.ElasticsearchDB",
|
||||
"opensearch": "embedchain.vectordb.opensearch.OpenSearchDB",
|
||||
"pinecone": "embedchain.vectordb.pinecone.PineconeDB",
|
||||
}
|
||||
provider_to_config_class = {
|
||||
"chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig",
|
||||
"elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig",
|
||||
"opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig",
|
||||
"pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import copy
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
@@ -6,38 +5,38 @@ try:
|
||||
import pinecone
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Pinecone requires extra dependencies. Install with `pip install --upgrade embedchain[pinecone]`"
|
||||
"Pinecone requires extra dependencies. Install with `pip install --upgrade 'embedchain[pinecone]'`"
|
||||
) from None
|
||||
|
||||
from embedchain.config.vectordb.pinecone import PineconeDbConfig
|
||||
from embedchain.config.vectordb.pinecone import PineconeDBConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class PineconeDb(BaseVectorDB):
|
||||
BATCH_SIZE = 100
|
||||
|
||||
class PineconeDB(BaseVectorDB):
|
||||
"""
|
||||
Pinecone as vector database
|
||||
"""
|
||||
|
||||
BATCH_SIZE = 100
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[PineconeDbConfig] = None,
|
||||
config: Optional[PineconeDBConfig] = None,
|
||||
):
|
||||
"""Pinecone as vector database.
|
||||
|
||||
:param config: Pinecone database config, defaults to None
|
||||
:type config: PineconeDbConfig, optional
|
||||
:type config: PineconeDBConfig, optional
|
||||
:raises ValueError: No config provided
|
||||
"""
|
||||
if config is None:
|
||||
self.config = PineconeDbConfig()
|
||||
self.config = PineconeDBConfig()
|
||||
else:
|
||||
if not isinstance(config, PineconeDbConfig):
|
||||
if not isinstance(config, PineconeDBConfig):
|
||||
raise TypeError(
|
||||
"config is not a `PineconeDbConfig` instance. "
|
||||
"config is not a `PineconeDBConfig` instance. "
|
||||
"Please make sure the type is right and that you are passing an instance."
|
||||
)
|
||||
self.config = config
|
||||
@@ -57,11 +56,14 @@ class PineconeDb(BaseVectorDB):
|
||||
pinecone.init(
|
||||
api_key=os.environ.get("PINECONE_API_KEY"),
|
||||
environment=os.environ.get("PINECONE_ENV"),
|
||||
**self.config.extra_params,
|
||||
)
|
||||
self.index_name = self._get_index_name()
|
||||
indexes = pinecone.list_indexes()
|
||||
if indexes is None or self.index_name not in indexes:
|
||||
pinecone.create_index(name=self.index_name, metric=self.config.metric, dimension=self.config.dimension)
|
||||
pinecone.create_index(
|
||||
name=self.index_name, metric=self.config.metric, dimension=self.config.vector_dimension
|
||||
)
|
||||
return pinecone.Index(self.index_name)
|
||||
|
||||
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
|
||||
@@ -81,7 +83,6 @@ class PineconeDb(BaseVectorDB):
|
||||
result = self.client.fetch(ids=ids[i : i + 1000])
|
||||
batch_existing_ids = list(result.get("vectors").keys())
|
||||
existing_ids.extend(batch_existing_ids)
|
||||
|
||||
return {"ids": existing_ids}
|
||||
|
||||
def add(
|
||||
@@ -102,15 +103,15 @@ class PineconeDb(BaseVectorDB):
|
||||
:type ids: List[str]
|
||||
"""
|
||||
docs = []
|
||||
if embeddings is None:
|
||||
embeddings = self.embedder.embedding_fn(documents)
|
||||
print("Adding documents to Pinecone...")
|
||||
|
||||
embeddings = self.embedder.embedding_fn(documents)
|
||||
for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
|
||||
metadata["text"] = text
|
||||
docs.append(
|
||||
{
|
||||
"id": id,
|
||||
"values": embedding,
|
||||
"metadata": copy.deepcopy(metadata),
|
||||
"metadata": {**metadata, "text": text},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -120,13 +121,14 @@ class PineconeDb(BaseVectorDB):
|
||||
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
|
||||
"""
|
||||
query contents from vector database based on vector similarity
|
||||
|
||||
:param input_query: list of query string
|
||||
:type input_query: List[str]
|
||||
:param n_results: no of similar documents to fetch from database
|
||||
:type n_results: int
|
||||
:param where: Optional. to filter data
|
||||
:type where: Dict[str, any]
|
||||
:param skip_embedding: Optional. if True, input_query is already embedded
|
||||
:type skip_embedding: bool
|
||||
:return: Database contents that are the result of the query
|
||||
:rtype: List[str]
|
||||
"""
|
||||
@@ -177,4 +179,4 @@ class PineconeDb(BaseVectorDB):
|
||||
:return: Pinecone index
|
||||
:rtype: str
|
||||
"""
|
||||
return f"{self.config.collection_name}-{self.config.dimension}".lower().replace("_", "-")
|
||||
return f"{self.config.collection_name}-{self.config.vector_dimension}".lower().replace("_", "-")
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "embedchain"
|
||||
version = "0.0.70"
|
||||
version = "0.0.71"
|
||||
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
|
||||
authors = ["Taranjeet Singh, Deshraj Yadav"]
|
||||
license = "Apache License"
|
||||
@@ -132,6 +132,7 @@ click = "^8.1.3"
|
||||
isort = "^5.12.0"
|
||||
pytest-cov = "^4.1.0"
|
||||
responses = "^0.23.3"
|
||||
mock = "^5.1.0"
|
||||
|
||||
[tool.poetry.extras]
|
||||
streamlit = ["streamlit"]
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from embedchain import App, CustomApp, Llama2App, OpenSourceApp
|
||||
from embedchain.config import AddConfig, AppConfig, BaseEmbedderConfig, BaseLlmConfig, ChromaDbConfig
|
||||
from embedchain.config import (AddConfig, AppConfig, BaseEmbedderConfig,
|
||||
BaseLlmConfig, ChromaDbConfig)
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.vectordb.base import BaseVectorDB, BaseVectorDbConfig
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import pytest
|
||||
|
||||
from embedchain.apps.app import App
|
||||
from embedchain.apps.person_app import PersonApp, PersonOpenSourceApp
|
||||
from embedchain.config import BaseLlmConfig, AppConfig
|
||||
from embedchain.config import AppConfig, BaseLlmConfig
|
||||
from embedchain.config.llm.base import DEFAULT_PROMPT
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import argparse
|
||||
|
||||
import pytest
|
||||
from fastapi_poe.types import ProtocolMessage, QueryRequest
|
||||
|
||||
from embedchain.bots.poe import PoeBot, start_command
|
||||
from fastapi_poe.types import QueryRequest, ProtocolMessage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AddConfig, AppConfig, ChunkerConfig
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import pytest
|
||||
|
||||
from embedchain.llm.base import BaseLlm, BaseLlmConfig
|
||||
|
||||
|
||||
|
||||
@@ -4,30 +4,30 @@ from unittest.mock import patch
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.vectordb.pineconedb import PineconeDb
|
||||
from embedchain.vectordb.pinecone import PineconeDB
|
||||
|
||||
|
||||
class TestPineconeDb:
|
||||
@patch("embedchain.vectordb.pineconedb.pinecone")
|
||||
class TestPinecone:
|
||||
@patch("embedchain.vectordb.pinecone.pinecone")
|
||||
def test_init(self, pinecone_mock):
|
||||
"""Test that the PineconeDb can be initialized."""
|
||||
# Create a PineconeDb instance
|
||||
PineconeDb()
|
||||
"""Test that the PineconeDB can be initialized."""
|
||||
# Create a PineconeDB instance
|
||||
PineconeDB()
|
||||
|
||||
# Assert that the Pinecone client was initialized
|
||||
pinecone_mock.init.assert_called_once()
|
||||
pinecone_mock.list_indexes.assert_called_once()
|
||||
pinecone_mock.Index.assert_called_once()
|
||||
|
||||
@patch("embedchain.vectordb.pineconedb.pinecone")
|
||||
@patch("embedchain.vectordb.pinecone.pinecone")
|
||||
def test_set_embedder(self, pinecone_mock):
|
||||
"""Test that the embedder can be set."""
|
||||
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
|
||||
# Create a PineconeDb instance
|
||||
db = PineconeDb()
|
||||
# Create a PineconeDB instance
|
||||
db = PineconeDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
|
||||
@@ -35,7 +35,7 @@ class TestPineconeDb:
|
||||
assert db.embedder == embedder
|
||||
pinecone_mock.init.assert_called_once()
|
||||
|
||||
@patch("embedchain.vectordb.pineconedb.pinecone")
|
||||
@patch("embedchain.vectordb.pinecone.pinecone")
|
||||
def test_add_documents(self, pinecone_mock):
|
||||
"""Test that documents can be added to the database."""
|
||||
pinecone_client_mock = pinecone_mock.Index.return_value
|
||||
@@ -46,7 +46,7 @@ class TestPineconeDb:
|
||||
vectors = [[0, 0, 0], [1, 1, 1]]
|
||||
embedding_function.return_value = vectors
|
||||
# Create a PineconeDb instance
|
||||
db = PineconeDb()
|
||||
db = PineconeDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=base_embedder)
|
||||
|
||||
@@ -63,7 +63,7 @@ class TestPineconeDb:
|
||||
# Assert that the Pinecone client was called to upsert the documents
|
||||
pinecone_client_mock.upsert.assert_called_once_with(expected_pinecone_upsert_args)
|
||||
|
||||
@patch("embedchain.vectordb.pineconedb.pinecone")
|
||||
@patch("embedchain.vectordb.pinecone.pinecone")
|
||||
def test_query_documents(self, pinecone_mock):
|
||||
"""Test that documents can be queried from the database."""
|
||||
pinecone_client_mock = pinecone_mock.Index.return_value
|
||||
@@ -73,8 +73,8 @@ class TestPineconeDb:
|
||||
base_embedder.set_embedding_fn(embedding_function)
|
||||
vectors = [[0, 0, 0]]
|
||||
embedding_function.return_value = vectors
|
||||
# Create a PineconeDb instance
|
||||
db = PineconeDb()
|
||||
# Create a PineconeDB instance
|
||||
db = PineconeDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=base_embedder)
|
||||
|
||||
@@ -88,11 +88,11 @@ class TestPineconeDb:
|
||||
vector=db.embedder.embedding_fn(input_query)[0], top_k=n_results, filter={}, include_metadata=True
|
||||
)
|
||||
|
||||
@patch("embedchain.vectordb.pineconedb.pinecone")
|
||||
@patch("embedchain.vectordb.pinecone.pinecone")
|
||||
def test_reset(self, pinecone_mock):
|
||||
"""Test that the database can be reset."""
|
||||
# Create a PineconeDb instance
|
||||
db = PineconeDb()
|
||||
db = PineconeDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=BaseEmbedder())
|
||||
|
||||
Reference in New Issue
Block a user