Add support for image dataset (#571)

Co-authored-by: Rupesh Bansal <rupeshbansal@Shankars-MacBook-Air.local>
This commit is contained in:
Rupesh Bansal
2023-10-04 09:50:40 +05:30
committed by GitHub
parent 55e9a1cbd6
commit d0af018b8d
19 changed files with 498 additions and 31 deletions

View File

@@ -66,3 +66,6 @@ class BaseChunker(JSONSerializable):
self.data_type = data_type
# TODO: This should be done during initialization. This means it has to be done in the child classes.
def get_word_count(self, documents):
return sum([len(document.split(" ")) for document in documents])

View File

@@ -0,0 +1,63 @@
import hashlib
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
class ImagesChunker(BaseChunker):
"""Chunker for an Image."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=300, chunk_overlap=0, length_function=len)
image_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(image_splitter)
def create_chunks(self, loader, src):
"""
Loads the image(s), and creates their corresponding embedding. This creates one chunk for each image
:param loader: The loader whose `load_data` method is used to create
the raw data.
:param src: The data to be handled by the loader. Can be a URL for
remote sources or local content for local loaders.
"""
documents = []
embeddings = []
ids = []
data_result = loader.load_data(src)
data_records = data_result["data"]
doc_id = data_result["doc_id"]
metadatas = []
for data in data_records:
meta_data = data["meta_data"]
# add data type to meta data to allow query using data type
meta_data["data_type"] = self.data_type.value
chunk_id = hashlib.sha256(meta_data["url"].encode()).hexdigest()
ids.append(chunk_id)
documents.append(data["content"])
embeddings.append(data["embedding"])
meta_data["doc_id"] = doc_id
metadatas.append(meta_data)
return {
"documents": documents,
"embeddings": embeddings,
"ids": ids,
"metadatas": metadatas,
"doc_id": doc_id,
}
def get_word_count(self, documents):
"""
The number of chunks and the corresponding word count for an image is fixed to 1, as 1 embedding is created for
each image
"""
return 1

View File

@@ -67,6 +67,7 @@ class BaseLlmConfig(BaseConfig):
deployment_name: Optional[str] = None,
system_prompt: Optional[str] = None,
where: Dict[str, Any] = None,
query_type: Optional[str] = None
):
"""
Initializes a configuration class instance for the LLM.
@@ -112,6 +113,7 @@ class BaseLlmConfig(BaseConfig):
self.top_p = top_p
self.deployment_name = deployment_name
self.system_prompt = system_prompt
self.query_type = query_type
if self.validate_template(template):
self.template = template

View File

@@ -2,6 +2,7 @@ from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.chunkers.docs_site import DocsSiteChunker
from embedchain.chunkers.docx_file import DocxFileChunker
from embedchain.chunkers.mdx import MdxChunker
from embedchain.chunkers.images import ImagesChunker
from embedchain.chunkers.notion import NotionChunker
from embedchain.chunkers.pdf_file import PdfFileChunker
from embedchain.chunkers.qna_pair import QnaPairChunker
@@ -16,6 +17,7 @@ from embedchain.loaders.base_loader import BaseLoader
from embedchain.loaders.csv import CsvLoader
from embedchain.loaders.docs_site_loader import DocsSiteLoader
from embedchain.loaders.docx_file import DocxFileLoader
from embedchain.loaders.images import ImagesLoader
from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
from embedchain.loaders.local_text import LocalTextLoader
from embedchain.loaders.mdx import MdxLoader
@@ -68,6 +70,7 @@ class DataFormatter(JSONSerializable):
DataType.DOCS_SITE: DocsSiteLoader,
DataType.CSV: CsvLoader,
DataType.MDX: MdxLoader,
DataType.IMAGES: ImagesLoader,
}
lazy_loaders = {DataType.NOTION}
if data_type in loaders:
@@ -102,11 +105,11 @@ class DataFormatter(JSONSerializable):
DataType.QNA_PAIR: QnaPairChunker,
DataType.TEXT: TextChunker,
DataType.DOCX: DocxFileChunker,
DataType.WEB_PAGE: WebPageChunker,
DataType.DOCS_SITE: DocsSiteChunker,
DataType.NOTION: NotionChunker,
DataType.CSV: TableChunker,
DataType.MDX: MdxChunker,
DataType.IMAGES: ImagesChunker,
}
if data_type in chunker_classes:
chunker_class: type = chunker_classes[data_type]

View File

@@ -212,7 +212,7 @@ class EmbedChain(JSONSerializable):
# Send anonymous telemetry
if self.config.collect_metrics:
# it's quicker to check the variable twice than to count words when they won't be submitted.
word_count = sum([len(document.split(" ")) for document in documents])
word_count = data_formatter.chunker.get_word_count(documents)
extra_metadata = {"data_type": data_type.value, "word_count": word_count, "chunks_count": new_chunks}
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("add", extra_metadata))
@@ -329,7 +329,6 @@ class EmbedChain(JSONSerializable):
# Create chunks
embeddings_data = chunker.create_chunks(loader, src)
# spread chunking results
documents = embeddings_data["documents"]
metadatas = embeddings_data["metadatas"]
@@ -393,7 +392,8 @@ class EmbedChain(JSONSerializable):
# Count before, to calculate a delta in the end.
chunks_before_addition = self.db.count()
self.db.add(documents=documents, metadatas=metadatas, ids=ids)
self.db.add(embeddings=embeddings_data.get("embeddings", None), documents=documents, metadatas=metadatas,
ids=ids, skip_embedding = (chunker.data_type == DataType.IMAGES))
count_new_chunks = self.db.count() - chunks_before_addition
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
return list(documents), metadatas, ids, count_new_chunks
@@ -434,10 +434,20 @@ class EmbedChain(JSONSerializable):
if self.config.id is not None:
where.update({"app_id": self.config.id})
# We cannot query the database with the input query in case of an image search. This is because we need
# to bring down both the image and text to the same dimension to be able to compare them.
db_query = input_query
if config.query_type == "Images":
# We import the clip processor here to make sure the package is not dependent on clip dependency even if the
# image dataset is not being used
from embedchain.models.clip_processor import ClipProcessor
db_query = ClipProcessor.get_text_features(query=input_query)
contents = self.db.query(
input_query=input_query,
input_query=db_query,
n_results=query_config.number_documents,
where=where,
skip_embedding = (config.query_type == "Images")
)
return contents

View File

@@ -191,6 +191,9 @@ class BaseLlm(JSONSerializable):
prev_config = self.config.serialize()
self.config = config
if config is not None and config.query_type == "Images":
return contexts
if self.is_docs_site_instance:
self.config.template = DOCS_SITE_PROMPT_TEMPLATE
self.config.number_documents = 5

View File

@@ -0,0 +1,37 @@
import os
import logging
import hashlib
from embedchain.loaders.base_loader import BaseLoader
class ImagesLoader(BaseLoader):
def load_data(self, image_url):
"""
Loads images from the supplied directory/file and applies CLIP model transformation to represent these images
in vector form
:param image_url: The URL from which the images are to be loaded
"""
# load model and image preprocessing
from embedchain.models.clip_processor import ClipProcessor
model, preprocess = ClipProcessor.load_model()
if os.path.isfile(image_url):
data = [ClipProcessor.get_image_features(image_url, model, preprocess)]
else:
data = []
for filename in os.listdir(image_url):
filepath = os.path.join(image_url, filename)
try:
data.append(ClipProcessor.get_image_features(filepath, model, preprocess))
except Exception as e:
# Log the file that was not loaded
logging.exception("Failed to load the file {}. Exception {}".format(filepath, e))
# Get the metadata like Size, Last Modified and Last Created timestamps
image_path_metadata = [str(os.path.getsize(image_url)), str(os.path.getmtime(image_url)),
str(os.path.getctime(image_url))]
doc_id = hashlib.sha256((" ".join(image_path_metadata) + image_url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": data,
}

View File

@@ -0,0 +1,64 @@
try:
import torch
import clip
from PIL import Image, UnidentifiedImageError
except ImportError:
raise ImportError("Images requires extra dependencies. Install with `pip install embedchain[images]`") from None
MODEL_NAME = "ViT-B/32"
class ClipProcessor:
@staticmethod
def load_model():
"""Load data from a director of images."""
device = "cuda" if torch.cuda.is_available() else "cpu"
# load model and image preprocessing
model, preprocess = clip.load(MODEL_NAME, device=device, jit=False)
return model, preprocess
@staticmethod
def get_image_features(image_url, model, preprocess):
"""
Applies the CLIP model to evaluate the vector representation of the supplied image
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
try:
# load image
image = Image.open(image_url)
except FileNotFoundError:
raise FileNotFoundError("The supplied file does not exist`")
except UnidentifiedImageError:
raise UnidentifiedImageError("The supplied file is not an image`")
# pre-process image
processed_image = preprocess(image).unsqueeze(0).to(device)
with torch.no_grad():
image_features = model.encode_image(processed_image)
image_features /= image_features.norm(dim=-1, keepdim=True)
image_features = image_features.cpu().detach().numpy().tolist()[0]
meta_data = {
"url": image_url
}
return {
"content": image_url,
"embedding": image_features,
"meta_data": meta_data
}
@staticmethod
def get_text_features(query):
"""
Applies the CLIP model to evaluate the vector representation of the supplied text
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = ClipProcessor.load_model()
text = clip.tokenize(query).to(device)
with torch.no_grad():
text_features = model.encode_text(text)
text_features /= text_features.norm(dim=-1, keepdim=True)
return text_features.cpu().numpy().tolist()[0]

View File

@@ -23,6 +23,7 @@ class IndirectDataType(Enum):
NOTION = "notion"
CSV = "csv"
MDX = "mdx"
IMAGES = "images"
class SpecialDataType(Enum):
@@ -45,3 +46,4 @@ class DataType(Enum):
CSV = IndirectDataType.CSV.value
MDX = IndirectDataType.MDX.value
QNA_PAIR = SpecialDataType.QNA_PAIR.value
IMAGES = IndirectDataType.IMAGES.value

View File

@@ -115,7 +115,8 @@ class ChromaDB(BaseVectorDB):
def get_advanced(self, where):
return self.collection.get(where=where, limit=1)
def add(self, documents: List[str], metadatas: List[object], ids: List[str]) -> Any:
def add(self, embeddings: List[List[float]], documents: List[str], metadatas: List[object],
ids: List[str], skip_embedding: bool) -> Any:
"""
Add vectors to chroma database
@@ -126,7 +127,10 @@ class ChromaDB(BaseVectorDB):
:param ids: ids
:type ids: List[str]
"""
self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
if skip_embedding:
self.collection.add(embeddings=embeddings, documents=documents, metadatas=metadatas, ids=ids)
else:
self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
def _format_result(self, results: QueryResult) -> list[tuple[Document, float]]:
"""
@@ -146,7 +150,7 @@ class ChromaDB(BaseVectorDB):
)
]
def query(self, input_query: List[str], n_results: int, where: Dict[str, Any]) -> List[str]:
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
"""
Query contents from vector data base based on vector similarity
@@ -161,19 +165,27 @@ class ChromaDB(BaseVectorDB):
:rtype: List[str]
"""
try:
result = self.collection.query(
query_texts=[
input_query,
],
n_results=n_results,
where=where,
)
if skip_embedding:
result = self.collection.query(
query_embeddings=[
input_query,
],
n_results=n_results,
where=where,
)
else:
result = self.collection.query(
query_texts=[
input_query,
],
n_results=n_results,
where=where,
)
except InvalidDimensionException as e:
raise InvalidDimensionException(
e.message()
+ ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501
+ ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501
) from None
results_formatted = self._format_result(result)
contents = [result[0].page_content for result in results_formatted]
return contents

View File

@@ -1,5 +1,5 @@
import logging
from typing import Dict, List, Optional, Set
from typing import Any, Dict, List, Optional, Set
try:
from elasticsearch import Elasticsearch
@@ -100,9 +100,10 @@ class ElasticsearchDB(BaseVectorDB):
ids = [doc["_id"] for doc in docs]
return {"ids": set(ids)}
def add(self, documents: List[str], metadatas: List[object], ids: List[str]):
"""add data in vector database
def add(self, embeddings: List[List[float]], documents: List[str], metadatas: List[object],
ids: List[str], skip_embedding: bool) -> Any:
"""
add data in vector database
:param documents: list of texts to add
:type documents: List[str]
:param metadatas: list of metadata associated with docs
@@ -112,7 +113,9 @@ class ElasticsearchDB(BaseVectorDB):
"""
docs = []
embeddings = self.embedder.embedding_fn(documents)
if not skip_embedding:
embeddings = self.embedder.embedding_fn(documents)
for id, text, metadata, embeddings in zip(ids, documents, metadatas, embeddings):
docs.append(
{
@@ -124,7 +127,7 @@ class ElasticsearchDB(BaseVectorDB):
bulk(self.client, docs)
self.client.indices.refresh(index=self._get_index())
def query(self, input_query: List[str], n_results: int, where: Dict[str, any]) -> List[str]:
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
"""
query contents from vector data base based on vector similarity
@@ -137,8 +140,12 @@ class ElasticsearchDB(BaseVectorDB):
:return: Database contents that are the result of the query
:rtype: List[str]
"""
input_query_vector = self.embedder.embedding_fn(input_query)
query_vector = input_query_vector[0]
if skip_embedding:
query_vector = input_query
else:
input_query_vector = self.embedder.embedding_fn(input_query)
query_vector = input_query_vector[0]
query = {
"script_score": {
"query": {"bool": {"must": [{"exists": {"field": "text"}}]}},