From 50c0285cb21604bc06564c0148a06fa0b9ce9e78 Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Fri, 28 Jun 2024 23:48:22 +0530 Subject: [PATCH] Fix batch_size for vectordb (#1449) --- Makefile | 2 +- embedchain/config/vectordb/base.py | 4 ---- embedchain/config/vectordb/chroma.py | 4 ++++ embedchain/config/vectordb/elasticsearch.py | 7 +++++++ embedchain/config/vectordb/opensearch.py | 4 ++++ embedchain/config/vectordb/pinecone.py | 2 ++ embedchain/config/vectordb/qdrant.py | 4 ++++ embedchain/config/vectordb/weaviate.py | 2 ++ embedchain/vectordb/chroma.py | 10 ++++++---- embedchain/vectordb/elasticsearch.py | 3 ++- embedchain/vectordb/opensearch.py | 7 +++---- embedchain/vectordb/pinecone.py | 7 ++++--- embedchain/vectordb/qdrant.py | 11 ++++++----- embedchain/vectordb/weaviate.py | 5 +++-- tests/vectordb/test_weaviate.py | 3 +-- 15 files changed, 49 insertions(+), 26 deletions(-) diff --git a/Makefile b/Makefile index 5f10b5f7..ab23978d 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,7 @@ install: install_all: poetry install --all-extras - poetry run pip install pinecone-text pinecone-client langchain-anthropic "unstructured[local-inference, all-docs]" ollama deepgram-sdk==3.2.7 langchain-huggingface + poetry run pip install pinecone-text pinecone-client langchain-anthropic "unstructured[local-inference, all-docs]" ollama deepgram-sdk==3.2.7 langchain-huggingface psutil install_es: poetry install --extras elasticsearch diff --git a/embedchain/config/vectordb/base.py b/embedchain/config/vectordb/base.py index 84ed5b13..3252880a 100644 --- a/embedchain/config/vectordb/base.py +++ b/embedchain/config/vectordb/base.py @@ -10,7 +10,6 @@ class BaseVectorDbConfig(BaseConfig): dir: str = "db", host: Optional[str] = None, port: Optional[str] = None, - batch_size: Optional[int] = 100, **kwargs, ): """ @@ -24,8 +23,6 @@ class BaseVectorDbConfig(BaseConfig): :type host: Optional[str], optional :param host: Database connection remote port. Use this if you run Embedchain as a client, defaults to None :type port: Optional[str], optional - :param batch_size: Number of items to insert in one batch, defaults to 100 - :type batch_size: Optional[int], optional :param kwargs: Additional keyword arguments :type kwargs: dict """ @@ -33,7 +30,6 @@ class BaseVectorDbConfig(BaseConfig): self.dir = dir self.host = host self.port = port - self.batch_size = batch_size # Assign additional keyword arguments if kwargs: for key, value in kwargs.items(): diff --git a/embedchain/config/vectordb/chroma.py b/embedchain/config/vectordb/chroma.py index d25de1c3..b99f1d94 100644 --- a/embedchain/config/vectordb/chroma.py +++ b/embedchain/config/vectordb/chroma.py @@ -12,6 +12,7 @@ class ChromaDbConfig(BaseVectorDbConfig): dir: Optional[str] = None, host: Optional[str] = None, port: Optional[str] = None, + batch_size: Optional[int] = 100, allow_reset=False, chroma_settings: Optional[dict] = None, ): @@ -26,6 +27,8 @@ class ChromaDbConfig(BaseVectorDbConfig): :type host: Optional[str], optional :param port: Database connection remote port. Use this if you run Embedchain as a client, defaults to None :type port: Optional[str], optional + :param batch_size: Number of items to insert in one batch, defaults to 100 + :type batch_size: Optional[int], optional :param allow_reset: Resets the database. defaults to False :type allow_reset: bool :param chroma_settings: Chroma settings dict, defaults to None @@ -34,4 +37,5 @@ class ChromaDbConfig(BaseVectorDbConfig): self.chroma_settings = chroma_settings self.allow_reset = allow_reset + self.batch_size = batch_size super().__init__(collection_name=collection_name, dir=dir, host=host, port=port) diff --git a/embedchain/config/vectordb/elasticsearch.py b/embedchain/config/vectordb/elasticsearch.py index 700a7192..1679700a 100644 --- a/embedchain/config/vectordb/elasticsearch.py +++ b/embedchain/config/vectordb/elasticsearch.py @@ -13,6 +13,7 @@ class ElasticsearchDBConfig(BaseVectorDbConfig): dir: Optional[str] = None, es_url: Union[str, list[str]] = None, cloud_id: Optional[str] = None, + batch_size: Optional[int] = 100, **ES_EXTRA_PARAMS: dict[str, any], ): """ @@ -24,6 +25,10 @@ class ElasticsearchDBConfig(BaseVectorDbConfig): :type dir: Optional[str], optional :param es_url: elasticsearch url or list of nodes url to be used for connection, defaults to None :type es_url: Union[str, list[str]], optional + :param cloud_id: cloud id of the elasticsearch cluster, defaults to None + :type cloud_id: Optional[str], optional + :param batch_size: Number of items to insert in one batch, defaults to 100 + :type batch_size: Optional[int], optional :param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch. :type ES_EXTRA_PARAMS: dict[str, Any], optional """ @@ -46,4 +51,6 @@ class ElasticsearchDBConfig(BaseVectorDbConfig): and not self.ES_EXTRA_PARAMS.get("bearer_auth") ): self.ES_EXTRA_PARAMS["api_key"] = os.environ.get("ELASTICSEARCH_API_KEY") + + self.batch_size = batch_size super().__init__(collection_name=collection_name, dir=dir) diff --git a/embedchain/config/vectordb/opensearch.py b/embedchain/config/vectordb/opensearch.py index 1e112772..3d0b6099 100644 --- a/embedchain/config/vectordb/opensearch.py +++ b/embedchain/config/vectordb/opensearch.py @@ -13,6 +13,7 @@ class OpenSearchDBConfig(BaseVectorDbConfig): vector_dimension: int = 1536, collection_name: Optional[str] = None, dir: Optional[str] = None, + batch_size: Optional[int] = 100, **extra_params: dict[str, any], ): """ @@ -28,10 +29,13 @@ class OpenSearchDBConfig(BaseVectorDbConfig): :type vector_dimension: int, optional :param dir: Path to the database directory, where the database is stored, defaults to None :type dir: Optional[str], optional + :param batch_size: Number of items to insert in one batch, defaults to 100 + :type batch_size: Optional[int], optional """ self.opensearch_url = opensearch_url self.http_auth = http_auth self.vector_dimension = vector_dimension self.extra_params = extra_params + self.batch_size = batch_size super().__init__(collection_name=collection_name, dir=dir) diff --git a/embedchain/config/vectordb/pinecone.py b/embedchain/config/vectordb/pinecone.py index f021639d..4eb66f21 100644 --- a/embedchain/config/vectordb/pinecone.py +++ b/embedchain/config/vectordb/pinecone.py @@ -17,6 +17,7 @@ class PineconeDBConfig(BaseVectorDbConfig): serverless_config: Optional[dict[str, any]] = None, hybrid_search: bool = False, bm25_encoder: any = None, + batch_size: Optional[int] = 100, **extra_params: dict[str, any], ): self.metric = metric @@ -26,6 +27,7 @@ class PineconeDBConfig(BaseVectorDbConfig): self.extra_params = extra_params self.hybrid_search = hybrid_search self.bm25_encoder = bm25_encoder + self.batch_size = batch_size if pod_config is None and serverless_config is None: # If no config is provided, use the default pod spec config pod_environment = os.environ.get("PINECONE_ENV", "gcp-starter") diff --git a/embedchain/config/vectordb/qdrant.py b/embedchain/config/vectordb/qdrant.py index 1268913e..2520226a 100644 --- a/embedchain/config/vectordb/qdrant.py +++ b/embedchain/config/vectordb/qdrant.py @@ -18,6 +18,7 @@ class QdrantDBConfig(BaseVectorDbConfig): hnsw_config: Optional[dict[str, any]] = None, quantization_config: Optional[dict[str, any]] = None, on_disk: Optional[bool] = None, + batch_size: Optional[int] = 10, **extra_params: dict[str, any], ): """ @@ -36,9 +37,12 @@ class QdrantDBConfig(BaseVectorDbConfig): This setting saves RAM by (slightly) increasing the response time. Note: those payload values that are involved in filtering and are indexed - remain in RAM. :type on_disk: bool, optional, defaults to None + :param batch_size: Number of items to insert in one batch, defaults to 10 + :type batch_size: Optional[int], optional """ self.hnsw_config = hnsw_config self.quantization_config = quantization_config self.on_disk = on_disk + self.batch_size = batch_size self.extra_params = extra_params super().__init__(collection_name=collection_name, dir=dir) diff --git a/embedchain/config/vectordb/weaviate.py b/embedchain/config/vectordb/weaviate.py index 3f5a353a..8277e88f 100644 --- a/embedchain/config/vectordb/weaviate.py +++ b/embedchain/config/vectordb/weaviate.py @@ -10,7 +10,9 @@ class WeaviateDBConfig(BaseVectorDbConfig): self, collection_name: Optional[str] = None, dir: Optional[str] = None, + batch_size: Optional[int] = 100, **extra_params: dict[str, any], ): + self.batch_size = batch_size self.extra_params = extra_params super().__init__(collection_name=collection_name, dir=dir) diff --git a/embedchain/vectordb/chroma.py b/embedchain/vectordb/chroma.py index 9b6c88c8..de73397e 100644 --- a/embedchain/vectordb/chroma.py +++ b/embedchain/vectordb/chroma.py @@ -42,6 +42,7 @@ class ChromaDB(BaseVectorDB): self.settings = Settings(anonymized_telemetry=False) self.settings.allow_reset = self.config.allow_reset if hasattr(self.config, "allow_reset") else False + self.batch_size = self.config.batch_size if self.config.chroma_settings: for key, value in self.config.chroma_settings.items(): if hasattr(self.settings, key): @@ -153,12 +154,13 @@ class ChromaDB(BaseVectorDB): " Ids size: {}".format(len(documents), len(metadatas), len(ids)) ) - for i in tqdm(range(0, len(documents), self.config.batch_size), desc="Inserting batches in chromadb"): + for i in tqdm(range(0, len(documents), self.batch_size), desc="Inserting batches in chromadb"): self.collection.add( - documents=documents[i : i + self.config.batch_size], - metadatas=metadatas[i : i + self.config.batch_size], - ids=ids[i : i + self.config.batch_size], + documents=documents[i : i + self.batch_size], + metadatas=metadatas[i : i + self.batch_size], + ids=ids[i : i + self.batch_size], ) + self.config @staticmethod def _format_result(results: QueryResult) -> list[tuple[Document, float]]: diff --git a/embedchain/vectordb/elasticsearch.py b/embedchain/vectordb/elasticsearch.py index c9d21ca3..12b87176 100644 --- a/embedchain/vectordb/elasticsearch.py +++ b/embedchain/vectordb/elasticsearch.py @@ -55,6 +55,7 @@ class ElasticsearchDB(BaseVectorDB): "Something is wrong with your config. Please check again - `https://docs.embedchain.ai/components/vector-databases#elasticsearch`" # noqa: E501 ) + self.batch_size = self.config.batch_size # Call parent init here because embedder is needed super().__init__(config=self.config) @@ -139,7 +140,7 @@ class ElasticsearchDB(BaseVectorDB): for chunk in chunks( list(zip(ids, documents, metadatas, embeddings)), - self.config.batch_size, + self.batch_size, desc="Inserting batches in elasticsearch", ): # noqa: E501 ids, docs, metadatas, embeddings = [], [], [], [] diff --git a/embedchain/vectordb/opensearch.py b/embedchain/vectordb/opensearch.py index e043a361..accec432 100644 --- a/embedchain/vectordb/opensearch.py +++ b/embedchain/vectordb/opensearch.py @@ -37,6 +37,7 @@ class OpenSearchDB(BaseVectorDB): if config is None: raise ValueError("OpenSearchDBConfig is required") self.config = config + self.batch_size = self.config.batch_size self.client = OpenSearch( hosts=[self.config.opensearch_url], http_auth=self.config.http_auth, @@ -118,10 +119,8 @@ class OpenSearchDB(BaseVectorDB): """Adds documents to the opensearch index""" embeddings = self.embedder.embedding_fn(documents) - for batch_start in tqdm( - range(0, len(documents), self.config.batch_size), desc="Inserting batches in opensearch" - ): - batch_end = batch_start + self.config.batch_size + for batch_start in tqdm(range(0, len(documents), self.batch_size), desc="Inserting batches in opensearch"): + batch_end = batch_start + self.batch_size batch_documents = documents[batch_start:batch_end] batch_embeddings = embeddings[batch_start:batch_end] diff --git a/embedchain/vectordb/pinecone.py b/embedchain/vectordb/pinecone.py index becbe08f..c89dbcda 100644 --- a/embedchain/vectordb/pinecone.py +++ b/embedchain/vectordb/pinecone.py @@ -48,6 +48,7 @@ class PineconeDB(BaseVectorDB): # Setup BM25Encoder if sparse vectors are to be used self.bm25_encoder = None + self.batch_size = self.config.batch_size if self.config.hybrid_search: logger.info("Initializing BM25Encoder for sparse vectors..") self.bm25_encoder = self.config.bm25_encoder if self.config.bm25_encoder else BM25Encoder.default() @@ -102,8 +103,8 @@ class PineconeDB(BaseVectorDB): metadatas = [] if ids is not None: - for i in range(0, len(ids), self.config.batch_size): - result = self.pinecone_index.fetch(ids=ids[i : i + self.config.batch_size]) + for i in range(0, len(ids), self.batch_size): + result = self.pinecone_index.fetch(ids=ids[i : i + self.batch_size]) vectors = result.get("vectors") batch_existing_ids = list(vectors.keys()) existing_ids.extend(batch_existing_ids) @@ -142,7 +143,7 @@ class PineconeDB(BaseVectorDB): }, ) - for chunk in chunks(docs, self.config.batch_size, desc="Adding chunks in batches"): + for chunk in chunks(docs, self.batch_size, desc="Adding chunks in batches"): self.pinecone_index.upsert(chunk, **kwargs) def query( diff --git a/embedchain/vectordb/qdrant.py b/embedchain/vectordb/qdrant.py index 322e1562..d9b05020 100644 --- a/embedchain/vectordb/qdrant.py +++ b/embedchain/vectordb/qdrant.py @@ -35,6 +35,7 @@ class QdrantDB(BaseVectorDB): "Please make sure the type is right and that you are passing an instance." ) self.config = config + self.batch_size = self.config.batch_size self.client = QdrantClient(url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY")) # Call parent init here because embedder is needed super().__init__(config=self.config) @@ -114,7 +115,7 @@ class QdrantDB(BaseVectorDB): collection_name=self.collection_name, scroll_filter=models.Filter(must=qdrant_must_filters), offset=offset, - limit=self.config.batch_size, + limit=self.batch_size, ) offset = response[1] for doc in response[0]: @@ -146,13 +147,13 @@ class QdrantDB(BaseVectorDB): qdrant_ids.append(id) payloads.append({"identifier": id, "text": document, "metadata": copy.deepcopy(metadata)}) - for i in tqdm(range(0, len(qdrant_ids), self.config.batch_size), desc="Adding data in batches"): + for i in tqdm(range(0, len(qdrant_ids), self.batch_size), desc="Adding data in batches"): self.client.upsert( collection_name=self.collection_name, points=Batch( - ids=qdrant_ids[i : i + self.config.batch_size], - payloads=payloads[i : i + self.config.batch_size], - vectors=embeddings[i : i + self.config.batch_size], + ids=qdrant_ids[i : i + self.batch_size], + payloads=payloads[i : i + self.batch_size], + vectors=embeddings[i : i + self.batch_size], ), **kwargs, ) diff --git a/embedchain/vectordb/weaviate.py b/embedchain/vectordb/weaviate.py index 2d450051..f4632436 100644 --- a/embedchain/vectordb/weaviate.py +++ b/embedchain/vectordb/weaviate.py @@ -38,6 +38,7 @@ class WeaviateDB(BaseVectorDB): "Please make sure the type is right and that you are passing an instance." ) self.config = config + self.batch_size = self.config.batch_size self.client = weaviate.Client( url=os.environ.get("WEAVIATE_ENDPOINT"), auth_client_secret=weaviate.AuthApiKey(api_key=os.environ.get("WEAVIATE_API_KEY")), @@ -167,7 +168,7 @@ class WeaviateDB(BaseVectorDB): ) .with_where(weaviate_where_clause) .with_additional(["id"]) - .with_limit(limit or self.config.batch_size), + .with_limit(limit or self.batch_size), offset, ) @@ -196,7 +197,7 @@ class WeaviateDB(BaseVectorDB): :type ids: list[str] """ embeddings = self.embedder.embedding_fn(documents) - self.client.batch.configure(batch_size=self.config.batch_size, timeout_retries=3) # Configure batch + self.client.batch.configure(batch_size=self.batch_size, timeout_retries=3) # Configure batch with self.client.batch as batch: # Initialize a batch process for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings): doc = {"identifier": id, "text": text} diff --git a/tests/vectordb/test_weaviate.py b/tests/vectordb/test_weaviate.py index 01068741..b263579f 100644 --- a/tests/vectordb/test_weaviate.py +++ b/tests/vectordb/test_weaviate.py @@ -124,7 +124,6 @@ class TestWeaviateDb(unittest.TestCase): db = WeaviateDB() app_config = AppConfig(collect_metrics=False) App(config=app_config, db=db, embedding_model=embedder) - db.config.batch_size = 1 documents = ["This is test document"] metadatas = [None] @@ -132,7 +131,7 @@ class TestWeaviateDb(unittest.TestCase): db.add(documents, metadatas, ids) # Check if the document was added to the database. - weaviate_client_batch_mock.configure.assert_called_once_with(batch_size=1, timeout_retries=3) + weaviate_client_batch_mock.configure.assert_called_once_with(batch_size=100, timeout_retries=3) weaviate_client_batch_enter_mock.add_data_object.assert_any_call( data_object={"text": documents[0]}, class_name="Embedchain_store_1536_metadata", vector=[1, 2, 3] )