From 0a78198bb5046a3f4f8d00e7dc4b38c9a82347a6 Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Fri, 28 Jun 2024 03:15:58 +0530 Subject: [PATCH] Add batch_size in config for VectorDB (#1448) --- docs/api-reference/advanced/configuration.mdx | 1 + embedchain/config/vectordb/base.py | 4 ++++ embedchain/vectordb/chroma.py | 10 ++++------ embedchain/vectordb/elasticsearch.py | 6 +++--- embedchain/vectordb/lancedb.py | 2 -- embedchain/vectordb/opensearch.py | 8 ++++---- embedchain/vectordb/pinecone.py | 9 +++------ embedchain/vectordb/qdrant.py | 14 ++++++-------- embedchain/vectordb/weaviate.py | 6 ++---- tests/vectordb/test_weaviate.py | 2 +- 10 files changed, 28 insertions(+), 34 deletions(-) diff --git a/docs/api-reference/advanced/configuration.mdx b/docs/api-reference/advanced/configuration.mdx index da90a5da..d9938d3e 100644 --- a/docs/api-reference/advanced/configuration.mdx +++ b/docs/api-reference/advanced/configuration.mdx @@ -217,6 +217,7 @@ Alright, let's dive into what each key means in the yaml config above: - `collection_name` (String): The initial collection name for the vectordb, set to 'full-stack-app'. - `dir` (String): The directory for the local database, set to 'db'. - `allow_reset` (Boolean): Indicates whether resetting the vectordb is allowed, set to true. + - `batch_size` (Integer): The batch size for docs insertion in vectordb, defaults to `100` We recommend you to checkout vectordb specific config [here](https://docs.embedchain.ai/components/vector-databases) 4. `embedder` Section: - `provider` (String): The provider for the embedder, set to 'openai'. You can find the full list of embedding model providers in [our docs](/components/embedding-models). diff --git a/embedchain/config/vectordb/base.py b/embedchain/config/vectordb/base.py index 3252880a..84ed5b13 100644 --- a/embedchain/config/vectordb/base.py +++ b/embedchain/config/vectordb/base.py @@ -10,6 +10,7 @@ class BaseVectorDbConfig(BaseConfig): dir: str = "db", host: Optional[str] = None, port: Optional[str] = None, + batch_size: Optional[int] = 100, **kwargs, ): """ @@ -23,6 +24,8 @@ class BaseVectorDbConfig(BaseConfig): :type host: Optional[str], optional :param host: Database connection remote port. Use this if you run Embedchain as a client, defaults to None :type port: Optional[str], optional + :param batch_size: Number of items to insert in one batch, defaults to 100 + :type batch_size: Optional[int], optional :param kwargs: Additional keyword arguments :type kwargs: dict """ @@ -30,6 +33,7 @@ class BaseVectorDbConfig(BaseConfig): self.dir = dir self.host = host self.port = port + self.batch_size = batch_size # Assign additional keyword arguments if kwargs: for key, value in kwargs.items(): diff --git a/embedchain/vectordb/chroma.py b/embedchain/vectordb/chroma.py index 31dc2615..9b6c88c8 100644 --- a/embedchain/vectordb/chroma.py +++ b/embedchain/vectordb/chroma.py @@ -29,8 +29,6 @@ logger = logging.getLogger(__name__) class ChromaDB(BaseVectorDB): """Vector database using ChromaDB.""" - BATCH_SIZE = 100 - def __init__(self, config: Optional[ChromaDbConfig] = None): """Initialize a new ChromaDB instance @@ -155,11 +153,11 @@ class ChromaDB(BaseVectorDB): " Ids size: {}".format(len(documents), len(metadatas), len(ids)) ) - for i in tqdm(range(0, len(documents), self.BATCH_SIZE), desc="Inserting batches in chromadb"): + for i in tqdm(range(0, len(documents), self.config.batch_size), desc="Inserting batches in chromadb"): self.collection.add( - documents=documents[i : i + self.BATCH_SIZE], - metadatas=metadatas[i : i + self.BATCH_SIZE], - ids=ids[i : i + self.BATCH_SIZE], + documents=documents[i : i + self.config.batch_size], + metadatas=metadatas[i : i + self.config.batch_size], + ids=ids[i : i + self.config.batch_size], ) @staticmethod diff --git a/embedchain/vectordb/elasticsearch.py b/embedchain/vectordb/elasticsearch.py index 5611af3e..c9d21ca3 100644 --- a/embedchain/vectordb/elasticsearch.py +++ b/embedchain/vectordb/elasticsearch.py @@ -23,8 +23,6 @@ class ElasticsearchDB(BaseVectorDB): Elasticsearch as vector database """ - BATCH_SIZE = 100 - def __init__( self, config: Optional[ElasticsearchDBConfig] = None, @@ -140,7 +138,9 @@ class ElasticsearchDB(BaseVectorDB): embeddings = self.embedder.embedding_fn(documents) for chunk in chunks( - list(zip(ids, documents, metadatas, embeddings)), self.BATCH_SIZE, desc="Inserting batches in elasticsearch" + list(zip(ids, documents, metadatas, embeddings)), + self.config.batch_size, + desc="Inserting batches in elasticsearch", ): # noqa: E501 ids, docs, metadatas, embeddings = [], [], [], [] for id, text, metadata, embedding in chunk: diff --git a/embedchain/vectordb/lancedb.py b/embedchain/vectordb/lancedb.py index a502db65..3179eb5c 100644 --- a/embedchain/vectordb/lancedb.py +++ b/embedchain/vectordb/lancedb.py @@ -18,8 +18,6 @@ class LanceDB(BaseVectorDB): LanceDB as vector database """ - BATCH_SIZE = 100 - def __init__( self, config: Optional[LanceDBConfig] = None, diff --git a/embedchain/vectordb/opensearch.py b/embedchain/vectordb/opensearch.py index a5339b3a..e043a361 100644 --- a/embedchain/vectordb/opensearch.py +++ b/embedchain/vectordb/opensearch.py @@ -28,8 +28,6 @@ class OpenSearchDB(BaseVectorDB): OpenSearch as vector database """ - BATCH_SIZE = 100 - def __init__(self, config: OpenSearchDBConfig): """OpenSearch as vector database. @@ -120,8 +118,10 @@ class OpenSearchDB(BaseVectorDB): """Adds documents to the opensearch index""" embeddings = self.embedder.embedding_fn(documents) - for batch_start in tqdm(range(0, len(documents), self.BATCH_SIZE), desc="Inserting batches in opensearch"): - batch_end = batch_start + self.BATCH_SIZE + for batch_start in tqdm( + range(0, len(documents), self.config.batch_size), desc="Inserting batches in opensearch" + ): + batch_end = batch_start + self.config.batch_size batch_documents = documents[batch_start:batch_end] batch_embeddings = embeddings[batch_start:batch_end] diff --git a/embedchain/vectordb/pinecone.py b/embedchain/vectordb/pinecone.py index bbf5b2ca..becbe08f 100644 --- a/embedchain/vectordb/pinecone.py +++ b/embedchain/vectordb/pinecone.py @@ -25,8 +25,6 @@ class PineconeDB(BaseVectorDB): Pinecone as vector database """ - BATCH_SIZE = 100 - def __init__( self, config: Optional[PineconeDBConfig] = None, @@ -103,10 +101,9 @@ class PineconeDB(BaseVectorDB): existing_ids = list() metadatas = [] - batch_size = 100 if ids is not None: - for i in range(0, len(ids), batch_size): - result = self.pinecone_index.fetch(ids=ids[i : i + batch_size]) + for i in range(0, len(ids), self.config.batch_size): + result = self.pinecone_index.fetch(ids=ids[i : i + self.config.batch_size]) vectors = result.get("vectors") batch_existing_ids = list(vectors.keys()) existing_ids.extend(batch_existing_ids) @@ -145,7 +142,7 @@ class PineconeDB(BaseVectorDB): }, ) - for chunk in chunks(docs, self.BATCH_SIZE, desc="Adding chunks in batches"): + for chunk in chunks(docs, self.config.batch_size, desc="Adding chunks in batches"): self.pinecone_index.upsert(chunk, **kwargs) def query( diff --git a/embedchain/vectordb/qdrant.py b/embedchain/vectordb/qdrant.py index c356332c..322e1562 100644 --- a/embedchain/vectordb/qdrant.py +++ b/embedchain/vectordb/qdrant.py @@ -21,8 +21,6 @@ class QdrantDB(BaseVectorDB): Qdrant as vector database """ - BATCH_SIZE = 10 - def __init__(self, config: QdrantDBConfig = None): """ Qdrant as vector database @@ -116,7 +114,7 @@ class QdrantDB(BaseVectorDB): collection_name=self.collection_name, scroll_filter=models.Filter(must=qdrant_must_filters), offset=offset, - limit=self.BATCH_SIZE, + limit=self.config.batch_size, ) offset = response[1] for doc in response[0]: @@ -148,13 +146,13 @@ class QdrantDB(BaseVectorDB): qdrant_ids.append(id) payloads.append({"identifier": id, "text": document, "metadata": copy.deepcopy(metadata)}) - for i in tqdm(range(0, len(qdrant_ids), self.BATCH_SIZE), desc="Adding data in batches"): + for i in tqdm(range(0, len(qdrant_ids), self.config.batch_size), desc="Adding data in batches"): self.client.upsert( collection_name=self.collection_name, points=Batch( - ids=qdrant_ids[i : i + self.BATCH_SIZE], - payloads=payloads[i : i + self.BATCH_SIZE], - vectors=embeddings[i : i + self.BATCH_SIZE], + ids=qdrant_ids[i : i + self.config.batch_size], + payloads=payloads[i : i + self.config.batch_size], + vectors=embeddings[i : i + self.config.batch_size], ), **kwargs, ) @@ -251,4 +249,4 @@ class QdrantDB(BaseVectorDB): def delete(self, where: dict): db_filter = self._generate_query(where) - self.client.delete(collection_name=self.collection_name, points_selector=db_filter) \ No newline at end of file + self.client.delete(collection_name=self.collection_name, points_selector=db_filter) diff --git a/embedchain/vectordb/weaviate.py b/embedchain/vectordb/weaviate.py index b5a76b84..2d450051 100644 --- a/embedchain/vectordb/weaviate.py +++ b/embedchain/vectordb/weaviate.py @@ -20,8 +20,6 @@ class WeaviateDB(BaseVectorDB): Weaviate as vector database """ - BATCH_SIZE = 100 - def __init__( self, config: Optional[WeaviateDBConfig] = None, @@ -169,7 +167,7 @@ class WeaviateDB(BaseVectorDB): ) .with_where(weaviate_where_clause) .with_additional(["id"]) - .with_limit(limit or self.BATCH_SIZE), + .with_limit(limit or self.config.batch_size), offset, ) @@ -198,7 +196,7 @@ class WeaviateDB(BaseVectorDB): :type ids: list[str] """ embeddings = self.embedder.embedding_fn(documents) - self.client.batch.configure(batch_size=self.BATCH_SIZE, timeout_retries=3) # Configure batch + self.client.batch.configure(batch_size=self.config.batch_size, timeout_retries=3) # Configure batch with self.client.batch as batch: # Initialize a batch process for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings): doc = {"identifier": id, "text": text} diff --git a/tests/vectordb/test_weaviate.py b/tests/vectordb/test_weaviate.py index 805e4aaf..01068741 100644 --- a/tests/vectordb/test_weaviate.py +++ b/tests/vectordb/test_weaviate.py @@ -124,7 +124,7 @@ class TestWeaviateDb(unittest.TestCase): db = WeaviateDB() app_config = AppConfig(collect_metrics=False) App(config=app_config, db=db, embedding_model=embedder) - db.BATCH_SIZE = 1 + db.config.batch_size = 1 documents = ["This is test document"] metadatas = [None]