Add support for image dataset (#571)
Co-authored-by: Rupesh Bansal <rupeshbansal@Shankars-MacBook-Air.local>
This commit is contained in:
@@ -66,3 +66,6 @@ class BaseChunker(JSONSerializable):
|
||||
self.data_type = data_type
|
||||
|
||||
# TODO: This should be done during initialization. This means it has to be done in the child classes.
|
||||
|
||||
def get_word_count(self, documents):
|
||||
return sum([len(document.split(" ")) for document in documents])
|
||||
|
||||
63
embedchain/chunkers/images.py
Normal file
63
embedchain/chunkers/images.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import hashlib
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
|
||||
|
||||
class ImagesChunker(BaseChunker):
|
||||
"""Chunker for an Image."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=300, chunk_overlap=0, length_function=len)
|
||||
image_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(image_splitter)
|
||||
|
||||
def create_chunks(self, loader, src):
|
||||
"""
|
||||
Loads the image(s), and creates their corresponding embedding. This creates one chunk for each image
|
||||
|
||||
:param loader: The loader whose `load_data` method is used to create
|
||||
the raw data.
|
||||
:param src: The data to be handled by the loader. Can be a URL for
|
||||
remote sources or local content for local loaders.
|
||||
"""
|
||||
documents = []
|
||||
embeddings = []
|
||||
ids = []
|
||||
data_result = loader.load_data(src)
|
||||
data_records = data_result["data"]
|
||||
doc_id = data_result["doc_id"]
|
||||
metadatas = []
|
||||
for data in data_records:
|
||||
meta_data = data["meta_data"]
|
||||
# add data type to meta data to allow query using data type
|
||||
meta_data["data_type"] = self.data_type.value
|
||||
chunk_id = hashlib.sha256(meta_data["url"].encode()).hexdigest()
|
||||
ids.append(chunk_id)
|
||||
documents.append(data["content"])
|
||||
embeddings.append(data["embedding"])
|
||||
meta_data["doc_id"] = doc_id
|
||||
metadatas.append(meta_data)
|
||||
|
||||
return {
|
||||
"documents": documents,
|
||||
"embeddings": embeddings,
|
||||
"ids": ids,
|
||||
"metadatas": metadatas,
|
||||
"doc_id": doc_id,
|
||||
}
|
||||
|
||||
def get_word_count(self, documents):
|
||||
"""
|
||||
The number of chunks and the corresponding word count for an image is fixed to 1, as 1 embedding is created for
|
||||
each image
|
||||
"""
|
||||
return 1
|
||||
@@ -67,6 +67,7 @@ class BaseLlmConfig(BaseConfig):
|
||||
deployment_name: Optional[str] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
where: Dict[str, Any] = None,
|
||||
query_type: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for the LLM.
|
||||
@@ -112,6 +113,7 @@ class BaseLlmConfig(BaseConfig):
|
||||
self.top_p = top_p
|
||||
self.deployment_name = deployment_name
|
||||
self.system_prompt = system_prompt
|
||||
self.query_type = query_type
|
||||
|
||||
if self.validate_template(template):
|
||||
self.template = template
|
||||
|
||||
@@ -2,6 +2,7 @@ from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.chunkers.docs_site import DocsSiteChunker
|
||||
from embedchain.chunkers.docx_file import DocxFileChunker
|
||||
from embedchain.chunkers.mdx import MdxChunker
|
||||
from embedchain.chunkers.images import ImagesChunker
|
||||
from embedchain.chunkers.notion import NotionChunker
|
||||
from embedchain.chunkers.pdf_file import PdfFileChunker
|
||||
from embedchain.chunkers.qna_pair import QnaPairChunker
|
||||
@@ -16,6 +17,7 @@ from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.loaders.csv import CsvLoader
|
||||
from embedchain.loaders.docs_site_loader import DocsSiteLoader
|
||||
from embedchain.loaders.docx_file import DocxFileLoader
|
||||
from embedchain.loaders.images import ImagesLoader
|
||||
from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
|
||||
from embedchain.loaders.local_text import LocalTextLoader
|
||||
from embedchain.loaders.mdx import MdxLoader
|
||||
@@ -68,6 +70,7 @@ class DataFormatter(JSONSerializable):
|
||||
DataType.DOCS_SITE: DocsSiteLoader,
|
||||
DataType.CSV: CsvLoader,
|
||||
DataType.MDX: MdxLoader,
|
||||
DataType.IMAGES: ImagesLoader,
|
||||
}
|
||||
lazy_loaders = {DataType.NOTION}
|
||||
if data_type in loaders:
|
||||
@@ -102,11 +105,11 @@ class DataFormatter(JSONSerializable):
|
||||
DataType.QNA_PAIR: QnaPairChunker,
|
||||
DataType.TEXT: TextChunker,
|
||||
DataType.DOCX: DocxFileChunker,
|
||||
DataType.WEB_PAGE: WebPageChunker,
|
||||
DataType.DOCS_SITE: DocsSiteChunker,
|
||||
DataType.NOTION: NotionChunker,
|
||||
DataType.CSV: TableChunker,
|
||||
DataType.MDX: MdxChunker,
|
||||
DataType.IMAGES: ImagesChunker,
|
||||
}
|
||||
if data_type in chunker_classes:
|
||||
chunker_class: type = chunker_classes[data_type]
|
||||
|
||||
@@ -212,7 +212,7 @@ class EmbedChain(JSONSerializable):
|
||||
# Send anonymous telemetry
|
||||
if self.config.collect_metrics:
|
||||
# it's quicker to check the variable twice than to count words when they won't be submitted.
|
||||
word_count = sum([len(document.split(" ")) for document in documents])
|
||||
word_count = data_formatter.chunker.get_word_count(documents)
|
||||
|
||||
extra_metadata = {"data_type": data_type.value, "word_count": word_count, "chunks_count": new_chunks}
|
||||
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("add", extra_metadata))
|
||||
@@ -329,7 +329,6 @@ class EmbedChain(JSONSerializable):
|
||||
|
||||
# Create chunks
|
||||
embeddings_data = chunker.create_chunks(loader, src)
|
||||
|
||||
# spread chunking results
|
||||
documents = embeddings_data["documents"]
|
||||
metadatas = embeddings_data["metadatas"]
|
||||
@@ -393,7 +392,8 @@ class EmbedChain(JSONSerializable):
|
||||
# Count before, to calculate a delta in the end.
|
||||
chunks_before_addition = self.db.count()
|
||||
|
||||
self.db.add(documents=documents, metadatas=metadatas, ids=ids)
|
||||
self.db.add(embeddings=embeddings_data.get("embeddings", None), documents=documents, metadatas=metadatas,
|
||||
ids=ids, skip_embedding = (chunker.data_type == DataType.IMAGES))
|
||||
count_new_chunks = self.db.count() - chunks_before_addition
|
||||
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
|
||||
return list(documents), metadatas, ids, count_new_chunks
|
||||
@@ -434,10 +434,20 @@ class EmbedChain(JSONSerializable):
|
||||
if self.config.id is not None:
|
||||
where.update({"app_id": self.config.id})
|
||||
|
||||
# We cannot query the database with the input query in case of an image search. This is because we need
|
||||
# to bring down both the image and text to the same dimension to be able to compare them.
|
||||
db_query = input_query
|
||||
if config.query_type == "Images":
|
||||
# We import the clip processor here to make sure the package is not dependent on clip dependency even if the
|
||||
# image dataset is not being used
|
||||
from embedchain.models.clip_processor import ClipProcessor
|
||||
db_query = ClipProcessor.get_text_features(query=input_query)
|
||||
|
||||
contents = self.db.query(
|
||||
input_query=input_query,
|
||||
input_query=db_query,
|
||||
n_results=query_config.number_documents,
|
||||
where=where,
|
||||
skip_embedding = (config.query_type == "Images")
|
||||
)
|
||||
|
||||
return contents
|
||||
|
||||
@@ -191,6 +191,9 @@ class BaseLlm(JSONSerializable):
|
||||
prev_config = self.config.serialize()
|
||||
self.config = config
|
||||
|
||||
if config is not None and config.query_type == "Images":
|
||||
return contexts
|
||||
|
||||
if self.is_docs_site_instance:
|
||||
self.config.template = DOCS_SITE_PROMPT_TEMPLATE
|
||||
self.config.number_documents = 5
|
||||
|
||||
37
embedchain/loaders/images.py
Normal file
37
embedchain/loaders/images.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import os
|
||||
import logging
|
||||
import hashlib
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
|
||||
class ImagesLoader(BaseLoader):
|
||||
|
||||
def load_data(self, image_url):
|
||||
"""
|
||||
Loads images from the supplied directory/file and applies CLIP model transformation to represent these images
|
||||
in vector form
|
||||
|
||||
:param image_url: The URL from which the images are to be loaded
|
||||
"""
|
||||
# load model and image preprocessing
|
||||
from embedchain.models.clip_processor import ClipProcessor
|
||||
model, preprocess = ClipProcessor.load_model()
|
||||
if os.path.isfile(image_url):
|
||||
data = [ClipProcessor.get_image_features(image_url, model, preprocess)]
|
||||
else:
|
||||
data = []
|
||||
for filename in os.listdir(image_url):
|
||||
filepath = os.path.join(image_url, filename)
|
||||
try:
|
||||
data.append(ClipProcessor.get_image_features(filepath, model, preprocess))
|
||||
except Exception as e:
|
||||
# Log the file that was not loaded
|
||||
logging.exception("Failed to load the file {}. Exception {}".format(filepath, e))
|
||||
# Get the metadata like Size, Last Modified and Last Created timestamps
|
||||
image_path_metadata = [str(os.path.getsize(image_url)), str(os.path.getmtime(image_url)),
|
||||
str(os.path.getctime(image_url))]
|
||||
doc_id = hashlib.sha256((" ".join(image_path_metadata) + image_url).encode()).hexdigest()
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": data,
|
||||
}
|
||||
64
embedchain/models/clip_processor.py
Normal file
64
embedchain/models/clip_processor.py
Normal file
@@ -0,0 +1,64 @@
|
||||
try:
|
||||
import torch
|
||||
import clip
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
except ImportError:
|
||||
raise ImportError("Images requires extra dependencies. Install with `pip install embedchain[images]`") from None
|
||||
|
||||
MODEL_NAME = "ViT-B/32"
|
||||
|
||||
|
||||
class ClipProcessor:
|
||||
@staticmethod
|
||||
def load_model():
|
||||
"""Load data from a director of images."""
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# load model and image preprocessing
|
||||
model, preprocess = clip.load(MODEL_NAME, device=device, jit=False)
|
||||
return model, preprocess
|
||||
|
||||
@staticmethod
|
||||
def get_image_features(image_url, model, preprocess):
|
||||
"""
|
||||
Applies the CLIP model to evaluate the vector representation of the supplied image
|
||||
"""
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
try:
|
||||
# load image
|
||||
image = Image.open(image_url)
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError("The supplied file does not exist`")
|
||||
except UnidentifiedImageError:
|
||||
raise UnidentifiedImageError("The supplied file is not an image`")
|
||||
|
||||
# pre-process image
|
||||
processed_image = preprocess(image).unsqueeze(0).to(device)
|
||||
with torch.no_grad():
|
||||
image_features = model.encode_image(processed_image)
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
image_features = image_features.cpu().detach().numpy().tolist()[0]
|
||||
meta_data = {
|
||||
"url": image_url
|
||||
}
|
||||
return {
|
||||
"content": image_url,
|
||||
"embedding": image_features,
|
||||
"meta_data": meta_data
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_text_features(query):
|
||||
"""
|
||||
Applies the CLIP model to evaluate the vector representation of the supplied text
|
||||
"""
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
model, preprocess = ClipProcessor.load_model()
|
||||
text = clip.tokenize(query).to(device)
|
||||
with torch.no_grad():
|
||||
text_features = model.encode_text(text)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
return text_features.cpu().numpy().tolist()[0]
|
||||
@@ -23,6 +23,7 @@ class IndirectDataType(Enum):
|
||||
NOTION = "notion"
|
||||
CSV = "csv"
|
||||
MDX = "mdx"
|
||||
IMAGES = "images"
|
||||
|
||||
|
||||
class SpecialDataType(Enum):
|
||||
@@ -45,3 +46,4 @@ class DataType(Enum):
|
||||
CSV = IndirectDataType.CSV.value
|
||||
MDX = IndirectDataType.MDX.value
|
||||
QNA_PAIR = SpecialDataType.QNA_PAIR.value
|
||||
IMAGES = IndirectDataType.IMAGES.value
|
||||
|
||||
@@ -115,7 +115,8 @@ class ChromaDB(BaseVectorDB):
|
||||
def get_advanced(self, where):
|
||||
return self.collection.get(where=where, limit=1)
|
||||
|
||||
def add(self, documents: List[str], metadatas: List[object], ids: List[str]) -> Any:
|
||||
def add(self, embeddings: List[List[float]], documents: List[str], metadatas: List[object],
|
||||
ids: List[str], skip_embedding: bool) -> Any:
|
||||
"""
|
||||
Add vectors to chroma database
|
||||
|
||||
@@ -126,7 +127,10 @@ class ChromaDB(BaseVectorDB):
|
||||
:param ids: ids
|
||||
:type ids: List[str]
|
||||
"""
|
||||
self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
|
||||
if skip_embedding:
|
||||
self.collection.add(embeddings=embeddings, documents=documents, metadatas=metadatas, ids=ids)
|
||||
else:
|
||||
self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
|
||||
|
||||
def _format_result(self, results: QueryResult) -> list[tuple[Document, float]]:
|
||||
"""
|
||||
@@ -146,7 +150,7 @@ class ChromaDB(BaseVectorDB):
|
||||
)
|
||||
]
|
||||
|
||||
def query(self, input_query: List[str], n_results: int, where: Dict[str, Any]) -> List[str]:
|
||||
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
|
||||
"""
|
||||
Query contents from vector data base based on vector similarity
|
||||
|
||||
@@ -161,19 +165,27 @@ class ChromaDB(BaseVectorDB):
|
||||
:rtype: List[str]
|
||||
"""
|
||||
try:
|
||||
result = self.collection.query(
|
||||
query_texts=[
|
||||
input_query,
|
||||
],
|
||||
n_results=n_results,
|
||||
where=where,
|
||||
)
|
||||
if skip_embedding:
|
||||
result = self.collection.query(
|
||||
query_embeddings=[
|
||||
input_query,
|
||||
],
|
||||
n_results=n_results,
|
||||
where=where,
|
||||
)
|
||||
else:
|
||||
result = self.collection.query(
|
||||
query_texts=[
|
||||
input_query,
|
||||
],
|
||||
n_results=n_results,
|
||||
where=where,
|
||||
)
|
||||
except InvalidDimensionException as e:
|
||||
raise InvalidDimensionException(
|
||||
e.message()
|
||||
+ ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501
|
||||
+ ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501
|
||||
) from None
|
||||
|
||||
results_formatted = self._format_result(result)
|
||||
contents = [result[0].page_content for result in results_formatted]
|
||||
return contents
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Set
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
try:
|
||||
from elasticsearch import Elasticsearch
|
||||
@@ -100,9 +100,10 @@ class ElasticsearchDB(BaseVectorDB):
|
||||
ids = [doc["_id"] for doc in docs]
|
||||
return {"ids": set(ids)}
|
||||
|
||||
def add(self, documents: List[str], metadatas: List[object], ids: List[str]):
|
||||
"""add data in vector database
|
||||
|
||||
def add(self, embeddings: List[List[float]], documents: List[str], metadatas: List[object],
|
||||
ids: List[str], skip_embedding: bool) -> Any:
|
||||
"""
|
||||
add data in vector database
|
||||
:param documents: list of texts to add
|
||||
:type documents: List[str]
|
||||
:param metadatas: list of metadata associated with docs
|
||||
@@ -112,7 +113,9 @@ class ElasticsearchDB(BaseVectorDB):
|
||||
"""
|
||||
|
||||
docs = []
|
||||
embeddings = self.embedder.embedding_fn(documents)
|
||||
if not skip_embedding:
|
||||
embeddings = self.embedder.embedding_fn(documents)
|
||||
|
||||
for id, text, metadata, embeddings in zip(ids, documents, metadatas, embeddings):
|
||||
docs.append(
|
||||
{
|
||||
@@ -124,7 +127,7 @@ class ElasticsearchDB(BaseVectorDB):
|
||||
bulk(self.client, docs)
|
||||
self.client.indices.refresh(index=self._get_index())
|
||||
|
||||
def query(self, input_query: List[str], n_results: int, where: Dict[str, any]) -> List[str]:
|
||||
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
|
||||
"""
|
||||
query contents from vector data base based on vector similarity
|
||||
|
||||
@@ -137,8 +140,12 @@ class ElasticsearchDB(BaseVectorDB):
|
||||
:return: Database contents that are the result of the query
|
||||
:rtype: List[str]
|
||||
"""
|
||||
input_query_vector = self.embedder.embedding_fn(input_query)
|
||||
query_vector = input_query_vector[0]
|
||||
if skip_embedding:
|
||||
query_vector = input_query
|
||||
else:
|
||||
input_query_vector = self.embedder.embedding_fn(input_query)
|
||||
query_vector = input_query_vector[0]
|
||||
|
||||
query = {
|
||||
"script_score": {
|
||||
"query": {"bool": {"must": [{"exists": {"field": "text"}}]}},
|
||||
|
||||
@@ -106,8 +106,9 @@ fastapi-poe = { version = "0.0.16", optional = true }
|
||||
discord = { version = "^2.3.2", optional = true }
|
||||
slack-sdk = { version = "3.21.3", optional = true }
|
||||
docx2txt = "^0.8"
|
||||
|
||||
|
||||
clip = {git = "https://github.com/openai/CLIP.git#a1d0717", optional = true}
|
||||
ftfy = { version = "6.1.1", optional = true }
|
||||
regex = { version = "2023.8.8", optional = true }
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
black = "^23.3.0"
|
||||
@@ -130,6 +131,7 @@ poe = ["fastapi-poe"]
|
||||
discord = ["discord"]
|
||||
slack = ["slack-sdk", "flask"]
|
||||
whatsapp = ["twilio", "flask"]
|
||||
images = ["torch", "ftfy", "regex", "clip"]
|
||||
|
||||
[tool.poetry.group.docs.dependencies]
|
||||
|
||||
|
||||
72
tests/chunkers/test_image_chunker.py
Normal file
72
tests/chunkers/test_image_chunker.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import unittest
|
||||
|
||||
from embedchain.chunkers.images import ImagesChunker
|
||||
from embedchain.config import ChunkerConfig
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
|
||||
class TestImageChunker(unittest.TestCase):
|
||||
def test_chunks(self):
|
||||
"""
|
||||
Test the chunks generated by TextChunker.
|
||||
# TODO: Not a very precise test.
|
||||
"""
|
||||
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
|
||||
chunker = ImagesChunker(config=chunker_config)
|
||||
# Data type must be set manually in the test
|
||||
chunker.set_data_type(DataType.IMAGES)
|
||||
|
||||
image_path = "./tmp/image.jpeg"
|
||||
result = chunker.create_chunks(MockLoader(), image_path)
|
||||
|
||||
expected_chunks = {'doc_id': '123',
|
||||
'documents': [image_path],
|
||||
'embeddings': ['embedding'],
|
||||
'ids': ['140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe'],
|
||||
'metadatas': [{'data_type': 'images', 'doc_id': '123', 'url': 'none'}]}
|
||||
self.assertEqual(expected_chunks, result)
|
||||
|
||||
def test_chunks_with_default_config(self):
|
||||
"""
|
||||
Test the chunks generated by ImageChunker with default config.
|
||||
"""
|
||||
chunker = ImagesChunker()
|
||||
# Data type must be set manually in the test
|
||||
chunker.set_data_type(DataType.IMAGES)
|
||||
|
||||
image_path = "./tmp/image.jpeg"
|
||||
result = chunker.create_chunks(MockLoader(), image_path)
|
||||
|
||||
expected_chunks = {'doc_id': '123',
|
||||
'documents': [image_path],
|
||||
'embeddings': ['embedding'],
|
||||
'ids': ['140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe'],
|
||||
'metadatas': [{'data_type': 'images', 'doc_id': '123', 'url': 'none'}]}
|
||||
self.assertEqual(expected_chunks, result)
|
||||
|
||||
def test_word_count(self):
|
||||
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
|
||||
chunker = ImagesChunker(config=chunker_config)
|
||||
chunker.set_data_type(DataType.IMAGES)
|
||||
|
||||
document = [["ab cd", "ef gh"], ["ij kl", "mn op"]]
|
||||
result = chunker.get_word_count(document)
|
||||
self.assertEqual(result, 1)
|
||||
|
||||
|
||||
class MockLoader:
|
||||
def load_data(self, src):
|
||||
"""
|
||||
Mock loader that returns a list of data dictionaries.
|
||||
Adjust this method to return different data for testing.
|
||||
"""
|
||||
return {
|
||||
"doc_id": "123",
|
||||
"data": [
|
||||
{
|
||||
"content": src,
|
||||
"embedding": "embedding",
|
||||
"meta_data": {"url": "none"},
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -62,6 +62,15 @@ class TestTextChunker(unittest.TestCase):
|
||||
|
||||
self.assertEqual(len(documents), len(text))
|
||||
|
||||
def test_word_count(self):
|
||||
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
|
||||
chunker = TextChunker(config=chunker_config)
|
||||
chunker.set_data_type(DataType.TEXT)
|
||||
|
||||
document = ["ab cd", "ef gh"]
|
||||
result = chunker.get_word_count(document)
|
||||
self.assertEqual(result, 4)
|
||||
|
||||
|
||||
class MockLoader:
|
||||
def load_data(self, src):
|
||||
|
||||
BIN
tests/models/image.jpg
Normal file
BIN
tests/models/image.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 27 KiB |
55
tests/models/test_clip_processor.py
Normal file
55
tests/models/test_clip_processor.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import tempfile
|
||||
import unittest
|
||||
import os
|
||||
import urllib
|
||||
from PIL import Image
|
||||
from embedchain.models.clip_processor import ClipProcessor
|
||||
|
||||
|
||||
class ClipProcessorTest(unittest.TestCase):
|
||||
|
||||
def test_load_model(self):
|
||||
# Test that the `load_model()` method loads the CLIP model and image preprocessing correctly.
|
||||
model, preprocess = ClipProcessor.load_model()
|
||||
|
||||
# Assert that the model is not None.
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
# Assert that the preprocess is not None.
|
||||
self.assertIsNotNone(preprocess)
|
||||
|
||||
def test_get_image_features(self):
|
||||
# Clone the image to a temporary folder.
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
urllib.request.urlretrieve(
|
||||
'https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg',
|
||||
"image.jpg")
|
||||
|
||||
image = Image.open("image.jpg")
|
||||
image.save(os.path.join(tmp_dir, "image.jpg"))
|
||||
|
||||
# Get the image features.
|
||||
model, preprocess = ClipProcessor.load_model()
|
||||
ClipProcessor.get_image_features(os.path.join(tmp_dir, "image.jpg"), model, preprocess)
|
||||
|
||||
# Delete the temporary file.
|
||||
os.remove(os.path.join(tmp_dir, "image.jpg"))
|
||||
|
||||
# Assert that the test passes.
|
||||
self.assertTrue(True)
|
||||
|
||||
def test_get_text_features(self):
|
||||
# Test that the `get_text_features()` method returns a list containing the text embedding.
|
||||
query = "This is a text query."
|
||||
model, preprocess = ClipProcessor.load_model()
|
||||
|
||||
text_features = ClipProcessor.get_text_features(query)
|
||||
|
||||
# Assert that the text embedding is not None.
|
||||
self.assertIsNotNone(text_features)
|
||||
|
||||
# Assert that the text embedding is a list of floats.
|
||||
self.assertIsInstance(text_features, list)
|
||||
|
||||
# Assert that the text embedding has the correct length.
|
||||
self.assertEqual(len(text_features), 512)
|
||||
@@ -186,6 +186,34 @@ class TestChromaDbCollection(unittest.TestCase):
|
||||
# Should still be 1, not 2.
|
||||
self.assertEqual(app.db.count(), 1)
|
||||
|
||||
def test_add_with_skip_embedding(self):
|
||||
"""
|
||||
Test that changes to one collection do not affect the other collection
|
||||
"""
|
||||
# Start with a clean app
|
||||
self.app_with_settings.reset()
|
||||
# app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
|
||||
# Collection should be empty when created
|
||||
self.assertEqual(self.app_with_settings.db.count(), 0)
|
||||
|
||||
self.app_with_settings.db.add(embeddings=[[0, 0, 0]], documents=["document"], metadatas=[{"value": "somevalue"}], ids=["id"], skip_embedding=True)
|
||||
# After adding, should contain one item
|
||||
self.assertEqual(self.app_with_settings.db.count(), 1)
|
||||
|
||||
# Validate if the get utility of the database is working as expected
|
||||
data = self.app_with_settings.db.get(["id"], limit=1)
|
||||
expected_value = {'documents': ['document'],
|
||||
'embeddings': None,
|
||||
'ids': ['id'],
|
||||
'metadatas': [{'value': 'somevalue'}]}
|
||||
self.assertEqual(data, expected_value)
|
||||
|
||||
# Validate if the query utility of the database is working as expected
|
||||
data = self.app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True)
|
||||
expected_value = ['document']
|
||||
self.assertEqual(data, expected_value)
|
||||
|
||||
def test_collections_are_persistent(self):
|
||||
"""
|
||||
Test that a collection can be picked up later.
|
||||
|
||||
@@ -1,14 +1,109 @@
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from embedchain.config import ElasticsearchDBConfig
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, ElasticsearchDBConfig
|
||||
from embedchain.vectordb.elasticsearch import ElasticsearchDB
|
||||
|
||||
from embedchain.embedder.gpt4all import GPT4AllEmbedder
|
||||
|
||||
class TestEsDB(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.es_config = ElasticsearchDBConfig(es_url="http://mock-url.net")
|
||||
|
||||
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
|
||||
def test_setUp(self, mock_client):
|
||||
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
|
||||
self.vector_dim = 384
|
||||
app_config = AppConfig(collection_name=False, collect_metrics=False)
|
||||
self.app = App(config=app_config, db=self.db)
|
||||
|
||||
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
|
||||
self.assertEqual(self.db.client, mock_client.return_value)
|
||||
|
||||
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
|
||||
def test_query(self, mock_client):
|
||||
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
|
||||
app_config = AppConfig(collection_name=False, collect_metrics=False)
|
||||
self.app = App(config=app_config, db=self.db, embedder=GPT4AllEmbedder())
|
||||
|
||||
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
|
||||
self.assertEqual(self.db.client, mock_client.return_value)
|
||||
|
||||
# Create some dummy data.
|
||||
embeddings = [[1, 2, 3], [4, 5, 6]]
|
||||
documents = ["This is a document.", "This is another document."]
|
||||
metadatas = [{}, {}]
|
||||
ids = ["doc_1", "doc_2"]
|
||||
|
||||
# Add the data to the database.
|
||||
self.db.add(embeddings, documents, metadatas, ids, skip_embedding=False)
|
||||
|
||||
search_response = {"hits":
|
||||
{"hits":
|
||||
[
|
||||
{
|
||||
"_source": {"text": "This is a document."},
|
||||
"_score": 0.9
|
||||
},
|
||||
{
|
||||
"_source": {"text": "This is another document."},
|
||||
"_score": 0.8
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# Configure the mock client to return the mocked response.
|
||||
mock_client.return_value.search.return_value = search_response
|
||||
|
||||
# Query the database for the documents that are most similar to the query "This is a document".
|
||||
query = ["This is a document"]
|
||||
results = self.db.query(query, n_results=2, where={}, skip_embedding=False)
|
||||
|
||||
# Assert that the results are correct.
|
||||
self.assertEqual(results, ["This is a document.", "This is another document."])
|
||||
|
||||
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
|
||||
def test_query_with_skip_embedding(self, mock_client):
|
||||
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
|
||||
app_config = AppConfig(collection_name=False, collect_metrics=False)
|
||||
self.app = App(config=app_config, db=self.db)
|
||||
|
||||
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
|
||||
self.assertEqual(self.db.client, mock_client.return_value)
|
||||
|
||||
# Create some dummy data.
|
||||
embeddings = [[1, 2, 3], [4, 5, 6]]
|
||||
documents = ["This is a document.", "This is another document."]
|
||||
metadatas = [{}, {}]
|
||||
ids = ["doc_1", "doc_2"]
|
||||
|
||||
# Add the data to the database.
|
||||
self.db.add(embeddings, documents, metadatas, ids, skip_embedding=True)
|
||||
|
||||
search_response = {"hits":
|
||||
{"hits":
|
||||
[
|
||||
{
|
||||
"_source": {"text": "This is a document."},
|
||||
"_score": 0.9
|
||||
},
|
||||
{
|
||||
"_source": {"text": "This is another document."},
|
||||
"_score": 0.8
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# Configure the mock client to return the mocked response.
|
||||
mock_client.return_value.search.return_value = search_response
|
||||
|
||||
# Query the database for the documents that are most similar to the query "This is a document".
|
||||
query = ["This is a document"]
|
||||
results = self.db.query(query, n_results=2, where={}, skip_embedding=True)
|
||||
|
||||
# Assert that the results are correct.
|
||||
self.assertEqual(results, ["This is a document.", "This is another document."])
|
||||
|
||||
def test_init_without_url(self):
|
||||
# Make sure it's not loaded from env
|
||||
|
||||
Reference in New Issue
Block a user