[Feature] Batch uploading in chromadb (#814)
This commit is contained in:
@@ -25,6 +25,8 @@ except RuntimeError:
|
|||||||
class ChromaDB(BaseVectorDB):
|
class ChromaDB(BaseVectorDB):
|
||||||
"""Vector database using ChromaDB."""
|
"""Vector database using ChromaDB."""
|
||||||
|
|
||||||
|
BATCH_SIZE = 100
|
||||||
|
|
||||||
def __init__(self, config: Optional[ChromaDbConfig] = None):
|
def __init__(self, config: Optional[ChromaDbConfig] = None):
|
||||||
"""Initialize a new ChromaDB instance
|
"""Initialize a new ChromaDB instance
|
||||||
|
|
||||||
@@ -123,10 +125,6 @@ class ChromaDB(BaseVectorDB):
|
|||||||
args["limit"] = limit
|
args["limit"] = limit
|
||||||
return self.collection.get(**args)
|
return self.collection.get(**args)
|
||||||
|
|
||||||
def get_advanced(self, where):
|
|
||||||
where_clause = self._generate_where_clause(where)
|
|
||||||
return self.collection.get(where=where_clause, limit=1)
|
|
||||||
|
|
||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
embeddings: List[List[float]],
|
embeddings: List[List[float]],
|
||||||
@@ -149,10 +147,31 @@ class ChromaDB(BaseVectorDB):
|
|||||||
:param skip_embedding: Optional. If True, then the embeddings are assumed to be already generated.
|
:param skip_embedding: Optional. If True, then the embeddings are assumed to be already generated.
|
||||||
:type skip_embedding: bool
|
:type skip_embedding: bool
|
||||||
"""
|
"""
|
||||||
|
size = len(documents)
|
||||||
|
if skip_embedding and (embeddings is None or len(embeddings) != len(documents)):
|
||||||
|
raise ValueError("Cannot add documents to chromadb with inconsistent embeddings")
|
||||||
|
|
||||||
|
if len(documents) != size or len(metadatas) != size or len(ids) != size:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot add documents to chromadb with inconsistent sizes. Documents size: {}, Metadata size: {},"
|
||||||
|
" Ids size: {}".format(len(documents), len(metadatas), len(ids))
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(0, len(documents), self.BATCH_SIZE):
|
||||||
|
print("Inserting batches from {} to {} in chromadb".format(i, min(len(documents), i + self.BATCH_SIZE)))
|
||||||
if skip_embedding:
|
if skip_embedding:
|
||||||
self.collection.add(embeddings=embeddings, documents=documents, metadatas=metadatas, ids=ids)
|
self.collection.add(
|
||||||
|
embeddings=embeddings[i : i + self.BATCH_SIZE],
|
||||||
|
documents=documents[i : i + self.BATCH_SIZE],
|
||||||
|
metadatas=metadatas[i : i + self.BATCH_SIZE],
|
||||||
|
ids=ids[i : i + self.BATCH_SIZE],
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
|
self.collection.add(
|
||||||
|
documents=documents[i : i + self.BATCH_SIZE],
|
||||||
|
metadatas=metadatas[i : i + self.BATCH_SIZE],
|
||||||
|
ids=ids[i : i + self.BATCH_SIZE],
|
||||||
|
)
|
||||||
|
|
||||||
def _format_result(self, results: QueryResult) -> list[tuple[Document, float]]:
|
def _format_result(self, results: QueryResult) -> list[tuple[Document, float]]:
|
||||||
"""
|
"""
|
||||||
@@ -208,7 +227,8 @@ class ChromaDB(BaseVectorDB):
|
|||||||
except InvalidDimensionException as e:
|
except InvalidDimensionException as e:
|
||||||
raise InvalidDimensionException(
|
raise InvalidDimensionException(
|
||||||
e.message()
|
e.message()
|
||||||
+ ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501
|
+ ". This is commonly a side-effect when an embedding function, different from the one used to add the"
|
||||||
|
" embeddings, is used to retrieve an embedding from the database."
|
||||||
) from None
|
) from None
|
||||||
results_formatted = self._format_result(result)
|
results_formatted = self._format_result(result)
|
||||||
contents = [result[0].page_content for result in results_formatted]
|
contents = [result[0].page_content for result in results_formatted]
|
||||||
|
|||||||
@@ -228,6 +228,40 @@ class TestChromaDbCollection(unittest.TestCase):
|
|||||||
expected_value = ["document"]
|
expected_value = ["document"]
|
||||||
self.assertEqual(data, expected_value)
|
self.assertEqual(data, expected_value)
|
||||||
|
|
||||||
|
def test_add_with_invalid_inputs(self):
|
||||||
|
"""
|
||||||
|
Test add fails with invalid inputs
|
||||||
|
"""
|
||||||
|
# Start with a clean app
|
||||||
|
self.app_with_settings.reset()
|
||||||
|
# app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||||
|
|
||||||
|
# Collection should be empty when created
|
||||||
|
self.assertEqual(self.app_with_settings.db.count(), 0)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
self.app_with_settings.db.add(
|
||||||
|
embeddings=[[0, 0, 0]],
|
||||||
|
documents=["document", "document2"],
|
||||||
|
metadatas=[{"value": "somevalue"}],
|
||||||
|
ids=["id"],
|
||||||
|
skip_embedding=True,
|
||||||
|
)
|
||||||
|
# After adding, should contain no item
|
||||||
|
self.assertEqual(self.app_with_settings.db.count(), 0)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
self.app_with_settings.db.add(
|
||||||
|
embeddings=None,
|
||||||
|
documents=["document", "document2"],
|
||||||
|
metadatas=[{"value": "somevalue"}],
|
||||||
|
ids=["id"],
|
||||||
|
skip_embedding=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# After adding, should contain no item
|
||||||
|
self.assertEqual(self.app_with_settings.db.count(), 0)
|
||||||
|
|
||||||
def test_collections_are_persistent(self):
|
def test_collections_are_persistent(self):
|
||||||
"""
|
"""
|
||||||
Test that a collection can be picked up later.
|
Test that a collection can be picked up later.
|
||||||
|
|||||||
Reference in New Issue
Block a user