Added Clip dependency (#778)

This commit is contained in:
Rupesh Bansal
2023-10-10 00:32:45 +05:30
committed by GitHub
parent bc649b9a85
commit 19a9141c2d
4 changed files with 47 additions and 70 deletions

View File

@@ -16,15 +16,15 @@ class ImagesLoader(BaseLoader):
# load model and image preprocessing # load model and image preprocessing
from embedchain.models.clip_processor import ClipProcessor from embedchain.models.clip_processor import ClipProcessor
model, preprocess = ClipProcessor.load_model() model = ClipProcessor.load_model()
if os.path.isfile(image_url): if os.path.isfile(image_url):
data = [ClipProcessor.get_image_features(image_url, model, preprocess)] data = [ClipProcessor.get_image_features(image_url, model)]
else: else:
data = [] data = []
for filename in os.listdir(image_url): for filename in os.listdir(image_url):
filepath = os.path.join(image_url, filename) filepath = os.path.join(image_url, filename)
try: try:
data.append(ClipProcessor.get_image_features(filepath, model, preprocess)) data.append(ClipProcessor.get_image_features(filepath, model))
except Exception as e: except Exception as e:
# Log the file that was not loaded # Log the file that was not loaded
logging.exception("Failed to load the file {}. Exception {}".format(filepath, e)) logging.exception("Failed to load the file {}. Exception {}".format(filepath, e))

View File

@@ -1,31 +1,27 @@
try: try:
import clip
import torch
from PIL import Image, UnidentifiedImageError from PIL import Image, UnidentifiedImageError
from sentence_transformers import SentenceTransformer
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Images requires extra dependencies. Install with `pip install 'embedchain[images]' git+https://github.com/openai/CLIP.git#a1d0717`" # noqa: E501 "Images requires extra dependencies. Install with `pip install 'embedchain[images]'"
) from None ) from None
MODEL_NAME = "ViT-B/32" MODEL_NAME = "clip-ViT-B-32"
class ClipProcessor: class ClipProcessor:
@staticmethod @staticmethod
def load_model(): def load_model():
"""Load data from a director of images.""" """Load data from a director of images."""
device = "cuda" if torch.cuda.is_available() else "cpu"
# load model and image preprocessing # load model and image preprocessing
model, preprocess = clip.load(MODEL_NAME, device=device, jit=False) model = SentenceTransformer(MODEL_NAME)
return model, preprocess return model
@staticmethod @staticmethod
def get_image_features(image_url, model, preprocess): def get_image_features(image_url, model):
""" """
Applies the CLIP model to evaluate the vector representation of the supplied image Applies the CLIP model to evaluate the vector representation of the supplied image
""" """
device = "cuda" if torch.cuda.is_available() else "cpu"
try: try:
# load image # load image
image = Image.open(image_url) image = Image.open(image_url)
@@ -34,27 +30,15 @@ class ClipProcessor:
except UnidentifiedImageError: except UnidentifiedImageError:
raise UnidentifiedImageError("The supplied file is not an image`") raise UnidentifiedImageError("The supplied file is not an image`")
# pre-process image image_features = model.encode(image)
processed_image = preprocess(image).unsqueeze(0).to(device)
with torch.no_grad():
image_features = model.encode_image(processed_image)
image_features /= image_features.norm(dim=-1, keepdim=True)
image_features = image_features.cpu().detach().numpy().tolist()[0]
meta_data = {"url": image_url} meta_data = {"url": image_url}
return {"content": image_url, "embedding": image_features, "meta_data": meta_data} return {"content": image_url, "embedding": image_features.tolist(), "meta_data": meta_data}
@staticmethod @staticmethod
def get_text_features(query): def get_text_features(query):
""" """
Applies the CLIP model to evaluate the vector representation of the supplied text Applies the CLIP model to evaluate the vector representation of the supplied text
""" """
device = "cuda" if torch.cuda.is_available() else "cpu" model = ClipProcessor.load_model()
text_features = model.encode(query)
model, preprocess = ClipProcessor.load_model() return text_features.tolist()
text = clip.tokenize(query).to(device)
with torch.no_grad():
text_features = model.encode_text(text)
text_features /= text_features.norm(dim=-1, keepdim=True)
return text_features.cpu().numpy().tolist()[0]

View File

@@ -128,7 +128,8 @@ def detect_datatype(source: Any) -> DataType:
formatted_source = format_source(str(source), 30) formatted_source = format_source(str(source), 30)
if url: if url:
from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS from langchain.document_loaders.youtube import \
ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
if url.netloc in YOUTUBE_ALLOWED_NETLOCS: if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.") logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")

View File

@@ -1,51 +1,43 @@
# import os import os
# import tempfile import tempfile
# import urllib import urllib
# import pytest from PIL import Image
# from PIL import Image
# TODO: Uncomment after fixing clip dependency issue from embedchain.models.clip_processor import ClipProcessor
# from embedchain.models.clip_processor import ClipProcessor
# class TestClipProcessor: class TestClipProcessor:
# @pytest.mark.xfail(reason="This test is failing because of the missing CLIP dependency.") def test_load_model(self):
# def test_load_model(self): # Test that the `load_model()` method loads the CLIP model and image preprocessing correctly.
# # Test that the `load_model()` method loads the CLIP model and image preprocessing correctly. model = ClipProcessor.load_model()
# model, preprocess = ClipProcessor.load_model() assert model is not None
# assert model is not None
# assert preprocess is not None
# @pytest.mark.xfail(reason="This test is failing because of the missing CLIP dependency.") def test_get_image_features(self):
# def test_get_image_features(self): # Clone the image to a temporary folder.
# # Clone the image to a temporary folder. with tempfile.TemporaryDirectory() as tmp_dir:
# with tempfile.TemporaryDirectory() as tmp_dir: urllib.request.urlretrieve("https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg", "image.jpg")
# urllib.request.urlretrieve("https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg", "image.jpg")
# image = Image.open("image.jpg") image = Image.open("image.jpg")
# image.save(os.path.join(tmp_dir, "image.jpg")) image.save(os.path.join(tmp_dir, "image.jpg"))
# # Get the image features. # Get the image features.
# model, preprocess = ClipProcessor.load_model() model = ClipProcessor.load_model()
# ClipProcessor.get_image_features(os.path.join(tmp_dir, "image.jpg"), model, preprocess) ClipProcessor.get_image_features(os.path.join(tmp_dir, "image.jpg"), model)
# # Delete the temporary file. # Delete the temporary file.
# os.remove(os.path.join(tmp_dir, "image.jpg")) os.remove(os.path.join(tmp_dir, "image.jpg"))
# @pytest.mark.xfail(reason="This test is failing because of the missing CLIP dependency.") def test_get_text_features(self):
# def test_get_text_features(self): # Test that the `get_text_features()` method returns a list containing the text embedding.
# # Test that the `get_text_features()` method returns a list containing the text embedding. query = "This is a text query."
# query = "This is a text query." text_features = ClipProcessor.get_text_features(query)
# model, preprocess = ClipProcessor.load_model()
# text_features = ClipProcessor.get_text_features(query) # Assert that the text embedding is not None.
assert text_features is not None
# # Assert that the text embedding is not None. # Assert that the text embedding is a list of floats.
# assert text_features is not None assert isinstance(text_features, list)
# # Assert that the text embedding is a list of floats. # Assert that the text embedding has the correct length.
# assert isinstance(text_features, list) assert len(text_features) == 512
# # Assert that the text embedding has the correct length.
# assert len(text_features) == 512