[bugfix] Fix issue when llm config is not defined (#763)

This commit is contained in:
Deshraj Yadav
2023-10-04 12:08:21 -07:00
committed by GitHub
parent d0af018b8d
commit 87d0b5c76f
15 changed files with 100 additions and 88 deletions

View File

@@ -34,4 +34,4 @@ jobs:
file: coverage.xml file: coverage.xml
env: env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}

View File

@@ -9,6 +9,9 @@ PROJECT_NAME := embedchain
install: install:
poetry install poetry install
install_all:
poetry install --all-extras
install_es: install_es:
poetry install --extras elasticsearch poetry install --extras elasticsearch

View File

@@ -67,7 +67,7 @@ class BaseLlmConfig(BaseConfig):
deployment_name: Optional[str] = None, deployment_name: Optional[str] = None,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
where: Dict[str, Any] = None, where: Dict[str, Any] = None,
query_type: Optional[str] = None query_type: Optional[str] = None,
): ):
""" """
Initializes a configuration class instance for the LLM. Initializes a configuration class instance for the LLM.

View File

@@ -1,8 +1,8 @@
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.chunkers.docs_site import DocsSiteChunker from embedchain.chunkers.docs_site import DocsSiteChunker
from embedchain.chunkers.docx_file import DocxFileChunker from embedchain.chunkers.docx_file import DocxFileChunker
from embedchain.chunkers.mdx import MdxChunker
from embedchain.chunkers.images import ImagesChunker from embedchain.chunkers.images import ImagesChunker
from embedchain.chunkers.mdx import MdxChunker
from embedchain.chunkers.notion import NotionChunker from embedchain.chunkers.notion import NotionChunker
from embedchain.chunkers.pdf_file import PdfFileChunker from embedchain.chunkers.pdf_file import PdfFileChunker
from embedchain.chunkers.qna_pair import QnaPairChunker from embedchain.chunkers.qna_pair import QnaPairChunker

View File

@@ -392,8 +392,13 @@ class EmbedChain(JSONSerializable):
# Count before, to calculate a delta in the end. # Count before, to calculate a delta in the end.
chunks_before_addition = self.db.count() chunks_before_addition = self.db.count()
self.db.add(embeddings=embeddings_data.get("embeddings", None), documents=documents, metadatas=metadatas, self.db.add(
ids=ids, skip_embedding = (chunker.data_type == DataType.IMAGES)) embeddings=embeddings_data.get("embeddings", None),
documents=documents,
metadatas=metadatas,
ids=ids,
skip_embedding=(chunker.data_type == DataType.IMAGES),
)
count_new_chunks = self.db.count() - chunks_before_addition count_new_chunks = self.db.count() - chunks_before_addition
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}")) print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
return list(documents), metadatas, ids, count_new_chunks return list(documents), metadatas, ids, count_new_chunks
@@ -437,17 +442,18 @@ class EmbedChain(JSONSerializable):
# We cannot query the database with the input query in case of an image search. This is because we need # We cannot query the database with the input query in case of an image search. This is because we need
# to bring down both the image and text to the same dimension to be able to compare them. # to bring down both the image and text to the same dimension to be able to compare them.
db_query = input_query db_query = input_query
if config.query_type == "Images": if hasattr(config, "query_type") and config.query_type == "Images":
# We import the clip processor here to make sure the package is not dependent on clip dependency even if the # We import the clip processor here to make sure the package is not dependent on clip dependency even if the
# image dataset is not being used # image dataset is not being used
from embedchain.models.clip_processor import ClipProcessor from embedchain.models.clip_processor import ClipProcessor
db_query = ClipProcessor.get_text_features(query=input_query) db_query = ClipProcessor.get_text_features(query=input_query)
contents = self.db.query( contents = self.db.query(
input_query=db_query, input_query=db_query,
n_results=query_config.number_documents, n_results=query_config.number_documents,
where=where, where=where,
skip_embedding = (config.query_type == "Images") skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"),
) )
return contents return contents

View File

@@ -22,7 +22,7 @@ class GPT4ALLLlm(BaseLlm):
from gpt4all import GPT4All from gpt4all import GPT4All
except ModuleNotFoundError: except ModuleNotFoundError:
raise ModuleNotFoundError( raise ModuleNotFoundError(
"The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`" # noqa E501 "The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`" # noqa E501
) from None ) from None
return GPT4All(model_name=model) return GPT4All(model_name=model)

View File

@@ -1,11 +1,11 @@
import os
import logging
import hashlib import hashlib
import logging
import os
from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.base_loader import BaseLoader
class ImagesLoader(BaseLoader): class ImagesLoader(BaseLoader):
def load_data(self, image_url): def load_data(self, image_url):
""" """
Loads images from the supplied directory/file and applies CLIP model transformation to represent these images Loads images from the supplied directory/file and applies CLIP model transformation to represent these images
@@ -15,6 +15,7 @@ class ImagesLoader(BaseLoader):
""" """
# load model and image preprocessing # load model and image preprocessing
from embedchain.models.clip_processor import ClipProcessor from embedchain.models.clip_processor import ClipProcessor
model, preprocess = ClipProcessor.load_model() model, preprocess = ClipProcessor.load_model()
if os.path.isfile(image_url): if os.path.isfile(image_url):
data = [ClipProcessor.get_image_features(image_url, model, preprocess)] data = [ClipProcessor.get_image_features(image_url, model, preprocess)]
@@ -28,8 +29,11 @@ class ImagesLoader(BaseLoader):
# Log the file that was not loaded # Log the file that was not loaded
logging.exception("Failed to load the file {}. Exception {}".format(filepath, e)) logging.exception("Failed to load the file {}. Exception {}".format(filepath, e))
# Get the metadata like Size, Last Modified and Last Created timestamps # Get the metadata like Size, Last Modified and Last Created timestamps
image_path_metadata = [str(os.path.getsize(image_url)), str(os.path.getmtime(image_url)), image_path_metadata = [
str(os.path.getctime(image_url))] str(os.path.getsize(image_url)),
str(os.path.getmtime(image_url)),
str(os.path.getctime(image_url)),
]
doc_id = hashlib.sha256((" ".join(image_path_metadata) + image_url).encode()).hexdigest() doc_id = hashlib.sha256((" ".join(image_path_metadata) + image_url).encode()).hexdigest()
return { return {
"doc_id": doc_id, "doc_id": doc_id,

View File

@@ -1,6 +1,6 @@
try: try:
import torch
import clip import clip
import torch
from PIL import Image, UnidentifiedImageError from PIL import Image, UnidentifiedImageError
except ImportError: except ImportError:
raise ImportError("Images requires extra dependencies. Install with `pip install embedchain[images]`") from None raise ImportError("Images requires extra dependencies. Install with `pip install embedchain[images]`") from None
@@ -39,14 +39,8 @@ class ClipProcessor:
image_features /= image_features.norm(dim=-1, keepdim=True) image_features /= image_features.norm(dim=-1, keepdim=True)
image_features = image_features.cpu().detach().numpy().tolist()[0] image_features = image_features.cpu().detach().numpy().tolist()[0]
meta_data = { meta_data = {"url": image_url}
"url": image_url return {"content": image_url, "embedding": image_features, "meta_data": meta_data}
}
return {
"content": image_url,
"embedding": image_features,
"meta_data": meta_data
}
@staticmethod @staticmethod
def get_text_features(query): def get_text_features(query):

View File

@@ -115,8 +115,14 @@ class ChromaDB(BaseVectorDB):
def get_advanced(self, where): def get_advanced(self, where):
return self.collection.get(where=where, limit=1) return self.collection.get(where=where, limit=1)
def add(self, embeddings: List[List[float]], documents: List[str], metadatas: List[object], def add(
ids: List[str], skip_embedding: bool) -> Any: self,
embeddings: List[List[float]],
documents: List[str],
metadatas: List[object],
ids: List[str],
skip_embedding: bool,
) -> Any:
""" """
Add vectors to chroma database Add vectors to chroma database
@@ -184,7 +190,7 @@ class ChromaDB(BaseVectorDB):
except InvalidDimensionException as e: except InvalidDimensionException as e:
raise InvalidDimensionException( raise InvalidDimensionException(
e.message() e.message()
+ ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501 + ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501
) from None ) from None
results_formatted = self._format_result(result) results_formatted = self._format_result(result)
contents = [result[0].page_content for result in results_formatted] contents = [result[0].page_content for result in results_formatted]

View File

@@ -100,8 +100,14 @@ class ElasticsearchDB(BaseVectorDB):
ids = [doc["_id"] for doc in docs] ids = [doc["_id"] for doc in docs]
return {"ids": set(ids)} return {"ids": set(ids)}
def add(self, embeddings: List[List[float]], documents: List[str], metadatas: List[object], def add(
ids: List[str], skip_embedding: bool) -> Any: self,
embeddings: List[List[float]],
documents: List[str],
metadatas: List[object],
ids: List[str],
skip_embedding: bool,
) -> Any:
""" """
add data in vector database add data in vector database
:param documents: list of texts to add :param documents: list of texts to add

View File

@@ -94,7 +94,7 @@ pytube = "^15.0.0"
duckduckgo-search = "^3.8.5" duckduckgo-search = "^3.8.5"
llama-hub = { version = "^0.0.29", optional = true } llama-hub = { version = "^0.0.29", optional = true }
sentence-transformers = { version = "^2.2.2", optional = true } sentence-transformers = { version = "^2.2.2", optional = true }
torch = { version = ">=2.0.0, !=2.0.1", optional = true } torch = { version = "2.0.0", optional = true }
# Torch 2.0.1 is not compatible with poetry (https://github.com/pytorch/pytorch/issues/100974) # Torch 2.0.1 is not compatible with poetry (https://github.com/pytorch/pytorch/issues/100974)
gpt4all = { version = "1.0.8", optional = true } gpt4all = { version = "1.0.8", optional = true }
# 1.0.9 is not working for some users (https://github.com/nomic-ai/gpt4all/issues/1394) # 1.0.9 is not working for some users (https://github.com/nomic-ai/gpt4all/issues/1394)
@@ -107,6 +107,8 @@ discord = { version = "^2.3.2", optional = true }
slack-sdk = { version = "3.21.3", optional = true } slack-sdk = { version = "3.21.3", optional = true }
docx2txt = "^0.8" docx2txt = "^0.8"
clip = {git = "https://github.com/openai/CLIP.git#a1d0717", optional = true} clip = {git = "https://github.com/openai/CLIP.git#a1d0717", optional = true}
pillow = { version = "10.0.1", optional = true }
torchvision = { version = ">=0.15.1, !=0.15.2", optional = true }
ftfy = { version = "6.1.1", optional = true } ftfy = { version = "6.1.1", optional = true }
regex = { version = "2023.8.8", optional = true } regex = { version = "2023.8.8", optional = true }
@@ -131,7 +133,7 @@ poe = ["fastapi-poe"]
discord = ["discord"] discord = ["discord"]
slack = ["slack-sdk", "flask"] slack = ["slack-sdk", "flask"]
whatsapp = ["twilio", "flask"] whatsapp = ["twilio", "flask"]
images = ["torch", "ftfy", "regex", "clip"] images = ["torch", "ftfy", "regex", "clip", "pillow", "torchvision"]
[tool.poetry.group.docs.dependencies] [tool.poetry.group.docs.dependencies]

View File

@@ -19,11 +19,13 @@ class TestImageChunker(unittest.TestCase):
image_path = "./tmp/image.jpeg" image_path = "./tmp/image.jpeg"
result = chunker.create_chunks(MockLoader(), image_path) result = chunker.create_chunks(MockLoader(), image_path)
expected_chunks = {'doc_id': '123', expected_chunks = {
'documents': [image_path], "doc_id": "123",
'embeddings': ['embedding'], "documents": [image_path],
'ids': ['140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe'], "embeddings": ["embedding"],
'metadatas': [{'data_type': 'images', 'doc_id': '123', 'url': 'none'}]} "ids": ["140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe"],
"metadatas": [{"data_type": "images", "doc_id": "123", "url": "none"}],
}
self.assertEqual(expected_chunks, result) self.assertEqual(expected_chunks, result)
def test_chunks_with_default_config(self): def test_chunks_with_default_config(self):
@@ -37,11 +39,13 @@ class TestImageChunker(unittest.TestCase):
image_path = "./tmp/image.jpeg" image_path = "./tmp/image.jpeg"
result = chunker.create_chunks(MockLoader(), image_path) result = chunker.create_chunks(MockLoader(), image_path)
expected_chunks = {'doc_id': '123', expected_chunks = {
'documents': [image_path], "doc_id": "123",
'embeddings': ['embedding'], "documents": [image_path],
'ids': ['140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe'], "embeddings": ["embedding"],
'metadatas': [{'data_type': 'images', 'doc_id': '123', 'url': 'none'}]} "ids": ["140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe"],
"metadatas": [{"data_type": "images", "doc_id": "123", "url": "none"}],
}
self.assertEqual(expected_chunks, result) self.assertEqual(expected_chunks, result)
def test_word_count(self): def test_word_count(self):

View File

@@ -1,29 +1,23 @@
import tempfile
import unittest
import os import os
import tempfile
import urllib import urllib
from PIL import Image from PIL import Image
from embedchain.models.clip_processor import ClipProcessor from embedchain.models.clip_processor import ClipProcessor
class ClipProcessorTest(unittest.TestCase): class TestClipProcessor:
def test_load_model(self): def test_load_model(self):
# Test that the `load_model()` method loads the CLIP model and image preprocessing correctly. # Test that the `load_model()` method loads the CLIP model and image preprocessing correctly.
model, preprocess = ClipProcessor.load_model() model, preprocess = ClipProcessor.load_model()
assert model is not None
# Assert that the model is not None. assert preprocess is not None
self.assertIsNotNone(model)
# Assert that the preprocess is not None.
self.assertIsNotNone(preprocess)
def test_get_image_features(self): def test_get_image_features(self):
# Clone the image to a temporary folder. # Clone the image to a temporary folder.
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
urllib.request.urlretrieve( urllib.request.urlretrieve("https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg", "image.jpg")
'https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg',
"image.jpg")
image = Image.open("image.jpg") image = Image.open("image.jpg")
image.save(os.path.join(tmp_dir, "image.jpg")) image.save(os.path.join(tmp_dir, "image.jpg"))
@@ -35,9 +29,6 @@ class ClipProcessorTest(unittest.TestCase):
# Delete the temporary file. # Delete the temporary file.
os.remove(os.path.join(tmp_dir, "image.jpg")) os.remove(os.path.join(tmp_dir, "image.jpg"))
# Assert that the test passes.
self.assertTrue(True)
def test_get_text_features(self): def test_get_text_features(self):
# Test that the `get_text_features()` method returns a list containing the text embedding. # Test that the `get_text_features()` method returns a list containing the text embedding.
query = "This is a text query." query = "This is a text query."
@@ -46,10 +37,10 @@ class ClipProcessorTest(unittest.TestCase):
text_features = ClipProcessor.get_text_features(query) text_features = ClipProcessor.get_text_features(query)
# Assert that the text embedding is not None. # Assert that the text embedding is not None.
self.assertIsNotNone(text_features) assert text_features is not None
# Assert that the text embedding is a list of floats. # Assert that the text embedding is a list of floats.
self.assertIsInstance(text_features, list) assert isinstance(text_features, list)
# Assert that the text embedding has the correct length. # Assert that the text embedding has the correct length.
self.assertEqual(len(text_features), 512) assert len(text_features) == 512

View File

@@ -197,21 +197,29 @@ class TestChromaDbCollection(unittest.TestCase):
# Collection should be empty when created # Collection should be empty when created
self.assertEqual(self.app_with_settings.db.count(), 0) self.assertEqual(self.app_with_settings.db.count(), 0)
self.app_with_settings.db.add(embeddings=[[0, 0, 0]], documents=["document"], metadatas=[{"value": "somevalue"}], ids=["id"], skip_embedding=True) self.app_with_settings.db.add(
embeddings=[[0, 0, 0]],
documents=["document"],
metadatas=[{"value": "somevalue"}],
ids=["id"],
skip_embedding=True,
)
# After adding, should contain one item # After adding, should contain one item
self.assertEqual(self.app_with_settings.db.count(), 1) self.assertEqual(self.app_with_settings.db.count(), 1)
# Validate if the get utility of the database is working as expected # Validate if the get utility of the database is working as expected
data = self.app_with_settings.db.get(["id"], limit=1) data = self.app_with_settings.db.get(["id"], limit=1)
expected_value = {'documents': ['document'], expected_value = {
'embeddings': None, "documents": ["document"],
'ids': ['id'], "embeddings": None,
'metadatas': [{'value': 'somevalue'}]} "ids": ["id"],
"metadatas": [{"value": "somevalue"}],
}
self.assertEqual(data, expected_value) self.assertEqual(data, expected_value)
# Validate if the query utility of the database is working as expected # Validate if the query utility of the database is working as expected
data = self.app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True) data = self.app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True)
expected_value = ['document'] expected_value = ["document"]
self.assertEqual(data, expected_value) self.assertEqual(data, expected_value)
def test_collections_are_persistent(self): def test_collections_are_persistent(self):

View File

@@ -4,11 +4,11 @@ from unittest.mock import patch
from embedchain import App from embedchain import App
from embedchain.config import AppConfig, ElasticsearchDBConfig from embedchain.config import AppConfig, ElasticsearchDBConfig
from embedchain.vectordb.elasticsearch import ElasticsearchDB
from embedchain.embedder.gpt4all import GPT4AllEmbedder from embedchain.embedder.gpt4all import GPT4AllEmbedder
from embedchain.vectordb.elasticsearch import ElasticsearchDB
class TestEsDB(unittest.TestCase): class TestEsDB(unittest.TestCase):
@patch("embedchain.vectordb.elasticsearch.Elasticsearch") @patch("embedchain.vectordb.elasticsearch.Elasticsearch")
def test_setUp(self, mock_client): def test_setUp(self, mock_client):
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200")) self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
@@ -37,17 +37,11 @@ class TestEsDB(unittest.TestCase):
# Add the data to the database. # Add the data to the database.
self.db.add(embeddings, documents, metadatas, ids, skip_embedding=False) self.db.add(embeddings, documents, metadatas, ids, skip_embedding=False)
search_response = {"hits": search_response = {
{"hits": "hits": {
[ "hits": [
{ {"_source": {"text": "This is a document."}, "_score": 0.9},
"_source": {"text": "This is a document."}, {"_source": {"text": "This is another document."}, "_score": 0.8},
"_score": 0.9
},
{
"_source": {"text": "This is another document."},
"_score": 0.8
}
] ]
} }
} }
@@ -80,17 +74,11 @@ class TestEsDB(unittest.TestCase):
# Add the data to the database. # Add the data to the database.
self.db.add(embeddings, documents, metadatas, ids, skip_embedding=True) self.db.add(embeddings, documents, metadatas, ids, skip_embedding=True)
search_response = {"hits": search_response = {
{"hits": "hits": {
[ "hits": [
{ {"_source": {"text": "This is a document."}, "_score": 0.9},
"_source": {"text": "This is a document."}, {"_source": {"text": "This is another document."}, "_score": 0.8},
"_score": 0.9
},
{
"_source": {"text": "This is another document."},
"_score": 0.8
}
] ]
} }
} }