[Bug fix] Fix typos, static methods and other sanity improvements in the package (#1129)

This commit is contained in:
Sandra Serrano
2024-01-08 19:47:46 +01:00
committed by GitHub
parent 62c0c52e31
commit 2496ed133e
41 changed files with 133 additions and 103 deletions

View File

@@ -7,9 +7,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from dotenv import load_dotenv
from langchain.docstore.document import Document
from embedchain.cache import (adapt, get_gptcache_session,
gptcache_data_convert,
gptcache_update_cache_callback)
from embedchain.cache import adapt, get_gptcache_session, gptcache_data_convert, gptcache_update_cache_callback
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
from embedchain.config.base_app_config import BaseAppConfig
@@ -19,8 +17,7 @@ from embedchain.embedder.base import BaseEmbedder
from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.llm.base import BaseLlm
from embedchain.loaders.base_loader import BaseLoader
from embedchain.models.data_type import (DataType, DirectDataType,
IndirectDataType, SpecialDataType)
from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
from embedchain.telemetry.posthog import AnonymousTelemetry
from embedchain.utils.misc import detect_datatype, is_valid_json_string
from embedchain.vectordb.base import BaseVectorDB
@@ -84,7 +81,7 @@ class EmbedChain(JSONSerializable):
# Attributes that aren't subclass related.
self.user_asks = []
self.chunker: ChunkerConfig = None
self.chunker: Optional[ChunkerConfig] = None
# Send anonymous telemetry
self._telemetry_props = {"class": self.__class__.__name__}
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
@@ -290,7 +287,7 @@ class EmbedChain(JSONSerializable):
# Or it's different, then it will be added as a new text.
return None
elif chunker.data_type.value in [item.value for item in IndirectDataType]:
# These types have a indirect source reference
# These types have an indirect source reference
# As long as the reference is the same, they can be updated.
where = {"url": src}
if chunker.data_type == DataType.JSON and is_valid_json_string(src):
@@ -442,10 +439,11 @@ class EmbedChain(JSONSerializable):
)
count_new_chunks = self.db.count() - chunks_before_addition
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
print(f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}")
return list(documents), metadatas, ids, count_new_chunks
def _format_result(self, results):
@staticmethod
def _format_result(results):
return [
(Document(page_content=result[0], metadata=result[1] or {}), result[2])
for result in zip(