From 19a9141c2dac16915240e4566be29134ed7f475e Mon Sep 17 00:00:00 2001 From: Rupesh Bansal Date: Tue, 10 Oct 2023 00:32:45 +0530 Subject: [PATCH] Added Clip dependency (#778) --- embedchain/loaders/images.py | 6 +-- embedchain/models/clip_processor.py | 38 +++++----------- embedchain/utils.py | 3 +- tests/models/test_clip_processor.py | 70 +++++++++++++---------------- 4 files changed, 47 insertions(+), 70 deletions(-) diff --git a/embedchain/loaders/images.py b/embedchain/loaders/images.py index f80afa9b..bd954b0d 100644 --- a/embedchain/loaders/images.py +++ b/embedchain/loaders/images.py @@ -16,15 +16,15 @@ class ImagesLoader(BaseLoader): # load model and image preprocessing from embedchain.models.clip_processor import ClipProcessor - model, preprocess = ClipProcessor.load_model() + model = ClipProcessor.load_model() if os.path.isfile(image_url): - data = [ClipProcessor.get_image_features(image_url, model, preprocess)] + data = [ClipProcessor.get_image_features(image_url, model)] else: data = [] for filename in os.listdir(image_url): filepath = os.path.join(image_url, filename) try: - data.append(ClipProcessor.get_image_features(filepath, model, preprocess)) + data.append(ClipProcessor.get_image_features(filepath, model)) except Exception as e: # Log the file that was not loaded logging.exception("Failed to load the file {}. Exception {}".format(filepath, e)) diff --git a/embedchain/models/clip_processor.py b/embedchain/models/clip_processor.py index f6f8b54f..e349f4ae 100644 --- a/embedchain/models/clip_processor.py +++ b/embedchain/models/clip_processor.py @@ -1,31 +1,27 @@ try: - import clip - import torch from PIL import Image, UnidentifiedImageError + from sentence_transformers import SentenceTransformer except ImportError: raise ImportError( - "Images requires extra dependencies. Install with `pip install 'embedchain[images]' git+https://github.com/openai/CLIP.git#a1d0717`" # noqa: E501 + "Images requires extra dependencies. Install with `pip install 'embedchain[images]'" ) from None -MODEL_NAME = "ViT-B/32" +MODEL_NAME = "clip-ViT-B-32" class ClipProcessor: @staticmethod def load_model(): """Load data from a director of images.""" - device = "cuda" if torch.cuda.is_available() else "cpu" - # load model and image preprocessing - model, preprocess = clip.load(MODEL_NAME, device=device, jit=False) - return model, preprocess + model = SentenceTransformer(MODEL_NAME) + return model @staticmethod - def get_image_features(image_url, model, preprocess): + def get_image_features(image_url, model): """ Applies the CLIP model to evaluate the vector representation of the supplied image """ - device = "cuda" if torch.cuda.is_available() else "cpu" try: # load image image = Image.open(image_url) @@ -34,27 +30,15 @@ class ClipProcessor: except UnidentifiedImageError: raise UnidentifiedImageError("The supplied file is not an image`") - # pre-process image - processed_image = preprocess(image).unsqueeze(0).to(device) - with torch.no_grad(): - image_features = model.encode_image(processed_image) - image_features /= image_features.norm(dim=-1, keepdim=True) - - image_features = image_features.cpu().detach().numpy().tolist()[0] + image_features = model.encode(image) meta_data = {"url": image_url} - return {"content": image_url, "embedding": image_features, "meta_data": meta_data} + return {"content": image_url, "embedding": image_features.tolist(), "meta_data": meta_data} @staticmethod def get_text_features(query): """ Applies the CLIP model to evaluate the vector representation of the supplied text """ - device = "cuda" if torch.cuda.is_available() else "cpu" - - model, preprocess = ClipProcessor.load_model() - text = clip.tokenize(query).to(device) - with torch.no_grad(): - text_features = model.encode_text(text) - text_features /= text_features.norm(dim=-1, keepdim=True) - - return text_features.cpu().numpy().tolist()[0] + model = ClipProcessor.load_model() + text_features = model.encode(query) + return text_features.tolist() diff --git a/embedchain/utils.py b/embedchain/utils.py index 9b5709c3..748cc852 100644 --- a/embedchain/utils.py +++ b/embedchain/utils.py @@ -128,7 +128,8 @@ def detect_datatype(source: Any) -> DataType: formatted_source = format_source(str(source), 30) if url: - from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS + from langchain.document_loaders.youtube import \ + ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS if url.netloc in YOUTUBE_ALLOWED_NETLOCS: logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.") diff --git a/tests/models/test_clip_processor.py b/tests/models/test_clip_processor.py index 9e8500f0..e625662e 100644 --- a/tests/models/test_clip_processor.py +++ b/tests/models/test_clip_processor.py @@ -1,51 +1,43 @@ -# import os -# import tempfile -# import urllib +import os +import tempfile +import urllib -# import pytest -# from PIL import Image +from PIL import Image -# TODO: Uncomment after fixing clip dependency issue -# from embedchain.models.clip_processor import ClipProcessor +from embedchain.models.clip_processor import ClipProcessor -# class TestClipProcessor: -# @pytest.mark.xfail(reason="This test is failing because of the missing CLIP dependency.") -# def test_load_model(self): -# # Test that the `load_model()` method loads the CLIP model and image preprocessing correctly. -# model, preprocess = ClipProcessor.load_model() -# assert model is not None -# assert preprocess is not None +class TestClipProcessor: + def test_load_model(self): + # Test that the `load_model()` method loads the CLIP model and image preprocessing correctly. + model = ClipProcessor.load_model() + assert model is not None -# @pytest.mark.xfail(reason="This test is failing because of the missing CLIP dependency.") -# def test_get_image_features(self): -# # Clone the image to a temporary folder. -# with tempfile.TemporaryDirectory() as tmp_dir: -# urllib.request.urlretrieve("https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg", "image.jpg") + def test_get_image_features(self): + # Clone the image to a temporary folder. + with tempfile.TemporaryDirectory() as tmp_dir: + urllib.request.urlretrieve("https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg", "image.jpg") -# image = Image.open("image.jpg") -# image.save(os.path.join(tmp_dir, "image.jpg")) + image = Image.open("image.jpg") + image.save(os.path.join(tmp_dir, "image.jpg")) -# # Get the image features. -# model, preprocess = ClipProcessor.load_model() -# ClipProcessor.get_image_features(os.path.join(tmp_dir, "image.jpg"), model, preprocess) + # Get the image features. + model = ClipProcessor.load_model() + ClipProcessor.get_image_features(os.path.join(tmp_dir, "image.jpg"), model) -# # Delete the temporary file. -# os.remove(os.path.join(tmp_dir, "image.jpg")) + # Delete the temporary file. + os.remove(os.path.join(tmp_dir, "image.jpg")) -# @pytest.mark.xfail(reason="This test is failing because of the missing CLIP dependency.") -# def test_get_text_features(self): -# # Test that the `get_text_features()` method returns a list containing the text embedding. -# query = "This is a text query." -# model, preprocess = ClipProcessor.load_model() + def test_get_text_features(self): + # Test that the `get_text_features()` method returns a list containing the text embedding. + query = "This is a text query." + text_features = ClipProcessor.get_text_features(query) -# text_features = ClipProcessor.get_text_features(query) + # Assert that the text embedding is not None. + assert text_features is not None -# # Assert that the text embedding is not None. -# assert text_features is not None + # Assert that the text embedding is a list of floats. + assert isinstance(text_features, list) -# # Assert that the text embedding is a list of floats. -# assert isinstance(text_features, list) - -# # Assert that the text embedding has the correct length. -# assert len(text_features) == 512 + # Assert that the text embedding has the correct length. + assert len(text_features) == 512