Add support for image dataset (#571)
Co-authored-by: Rupesh Bansal <rupeshbansal@Shankars-MacBook-Air.local>
This commit is contained in:
64
embedchain/models/clip_processor.py
Normal file
64
embedchain/models/clip_processor.py
Normal file
@@ -0,0 +1,64 @@
|
||||
try:
|
||||
import torch
|
||||
import clip
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
except ImportError:
|
||||
raise ImportError("Images requires extra dependencies. Install with `pip install embedchain[images]`") from None
|
||||
|
||||
MODEL_NAME = "ViT-B/32"
|
||||
|
||||
|
||||
class ClipProcessor:
|
||||
@staticmethod
|
||||
def load_model():
|
||||
"""Load data from a director of images."""
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# load model and image preprocessing
|
||||
model, preprocess = clip.load(MODEL_NAME, device=device, jit=False)
|
||||
return model, preprocess
|
||||
|
||||
@staticmethod
|
||||
def get_image_features(image_url, model, preprocess):
|
||||
"""
|
||||
Applies the CLIP model to evaluate the vector representation of the supplied image
|
||||
"""
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
try:
|
||||
# load image
|
||||
image = Image.open(image_url)
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError("The supplied file does not exist`")
|
||||
except UnidentifiedImageError:
|
||||
raise UnidentifiedImageError("The supplied file is not an image`")
|
||||
|
||||
# pre-process image
|
||||
processed_image = preprocess(image).unsqueeze(0).to(device)
|
||||
with torch.no_grad():
|
||||
image_features = model.encode_image(processed_image)
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
image_features = image_features.cpu().detach().numpy().tolist()[0]
|
||||
meta_data = {
|
||||
"url": image_url
|
||||
}
|
||||
return {
|
||||
"content": image_url,
|
||||
"embedding": image_features,
|
||||
"meta_data": meta_data
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_text_features(query):
|
||||
"""
|
||||
Applies the CLIP model to evaluate the vector representation of the supplied text
|
||||
"""
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
model, preprocess = ClipProcessor.load_model()
|
||||
text = clip.tokenize(query).to(device)
|
||||
with torch.no_grad():
|
||||
text_features = model.encode_text(text)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
return text_features.cpu().numpy().tolist()[0]
|
||||
@@ -23,6 +23,7 @@ class IndirectDataType(Enum):
|
||||
NOTION = "notion"
|
||||
CSV = "csv"
|
||||
MDX = "mdx"
|
||||
IMAGES = "images"
|
||||
|
||||
|
||||
class SpecialDataType(Enum):
|
||||
@@ -45,3 +46,4 @@ class DataType(Enum):
|
||||
CSV = IndirectDataType.CSV.value
|
||||
MDX = IndirectDataType.MDX.value
|
||||
QNA_PAIR = SpecialDataType.QNA_PAIR.value
|
||||
IMAGES = IndirectDataType.IMAGES.value
|
||||
|
||||
Reference in New Issue
Block a user