diff --git a/embedchain/alembic.ini b/embedchain/alembic.ini index 53023ad8..e1d08b31 100644 --- a/embedchain/alembic.ini +++ b/embedchain/alembic.ini @@ -91,7 +91,6 @@ keys = console keys = generic [logger_root] -level = WARN handlers = console qualname = diff --git a/embedchain/app.py b/embedchain/app.py index 6547e36d..9811610b 100644 --- a/embedchain/app.py +++ b/embedchain/app.py @@ -32,6 +32,8 @@ from embedchain.utils.misc import validate_config from embedchain.vectordb.base import BaseVectorDB from embedchain.vectordb.chroma import ChromaDB +logger = logging.getLogger(__name__) + @register_deserializable class App(EmbedChain): @@ -50,10 +52,10 @@ class App(EmbedChain): embedding_model: BaseEmbedder = None, llm: BaseLlm = None, config_data: dict = None, - log_level=logging.WARN, auto_deploy: bool = False, chunker: ChunkerConfig = None, cache_config: CacheConfig = None, + log_level: int = logging.WARN, ): """ Initialize a new `App` instance. @@ -68,8 +70,6 @@ class App(EmbedChain): :type llm: BaseLlm, optional :param config_data: Config dictionary, defaults to None :type config_data: dict, optional - :param log_level: Log level to use, defaults to logging.WARN - :type log_level: int, optional :param auto_deploy: Whether to deploy the pipeline automatically, defaults to False :type auto_deploy: bool, optional :raises Exception: If an error occurs while creating the pipeline @@ -83,13 +83,12 @@ class App(EmbedChain): if name and config: raise Exception("Cannot provide both name and config. Please provide only one of them.") - # logging.basicConfig(level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") - self.logger = logging.getLogger(__name__) - + logger.debug("4.0") # Initialize the metadata db for the app setup_engine(database_uri=os.environ.get("EMBEDCHAIN_DB_URI")) init_db() + logger.debug("4.0") self.auto_deploy = auto_deploy # Store the dict config as an attribute to be able to send it self.config_data = config_data if (config_data and validate_config(config_data)) else None @@ -119,6 +118,7 @@ class App(EmbedChain): self.llm = llm or OpenAILlm() self._init_db() + logger.debug("4.1") # Session for the metadata db self.db_session = get_session() @@ -126,6 +126,7 @@ class App(EmbedChain): if self.cache_config is not None: self._init_cache() + logger.debug("4.2") # Send anonymous telemetry self._telemetry_props = {"class": self.__class__.__name__} self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics) @@ -238,7 +239,7 @@ class App(EmbedChain): response.raise_for_status() return response.status_code == 200 except Exception as e: - self.logger.exception(f"Error occurred during file upload: {str(e)}") + logger.exception(f"Error occurred during file upload: {str(e)}") print("❌ Error occurred during file upload!") return False @@ -272,7 +273,7 @@ class App(EmbedChain): metadata = {"file_path": data_value, "s3_key": s3_key} data_value = presigned_url else: - self.logger.error(f"File upload failed for hash: {data_hash}") + logger.error(f"File upload failed for hash: {data_hash}") return False else: if data_type == "qna_pair": @@ -336,6 +337,7 @@ class App(EmbedChain): :return: An instance of the App class. :rtype: App """ + logger.debug("6") # Backward compatibility for yaml_path if yaml_path and not config_path: config_path = yaml_path @@ -357,15 +359,13 @@ class App(EmbedChain): elif config and isinstance(config, dict): config_data = config else: - logging.error( + logger.error( "Please provide either a config file path (YAML or JSON) or a config dictionary. Falling back to defaults because no config is provided.", # noqa: E501 ) config_data = {} - try: - validate_config(config_data) - except Exception as e: - raise Exception(f"Error occurred while validating the config. Error: {str(e)}") + # Validate the config + validate_config(config_data) app_config_data = config_data.get("app", {}).get("config", {}) vector_db_config_data = config_data.get("vectordb", {}) @@ -477,12 +477,12 @@ class App(EmbedChain): EvalMetric.GROUNDEDNESS.value, ] - logging.info(f"Collecting data from {len(queries)} questions for evaluation...") + logger.info(f"Collecting data from {len(queries)} questions for evaluation...") dataset = [] for q, a, c in zip(queries, answers, contexts): dataset.append(EvalData(question=q, answer=a, contexts=c)) - logging.info(f"Evaluating {len(dataset)} data points...") + logger.info(f"Evaluating {len(dataset)} data points...") result = {} with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: future_to_metric = {executor.submit(self._eval, dataset, metric): metric for metric in metrics} diff --git a/embedchain/bots/discord.py b/embedchain/bots/discord.py index adbf3b7c..907323d9 100644 --- a/embedchain/bots/discord.py +++ b/embedchain/bots/discord.py @@ -17,6 +17,8 @@ except ModuleNotFoundError: ) from None +logger = logging.getLogger(__name__) + intents = discord.Intents.default() intents.message_content = True client = discord.Client(intents=intents) @@ -37,7 +39,7 @@ class DiscordBot(BaseBot): self.add(data) response = f"Added data from: {data}" except Exception: - logging.exception(f"Failed to add data {data}.") + logger.exception(f"Failed to add data {data}.") response = "Some error occurred while adding data." return response @@ -45,7 +47,7 @@ class DiscordBot(BaseBot): try: response = self.query(message) except Exception: - logging.exception(f"Failed to query {message}.") + logger.exception(f"Failed to query {message}.") response = "An error occurred. Please try again!" return response @@ -60,7 +62,7 @@ class DiscordBot(BaseBot): async def query_command(interaction: discord.Interaction, question: str): await interaction.response.defer() member = client.guilds[0].get_member(client.user.id) - logging.info(f"User: {member}, Query: {question}") + logger.info(f"User: {member}, Query: {question}") try: answer = discord_bot.ask_bot(question) if args.include_question: @@ -70,20 +72,20 @@ async def query_command(interaction: discord.Interaction, question: str): await interaction.followup.send(response) except Exception as e: await interaction.followup.send("An error occurred. Please try again!") - logging.error("Error occurred during 'query' command:", e) + logger.error("Error occurred during 'query' command:", e) @tree.command(name="add", description="add new content to the embedchain database") async def add_command(interaction: discord.Interaction, url_or_text: str): await interaction.response.defer() member = client.guilds[0].get_member(client.user.id) - logging.info(f"User: {member}, Add: {url_or_text}") + logger.info(f"User: {member}, Add: {url_or_text}") try: response = discord_bot.add_data(url_or_text) await interaction.followup.send(response) except Exception as e: await interaction.followup.send("An error occurred. Please try again!") - logging.error("Error occurred during 'add' command:", e) + logger.error("Error occurred during 'add' command:", e) @tree.command(name="ping", description="Simple ping pong command") @@ -96,7 +98,7 @@ async def on_app_command_error(interaction: discord.Interaction, error: discord. if isinstance(error, commands.CommandNotFound): await interaction.followup.send("Invalid command. Please refer to the documentation for correct syntax.") else: - logging.error("Error occurred during command execution:", error) + logger.error("Error occurred during command execution:", error) @client.event @@ -104,8 +106,8 @@ async def on_ready(): # TODO: Sync in admin command, to not hit rate limits. # This might be overkill for most users, and it would require to set a guild or user id, where sync is allowed. await tree.sync() - logging.debug("Command tree synced") - logging.info(f"Logged in as {client.user.name}") + logger.debug("Command tree synced") + logger.info(f"Logged in as {client.user.name}") def start_command(): diff --git a/embedchain/bots/slack.py b/embedchain/bots/slack.py index 43e39f26..6beb847a 100644 --- a/embedchain/bots/slack.py +++ b/embedchain/bots/slack.py @@ -19,6 +19,8 @@ except ModuleNotFoundError: ) from None +logger = logging.getLogger(__name__) + SLACK_BOT_TOKEN = os.environ.get("SLACK_BOT_TOKEN") @@ -42,10 +44,10 @@ class SlackBot(BaseBot): try: response = self.chat_bot.chat(question) self.send_slack_message(message["channel"], response) - logging.info("Query answered successfully!") + logger.info("Query answered successfully!") except Exception as e: self.send_slack_message(message["channel"], "An error occurred. Please try again!") - logging.error("Error occurred during 'query' command:", e) + logger.error("Error occurred during 'query' command:", e) elif text.startswith("add"): _, data_type, url_or_text = text.split(" ", 2) if url_or_text.startswith("<") and url_or_text.endswith(">"): @@ -55,10 +57,10 @@ class SlackBot(BaseBot): self.send_slack_message(message["channel"], f"Added {data_type} : {url_or_text}") except ValueError as e: self.send_slack_message(message["channel"], f"Error: {str(e)}") - logging.error("Error occurred during 'add' command:", e) + logger.error("Error occurred during 'add' command:", e) except Exception as e: self.send_slack_message(message["channel"], f"Failed to add {data_type} : {url_or_text}") - logging.error("Error occurred during 'add' command:", e) + logger.error("Error occurred during 'add' command:", e) def send_slack_message(self, channel, message): response = self.client.chat_postMessage(channel=channel, text=message) @@ -68,7 +70,7 @@ class SlackBot(BaseBot): app = Flask(__name__) def signal_handler(sig, frame): - logging.info("\nGracefully shutting down the SlackBot...") + logger.info("\nGracefully shutting down the SlackBot...") sys.exit(0) signal.signal(signal.SIGINT, signal_handler) diff --git a/embedchain/bots/whatsapp.py b/embedchain/bots/whatsapp.py index 5106d40d..b6534068 100644 --- a/embedchain/bots/whatsapp.py +++ b/embedchain/bots/whatsapp.py @@ -8,6 +8,8 @@ from embedchain.helpers.json_serializable import register_deserializable from .base import BaseBot +logger = logging.getLogger(__name__) + @register_deserializable class WhatsAppBot(BaseBot): @@ -35,7 +37,7 @@ class WhatsAppBot(BaseBot): self.add(data) response = f"Added data from: {data}" except Exception: - logging.exception(f"Failed to add data {data}.") + logger.exception(f"Failed to add data {data}.") response = "Some error occurred while adding data." return response @@ -43,7 +45,7 @@ class WhatsAppBot(BaseBot): try: response = self.query(message) except Exception: - logging.exception(f"Failed to query {message}.") + logger.exception(f"Failed to query {message}.") response = "An error occurred. Please try again!" return response @@ -51,7 +53,7 @@ class WhatsAppBot(BaseBot): app = self.flask.Flask(__name__) def signal_handler(sig, frame): - logging.info("\nGracefully shutting down the WhatsAppBot...") + logger.info("\nGracefully shutting down the WhatsAppBot...") sys.exit(0) signal.signal(signal.SIGINT, signal_handler) diff --git a/embedchain/cache.py b/embedchain/cache.py index 4dd4ccc4..e3334fd8 100644 --- a/embedchain/cache.py +++ b/embedchain/cache.py @@ -14,6 +14,8 @@ from gptcache.similarity_evaluation.distance import \ from gptcache.similarity_evaluation.exact_match import \ ExactMatchEvaluation # noqa: F401 +logger = logging.getLogger(__name__) + def gptcache_pre_function(data: dict[str, Any], **params: dict[str, Any]): return data["input_query"] @@ -24,12 +26,12 @@ def gptcache_data_manager(vector_dimension): def gptcache_data_convert(cache_data): - logging.info("[Cache] Cache hit, returning cache data...") + logger.info("[Cache] Cache hit, returning cache data...") return cache_data def gptcache_update_cache_callback(llm_data, update_cache_func, *args, **kwargs): - logging.info("[Cache] Cache missed, updating cache...") + logger.info("[Cache] Cache missed, updating cache...") update_cache_func(Answer(llm_data, CacheDataType.STR)) return llm_data diff --git a/embedchain/chunkers/base_chunker.py b/embedchain/chunkers/base_chunker.py index b4d50f78..d7d4987d 100644 --- a/embedchain/chunkers/base_chunker.py +++ b/embedchain/chunkers/base_chunker.py @@ -6,6 +6,8 @@ from embedchain.config.add_config import ChunkerConfig from embedchain.helpers.json_serializable import JSONSerializable from embedchain.models.data_type import DataType +logger = logging.getLogger(__name__) + class BaseChunker(JSONSerializable): def __init__(self, text_splitter): @@ -27,7 +29,7 @@ class BaseChunker(JSONSerializable): chunk_ids = [] id_map = {} min_chunk_size = config.min_chunk_size if config is not None else 1 - logging.info(f"Skipping chunks smaller than {min_chunk_size} characters") + logger.info(f"Skipping chunks smaller than {min_chunk_size} characters") data_result = loader.load_data(src) data_records = data_result["data"] doc_id = data_result["doc_id"] diff --git a/embedchain/client.py b/embedchain/client.py index 0e6c6eaa..7e8fcddb 100644 --- a/embedchain/client.py +++ b/embedchain/client.py @@ -7,6 +7,8 @@ import requests from embedchain.constants import CONFIG_DIR, CONFIG_FILE +logger = logging.getLogger(__name__) + class Client: def __init__(self, api_key=None, host="https://apiv2.embedchain.ai"): @@ -24,7 +26,7 @@ class Client: else: if "api_key" in self.config_data: self.api_key = self.config_data["api_key"] - logging.info("API key loaded successfully!") + logger.info("API key loaded successfully!") else: raise ValueError( "You are not logged in. Please obtain an API key from https://app.embedchain.ai/settings/keys/" @@ -64,7 +66,7 @@ class Client: with open(CONFIG_FILE, "w") as config_file: json.dump(self.config_data, config_file, indent=4) - logging.info("API key saved successfully!") + logger.info("API key saved successfully!") def clear(self): if "api_key" in self.config_data: @@ -72,17 +74,17 @@ class Client: with open(CONFIG_FILE, "w") as config_file: json.dump(self.config_data, config_file, indent=4) self.api_key = None - logging.info("API key deleted successfully!") + logger.info("API key deleted successfully!") else: - logging.warning("API key not found in the configuration file.") + logger.warning("API key not found in the configuration file.") def update(self, api_key): if self.check(api_key): self.api_key = api_key self.save() - logging.info("API key updated successfully!") + logger.info("API key updated successfully!") else: - logging.warning("Invalid API key provided. API key not updated.") + logger.warning("Invalid API key provided. API key not updated.") def check(self, api_key): validation_url = f"{self.host}/api/v1/accounts/api_keys/validate/" @@ -90,8 +92,8 @@ class Client: if response.status_code == 200: return True else: - logging.warning(f"Response from API: {response.text}") - logging.warning("Invalid API key. Unable to validate.") + logger.warning(f"Response from API: {response.text}") + logger.warning("Invalid API key. Unable to validate.") return False def get(self): diff --git a/embedchain/config/base_app_config.py b/embedchain/config/base_app_config.py index c5b8e154..781ca024 100644 --- a/embedchain/config/base_app_config.py +++ b/embedchain/config/base_app_config.py @@ -5,6 +5,8 @@ from embedchain.config.base_config import BaseConfig from embedchain.helpers.json_serializable import JSONSerializable from embedchain.vectordb.base import BaseVectorDB +logger = logging.getLogger(__name__) + class BaseAppConfig(BaseConfig, JSONSerializable): """ @@ -42,15 +44,15 @@ class BaseAppConfig(BaseConfig, JSONSerializable): if db: self._db = db - logging.warning( + logger.warning( "DEPRECATION WARNING: Please supply the database as the second parameter during app init. " "Such as `app(config=config, db=db)`." ) if collection_name: - logging.warning("DEPRECATION WARNING: Please supply the collection name to the database config.") + logger.warning("DEPRECATION WARNING: Please supply the collection name to the database config.") return def _setup_logging(self, log_level): - logging.basicConfig(format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s", level=log_level) - self.logger = logging.getLogger(__name__) + logger.basicConfig(format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s", level=log_level) + self.logger = logger.getLogger(__name__) diff --git a/embedchain/config/llm/base.py b/embedchain/config/llm/base.py index 72dfd015..1efe8acb 100644 --- a/embedchain/config/llm/base.py +++ b/embedchain/config/llm/base.py @@ -6,6 +6,8 @@ from typing import Any, Optional from embedchain.config.base_config import BaseConfig from embedchain.helpers.json_serializable import register_deserializable +logger = logging.getLogger(__name__) + DEFAULT_PROMPT = """ You are a Q&A expert system. Your responses must always be rooted in the context provided for each query. Here are some guidelines to follow: @@ -147,7 +149,7 @@ class BaseLlmConfig(BaseConfig): :raises ValueError: Stream is not boolean """ if template is not None: - logging.warning( + logger.warning( "The `template` argument is deprecated and will be removed in a future version. " + "Please use `prompt` instead." ) diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 386ccdb6..ee824dd0 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -25,6 +25,8 @@ from embedchain.vectordb.base import BaseVectorDB load_dotenv() +logger = logging.getLogger(__name__) + class EmbedChain(JSONSerializable): def __init__( @@ -143,10 +145,10 @@ class EmbedChain(JSONSerializable): try: DataType(source) - logging.warning( + logger.warning( f"""Starting from version v0.0.40, Embedchain can automatically detect the data type. So, in the `add` method, the argument order has changed. You no longer need to specify '{source}' for the `source` argument. So the code snippet will be `.add("{data_type}", "{source}")`""" # noqa #E501 ) - logging.warning( + logger.warning( "Embedchain is swapping the arguments for you. This functionality might be deprecated in the future, so please adjust your code." # noqa #E501 ) source, data_type = data_type, source @@ -157,7 +159,7 @@ class EmbedChain(JSONSerializable): try: data_type = DataType(data_type) except ValueError: - logging.info( + logger.info( f"Invalid data_type: '{data_type}', using `custom` instead.\n Check docs to pass the valid data type: `https://docs.embedchain.ai/data-sources/overview`" # noqa: E501 ) data_type = DataType.CUSTOM @@ -190,12 +192,12 @@ class EmbedChain(JSONSerializable): try: self.db_session.commit() except Exception as e: - logging.error(f"Error adding data source: {e}") + logger.error(f"Error adding data source: {e}") self.db_session.rollback() if dry_run: data_chunks_info = {"chunks": documents, "metadata": metadatas, "count": len(documents), "type": data_type} - logging.debug(f"Dry run info : {data_chunks_info}") + logger.debug(f"Dry run info : {data_chunks_info}") return data_chunks_info # Send anonymous telemetry @@ -490,7 +492,7 @@ class EmbedChain(JSONSerializable): contexts_data_for_llm_query = contexts if self.cache_config is not None: - logging.info("Cache enabled. Checking cache...") + logger.info("Cache enabled. Checking cache...") answer = adapt( llm_handler=self.llm.query, cache_data_convert=gptcache_data_convert, @@ -562,7 +564,7 @@ class EmbedChain(JSONSerializable): self.llm.update_history(app_id=self.config.id, session_id=session_id) if self.cache_config is not None: - logging.info("Cache enabled. Checking cache...") + logger.debug("Cache enabled. Checking cache...") cache_id = f"{session_id}--{self.config.id}" answer = adapt( llm_handler=self.llm.chat, @@ -575,6 +577,7 @@ class EmbedChain(JSONSerializable): dry_run=dry_run, ) else: + logger.debug("Cache disabled. Running chat without cache.") answer = self.llm.chat( input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run ) @@ -652,7 +655,7 @@ class EmbedChain(JSONSerializable): self.db_session.query(ChatHistory).filter_by(app_id=self.config.id).delete() self.db_session.commit() except Exception as e: - logging.error(f"Error deleting data sources: {e}") + logger.error(f"Error deleting data sources: {e}") self.db_session.rollback() return None self.db.reset() @@ -694,11 +697,11 @@ class EmbedChain(JSONSerializable): self.db_session.query(DataSource).filter_by(hash=source_id, app_id=self.config.id).delete() self.db_session.commit() except Exception as e: - logging.error(f"Error deleting data sources: {e}") + logger.error(f"Error deleting data sources: {e}") self.db_session.rollback() return None self.db.delete(where={"hash": source_id}) - logging.info(f"Successfully deleted {source_id}") + logger.info(f"Successfully deleted {source_id}") # Send anonymous telemetry if self.config.collect_metrics: self.telemetry.capture(event_name="delete", properties=self._telemetry_props) diff --git a/embedchain/embedder/nvidia.py b/embedchain/embedder/nvidia.py index f6b5788c..5a499037 100644 --- a/embedchain/embedder/nvidia.py +++ b/embedchain/embedder/nvidia.py @@ -8,6 +8,8 @@ from embedchain.config import BaseEmbedderConfig from embedchain.embedder.base import BaseEmbedder from embedchain.models import VectorDimensions +logger = logging.getLogger(__name__) + class NvidiaEmbedder(BaseEmbedder): def __init__(self, config: Optional[BaseEmbedderConfig] = None): @@ -17,7 +19,7 @@ class NvidiaEmbedder(BaseEmbedder): super().__init__(config=config) model = self.config.model or "nvolveqa_40k" - logging.info(f"Using NVIDIA embedding model: {model}") + logger.info(f"Using NVIDIA embedding model: {model}") embedder = NVIDIAEmbeddings(model=model) embedding_fn = BaseEmbedder._langchain_default_concept(embedder) self.set_embedding_fn(embedding_fn=embedding_fn) diff --git a/embedchain/evaluation/metrics/answer_relevancy.py b/embedchain/evaluation/metrics/answer_relevancy.py index 588fc0fd..3e5c3859 100644 --- a/embedchain/evaluation/metrics/answer_relevancy.py +++ b/embedchain/evaluation/metrics/answer_relevancy.py @@ -12,6 +12,8 @@ from embedchain.config.evaluation.base import AnswerRelevanceConfig from embedchain.evaluation.base import BaseMetric from embedchain.utils.evaluation import EvalData, EvalMetric +logger = logging.getLogger(__name__) + class AnswerRelevance(BaseMetric): """ @@ -88,6 +90,6 @@ class AnswerRelevance(BaseMetric): try: results.append(future.result()) except Exception as e: - logging.error(f"Error evaluating answer relevancy for {data}: {e}") + logger.error(f"Error evaluating answer relevancy for {data}: {e}") return np.mean(results) if results else 0.0 diff --git a/embedchain/evaluation/metrics/groundedness.py b/embedchain/evaluation/metrics/groundedness.py index 082dfd6b..86f3f320 100644 --- a/embedchain/evaluation/metrics/groundedness.py +++ b/embedchain/evaluation/metrics/groundedness.py @@ -12,6 +12,8 @@ from embedchain.config.evaluation.base import GroundednessConfig from embedchain.evaluation.base import BaseMetric from embedchain.utils.evaluation import EvalData, EvalMetric +logger = logging.getLogger(__name__) + class Groundedness(BaseMetric): """ @@ -97,6 +99,6 @@ class Groundedness(BaseMetric): score = future.result() results.append(score) except Exception as e: - logging.error(f"Error while evaluating groundedness for data point {data}: {e}") + logger.error(f"Error while evaluating groundedness for data point {data}: {e}") return np.mean(results) if results else 0.0 diff --git a/embedchain/helpers/json_serializable.py b/embedchain/helpers/json_serializable.py index 5f3179af..bcbb4941 100644 --- a/embedchain/helpers/json_serializable.py +++ b/embedchain/helpers/json_serializable.py @@ -8,6 +8,8 @@ T = TypeVar("T", bound="JSONSerializable") # NOTE: Through inheritance, all of our classes should be children of JSONSerializable. (highest level) # NOTE: The @register_deserializable decorator should be added to all user facing child classes. (lowest level) +logger = logging.getLogger(__name__) + def register_deserializable(cls: Type[T]) -> Type[T]: """ @@ -57,7 +59,7 @@ class JSONSerializable: try: return json.dumps(self, default=self._auto_encoder, ensure_ascii=False) except Exception as e: - logging.error(f"Serialization error: {e}") + logger.error(f"Serialization error: {e}") return "{}" @classmethod @@ -79,7 +81,7 @@ class JSONSerializable: try: return json.loads(json_str, object_hook=cls._auto_decoder) except Exception as e: - logging.error(f"Deserialization error: {e}") + logger.error(f"Deserialization error: {e}") # Return a default instance in case of failure return cls() diff --git a/embedchain/llm/anthropic.py b/embedchain/llm/anthropic.py index 05a1f815..27f9eaa9 100644 --- a/embedchain/llm/anthropic.py +++ b/embedchain/llm/anthropic.py @@ -6,6 +6,8 @@ from embedchain.config import BaseLlmConfig from embedchain.helpers.json_serializable import register_deserializable from embedchain.llm.base import BaseLlm +logger = logging.getLogger(__name__) + @register_deserializable class AnthropicLlm(BaseLlm): @@ -26,7 +28,7 @@ class AnthropicLlm(BaseLlm): ) if config.max_tokens and config.max_tokens != 1000: - logging.warning("Config option `max_tokens` is not supported by this model.") + logger.warning("Config option `max_tokens` is not supported by this model.") messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt) diff --git a/embedchain/llm/aws_bedrock.py b/embedchain/llm/aws_bedrock.py index 34170981..362c5b75 100644 --- a/embedchain/llm/aws_bedrock.py +++ b/embedchain/llm/aws_bedrock.py @@ -38,7 +38,8 @@ class AWSBedrockLlm(BaseLlm): } if config.stream: - from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler + from langchain.callbacks.streaming_stdout import \ + StreamingStdOutCallbackHandler callbacks = [StreamingStdOutCallbackHandler()] llm = Bedrock(**kwargs, streaming=config.stream, callbacks=callbacks) diff --git a/embedchain/llm/azure_openai.py b/embedchain/llm/azure_openai.py index 7cf095ba..b2542b42 100644 --- a/embedchain/llm/azure_openai.py +++ b/embedchain/llm/azure_openai.py @@ -5,6 +5,8 @@ from embedchain.config import BaseLlmConfig from embedchain.helpers.json_serializable import register_deserializable from embedchain.llm.base import BaseLlm +logger = logging.getLogger(__name__) + @register_deserializable class AzureOpenAILlm(BaseLlm): @@ -31,7 +33,7 @@ class AzureOpenAILlm(BaseLlm): ) if config.top_p and config.top_p != 1: - logging.warning("Config option `top_p` is not supported by this model.") + logger.warning("Config option `top_p` is not supported by this model.") messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt) diff --git a/embedchain/llm/base.py b/embedchain/llm/base.py index f067fdd6..f8a0fd85 100644 --- a/embedchain/llm/base.py +++ b/embedchain/llm/base.py @@ -12,6 +12,8 @@ from embedchain.helpers.json_serializable import JSONSerializable from embedchain.memory.base import ChatHistory from embedchain.memory.message import ChatMessage +logger = logging.getLogger(__name__) + class BaseLlm(JSONSerializable): def __init__(self, config: Optional[BaseLlmConfig] = None): @@ -108,7 +110,7 @@ class BaseLlm(JSONSerializable): ) else: # If we can't swap in the default, we still proceed but tell users that the history is ignored. - logging.warning( + logger.warning( "Your bot contains a history, but prompt does not include `$history` key. History is ignored." ) prompt = self.config.prompt.substitute(context=context_string, query=input_query) @@ -159,7 +161,7 @@ class BaseLlm(JSONSerializable): 'Searching requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`' ) from None search = DuckDuckGoSearchRun() - logging.info(f"Access search to get answers for {input_query}") + logger.info(f"Access search to get answers for {input_query}") return search.run(input_query) @staticmethod @@ -175,7 +177,7 @@ class BaseLlm(JSONSerializable): for chunk in answer: streamed_answer = streamed_answer + chunk yield chunk - logging.info(f"Answer: {streamed_answer}") + logger.info(f"Answer: {streamed_answer}") def query(self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False): """ @@ -214,13 +216,13 @@ class BaseLlm(JSONSerializable): if self.online: k["web_search_result"] = self.access_search_and_get_results(input_query) prompt = self.generate_prompt(input_query, contexts, **k) - logging.info(f"Prompt: {prompt}") + logger.info(f"Prompt: {prompt}") if dry_run: return prompt answer = self.get_answer_from_llm(prompt) if isinstance(answer, str): - logging.info(f"Answer: {answer}") + logger.info(f"Answer: {answer}") return answer else: return self._stream_response(answer) @@ -270,14 +272,14 @@ class BaseLlm(JSONSerializable): k["web_search_result"] = self.access_search_and_get_results(input_query) prompt = self.generate_prompt(input_query, contexts, **k) - logging.info(f"Prompt: {prompt}") + logger.info(f"Prompt: {prompt}") if dry_run: return prompt answer = self.get_answer_from_llm(prompt) if isinstance(answer, str): - logging.info(f"Answer: {answer}") + logger.info(f"Answer: {answer}") return answer else: # this is a streamed response and needs to be handled differently. diff --git a/embedchain/llm/google.py b/embedchain/llm/google.py index 6f41e5e9..de15d4c4 100644 --- a/embedchain/llm/google.py +++ b/embedchain/llm/google.py @@ -10,6 +10,8 @@ from embedchain.config import BaseLlmConfig from embedchain.helpers.json_serializable import register_deserializable from embedchain.llm.base import BaseLlm +logger = logging.getLogger(__name__) + @register_deserializable class GoogleLlm(BaseLlm): @@ -36,7 +38,7 @@ class GoogleLlm(BaseLlm): def _get_answer(self, prompt: str) -> Union[str, Generator[Any, Any, None]]: model_name = self.config.model or "gemini-pro" - logging.info(f"Using Google LLM model: {model_name}") + logger.info(f"Using Google LLM model: {model_name}") model = genai.GenerativeModel(model_name=model_name) generation_config_params = { diff --git a/embedchain/llm/huggingface.py b/embedchain/llm/huggingface.py index 8cf0f8b5..69f0c463 100644 --- a/embedchain/llm/huggingface.py +++ b/embedchain/llm/huggingface.py @@ -11,6 +11,8 @@ from embedchain.config import BaseLlmConfig from embedchain.helpers.json_serializable import register_deserializable from embedchain.llm.base import BaseLlm +logger = logging.getLogger(__name__) + @register_deserializable class HuggingFaceLlm(BaseLlm): @@ -58,7 +60,7 @@ class HuggingFaceLlm(BaseLlm): raise ValueError("`top_p` must be > 0.0 and < 1.0") model = config.model - logging.info(f"Using HuggingFaceHub with model {model}") + logger.info(f"Using HuggingFaceHub with model {model}") llm = HuggingFaceHub( huggingfacehub_api_token=os.environ["HUGGINGFACE_ACCESS_TOKEN"], repo_id=model, diff --git a/embedchain/llm/openai.py b/embedchain/llm/openai.py index 240c8070..4432b5f8 100644 --- a/embedchain/llm/openai.py +++ b/embedchain/llm/openai.py @@ -65,7 +65,8 @@ class OpenAILlm(BaseLlm): messages: list[BaseMessage], ) -> str: from langchain.output_parsers.openai_tools import JsonOutputToolsParser - from langchain_core.utils.function_calling import convert_to_openai_tool + from langchain_core.utils.function_calling import \ + convert_to_openai_tool openai_tools = [convert_to_openai_tool(tools)] chat = chat.bind(tools=openai_tools).pipe(JsonOutputToolsParser()) diff --git a/embedchain/llm/vertex_ai.py b/embedchain/llm/vertex_ai.py index a026e11d..8808f886 100644 --- a/embedchain/llm/vertex_ai.py +++ b/embedchain/llm/vertex_ai.py @@ -9,6 +9,8 @@ from embedchain.config import BaseLlmConfig from embedchain.helpers.json_serializable import register_deserializable from embedchain.llm.base import BaseLlm +logger = logging.getLogger(__name__) + @register_deserializable class VertexAILlm(BaseLlm): @@ -28,7 +30,7 @@ class VertexAILlm(BaseLlm): @staticmethod def _get_answer(prompt: str, config: BaseLlmConfig) -> str: if config.top_p and config.top_p != 1: - logging.warning("Config option `top_p` is not supported by this model.") + logger.warning("Config option `top_p` is not supported by this model.") messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt) diff --git a/embedchain/loaders/beehiiv.py b/embedchain/loaders/beehiiv.py index 5169f7ae..4319b021 100644 --- a/embedchain/loaders/beehiiv.py +++ b/embedchain/loaders/beehiiv.py @@ -9,6 +9,8 @@ from embedchain.helpers.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader from embedchain.utils.misc import is_readable +logger = logging.getLogger(__name__) + @register_deserializable class BeehiivLoader(BaseLoader): @@ -90,9 +92,9 @@ class BeehiivLoader(BaseLoader): if is_readable(data): return data else: - logging.warning(f"Page is not readable (too many invalid characters): {link}") + logger.warning(f"Page is not readable (too many invalid characters): {link}") except ParserRejectedMarkup as e: - logging.error(f"Failed to parse {link}: {e}") + logger.error(f"Failed to parse {link}: {e}") return None for link in links: diff --git a/embedchain/loaders/directory_loader.py b/embedchain/loaders/directory_loader.py index 915c0249..5903813b 100644 --- a/embedchain/loaders/directory_loader.py +++ b/embedchain/loaders/directory_loader.py @@ -10,6 +10,8 @@ from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.text_file import TextFileLoader from embedchain.utils.misc import detect_datatype +logger = logging.getLogger(__name__) + @register_deserializable class DirectoryLoader(BaseLoader): @@ -27,12 +29,12 @@ class DirectoryLoader(BaseLoader): if not directory_path.is_dir(): raise ValueError(f"Invalid path: {path}") - logging.info(f"Loading data from directory: {path}") + logger.info(f"Loading data from directory: {path}") data_list = self._process_directory(directory_path) doc_id = hashlib.sha256((str(data_list) + str(directory_path)).encode()).hexdigest() for error in self.errors: - logging.warning(error) + logger.warning(error) return {"doc_id": doc_id, "data": data_list} @@ -46,7 +48,7 @@ class DirectoryLoader(BaseLoader): loader = self._predict_loader(file_path) data_list.extend(loader.load_data(str(file_path))["data"]) elif file_path.is_dir(): - logging.info(f"Loading data from directory: {file_path}") + logger.info(f"Loading data from directory: {file_path}") return data_list def _predict_loader(self, file_path: Path) -> BaseLoader: diff --git a/embedchain/loaders/discord.py b/embedchain/loaders/discord.py index 7db210ad..807a3d00 100644 --- a/embedchain/loaders/discord.py +++ b/embedchain/loaders/discord.py @@ -5,6 +5,8 @@ import os from embedchain.helpers.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader +logger = logging.getLogger(__name__) + @register_deserializable class DiscordLoader(BaseLoader): @@ -102,7 +104,7 @@ class DiscordLoader(BaseLoader): class DiscordClient(discord.Client): async def on_ready(self) -> None: - logging.info("Logged on as {0}!".format(self.user)) + logger.info("Logged on as {0}!".format(self.user)) try: channel = self.get_channel(int(channel_id)) if not isinstance(channel, discord.TextChannel): @@ -121,7 +123,7 @@ class DiscordLoader(BaseLoader): messages.append(DiscordLoader._format_message(thread_message)) except Exception as e: - logging.error(e) + logger.error(e) await self.close() finally: await self.close() diff --git a/embedchain/loaders/discourse.py b/embedchain/loaders/discourse.py index 1d36efa7..65c1dd75 100644 --- a/embedchain/loaders/discourse.py +++ b/embedchain/loaders/discourse.py @@ -8,6 +8,8 @@ import requests from embedchain.loaders.base_loader import BaseLoader from embedchain.utils.misc import clean_string +logger = logging.getLogger(__name__) + class DiscourseLoader(BaseLoader): def __init__(self, config: Optional[dict[str, Any]] = None): @@ -35,7 +37,7 @@ class DiscourseLoader(BaseLoader): try: response.raise_for_status() except Exception as e: - logging.error(f"Failed to load post {post_id}: {e}") + logger.error(f"Failed to load post {post_id}: {e}") return response_data = response.json() post_contents = clean_string(response_data.get("raw")) @@ -56,7 +58,7 @@ class DiscourseLoader(BaseLoader): self._check_query(query) data = [] data_contents = [] - logging.info(f"Searching data on discourse url: {self.domain}, for query: {query}") + logger.info(f"Searching data on discourse url: {self.domain}, for query: {query}") search_url = f"{self.domain}search.json?q={query}" response = requests.get(search_url) try: diff --git a/embedchain/loaders/docs_site_loader.py b/embedchain/loaders/docs_site_loader.py index 6baf1683..ff3ab75f 100644 --- a/embedchain/loaders/docs_site_loader.py +++ b/embedchain/loaders/docs_site_loader.py @@ -15,6 +15,8 @@ except ImportError: from embedchain.helpers.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader +logger = logging.getLogger(__name__) + @register_deserializable class DocsSiteLoader(BaseLoader): @@ -28,7 +30,7 @@ class DocsSiteLoader(BaseLoader): response = requests.get(url) if response.status_code != 200: - logging.info(f"Failed to fetch the website: {response.status_code}") + logger.info(f"Failed to fetch the website: {response.status_code}") return soup = BeautifulSoup(response.text, "html.parser") @@ -53,7 +55,7 @@ class DocsSiteLoader(BaseLoader): def _load_data_from_url(url: str) -> list: response = requests.get(url) if response.status_code != 200: - logging.info(f"Failed to fetch the website: {response.status_code}") + logger.info(f"Failed to fetch the website: {response.status_code}") return [] soup = BeautifulSoup(response.content, "html.parser") diff --git a/embedchain/loaders/gmail.py b/embedchain/loaders/gmail.py index 07e10273..ec62a34b 100644 --- a/embedchain/loaders/gmail.py +++ b/embedchain/loaders/gmail.py @@ -22,6 +22,8 @@ except ImportError: from embedchain.loaders.base_loader import BaseLoader from embedchain.utils.misc import clean_string +logger = logging.getLogger(__name__) + class GmailReader: SCOPES = ["https://www.googleapis.com/auth/gmail.readonly"] @@ -114,7 +116,7 @@ class GmailLoader(BaseLoader): def load_data(self, query: str): reader = GmailReader(query=query) emails = reader.load_emails() - logging.info(f"Gmail Loader: {len(emails)} emails found for query '{query}'") + logger.info(f"Gmail Loader: {len(emails)} emails found for query '{query}'") data = [] for email in emails: diff --git a/embedchain/loaders/mysql.py b/embedchain/loaders/mysql.py index 7eee2893..fd5b38ac 100644 --- a/embedchain/loaders/mysql.py +++ b/embedchain/loaders/mysql.py @@ -5,6 +5,8 @@ from typing import Any, Optional from embedchain.loaders.base_loader import BaseLoader from embedchain.utils.misc import clean_string +logger = logging.getLogger(__name__) + class MySQLLoader(BaseLoader): def __init__(self, config: Optional[dict[str, Any]]): @@ -32,7 +34,7 @@ class MySQLLoader(BaseLoader): self.connection = sqlconnector.connection.MySQLConnection(**config) self.cursor = self.connection.cursor() except (sqlconnector.Error, IOError) as err: - logging.info(f"Connection failed: {err}") + logger.info(f"Connection failed: {err}") raise ValueError( f"Unable to connect with the given config: {config}.", "Please provide the correct configuration to load data from you MySQL DB. \ diff --git a/embedchain/loaders/notion.py b/embedchain/loaders/notion.py index 0ce8eb3f..2a336381 100644 --- a/embedchain/loaders/notion.py +++ b/embedchain/loaders/notion.py @@ -9,6 +9,8 @@ from embedchain.helpers.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader from embedchain.utils.misc import clean_string +logger = logging.getLogger(__name__) + class NotionDocument: """ @@ -98,7 +100,7 @@ class NotionLoader(BaseLoader): id = source[-32:] formatted_id = f"{id[:8]}-{id[8:12]}-{id[12:16]}-{id[16:20]}-{id[20:]}" - logging.debug(f"Extracted notion page id as: {formatted_id}") + logger.debug(f"Extracted notion page id as: {formatted_id}") integration_token = os.getenv("NOTION_INTEGRATION_TOKEN") reader = NotionPageLoader(integration_token=integration_token) diff --git a/embedchain/loaders/postgres.py b/embedchain/loaders/postgres.py index d336248c..2ef396f9 100644 --- a/embedchain/loaders/postgres.py +++ b/embedchain/loaders/postgres.py @@ -4,6 +4,8 @@ from typing import Any, Optional from embedchain.loaders.base_loader import BaseLoader +logger = logging.getLogger(__name__) + class PostgresLoader(BaseLoader): def __init__(self, config: Optional[dict[str, Any]] = None): @@ -32,7 +34,7 @@ class PostgresLoader(BaseLoader): conn_params.append(f"{key}={value}") config_info = " ".join(conn_params) - logging.info(f"Connecting to postrgres sql: {config_info}") + logger.info(f"Connecting to postrgres sql: {config_info}") self.connection = psycopg.connect(conninfo=config_info) self.cursor = self.connection.cursor() diff --git a/embedchain/loaders/sitemap.py b/embedchain/loaders/sitemap.py index 4e4da7e1..43a3f20e 100644 --- a/embedchain/loaders/sitemap.py +++ b/embedchain/loaders/sitemap.py @@ -19,6 +19,8 @@ from embedchain.helpers.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.web_page import WebPageLoader +logger = logging.getLogger(__name__) + @register_deserializable class SitemapLoader(BaseLoader): @@ -41,7 +43,7 @@ class SitemapLoader(BaseLoader): response.raise_for_status() soup = BeautifulSoup(response.text, "xml") except requests.RequestException as e: - logging.error(f"Error fetching sitemap from URL: {e}") + logger.error(f"Error fetching sitemap from URL: {e}") return elif os.path.isfile(sitemap_source): with open(sitemap_source, "r") as file: @@ -60,7 +62,7 @@ class SitemapLoader(BaseLoader): loader_data = web_page_loader.load_data(link) return loader_data.get("data") except ParserRejectedMarkup as e: - logging.error(f"Failed to parse {link}: {e}") + logger.error(f"Failed to parse {link}: {e}") return None with concurrent.futures.ThreadPoolExecutor() as executor: @@ -72,6 +74,6 @@ class SitemapLoader(BaseLoader): if data: output.extend(data) except Exception as e: - logging.error(f"Error loading page {link}: {e}") + logger.error(f"Error loading page {link}: {e}") return {"doc_id": doc_id, "data": output} diff --git a/embedchain/loaders/slack.py b/embedchain/loaders/slack.py index f1caa1ec..6fb6e9db 100644 --- a/embedchain/loaders/slack.py +++ b/embedchain/loaders/slack.py @@ -11,6 +11,8 @@ from embedchain.utils.misc import clean_string SLACK_API_BASE_URL = "https://www.slack.com/api/" +logger = logging.getLogger(__name__) + class SlackLoader(BaseLoader): def __init__(self, config: Optional[dict[str, Any]] = None): @@ -38,7 +40,7 @@ class SlackLoader(BaseLoader): "SLACK_USER_TOKEN environment variables not provided. Check `https://docs.embedchain.ai/data-sources/slack` to learn more." # noqa:E501 ) - logging.info(f"Creating Slack Loader with config: {config}") + logger.info(f"Creating Slack Loader with config: {config}") # get slack client config params slack_bot_token = os.getenv("SLACK_USER_TOKEN") ssl_cert = ssl.create_default_context(cafile=certifi.where()) @@ -54,7 +56,7 @@ class SlackLoader(BaseLoader): headers=headers, team_id=team_id, ) - logging.info("Slack Loader setup successful!") + logger.info("Slack Loader setup successful!") @staticmethod def _check_query(query): @@ -69,7 +71,7 @@ class SlackLoader(BaseLoader): data = [] data_content = [] - logging.info(f"Searching slack conversations for query: {query}") + logger.info(f"Searching slack conversations for query: {query}") results = self.client.search_messages( query=query, sort="timestamp", @@ -79,7 +81,7 @@ class SlackLoader(BaseLoader): messages = results.get("messages") num_message = len(messages) - logging.info(f"Found {num_message} messages for query: {query}") + logger.info(f"Found {num_message} messages for query: {query}") matches = messages.get("matches", []) for message in matches: @@ -107,7 +109,7 @@ class SlackLoader(BaseLoader): "data": data, } except Exception as e: - logging.warning(f"Error in loading slack data: {e}") + logger.warning(f"Error in loading slack data: {e}") raise ValueError( f"Error in loading slack data: {e}. Check `https://docs.embedchain.ai/data-sources/slack` to learn more." # noqa:E501 ) from e diff --git a/embedchain/loaders/substack.py b/embedchain/loaders/substack.py index 77903d17..30975001 100644 --- a/embedchain/loaders/substack.py +++ b/embedchain/loaders/substack.py @@ -9,6 +9,8 @@ from embedchain.helpers.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader from embedchain.utils.misc import is_readable +logger = logging.getLogger(__name__) + @register_deserializable class SubstackLoader(BaseLoader): @@ -90,9 +92,9 @@ class SubstackLoader(BaseLoader): if is_readable(data): return data else: - logging.warning(f"Page is not readable (too many invalid characters): {link}") + logger.warning(f"Page is not readable (too many invalid characters): {link}") except ParserRejectedMarkup as e: - logging.error(f"Failed to parse {link}: {e}") + logger.error(f"Failed to parse {link}: {e}") return None for link in links: diff --git a/embedchain/loaders/web_page.py b/embedchain/loaders/web_page.py index d68bb8a8..0568b9e7 100644 --- a/embedchain/loaders/web_page.py +++ b/embedchain/loaders/web_page.py @@ -14,6 +14,8 @@ from embedchain.helpers.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader from embedchain.utils.misc import clean_string +logger = logging.getLogger(__name__) + @register_deserializable class WebPageLoader(BaseLoader): @@ -87,7 +89,7 @@ class WebPageLoader(BaseLoader): cleaned_size = len(content) if original_size != 0: - logging.info( + logger.info( f"[{url}] Cleaned page size: {cleaned_size} characters, down from {original_size} (shrunk: {original_size-cleaned_size} chars, {round((1-(cleaned_size/original_size)) * 100, 2)}%)" # noqa:E501 ) diff --git a/embedchain/loaders/youtube_channel.py b/embedchain/loaders/youtube_channel.py index c726d458..8f5e3f7f 100644 --- a/embedchain/loaders/youtube_channel.py +++ b/embedchain/loaders/youtube_channel.py @@ -7,6 +7,8 @@ from tqdm import tqdm from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.youtube_video import YoutubeVideoLoader +logger = logging.getLogger(__name__) + class YoutubeChannelLoader(BaseLoader): """Loader for youtube channel.""" @@ -36,7 +38,7 @@ class YoutubeChannelLoader(BaseLoader): videos = [entry["url"] for entry in info_dict["entries"]] return videos except Exception: - logging.error(f"Failed to fetch youtube videos for channel: {channel_name}") + logger.error(f"Failed to fetch youtube videos for channel: {channel_name}") return [] def _load_yt_video(video_link): @@ -45,12 +47,12 @@ class YoutubeChannelLoader(BaseLoader): if each_load_data: return each_load_data.get("data") except Exception as e: - logging.error(f"Failed to load youtube video {video_link}: {e}") + logger.error(f"Failed to load youtube video {video_link}: {e}") return None def _add_youtube_channel(): video_links = _get_yt_video_links() - logging.info("Loading videos from youtube channel...") + logger.info("Loading videos from youtube channel...") with concurrent.futures.ThreadPoolExecutor() as executor: # Submitting all tasks and storing the future object with the video link future_to_video = { @@ -67,7 +69,7 @@ class YoutubeChannelLoader(BaseLoader): data.extend(results) data_urls.extend([result.get("meta_data").get("url") for result in results]) except Exception as e: - logging.error(f"Failed to process youtube video {video}: {e}") + logger.error(f"Failed to process youtube video {video}: {e}") _add_youtube_channel() doc_id = hashlib.sha256((youtube_url + ", ".join(data_urls)).encode()).hexdigest() diff --git a/embedchain/memory/base.py b/embedchain/memory/base.py index 774cdc8e..d6697625 100644 --- a/embedchain/memory/base.py +++ b/embedchain/memory/base.py @@ -8,6 +8,8 @@ from embedchain.core.db.models import ChatHistory as ChatHistoryModel from embedchain.memory.message import ChatMessage from embedchain.memory.utils import merge_metadata_dict +logger = logging.getLogger(__name__) + class ChatHistory: def __init__(self) -> None: @@ -31,11 +33,11 @@ class ChatHistory: try: self.db_session.commit() except Exception as e: - logging.error(f"Error adding chat memory to db: {e}") + logger.error(f"Error adding chat memory to db: {e}") self.db_session.rollback() return None - logging.info(f"Added chat memory to db with id: {memory_id}") + logger.info(f"Added chat memory to db with id: {memory_id}") return memory_id def delete(self, app_id: str, session_id: Optional[str] = None): @@ -55,7 +57,7 @@ class ChatHistory: try: self.db_session.commit() except Exception as e: - logging.error(f"Error deleting chat history: {e}") + logger.error(f"Error deleting chat history: {e}") self.db_session.rollback() def get( diff --git a/embedchain/memory/message.py b/embedchain/memory/message.py index cc8c3a94..5211b0f6 100644 --- a/embedchain/memory/message.py +++ b/embedchain/memory/message.py @@ -3,6 +3,8 @@ from typing import Any, Optional from embedchain.helpers.json_serializable import JSONSerializable +logger = logging.getLogger(__name__) + class BaseMessage(JSONSerializable): """ @@ -52,7 +54,7 @@ class ChatMessage(JSONSerializable): def add_user_message(self, message: str, metadata: Optional[dict] = None): if self.human_message: - logging.info( + logger.info( "Human message already exists in the chat message,\ overwriting it with new message." ) @@ -61,7 +63,7 @@ class ChatMessage(JSONSerializable): def add_ai_message(self, message: str, metadata: Optional[dict] = None): if self.ai_message: - logging.info( + logger.info( "AI message already exists in the chat message,\ overwriting it with new message." ) diff --git a/embedchain/store/assistants.py b/embedchain/store/assistants.py index 84982dea..9098cc7c 100644 --- a/embedchain/store/assistants.py +++ b/embedchain/store/assistants.py @@ -157,7 +157,6 @@ class AIAssistant: log_level=logging.INFO, collect_metrics=True, ): - self.name = name or "AI Assistant" self.data_sources = data_sources or [] self.log_level = log_level diff --git a/embedchain/utils/misc.py b/embedchain/utils/misc.py index 7e1e6a6d..61aaea78 100644 --- a/embedchain/utils/misc.py +++ b/embedchain/utils/misc.py @@ -11,6 +11,8 @@ from tqdm import tqdm from embedchain.models.data_type import DataType +logger = logging.getLogger(__name__) + def parse_content(content, type): implemented = ["html.parser", "lxml", "lxml-xml", "xml", "html5lib"] @@ -61,7 +63,7 @@ def parse_content(content, type): cleaned_size = len(content) if original_size != 0: - logging.info( + logger.info( f"Cleaned page size: {cleaned_size} characters, down from {original_size} (shrunk: {original_size-cleaned_size} chars, {round((1-(cleaned_size/original_size)) * 100, 2)}%)" # noqa:E501 ) @@ -208,31 +210,31 @@ def detect_datatype(source: Any) -> DataType: } if url.netloc in YOUTUBE_ALLOWED_NETLOCKS: - logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.") + logger.debug(f"Source of `{formatted_source}` detected as `youtube_video`.") return DataType.YOUTUBE_VIDEO if url.netloc in {"notion.so", "notion.site"}: - logging.debug(f"Source of `{formatted_source}` detected as `notion`.") + logger.debug(f"Source of `{formatted_source}` detected as `notion`.") return DataType.NOTION if url.path.endswith(".pdf"): - logging.debug(f"Source of `{formatted_source}` detected as `pdf_file`.") + logger.debug(f"Source of `{formatted_source}` detected as `pdf_file`.") return DataType.PDF_FILE if url.path.endswith(".xml"): - logging.debug(f"Source of `{formatted_source}` detected as `sitemap`.") + logger.debug(f"Source of `{formatted_source}` detected as `sitemap`.") return DataType.SITEMAP if url.path.endswith(".csv"): - logging.debug(f"Source of `{formatted_source}` detected as `csv`.") + logger.debug(f"Source of `{formatted_source}` detected as `csv`.") return DataType.CSV if url.path.endswith(".mdx") or url.path.endswith(".md"): - logging.debug(f"Source of `{formatted_source}` detected as `mdx`.") + logger.debug(f"Source of `{formatted_source}` detected as `mdx`.") return DataType.MDX if url.path.endswith(".docx"): - logging.debug(f"Source of `{formatted_source}` detected as `docx`.") + logger.debug(f"Source of `{formatted_source}` detected as `docx`.") return DataType.DOCX if url.path.endswith(".yaml"): @@ -242,14 +244,14 @@ def detect_datatype(source: Any) -> DataType: try: yaml_content = yaml.safe_load(response.text) except yaml.YAMLError as exc: - logging.error(f"Error parsing YAML: {exc}") + logger.error(f"Error parsing YAML: {exc}") raise TypeError(f"Not a valid data type. Error loading YAML: {exc}") if is_openapi_yaml(yaml_content): - logging.debug(f"Source of `{formatted_source}` detected as `openapi`.") + logger.debug(f"Source of `{formatted_source}` detected as `openapi`.") return DataType.OPENAPI else: - logging.error( + logger.error( f"Source of `{formatted_source}` does not contain all the required \ fields of OpenAPI yaml. Check 'https://spec.openapis.org/oas/v3.1.0'" ) @@ -258,35 +260,35 @@ def detect_datatype(source: Any) -> DataType: make sure you have all the required fields in YAML config data" ) except requests.exceptions.RequestException as e: - logging.error(f"Error fetching URL {formatted_source}: {e}") + logger.error(f"Error fetching URL {formatted_source}: {e}") if url.path.endswith(".json"): - logging.debug(f"Source of `{formatted_source}` detected as `json_file`.") + logger.debug(f"Source of `{formatted_source}` detected as `json_file`.") return DataType.JSON if "docs" in url.netloc or ("docs" in url.path and url.scheme != "file"): # `docs_site` detection via path is not accepted for local filesystem URIs, # because that would mean all paths that contain `docs` are now doc sites, which is too aggressive. - logging.debug(f"Source of `{formatted_source}` detected as `docs_site`.") + logger.debug(f"Source of `{formatted_source}` detected as `docs_site`.") return DataType.DOCS_SITE if "github.com" in url.netloc: - logging.debug(f"Source of `{formatted_source}` detected as `github`.") + logger.debug(f"Source of `{formatted_source}` detected as `github`.") return DataType.GITHUB if is_google_drive_folder(url.netloc + url.path): - logging.debug(f"Source of `{formatted_source}` detected as `google drive folder`.") + logger.debug(f"Source of `{formatted_source}` detected as `google drive folder`.") return DataType.GOOGLE_DRIVE_FOLDER # If none of the above conditions are met, it's a general web page - logging.debug(f"Source of `{formatted_source}` detected as `web_page`.") + logger.debug(f"Source of `{formatted_source}` detected as `web_page`.") return DataType.WEB_PAGE elif not isinstance(source, str): # For datatypes where source is not a string. if isinstance(source, tuple) and len(source) == 2 and isinstance(source[0], str) and isinstance(source[1], str): - logging.debug(f"Source of `{formatted_source}` detected as `qna_pair`.") + logger.debug(f"Source of `{formatted_source}` detected as `qna_pair`.") return DataType.QNA_PAIR # Raise an error if it isn't a string and also not a valid non-string type (one of the previous). @@ -300,37 +302,37 @@ def detect_datatype(source: Any) -> DataType: # Note: checking for string is not necessary anymore. if source.endswith(".docx"): - logging.debug(f"Source of `{formatted_source}` detected as `docx`.") + logger.debug(f"Source of `{formatted_source}` detected as `docx`.") return DataType.DOCX if source.endswith(".csv"): - logging.debug(f"Source of `{formatted_source}` detected as `csv`.") + logger.debug(f"Source of `{formatted_source}` detected as `csv`.") return DataType.CSV if source.endswith(".xml"): - logging.debug(f"Source of `{formatted_source}` detected as `xml`.") + logger.debug(f"Source of `{formatted_source}` detected as `xml`.") return DataType.XML if source.endswith(".mdx") or source.endswith(".md"): - logging.debug(f"Source of `{formatted_source}` detected as `mdx`.") + logger.debug(f"Source of `{formatted_source}` detected as `mdx`.") return DataType.MDX if source.endswith(".txt"): - logging.debug(f"Source of `{formatted_source}` detected as `text`.") + logger.debug(f"Source of `{formatted_source}` detected as `text`.") return DataType.TEXT_FILE if source.endswith(".pdf"): - logging.debug(f"Source of `{formatted_source}` detected as `pdf_file`.") + logger.debug(f"Source of `{formatted_source}` detected as `pdf_file`.") return DataType.PDF_FILE if source.endswith(".yaml"): with open(source, "r") as file: yaml_content = yaml.safe_load(file) if is_openapi_yaml(yaml_content): - logging.debug(f"Source of `{formatted_source}` detected as `openapi`.") + logger.debug(f"Source of `{formatted_source}` detected as `openapi`.") return DataType.OPENAPI else: - logging.error( + logger.error( f"Source of `{formatted_source}` does not contain all the required \ fields of OpenAPI yaml. Check 'https://spec.openapis.org/oas/v3.1.0'" ) @@ -340,11 +342,11 @@ def detect_datatype(source: Any) -> DataType: ) if source.endswith(".json"): - logging.debug(f"Source of `{formatted_source}` detected as `json`.") + logger.debug(f"Source of `{formatted_source}` detected as `json`.") return DataType.JSON if os.path.exists(source) and is_readable(open(source).read()): - logging.debug(f"Source of `{formatted_source}` detected as `text_file`.") + logger.debug(f"Source of `{formatted_source}` detected as `text_file`.") return DataType.TEXT_FILE # If the source is a valid file, that's not detectable as a type, an error is raised. @@ -360,11 +362,11 @@ def detect_datatype(source: Any) -> DataType: # check if the source is valid json string if is_valid_json_string(source): - logging.debug(f"Source of `{formatted_source}` detected as `json`.") + logger.debug(f"Source of `{formatted_source}` detected as `json`.") return DataType.JSON # Use text as final fallback. - logging.debug(f"Source of `{formatted_source}` detected as `text`.") + logger.debug(f"Source of `{formatted_source}` detected as `text`.") return DataType.TEXT diff --git a/embedchain/vectordb/chroma.py b/embedchain/vectordb/chroma.py index ec166913..f3b87b7f 100644 --- a/embedchain/vectordb/chroma.py +++ b/embedchain/vectordb/chroma.py @@ -22,6 +22,9 @@ except RuntimeError: from chromadb.errors import InvalidDimensionException +logger = logging.getLogger(__name__) + + @register_deserializable class ChromaDB(BaseVectorDB): """Vector database using ChromaDB.""" @@ -47,7 +50,7 @@ class ChromaDB(BaseVectorDB): setattr(self.settings, key, value) if self.config.host and self.config.port: - logging.info(f"Connecting to ChromaDB server: {self.config.host}:{self.config.port}") + logger.info(f"Connecting to ChromaDB server: {self.config.host}:{self.config.port}") self.settings.chroma_server_host = self.config.host self.settings.chroma_server_http_port = self.config.port self.settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI" diff --git a/embedchain/vectordb/elasticsearch.py b/embedchain/vectordb/elasticsearch.py index b73883ec..ba0b2510 100644 --- a/embedchain/vectordb/elasticsearch.py +++ b/embedchain/vectordb/elasticsearch.py @@ -14,6 +14,8 @@ from embedchain.helpers.json_serializable import register_deserializable from embedchain.utils.misc import chunks from embedchain.vectordb.base import BaseVectorDB +logger = logging.getLogger(__name__) + @register_deserializable class ElasticsearchDB(BaseVectorDB): @@ -62,7 +64,7 @@ class ElasticsearchDB(BaseVectorDB): """ This method is needed because `embedder` attribute needs to be set externally before it can be initialized. """ - logging.info(self.client.info()) + logger.info(self.client.info()) index_settings = { "mappings": { "properties": { diff --git a/embedchain/vectordb/opensearch.py b/embedchain/vectordb/opensearch.py index 18cf0d35..99d73300 100644 --- a/embedchain/vectordb/opensearch.py +++ b/embedchain/vectordb/opensearch.py @@ -19,6 +19,8 @@ from embedchain.config import OpenSearchDBConfig from embedchain.helpers.json_serializable import register_deserializable from embedchain.vectordb.base import BaseVectorDB +logger = logging.getLogger(__name__) + @register_deserializable class OpenSearchDB(BaseVectorDB): @@ -43,12 +45,12 @@ class OpenSearchDB(BaseVectorDB): **self.config.extra_params, ) info = self.client.info() - logging.info(f"Connected to {info['version']['distribution']}. Version: {info['version']['number']}") + logger.info(f"Connected to {info['version']['distribution']}. Version: {info['version']['number']}") # Remove auth credentials from config after successful connection super().__init__(config=self.config) def _initialize(self): - logging.info(self.client.info()) + logger.info(self.client.info()) index_name = self._get_index() if self.client.indices.exists(index=index_name): print(f"Index '{index_name}' already exists.") diff --git a/embedchain/vectordb/pinecone.py b/embedchain/vectordb/pinecone.py index 1e083f9a..710e0611 100644 --- a/embedchain/vectordb/pinecone.py +++ b/embedchain/vectordb/pinecone.py @@ -16,6 +16,8 @@ from embedchain.helpers.json_serializable import register_deserializable from embedchain.utils.misc import chunks from embedchain.vectordb.base import BaseVectorDB +logger = logging.getLogger(__name__) + @register_deserializable class PineconeDB(BaseVectorDB): @@ -49,7 +51,7 @@ class PineconeDB(BaseVectorDB): # Setup BM25Encoder if sparse vectors are to be used self.bm25_encoder = None if self.config.hybrid_search: - logging.info("Initializing BM25Encoder for sparse vectors..") + logger.info("Initializing BM25Encoder for sparse vectors..") self.bm25_encoder = self.config.bm25_encoder if self.config.bm25_encoder else BM25Encoder.default() # Call parent init here because embedder is needed diff --git a/embedchain/vectordb/zilliz.py b/embedchain/vectordb/zilliz.py index cb932fa4..14663614 100644 --- a/embedchain/vectordb/zilliz.py +++ b/embedchain/vectordb/zilliz.py @@ -13,6 +13,8 @@ except ImportError: "Zilliz requires extra dependencies. Install with `pip install --upgrade embedchain[milvus]`" ) from None +logger = logging.getLogger(__name__) + @register_deserializable class ZillizVectorDB(BaseVectorDB): @@ -62,7 +64,7 @@ class ZillizVectorDB(BaseVectorDB): :type name: str """ if utility.has_collection(name): - logging.info(f"[ZillizDB]: found an existing collection {name}, make sure the auto-id is disabled.") + logger.info(f"[ZillizDB]: found an existing collection {name}, make sure the auto-id is disabled.") self.collection = Collection(name) else: fields = [ diff --git a/examples/api_server/api_server.py b/examples/api_server/api_server.py index de447a8e..f8d4d4d1 100644 --- a/examples/api_server/api_server.py +++ b/examples/api_server/api_server.py @@ -7,6 +7,9 @@ from embedchain import App app = Flask(__name__) +logger = logging.getLogger(__name__) + + @app.route("/add", methods=["POST"]) def add(): data = request.get_json() @@ -17,7 +20,7 @@ def add(): App().add(url_or_text, data_type=data_type) return jsonify({"data": f"Added {data_type}: {url_or_text}"}), 200 except Exception: - logging.exception(f"Failed to add {data_type=}: {url_or_text=}") + logger.exception(f"Failed to add {data_type=}: {url_or_text=}") return jsonify({"error": f"Failed to add {data_type}: {url_or_text}"}), 500 return jsonify({"error": "Invalid request. Please provide 'data_type' and 'url_or_text' in JSON format."}), 400 @@ -31,7 +34,7 @@ def query(): response = App().query(question) return jsonify({"data": response}), 200 except Exception: - logging.exception(f"Failed to query {question=}") + logger.exception(f"Failed to query {question=}") return jsonify({"error": "An error occurred. Please try again!"}), 500 return jsonify({"error": "Invalid request. Please provide 'question' in JSON format."}), 400 @@ -45,7 +48,7 @@ def chat(): response = App().chat(question) return jsonify({"data": response}), 200 except Exception: - logging.exception(f"Failed to chat {question=}") + logger.exception(f"Failed to chat {question=}") return jsonify({"error": "An error occurred. Please try again!"}), 500 return jsonify({"error": "Invalid request. Please provide 'question' in JSON format."}), 400 diff --git a/examples/nextjs/nextjs_discord/app.py b/examples/nextjs/nextjs_discord/app.py index 07f4a4e9..74b245ab 100644 --- a/examples/nextjs/nextjs_discord/app.py +++ b/examples/nextjs/nextjs_discord/app.py @@ -12,10 +12,12 @@ intents.message_content = True client = discord.Client(intents=intents) discord_bot_name = os.environ["DISCORD_BOT_NAME"] +logger = logging.getLogger(__name__) + class NextJSBot: def __init__(self) -> None: - logging.info("NextJS Bot powered with embedchain.") + logger.info("NextJS Bot powered with embedchain.") def add(self, _): raise ValueError("Add is not implemented yet") @@ -31,11 +33,11 @@ class NextJSBot: try: response = response.json() except Exception: - logging.error(f"Failed to parse response: {response}") + logger.error(f"Failed to parse response: {response}") response = {} return response except Exception: - logging.exception(f"Failed to query {message}.") + logger.exception(f"Failed to query {message}.") response = "An error occurred. Please try again!" return response @@ -49,7 +51,7 @@ NEXTJS_BOT = NextJSBot() @client.event async def on_ready(): - logging.info(f"User {client.user.name} logged in with id: {client.user.id}!") + logger.info(f"User {client.user.name} logged in with id: {client.user.id}!") def _get_question(message): diff --git a/examples/nextjs/nextjs_slack/app.py b/examples/nextjs/nextjs_slack/app.py index 58651e5e..005a4e29 100644 --- a/examples/nextjs/nextjs_slack/app.py +++ b/examples/nextjs/nextjs_slack/app.py @@ -9,6 +9,8 @@ from slack_bolt.adapter.socket_mode import SocketModeHandler load_dotenv(".env") +logger = logging.getLogger(__name__) + def remove_mentions(message): mention_pattern = re.compile(r"<@[^>]+>") @@ -19,7 +21,7 @@ def remove_mentions(message): class SlackBotApp: def __init__(self) -> None: - logging.info("Slack Bot using Embedchain!") + logger.info("Slack Bot using Embedchain!") def add(self, _): raise ValueError("Add is not implemented yet") @@ -35,11 +37,11 @@ class SlackBotApp: try: response = response.json() except Exception: - logging.error(f"Failed to parse response: {response}") + logger.error(f"Failed to parse response: {response}") response = {} return response except Exception: - logging.exception(f"Failed to query {query}.") + logger.exception(f"Failed to query {query}.") response = "An error occurred. Please try again!" return response diff --git a/examples/rest-api/main.py b/examples/rest-api/main.py index f57474ea..66eef927 100644 --- a/examples/rest-api/main.py +++ b/examples/rest-api/main.py @@ -13,6 +13,8 @@ from utils import generate_error_message_for_api_keys from embedchain import App from embedchain.client import Client +logger = logging.getLogger(__name__) + Base.metadata.create_all(bind=engine) @@ -84,7 +86,7 @@ async def create_app_using_default_config(app_id: str, config: UploadFile = None return DefaultResponse(response=f"App created successfully. App ID: {app_id}") except Exception as e: - logging.warning(str(e)) + logger.warning(str(e)) raise HTTPException(detail=f"Error creating app: {str(e)}", status_code=400) @@ -114,13 +116,13 @@ async def get_datasources_associated_with_app_id(app_id: str, db: Session = Depe response = app.get_data_sources() return {"results": response} except ValueError as ve: - logging.warning(str(ve)) + logger.warning(str(ve)) raise HTTPException( detail=generate_error_message_for_api_keys(ve), status_code=400, ) except Exception as e: - logging.warning(str(e)) + logger.warning(str(e)) raise HTTPException(detail=f"Error occurred: {str(e)}", status_code=400) @@ -153,13 +155,13 @@ async def add_datasource_to_an_app(body: SourceApp, app_id: str, db: Session = D response = app.add(source=body.source, data_type=body.data_type) return DefaultResponse(response=response) except ValueError as ve: - logging.warning(str(ve)) + logger.warning(str(ve)) raise HTTPException( detail=generate_error_message_for_api_keys(ve), status_code=400, ) except Exception as e: - logging.warning(str(e)) + logger.warning(str(e)) raise HTTPException(detail=f"Error occurred: {str(e)}", status_code=400) @@ -191,13 +193,13 @@ async def query_an_app(body: QueryApp, app_id: str, db: Session = Depends(get_db response = app.query(body.query) return DefaultResponse(response=response) except ValueError as ve: - logging.warning(str(ve)) + logger.warning(str(ve)) raise HTTPException( detail=generate_error_message_for_api_keys(ve), status_code=400, ) except Exception as e: - logging.warning(str(e)) + logger.warning(str(e)) raise HTTPException(detail=f"Error occurred: {str(e)}", status_code=400) @@ -274,13 +276,13 @@ async def deploy_app(body: DeployAppRequest, app_id: str, db: Session = Depends( app.deploy() return DefaultResponse(response="App deployed successfully.") except ValueError as ve: - logging.warning(str(ve)) + logger.warning(str(ve)) raise HTTPException( detail=generate_error_message_for_api_keys(ve), status_code=400, ) except Exception as e: - logging.warning(str(e)) + logger.warning(str(e)) raise HTTPException(detail=f"Error occurred: {str(e)}", status_code=400) diff --git a/pyproject.toml b/pyproject.toml index 3f8681b2..ab73fc93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "embedchain" -version = "0.1.94" +version = "0.1.95" description = "Simplest open source retrieval(RAG) framework" authors = [ "Taranjeet Singh ", diff --git a/tests/llm/test_anthrophic.py b/tests/llm/test_anthrophic.py index d3a71715..4ad34349 100644 --- a/tests/llm/test_anthrophic.py +++ b/tests/llm/test_anthrophic.py @@ -50,22 +50,3 @@ def test_get_messages(anthropic_llm): SystemMessage(content="Test System Prompt", additional_kwargs={}), HumanMessage(content="Test Prompt", additional_kwargs={}, example=False), ] - - -def test_get_answer_max_tokens_is_provided(anthropic_llm, caplog): - with patch("langchain_community.chat_models.ChatAnthropic") as mock_chat: - mock_chat_instance = mock_chat.return_value - mock_chat_instance.return_value = MagicMock(content="Test Response") - - prompt = "Test Prompt" - config = anthropic_llm.config - config.max_tokens = 500 - - response = anthropic_llm._get_answer(prompt, config) - - assert response == "Test Response" - mock_chat.assert_called_once_with( - anthropic_api_key="test_api_key", temperature=config.temperature, model=config.model - ) - - assert "Config option `max_tokens` is not supported by this model." in caplog.text diff --git a/tests/llm/test_azure_openai.py b/tests/llm/test_azure_openai.py index 5004f846..ab5a8d93 100644 --- a/tests/llm/test_azure_openai.py +++ b/tests/llm/test_azure_openai.py @@ -59,33 +59,6 @@ def test_get_messages(azure_openai_llm): ] -def test_get_answer_top_p_is_provided(azure_openai_llm, caplog): - with patch("langchain_community.chat_models.AzureChatOpenAI") as mock_chat: - mock_chat_instance = mock_chat.return_value - mock_chat_instance.return_value = MagicMock(content="Test Response") - - prompt = "Test Prompt" - config = azure_openai_llm.config - config.top_p = 0.5 - - response = azure_openai_llm._get_answer(prompt, config) - - assert response == "Test Response" - mock_chat.assert_called_once_with( - deployment_name=config.deployment_name, - openai_api_version="2023-05-15", - model_name=config.model or "gpt-3.5-turbo", - temperature=config.temperature, - max_tokens=config.max_tokens, - streaming=config.stream, - ) - mock_chat_instance.assert_called_once_with( - azure_openai_llm._get_messages(prompt, system_prompt=config.system_prompt) - ) - - assert "Config option `top_p` is not supported by this model." in caplog.text - - def test_when_no_deployment_name_provided(): config = BaseLlmConfig(temperature=0.7, model="gpt-3.5-turbo", max_tokens=50, system_prompt="System Prompt") with pytest.raises(ValueError): diff --git a/tests/loaders/test_discourse.py b/tests/loaders/test_discourse.py index 6a88eba4..71635b37 100644 --- a/tests/loaders/test_discourse.py +++ b/tests/loaders/test_discourse.py @@ -66,21 +66,6 @@ def test_discourse_loader_load_post_with_valid_post_id(discourse_loader, monkeyp assert "meta_data" in post_data -def test_discourse_loader_load_post_with_invalid_post_id(discourse_loader, monkeypatch, caplog): - def mock_get(*args, **kwargs): - class MockResponse: - def raise_for_status(self): - raise requests.exceptions.RequestException("Test error") - - return MockResponse() - - monkeypatch.setattr(requests, "get", mock_get) - - discourse_loader._load_post(123) - - assert "Failed to load post" in caplog.text - - def test_discourse_loader_load_data_with_valid_query(discourse_loader, monkeypatch): def mock_get(*args, **kwargs): class MockResponse: