diff --git a/embedchain/chunkers/base_chunker.py b/embedchain/chunkers/base_chunker.py index 42c030be..09fb8805 100644 --- a/embedchain/chunkers/base_chunker.py +++ b/embedchain/chunkers/base_chunker.py @@ -66,3 +66,6 @@ class BaseChunker(JSONSerializable): self.data_type = data_type # TODO: This should be done during initialization. This means it has to be done in the child classes. + + def get_word_count(self, documents): + return sum([len(document.split(" ")) for document in documents]) diff --git a/embedchain/chunkers/images.py b/embedchain/chunkers/images.py new file mode 100644 index 00000000..22e4a7fd --- /dev/null +++ b/embedchain/chunkers/images.py @@ -0,0 +1,63 @@ +import hashlib +from typing import Optional + +from langchain.text_splitter import RecursiveCharacterTextSplitter + +from embedchain.chunkers.base_chunker import BaseChunker +from embedchain.config.add_config import ChunkerConfig + + +class ImagesChunker(BaseChunker): + """Chunker for an Image.""" + + def __init__(self, config: Optional[ChunkerConfig] = None): + if config is None: + config = ChunkerConfig(chunk_size=300, chunk_overlap=0, length_function=len) + image_splitter = RecursiveCharacterTextSplitter( + chunk_size=config.chunk_size, + chunk_overlap=config.chunk_overlap, + length_function=config.length_function, + ) + super().__init__(image_splitter) + + def create_chunks(self, loader, src): + """ + Loads the image(s), and creates their corresponding embedding. This creates one chunk for each image + + :param loader: The loader whose `load_data` method is used to create + the raw data. + :param src: The data to be handled by the loader. Can be a URL for + remote sources or local content for local loaders. + """ + documents = [] + embeddings = [] + ids = [] + data_result = loader.load_data(src) + data_records = data_result["data"] + doc_id = data_result["doc_id"] + metadatas = [] + for data in data_records: + meta_data = data["meta_data"] + # add data type to meta data to allow query using data type + meta_data["data_type"] = self.data_type.value + chunk_id = hashlib.sha256(meta_data["url"].encode()).hexdigest() + ids.append(chunk_id) + documents.append(data["content"]) + embeddings.append(data["embedding"]) + meta_data["doc_id"] = doc_id + metadatas.append(meta_data) + + return { + "documents": documents, + "embeddings": embeddings, + "ids": ids, + "metadatas": metadatas, + "doc_id": doc_id, + } + + def get_word_count(self, documents): + """ + The number of chunks and the corresponding word count for an image is fixed to 1, as 1 embedding is created for + each image + """ + return 1 diff --git a/embedchain/config/llm/base_llm_config.py b/embedchain/config/llm/base_llm_config.py index 276903e7..075dd131 100644 --- a/embedchain/config/llm/base_llm_config.py +++ b/embedchain/config/llm/base_llm_config.py @@ -67,6 +67,7 @@ class BaseLlmConfig(BaseConfig): deployment_name: Optional[str] = None, system_prompt: Optional[str] = None, where: Dict[str, Any] = None, + query_type: Optional[str] = None ): """ Initializes a configuration class instance for the LLM. @@ -112,6 +113,7 @@ class BaseLlmConfig(BaseConfig): self.top_p = top_p self.deployment_name = deployment_name self.system_prompt = system_prompt + self.query_type = query_type if self.validate_template(template): self.template = template diff --git a/embedchain/data_formatter/data_formatter.py b/embedchain/data_formatter/data_formatter.py index 4312cd9c..4e3be6ff 100644 --- a/embedchain/data_formatter/data_formatter.py +++ b/embedchain/data_formatter/data_formatter.py @@ -2,6 +2,7 @@ from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.docs_site import DocsSiteChunker from embedchain.chunkers.docx_file import DocxFileChunker from embedchain.chunkers.mdx import MdxChunker +from embedchain.chunkers.images import ImagesChunker from embedchain.chunkers.notion import NotionChunker from embedchain.chunkers.pdf_file import PdfFileChunker from embedchain.chunkers.qna_pair import QnaPairChunker @@ -16,6 +17,7 @@ from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.csv import CsvLoader from embedchain.loaders.docs_site_loader import DocsSiteLoader from embedchain.loaders.docx_file import DocxFileLoader +from embedchain.loaders.images import ImagesLoader from embedchain.loaders.local_qna_pair import LocalQnaPairLoader from embedchain.loaders.local_text import LocalTextLoader from embedchain.loaders.mdx import MdxLoader @@ -68,6 +70,7 @@ class DataFormatter(JSONSerializable): DataType.DOCS_SITE: DocsSiteLoader, DataType.CSV: CsvLoader, DataType.MDX: MdxLoader, + DataType.IMAGES: ImagesLoader, } lazy_loaders = {DataType.NOTION} if data_type in loaders: @@ -102,11 +105,11 @@ class DataFormatter(JSONSerializable): DataType.QNA_PAIR: QnaPairChunker, DataType.TEXT: TextChunker, DataType.DOCX: DocxFileChunker, - DataType.WEB_PAGE: WebPageChunker, DataType.DOCS_SITE: DocsSiteChunker, DataType.NOTION: NotionChunker, DataType.CSV: TableChunker, DataType.MDX: MdxChunker, + DataType.IMAGES: ImagesChunker, } if data_type in chunker_classes: chunker_class: type = chunker_classes[data_type] diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 90162863..584601ae 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -212,7 +212,7 @@ class EmbedChain(JSONSerializable): # Send anonymous telemetry if self.config.collect_metrics: # it's quicker to check the variable twice than to count words when they won't be submitted. - word_count = sum([len(document.split(" ")) for document in documents]) + word_count = data_formatter.chunker.get_word_count(documents) extra_metadata = {"data_type": data_type.value, "word_count": word_count, "chunks_count": new_chunks} thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("add", extra_metadata)) @@ -329,7 +329,6 @@ class EmbedChain(JSONSerializable): # Create chunks embeddings_data = chunker.create_chunks(loader, src) - # spread chunking results documents = embeddings_data["documents"] metadatas = embeddings_data["metadatas"] @@ -393,7 +392,8 @@ class EmbedChain(JSONSerializable): # Count before, to calculate a delta in the end. chunks_before_addition = self.db.count() - self.db.add(documents=documents, metadatas=metadatas, ids=ids) + self.db.add(embeddings=embeddings_data.get("embeddings", None), documents=documents, metadatas=metadatas, + ids=ids, skip_embedding = (chunker.data_type == DataType.IMAGES)) count_new_chunks = self.db.count() - chunks_before_addition print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}")) return list(documents), metadatas, ids, count_new_chunks @@ -434,10 +434,20 @@ class EmbedChain(JSONSerializable): if self.config.id is not None: where.update({"app_id": self.config.id}) + # We cannot query the database with the input query in case of an image search. This is because we need + # to bring down both the image and text to the same dimension to be able to compare them. + db_query = input_query + if config.query_type == "Images": + # We import the clip processor here to make sure the package is not dependent on clip dependency even if the + # image dataset is not being used + from embedchain.models.clip_processor import ClipProcessor + db_query = ClipProcessor.get_text_features(query=input_query) + contents = self.db.query( - input_query=input_query, + input_query=db_query, n_results=query_config.number_documents, where=where, + skip_embedding = (config.query_type == "Images") ) return contents diff --git a/embedchain/llm/base.py b/embedchain/llm/base.py index 868fb86f..35251719 100644 --- a/embedchain/llm/base.py +++ b/embedchain/llm/base.py @@ -191,6 +191,9 @@ class BaseLlm(JSONSerializable): prev_config = self.config.serialize() self.config = config + if config is not None and config.query_type == "Images": + return contexts + if self.is_docs_site_instance: self.config.template = DOCS_SITE_PROMPT_TEMPLATE self.config.number_documents = 5 diff --git a/embedchain/loaders/images.py b/embedchain/loaders/images.py new file mode 100644 index 00000000..b7b53487 --- /dev/null +++ b/embedchain/loaders/images.py @@ -0,0 +1,37 @@ +import os +import logging +import hashlib +from embedchain.loaders.base_loader import BaseLoader + + +class ImagesLoader(BaseLoader): + + def load_data(self, image_url): + """ + Loads images from the supplied directory/file and applies CLIP model transformation to represent these images + in vector form + + :param image_url: The URL from which the images are to be loaded + """ + # load model and image preprocessing + from embedchain.models.clip_processor import ClipProcessor + model, preprocess = ClipProcessor.load_model() + if os.path.isfile(image_url): + data = [ClipProcessor.get_image_features(image_url, model, preprocess)] + else: + data = [] + for filename in os.listdir(image_url): + filepath = os.path.join(image_url, filename) + try: + data.append(ClipProcessor.get_image_features(filepath, model, preprocess)) + except Exception as e: + # Log the file that was not loaded + logging.exception("Failed to load the file {}. Exception {}".format(filepath, e)) + # Get the metadata like Size, Last Modified and Last Created timestamps + image_path_metadata = [str(os.path.getsize(image_url)), str(os.path.getmtime(image_url)), + str(os.path.getctime(image_url))] + doc_id = hashlib.sha256((" ".join(image_path_metadata) + image_url).encode()).hexdigest() + return { + "doc_id": doc_id, + "data": data, + } diff --git a/embedchain/models/clip_processor.py b/embedchain/models/clip_processor.py new file mode 100644 index 00000000..5bc3108c --- /dev/null +++ b/embedchain/models/clip_processor.py @@ -0,0 +1,64 @@ +try: + import torch + import clip + from PIL import Image, UnidentifiedImageError +except ImportError: + raise ImportError("Images requires extra dependencies. Install with `pip install embedchain[images]`") from None + +MODEL_NAME = "ViT-B/32" + + +class ClipProcessor: + @staticmethod + def load_model(): + """Load data from a director of images.""" + device = "cuda" if torch.cuda.is_available() else "cpu" + + # load model and image preprocessing + model, preprocess = clip.load(MODEL_NAME, device=device, jit=False) + return model, preprocess + + @staticmethod + def get_image_features(image_url, model, preprocess): + """ + Applies the CLIP model to evaluate the vector representation of the supplied image + """ + device = "cuda" if torch.cuda.is_available() else "cpu" + try: + # load image + image = Image.open(image_url) + except FileNotFoundError: + raise FileNotFoundError("The supplied file does not exist`") + except UnidentifiedImageError: + raise UnidentifiedImageError("The supplied file is not an image`") + + # pre-process image + processed_image = preprocess(image).unsqueeze(0).to(device) + with torch.no_grad(): + image_features = model.encode_image(processed_image) + image_features /= image_features.norm(dim=-1, keepdim=True) + + image_features = image_features.cpu().detach().numpy().tolist()[0] + meta_data = { + "url": image_url + } + return { + "content": image_url, + "embedding": image_features, + "meta_data": meta_data + } + + @staticmethod + def get_text_features(query): + """ + Applies the CLIP model to evaluate the vector representation of the supplied text + """ + device = "cuda" if torch.cuda.is_available() else "cpu" + + model, preprocess = ClipProcessor.load_model() + text = clip.tokenize(query).to(device) + with torch.no_grad(): + text_features = model.encode_text(text) + text_features /= text_features.norm(dim=-1, keepdim=True) + + return text_features.cpu().numpy().tolist()[0] diff --git a/embedchain/models/data_type.py b/embedchain/models/data_type.py index d41bf0bd..90d7dd91 100644 --- a/embedchain/models/data_type.py +++ b/embedchain/models/data_type.py @@ -23,6 +23,7 @@ class IndirectDataType(Enum): NOTION = "notion" CSV = "csv" MDX = "mdx" + IMAGES = "images" class SpecialDataType(Enum): @@ -45,3 +46,4 @@ class DataType(Enum): CSV = IndirectDataType.CSV.value MDX = IndirectDataType.MDX.value QNA_PAIR = SpecialDataType.QNA_PAIR.value + IMAGES = IndirectDataType.IMAGES.value diff --git a/embedchain/vectordb/chroma.py b/embedchain/vectordb/chroma.py index 9319d909..504dee95 100644 --- a/embedchain/vectordb/chroma.py +++ b/embedchain/vectordb/chroma.py @@ -115,7 +115,8 @@ class ChromaDB(BaseVectorDB): def get_advanced(self, where): return self.collection.get(where=where, limit=1) - def add(self, documents: List[str], metadatas: List[object], ids: List[str]) -> Any: + def add(self, embeddings: List[List[float]], documents: List[str], metadatas: List[object], + ids: List[str], skip_embedding: bool) -> Any: """ Add vectors to chroma database @@ -126,7 +127,10 @@ class ChromaDB(BaseVectorDB): :param ids: ids :type ids: List[str] """ - self.collection.add(documents=documents, metadatas=metadatas, ids=ids) + if skip_embedding: + self.collection.add(embeddings=embeddings, documents=documents, metadatas=metadatas, ids=ids) + else: + self.collection.add(documents=documents, metadatas=metadatas, ids=ids) def _format_result(self, results: QueryResult) -> list[tuple[Document, float]]: """ @@ -146,7 +150,7 @@ class ChromaDB(BaseVectorDB): ) ] - def query(self, input_query: List[str], n_results: int, where: Dict[str, Any]) -> List[str]: + def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]: """ Query contents from vector data base based on vector similarity @@ -161,19 +165,27 @@ class ChromaDB(BaseVectorDB): :rtype: List[str] """ try: - result = self.collection.query( - query_texts=[ - input_query, - ], - n_results=n_results, - where=where, - ) + if skip_embedding: + result = self.collection.query( + query_embeddings=[ + input_query, + ], + n_results=n_results, + where=where, + ) + else: + result = self.collection.query( + query_texts=[ + input_query, + ], + n_results=n_results, + where=where, + ) except InvalidDimensionException as e: raise InvalidDimensionException( e.message() - + ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501 + + ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501 ) from None - results_formatted = self._format_result(result) contents = [result[0].page_content for result in results_formatted] return contents diff --git a/embedchain/vectordb/elasticsearch.py b/embedchain/vectordb/elasticsearch.py index 71da4a22..4c45b1f8 100644 --- a/embedchain/vectordb/elasticsearch.py +++ b/embedchain/vectordb/elasticsearch.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, List, Optional, Set +from typing import Any, Dict, List, Optional, Set try: from elasticsearch import Elasticsearch @@ -100,9 +100,10 @@ class ElasticsearchDB(BaseVectorDB): ids = [doc["_id"] for doc in docs] return {"ids": set(ids)} - def add(self, documents: List[str], metadatas: List[object], ids: List[str]): - """add data in vector database - + def add(self, embeddings: List[List[float]], documents: List[str], metadatas: List[object], + ids: List[str], skip_embedding: bool) -> Any: + """ + add data in vector database :param documents: list of texts to add :type documents: List[str] :param metadatas: list of metadata associated with docs @@ -112,7 +113,9 @@ class ElasticsearchDB(BaseVectorDB): """ docs = [] - embeddings = self.embedder.embedding_fn(documents) + if not skip_embedding: + embeddings = self.embedder.embedding_fn(documents) + for id, text, metadata, embeddings in zip(ids, documents, metadatas, embeddings): docs.append( { @@ -124,7 +127,7 @@ class ElasticsearchDB(BaseVectorDB): bulk(self.client, docs) self.client.indices.refresh(index=self._get_index()) - def query(self, input_query: List[str], n_results: int, where: Dict[str, any]) -> List[str]: + def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]: """ query contents from vector data base based on vector similarity @@ -137,8 +140,12 @@ class ElasticsearchDB(BaseVectorDB): :return: Database contents that are the result of the query :rtype: List[str] """ - input_query_vector = self.embedder.embedding_fn(input_query) - query_vector = input_query_vector[0] + if skip_embedding: + query_vector = input_query + else: + input_query_vector = self.embedder.embedding_fn(input_query) + query_vector = input_query_vector[0] + query = { "script_score": { "query": {"bool": {"must": [{"exists": {"field": "text"}}]}}, diff --git a/image.jpg b/image.jpg new file mode 100644 index 00000000..a686d104 Binary files /dev/null and b/image.jpg differ diff --git a/pyproject.toml b/pyproject.toml index 89bfa112..4eb3a223 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,8 +106,9 @@ fastapi-poe = { version = "0.0.16", optional = true } discord = { version = "^2.3.2", optional = true } slack-sdk = { version = "3.21.3", optional = true } docx2txt = "^0.8" - - +clip = {git = "https://github.com/openai/CLIP.git#a1d0717", optional = true} +ftfy = { version = "6.1.1", optional = true } +regex = { version = "2023.8.8", optional = true } [tool.poetry.group.dev.dependencies] black = "^23.3.0" @@ -130,6 +131,7 @@ poe = ["fastapi-poe"] discord = ["discord"] slack = ["slack-sdk", "flask"] whatsapp = ["twilio", "flask"] +images = ["torch", "ftfy", "regex", "clip"] [tool.poetry.group.docs.dependencies] diff --git a/tests/chunkers/test_image_chunker.py b/tests/chunkers/test_image_chunker.py new file mode 100644 index 00000000..7b298337 --- /dev/null +++ b/tests/chunkers/test_image_chunker.py @@ -0,0 +1,72 @@ +import unittest + +from embedchain.chunkers.images import ImagesChunker +from embedchain.config import ChunkerConfig +from embedchain.models.data_type import DataType + + +class TestImageChunker(unittest.TestCase): + def test_chunks(self): + """ + Test the chunks generated by TextChunker. + # TODO: Not a very precise test. + """ + chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len) + chunker = ImagesChunker(config=chunker_config) + # Data type must be set manually in the test + chunker.set_data_type(DataType.IMAGES) + + image_path = "./tmp/image.jpeg" + result = chunker.create_chunks(MockLoader(), image_path) + + expected_chunks = {'doc_id': '123', + 'documents': [image_path], + 'embeddings': ['embedding'], + 'ids': ['140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe'], + 'metadatas': [{'data_type': 'images', 'doc_id': '123', 'url': 'none'}]} + self.assertEqual(expected_chunks, result) + + def test_chunks_with_default_config(self): + """ + Test the chunks generated by ImageChunker with default config. + """ + chunker = ImagesChunker() + # Data type must be set manually in the test + chunker.set_data_type(DataType.IMAGES) + + image_path = "./tmp/image.jpeg" + result = chunker.create_chunks(MockLoader(), image_path) + + expected_chunks = {'doc_id': '123', + 'documents': [image_path], + 'embeddings': ['embedding'], + 'ids': ['140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe'], + 'metadatas': [{'data_type': 'images', 'doc_id': '123', 'url': 'none'}]} + self.assertEqual(expected_chunks, result) + + def test_word_count(self): + chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len) + chunker = ImagesChunker(config=chunker_config) + chunker.set_data_type(DataType.IMAGES) + + document = [["ab cd", "ef gh"], ["ij kl", "mn op"]] + result = chunker.get_word_count(document) + self.assertEqual(result, 1) + + +class MockLoader: + def load_data(self, src): + """ + Mock loader that returns a list of data dictionaries. + Adjust this method to return different data for testing. + """ + return { + "doc_id": "123", + "data": [ + { + "content": src, + "embedding": "embedding", + "meta_data": {"url": "none"}, + } + ], + } diff --git a/tests/chunkers/test_text.py b/tests/chunkers/test_text.py index e5bc32ab..c6fabe98 100644 --- a/tests/chunkers/test_text.py +++ b/tests/chunkers/test_text.py @@ -62,6 +62,15 @@ class TestTextChunker(unittest.TestCase): self.assertEqual(len(documents), len(text)) + def test_word_count(self): + chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len) + chunker = TextChunker(config=chunker_config) + chunker.set_data_type(DataType.TEXT) + + document = ["ab cd", "ef gh"] + result = chunker.get_word_count(document) + self.assertEqual(result, 4) + class MockLoader: def load_data(self, src): diff --git a/tests/models/image.jpg b/tests/models/image.jpg new file mode 100644 index 00000000..a686d104 Binary files /dev/null and b/tests/models/image.jpg differ diff --git a/tests/models/test_clip_processor.py b/tests/models/test_clip_processor.py new file mode 100644 index 00000000..debe618b --- /dev/null +++ b/tests/models/test_clip_processor.py @@ -0,0 +1,55 @@ +import tempfile +import unittest +import os +import urllib +from PIL import Image +from embedchain.models.clip_processor import ClipProcessor + + +class ClipProcessorTest(unittest.TestCase): + + def test_load_model(self): + # Test that the `load_model()` method loads the CLIP model and image preprocessing correctly. + model, preprocess = ClipProcessor.load_model() + + # Assert that the model is not None. + self.assertIsNotNone(model) + + # Assert that the preprocess is not None. + self.assertIsNotNone(preprocess) + + def test_get_image_features(self): + # Clone the image to a temporary folder. + with tempfile.TemporaryDirectory() as tmp_dir: + urllib.request.urlretrieve( + 'https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg', + "image.jpg") + + image = Image.open("image.jpg") + image.save(os.path.join(tmp_dir, "image.jpg")) + + # Get the image features. + model, preprocess = ClipProcessor.load_model() + ClipProcessor.get_image_features(os.path.join(tmp_dir, "image.jpg"), model, preprocess) + + # Delete the temporary file. + os.remove(os.path.join(tmp_dir, "image.jpg")) + + # Assert that the test passes. + self.assertTrue(True) + + def test_get_text_features(self): + # Test that the `get_text_features()` method returns a list containing the text embedding. + query = "This is a text query." + model, preprocess = ClipProcessor.load_model() + + text_features = ClipProcessor.get_text_features(query) + + # Assert that the text embedding is not None. + self.assertIsNotNone(text_features) + + # Assert that the text embedding is a list of floats. + self.assertIsInstance(text_features, list) + + # Assert that the text embedding has the correct length. + self.assertEqual(len(text_features), 512) diff --git a/tests/vectordb/test_chroma_db.py b/tests/vectordb/test_chroma_db.py index 818db1cf..f9cbeadf 100644 --- a/tests/vectordb/test_chroma_db.py +++ b/tests/vectordb/test_chroma_db.py @@ -186,6 +186,34 @@ class TestChromaDbCollection(unittest.TestCase): # Should still be 1, not 2. self.assertEqual(app.db.count(), 1) + def test_add_with_skip_embedding(self): + """ + Test that changes to one collection do not affect the other collection + """ + # Start with a clean app + self.app_with_settings.reset() + # app = App(config=AppConfig(collect_metrics=False), db=db) + + # Collection should be empty when created + self.assertEqual(self.app_with_settings.db.count(), 0) + + self.app_with_settings.db.add(embeddings=[[0, 0, 0]], documents=["document"], metadatas=[{"value": "somevalue"}], ids=["id"], skip_embedding=True) + # After adding, should contain one item + self.assertEqual(self.app_with_settings.db.count(), 1) + + # Validate if the get utility of the database is working as expected + data = self.app_with_settings.db.get(["id"], limit=1) + expected_value = {'documents': ['document'], + 'embeddings': None, + 'ids': ['id'], + 'metadatas': [{'value': 'somevalue'}]} + self.assertEqual(data, expected_value) + + # Validate if the query utility of the database is working as expected + data = self.app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True) + expected_value = ['document'] + self.assertEqual(data, expected_value) + def test_collections_are_persistent(self): """ Test that a collection can be picked up later. diff --git a/tests/vectordb/test_elasticsearch_db.py b/tests/vectordb/test_elasticsearch_db.py index 4a536f56..8e090718 100644 --- a/tests/vectordb/test_elasticsearch_db.py +++ b/tests/vectordb/test_elasticsearch_db.py @@ -1,14 +1,109 @@ import os import unittest +from unittest.mock import patch -from embedchain.config import ElasticsearchDBConfig +from embedchain import App +from embedchain.config import AppConfig, ElasticsearchDBConfig from embedchain.vectordb.elasticsearch import ElasticsearchDB - +from embedchain.embedder.gpt4all import GPT4AllEmbedder class TestEsDB(unittest.TestCase): - def setUp(self): - self.es_config = ElasticsearchDBConfig(es_url="http://mock-url.net") + + @patch("embedchain.vectordb.elasticsearch.Elasticsearch") + def test_setUp(self, mock_client): + self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200")) self.vector_dim = 384 + app_config = AppConfig(collection_name=False, collect_metrics=False) + self.app = App(config=app_config, db=self.db) + + # Assert that the Elasticsearch client is stored in the ElasticsearchDB class. + self.assertEqual(self.db.client, mock_client.return_value) + + @patch("embedchain.vectordb.elasticsearch.Elasticsearch") + def test_query(self, mock_client): + self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200")) + app_config = AppConfig(collection_name=False, collect_metrics=False) + self.app = App(config=app_config, db=self.db, embedder=GPT4AllEmbedder()) + + # Assert that the Elasticsearch client is stored in the ElasticsearchDB class. + self.assertEqual(self.db.client, mock_client.return_value) + + # Create some dummy data. + embeddings = [[1, 2, 3], [4, 5, 6]] + documents = ["This is a document.", "This is another document."] + metadatas = [{}, {}] + ids = ["doc_1", "doc_2"] + + # Add the data to the database. + self.db.add(embeddings, documents, metadatas, ids, skip_embedding=False) + + search_response = {"hits": + {"hits": + [ + { + "_source": {"text": "This is a document."}, + "_score": 0.9 + }, + { + "_source": {"text": "This is another document."}, + "_score": 0.8 + } + ] + } + } + + # Configure the mock client to return the mocked response. + mock_client.return_value.search.return_value = search_response + + # Query the database for the documents that are most similar to the query "This is a document". + query = ["This is a document"] + results = self.db.query(query, n_results=2, where={}, skip_embedding=False) + + # Assert that the results are correct. + self.assertEqual(results, ["This is a document.", "This is another document."]) + + @patch("embedchain.vectordb.elasticsearch.Elasticsearch") + def test_query_with_skip_embedding(self, mock_client): + self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200")) + app_config = AppConfig(collection_name=False, collect_metrics=False) + self.app = App(config=app_config, db=self.db) + + # Assert that the Elasticsearch client is stored in the ElasticsearchDB class. + self.assertEqual(self.db.client, mock_client.return_value) + + # Create some dummy data. + embeddings = [[1, 2, 3], [4, 5, 6]] + documents = ["This is a document.", "This is another document."] + metadatas = [{}, {}] + ids = ["doc_1", "doc_2"] + + # Add the data to the database. + self.db.add(embeddings, documents, metadatas, ids, skip_embedding=True) + + search_response = {"hits": + {"hits": + [ + { + "_source": {"text": "This is a document."}, + "_score": 0.9 + }, + { + "_source": {"text": "This is another document."}, + "_score": 0.8 + } + ] + } + } + + # Configure the mock client to return the mocked response. + mock_client.return_value.search.return_value = search_response + + # Query the database for the documents that are most similar to the query "This is a document". + query = ["This is a document"] + results = self.db.query(query, n_results=2, where={}, skip_embedding=True) + + # Assert that the results are correct. + self.assertEqual(results, ["This is a document.", "This is another document."]) def test_init_without_url(self): # Make sure it's not loaded from env