Add support for image dataset (#571)

Co-authored-by: Rupesh Bansal <rupeshbansal@Shankars-MacBook-Air.local>
This commit is contained in:
Rupesh Bansal
2023-10-04 09:50:40 +05:30
committed by GitHub
parent 55e9a1cbd6
commit d0af018b8d
19 changed files with 498 additions and 31 deletions

View File

@@ -66,3 +66,6 @@ class BaseChunker(JSONSerializable):
self.data_type = data_type
# TODO: This should be done during initialization. This means it has to be done in the child classes.
def get_word_count(self, documents):
return sum([len(document.split(" ")) for document in documents])

View File

@@ -0,0 +1,63 @@
import hashlib
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
class ImagesChunker(BaseChunker):
"""Chunker for an Image."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=300, chunk_overlap=0, length_function=len)
image_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(image_splitter)
def create_chunks(self, loader, src):
"""
Loads the image(s), and creates their corresponding embedding. This creates one chunk for each image
:param loader: The loader whose `load_data` method is used to create
the raw data.
:param src: The data to be handled by the loader. Can be a URL for
remote sources or local content for local loaders.
"""
documents = []
embeddings = []
ids = []
data_result = loader.load_data(src)
data_records = data_result["data"]
doc_id = data_result["doc_id"]
metadatas = []
for data in data_records:
meta_data = data["meta_data"]
# add data type to meta data to allow query using data type
meta_data["data_type"] = self.data_type.value
chunk_id = hashlib.sha256(meta_data["url"].encode()).hexdigest()
ids.append(chunk_id)
documents.append(data["content"])
embeddings.append(data["embedding"])
meta_data["doc_id"] = doc_id
metadatas.append(meta_data)
return {
"documents": documents,
"embeddings": embeddings,
"ids": ids,
"metadatas": metadatas,
"doc_id": doc_id,
}
def get_word_count(self, documents):
"""
The number of chunks and the corresponding word count for an image is fixed to 1, as 1 embedding is created for
each image
"""
return 1

View File

@@ -67,6 +67,7 @@ class BaseLlmConfig(BaseConfig):
deployment_name: Optional[str] = None,
system_prompt: Optional[str] = None,
where: Dict[str, Any] = None,
query_type: Optional[str] = None
):
"""
Initializes a configuration class instance for the LLM.
@@ -112,6 +113,7 @@ class BaseLlmConfig(BaseConfig):
self.top_p = top_p
self.deployment_name = deployment_name
self.system_prompt = system_prompt
self.query_type = query_type
if self.validate_template(template):
self.template = template

View File

@@ -2,6 +2,7 @@ from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.chunkers.docs_site import DocsSiteChunker
from embedchain.chunkers.docx_file import DocxFileChunker
from embedchain.chunkers.mdx import MdxChunker
from embedchain.chunkers.images import ImagesChunker
from embedchain.chunkers.notion import NotionChunker
from embedchain.chunkers.pdf_file import PdfFileChunker
from embedchain.chunkers.qna_pair import QnaPairChunker
@@ -16,6 +17,7 @@ from embedchain.loaders.base_loader import BaseLoader
from embedchain.loaders.csv import CsvLoader
from embedchain.loaders.docs_site_loader import DocsSiteLoader
from embedchain.loaders.docx_file import DocxFileLoader
from embedchain.loaders.images import ImagesLoader
from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
from embedchain.loaders.local_text import LocalTextLoader
from embedchain.loaders.mdx import MdxLoader
@@ -68,6 +70,7 @@ class DataFormatter(JSONSerializable):
DataType.DOCS_SITE: DocsSiteLoader,
DataType.CSV: CsvLoader,
DataType.MDX: MdxLoader,
DataType.IMAGES: ImagesLoader,
}
lazy_loaders = {DataType.NOTION}
if data_type in loaders:
@@ -102,11 +105,11 @@ class DataFormatter(JSONSerializable):
DataType.QNA_PAIR: QnaPairChunker,
DataType.TEXT: TextChunker,
DataType.DOCX: DocxFileChunker,
DataType.WEB_PAGE: WebPageChunker,
DataType.DOCS_SITE: DocsSiteChunker,
DataType.NOTION: NotionChunker,
DataType.CSV: TableChunker,
DataType.MDX: MdxChunker,
DataType.IMAGES: ImagesChunker,
}
if data_type in chunker_classes:
chunker_class: type = chunker_classes[data_type]

View File

@@ -212,7 +212,7 @@ class EmbedChain(JSONSerializable):
# Send anonymous telemetry
if self.config.collect_metrics:
# it's quicker to check the variable twice than to count words when they won't be submitted.
word_count = sum([len(document.split(" ")) for document in documents])
word_count = data_formatter.chunker.get_word_count(documents)
extra_metadata = {"data_type": data_type.value, "word_count": word_count, "chunks_count": new_chunks}
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("add", extra_metadata))
@@ -329,7 +329,6 @@ class EmbedChain(JSONSerializable):
# Create chunks
embeddings_data = chunker.create_chunks(loader, src)
# spread chunking results
documents = embeddings_data["documents"]
metadatas = embeddings_data["metadatas"]
@@ -393,7 +392,8 @@ class EmbedChain(JSONSerializable):
# Count before, to calculate a delta in the end.
chunks_before_addition = self.db.count()
self.db.add(documents=documents, metadatas=metadatas, ids=ids)
self.db.add(embeddings=embeddings_data.get("embeddings", None), documents=documents, metadatas=metadatas,
ids=ids, skip_embedding = (chunker.data_type == DataType.IMAGES))
count_new_chunks = self.db.count() - chunks_before_addition
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
return list(documents), metadatas, ids, count_new_chunks
@@ -434,10 +434,20 @@ class EmbedChain(JSONSerializable):
if self.config.id is not None:
where.update({"app_id": self.config.id})
# We cannot query the database with the input query in case of an image search. This is because we need
# to bring down both the image and text to the same dimension to be able to compare them.
db_query = input_query
if config.query_type == "Images":
# We import the clip processor here to make sure the package is not dependent on clip dependency even if the
# image dataset is not being used
from embedchain.models.clip_processor import ClipProcessor
db_query = ClipProcessor.get_text_features(query=input_query)
contents = self.db.query(
input_query=input_query,
input_query=db_query,
n_results=query_config.number_documents,
where=where,
skip_embedding = (config.query_type == "Images")
)
return contents

View File

@@ -191,6 +191,9 @@ class BaseLlm(JSONSerializable):
prev_config = self.config.serialize()
self.config = config
if config is not None and config.query_type == "Images":
return contexts
if self.is_docs_site_instance:
self.config.template = DOCS_SITE_PROMPT_TEMPLATE
self.config.number_documents = 5

View File

@@ -0,0 +1,37 @@
import os
import logging
import hashlib
from embedchain.loaders.base_loader import BaseLoader
class ImagesLoader(BaseLoader):
def load_data(self, image_url):
"""
Loads images from the supplied directory/file and applies CLIP model transformation to represent these images
in vector form
:param image_url: The URL from which the images are to be loaded
"""
# load model and image preprocessing
from embedchain.models.clip_processor import ClipProcessor
model, preprocess = ClipProcessor.load_model()
if os.path.isfile(image_url):
data = [ClipProcessor.get_image_features(image_url, model, preprocess)]
else:
data = []
for filename in os.listdir(image_url):
filepath = os.path.join(image_url, filename)
try:
data.append(ClipProcessor.get_image_features(filepath, model, preprocess))
except Exception as e:
# Log the file that was not loaded
logging.exception("Failed to load the file {}. Exception {}".format(filepath, e))
# Get the metadata like Size, Last Modified and Last Created timestamps
image_path_metadata = [str(os.path.getsize(image_url)), str(os.path.getmtime(image_url)),
str(os.path.getctime(image_url))]
doc_id = hashlib.sha256((" ".join(image_path_metadata) + image_url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": data,
}

View File

@@ -0,0 +1,64 @@
try:
import torch
import clip
from PIL import Image, UnidentifiedImageError
except ImportError:
raise ImportError("Images requires extra dependencies. Install with `pip install embedchain[images]`") from None
MODEL_NAME = "ViT-B/32"
class ClipProcessor:
@staticmethod
def load_model():
"""Load data from a director of images."""
device = "cuda" if torch.cuda.is_available() else "cpu"
# load model and image preprocessing
model, preprocess = clip.load(MODEL_NAME, device=device, jit=False)
return model, preprocess
@staticmethod
def get_image_features(image_url, model, preprocess):
"""
Applies the CLIP model to evaluate the vector representation of the supplied image
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
try:
# load image
image = Image.open(image_url)
except FileNotFoundError:
raise FileNotFoundError("The supplied file does not exist`")
except UnidentifiedImageError:
raise UnidentifiedImageError("The supplied file is not an image`")
# pre-process image
processed_image = preprocess(image).unsqueeze(0).to(device)
with torch.no_grad():
image_features = model.encode_image(processed_image)
image_features /= image_features.norm(dim=-1, keepdim=True)
image_features = image_features.cpu().detach().numpy().tolist()[0]
meta_data = {
"url": image_url
}
return {
"content": image_url,
"embedding": image_features,
"meta_data": meta_data
}
@staticmethod
def get_text_features(query):
"""
Applies the CLIP model to evaluate the vector representation of the supplied text
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = ClipProcessor.load_model()
text = clip.tokenize(query).to(device)
with torch.no_grad():
text_features = model.encode_text(text)
text_features /= text_features.norm(dim=-1, keepdim=True)
return text_features.cpu().numpy().tolist()[0]

View File

@@ -23,6 +23,7 @@ class IndirectDataType(Enum):
NOTION = "notion"
CSV = "csv"
MDX = "mdx"
IMAGES = "images"
class SpecialDataType(Enum):
@@ -45,3 +46,4 @@ class DataType(Enum):
CSV = IndirectDataType.CSV.value
MDX = IndirectDataType.MDX.value
QNA_PAIR = SpecialDataType.QNA_PAIR.value
IMAGES = IndirectDataType.IMAGES.value

View File

@@ -115,7 +115,8 @@ class ChromaDB(BaseVectorDB):
def get_advanced(self, where):
return self.collection.get(where=where, limit=1)
def add(self, documents: List[str], metadatas: List[object], ids: List[str]) -> Any:
def add(self, embeddings: List[List[float]], documents: List[str], metadatas: List[object],
ids: List[str], skip_embedding: bool) -> Any:
"""
Add vectors to chroma database
@@ -126,7 +127,10 @@ class ChromaDB(BaseVectorDB):
:param ids: ids
:type ids: List[str]
"""
self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
if skip_embedding:
self.collection.add(embeddings=embeddings, documents=documents, metadatas=metadatas, ids=ids)
else:
self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
def _format_result(self, results: QueryResult) -> list[tuple[Document, float]]:
"""
@@ -146,7 +150,7 @@ class ChromaDB(BaseVectorDB):
)
]
def query(self, input_query: List[str], n_results: int, where: Dict[str, Any]) -> List[str]:
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
"""
Query contents from vector data base based on vector similarity
@@ -161,19 +165,27 @@ class ChromaDB(BaseVectorDB):
:rtype: List[str]
"""
try:
result = self.collection.query(
query_texts=[
input_query,
],
n_results=n_results,
where=where,
)
if skip_embedding:
result = self.collection.query(
query_embeddings=[
input_query,
],
n_results=n_results,
where=where,
)
else:
result = self.collection.query(
query_texts=[
input_query,
],
n_results=n_results,
where=where,
)
except InvalidDimensionException as e:
raise InvalidDimensionException(
e.message()
+ ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501
+ ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501
) from None
results_formatted = self._format_result(result)
contents = [result[0].page_content for result in results_formatted]
return contents

View File

@@ -1,5 +1,5 @@
import logging
from typing import Dict, List, Optional, Set
from typing import Any, Dict, List, Optional, Set
try:
from elasticsearch import Elasticsearch
@@ -100,9 +100,10 @@ class ElasticsearchDB(BaseVectorDB):
ids = [doc["_id"] for doc in docs]
return {"ids": set(ids)}
def add(self, documents: List[str], metadatas: List[object], ids: List[str]):
"""add data in vector database
def add(self, embeddings: List[List[float]], documents: List[str], metadatas: List[object],
ids: List[str], skip_embedding: bool) -> Any:
"""
add data in vector database
:param documents: list of texts to add
:type documents: List[str]
:param metadatas: list of metadata associated with docs
@@ -112,7 +113,9 @@ class ElasticsearchDB(BaseVectorDB):
"""
docs = []
embeddings = self.embedder.embedding_fn(documents)
if not skip_embedding:
embeddings = self.embedder.embedding_fn(documents)
for id, text, metadata, embeddings in zip(ids, documents, metadatas, embeddings):
docs.append(
{
@@ -124,7 +127,7 @@ class ElasticsearchDB(BaseVectorDB):
bulk(self.client, docs)
self.client.indices.refresh(index=self._get_index())
def query(self, input_query: List[str], n_results: int, where: Dict[str, any]) -> List[str]:
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
"""
query contents from vector data base based on vector similarity
@@ -137,8 +140,12 @@ class ElasticsearchDB(BaseVectorDB):
:return: Database contents that are the result of the query
:rtype: List[str]
"""
input_query_vector = self.embedder.embedding_fn(input_query)
query_vector = input_query_vector[0]
if skip_embedding:
query_vector = input_query
else:
input_query_vector = self.embedder.embedding_fn(input_query)
query_vector = input_query_vector[0]
query = {
"script_score": {
"query": {"bool": {"must": [{"exists": {"field": "text"}}]}},

BIN
image.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

View File

@@ -106,8 +106,9 @@ fastapi-poe = { version = "0.0.16", optional = true }
discord = { version = "^2.3.2", optional = true }
slack-sdk = { version = "3.21.3", optional = true }
docx2txt = "^0.8"
clip = {git = "https://github.com/openai/CLIP.git#a1d0717", optional = true}
ftfy = { version = "6.1.1", optional = true }
regex = { version = "2023.8.8", optional = true }
[tool.poetry.group.dev.dependencies]
black = "^23.3.0"
@@ -130,6 +131,7 @@ poe = ["fastapi-poe"]
discord = ["discord"]
slack = ["slack-sdk", "flask"]
whatsapp = ["twilio", "flask"]
images = ["torch", "ftfy", "regex", "clip"]
[tool.poetry.group.docs.dependencies]

View File

@@ -0,0 +1,72 @@
import unittest
from embedchain.chunkers.images import ImagesChunker
from embedchain.config import ChunkerConfig
from embedchain.models.data_type import DataType
class TestImageChunker(unittest.TestCase):
def test_chunks(self):
"""
Test the chunks generated by TextChunker.
# TODO: Not a very precise test.
"""
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
chunker = ImagesChunker(config=chunker_config)
# Data type must be set manually in the test
chunker.set_data_type(DataType.IMAGES)
image_path = "./tmp/image.jpeg"
result = chunker.create_chunks(MockLoader(), image_path)
expected_chunks = {'doc_id': '123',
'documents': [image_path],
'embeddings': ['embedding'],
'ids': ['140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe'],
'metadatas': [{'data_type': 'images', 'doc_id': '123', 'url': 'none'}]}
self.assertEqual(expected_chunks, result)
def test_chunks_with_default_config(self):
"""
Test the chunks generated by ImageChunker with default config.
"""
chunker = ImagesChunker()
# Data type must be set manually in the test
chunker.set_data_type(DataType.IMAGES)
image_path = "./tmp/image.jpeg"
result = chunker.create_chunks(MockLoader(), image_path)
expected_chunks = {'doc_id': '123',
'documents': [image_path],
'embeddings': ['embedding'],
'ids': ['140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe'],
'metadatas': [{'data_type': 'images', 'doc_id': '123', 'url': 'none'}]}
self.assertEqual(expected_chunks, result)
def test_word_count(self):
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
chunker = ImagesChunker(config=chunker_config)
chunker.set_data_type(DataType.IMAGES)
document = [["ab cd", "ef gh"], ["ij kl", "mn op"]]
result = chunker.get_word_count(document)
self.assertEqual(result, 1)
class MockLoader:
def load_data(self, src):
"""
Mock loader that returns a list of data dictionaries.
Adjust this method to return different data for testing.
"""
return {
"doc_id": "123",
"data": [
{
"content": src,
"embedding": "embedding",
"meta_data": {"url": "none"},
}
],
}

View File

@@ -62,6 +62,15 @@ class TestTextChunker(unittest.TestCase):
self.assertEqual(len(documents), len(text))
def test_word_count(self):
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
chunker = TextChunker(config=chunker_config)
chunker.set_data_type(DataType.TEXT)
document = ["ab cd", "ef gh"]
result = chunker.get_word_count(document)
self.assertEqual(result, 4)
class MockLoader:
def load_data(self, src):

BIN
tests/models/image.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

View File

@@ -0,0 +1,55 @@
import tempfile
import unittest
import os
import urllib
from PIL import Image
from embedchain.models.clip_processor import ClipProcessor
class ClipProcessorTest(unittest.TestCase):
def test_load_model(self):
# Test that the `load_model()` method loads the CLIP model and image preprocessing correctly.
model, preprocess = ClipProcessor.load_model()
# Assert that the model is not None.
self.assertIsNotNone(model)
# Assert that the preprocess is not None.
self.assertIsNotNone(preprocess)
def test_get_image_features(self):
# Clone the image to a temporary folder.
with tempfile.TemporaryDirectory() as tmp_dir:
urllib.request.urlretrieve(
'https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg',
"image.jpg")
image = Image.open("image.jpg")
image.save(os.path.join(tmp_dir, "image.jpg"))
# Get the image features.
model, preprocess = ClipProcessor.load_model()
ClipProcessor.get_image_features(os.path.join(tmp_dir, "image.jpg"), model, preprocess)
# Delete the temporary file.
os.remove(os.path.join(tmp_dir, "image.jpg"))
# Assert that the test passes.
self.assertTrue(True)
def test_get_text_features(self):
# Test that the `get_text_features()` method returns a list containing the text embedding.
query = "This is a text query."
model, preprocess = ClipProcessor.load_model()
text_features = ClipProcessor.get_text_features(query)
# Assert that the text embedding is not None.
self.assertIsNotNone(text_features)
# Assert that the text embedding is a list of floats.
self.assertIsInstance(text_features, list)
# Assert that the text embedding has the correct length.
self.assertEqual(len(text_features), 512)

View File

@@ -186,6 +186,34 @@ class TestChromaDbCollection(unittest.TestCase):
# Should still be 1, not 2.
self.assertEqual(app.db.count(), 1)
def test_add_with_skip_embedding(self):
"""
Test that changes to one collection do not affect the other collection
"""
# Start with a clean app
self.app_with_settings.reset()
# app = App(config=AppConfig(collect_metrics=False), db=db)
# Collection should be empty when created
self.assertEqual(self.app_with_settings.db.count(), 0)
self.app_with_settings.db.add(embeddings=[[0, 0, 0]], documents=["document"], metadatas=[{"value": "somevalue"}], ids=["id"], skip_embedding=True)
# After adding, should contain one item
self.assertEqual(self.app_with_settings.db.count(), 1)
# Validate if the get utility of the database is working as expected
data = self.app_with_settings.db.get(["id"], limit=1)
expected_value = {'documents': ['document'],
'embeddings': None,
'ids': ['id'],
'metadatas': [{'value': 'somevalue'}]}
self.assertEqual(data, expected_value)
# Validate if the query utility of the database is working as expected
data = self.app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True)
expected_value = ['document']
self.assertEqual(data, expected_value)
def test_collections_are_persistent(self):
"""
Test that a collection can be picked up later.

View File

@@ -1,14 +1,109 @@
import os
import unittest
from unittest.mock import patch
from embedchain.config import ElasticsearchDBConfig
from embedchain import App
from embedchain.config import AppConfig, ElasticsearchDBConfig
from embedchain.vectordb.elasticsearch import ElasticsearchDB
from embedchain.embedder.gpt4all import GPT4AllEmbedder
class TestEsDB(unittest.TestCase):
def setUp(self):
self.es_config = ElasticsearchDBConfig(es_url="http://mock-url.net")
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
def test_setUp(self, mock_client):
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
self.vector_dim = 384
app_config = AppConfig(collection_name=False, collect_metrics=False)
self.app = App(config=app_config, db=self.db)
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
self.assertEqual(self.db.client, mock_client.return_value)
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
def test_query(self, mock_client):
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
app_config = AppConfig(collection_name=False, collect_metrics=False)
self.app = App(config=app_config, db=self.db, embedder=GPT4AllEmbedder())
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
self.assertEqual(self.db.client, mock_client.return_value)
# Create some dummy data.
embeddings = [[1, 2, 3], [4, 5, 6]]
documents = ["This is a document.", "This is another document."]
metadatas = [{}, {}]
ids = ["doc_1", "doc_2"]
# Add the data to the database.
self.db.add(embeddings, documents, metadatas, ids, skip_embedding=False)
search_response = {"hits":
{"hits":
[
{
"_source": {"text": "This is a document."},
"_score": 0.9
},
{
"_source": {"text": "This is another document."},
"_score": 0.8
}
]
}
}
# Configure the mock client to return the mocked response.
mock_client.return_value.search.return_value = search_response
# Query the database for the documents that are most similar to the query "This is a document".
query = ["This is a document"]
results = self.db.query(query, n_results=2, where={}, skip_embedding=False)
# Assert that the results are correct.
self.assertEqual(results, ["This is a document.", "This is another document."])
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
def test_query_with_skip_embedding(self, mock_client):
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
app_config = AppConfig(collection_name=False, collect_metrics=False)
self.app = App(config=app_config, db=self.db)
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
self.assertEqual(self.db.client, mock_client.return_value)
# Create some dummy data.
embeddings = [[1, 2, 3], [4, 5, 6]]
documents = ["This is a document.", "This is another document."]
metadatas = [{}, {}]
ids = ["doc_1", "doc_2"]
# Add the data to the database.
self.db.add(embeddings, documents, metadatas, ids, skip_embedding=True)
search_response = {"hits":
{"hits":
[
{
"_source": {"text": "This is a document."},
"_score": 0.9
},
{
"_source": {"text": "This is another document."},
"_score": 0.8
}
]
}
}
# Configure the mock client to return the mocked response.
mock_client.return_value.search.return_value = search_response
# Query the database for the documents that are most similar to the query "This is a document".
query = ["This is a document"]
results = self.db.query(query, n_results=2, where={}, skip_embedding=True)
# Assert that the results are correct.
self.assertEqual(results, ["This is a document.", "This is another document."])
def test_init_without_url(self):
# Make sure it's not loaded from env