Add batch_size in config for VectorDB (#1448)

This commit is contained in:
Dev Khant
2024-06-28 03:15:58 +05:30
committed by GitHub
parent edaeb78ccf
commit 0a78198bb5
10 changed files with 28 additions and 34 deletions

View File

@@ -217,6 +217,7 @@ Alright, let's dive into what each key means in the yaml config above:
- `collection_name` (String): The initial collection name for the vectordb, set to 'full-stack-app'. - `collection_name` (String): The initial collection name for the vectordb, set to 'full-stack-app'.
- `dir` (String): The directory for the local database, set to 'db'. - `dir` (String): The directory for the local database, set to 'db'.
- `allow_reset` (Boolean): Indicates whether resetting the vectordb is allowed, set to true. - `allow_reset` (Boolean): Indicates whether resetting the vectordb is allowed, set to true.
- `batch_size` (Integer): The batch size for docs insertion in vectordb, defaults to `100`
<Note>We recommend you to checkout vectordb specific config [here](https://docs.embedchain.ai/components/vector-databases)</Note> <Note>We recommend you to checkout vectordb specific config [here](https://docs.embedchain.ai/components/vector-databases)</Note>
4. `embedder` Section: 4. `embedder` Section:
- `provider` (String): The provider for the embedder, set to 'openai'. You can find the full list of embedding model providers in [our docs](/components/embedding-models). - `provider` (String): The provider for the embedder, set to 'openai'. You can find the full list of embedding model providers in [our docs](/components/embedding-models).

View File

@@ -10,6 +10,7 @@ class BaseVectorDbConfig(BaseConfig):
dir: str = "db", dir: str = "db",
host: Optional[str] = None, host: Optional[str] = None,
port: Optional[str] = None, port: Optional[str] = None,
batch_size: Optional[int] = 100,
**kwargs, **kwargs,
): ):
""" """
@@ -23,6 +24,8 @@ class BaseVectorDbConfig(BaseConfig):
:type host: Optional[str], optional :type host: Optional[str], optional
:param host: Database connection remote port. Use this if you run Embedchain as a client, defaults to None :param host: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
:type port: Optional[str], optional :type port: Optional[str], optional
:param batch_size: Number of items to insert in one batch, defaults to 100
:type batch_size: Optional[int], optional
:param kwargs: Additional keyword arguments :param kwargs: Additional keyword arguments
:type kwargs: dict :type kwargs: dict
""" """
@@ -30,6 +33,7 @@ class BaseVectorDbConfig(BaseConfig):
self.dir = dir self.dir = dir
self.host = host self.host = host
self.port = port self.port = port
self.batch_size = batch_size
# Assign additional keyword arguments # Assign additional keyword arguments
if kwargs: if kwargs:
for key, value in kwargs.items(): for key, value in kwargs.items():

View File

@@ -29,8 +29,6 @@ logger = logging.getLogger(__name__)
class ChromaDB(BaseVectorDB): class ChromaDB(BaseVectorDB):
"""Vector database using ChromaDB.""" """Vector database using ChromaDB."""
BATCH_SIZE = 100
def __init__(self, config: Optional[ChromaDbConfig] = None): def __init__(self, config: Optional[ChromaDbConfig] = None):
"""Initialize a new ChromaDB instance """Initialize a new ChromaDB instance
@@ -155,11 +153,11 @@ class ChromaDB(BaseVectorDB):
" Ids size: {}".format(len(documents), len(metadatas), len(ids)) " Ids size: {}".format(len(documents), len(metadatas), len(ids))
) )
for i in tqdm(range(0, len(documents), self.BATCH_SIZE), desc="Inserting batches in chromadb"): for i in tqdm(range(0, len(documents), self.config.batch_size), desc="Inserting batches in chromadb"):
self.collection.add( self.collection.add(
documents=documents[i : i + self.BATCH_SIZE], documents=documents[i : i + self.config.batch_size],
metadatas=metadatas[i : i + self.BATCH_SIZE], metadatas=metadatas[i : i + self.config.batch_size],
ids=ids[i : i + self.BATCH_SIZE], ids=ids[i : i + self.config.batch_size],
) )
@staticmethod @staticmethod

View File

@@ -23,8 +23,6 @@ class ElasticsearchDB(BaseVectorDB):
Elasticsearch as vector database Elasticsearch as vector database
""" """
BATCH_SIZE = 100
def __init__( def __init__(
self, self,
config: Optional[ElasticsearchDBConfig] = None, config: Optional[ElasticsearchDBConfig] = None,
@@ -140,7 +138,9 @@ class ElasticsearchDB(BaseVectorDB):
embeddings = self.embedder.embedding_fn(documents) embeddings = self.embedder.embedding_fn(documents)
for chunk in chunks( for chunk in chunks(
list(zip(ids, documents, metadatas, embeddings)), self.BATCH_SIZE, desc="Inserting batches in elasticsearch" list(zip(ids, documents, metadatas, embeddings)),
self.config.batch_size,
desc="Inserting batches in elasticsearch",
): # noqa: E501 ): # noqa: E501
ids, docs, metadatas, embeddings = [], [], [], [] ids, docs, metadatas, embeddings = [], [], [], []
for id, text, metadata, embedding in chunk: for id, text, metadata, embedding in chunk:

View File

@@ -18,8 +18,6 @@ class LanceDB(BaseVectorDB):
LanceDB as vector database LanceDB as vector database
""" """
BATCH_SIZE = 100
def __init__( def __init__(
self, self,
config: Optional[LanceDBConfig] = None, config: Optional[LanceDBConfig] = None,

View File

@@ -28,8 +28,6 @@ class OpenSearchDB(BaseVectorDB):
OpenSearch as vector database OpenSearch as vector database
""" """
BATCH_SIZE = 100
def __init__(self, config: OpenSearchDBConfig): def __init__(self, config: OpenSearchDBConfig):
"""OpenSearch as vector database. """OpenSearch as vector database.
@@ -120,8 +118,10 @@ class OpenSearchDB(BaseVectorDB):
"""Adds documents to the opensearch index""" """Adds documents to the opensearch index"""
embeddings = self.embedder.embedding_fn(documents) embeddings = self.embedder.embedding_fn(documents)
for batch_start in tqdm(range(0, len(documents), self.BATCH_SIZE), desc="Inserting batches in opensearch"): for batch_start in tqdm(
batch_end = batch_start + self.BATCH_SIZE range(0, len(documents), self.config.batch_size), desc="Inserting batches in opensearch"
):
batch_end = batch_start + self.config.batch_size
batch_documents = documents[batch_start:batch_end] batch_documents = documents[batch_start:batch_end]
batch_embeddings = embeddings[batch_start:batch_end] batch_embeddings = embeddings[batch_start:batch_end]

View File

@@ -25,8 +25,6 @@ class PineconeDB(BaseVectorDB):
Pinecone as vector database Pinecone as vector database
""" """
BATCH_SIZE = 100
def __init__( def __init__(
self, self,
config: Optional[PineconeDBConfig] = None, config: Optional[PineconeDBConfig] = None,
@@ -103,10 +101,9 @@ class PineconeDB(BaseVectorDB):
existing_ids = list() existing_ids = list()
metadatas = [] metadatas = []
batch_size = 100
if ids is not None: if ids is not None:
for i in range(0, len(ids), batch_size): for i in range(0, len(ids), self.config.batch_size):
result = self.pinecone_index.fetch(ids=ids[i : i + batch_size]) result = self.pinecone_index.fetch(ids=ids[i : i + self.config.batch_size])
vectors = result.get("vectors") vectors = result.get("vectors")
batch_existing_ids = list(vectors.keys()) batch_existing_ids = list(vectors.keys())
existing_ids.extend(batch_existing_ids) existing_ids.extend(batch_existing_ids)
@@ -145,7 +142,7 @@ class PineconeDB(BaseVectorDB):
}, },
) )
for chunk in chunks(docs, self.BATCH_SIZE, desc="Adding chunks in batches"): for chunk in chunks(docs, self.config.batch_size, desc="Adding chunks in batches"):
self.pinecone_index.upsert(chunk, **kwargs) self.pinecone_index.upsert(chunk, **kwargs)
def query( def query(

View File

@@ -21,8 +21,6 @@ class QdrantDB(BaseVectorDB):
Qdrant as vector database Qdrant as vector database
""" """
BATCH_SIZE = 10
def __init__(self, config: QdrantDBConfig = None): def __init__(self, config: QdrantDBConfig = None):
""" """
Qdrant as vector database Qdrant as vector database
@@ -116,7 +114,7 @@ class QdrantDB(BaseVectorDB):
collection_name=self.collection_name, collection_name=self.collection_name,
scroll_filter=models.Filter(must=qdrant_must_filters), scroll_filter=models.Filter(must=qdrant_must_filters),
offset=offset, offset=offset,
limit=self.BATCH_SIZE, limit=self.config.batch_size,
) )
offset = response[1] offset = response[1]
for doc in response[0]: for doc in response[0]:
@@ -148,13 +146,13 @@ class QdrantDB(BaseVectorDB):
qdrant_ids.append(id) qdrant_ids.append(id)
payloads.append({"identifier": id, "text": document, "metadata": copy.deepcopy(metadata)}) payloads.append({"identifier": id, "text": document, "metadata": copy.deepcopy(metadata)})
for i in tqdm(range(0, len(qdrant_ids), self.BATCH_SIZE), desc="Adding data in batches"): for i in tqdm(range(0, len(qdrant_ids), self.config.batch_size), desc="Adding data in batches"):
self.client.upsert( self.client.upsert(
collection_name=self.collection_name, collection_name=self.collection_name,
points=Batch( points=Batch(
ids=qdrant_ids[i : i + self.BATCH_SIZE], ids=qdrant_ids[i : i + self.config.batch_size],
payloads=payloads[i : i + self.BATCH_SIZE], payloads=payloads[i : i + self.config.batch_size],
vectors=embeddings[i : i + self.BATCH_SIZE], vectors=embeddings[i : i + self.config.batch_size],
), ),
**kwargs, **kwargs,
) )
@@ -251,4 +249,4 @@ class QdrantDB(BaseVectorDB):
def delete(self, where: dict): def delete(self, where: dict):
db_filter = self._generate_query(where) db_filter = self._generate_query(where)
self.client.delete(collection_name=self.collection_name, points_selector=db_filter) self.client.delete(collection_name=self.collection_name, points_selector=db_filter)

View File

@@ -20,8 +20,6 @@ class WeaviateDB(BaseVectorDB):
Weaviate as vector database Weaviate as vector database
""" """
BATCH_SIZE = 100
def __init__( def __init__(
self, self,
config: Optional[WeaviateDBConfig] = None, config: Optional[WeaviateDBConfig] = None,
@@ -169,7 +167,7 @@ class WeaviateDB(BaseVectorDB):
) )
.with_where(weaviate_where_clause) .with_where(weaviate_where_clause)
.with_additional(["id"]) .with_additional(["id"])
.with_limit(limit or self.BATCH_SIZE), .with_limit(limit or self.config.batch_size),
offset, offset,
) )
@@ -198,7 +196,7 @@ class WeaviateDB(BaseVectorDB):
:type ids: list[str] :type ids: list[str]
""" """
embeddings = self.embedder.embedding_fn(documents) embeddings = self.embedder.embedding_fn(documents)
self.client.batch.configure(batch_size=self.BATCH_SIZE, timeout_retries=3) # Configure batch self.client.batch.configure(batch_size=self.config.batch_size, timeout_retries=3) # Configure batch
with self.client.batch as batch: # Initialize a batch process with self.client.batch as batch: # Initialize a batch process
for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings): for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
doc = {"identifier": id, "text": text} doc = {"identifier": id, "text": text}

View File

@@ -124,7 +124,7 @@ class TestWeaviateDb(unittest.TestCase):
db = WeaviateDB() db = WeaviateDB()
app_config = AppConfig(collect_metrics=False) app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedding_model=embedder) App(config=app_config, db=db, embedding_model=embedder)
db.BATCH_SIZE = 1 db.config.batch_size = 1
documents = ["This is test document"] documents = ["This is test document"]
metadatas = [None] metadatas = [None]