[bugfix] Fix issue when llm config is not defined (#763)
This commit is contained in:
2
.github/workflows/ci.yml
vendored
2
.github/workflows/ci.yml
vendored
@@ -34,4 +34,4 @@ jobs:
|
||||
file: coverage.xml
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
|
||||
|
||||
3
Makefile
3
Makefile
@@ -9,6 +9,9 @@ PROJECT_NAME := embedchain
|
||||
install:
|
||||
poetry install
|
||||
|
||||
install_all:
|
||||
poetry install --all-extras
|
||||
|
||||
install_es:
|
||||
poetry install --extras elasticsearch
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ class BaseLlmConfig(BaseConfig):
|
||||
deployment_name: Optional[str] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
where: Dict[str, Any] = None,
|
||||
query_type: Optional[str] = None
|
||||
query_type: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for the LLM.
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.chunkers.docs_site import DocsSiteChunker
|
||||
from embedchain.chunkers.docx_file import DocxFileChunker
|
||||
from embedchain.chunkers.mdx import MdxChunker
|
||||
from embedchain.chunkers.images import ImagesChunker
|
||||
from embedchain.chunkers.mdx import MdxChunker
|
||||
from embedchain.chunkers.notion import NotionChunker
|
||||
from embedchain.chunkers.pdf_file import PdfFileChunker
|
||||
from embedchain.chunkers.qna_pair import QnaPairChunker
|
||||
|
||||
@@ -392,8 +392,13 @@ class EmbedChain(JSONSerializable):
|
||||
# Count before, to calculate a delta in the end.
|
||||
chunks_before_addition = self.db.count()
|
||||
|
||||
self.db.add(embeddings=embeddings_data.get("embeddings", None), documents=documents, metadatas=metadatas,
|
||||
ids=ids, skip_embedding = (chunker.data_type == DataType.IMAGES))
|
||||
self.db.add(
|
||||
embeddings=embeddings_data.get("embeddings", None),
|
||||
documents=documents,
|
||||
metadatas=metadatas,
|
||||
ids=ids,
|
||||
skip_embedding=(chunker.data_type == DataType.IMAGES),
|
||||
)
|
||||
count_new_chunks = self.db.count() - chunks_before_addition
|
||||
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
|
||||
return list(documents), metadatas, ids, count_new_chunks
|
||||
@@ -437,17 +442,18 @@ class EmbedChain(JSONSerializable):
|
||||
# We cannot query the database with the input query in case of an image search. This is because we need
|
||||
# to bring down both the image and text to the same dimension to be able to compare them.
|
||||
db_query = input_query
|
||||
if config.query_type == "Images":
|
||||
if hasattr(config, "query_type") and config.query_type == "Images":
|
||||
# We import the clip processor here to make sure the package is not dependent on clip dependency even if the
|
||||
# image dataset is not being used
|
||||
from embedchain.models.clip_processor import ClipProcessor
|
||||
|
||||
db_query = ClipProcessor.get_text_features(query=input_query)
|
||||
|
||||
contents = self.db.query(
|
||||
input_query=db_query,
|
||||
n_results=query_config.number_documents,
|
||||
where=where,
|
||||
skip_embedding = (config.query_type == "Images")
|
||||
skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"),
|
||||
)
|
||||
|
||||
return contents
|
||||
|
||||
@@ -22,7 +22,7 @@ class GPT4ALLLlm(BaseLlm):
|
||||
from gpt4all import GPT4All
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`" # noqa E501
|
||||
"The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`" # noqa E501
|
||||
) from None
|
||||
|
||||
return GPT4All(model_name=model)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import os
|
||||
import logging
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
|
||||
class ImagesLoader(BaseLoader):
|
||||
|
||||
def load_data(self, image_url):
|
||||
"""
|
||||
Loads images from the supplied directory/file and applies CLIP model transformation to represent these images
|
||||
@@ -15,6 +15,7 @@ class ImagesLoader(BaseLoader):
|
||||
"""
|
||||
# load model and image preprocessing
|
||||
from embedchain.models.clip_processor import ClipProcessor
|
||||
|
||||
model, preprocess = ClipProcessor.load_model()
|
||||
if os.path.isfile(image_url):
|
||||
data = [ClipProcessor.get_image_features(image_url, model, preprocess)]
|
||||
@@ -28,8 +29,11 @@ class ImagesLoader(BaseLoader):
|
||||
# Log the file that was not loaded
|
||||
logging.exception("Failed to load the file {}. Exception {}".format(filepath, e))
|
||||
# Get the metadata like Size, Last Modified and Last Created timestamps
|
||||
image_path_metadata = [str(os.path.getsize(image_url)), str(os.path.getmtime(image_url)),
|
||||
str(os.path.getctime(image_url))]
|
||||
image_path_metadata = [
|
||||
str(os.path.getsize(image_url)),
|
||||
str(os.path.getmtime(image_url)),
|
||||
str(os.path.getctime(image_url)),
|
||||
]
|
||||
doc_id = hashlib.sha256((" ".join(image_path_metadata) + image_url).encode()).hexdigest()
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
try:
|
||||
import torch
|
||||
import clip
|
||||
import torch
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
except ImportError:
|
||||
raise ImportError("Images requires extra dependencies. Install with `pip install embedchain[images]`") from None
|
||||
@@ -39,14 +39,8 @@ class ClipProcessor:
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
image_features = image_features.cpu().detach().numpy().tolist()[0]
|
||||
meta_data = {
|
||||
"url": image_url
|
||||
}
|
||||
return {
|
||||
"content": image_url,
|
||||
"embedding": image_features,
|
||||
"meta_data": meta_data
|
||||
}
|
||||
meta_data = {"url": image_url}
|
||||
return {"content": image_url, "embedding": image_features, "meta_data": meta_data}
|
||||
|
||||
@staticmethod
|
||||
def get_text_features(query):
|
||||
|
||||
@@ -115,8 +115,14 @@ class ChromaDB(BaseVectorDB):
|
||||
def get_advanced(self, where):
|
||||
return self.collection.get(where=where, limit=1)
|
||||
|
||||
def add(self, embeddings: List[List[float]], documents: List[str], metadatas: List[object],
|
||||
ids: List[str], skip_embedding: bool) -> Any:
|
||||
def add(
|
||||
self,
|
||||
embeddings: List[List[float]],
|
||||
documents: List[str],
|
||||
metadatas: List[object],
|
||||
ids: List[str],
|
||||
skip_embedding: bool,
|
||||
) -> Any:
|
||||
"""
|
||||
Add vectors to chroma database
|
||||
|
||||
@@ -184,7 +190,7 @@ class ChromaDB(BaseVectorDB):
|
||||
except InvalidDimensionException as e:
|
||||
raise InvalidDimensionException(
|
||||
e.message()
|
||||
+ ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501
|
||||
+ ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501
|
||||
) from None
|
||||
results_formatted = self._format_result(result)
|
||||
contents = [result[0].page_content for result in results_formatted]
|
||||
|
||||
@@ -100,8 +100,14 @@ class ElasticsearchDB(BaseVectorDB):
|
||||
ids = [doc["_id"] for doc in docs]
|
||||
return {"ids": set(ids)}
|
||||
|
||||
def add(self, embeddings: List[List[float]], documents: List[str], metadatas: List[object],
|
||||
ids: List[str], skip_embedding: bool) -> Any:
|
||||
def add(
|
||||
self,
|
||||
embeddings: List[List[float]],
|
||||
documents: List[str],
|
||||
metadatas: List[object],
|
||||
ids: List[str],
|
||||
skip_embedding: bool,
|
||||
) -> Any:
|
||||
"""
|
||||
add data in vector database
|
||||
:param documents: list of texts to add
|
||||
|
||||
@@ -94,7 +94,7 @@ pytube = "^15.0.0"
|
||||
duckduckgo-search = "^3.8.5"
|
||||
llama-hub = { version = "^0.0.29", optional = true }
|
||||
sentence-transformers = { version = "^2.2.2", optional = true }
|
||||
torch = { version = ">=2.0.0, !=2.0.1", optional = true }
|
||||
torch = { version = "2.0.0", optional = true }
|
||||
# Torch 2.0.1 is not compatible with poetry (https://github.com/pytorch/pytorch/issues/100974)
|
||||
gpt4all = { version = "1.0.8", optional = true }
|
||||
# 1.0.9 is not working for some users (https://github.com/nomic-ai/gpt4all/issues/1394)
|
||||
@@ -107,6 +107,8 @@ discord = { version = "^2.3.2", optional = true }
|
||||
slack-sdk = { version = "3.21.3", optional = true }
|
||||
docx2txt = "^0.8"
|
||||
clip = {git = "https://github.com/openai/CLIP.git#a1d0717", optional = true}
|
||||
pillow = { version = "10.0.1", optional = true }
|
||||
torchvision = { version = ">=0.15.1, !=0.15.2", optional = true }
|
||||
ftfy = { version = "6.1.1", optional = true }
|
||||
regex = { version = "2023.8.8", optional = true }
|
||||
|
||||
@@ -131,7 +133,7 @@ poe = ["fastapi-poe"]
|
||||
discord = ["discord"]
|
||||
slack = ["slack-sdk", "flask"]
|
||||
whatsapp = ["twilio", "flask"]
|
||||
images = ["torch", "ftfy", "regex", "clip"]
|
||||
images = ["torch", "ftfy", "regex", "clip", "pillow", "torchvision"]
|
||||
|
||||
[tool.poetry.group.docs.dependencies]
|
||||
|
||||
|
||||
@@ -19,11 +19,13 @@ class TestImageChunker(unittest.TestCase):
|
||||
image_path = "./tmp/image.jpeg"
|
||||
result = chunker.create_chunks(MockLoader(), image_path)
|
||||
|
||||
expected_chunks = {'doc_id': '123',
|
||||
'documents': [image_path],
|
||||
'embeddings': ['embedding'],
|
||||
'ids': ['140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe'],
|
||||
'metadatas': [{'data_type': 'images', 'doc_id': '123', 'url': 'none'}]}
|
||||
expected_chunks = {
|
||||
"doc_id": "123",
|
||||
"documents": [image_path],
|
||||
"embeddings": ["embedding"],
|
||||
"ids": ["140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe"],
|
||||
"metadatas": [{"data_type": "images", "doc_id": "123", "url": "none"}],
|
||||
}
|
||||
self.assertEqual(expected_chunks, result)
|
||||
|
||||
def test_chunks_with_default_config(self):
|
||||
@@ -37,11 +39,13 @@ class TestImageChunker(unittest.TestCase):
|
||||
image_path = "./tmp/image.jpeg"
|
||||
result = chunker.create_chunks(MockLoader(), image_path)
|
||||
|
||||
expected_chunks = {'doc_id': '123',
|
||||
'documents': [image_path],
|
||||
'embeddings': ['embedding'],
|
||||
'ids': ['140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe'],
|
||||
'metadatas': [{'data_type': 'images', 'doc_id': '123', 'url': 'none'}]}
|
||||
expected_chunks = {
|
||||
"doc_id": "123",
|
||||
"documents": [image_path],
|
||||
"embeddings": ["embedding"],
|
||||
"ids": ["140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe"],
|
||||
"metadatas": [{"data_type": "images", "doc_id": "123", "url": "none"}],
|
||||
}
|
||||
self.assertEqual(expected_chunks, result)
|
||||
|
||||
def test_word_count(self):
|
||||
|
||||
@@ -1,29 +1,23 @@
|
||||
import tempfile
|
||||
import unittest
|
||||
import os
|
||||
import tempfile
|
||||
import urllib
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from embedchain.models.clip_processor import ClipProcessor
|
||||
|
||||
|
||||
class ClipProcessorTest(unittest.TestCase):
|
||||
|
||||
class TestClipProcessor:
|
||||
def test_load_model(self):
|
||||
# Test that the `load_model()` method loads the CLIP model and image preprocessing correctly.
|
||||
model, preprocess = ClipProcessor.load_model()
|
||||
|
||||
# Assert that the model is not None.
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
# Assert that the preprocess is not None.
|
||||
self.assertIsNotNone(preprocess)
|
||||
assert model is not None
|
||||
assert preprocess is not None
|
||||
|
||||
def test_get_image_features(self):
|
||||
# Clone the image to a temporary folder.
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
urllib.request.urlretrieve(
|
||||
'https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg',
|
||||
"image.jpg")
|
||||
urllib.request.urlretrieve("https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg", "image.jpg")
|
||||
|
||||
image = Image.open("image.jpg")
|
||||
image.save(os.path.join(tmp_dir, "image.jpg"))
|
||||
@@ -35,9 +29,6 @@ class ClipProcessorTest(unittest.TestCase):
|
||||
# Delete the temporary file.
|
||||
os.remove(os.path.join(tmp_dir, "image.jpg"))
|
||||
|
||||
# Assert that the test passes.
|
||||
self.assertTrue(True)
|
||||
|
||||
def test_get_text_features(self):
|
||||
# Test that the `get_text_features()` method returns a list containing the text embedding.
|
||||
query = "This is a text query."
|
||||
@@ -46,10 +37,10 @@ class ClipProcessorTest(unittest.TestCase):
|
||||
text_features = ClipProcessor.get_text_features(query)
|
||||
|
||||
# Assert that the text embedding is not None.
|
||||
self.assertIsNotNone(text_features)
|
||||
assert text_features is not None
|
||||
|
||||
# Assert that the text embedding is a list of floats.
|
||||
self.assertIsInstance(text_features, list)
|
||||
assert isinstance(text_features, list)
|
||||
|
||||
# Assert that the text embedding has the correct length.
|
||||
self.assertEqual(len(text_features), 512)
|
||||
assert len(text_features) == 512
|
||||
|
||||
@@ -197,21 +197,29 @@ class TestChromaDbCollection(unittest.TestCase):
|
||||
# Collection should be empty when created
|
||||
self.assertEqual(self.app_with_settings.db.count(), 0)
|
||||
|
||||
self.app_with_settings.db.add(embeddings=[[0, 0, 0]], documents=["document"], metadatas=[{"value": "somevalue"}], ids=["id"], skip_embedding=True)
|
||||
self.app_with_settings.db.add(
|
||||
embeddings=[[0, 0, 0]],
|
||||
documents=["document"],
|
||||
metadatas=[{"value": "somevalue"}],
|
||||
ids=["id"],
|
||||
skip_embedding=True,
|
||||
)
|
||||
# After adding, should contain one item
|
||||
self.assertEqual(self.app_with_settings.db.count(), 1)
|
||||
|
||||
# Validate if the get utility of the database is working as expected
|
||||
data = self.app_with_settings.db.get(["id"], limit=1)
|
||||
expected_value = {'documents': ['document'],
|
||||
'embeddings': None,
|
||||
'ids': ['id'],
|
||||
'metadatas': [{'value': 'somevalue'}]}
|
||||
expected_value = {
|
||||
"documents": ["document"],
|
||||
"embeddings": None,
|
||||
"ids": ["id"],
|
||||
"metadatas": [{"value": "somevalue"}],
|
||||
}
|
||||
self.assertEqual(data, expected_value)
|
||||
|
||||
# Validate if the query utility of the database is working as expected
|
||||
data = self.app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True)
|
||||
expected_value = ['document']
|
||||
expected_value = ["document"]
|
||||
self.assertEqual(data, expected_value)
|
||||
|
||||
def test_collections_are_persistent(self):
|
||||
|
||||
@@ -4,11 +4,11 @@ from unittest.mock import patch
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, ElasticsearchDBConfig
|
||||
from embedchain.vectordb.elasticsearch import ElasticsearchDB
|
||||
from embedchain.embedder.gpt4all import GPT4AllEmbedder
|
||||
from embedchain.vectordb.elasticsearch import ElasticsearchDB
|
||||
|
||||
|
||||
class TestEsDB(unittest.TestCase):
|
||||
|
||||
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
|
||||
def test_setUp(self, mock_client):
|
||||
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
|
||||
@@ -37,17 +37,11 @@ class TestEsDB(unittest.TestCase):
|
||||
# Add the data to the database.
|
||||
self.db.add(embeddings, documents, metadatas, ids, skip_embedding=False)
|
||||
|
||||
search_response = {"hits":
|
||||
{"hits":
|
||||
[
|
||||
{
|
||||
"_source": {"text": "This is a document."},
|
||||
"_score": 0.9
|
||||
},
|
||||
{
|
||||
"_source": {"text": "This is another document."},
|
||||
"_score": 0.8
|
||||
}
|
||||
search_response = {
|
||||
"hits": {
|
||||
"hits": [
|
||||
{"_source": {"text": "This is a document."}, "_score": 0.9},
|
||||
{"_source": {"text": "This is another document."}, "_score": 0.8},
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -80,17 +74,11 @@ class TestEsDB(unittest.TestCase):
|
||||
# Add the data to the database.
|
||||
self.db.add(embeddings, documents, metadatas, ids, skip_embedding=True)
|
||||
|
||||
search_response = {"hits":
|
||||
{"hits":
|
||||
[
|
||||
{
|
||||
"_source": {"text": "This is a document."},
|
||||
"_score": 0.9
|
||||
},
|
||||
{
|
||||
"_source": {"text": "This is another document."},
|
||||
"_score": 0.8
|
||||
}
|
||||
search_response = {
|
||||
"hits": {
|
||||
"hits": [
|
||||
{"_source": {"text": "This is a document."}, "_score": 0.9},
|
||||
{"_source": {"text": "This is another document."}, "_score": 0.8},
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user