Added Clip dependency (#778)
This commit is contained in:
@@ -16,15 +16,15 @@ class ImagesLoader(BaseLoader):
|
|||||||
# load model and image preprocessing
|
# load model and image preprocessing
|
||||||
from embedchain.models.clip_processor import ClipProcessor
|
from embedchain.models.clip_processor import ClipProcessor
|
||||||
|
|
||||||
model, preprocess = ClipProcessor.load_model()
|
model = ClipProcessor.load_model()
|
||||||
if os.path.isfile(image_url):
|
if os.path.isfile(image_url):
|
||||||
data = [ClipProcessor.get_image_features(image_url, model, preprocess)]
|
data = [ClipProcessor.get_image_features(image_url, model)]
|
||||||
else:
|
else:
|
||||||
data = []
|
data = []
|
||||||
for filename in os.listdir(image_url):
|
for filename in os.listdir(image_url):
|
||||||
filepath = os.path.join(image_url, filename)
|
filepath = os.path.join(image_url, filename)
|
||||||
try:
|
try:
|
||||||
data.append(ClipProcessor.get_image_features(filepath, model, preprocess))
|
data.append(ClipProcessor.get_image_features(filepath, model))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Log the file that was not loaded
|
# Log the file that was not loaded
|
||||||
logging.exception("Failed to load the file {}. Exception {}".format(filepath, e))
|
logging.exception("Failed to load the file {}. Exception {}".format(filepath, e))
|
||||||
|
|||||||
@@ -1,31 +1,27 @@
|
|||||||
try:
|
try:
|
||||||
import clip
|
|
||||||
import torch
|
|
||||||
from PIL import Image, UnidentifiedImageError
|
from PIL import Image, UnidentifiedImageError
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Images requires extra dependencies. Install with `pip install 'embedchain[images]' git+https://github.com/openai/CLIP.git#a1d0717`" # noqa: E501
|
"Images requires extra dependencies. Install with `pip install 'embedchain[images]'"
|
||||||
) from None
|
) from None
|
||||||
|
|
||||||
MODEL_NAME = "ViT-B/32"
|
MODEL_NAME = "clip-ViT-B-32"
|
||||||
|
|
||||||
|
|
||||||
class ClipProcessor:
|
class ClipProcessor:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_model():
|
def load_model():
|
||||||
"""Load data from a director of images."""
|
"""Load data from a director of images."""
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
|
|
||||||
# load model and image preprocessing
|
# load model and image preprocessing
|
||||||
model, preprocess = clip.load(MODEL_NAME, device=device, jit=False)
|
model = SentenceTransformer(MODEL_NAME)
|
||||||
return model, preprocess
|
return model
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_image_features(image_url, model, preprocess):
|
def get_image_features(image_url, model):
|
||||||
"""
|
"""
|
||||||
Applies the CLIP model to evaluate the vector representation of the supplied image
|
Applies the CLIP model to evaluate the vector representation of the supplied image
|
||||||
"""
|
"""
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
try:
|
try:
|
||||||
# load image
|
# load image
|
||||||
image = Image.open(image_url)
|
image = Image.open(image_url)
|
||||||
@@ -34,27 +30,15 @@ class ClipProcessor:
|
|||||||
except UnidentifiedImageError:
|
except UnidentifiedImageError:
|
||||||
raise UnidentifiedImageError("The supplied file is not an image`")
|
raise UnidentifiedImageError("The supplied file is not an image`")
|
||||||
|
|
||||||
# pre-process image
|
image_features = model.encode(image)
|
||||||
processed_image = preprocess(image).unsqueeze(0).to(device)
|
|
||||||
with torch.no_grad():
|
|
||||||
image_features = model.encode_image(processed_image)
|
|
||||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
image_features = image_features.cpu().detach().numpy().tolist()[0]
|
|
||||||
meta_data = {"url": image_url}
|
meta_data = {"url": image_url}
|
||||||
return {"content": image_url, "embedding": image_features, "meta_data": meta_data}
|
return {"content": image_url, "embedding": image_features.tolist(), "meta_data": meta_data}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_text_features(query):
|
def get_text_features(query):
|
||||||
"""
|
"""
|
||||||
Applies the CLIP model to evaluate the vector representation of the supplied text
|
Applies the CLIP model to evaluate the vector representation of the supplied text
|
||||||
"""
|
"""
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
model = ClipProcessor.load_model()
|
||||||
|
text_features = model.encode(query)
|
||||||
model, preprocess = ClipProcessor.load_model()
|
return text_features.tolist()
|
||||||
text = clip.tokenize(query).to(device)
|
|
||||||
with torch.no_grad():
|
|
||||||
text_features = model.encode_text(text)
|
|
||||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
return text_features.cpu().numpy().tolist()[0]
|
|
||||||
|
|||||||
@@ -128,7 +128,8 @@ def detect_datatype(source: Any) -> DataType:
|
|||||||
formatted_source = format_source(str(source), 30)
|
formatted_source = format_source(str(source), 30)
|
||||||
|
|
||||||
if url:
|
if url:
|
||||||
from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
|
from langchain.document_loaders.youtube import \
|
||||||
|
ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
|
||||||
|
|
||||||
if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
|
if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
|
||||||
logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
|
logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
|
||||||
|
|||||||
@@ -1,51 +1,43 @@
|
|||||||
# import os
|
import os
|
||||||
# import tempfile
|
import tempfile
|
||||||
# import urllib
|
import urllib
|
||||||
|
|
||||||
# import pytest
|
from PIL import Image
|
||||||
# from PIL import Image
|
|
||||||
|
|
||||||
# TODO: Uncomment after fixing clip dependency issue
|
from embedchain.models.clip_processor import ClipProcessor
|
||||||
# from embedchain.models.clip_processor import ClipProcessor
|
|
||||||
|
|
||||||
|
|
||||||
# class TestClipProcessor:
|
class TestClipProcessor:
|
||||||
# @pytest.mark.xfail(reason="This test is failing because of the missing CLIP dependency.")
|
def test_load_model(self):
|
||||||
# def test_load_model(self):
|
# Test that the `load_model()` method loads the CLIP model and image preprocessing correctly.
|
||||||
# # Test that the `load_model()` method loads the CLIP model and image preprocessing correctly.
|
model = ClipProcessor.load_model()
|
||||||
# model, preprocess = ClipProcessor.load_model()
|
assert model is not None
|
||||||
# assert model is not None
|
|
||||||
# assert preprocess is not None
|
|
||||||
|
|
||||||
# @pytest.mark.xfail(reason="This test is failing because of the missing CLIP dependency.")
|
def test_get_image_features(self):
|
||||||
# def test_get_image_features(self):
|
# Clone the image to a temporary folder.
|
||||||
# # Clone the image to a temporary folder.
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
# with tempfile.TemporaryDirectory() as tmp_dir:
|
urllib.request.urlretrieve("https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg", "image.jpg")
|
||||||
# urllib.request.urlretrieve("https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg", "image.jpg")
|
|
||||||
|
|
||||||
# image = Image.open("image.jpg")
|
image = Image.open("image.jpg")
|
||||||
# image.save(os.path.join(tmp_dir, "image.jpg"))
|
image.save(os.path.join(tmp_dir, "image.jpg"))
|
||||||
|
|
||||||
# # Get the image features.
|
# Get the image features.
|
||||||
# model, preprocess = ClipProcessor.load_model()
|
model = ClipProcessor.load_model()
|
||||||
# ClipProcessor.get_image_features(os.path.join(tmp_dir, "image.jpg"), model, preprocess)
|
ClipProcessor.get_image_features(os.path.join(tmp_dir, "image.jpg"), model)
|
||||||
|
|
||||||
# # Delete the temporary file.
|
# Delete the temporary file.
|
||||||
# os.remove(os.path.join(tmp_dir, "image.jpg"))
|
os.remove(os.path.join(tmp_dir, "image.jpg"))
|
||||||
|
|
||||||
# @pytest.mark.xfail(reason="This test is failing because of the missing CLIP dependency.")
|
def test_get_text_features(self):
|
||||||
# def test_get_text_features(self):
|
# Test that the `get_text_features()` method returns a list containing the text embedding.
|
||||||
# # Test that the `get_text_features()` method returns a list containing the text embedding.
|
query = "This is a text query."
|
||||||
# query = "This is a text query."
|
text_features = ClipProcessor.get_text_features(query)
|
||||||
# model, preprocess = ClipProcessor.load_model()
|
|
||||||
|
|
||||||
# text_features = ClipProcessor.get_text_features(query)
|
# Assert that the text embedding is not None.
|
||||||
|
assert text_features is not None
|
||||||
|
|
||||||
# # Assert that the text embedding is not None.
|
# Assert that the text embedding is a list of floats.
|
||||||
# assert text_features is not None
|
assert isinstance(text_features, list)
|
||||||
|
|
||||||
# # Assert that the text embedding is a list of floats.
|
# Assert that the text embedding has the correct length.
|
||||||
# assert isinstance(text_features, list)
|
assert len(text_features) == 512
|
||||||
|
|
||||||
# # Assert that the text embedding has the correct length.
|
|
||||||
# assert len(text_features) == 512
|
|
||||||
|
|||||||
Reference in New Issue
Block a user