[Refactor] Improve logging package wide (#1315)

This commit is contained in:
Deshraj Yadav
2024-03-13 17:13:30 -07:00
committed by GitHub
parent ef69c91b60
commit 3616eaadb4
54 changed files with 263 additions and 231 deletions

View File

@@ -91,7 +91,6 @@ keys = console
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =

View File

@@ -32,6 +32,8 @@ from embedchain.utils.misc import validate_config
from embedchain.vectordb.base import BaseVectorDB
from embedchain.vectordb.chroma import ChromaDB
logger = logging.getLogger(__name__)
@register_deserializable
class App(EmbedChain):
@@ -50,10 +52,10 @@ class App(EmbedChain):
embedding_model: BaseEmbedder = None,
llm: BaseLlm = None,
config_data: dict = None,
log_level=logging.WARN,
auto_deploy: bool = False,
chunker: ChunkerConfig = None,
cache_config: CacheConfig = None,
log_level: int = logging.WARN,
):
"""
Initialize a new `App` instance.
@@ -68,8 +70,6 @@ class App(EmbedChain):
:type llm: BaseLlm, optional
:param config_data: Config dictionary, defaults to None
:type config_data: dict, optional
:param log_level: Log level to use, defaults to logging.WARN
:type log_level: int, optional
:param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
:type auto_deploy: bool, optional
:raises Exception: If an error occurs while creating the pipeline
@@ -83,13 +83,12 @@ class App(EmbedChain):
if name and config:
raise Exception("Cannot provide both name and config. Please provide only one of them.")
# logging.basicConfig(level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
self.logger = logging.getLogger(__name__)
logger.debug("4.0")
# Initialize the metadata db for the app
setup_engine(database_uri=os.environ.get("EMBEDCHAIN_DB_URI"))
init_db()
logger.debug("4.0")
self.auto_deploy = auto_deploy
# Store the dict config as an attribute to be able to send it
self.config_data = config_data if (config_data and validate_config(config_data)) else None
@@ -119,6 +118,7 @@ class App(EmbedChain):
self.llm = llm or OpenAILlm()
self._init_db()
logger.debug("4.1")
# Session for the metadata db
self.db_session = get_session()
@@ -126,6 +126,7 @@ class App(EmbedChain):
if self.cache_config is not None:
self._init_cache()
logger.debug("4.2")
# Send anonymous telemetry
self._telemetry_props = {"class": self.__class__.__name__}
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
@@ -238,7 +239,7 @@ class App(EmbedChain):
response.raise_for_status()
return response.status_code == 200
except Exception as e:
self.logger.exception(f"Error occurred during file upload: {str(e)}")
logger.exception(f"Error occurred during file upload: {str(e)}")
print("❌ Error occurred during file upload!")
return False
@@ -272,7 +273,7 @@ class App(EmbedChain):
metadata = {"file_path": data_value, "s3_key": s3_key}
data_value = presigned_url
else:
self.logger.error(f"File upload failed for hash: {data_hash}")
logger.error(f"File upload failed for hash: {data_hash}")
return False
else:
if data_type == "qna_pair":
@@ -336,6 +337,7 @@ class App(EmbedChain):
:return: An instance of the App class.
:rtype: App
"""
logger.debug("6")
# Backward compatibility for yaml_path
if yaml_path and not config_path:
config_path = yaml_path
@@ -357,15 +359,13 @@ class App(EmbedChain):
elif config and isinstance(config, dict):
config_data = config
else:
logging.error(
logger.error(
"Please provide either a config file path (YAML or JSON) or a config dictionary. Falling back to defaults because no config is provided.", # noqa: E501
)
config_data = {}
try:
validate_config(config_data)
except Exception as e:
raise Exception(f"Error occurred while validating the config. Error: {str(e)}")
# Validate the config
validate_config(config_data)
app_config_data = config_data.get("app", {}).get("config", {})
vector_db_config_data = config_data.get("vectordb", {})
@@ -477,12 +477,12 @@ class App(EmbedChain):
EvalMetric.GROUNDEDNESS.value,
]
logging.info(f"Collecting data from {len(queries)} questions for evaluation...")
logger.info(f"Collecting data from {len(queries)} questions for evaluation...")
dataset = []
for q, a, c in zip(queries, answers, contexts):
dataset.append(EvalData(question=q, answer=a, contexts=c))
logging.info(f"Evaluating {len(dataset)} data points...")
logger.info(f"Evaluating {len(dataset)} data points...")
result = {}
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
future_to_metric = {executor.submit(self._eval, dataset, metric): metric for metric in metrics}

View File

@@ -17,6 +17,8 @@ except ModuleNotFoundError:
) from None
logger = logging.getLogger(__name__)
intents = discord.Intents.default()
intents.message_content = True
client = discord.Client(intents=intents)
@@ -37,7 +39,7 @@ class DiscordBot(BaseBot):
self.add(data)
response = f"Added data from: {data}"
except Exception:
logging.exception(f"Failed to add data {data}.")
logger.exception(f"Failed to add data {data}.")
response = "Some error occurred while adding data."
return response
@@ -45,7 +47,7 @@ class DiscordBot(BaseBot):
try:
response = self.query(message)
except Exception:
logging.exception(f"Failed to query {message}.")
logger.exception(f"Failed to query {message}.")
response = "An error occurred. Please try again!"
return response
@@ -60,7 +62,7 @@ class DiscordBot(BaseBot):
async def query_command(interaction: discord.Interaction, question: str):
await interaction.response.defer()
member = client.guilds[0].get_member(client.user.id)
logging.info(f"User: {member}, Query: {question}")
logger.info(f"User: {member}, Query: {question}")
try:
answer = discord_bot.ask_bot(question)
if args.include_question:
@@ -70,20 +72,20 @@ async def query_command(interaction: discord.Interaction, question: str):
await interaction.followup.send(response)
except Exception as e:
await interaction.followup.send("An error occurred. Please try again!")
logging.error("Error occurred during 'query' command:", e)
logger.error("Error occurred during 'query' command:", e)
@tree.command(name="add", description="add new content to the embedchain database")
async def add_command(interaction: discord.Interaction, url_or_text: str):
await interaction.response.defer()
member = client.guilds[0].get_member(client.user.id)
logging.info(f"User: {member}, Add: {url_or_text}")
logger.info(f"User: {member}, Add: {url_or_text}")
try:
response = discord_bot.add_data(url_or_text)
await interaction.followup.send(response)
except Exception as e:
await interaction.followup.send("An error occurred. Please try again!")
logging.error("Error occurred during 'add' command:", e)
logger.error("Error occurred during 'add' command:", e)
@tree.command(name="ping", description="Simple ping pong command")
@@ -96,7 +98,7 @@ async def on_app_command_error(interaction: discord.Interaction, error: discord.
if isinstance(error, commands.CommandNotFound):
await interaction.followup.send("Invalid command. Please refer to the documentation for correct syntax.")
else:
logging.error("Error occurred during command execution:", error)
logger.error("Error occurred during command execution:", error)
@client.event
@@ -104,8 +106,8 @@ async def on_ready():
# TODO: Sync in admin command, to not hit rate limits.
# This might be overkill for most users, and it would require to set a guild or user id, where sync is allowed.
await tree.sync()
logging.debug("Command tree synced")
logging.info(f"Logged in as {client.user.name}")
logger.debug("Command tree synced")
logger.info(f"Logged in as {client.user.name}")
def start_command():

View File

@@ -19,6 +19,8 @@ except ModuleNotFoundError:
) from None
logger = logging.getLogger(__name__)
SLACK_BOT_TOKEN = os.environ.get("SLACK_BOT_TOKEN")
@@ -42,10 +44,10 @@ class SlackBot(BaseBot):
try:
response = self.chat_bot.chat(question)
self.send_slack_message(message["channel"], response)
logging.info("Query answered successfully!")
logger.info("Query answered successfully!")
except Exception as e:
self.send_slack_message(message["channel"], "An error occurred. Please try again!")
logging.error("Error occurred during 'query' command:", e)
logger.error("Error occurred during 'query' command:", e)
elif text.startswith("add"):
_, data_type, url_or_text = text.split(" ", 2)
if url_or_text.startswith("<") and url_or_text.endswith(">"):
@@ -55,10 +57,10 @@ class SlackBot(BaseBot):
self.send_slack_message(message["channel"], f"Added {data_type} : {url_or_text}")
except ValueError as e:
self.send_slack_message(message["channel"], f"Error: {str(e)}")
logging.error("Error occurred during 'add' command:", e)
logger.error("Error occurred during 'add' command:", e)
except Exception as e:
self.send_slack_message(message["channel"], f"Failed to add {data_type} : {url_or_text}")
logging.error("Error occurred during 'add' command:", e)
logger.error("Error occurred during 'add' command:", e)
def send_slack_message(self, channel, message):
response = self.client.chat_postMessage(channel=channel, text=message)
@@ -68,7 +70,7 @@ class SlackBot(BaseBot):
app = Flask(__name__)
def signal_handler(sig, frame):
logging.info("\nGracefully shutting down the SlackBot...")
logger.info("\nGracefully shutting down the SlackBot...")
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)

View File

@@ -8,6 +8,8 @@ from embedchain.helpers.json_serializable import register_deserializable
from .base import BaseBot
logger = logging.getLogger(__name__)
@register_deserializable
class WhatsAppBot(BaseBot):
@@ -35,7 +37,7 @@ class WhatsAppBot(BaseBot):
self.add(data)
response = f"Added data from: {data}"
except Exception:
logging.exception(f"Failed to add data {data}.")
logger.exception(f"Failed to add data {data}.")
response = "Some error occurred while adding data."
return response
@@ -43,7 +45,7 @@ class WhatsAppBot(BaseBot):
try:
response = self.query(message)
except Exception:
logging.exception(f"Failed to query {message}.")
logger.exception(f"Failed to query {message}.")
response = "An error occurred. Please try again!"
return response
@@ -51,7 +53,7 @@ class WhatsAppBot(BaseBot):
app = self.flask.Flask(__name__)
def signal_handler(sig, frame):
logging.info("\nGracefully shutting down the WhatsAppBot...")
logger.info("\nGracefully shutting down the WhatsAppBot...")
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)

View File

@@ -14,6 +14,8 @@ from gptcache.similarity_evaluation.distance import \
from gptcache.similarity_evaluation.exact_match import \
ExactMatchEvaluation # noqa: F401
logger = logging.getLogger(__name__)
def gptcache_pre_function(data: dict[str, Any], **params: dict[str, Any]):
return data["input_query"]
@@ -24,12 +26,12 @@ def gptcache_data_manager(vector_dimension):
def gptcache_data_convert(cache_data):
logging.info("[Cache] Cache hit, returning cache data...")
logger.info("[Cache] Cache hit, returning cache data...")
return cache_data
def gptcache_update_cache_callback(llm_data, update_cache_func, *args, **kwargs):
logging.info("[Cache] Cache missed, updating cache...")
logger.info("[Cache] Cache missed, updating cache...")
update_cache_func(Answer(llm_data, CacheDataType.STR))
return llm_data

View File

@@ -6,6 +6,8 @@ from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.models.data_type import DataType
logger = logging.getLogger(__name__)
class BaseChunker(JSONSerializable):
def __init__(self, text_splitter):
@@ -27,7 +29,7 @@ class BaseChunker(JSONSerializable):
chunk_ids = []
id_map = {}
min_chunk_size = config.min_chunk_size if config is not None else 1
logging.info(f"Skipping chunks smaller than {min_chunk_size} characters")
logger.info(f"Skipping chunks smaller than {min_chunk_size} characters")
data_result = loader.load_data(src)
data_records = data_result["data"]
doc_id = data_result["doc_id"]

View File

@@ -7,6 +7,8 @@ import requests
from embedchain.constants import CONFIG_DIR, CONFIG_FILE
logger = logging.getLogger(__name__)
class Client:
def __init__(self, api_key=None, host="https://apiv2.embedchain.ai"):
@@ -24,7 +26,7 @@ class Client:
else:
if "api_key" in self.config_data:
self.api_key = self.config_data["api_key"]
logging.info("API key loaded successfully!")
logger.info("API key loaded successfully!")
else:
raise ValueError(
"You are not logged in. Please obtain an API key from https://app.embedchain.ai/settings/keys/"
@@ -64,7 +66,7 @@ class Client:
with open(CONFIG_FILE, "w") as config_file:
json.dump(self.config_data, config_file, indent=4)
logging.info("API key saved successfully!")
logger.info("API key saved successfully!")
def clear(self):
if "api_key" in self.config_data:
@@ -72,17 +74,17 @@ class Client:
with open(CONFIG_FILE, "w") as config_file:
json.dump(self.config_data, config_file, indent=4)
self.api_key = None
logging.info("API key deleted successfully!")
logger.info("API key deleted successfully!")
else:
logging.warning("API key not found in the configuration file.")
logger.warning("API key not found in the configuration file.")
def update(self, api_key):
if self.check(api_key):
self.api_key = api_key
self.save()
logging.info("API key updated successfully!")
logger.info("API key updated successfully!")
else:
logging.warning("Invalid API key provided. API key not updated.")
logger.warning("Invalid API key provided. API key not updated.")
def check(self, api_key):
validation_url = f"{self.host}/api/v1/accounts/api_keys/validate/"
@@ -90,8 +92,8 @@ class Client:
if response.status_code == 200:
return True
else:
logging.warning(f"Response from API: {response.text}")
logging.warning("Invalid API key. Unable to validate.")
logger.warning(f"Response from API: {response.text}")
logger.warning("Invalid API key. Unable to validate.")
return False
def get(self):

View File

@@ -5,6 +5,8 @@ from embedchain.config.base_config import BaseConfig
from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.vectordb.base import BaseVectorDB
logger = logging.getLogger(__name__)
class BaseAppConfig(BaseConfig, JSONSerializable):
"""
@@ -42,15 +44,15 @@ class BaseAppConfig(BaseConfig, JSONSerializable):
if db:
self._db = db
logging.warning(
logger.warning(
"DEPRECATION WARNING: Please supply the database as the second parameter during app init. "
"Such as `app(config=config, db=db)`."
)
if collection_name:
logging.warning("DEPRECATION WARNING: Please supply the collection name to the database config.")
logger.warning("DEPRECATION WARNING: Please supply the collection name to the database config.")
return
def _setup_logging(self, log_level):
logging.basicConfig(format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s", level=log_level)
self.logger = logging.getLogger(__name__)
logger.basicConfig(format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s", level=log_level)
self.logger = logger.getLogger(__name__)

View File

@@ -6,6 +6,8 @@ from typing import Any, Optional
from embedchain.config.base_config import BaseConfig
from embedchain.helpers.json_serializable import register_deserializable
logger = logging.getLogger(__name__)
DEFAULT_PROMPT = """
You are a Q&A expert system. Your responses must always be rooted in the context provided for each query. Here are some guidelines to follow:
@@ -147,7 +149,7 @@ class BaseLlmConfig(BaseConfig):
:raises ValueError: Stream is not boolean
"""
if template is not None:
logging.warning(
logger.warning(
"The `template` argument is deprecated and will be removed in a future version. "
+ "Please use `prompt` instead."
)

View File

@@ -25,6 +25,8 @@ from embedchain.vectordb.base import BaseVectorDB
load_dotenv()
logger = logging.getLogger(__name__)
class EmbedChain(JSONSerializable):
def __init__(
@@ -143,10 +145,10 @@ class EmbedChain(JSONSerializable):
try:
DataType(source)
logging.warning(
logger.warning(
f"""Starting from version v0.0.40, Embedchain can automatically detect the data type. So, in the `add` method, the argument order has changed. You no longer need to specify '{source}' for the `source` argument. So the code snippet will be `.add("{data_type}", "{source}")`""" # noqa #E501
)
logging.warning(
logger.warning(
"Embedchain is swapping the arguments for you. This functionality might be deprecated in the future, so please adjust your code." # noqa #E501
)
source, data_type = data_type, source
@@ -157,7 +159,7 @@ class EmbedChain(JSONSerializable):
try:
data_type = DataType(data_type)
except ValueError:
logging.info(
logger.info(
f"Invalid data_type: '{data_type}', using `custom` instead.\n Check docs to pass the valid data type: `https://docs.embedchain.ai/data-sources/overview`" # noqa: E501
)
data_type = DataType.CUSTOM
@@ -190,12 +192,12 @@ class EmbedChain(JSONSerializable):
try:
self.db_session.commit()
except Exception as e:
logging.error(f"Error adding data source: {e}")
logger.error(f"Error adding data source: {e}")
self.db_session.rollback()
if dry_run:
data_chunks_info = {"chunks": documents, "metadata": metadatas, "count": len(documents), "type": data_type}
logging.debug(f"Dry run info : {data_chunks_info}")
logger.debug(f"Dry run info : {data_chunks_info}")
return data_chunks_info
# Send anonymous telemetry
@@ -490,7 +492,7 @@ class EmbedChain(JSONSerializable):
contexts_data_for_llm_query = contexts
if self.cache_config is not None:
logging.info("Cache enabled. Checking cache...")
logger.info("Cache enabled. Checking cache...")
answer = adapt(
llm_handler=self.llm.query,
cache_data_convert=gptcache_data_convert,
@@ -562,7 +564,7 @@ class EmbedChain(JSONSerializable):
self.llm.update_history(app_id=self.config.id, session_id=session_id)
if self.cache_config is not None:
logging.info("Cache enabled. Checking cache...")
logger.debug("Cache enabled. Checking cache...")
cache_id = f"{session_id}--{self.config.id}"
answer = adapt(
llm_handler=self.llm.chat,
@@ -575,6 +577,7 @@ class EmbedChain(JSONSerializable):
dry_run=dry_run,
)
else:
logger.debug("Cache disabled. Running chat without cache.")
answer = self.llm.chat(
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
)
@@ -652,7 +655,7 @@ class EmbedChain(JSONSerializable):
self.db_session.query(ChatHistory).filter_by(app_id=self.config.id).delete()
self.db_session.commit()
except Exception as e:
logging.error(f"Error deleting data sources: {e}")
logger.error(f"Error deleting data sources: {e}")
self.db_session.rollback()
return None
self.db.reset()
@@ -694,11 +697,11 @@ class EmbedChain(JSONSerializable):
self.db_session.query(DataSource).filter_by(hash=source_id, app_id=self.config.id).delete()
self.db_session.commit()
except Exception as e:
logging.error(f"Error deleting data sources: {e}")
logger.error(f"Error deleting data sources: {e}")
self.db_session.rollback()
return None
self.db.delete(where={"hash": source_id})
logging.info(f"Successfully deleted {source_id}")
logger.info(f"Successfully deleted {source_id}")
# Send anonymous telemetry
if self.config.collect_metrics:
self.telemetry.capture(event_name="delete", properties=self._telemetry_props)

View File

@@ -8,6 +8,8 @@ from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.models import VectorDimensions
logger = logging.getLogger(__name__)
class NvidiaEmbedder(BaseEmbedder):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
@@ -17,7 +19,7 @@ class NvidiaEmbedder(BaseEmbedder):
super().__init__(config=config)
model = self.config.model or "nvolveqa_40k"
logging.info(f"Using NVIDIA embedding model: {model}")
logger.info(f"Using NVIDIA embedding model: {model}")
embedder = NVIDIAEmbeddings(model=model)
embedding_fn = BaseEmbedder._langchain_default_concept(embedder)
self.set_embedding_fn(embedding_fn=embedding_fn)

View File

@@ -12,6 +12,8 @@ from embedchain.config.evaluation.base import AnswerRelevanceConfig
from embedchain.evaluation.base import BaseMetric
from embedchain.utils.evaluation import EvalData, EvalMetric
logger = logging.getLogger(__name__)
class AnswerRelevance(BaseMetric):
"""
@@ -88,6 +90,6 @@ class AnswerRelevance(BaseMetric):
try:
results.append(future.result())
except Exception as e:
logging.error(f"Error evaluating answer relevancy for {data}: {e}")
logger.error(f"Error evaluating answer relevancy for {data}: {e}")
return np.mean(results) if results else 0.0

View File

@@ -12,6 +12,8 @@ from embedchain.config.evaluation.base import GroundednessConfig
from embedchain.evaluation.base import BaseMetric
from embedchain.utils.evaluation import EvalData, EvalMetric
logger = logging.getLogger(__name__)
class Groundedness(BaseMetric):
"""
@@ -97,6 +99,6 @@ class Groundedness(BaseMetric):
score = future.result()
results.append(score)
except Exception as e:
logging.error(f"Error while evaluating groundedness for data point {data}: {e}")
logger.error(f"Error while evaluating groundedness for data point {data}: {e}")
return np.mean(results) if results else 0.0

View File

@@ -8,6 +8,8 @@ T = TypeVar("T", bound="JSONSerializable")
# NOTE: Through inheritance, all of our classes should be children of JSONSerializable. (highest level)
# NOTE: The @register_deserializable decorator should be added to all user facing child classes. (lowest level)
logger = logging.getLogger(__name__)
def register_deserializable(cls: Type[T]) -> Type[T]:
"""
@@ -57,7 +59,7 @@ class JSONSerializable:
try:
return json.dumps(self, default=self._auto_encoder, ensure_ascii=False)
except Exception as e:
logging.error(f"Serialization error: {e}")
logger.error(f"Serialization error: {e}")
return "{}"
@classmethod
@@ -79,7 +81,7 @@ class JSONSerializable:
try:
return json.loads(json_str, object_hook=cls._auto_decoder)
except Exception as e:
logging.error(f"Deserialization error: {e}")
logger.error(f"Deserialization error: {e}")
# Return a default instance in case of failure
return cls()

View File

@@ -6,6 +6,8 @@ from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
logger = logging.getLogger(__name__)
@register_deserializable
class AnthropicLlm(BaseLlm):
@@ -26,7 +28,7 @@ class AnthropicLlm(BaseLlm):
)
if config.max_tokens and config.max_tokens != 1000:
logging.warning("Config option `max_tokens` is not supported by this model.")
logger.warning("Config option `max_tokens` is not supported by this model.")
messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)

View File

@@ -38,7 +38,8 @@ class AWSBedrockLlm(BaseLlm):
}
if config.stream:
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks.streaming_stdout import \
StreamingStdOutCallbackHandler
callbacks = [StreamingStdOutCallbackHandler()]
llm = Bedrock(**kwargs, streaming=config.stream, callbacks=callbacks)

View File

@@ -5,6 +5,8 @@ from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
logger = logging.getLogger(__name__)
@register_deserializable
class AzureOpenAILlm(BaseLlm):
@@ -31,7 +33,7 @@ class AzureOpenAILlm(BaseLlm):
)
if config.top_p and config.top_p != 1:
logging.warning("Config option `top_p` is not supported by this model.")
logger.warning("Config option `top_p` is not supported by this model.")
messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)

View File

@@ -12,6 +12,8 @@ from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.memory.base import ChatHistory
from embedchain.memory.message import ChatMessage
logger = logging.getLogger(__name__)
class BaseLlm(JSONSerializable):
def __init__(self, config: Optional[BaseLlmConfig] = None):
@@ -108,7 +110,7 @@ class BaseLlm(JSONSerializable):
)
else:
# If we can't swap in the default, we still proceed but tell users that the history is ignored.
logging.warning(
logger.warning(
"Your bot contains a history, but prompt does not include `$history` key. History is ignored."
)
prompt = self.config.prompt.substitute(context=context_string, query=input_query)
@@ -159,7 +161,7 @@ class BaseLlm(JSONSerializable):
'Searching requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
) from None
search = DuckDuckGoSearchRun()
logging.info(f"Access search to get answers for {input_query}")
logger.info(f"Access search to get answers for {input_query}")
return search.run(input_query)
@staticmethod
@@ -175,7 +177,7 @@ class BaseLlm(JSONSerializable):
for chunk in answer:
streamed_answer = streamed_answer + chunk
yield chunk
logging.info(f"Answer: {streamed_answer}")
logger.info(f"Answer: {streamed_answer}")
def query(self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False):
"""
@@ -214,13 +216,13 @@ class BaseLlm(JSONSerializable):
if self.online:
k["web_search_result"] = self.access_search_and_get_results(input_query)
prompt = self.generate_prompt(input_query, contexts, **k)
logging.info(f"Prompt: {prompt}")
logger.info(f"Prompt: {prompt}")
if dry_run:
return prompt
answer = self.get_answer_from_llm(prompt)
if isinstance(answer, str):
logging.info(f"Answer: {answer}")
logger.info(f"Answer: {answer}")
return answer
else:
return self._stream_response(answer)
@@ -270,14 +272,14 @@ class BaseLlm(JSONSerializable):
k["web_search_result"] = self.access_search_and_get_results(input_query)
prompt = self.generate_prompt(input_query, contexts, **k)
logging.info(f"Prompt: {prompt}")
logger.info(f"Prompt: {prompt}")
if dry_run:
return prompt
answer = self.get_answer_from_llm(prompt)
if isinstance(answer, str):
logging.info(f"Answer: {answer}")
logger.info(f"Answer: {answer}")
return answer
else:
# this is a streamed response and needs to be handled differently.

View File

@@ -10,6 +10,8 @@ from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
logger = logging.getLogger(__name__)
@register_deserializable
class GoogleLlm(BaseLlm):
@@ -36,7 +38,7 @@ class GoogleLlm(BaseLlm):
def _get_answer(self, prompt: str) -> Union[str, Generator[Any, Any, None]]:
model_name = self.config.model or "gemini-pro"
logging.info(f"Using Google LLM model: {model_name}")
logger.info(f"Using Google LLM model: {model_name}")
model = genai.GenerativeModel(model_name=model_name)
generation_config_params = {

View File

@@ -11,6 +11,8 @@ from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
logger = logging.getLogger(__name__)
@register_deserializable
class HuggingFaceLlm(BaseLlm):
@@ -58,7 +60,7 @@ class HuggingFaceLlm(BaseLlm):
raise ValueError("`top_p` must be > 0.0 and < 1.0")
model = config.model
logging.info(f"Using HuggingFaceHub with model {model}")
logger.info(f"Using HuggingFaceHub with model {model}")
llm = HuggingFaceHub(
huggingfacehub_api_token=os.environ["HUGGINGFACE_ACCESS_TOKEN"],
repo_id=model,

View File

@@ -65,7 +65,8 @@ class OpenAILlm(BaseLlm):
messages: list[BaseMessage],
) -> str:
from langchain.output_parsers.openai_tools import JsonOutputToolsParser
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.function_calling import \
convert_to_openai_tool
openai_tools = [convert_to_openai_tool(tools)]
chat = chat.bind(tools=openai_tools).pipe(JsonOutputToolsParser())

View File

@@ -9,6 +9,8 @@ from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
logger = logging.getLogger(__name__)
@register_deserializable
class VertexAILlm(BaseLlm):
@@ -28,7 +30,7 @@ class VertexAILlm(BaseLlm):
@staticmethod
def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
if config.top_p and config.top_p != 1:
logging.warning("Config option `top_p` is not supported by this model.")
logger.warning("Config option `top_p` is not supported by this model.")
messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)

View File

@@ -9,6 +9,8 @@ from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils.misc import is_readable
logger = logging.getLogger(__name__)
@register_deserializable
class BeehiivLoader(BaseLoader):
@@ -90,9 +92,9 @@ class BeehiivLoader(BaseLoader):
if is_readable(data):
return data
else:
logging.warning(f"Page is not readable (too many invalid characters): {link}")
logger.warning(f"Page is not readable (too many invalid characters): {link}")
except ParserRejectedMarkup as e:
logging.error(f"Failed to parse {link}: {e}")
logger.error(f"Failed to parse {link}: {e}")
return None
for link in links:

View File

@@ -10,6 +10,8 @@ from embedchain.loaders.base_loader import BaseLoader
from embedchain.loaders.text_file import TextFileLoader
from embedchain.utils.misc import detect_datatype
logger = logging.getLogger(__name__)
@register_deserializable
class DirectoryLoader(BaseLoader):
@@ -27,12 +29,12 @@ class DirectoryLoader(BaseLoader):
if not directory_path.is_dir():
raise ValueError(f"Invalid path: {path}")
logging.info(f"Loading data from directory: {path}")
logger.info(f"Loading data from directory: {path}")
data_list = self._process_directory(directory_path)
doc_id = hashlib.sha256((str(data_list) + str(directory_path)).encode()).hexdigest()
for error in self.errors:
logging.warning(error)
logger.warning(error)
return {"doc_id": doc_id, "data": data_list}
@@ -46,7 +48,7 @@ class DirectoryLoader(BaseLoader):
loader = self._predict_loader(file_path)
data_list.extend(loader.load_data(str(file_path))["data"])
elif file_path.is_dir():
logging.info(f"Loading data from directory: {file_path}")
logger.info(f"Loading data from directory: {file_path}")
return data_list
def _predict_loader(self, file_path: Path) -> BaseLoader:

View File

@@ -5,6 +5,8 @@ import os
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
logger = logging.getLogger(__name__)
@register_deserializable
class DiscordLoader(BaseLoader):
@@ -102,7 +104,7 @@ class DiscordLoader(BaseLoader):
class DiscordClient(discord.Client):
async def on_ready(self) -> None:
logging.info("Logged on as {0}!".format(self.user))
logger.info("Logged on as {0}!".format(self.user))
try:
channel = self.get_channel(int(channel_id))
if not isinstance(channel, discord.TextChannel):
@@ -121,7 +123,7 @@ class DiscordLoader(BaseLoader):
messages.append(DiscordLoader._format_message(thread_message))
except Exception as e:
logging.error(e)
logger.error(e)
await self.close()
finally:
await self.close()

View File

@@ -8,6 +8,8 @@ import requests
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils.misc import clean_string
logger = logging.getLogger(__name__)
class DiscourseLoader(BaseLoader):
def __init__(self, config: Optional[dict[str, Any]] = None):
@@ -35,7 +37,7 @@ class DiscourseLoader(BaseLoader):
try:
response.raise_for_status()
except Exception as e:
logging.error(f"Failed to load post {post_id}: {e}")
logger.error(f"Failed to load post {post_id}: {e}")
return
response_data = response.json()
post_contents = clean_string(response_data.get("raw"))
@@ -56,7 +58,7 @@ class DiscourseLoader(BaseLoader):
self._check_query(query)
data = []
data_contents = []
logging.info(f"Searching data on discourse url: {self.domain}, for query: {query}")
logger.info(f"Searching data on discourse url: {self.domain}, for query: {query}")
search_url = f"{self.domain}search.json?q={query}"
response = requests.get(search_url)
try:

View File

@@ -15,6 +15,8 @@ except ImportError:
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
logger = logging.getLogger(__name__)
@register_deserializable
class DocsSiteLoader(BaseLoader):
@@ -28,7 +30,7 @@ class DocsSiteLoader(BaseLoader):
response = requests.get(url)
if response.status_code != 200:
logging.info(f"Failed to fetch the website: {response.status_code}")
logger.info(f"Failed to fetch the website: {response.status_code}")
return
soup = BeautifulSoup(response.text, "html.parser")
@@ -53,7 +55,7 @@ class DocsSiteLoader(BaseLoader):
def _load_data_from_url(url: str) -> list:
response = requests.get(url)
if response.status_code != 200:
logging.info(f"Failed to fetch the website: {response.status_code}")
logger.info(f"Failed to fetch the website: {response.status_code}")
return []
soup = BeautifulSoup(response.content, "html.parser")

View File

@@ -22,6 +22,8 @@ except ImportError:
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils.misc import clean_string
logger = logging.getLogger(__name__)
class GmailReader:
SCOPES = ["https://www.googleapis.com/auth/gmail.readonly"]
@@ -114,7 +116,7 @@ class GmailLoader(BaseLoader):
def load_data(self, query: str):
reader = GmailReader(query=query)
emails = reader.load_emails()
logging.info(f"Gmail Loader: {len(emails)} emails found for query '{query}'")
logger.info(f"Gmail Loader: {len(emails)} emails found for query '{query}'")
data = []
for email in emails:

View File

@@ -5,6 +5,8 @@ from typing import Any, Optional
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils.misc import clean_string
logger = logging.getLogger(__name__)
class MySQLLoader(BaseLoader):
def __init__(self, config: Optional[dict[str, Any]]):
@@ -32,7 +34,7 @@ class MySQLLoader(BaseLoader):
self.connection = sqlconnector.connection.MySQLConnection(**config)
self.cursor = self.connection.cursor()
except (sqlconnector.Error, IOError) as err:
logging.info(f"Connection failed: {err}")
logger.info(f"Connection failed: {err}")
raise ValueError(
f"Unable to connect with the given config: {config}.",
"Please provide the correct configuration to load data from you MySQL DB. \

View File

@@ -9,6 +9,8 @@ from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils.misc import clean_string
logger = logging.getLogger(__name__)
class NotionDocument:
"""
@@ -98,7 +100,7 @@ class NotionLoader(BaseLoader):
id = source[-32:]
formatted_id = f"{id[:8]}-{id[8:12]}-{id[12:16]}-{id[16:20]}-{id[20:]}"
logging.debug(f"Extracted notion page id as: {formatted_id}")
logger.debug(f"Extracted notion page id as: {formatted_id}")
integration_token = os.getenv("NOTION_INTEGRATION_TOKEN")
reader = NotionPageLoader(integration_token=integration_token)

View File

@@ -4,6 +4,8 @@ from typing import Any, Optional
from embedchain.loaders.base_loader import BaseLoader
logger = logging.getLogger(__name__)
class PostgresLoader(BaseLoader):
def __init__(self, config: Optional[dict[str, Any]] = None):
@@ -32,7 +34,7 @@ class PostgresLoader(BaseLoader):
conn_params.append(f"{key}={value}")
config_info = " ".join(conn_params)
logging.info(f"Connecting to postrgres sql: {config_info}")
logger.info(f"Connecting to postrgres sql: {config_info}")
self.connection = psycopg.connect(conninfo=config_info)
self.cursor = self.connection.cursor()

View File

@@ -19,6 +19,8 @@ from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
from embedchain.loaders.web_page import WebPageLoader
logger = logging.getLogger(__name__)
@register_deserializable
class SitemapLoader(BaseLoader):
@@ -41,7 +43,7 @@ class SitemapLoader(BaseLoader):
response.raise_for_status()
soup = BeautifulSoup(response.text, "xml")
except requests.RequestException as e:
logging.error(f"Error fetching sitemap from URL: {e}")
logger.error(f"Error fetching sitemap from URL: {e}")
return
elif os.path.isfile(sitemap_source):
with open(sitemap_source, "r") as file:
@@ -60,7 +62,7 @@ class SitemapLoader(BaseLoader):
loader_data = web_page_loader.load_data(link)
return loader_data.get("data")
except ParserRejectedMarkup as e:
logging.error(f"Failed to parse {link}: {e}")
logger.error(f"Failed to parse {link}: {e}")
return None
with concurrent.futures.ThreadPoolExecutor() as executor:
@@ -72,6 +74,6 @@ class SitemapLoader(BaseLoader):
if data:
output.extend(data)
except Exception as e:
logging.error(f"Error loading page {link}: {e}")
logger.error(f"Error loading page {link}: {e}")
return {"doc_id": doc_id, "data": output}

View File

@@ -11,6 +11,8 @@ from embedchain.utils.misc import clean_string
SLACK_API_BASE_URL = "https://www.slack.com/api/"
logger = logging.getLogger(__name__)
class SlackLoader(BaseLoader):
def __init__(self, config: Optional[dict[str, Any]] = None):
@@ -38,7 +40,7 @@ class SlackLoader(BaseLoader):
"SLACK_USER_TOKEN environment variables not provided. Check `https://docs.embedchain.ai/data-sources/slack` to learn more." # noqa:E501
)
logging.info(f"Creating Slack Loader with config: {config}")
logger.info(f"Creating Slack Loader with config: {config}")
# get slack client config params
slack_bot_token = os.getenv("SLACK_USER_TOKEN")
ssl_cert = ssl.create_default_context(cafile=certifi.where())
@@ -54,7 +56,7 @@ class SlackLoader(BaseLoader):
headers=headers,
team_id=team_id,
)
logging.info("Slack Loader setup successful!")
logger.info("Slack Loader setup successful!")
@staticmethod
def _check_query(query):
@@ -69,7 +71,7 @@ class SlackLoader(BaseLoader):
data = []
data_content = []
logging.info(f"Searching slack conversations for query: {query}")
logger.info(f"Searching slack conversations for query: {query}")
results = self.client.search_messages(
query=query,
sort="timestamp",
@@ -79,7 +81,7 @@ class SlackLoader(BaseLoader):
messages = results.get("messages")
num_message = len(messages)
logging.info(f"Found {num_message} messages for query: {query}")
logger.info(f"Found {num_message} messages for query: {query}")
matches = messages.get("matches", [])
for message in matches:
@@ -107,7 +109,7 @@ class SlackLoader(BaseLoader):
"data": data,
}
except Exception as e:
logging.warning(f"Error in loading slack data: {e}")
logger.warning(f"Error in loading slack data: {e}")
raise ValueError(
f"Error in loading slack data: {e}. Check `https://docs.embedchain.ai/data-sources/slack` to learn more." # noqa:E501
) from e

View File

@@ -9,6 +9,8 @@ from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils.misc import is_readable
logger = logging.getLogger(__name__)
@register_deserializable
class SubstackLoader(BaseLoader):
@@ -90,9 +92,9 @@ class SubstackLoader(BaseLoader):
if is_readable(data):
return data
else:
logging.warning(f"Page is not readable (too many invalid characters): {link}")
logger.warning(f"Page is not readable (too many invalid characters): {link}")
except ParserRejectedMarkup as e:
logging.error(f"Failed to parse {link}: {e}")
logger.error(f"Failed to parse {link}: {e}")
return None
for link in links:

View File

@@ -14,6 +14,8 @@ from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils.misc import clean_string
logger = logging.getLogger(__name__)
@register_deserializable
class WebPageLoader(BaseLoader):
@@ -87,7 +89,7 @@ class WebPageLoader(BaseLoader):
cleaned_size = len(content)
if original_size != 0:
logging.info(
logger.info(
f"[{url}] Cleaned page size: {cleaned_size} characters, down from {original_size} (shrunk: {original_size-cleaned_size} chars, {round((1-(cleaned_size/original_size)) * 100, 2)}%)" # noqa:E501
)

View File

@@ -7,6 +7,8 @@ from tqdm import tqdm
from embedchain.loaders.base_loader import BaseLoader
from embedchain.loaders.youtube_video import YoutubeVideoLoader
logger = logging.getLogger(__name__)
class YoutubeChannelLoader(BaseLoader):
"""Loader for youtube channel."""
@@ -36,7 +38,7 @@ class YoutubeChannelLoader(BaseLoader):
videos = [entry["url"] for entry in info_dict["entries"]]
return videos
except Exception:
logging.error(f"Failed to fetch youtube videos for channel: {channel_name}")
logger.error(f"Failed to fetch youtube videos for channel: {channel_name}")
return []
def _load_yt_video(video_link):
@@ -45,12 +47,12 @@ class YoutubeChannelLoader(BaseLoader):
if each_load_data:
return each_load_data.get("data")
except Exception as e:
logging.error(f"Failed to load youtube video {video_link}: {e}")
logger.error(f"Failed to load youtube video {video_link}: {e}")
return None
def _add_youtube_channel():
video_links = _get_yt_video_links()
logging.info("Loading videos from youtube channel...")
logger.info("Loading videos from youtube channel...")
with concurrent.futures.ThreadPoolExecutor() as executor:
# Submitting all tasks and storing the future object with the video link
future_to_video = {
@@ -67,7 +69,7 @@ class YoutubeChannelLoader(BaseLoader):
data.extend(results)
data_urls.extend([result.get("meta_data").get("url") for result in results])
except Exception as e:
logging.error(f"Failed to process youtube video {video}: {e}")
logger.error(f"Failed to process youtube video {video}: {e}")
_add_youtube_channel()
doc_id = hashlib.sha256((youtube_url + ", ".join(data_urls)).encode()).hexdigest()

View File

@@ -8,6 +8,8 @@ from embedchain.core.db.models import ChatHistory as ChatHistoryModel
from embedchain.memory.message import ChatMessage
from embedchain.memory.utils import merge_metadata_dict
logger = logging.getLogger(__name__)
class ChatHistory:
def __init__(self) -> None:
@@ -31,11 +33,11 @@ class ChatHistory:
try:
self.db_session.commit()
except Exception as e:
logging.error(f"Error adding chat memory to db: {e}")
logger.error(f"Error adding chat memory to db: {e}")
self.db_session.rollback()
return None
logging.info(f"Added chat memory to db with id: {memory_id}")
logger.info(f"Added chat memory to db with id: {memory_id}")
return memory_id
def delete(self, app_id: str, session_id: Optional[str] = None):
@@ -55,7 +57,7 @@ class ChatHistory:
try:
self.db_session.commit()
except Exception as e:
logging.error(f"Error deleting chat history: {e}")
logger.error(f"Error deleting chat history: {e}")
self.db_session.rollback()
def get(

View File

@@ -3,6 +3,8 @@ from typing import Any, Optional
from embedchain.helpers.json_serializable import JSONSerializable
logger = logging.getLogger(__name__)
class BaseMessage(JSONSerializable):
"""
@@ -52,7 +54,7 @@ class ChatMessage(JSONSerializable):
def add_user_message(self, message: str, metadata: Optional[dict] = None):
if self.human_message:
logging.info(
logger.info(
"Human message already exists in the chat message,\
overwriting it with new message."
)
@@ -61,7 +63,7 @@ class ChatMessage(JSONSerializable):
def add_ai_message(self, message: str, metadata: Optional[dict] = None):
if self.ai_message:
logging.info(
logger.info(
"AI message already exists in the chat message,\
overwriting it with new message."
)

View File

@@ -157,7 +157,6 @@ class AIAssistant:
log_level=logging.INFO,
collect_metrics=True,
):
self.name = name or "AI Assistant"
self.data_sources = data_sources or []
self.log_level = log_level

View File

@@ -11,6 +11,8 @@ from tqdm import tqdm
from embedchain.models.data_type import DataType
logger = logging.getLogger(__name__)
def parse_content(content, type):
implemented = ["html.parser", "lxml", "lxml-xml", "xml", "html5lib"]
@@ -61,7 +63,7 @@ def parse_content(content, type):
cleaned_size = len(content)
if original_size != 0:
logging.info(
logger.info(
f"Cleaned page size: {cleaned_size} characters, down from {original_size} (shrunk: {original_size-cleaned_size} chars, {round((1-(cleaned_size/original_size)) * 100, 2)}%)" # noqa:E501
)
@@ -208,31 +210,31 @@ def detect_datatype(source: Any) -> DataType:
}
if url.netloc in YOUTUBE_ALLOWED_NETLOCKS:
logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
logger.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
return DataType.YOUTUBE_VIDEO
if url.netloc in {"notion.so", "notion.site"}:
logging.debug(f"Source of `{formatted_source}` detected as `notion`.")
logger.debug(f"Source of `{formatted_source}` detected as `notion`.")
return DataType.NOTION
if url.path.endswith(".pdf"):
logging.debug(f"Source of `{formatted_source}` detected as `pdf_file`.")
logger.debug(f"Source of `{formatted_source}` detected as `pdf_file`.")
return DataType.PDF_FILE
if url.path.endswith(".xml"):
logging.debug(f"Source of `{formatted_source}` detected as `sitemap`.")
logger.debug(f"Source of `{formatted_source}` detected as `sitemap`.")
return DataType.SITEMAP
if url.path.endswith(".csv"):
logging.debug(f"Source of `{formatted_source}` detected as `csv`.")
logger.debug(f"Source of `{formatted_source}` detected as `csv`.")
return DataType.CSV
if url.path.endswith(".mdx") or url.path.endswith(".md"):
logging.debug(f"Source of `{formatted_source}` detected as `mdx`.")
logger.debug(f"Source of `{formatted_source}` detected as `mdx`.")
return DataType.MDX
if url.path.endswith(".docx"):
logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
logger.debug(f"Source of `{formatted_source}` detected as `docx`.")
return DataType.DOCX
if url.path.endswith(".yaml"):
@@ -242,14 +244,14 @@ def detect_datatype(source: Any) -> DataType:
try:
yaml_content = yaml.safe_load(response.text)
except yaml.YAMLError as exc:
logging.error(f"Error parsing YAML: {exc}")
logger.error(f"Error parsing YAML: {exc}")
raise TypeError(f"Not a valid data type. Error loading YAML: {exc}")
if is_openapi_yaml(yaml_content):
logging.debug(f"Source of `{formatted_source}` detected as `openapi`.")
logger.debug(f"Source of `{formatted_source}` detected as `openapi`.")
return DataType.OPENAPI
else:
logging.error(
logger.error(
f"Source of `{formatted_source}` does not contain all the required \
fields of OpenAPI yaml. Check 'https://spec.openapis.org/oas/v3.1.0'"
)
@@ -258,35 +260,35 @@ def detect_datatype(source: Any) -> DataType:
make sure you have all the required fields in YAML config data"
)
except requests.exceptions.RequestException as e:
logging.error(f"Error fetching URL {formatted_source}: {e}")
logger.error(f"Error fetching URL {formatted_source}: {e}")
if url.path.endswith(".json"):
logging.debug(f"Source of `{formatted_source}` detected as `json_file`.")
logger.debug(f"Source of `{formatted_source}` detected as `json_file`.")
return DataType.JSON
if "docs" in url.netloc or ("docs" in url.path and url.scheme != "file"):
# `docs_site` detection via path is not accepted for local filesystem URIs,
# because that would mean all paths that contain `docs` are now doc sites, which is too aggressive.
logging.debug(f"Source of `{formatted_source}` detected as `docs_site`.")
logger.debug(f"Source of `{formatted_source}` detected as `docs_site`.")
return DataType.DOCS_SITE
if "github.com" in url.netloc:
logging.debug(f"Source of `{formatted_source}` detected as `github`.")
logger.debug(f"Source of `{formatted_source}` detected as `github`.")
return DataType.GITHUB
if is_google_drive_folder(url.netloc + url.path):
logging.debug(f"Source of `{formatted_source}` detected as `google drive folder`.")
logger.debug(f"Source of `{formatted_source}` detected as `google drive folder`.")
return DataType.GOOGLE_DRIVE_FOLDER
# If none of the above conditions are met, it's a general web page
logging.debug(f"Source of `{formatted_source}` detected as `web_page`.")
logger.debug(f"Source of `{formatted_source}` detected as `web_page`.")
return DataType.WEB_PAGE
elif not isinstance(source, str):
# For datatypes where source is not a string.
if isinstance(source, tuple) and len(source) == 2 and isinstance(source[0], str) and isinstance(source[1], str):
logging.debug(f"Source of `{formatted_source}` detected as `qna_pair`.")
logger.debug(f"Source of `{formatted_source}` detected as `qna_pair`.")
return DataType.QNA_PAIR
# Raise an error if it isn't a string and also not a valid non-string type (one of the previous).
@@ -300,37 +302,37 @@ def detect_datatype(source: Any) -> DataType:
# Note: checking for string is not necessary anymore.
if source.endswith(".docx"):
logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
logger.debug(f"Source of `{formatted_source}` detected as `docx`.")
return DataType.DOCX
if source.endswith(".csv"):
logging.debug(f"Source of `{formatted_source}` detected as `csv`.")
logger.debug(f"Source of `{formatted_source}` detected as `csv`.")
return DataType.CSV
if source.endswith(".xml"):
logging.debug(f"Source of `{formatted_source}` detected as `xml`.")
logger.debug(f"Source of `{formatted_source}` detected as `xml`.")
return DataType.XML
if source.endswith(".mdx") or source.endswith(".md"):
logging.debug(f"Source of `{formatted_source}` detected as `mdx`.")
logger.debug(f"Source of `{formatted_source}` detected as `mdx`.")
return DataType.MDX
if source.endswith(".txt"):
logging.debug(f"Source of `{formatted_source}` detected as `text`.")
logger.debug(f"Source of `{formatted_source}` detected as `text`.")
return DataType.TEXT_FILE
if source.endswith(".pdf"):
logging.debug(f"Source of `{formatted_source}` detected as `pdf_file`.")
logger.debug(f"Source of `{formatted_source}` detected as `pdf_file`.")
return DataType.PDF_FILE
if source.endswith(".yaml"):
with open(source, "r") as file:
yaml_content = yaml.safe_load(file)
if is_openapi_yaml(yaml_content):
logging.debug(f"Source of `{formatted_source}` detected as `openapi`.")
logger.debug(f"Source of `{formatted_source}` detected as `openapi`.")
return DataType.OPENAPI
else:
logging.error(
logger.error(
f"Source of `{formatted_source}` does not contain all the required \
fields of OpenAPI yaml. Check 'https://spec.openapis.org/oas/v3.1.0'"
)
@@ -340,11 +342,11 @@ def detect_datatype(source: Any) -> DataType:
)
if source.endswith(".json"):
logging.debug(f"Source of `{formatted_source}` detected as `json`.")
logger.debug(f"Source of `{formatted_source}` detected as `json`.")
return DataType.JSON
if os.path.exists(source) and is_readable(open(source).read()):
logging.debug(f"Source of `{formatted_source}` detected as `text_file`.")
logger.debug(f"Source of `{formatted_source}` detected as `text_file`.")
return DataType.TEXT_FILE
# If the source is a valid file, that's not detectable as a type, an error is raised.
@@ -360,11 +362,11 @@ def detect_datatype(source: Any) -> DataType:
# check if the source is valid json string
if is_valid_json_string(source):
logging.debug(f"Source of `{formatted_source}` detected as `json`.")
logger.debug(f"Source of `{formatted_source}` detected as `json`.")
return DataType.JSON
# Use text as final fallback.
logging.debug(f"Source of `{formatted_source}` detected as `text`.")
logger.debug(f"Source of `{formatted_source}` detected as `text`.")
return DataType.TEXT

View File

@@ -22,6 +22,9 @@ except RuntimeError:
from chromadb.errors import InvalidDimensionException
logger = logging.getLogger(__name__)
@register_deserializable
class ChromaDB(BaseVectorDB):
"""Vector database using ChromaDB."""
@@ -47,7 +50,7 @@ class ChromaDB(BaseVectorDB):
setattr(self.settings, key, value)
if self.config.host and self.config.port:
logging.info(f"Connecting to ChromaDB server: {self.config.host}:{self.config.port}")
logger.info(f"Connecting to ChromaDB server: {self.config.host}:{self.config.port}")
self.settings.chroma_server_host = self.config.host
self.settings.chroma_server_http_port = self.config.port
self.settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"

View File

@@ -14,6 +14,8 @@ from embedchain.helpers.json_serializable import register_deserializable
from embedchain.utils.misc import chunks
from embedchain.vectordb.base import BaseVectorDB
logger = logging.getLogger(__name__)
@register_deserializable
class ElasticsearchDB(BaseVectorDB):
@@ -62,7 +64,7 @@ class ElasticsearchDB(BaseVectorDB):
"""
This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
"""
logging.info(self.client.info())
logger.info(self.client.info())
index_settings = {
"mappings": {
"properties": {

View File

@@ -19,6 +19,8 @@ from embedchain.config import OpenSearchDBConfig
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.vectordb.base import BaseVectorDB
logger = logging.getLogger(__name__)
@register_deserializable
class OpenSearchDB(BaseVectorDB):
@@ -43,12 +45,12 @@ class OpenSearchDB(BaseVectorDB):
**self.config.extra_params,
)
info = self.client.info()
logging.info(f"Connected to {info['version']['distribution']}. Version: {info['version']['number']}")
logger.info(f"Connected to {info['version']['distribution']}. Version: {info['version']['number']}")
# Remove auth credentials from config after successful connection
super().__init__(config=self.config)
def _initialize(self):
logging.info(self.client.info())
logger.info(self.client.info())
index_name = self._get_index()
if self.client.indices.exists(index=index_name):
print(f"Index '{index_name}' already exists.")

View File

@@ -16,6 +16,8 @@ from embedchain.helpers.json_serializable import register_deserializable
from embedchain.utils.misc import chunks
from embedchain.vectordb.base import BaseVectorDB
logger = logging.getLogger(__name__)
@register_deserializable
class PineconeDB(BaseVectorDB):
@@ -49,7 +51,7 @@ class PineconeDB(BaseVectorDB):
# Setup BM25Encoder if sparse vectors are to be used
self.bm25_encoder = None
if self.config.hybrid_search:
logging.info("Initializing BM25Encoder for sparse vectors..")
logger.info("Initializing BM25Encoder for sparse vectors..")
self.bm25_encoder = self.config.bm25_encoder if self.config.bm25_encoder else BM25Encoder.default()
# Call parent init here because embedder is needed

View File

@@ -13,6 +13,8 @@ except ImportError:
"Zilliz requires extra dependencies. Install with `pip install --upgrade embedchain[milvus]`"
) from None
logger = logging.getLogger(__name__)
@register_deserializable
class ZillizVectorDB(BaseVectorDB):
@@ -62,7 +64,7 @@ class ZillizVectorDB(BaseVectorDB):
:type name: str
"""
if utility.has_collection(name):
logging.info(f"[ZillizDB]: found an existing collection {name}, make sure the auto-id is disabled.")
logger.info(f"[ZillizDB]: found an existing collection {name}, make sure the auto-id is disabled.")
self.collection = Collection(name)
else:
fields = [