diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index faa1bd1b..53c3d57f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,4 +34,4 @@ jobs: file: coverage.xml env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} - + diff --git a/Makefile b/Makefile index a83d57e8..cf1974b3 100644 --- a/Makefile +++ b/Makefile @@ -9,6 +9,9 @@ PROJECT_NAME := embedchain install: poetry install +install_all: + poetry install --all-extras + install_es: poetry install --extras elasticsearch diff --git a/embedchain/config/llm/base_llm_config.py b/embedchain/config/llm/base_llm_config.py index 075dd131..e9c11515 100644 --- a/embedchain/config/llm/base_llm_config.py +++ b/embedchain/config/llm/base_llm_config.py @@ -67,7 +67,7 @@ class BaseLlmConfig(BaseConfig): deployment_name: Optional[str] = None, system_prompt: Optional[str] = None, where: Dict[str, Any] = None, - query_type: Optional[str] = None + query_type: Optional[str] = None, ): """ Initializes a configuration class instance for the LLM. diff --git a/embedchain/data_formatter/data_formatter.py b/embedchain/data_formatter/data_formatter.py index 4e3be6ff..10829474 100644 --- a/embedchain/data_formatter/data_formatter.py +++ b/embedchain/data_formatter/data_formatter.py @@ -1,8 +1,8 @@ from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.docs_site import DocsSiteChunker from embedchain.chunkers.docx_file import DocxFileChunker -from embedchain.chunkers.mdx import MdxChunker from embedchain.chunkers.images import ImagesChunker +from embedchain.chunkers.mdx import MdxChunker from embedchain.chunkers.notion import NotionChunker from embedchain.chunkers.pdf_file import PdfFileChunker from embedchain.chunkers.qna_pair import QnaPairChunker diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 584601ae..dc8fd578 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -392,8 +392,13 @@ class EmbedChain(JSONSerializable): # Count before, to calculate a delta in the end. chunks_before_addition = self.db.count() - self.db.add(embeddings=embeddings_data.get("embeddings", None), documents=documents, metadatas=metadatas, - ids=ids, skip_embedding = (chunker.data_type == DataType.IMAGES)) + self.db.add( + embeddings=embeddings_data.get("embeddings", None), + documents=documents, + metadatas=metadatas, + ids=ids, + skip_embedding=(chunker.data_type == DataType.IMAGES), + ) count_new_chunks = self.db.count() - chunks_before_addition print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}")) return list(documents), metadatas, ids, count_new_chunks @@ -437,17 +442,18 @@ class EmbedChain(JSONSerializable): # We cannot query the database with the input query in case of an image search. This is because we need # to bring down both the image and text to the same dimension to be able to compare them. db_query = input_query - if config.query_type == "Images": + if hasattr(config, "query_type") and config.query_type == "Images": # We import the clip processor here to make sure the package is not dependent on clip dependency even if the # image dataset is not being used from embedchain.models.clip_processor import ClipProcessor + db_query = ClipProcessor.get_text_features(query=input_query) contents = self.db.query( input_query=db_query, n_results=query_config.number_documents, where=where, - skip_embedding = (config.query_type == "Images") + skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"), ) return contents diff --git a/embedchain/llm/gpt4all.py b/embedchain/llm/gpt4all.py index 586ac5be..91667fa7 100644 --- a/embedchain/llm/gpt4all.py +++ b/embedchain/llm/gpt4all.py @@ -22,7 +22,7 @@ class GPT4ALLLlm(BaseLlm): from gpt4all import GPT4All except ModuleNotFoundError: raise ModuleNotFoundError( - "The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`" # noqa E501 + "The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`" # noqa E501 ) from None return GPT4All(model_name=model) diff --git a/embedchain/loaders/images.py b/embedchain/loaders/images.py index b7b53487..f80afa9b 100644 --- a/embedchain/loaders/images.py +++ b/embedchain/loaders/images.py @@ -1,11 +1,11 @@ -import os -import logging import hashlib +import logging +import os + from embedchain.loaders.base_loader import BaseLoader class ImagesLoader(BaseLoader): - def load_data(self, image_url): """ Loads images from the supplied directory/file and applies CLIP model transformation to represent these images @@ -15,6 +15,7 @@ class ImagesLoader(BaseLoader): """ # load model and image preprocessing from embedchain.models.clip_processor import ClipProcessor + model, preprocess = ClipProcessor.load_model() if os.path.isfile(image_url): data = [ClipProcessor.get_image_features(image_url, model, preprocess)] @@ -28,8 +29,11 @@ class ImagesLoader(BaseLoader): # Log the file that was not loaded logging.exception("Failed to load the file {}. Exception {}".format(filepath, e)) # Get the metadata like Size, Last Modified and Last Created timestamps - image_path_metadata = [str(os.path.getsize(image_url)), str(os.path.getmtime(image_url)), - str(os.path.getctime(image_url))] + image_path_metadata = [ + str(os.path.getsize(image_url)), + str(os.path.getmtime(image_url)), + str(os.path.getctime(image_url)), + ] doc_id = hashlib.sha256((" ".join(image_path_metadata) + image_url).encode()).hexdigest() return { "doc_id": doc_id, diff --git a/embedchain/models/clip_processor.py b/embedchain/models/clip_processor.py index 5bc3108c..1c5c404f 100644 --- a/embedchain/models/clip_processor.py +++ b/embedchain/models/clip_processor.py @@ -1,6 +1,6 @@ try: - import torch import clip + import torch from PIL import Image, UnidentifiedImageError except ImportError: raise ImportError("Images requires extra dependencies. Install with `pip install embedchain[images]`") from None @@ -39,14 +39,8 @@ class ClipProcessor: image_features /= image_features.norm(dim=-1, keepdim=True) image_features = image_features.cpu().detach().numpy().tolist()[0] - meta_data = { - "url": image_url - } - return { - "content": image_url, - "embedding": image_features, - "meta_data": meta_data - } + meta_data = {"url": image_url} + return {"content": image_url, "embedding": image_features, "meta_data": meta_data} @staticmethod def get_text_features(query): diff --git a/embedchain/vectordb/chroma.py b/embedchain/vectordb/chroma.py index 504dee95..fa615a41 100644 --- a/embedchain/vectordb/chroma.py +++ b/embedchain/vectordb/chroma.py @@ -115,8 +115,14 @@ class ChromaDB(BaseVectorDB): def get_advanced(self, where): return self.collection.get(where=where, limit=1) - def add(self, embeddings: List[List[float]], documents: List[str], metadatas: List[object], - ids: List[str], skip_embedding: bool) -> Any: + def add( + self, + embeddings: List[List[float]], + documents: List[str], + metadatas: List[object], + ids: List[str], + skip_embedding: bool, + ) -> Any: """ Add vectors to chroma database @@ -184,7 +190,7 @@ class ChromaDB(BaseVectorDB): except InvalidDimensionException as e: raise InvalidDimensionException( e.message() - + ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501 + + ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501 ) from None results_formatted = self._format_result(result) contents = [result[0].page_content for result in results_formatted] diff --git a/embedchain/vectordb/elasticsearch.py b/embedchain/vectordb/elasticsearch.py index 4c45b1f8..48a6d44d 100644 --- a/embedchain/vectordb/elasticsearch.py +++ b/embedchain/vectordb/elasticsearch.py @@ -100,8 +100,14 @@ class ElasticsearchDB(BaseVectorDB): ids = [doc["_id"] for doc in docs] return {"ids": set(ids)} - def add(self, embeddings: List[List[float]], documents: List[str], metadatas: List[object], - ids: List[str], skip_embedding: bool) -> Any: + def add( + self, + embeddings: List[List[float]], + documents: List[str], + metadatas: List[object], + ids: List[str], + skip_embedding: bool, + ) -> Any: """ add data in vector database :param documents: list of texts to add diff --git a/pyproject.toml b/pyproject.toml index 4eb3a223..9bf49fcd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,7 +94,7 @@ pytube = "^15.0.0" duckduckgo-search = "^3.8.5" llama-hub = { version = "^0.0.29", optional = true } sentence-transformers = { version = "^2.2.2", optional = true } -torch = { version = ">=2.0.0, !=2.0.1", optional = true } +torch = { version = "2.0.0", optional = true } # Torch 2.0.1 is not compatible with poetry (https://github.com/pytorch/pytorch/issues/100974) gpt4all = { version = "1.0.8", optional = true } # 1.0.9 is not working for some users (https://github.com/nomic-ai/gpt4all/issues/1394) @@ -107,6 +107,8 @@ discord = { version = "^2.3.2", optional = true } slack-sdk = { version = "3.21.3", optional = true } docx2txt = "^0.8" clip = {git = "https://github.com/openai/CLIP.git#a1d0717", optional = true} +pillow = { version = "10.0.1", optional = true } +torchvision = { version = ">=0.15.1, !=0.15.2", optional = true } ftfy = { version = "6.1.1", optional = true } regex = { version = "2023.8.8", optional = true } @@ -131,7 +133,7 @@ poe = ["fastapi-poe"] discord = ["discord"] slack = ["slack-sdk", "flask"] whatsapp = ["twilio", "flask"] -images = ["torch", "ftfy", "regex", "clip"] +images = ["torch", "ftfy", "regex", "clip", "pillow", "torchvision"] [tool.poetry.group.docs.dependencies] diff --git a/tests/chunkers/test_image_chunker.py b/tests/chunkers/test_image_chunker.py index 7b298337..299ff98d 100644 --- a/tests/chunkers/test_image_chunker.py +++ b/tests/chunkers/test_image_chunker.py @@ -19,11 +19,13 @@ class TestImageChunker(unittest.TestCase): image_path = "./tmp/image.jpeg" result = chunker.create_chunks(MockLoader(), image_path) - expected_chunks = {'doc_id': '123', - 'documents': [image_path], - 'embeddings': ['embedding'], - 'ids': ['140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe'], - 'metadatas': [{'data_type': 'images', 'doc_id': '123', 'url': 'none'}]} + expected_chunks = { + "doc_id": "123", + "documents": [image_path], + "embeddings": ["embedding"], + "ids": ["140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe"], + "metadatas": [{"data_type": "images", "doc_id": "123", "url": "none"}], + } self.assertEqual(expected_chunks, result) def test_chunks_with_default_config(self): @@ -37,11 +39,13 @@ class TestImageChunker(unittest.TestCase): image_path = "./tmp/image.jpeg" result = chunker.create_chunks(MockLoader(), image_path) - expected_chunks = {'doc_id': '123', - 'documents': [image_path], - 'embeddings': ['embedding'], - 'ids': ['140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe'], - 'metadatas': [{'data_type': 'images', 'doc_id': '123', 'url': 'none'}]} + expected_chunks = { + "doc_id": "123", + "documents": [image_path], + "embeddings": ["embedding"], + "ids": ["140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe"], + "metadatas": [{"data_type": "images", "doc_id": "123", "url": "none"}], + } self.assertEqual(expected_chunks, result) def test_word_count(self): diff --git a/tests/models/test_clip_processor.py b/tests/models/test_clip_processor.py index debe618b..de60fdb8 100644 --- a/tests/models/test_clip_processor.py +++ b/tests/models/test_clip_processor.py @@ -1,29 +1,23 @@ -import tempfile -import unittest import os +import tempfile import urllib + from PIL import Image + from embedchain.models.clip_processor import ClipProcessor -class ClipProcessorTest(unittest.TestCase): - +class TestClipProcessor: def test_load_model(self): # Test that the `load_model()` method loads the CLIP model and image preprocessing correctly. model, preprocess = ClipProcessor.load_model() - - # Assert that the model is not None. - self.assertIsNotNone(model) - - # Assert that the preprocess is not None. - self.assertIsNotNone(preprocess) + assert model is not None + assert preprocess is not None def test_get_image_features(self): # Clone the image to a temporary folder. with tempfile.TemporaryDirectory() as tmp_dir: - urllib.request.urlretrieve( - 'https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg', - "image.jpg") + urllib.request.urlretrieve("https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg", "image.jpg") image = Image.open("image.jpg") image.save(os.path.join(tmp_dir, "image.jpg")) @@ -35,9 +29,6 @@ class ClipProcessorTest(unittest.TestCase): # Delete the temporary file. os.remove(os.path.join(tmp_dir, "image.jpg")) - # Assert that the test passes. - self.assertTrue(True) - def test_get_text_features(self): # Test that the `get_text_features()` method returns a list containing the text embedding. query = "This is a text query." @@ -46,10 +37,10 @@ class ClipProcessorTest(unittest.TestCase): text_features = ClipProcessor.get_text_features(query) # Assert that the text embedding is not None. - self.assertIsNotNone(text_features) + assert text_features is not None # Assert that the text embedding is a list of floats. - self.assertIsInstance(text_features, list) + assert isinstance(text_features, list) # Assert that the text embedding has the correct length. - self.assertEqual(len(text_features), 512) + assert len(text_features) == 512 diff --git a/tests/vectordb/test_chroma_db.py b/tests/vectordb/test_chroma_db.py index f9cbeadf..3be89d49 100644 --- a/tests/vectordb/test_chroma_db.py +++ b/tests/vectordb/test_chroma_db.py @@ -197,21 +197,29 @@ class TestChromaDbCollection(unittest.TestCase): # Collection should be empty when created self.assertEqual(self.app_with_settings.db.count(), 0) - self.app_with_settings.db.add(embeddings=[[0, 0, 0]], documents=["document"], metadatas=[{"value": "somevalue"}], ids=["id"], skip_embedding=True) + self.app_with_settings.db.add( + embeddings=[[0, 0, 0]], + documents=["document"], + metadatas=[{"value": "somevalue"}], + ids=["id"], + skip_embedding=True, + ) # After adding, should contain one item self.assertEqual(self.app_with_settings.db.count(), 1) # Validate if the get utility of the database is working as expected data = self.app_with_settings.db.get(["id"], limit=1) - expected_value = {'documents': ['document'], - 'embeddings': None, - 'ids': ['id'], - 'metadatas': [{'value': 'somevalue'}]} + expected_value = { + "documents": ["document"], + "embeddings": None, + "ids": ["id"], + "metadatas": [{"value": "somevalue"}], + } self.assertEqual(data, expected_value) # Validate if the query utility of the database is working as expected data = self.app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True) - expected_value = ['document'] + expected_value = ["document"] self.assertEqual(data, expected_value) def test_collections_are_persistent(self): diff --git a/tests/vectordb/test_elasticsearch_db.py b/tests/vectordb/test_elasticsearch_db.py index 8e090718..ed7036c9 100644 --- a/tests/vectordb/test_elasticsearch_db.py +++ b/tests/vectordb/test_elasticsearch_db.py @@ -4,11 +4,11 @@ from unittest.mock import patch from embedchain import App from embedchain.config import AppConfig, ElasticsearchDBConfig -from embedchain.vectordb.elasticsearch import ElasticsearchDB from embedchain.embedder.gpt4all import GPT4AllEmbedder +from embedchain.vectordb.elasticsearch import ElasticsearchDB + class TestEsDB(unittest.TestCase): - @patch("embedchain.vectordb.elasticsearch.Elasticsearch") def test_setUp(self, mock_client): self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200")) @@ -37,17 +37,11 @@ class TestEsDB(unittest.TestCase): # Add the data to the database. self.db.add(embeddings, documents, metadatas, ids, skip_embedding=False) - search_response = {"hits": - {"hits": - [ - { - "_source": {"text": "This is a document."}, - "_score": 0.9 - }, - { - "_source": {"text": "This is another document."}, - "_score": 0.8 - } + search_response = { + "hits": { + "hits": [ + {"_source": {"text": "This is a document."}, "_score": 0.9}, + {"_source": {"text": "This is another document."}, "_score": 0.8}, ] } } @@ -80,17 +74,11 @@ class TestEsDB(unittest.TestCase): # Add the data to the database. self.db.add(embeddings, documents, metadatas, ids, skip_embedding=True) - search_response = {"hits": - {"hits": - [ - { - "_source": {"text": "This is a document."}, - "_score": 0.9 - }, - { - "_source": {"text": "This is another document."}, - "_score": 0.8 - } + search_response = { + "hits": { + "hits": [ + {"_source": {"text": "This is a document."}, "_score": 0.9}, + {"_source": {"text": "This is another document."}, "_score": 0.8}, ] } }