[bugfix] Fix issue when llm config is not defined (#763)

This commit is contained in:
Deshraj Yadav
2023-10-04 12:08:21 -07:00
committed by GitHub
parent d0af018b8d
commit 87d0b5c76f
15 changed files with 100 additions and 88 deletions

View File

@@ -67,7 +67,7 @@ class BaseLlmConfig(BaseConfig):
deployment_name: Optional[str] = None,
system_prompt: Optional[str] = None,
where: Dict[str, Any] = None,
query_type: Optional[str] = None
query_type: Optional[str] = None,
):
"""
Initializes a configuration class instance for the LLM.

View File

@@ -1,8 +1,8 @@
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.chunkers.docs_site import DocsSiteChunker
from embedchain.chunkers.docx_file import DocxFileChunker
from embedchain.chunkers.mdx import MdxChunker
from embedchain.chunkers.images import ImagesChunker
from embedchain.chunkers.mdx import MdxChunker
from embedchain.chunkers.notion import NotionChunker
from embedchain.chunkers.pdf_file import PdfFileChunker
from embedchain.chunkers.qna_pair import QnaPairChunker

View File

@@ -392,8 +392,13 @@ class EmbedChain(JSONSerializable):
# Count before, to calculate a delta in the end.
chunks_before_addition = self.db.count()
self.db.add(embeddings=embeddings_data.get("embeddings", None), documents=documents, metadatas=metadatas,
ids=ids, skip_embedding = (chunker.data_type == DataType.IMAGES))
self.db.add(
embeddings=embeddings_data.get("embeddings", None),
documents=documents,
metadatas=metadatas,
ids=ids,
skip_embedding=(chunker.data_type == DataType.IMAGES),
)
count_new_chunks = self.db.count() - chunks_before_addition
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
return list(documents), metadatas, ids, count_new_chunks
@@ -437,17 +442,18 @@ class EmbedChain(JSONSerializable):
# We cannot query the database with the input query in case of an image search. This is because we need
# to bring down both the image and text to the same dimension to be able to compare them.
db_query = input_query
if config.query_type == "Images":
if hasattr(config, "query_type") and config.query_type == "Images":
# We import the clip processor here to make sure the package is not dependent on clip dependency even if the
# image dataset is not being used
from embedchain.models.clip_processor import ClipProcessor
db_query = ClipProcessor.get_text_features(query=input_query)
contents = self.db.query(
input_query=db_query,
n_results=query_config.number_documents,
where=where,
skip_embedding = (config.query_type == "Images")
skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"),
)
return contents

View File

@@ -22,7 +22,7 @@ class GPT4ALLLlm(BaseLlm):
from gpt4all import GPT4All
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`" # noqa E501
"The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`" # noqa E501
) from None
return GPT4All(model_name=model)

View File

@@ -1,11 +1,11 @@
import os
import logging
import hashlib
import logging
import os
from embedchain.loaders.base_loader import BaseLoader
class ImagesLoader(BaseLoader):
def load_data(self, image_url):
"""
Loads images from the supplied directory/file and applies CLIP model transformation to represent these images
@@ -15,6 +15,7 @@ class ImagesLoader(BaseLoader):
"""
# load model and image preprocessing
from embedchain.models.clip_processor import ClipProcessor
model, preprocess = ClipProcessor.load_model()
if os.path.isfile(image_url):
data = [ClipProcessor.get_image_features(image_url, model, preprocess)]
@@ -28,8 +29,11 @@ class ImagesLoader(BaseLoader):
# Log the file that was not loaded
logging.exception("Failed to load the file {}. Exception {}".format(filepath, e))
# Get the metadata like Size, Last Modified and Last Created timestamps
image_path_metadata = [str(os.path.getsize(image_url)), str(os.path.getmtime(image_url)),
str(os.path.getctime(image_url))]
image_path_metadata = [
str(os.path.getsize(image_url)),
str(os.path.getmtime(image_url)),
str(os.path.getctime(image_url)),
]
doc_id = hashlib.sha256((" ".join(image_path_metadata) + image_url).encode()).hexdigest()
return {
"doc_id": doc_id,

View File

@@ -1,6 +1,6 @@
try:
import torch
import clip
import torch
from PIL import Image, UnidentifiedImageError
except ImportError:
raise ImportError("Images requires extra dependencies. Install with `pip install embedchain[images]`") from None
@@ -39,14 +39,8 @@ class ClipProcessor:
image_features /= image_features.norm(dim=-1, keepdim=True)
image_features = image_features.cpu().detach().numpy().tolist()[0]
meta_data = {
"url": image_url
}
return {
"content": image_url,
"embedding": image_features,
"meta_data": meta_data
}
meta_data = {"url": image_url}
return {"content": image_url, "embedding": image_features, "meta_data": meta_data}
@staticmethod
def get_text_features(query):

View File

@@ -115,8 +115,14 @@ class ChromaDB(BaseVectorDB):
def get_advanced(self, where):
return self.collection.get(where=where, limit=1)
def add(self, embeddings: List[List[float]], documents: List[str], metadatas: List[object],
ids: List[str], skip_embedding: bool) -> Any:
def add(
self,
embeddings: List[List[float]],
documents: List[str],
metadatas: List[object],
ids: List[str],
skip_embedding: bool,
) -> Any:
"""
Add vectors to chroma database
@@ -184,7 +190,7 @@ class ChromaDB(BaseVectorDB):
except InvalidDimensionException as e:
raise InvalidDimensionException(
e.message()
+ ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501
+ ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501
) from None
results_formatted = self._format_result(result)
contents = [result[0].page_content for result in results_formatted]

View File

@@ -100,8 +100,14 @@ class ElasticsearchDB(BaseVectorDB):
ids = [doc["_id"] for doc in docs]
return {"ids": set(ids)}
def add(self, embeddings: List[List[float]], documents: List[str], metadatas: List[object],
ids: List[str], skip_embedding: bool) -> Any:
def add(
self,
embeddings: List[List[float]],
documents: List[str],
metadatas: List[object],
ids: List[str],
skip_embedding: bool,
) -> Any:
"""
add data in vector database
:param documents: list of texts to add