From d8a7d7134494102b4508ed97000fbc4fc157eca5 Mon Sep 17 00:00:00 2001 From: Rupesh Bansal Date: Wed, 18 Oct 2023 10:52:29 +0530 Subject: [PATCH] [Feature] Batch uploading in chromadb (#814) --- embedchain/vectordb/chroma.py | 38 ++++++++++++++++++++++++-------- tests/vectordb/test_chroma_db.py | 34 ++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 9 deletions(-) diff --git a/embedchain/vectordb/chroma.py b/embedchain/vectordb/chroma.py index be4e8604..c77e83c1 100644 --- a/embedchain/vectordb/chroma.py +++ b/embedchain/vectordb/chroma.py @@ -25,6 +25,8 @@ except RuntimeError: class ChromaDB(BaseVectorDB): """Vector database using ChromaDB.""" + BATCH_SIZE = 100 + def __init__(self, config: Optional[ChromaDbConfig] = None): """Initialize a new ChromaDB instance @@ -123,10 +125,6 @@ class ChromaDB(BaseVectorDB): args["limit"] = limit return self.collection.get(**args) - def get_advanced(self, where): - where_clause = self._generate_where_clause(where) - return self.collection.get(where=where_clause, limit=1) - def add( self, embeddings: List[List[float]], @@ -149,10 +147,31 @@ class ChromaDB(BaseVectorDB): :param skip_embedding: Optional. If True, then the embeddings are assumed to be already generated. :type skip_embedding: bool """ - if skip_embedding: - self.collection.add(embeddings=embeddings, documents=documents, metadatas=metadatas, ids=ids) - else: - self.collection.add(documents=documents, metadatas=metadatas, ids=ids) + size = len(documents) + if skip_embedding and (embeddings is None or len(embeddings) != len(documents)): + raise ValueError("Cannot add documents to chromadb with inconsistent embeddings") + + if len(documents) != size or len(metadatas) != size or len(ids) != size: + raise ValueError( + "Cannot add documents to chromadb with inconsistent sizes. Documents size: {}, Metadata size: {}," + " Ids size: {}".format(len(documents), len(metadatas), len(ids)) + ) + + for i in range(0, len(documents), self.BATCH_SIZE): + print("Inserting batches from {} to {} in chromadb".format(i, min(len(documents), i + self.BATCH_SIZE))) + if skip_embedding: + self.collection.add( + embeddings=embeddings[i : i + self.BATCH_SIZE], + documents=documents[i : i + self.BATCH_SIZE], + metadatas=metadatas[i : i + self.BATCH_SIZE], + ids=ids[i : i + self.BATCH_SIZE], + ) + else: + self.collection.add( + documents=documents[i : i + self.BATCH_SIZE], + metadatas=metadatas[i : i + self.BATCH_SIZE], + ids=ids[i : i + self.BATCH_SIZE], + ) def _format_result(self, results: QueryResult) -> list[tuple[Document, float]]: """ @@ -208,7 +227,8 @@ class ChromaDB(BaseVectorDB): except InvalidDimensionException as e: raise InvalidDimensionException( e.message() - + ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501 + + ". This is commonly a side-effect when an embedding function, different from the one used to add the" + " embeddings, is used to retrieve an embedding from the database." ) from None results_formatted = self._format_result(result) contents = [result[0].page_content for result in results_formatted] diff --git a/tests/vectordb/test_chroma_db.py b/tests/vectordb/test_chroma_db.py index 0472ec0f..8caf14ac 100644 --- a/tests/vectordb/test_chroma_db.py +++ b/tests/vectordb/test_chroma_db.py @@ -228,6 +228,40 @@ class TestChromaDbCollection(unittest.TestCase): expected_value = ["document"] self.assertEqual(data, expected_value) + def test_add_with_invalid_inputs(self): + """ + Test add fails with invalid inputs + """ + # Start with a clean app + self.app_with_settings.reset() + # app = App(config=AppConfig(collect_metrics=False), db=db) + + # Collection should be empty when created + self.assertEqual(self.app_with_settings.db.count(), 0) + + with self.assertRaises(ValueError): + self.app_with_settings.db.add( + embeddings=[[0, 0, 0]], + documents=["document", "document2"], + metadatas=[{"value": "somevalue"}], + ids=["id"], + skip_embedding=True, + ) + # After adding, should contain no item + self.assertEqual(self.app_with_settings.db.count(), 0) + + with self.assertRaises(ValueError): + self.app_with_settings.db.add( + embeddings=None, + documents=["document", "document2"], + metadatas=[{"value": "somevalue"}], + ids=["id"], + skip_embedding=True, + ) + + # After adding, should contain no item + self.assertEqual(self.app_with_settings.db.count(), 0) + def test_collections_are_persistent(self): """ Test that a collection can be picked up later.