Add support for image dataset (#571)

Co-authored-by: Rupesh Bansal <rupeshbansal@Shankars-MacBook-Air.local>
This commit is contained in:
Rupesh Bansal
2023-10-04 09:50:40 +05:30
committed by GitHub
parent 55e9a1cbd6
commit d0af018b8d
19 changed files with 498 additions and 31 deletions

View File

@@ -212,7 +212,7 @@ class EmbedChain(JSONSerializable):
# Send anonymous telemetry
if self.config.collect_metrics:
# it's quicker to check the variable twice than to count words when they won't be submitted.
word_count = sum([len(document.split(" ")) for document in documents])
word_count = data_formatter.chunker.get_word_count(documents)
extra_metadata = {"data_type": data_type.value, "word_count": word_count, "chunks_count": new_chunks}
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("add", extra_metadata))
@@ -329,7 +329,6 @@ class EmbedChain(JSONSerializable):
# Create chunks
embeddings_data = chunker.create_chunks(loader, src)
# spread chunking results
documents = embeddings_data["documents"]
metadatas = embeddings_data["metadatas"]
@@ -393,7 +392,8 @@ class EmbedChain(JSONSerializable):
# Count before, to calculate a delta in the end.
chunks_before_addition = self.db.count()
self.db.add(documents=documents, metadatas=metadatas, ids=ids)
self.db.add(embeddings=embeddings_data.get("embeddings", None), documents=documents, metadatas=metadatas,
ids=ids, skip_embedding = (chunker.data_type == DataType.IMAGES))
count_new_chunks = self.db.count() - chunks_before_addition
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
return list(documents), metadatas, ids, count_new_chunks
@@ -434,10 +434,20 @@ class EmbedChain(JSONSerializable):
if self.config.id is not None:
where.update({"app_id": self.config.id})
# We cannot query the database with the input query in case of an image search. This is because we need
# to bring down both the image and text to the same dimension to be able to compare them.
db_query = input_query
if config.query_type == "Images":
# We import the clip processor here to make sure the package is not dependent on clip dependency even if the
# image dataset is not being used
from embedchain.models.clip_processor import ClipProcessor
db_query = ClipProcessor.get_text_features(query=input_query)
contents = self.db.query(
input_query=input_query,
input_query=db_query,
n_results=query_config.number_documents,
where=where,
skip_embedding = (config.query_type == "Images")
)
return contents