From d94aee812b90249c304e78811d9bef2463029ce4 Mon Sep 17 00:00:00 2001 From: Michael Date: Sun, 11 Feb 2024 15:45:02 -0800 Subject: [PATCH] [Improvements] Fixes to null data results and OpenAI embedding limits (#1238) --- embedchain/embedchain.py | 29 ++++++++++++++++++++++++++--- embedchain/vectordb/weaviate.py | 3 +++ 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 173e7928..b07a0ee0 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -429,16 +429,36 @@ class EmbedChain(JSONSerializable): if dry_run: return list(documents), metadatas, ids, 0 - + # Count before, to calculate a delta in the end. chunks_before_addition = self.db.count() - self.db.add(documents=documents, metadatas=metadatas, ids=ids, **kwargs) - count_new_chunks = self.db.count() - chunks_before_addition + + # Filter out empty documents and ensure they meet the API requirements + valid_documents = [doc for doc in documents if doc and isinstance(doc, str)] + documents = valid_documents + + # Chunk documents into batches of 2048 and handle each batch + # helps wigth large loads of embeddings that hit OpenAI limits + document_batches = [documents[i:i+2048] for i in range(0, len(documents), 2048)] + for batch in document_batches: + try: + # Add only valid batches + if batch: + self.db.add(documents=batch, metadatas=metadatas, ids=ids, **kwargs) + except Exception as e: + print(f"Failed to add batch due to a bad request: {e}") + # Handle the error, e.g., by logging, retrying, or skipping + pass + + + count_new_chunks = self.db.count() - chunks_before_addition print(f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}") + return list(documents), metadatas, ids, count_new_chunks + @staticmethod def _format_result(results): return [ @@ -473,7 +493,9 @@ class EmbedChain(JSONSerializable): :return: List of contents of the document that matched your query :rtype: list[str] """ + print("Query passed in config:", config) query_config = config or self.llm.config + print("Final config:", query_config) if where is not None: where = where else: @@ -484,6 +506,7 @@ class EmbedChain(JSONSerializable): if self.config.id is not None: where.update({"app_id": self.config.id}) + print('Number documents', query_config) contexts = self.db.query( input_query=input_query, n_results=query_config.number_documents, diff --git a/embedchain/vectordb/weaviate.py b/embedchain/vectordb/weaviate.py index 446db496..13693f61 100644 --- a/embedchain/vectordb/weaviate.py +++ b/embedchain/vectordb/weaviate.py @@ -274,6 +274,9 @@ class WeaviateDB(BaseVectorDB): .do() ) + if results["data"]["Get"].get(self.index_name) is None: + return [] + docs = results["data"]["Get"].get(self.index_name) contexts = [] for doc in docs: