Fix batch_size for vectordb (#1449)

This commit is contained in:
Dev Khant
2024-06-28 23:48:22 +05:30
committed by GitHub
parent 0a78198bb5
commit 50c0285cb2
15 changed files with 49 additions and 26 deletions

View File

@@ -42,6 +42,7 @@ class ChromaDB(BaseVectorDB):
self.settings = Settings(anonymized_telemetry=False)
self.settings.allow_reset = self.config.allow_reset if hasattr(self.config, "allow_reset") else False
self.batch_size = self.config.batch_size
if self.config.chroma_settings:
for key, value in self.config.chroma_settings.items():
if hasattr(self.settings, key):
@@ -153,12 +154,13 @@ class ChromaDB(BaseVectorDB):
" Ids size: {}".format(len(documents), len(metadatas), len(ids))
)
for i in tqdm(range(0, len(documents), self.config.batch_size), desc="Inserting batches in chromadb"):
for i in tqdm(range(0, len(documents), self.batch_size), desc="Inserting batches in chromadb"):
self.collection.add(
documents=documents[i : i + self.config.batch_size],
metadatas=metadatas[i : i + self.config.batch_size],
ids=ids[i : i + self.config.batch_size],
documents=documents[i : i + self.batch_size],
metadatas=metadatas[i : i + self.batch_size],
ids=ids[i : i + self.batch_size],
)
self.config
@staticmethod
def _format_result(results: QueryResult) -> list[tuple[Document, float]]:

View File

@@ -55,6 +55,7 @@ class ElasticsearchDB(BaseVectorDB):
"Something is wrong with your config. Please check again - `https://docs.embedchain.ai/components/vector-databases#elasticsearch`" # noqa: E501
)
self.batch_size = self.config.batch_size
# Call parent init here because embedder is needed
super().__init__(config=self.config)
@@ -139,7 +140,7 @@ class ElasticsearchDB(BaseVectorDB):
for chunk in chunks(
list(zip(ids, documents, metadatas, embeddings)),
self.config.batch_size,
self.batch_size,
desc="Inserting batches in elasticsearch",
): # noqa: E501
ids, docs, metadatas, embeddings = [], [], [], []

View File

@@ -37,6 +37,7 @@ class OpenSearchDB(BaseVectorDB):
if config is None:
raise ValueError("OpenSearchDBConfig is required")
self.config = config
self.batch_size = self.config.batch_size
self.client = OpenSearch(
hosts=[self.config.opensearch_url],
http_auth=self.config.http_auth,
@@ -118,10 +119,8 @@ class OpenSearchDB(BaseVectorDB):
"""Adds documents to the opensearch index"""
embeddings = self.embedder.embedding_fn(documents)
for batch_start in tqdm(
range(0, len(documents), self.config.batch_size), desc="Inserting batches in opensearch"
):
batch_end = batch_start + self.config.batch_size
for batch_start in tqdm(range(0, len(documents), self.batch_size), desc="Inserting batches in opensearch"):
batch_end = batch_start + self.batch_size
batch_documents = documents[batch_start:batch_end]
batch_embeddings = embeddings[batch_start:batch_end]

View File

@@ -48,6 +48,7 @@ class PineconeDB(BaseVectorDB):
# Setup BM25Encoder if sparse vectors are to be used
self.bm25_encoder = None
self.batch_size = self.config.batch_size
if self.config.hybrid_search:
logger.info("Initializing BM25Encoder for sparse vectors..")
self.bm25_encoder = self.config.bm25_encoder if self.config.bm25_encoder else BM25Encoder.default()
@@ -102,8 +103,8 @@ class PineconeDB(BaseVectorDB):
metadatas = []
if ids is not None:
for i in range(0, len(ids), self.config.batch_size):
result = self.pinecone_index.fetch(ids=ids[i : i + self.config.batch_size])
for i in range(0, len(ids), self.batch_size):
result = self.pinecone_index.fetch(ids=ids[i : i + self.batch_size])
vectors = result.get("vectors")
batch_existing_ids = list(vectors.keys())
existing_ids.extend(batch_existing_ids)
@@ -142,7 +143,7 @@ class PineconeDB(BaseVectorDB):
},
)
for chunk in chunks(docs, self.config.batch_size, desc="Adding chunks in batches"):
for chunk in chunks(docs, self.batch_size, desc="Adding chunks in batches"):
self.pinecone_index.upsert(chunk, **kwargs)
def query(

View File

@@ -35,6 +35,7 @@ class QdrantDB(BaseVectorDB):
"Please make sure the type is right and that you are passing an instance."
)
self.config = config
self.batch_size = self.config.batch_size
self.client = QdrantClient(url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))
# Call parent init here because embedder is needed
super().__init__(config=self.config)
@@ -114,7 +115,7 @@ class QdrantDB(BaseVectorDB):
collection_name=self.collection_name,
scroll_filter=models.Filter(must=qdrant_must_filters),
offset=offset,
limit=self.config.batch_size,
limit=self.batch_size,
)
offset = response[1]
for doc in response[0]:
@@ -146,13 +147,13 @@ class QdrantDB(BaseVectorDB):
qdrant_ids.append(id)
payloads.append({"identifier": id, "text": document, "metadata": copy.deepcopy(metadata)})
for i in tqdm(range(0, len(qdrant_ids), self.config.batch_size), desc="Adding data in batches"):
for i in tqdm(range(0, len(qdrant_ids), self.batch_size), desc="Adding data in batches"):
self.client.upsert(
collection_name=self.collection_name,
points=Batch(
ids=qdrant_ids[i : i + self.config.batch_size],
payloads=payloads[i : i + self.config.batch_size],
vectors=embeddings[i : i + self.config.batch_size],
ids=qdrant_ids[i : i + self.batch_size],
payloads=payloads[i : i + self.batch_size],
vectors=embeddings[i : i + self.batch_size],
),
**kwargs,
)

View File

@@ -38,6 +38,7 @@ class WeaviateDB(BaseVectorDB):
"Please make sure the type is right and that you are passing an instance."
)
self.config = config
self.batch_size = self.config.batch_size
self.client = weaviate.Client(
url=os.environ.get("WEAVIATE_ENDPOINT"),
auth_client_secret=weaviate.AuthApiKey(api_key=os.environ.get("WEAVIATE_API_KEY")),
@@ -167,7 +168,7 @@ class WeaviateDB(BaseVectorDB):
)
.with_where(weaviate_where_clause)
.with_additional(["id"])
.with_limit(limit or self.config.batch_size),
.with_limit(limit or self.batch_size),
offset,
)
@@ -196,7 +197,7 @@ class WeaviateDB(BaseVectorDB):
:type ids: list[str]
"""
embeddings = self.embedder.embedding_fn(documents)
self.client.batch.configure(batch_size=self.config.batch_size, timeout_retries=3) # Configure batch
self.client.batch.configure(batch_size=self.batch_size, timeout_retries=3) # Configure batch
with self.client.batch as batch: # Initialize a batch process
for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
doc = {"identifier": id, "text": text}