[Refactor] Improve logging package wide (#1315)
This commit is contained in:
@@ -91,7 +91,6 @@ keys = console
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
|
||||
@@ -32,6 +32,8 @@ from embedchain.utils.misc import validate_config
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class App(EmbedChain):
|
||||
@@ -50,10 +52,10 @@ class App(EmbedChain):
|
||||
embedding_model: BaseEmbedder = None,
|
||||
llm: BaseLlm = None,
|
||||
config_data: dict = None,
|
||||
log_level=logging.WARN,
|
||||
auto_deploy: bool = False,
|
||||
chunker: ChunkerConfig = None,
|
||||
cache_config: CacheConfig = None,
|
||||
log_level: int = logging.WARN,
|
||||
):
|
||||
"""
|
||||
Initialize a new `App` instance.
|
||||
@@ -68,8 +70,6 @@ class App(EmbedChain):
|
||||
:type llm: BaseLlm, optional
|
||||
:param config_data: Config dictionary, defaults to None
|
||||
:type config_data: dict, optional
|
||||
:param log_level: Log level to use, defaults to logging.WARN
|
||||
:type log_level: int, optional
|
||||
:param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
|
||||
:type auto_deploy: bool, optional
|
||||
:raises Exception: If an error occurs while creating the pipeline
|
||||
@@ -83,13 +83,12 @@ class App(EmbedChain):
|
||||
if name and config:
|
||||
raise Exception("Cannot provide both name and config. Please provide only one of them.")
|
||||
|
||||
# logging.basicConfig(level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
logger.debug("4.0")
|
||||
# Initialize the metadata db for the app
|
||||
setup_engine(database_uri=os.environ.get("EMBEDCHAIN_DB_URI"))
|
||||
init_db()
|
||||
|
||||
logger.debug("4.0")
|
||||
self.auto_deploy = auto_deploy
|
||||
# Store the dict config as an attribute to be able to send it
|
||||
self.config_data = config_data if (config_data and validate_config(config_data)) else None
|
||||
@@ -119,6 +118,7 @@ class App(EmbedChain):
|
||||
self.llm = llm or OpenAILlm()
|
||||
self._init_db()
|
||||
|
||||
logger.debug("4.1")
|
||||
# Session for the metadata db
|
||||
self.db_session = get_session()
|
||||
|
||||
@@ -126,6 +126,7 @@ class App(EmbedChain):
|
||||
if self.cache_config is not None:
|
||||
self._init_cache()
|
||||
|
||||
logger.debug("4.2")
|
||||
# Send anonymous telemetry
|
||||
self._telemetry_props = {"class": self.__class__.__name__}
|
||||
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
|
||||
@@ -238,7 +239,7 @@ class App(EmbedChain):
|
||||
response.raise_for_status()
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error occurred during file upload: {str(e)}")
|
||||
logger.exception(f"Error occurred during file upload: {str(e)}")
|
||||
print("❌ Error occurred during file upload!")
|
||||
return False
|
||||
|
||||
@@ -272,7 +273,7 @@ class App(EmbedChain):
|
||||
metadata = {"file_path": data_value, "s3_key": s3_key}
|
||||
data_value = presigned_url
|
||||
else:
|
||||
self.logger.error(f"File upload failed for hash: {data_hash}")
|
||||
logger.error(f"File upload failed for hash: {data_hash}")
|
||||
return False
|
||||
else:
|
||||
if data_type == "qna_pair":
|
||||
@@ -336,6 +337,7 @@ class App(EmbedChain):
|
||||
:return: An instance of the App class.
|
||||
:rtype: App
|
||||
"""
|
||||
logger.debug("6")
|
||||
# Backward compatibility for yaml_path
|
||||
if yaml_path and not config_path:
|
||||
config_path = yaml_path
|
||||
@@ -357,15 +359,13 @@ class App(EmbedChain):
|
||||
elif config and isinstance(config, dict):
|
||||
config_data = config
|
||||
else:
|
||||
logging.error(
|
||||
logger.error(
|
||||
"Please provide either a config file path (YAML or JSON) or a config dictionary. Falling back to defaults because no config is provided.", # noqa: E501
|
||||
)
|
||||
config_data = {}
|
||||
|
||||
try:
|
||||
validate_config(config_data)
|
||||
except Exception as e:
|
||||
raise Exception(f"Error occurred while validating the config. Error: {str(e)}")
|
||||
# Validate the config
|
||||
validate_config(config_data)
|
||||
|
||||
app_config_data = config_data.get("app", {}).get("config", {})
|
||||
vector_db_config_data = config_data.get("vectordb", {})
|
||||
@@ -477,12 +477,12 @@ class App(EmbedChain):
|
||||
EvalMetric.GROUNDEDNESS.value,
|
||||
]
|
||||
|
||||
logging.info(f"Collecting data from {len(queries)} questions for evaluation...")
|
||||
logger.info(f"Collecting data from {len(queries)} questions for evaluation...")
|
||||
dataset = []
|
||||
for q, a, c in zip(queries, answers, contexts):
|
||||
dataset.append(EvalData(question=q, answer=a, contexts=c))
|
||||
|
||||
logging.info(f"Evaluating {len(dataset)} data points...")
|
||||
logger.info(f"Evaluating {len(dataset)} data points...")
|
||||
result = {}
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
future_to_metric = {executor.submit(self._eval, dataset, metric): metric for metric in metrics}
|
||||
|
||||
@@ -17,6 +17,8 @@ except ModuleNotFoundError:
|
||||
) from None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
client = discord.Client(intents=intents)
|
||||
@@ -37,7 +39,7 @@ class DiscordBot(BaseBot):
|
||||
self.add(data)
|
||||
response = f"Added data from: {data}"
|
||||
except Exception:
|
||||
logging.exception(f"Failed to add data {data}.")
|
||||
logger.exception(f"Failed to add data {data}.")
|
||||
response = "Some error occurred while adding data."
|
||||
return response
|
||||
|
||||
@@ -45,7 +47,7 @@ class DiscordBot(BaseBot):
|
||||
try:
|
||||
response = self.query(message)
|
||||
except Exception:
|
||||
logging.exception(f"Failed to query {message}.")
|
||||
logger.exception(f"Failed to query {message}.")
|
||||
response = "An error occurred. Please try again!"
|
||||
return response
|
||||
|
||||
@@ -60,7 +62,7 @@ class DiscordBot(BaseBot):
|
||||
async def query_command(interaction: discord.Interaction, question: str):
|
||||
await interaction.response.defer()
|
||||
member = client.guilds[0].get_member(client.user.id)
|
||||
logging.info(f"User: {member}, Query: {question}")
|
||||
logger.info(f"User: {member}, Query: {question}")
|
||||
try:
|
||||
answer = discord_bot.ask_bot(question)
|
||||
if args.include_question:
|
||||
@@ -70,20 +72,20 @@ async def query_command(interaction: discord.Interaction, question: str):
|
||||
await interaction.followup.send(response)
|
||||
except Exception as e:
|
||||
await interaction.followup.send("An error occurred. Please try again!")
|
||||
logging.error("Error occurred during 'query' command:", e)
|
||||
logger.error("Error occurred during 'query' command:", e)
|
||||
|
||||
|
||||
@tree.command(name="add", description="add new content to the embedchain database")
|
||||
async def add_command(interaction: discord.Interaction, url_or_text: str):
|
||||
await interaction.response.defer()
|
||||
member = client.guilds[0].get_member(client.user.id)
|
||||
logging.info(f"User: {member}, Add: {url_or_text}")
|
||||
logger.info(f"User: {member}, Add: {url_or_text}")
|
||||
try:
|
||||
response = discord_bot.add_data(url_or_text)
|
||||
await interaction.followup.send(response)
|
||||
except Exception as e:
|
||||
await interaction.followup.send("An error occurred. Please try again!")
|
||||
logging.error("Error occurred during 'add' command:", e)
|
||||
logger.error("Error occurred during 'add' command:", e)
|
||||
|
||||
|
||||
@tree.command(name="ping", description="Simple ping pong command")
|
||||
@@ -96,7 +98,7 @@ async def on_app_command_error(interaction: discord.Interaction, error: discord.
|
||||
if isinstance(error, commands.CommandNotFound):
|
||||
await interaction.followup.send("Invalid command. Please refer to the documentation for correct syntax.")
|
||||
else:
|
||||
logging.error("Error occurred during command execution:", error)
|
||||
logger.error("Error occurred during command execution:", error)
|
||||
|
||||
|
||||
@client.event
|
||||
@@ -104,8 +106,8 @@ async def on_ready():
|
||||
# TODO: Sync in admin command, to not hit rate limits.
|
||||
# This might be overkill for most users, and it would require to set a guild or user id, where sync is allowed.
|
||||
await tree.sync()
|
||||
logging.debug("Command tree synced")
|
||||
logging.info(f"Logged in as {client.user.name}")
|
||||
logger.debug("Command tree synced")
|
||||
logger.info(f"Logged in as {client.user.name}")
|
||||
|
||||
|
||||
def start_command():
|
||||
|
||||
@@ -19,6 +19,8 @@ except ModuleNotFoundError:
|
||||
) from None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SLACK_BOT_TOKEN = os.environ.get("SLACK_BOT_TOKEN")
|
||||
|
||||
|
||||
@@ -42,10 +44,10 @@ class SlackBot(BaseBot):
|
||||
try:
|
||||
response = self.chat_bot.chat(question)
|
||||
self.send_slack_message(message["channel"], response)
|
||||
logging.info("Query answered successfully!")
|
||||
logger.info("Query answered successfully!")
|
||||
except Exception as e:
|
||||
self.send_slack_message(message["channel"], "An error occurred. Please try again!")
|
||||
logging.error("Error occurred during 'query' command:", e)
|
||||
logger.error("Error occurred during 'query' command:", e)
|
||||
elif text.startswith("add"):
|
||||
_, data_type, url_or_text = text.split(" ", 2)
|
||||
if url_or_text.startswith("<") and url_or_text.endswith(">"):
|
||||
@@ -55,10 +57,10 @@ class SlackBot(BaseBot):
|
||||
self.send_slack_message(message["channel"], f"Added {data_type} : {url_or_text}")
|
||||
except ValueError as e:
|
||||
self.send_slack_message(message["channel"], f"Error: {str(e)}")
|
||||
logging.error("Error occurred during 'add' command:", e)
|
||||
logger.error("Error occurred during 'add' command:", e)
|
||||
except Exception as e:
|
||||
self.send_slack_message(message["channel"], f"Failed to add {data_type} : {url_or_text}")
|
||||
logging.error("Error occurred during 'add' command:", e)
|
||||
logger.error("Error occurred during 'add' command:", e)
|
||||
|
||||
def send_slack_message(self, channel, message):
|
||||
response = self.client.chat_postMessage(channel=channel, text=message)
|
||||
@@ -68,7 +70,7 @@ class SlackBot(BaseBot):
|
||||
app = Flask(__name__)
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
logging.info("\nGracefully shutting down the SlackBot...")
|
||||
logger.info("\nGracefully shutting down the SlackBot...")
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
@@ -8,6 +8,8 @@ from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
from .base import BaseBot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class WhatsAppBot(BaseBot):
|
||||
@@ -35,7 +37,7 @@ class WhatsAppBot(BaseBot):
|
||||
self.add(data)
|
||||
response = f"Added data from: {data}"
|
||||
except Exception:
|
||||
logging.exception(f"Failed to add data {data}.")
|
||||
logger.exception(f"Failed to add data {data}.")
|
||||
response = "Some error occurred while adding data."
|
||||
return response
|
||||
|
||||
@@ -43,7 +45,7 @@ class WhatsAppBot(BaseBot):
|
||||
try:
|
||||
response = self.query(message)
|
||||
except Exception:
|
||||
logging.exception(f"Failed to query {message}.")
|
||||
logger.exception(f"Failed to query {message}.")
|
||||
response = "An error occurred. Please try again!"
|
||||
return response
|
||||
|
||||
@@ -51,7 +53,7 @@ class WhatsAppBot(BaseBot):
|
||||
app = self.flask.Flask(__name__)
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
logging.info("\nGracefully shutting down the WhatsAppBot...")
|
||||
logger.info("\nGracefully shutting down the WhatsAppBot...")
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
@@ -14,6 +14,8 @@ from gptcache.similarity_evaluation.distance import \
|
||||
from gptcache.similarity_evaluation.exact_match import \
|
||||
ExactMatchEvaluation # noqa: F401
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def gptcache_pre_function(data: dict[str, Any], **params: dict[str, Any]):
|
||||
return data["input_query"]
|
||||
@@ -24,12 +26,12 @@ def gptcache_data_manager(vector_dimension):
|
||||
|
||||
|
||||
def gptcache_data_convert(cache_data):
|
||||
logging.info("[Cache] Cache hit, returning cache data...")
|
||||
logger.info("[Cache] Cache hit, returning cache data...")
|
||||
return cache_data
|
||||
|
||||
|
||||
def gptcache_update_cache_callback(llm_data, update_cache_func, *args, **kwargs):
|
||||
logging.info("[Cache] Cache missed, updating cache...")
|
||||
logger.info("[Cache] Cache missed, updating cache...")
|
||||
update_cache_func(Answer(llm_data, CacheDataType.STR))
|
||||
return llm_data
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@ from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseChunker(JSONSerializable):
|
||||
def __init__(self, text_splitter):
|
||||
@@ -27,7 +29,7 @@ class BaseChunker(JSONSerializable):
|
||||
chunk_ids = []
|
||||
id_map = {}
|
||||
min_chunk_size = config.min_chunk_size if config is not None else 1
|
||||
logging.info(f"Skipping chunks smaller than {min_chunk_size} characters")
|
||||
logger.info(f"Skipping chunks smaller than {min_chunk_size} characters")
|
||||
data_result = loader.load_data(src)
|
||||
data_records = data_result["data"]
|
||||
doc_id = data_result["doc_id"]
|
||||
|
||||
@@ -7,6 +7,8 @@ import requests
|
||||
|
||||
from embedchain.constants import CONFIG_DIR, CONFIG_FILE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Client:
|
||||
def __init__(self, api_key=None, host="https://apiv2.embedchain.ai"):
|
||||
@@ -24,7 +26,7 @@ class Client:
|
||||
else:
|
||||
if "api_key" in self.config_data:
|
||||
self.api_key = self.config_data["api_key"]
|
||||
logging.info("API key loaded successfully!")
|
||||
logger.info("API key loaded successfully!")
|
||||
else:
|
||||
raise ValueError(
|
||||
"You are not logged in. Please obtain an API key from https://app.embedchain.ai/settings/keys/"
|
||||
@@ -64,7 +66,7 @@ class Client:
|
||||
with open(CONFIG_FILE, "w") as config_file:
|
||||
json.dump(self.config_data, config_file, indent=4)
|
||||
|
||||
logging.info("API key saved successfully!")
|
||||
logger.info("API key saved successfully!")
|
||||
|
||||
def clear(self):
|
||||
if "api_key" in self.config_data:
|
||||
@@ -72,17 +74,17 @@ class Client:
|
||||
with open(CONFIG_FILE, "w") as config_file:
|
||||
json.dump(self.config_data, config_file, indent=4)
|
||||
self.api_key = None
|
||||
logging.info("API key deleted successfully!")
|
||||
logger.info("API key deleted successfully!")
|
||||
else:
|
||||
logging.warning("API key not found in the configuration file.")
|
||||
logger.warning("API key not found in the configuration file.")
|
||||
|
||||
def update(self, api_key):
|
||||
if self.check(api_key):
|
||||
self.api_key = api_key
|
||||
self.save()
|
||||
logging.info("API key updated successfully!")
|
||||
logger.info("API key updated successfully!")
|
||||
else:
|
||||
logging.warning("Invalid API key provided. API key not updated.")
|
||||
logger.warning("Invalid API key provided. API key not updated.")
|
||||
|
||||
def check(self, api_key):
|
||||
validation_url = f"{self.host}/api/v1/accounts/api_keys/validate/"
|
||||
@@ -90,8 +92,8 @@ class Client:
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
logging.warning(f"Response from API: {response.text}")
|
||||
logging.warning("Invalid API key. Unable to validate.")
|
||||
logger.warning(f"Response from API: {response.text}")
|
||||
logger.warning("Invalid API key. Unable to validate.")
|
||||
return False
|
||||
|
||||
def get(self):
|
||||
|
||||
@@ -5,6 +5,8 @@ from embedchain.config.base_config import BaseConfig
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseAppConfig(BaseConfig, JSONSerializable):
|
||||
"""
|
||||
@@ -42,15 +44,15 @@ class BaseAppConfig(BaseConfig, JSONSerializable):
|
||||
|
||||
if db:
|
||||
self._db = db
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
"DEPRECATION WARNING: Please supply the database as the second parameter during app init. "
|
||||
"Such as `app(config=config, db=db)`."
|
||||
)
|
||||
|
||||
if collection_name:
|
||||
logging.warning("DEPRECATION WARNING: Please supply the collection name to the database config.")
|
||||
logger.warning("DEPRECATION WARNING: Please supply the collection name to the database config.")
|
||||
return
|
||||
|
||||
def _setup_logging(self, log_level):
|
||||
logging.basicConfig(format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s", level=log_level)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
logger.basicConfig(format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s", level=log_level)
|
||||
self.logger = logger.getLogger(__name__)
|
||||
|
||||
@@ -6,6 +6,8 @@ from typing import Any, Optional
|
||||
from embedchain.config.base_config import BaseConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_PROMPT = """
|
||||
You are a Q&A expert system. Your responses must always be rooted in the context provided for each query. Here are some guidelines to follow:
|
||||
|
||||
@@ -147,7 +149,7 @@ class BaseLlmConfig(BaseConfig):
|
||||
:raises ValueError: Stream is not boolean
|
||||
"""
|
||||
if template is not None:
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
"The `template` argument is deprecated and will be removed in a future version. "
|
||||
+ "Please use `prompt` instead."
|
||||
)
|
||||
|
||||
@@ -25,6 +25,8 @@ from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbedChain(JSONSerializable):
|
||||
def __init__(
|
||||
@@ -143,10 +145,10 @@ class EmbedChain(JSONSerializable):
|
||||
|
||||
try:
|
||||
DataType(source)
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
f"""Starting from version v0.0.40, Embedchain can automatically detect the data type. So, in the `add` method, the argument order has changed. You no longer need to specify '{source}' for the `source` argument. So the code snippet will be `.add("{data_type}", "{source}")`""" # noqa #E501
|
||||
)
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
"Embedchain is swapping the arguments for you. This functionality might be deprecated in the future, so please adjust your code." # noqa #E501
|
||||
)
|
||||
source, data_type = data_type, source
|
||||
@@ -157,7 +159,7 @@ class EmbedChain(JSONSerializable):
|
||||
try:
|
||||
data_type = DataType(data_type)
|
||||
except ValueError:
|
||||
logging.info(
|
||||
logger.info(
|
||||
f"Invalid data_type: '{data_type}', using `custom` instead.\n Check docs to pass the valid data type: `https://docs.embedchain.ai/data-sources/overview`" # noqa: E501
|
||||
)
|
||||
data_type = DataType.CUSTOM
|
||||
@@ -190,12 +192,12 @@ class EmbedChain(JSONSerializable):
|
||||
try:
|
||||
self.db_session.commit()
|
||||
except Exception as e:
|
||||
logging.error(f"Error adding data source: {e}")
|
||||
logger.error(f"Error adding data source: {e}")
|
||||
self.db_session.rollback()
|
||||
|
||||
if dry_run:
|
||||
data_chunks_info = {"chunks": documents, "metadata": metadatas, "count": len(documents), "type": data_type}
|
||||
logging.debug(f"Dry run info : {data_chunks_info}")
|
||||
logger.debug(f"Dry run info : {data_chunks_info}")
|
||||
return data_chunks_info
|
||||
|
||||
# Send anonymous telemetry
|
||||
@@ -490,7 +492,7 @@ class EmbedChain(JSONSerializable):
|
||||
contexts_data_for_llm_query = contexts
|
||||
|
||||
if self.cache_config is not None:
|
||||
logging.info("Cache enabled. Checking cache...")
|
||||
logger.info("Cache enabled. Checking cache...")
|
||||
answer = adapt(
|
||||
llm_handler=self.llm.query,
|
||||
cache_data_convert=gptcache_data_convert,
|
||||
@@ -562,7 +564,7 @@ class EmbedChain(JSONSerializable):
|
||||
self.llm.update_history(app_id=self.config.id, session_id=session_id)
|
||||
|
||||
if self.cache_config is not None:
|
||||
logging.info("Cache enabled. Checking cache...")
|
||||
logger.debug("Cache enabled. Checking cache...")
|
||||
cache_id = f"{session_id}--{self.config.id}"
|
||||
answer = adapt(
|
||||
llm_handler=self.llm.chat,
|
||||
@@ -575,6 +577,7 @@ class EmbedChain(JSONSerializable):
|
||||
dry_run=dry_run,
|
||||
)
|
||||
else:
|
||||
logger.debug("Cache disabled. Running chat without cache.")
|
||||
answer = self.llm.chat(
|
||||
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
|
||||
)
|
||||
@@ -652,7 +655,7 @@ class EmbedChain(JSONSerializable):
|
||||
self.db_session.query(ChatHistory).filter_by(app_id=self.config.id).delete()
|
||||
self.db_session.commit()
|
||||
except Exception as e:
|
||||
logging.error(f"Error deleting data sources: {e}")
|
||||
logger.error(f"Error deleting data sources: {e}")
|
||||
self.db_session.rollback()
|
||||
return None
|
||||
self.db.reset()
|
||||
@@ -694,11 +697,11 @@ class EmbedChain(JSONSerializable):
|
||||
self.db_session.query(DataSource).filter_by(hash=source_id, app_id=self.config.id).delete()
|
||||
self.db_session.commit()
|
||||
except Exception as e:
|
||||
logging.error(f"Error deleting data sources: {e}")
|
||||
logger.error(f"Error deleting data sources: {e}")
|
||||
self.db_session.rollback()
|
||||
return None
|
||||
self.db.delete(where={"hash": source_id})
|
||||
logging.info(f"Successfully deleted {source_id}")
|
||||
logger.info(f"Successfully deleted {source_id}")
|
||||
# Send anonymous telemetry
|
||||
if self.config.collect_metrics:
|
||||
self.telemetry.capture(event_name="delete", properties=self._telemetry_props)
|
||||
|
||||
@@ -8,6 +8,8 @@ from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.models import VectorDimensions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NvidiaEmbedder(BaseEmbedder):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
@@ -17,7 +19,7 @@ class NvidiaEmbedder(BaseEmbedder):
|
||||
super().__init__(config=config)
|
||||
|
||||
model = self.config.model or "nvolveqa_40k"
|
||||
logging.info(f"Using NVIDIA embedding model: {model}")
|
||||
logger.info(f"Using NVIDIA embedding model: {model}")
|
||||
embedder = NVIDIAEmbeddings(model=model)
|
||||
embedding_fn = BaseEmbedder._langchain_default_concept(embedder)
|
||||
self.set_embedding_fn(embedding_fn=embedding_fn)
|
||||
|
||||
@@ -12,6 +12,8 @@ from embedchain.config.evaluation.base import AnswerRelevanceConfig
|
||||
from embedchain.evaluation.base import BaseMetric
|
||||
from embedchain.utils.evaluation import EvalData, EvalMetric
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnswerRelevance(BaseMetric):
|
||||
"""
|
||||
@@ -88,6 +90,6 @@ class AnswerRelevance(BaseMetric):
|
||||
try:
|
||||
results.append(future.result())
|
||||
except Exception as e:
|
||||
logging.error(f"Error evaluating answer relevancy for {data}: {e}")
|
||||
logger.error(f"Error evaluating answer relevancy for {data}: {e}")
|
||||
|
||||
return np.mean(results) if results else 0.0
|
||||
|
||||
@@ -12,6 +12,8 @@ from embedchain.config.evaluation.base import GroundednessConfig
|
||||
from embedchain.evaluation.base import BaseMetric
|
||||
from embedchain.utils.evaluation import EvalData, EvalMetric
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Groundedness(BaseMetric):
|
||||
"""
|
||||
@@ -97,6 +99,6 @@ class Groundedness(BaseMetric):
|
||||
score = future.result()
|
||||
results.append(score)
|
||||
except Exception as e:
|
||||
logging.error(f"Error while evaluating groundedness for data point {data}: {e}")
|
||||
logger.error(f"Error while evaluating groundedness for data point {data}: {e}")
|
||||
|
||||
return np.mean(results) if results else 0.0
|
||||
|
||||
@@ -8,6 +8,8 @@ T = TypeVar("T", bound="JSONSerializable")
|
||||
# NOTE: Through inheritance, all of our classes should be children of JSONSerializable. (highest level)
|
||||
# NOTE: The @register_deserializable decorator should be added to all user facing child classes. (lowest level)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def register_deserializable(cls: Type[T]) -> Type[T]:
|
||||
"""
|
||||
@@ -57,7 +59,7 @@ class JSONSerializable:
|
||||
try:
|
||||
return json.dumps(self, default=self._auto_encoder, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logging.error(f"Serialization error: {e}")
|
||||
logger.error(f"Serialization error: {e}")
|
||||
return "{}"
|
||||
|
||||
@classmethod
|
||||
@@ -79,7 +81,7 @@ class JSONSerializable:
|
||||
try:
|
||||
return json.loads(json_str, object_hook=cls._auto_decoder)
|
||||
except Exception as e:
|
||||
logging.error(f"Deserialization error: {e}")
|
||||
logger.error(f"Deserialization error: {e}")
|
||||
# Return a default instance in case of failure
|
||||
return cls()
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@ from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class AnthropicLlm(BaseLlm):
|
||||
@@ -26,7 +28,7 @@ class AnthropicLlm(BaseLlm):
|
||||
)
|
||||
|
||||
if config.max_tokens and config.max_tokens != 1000:
|
||||
logging.warning("Config option `max_tokens` is not supported by this model.")
|
||||
logger.warning("Config option `max_tokens` is not supported by this model.")
|
||||
|
||||
messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)
|
||||
|
||||
|
||||
@@ -38,7 +38,8 @@ class AWSBedrockLlm(BaseLlm):
|
||||
}
|
||||
|
||||
if config.stream:
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
from langchain.callbacks.streaming_stdout import \
|
||||
StreamingStdOutCallbackHandler
|
||||
|
||||
callbacks = [StreamingStdOutCallbackHandler()]
|
||||
llm = Bedrock(**kwargs, streaming=config.stream, callbacks=callbacks)
|
||||
|
||||
@@ -5,6 +5,8 @@ from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class AzureOpenAILlm(BaseLlm):
|
||||
@@ -31,7 +33,7 @@ class AzureOpenAILlm(BaseLlm):
|
||||
)
|
||||
|
||||
if config.top_p and config.top_p != 1:
|
||||
logging.warning("Config option `top_p` is not supported by this model.")
|
||||
logger.warning("Config option `top_p` is not supported by this model.")
|
||||
|
||||
messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@ from embedchain.helpers.json_serializable import JSONSerializable
|
||||
from embedchain.memory.base import ChatHistory
|
||||
from embedchain.memory.message import ChatMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseLlm(JSONSerializable):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
@@ -108,7 +110,7 @@ class BaseLlm(JSONSerializable):
|
||||
)
|
||||
else:
|
||||
# If we can't swap in the default, we still proceed but tell users that the history is ignored.
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
"Your bot contains a history, but prompt does not include `$history` key. History is ignored."
|
||||
)
|
||||
prompt = self.config.prompt.substitute(context=context_string, query=input_query)
|
||||
@@ -159,7 +161,7 @@ class BaseLlm(JSONSerializable):
|
||||
'Searching requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
|
||||
) from None
|
||||
search = DuckDuckGoSearchRun()
|
||||
logging.info(f"Access search to get answers for {input_query}")
|
||||
logger.info(f"Access search to get answers for {input_query}")
|
||||
return search.run(input_query)
|
||||
|
||||
@staticmethod
|
||||
@@ -175,7 +177,7 @@ class BaseLlm(JSONSerializable):
|
||||
for chunk in answer:
|
||||
streamed_answer = streamed_answer + chunk
|
||||
yield chunk
|
||||
logging.info(f"Answer: {streamed_answer}")
|
||||
logger.info(f"Answer: {streamed_answer}")
|
||||
|
||||
def query(self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False):
|
||||
"""
|
||||
@@ -214,13 +216,13 @@ class BaseLlm(JSONSerializable):
|
||||
if self.online:
|
||||
k["web_search_result"] = self.access_search_and_get_results(input_query)
|
||||
prompt = self.generate_prompt(input_query, contexts, **k)
|
||||
logging.info(f"Prompt: {prompt}")
|
||||
logger.info(f"Prompt: {prompt}")
|
||||
if dry_run:
|
||||
return prompt
|
||||
|
||||
answer = self.get_answer_from_llm(prompt)
|
||||
if isinstance(answer, str):
|
||||
logging.info(f"Answer: {answer}")
|
||||
logger.info(f"Answer: {answer}")
|
||||
return answer
|
||||
else:
|
||||
return self._stream_response(answer)
|
||||
@@ -270,14 +272,14 @@ class BaseLlm(JSONSerializable):
|
||||
k["web_search_result"] = self.access_search_and_get_results(input_query)
|
||||
|
||||
prompt = self.generate_prompt(input_query, contexts, **k)
|
||||
logging.info(f"Prompt: {prompt}")
|
||||
logger.info(f"Prompt: {prompt}")
|
||||
|
||||
if dry_run:
|
||||
return prompt
|
||||
|
||||
answer = self.get_answer_from_llm(prompt)
|
||||
if isinstance(answer, str):
|
||||
logging.info(f"Answer: {answer}")
|
||||
logger.info(f"Answer: {answer}")
|
||||
return answer
|
||||
else:
|
||||
# this is a streamed response and needs to be handled differently.
|
||||
|
||||
@@ -10,6 +10,8 @@ from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class GoogleLlm(BaseLlm):
|
||||
@@ -36,7 +38,7 @@ class GoogleLlm(BaseLlm):
|
||||
|
||||
def _get_answer(self, prompt: str) -> Union[str, Generator[Any, Any, None]]:
|
||||
model_name = self.config.model or "gemini-pro"
|
||||
logging.info(f"Using Google LLM model: {model_name}")
|
||||
logger.info(f"Using Google LLM model: {model_name}")
|
||||
model = genai.GenerativeModel(model_name=model_name)
|
||||
|
||||
generation_config_params = {
|
||||
|
||||
@@ -11,6 +11,8 @@ from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class HuggingFaceLlm(BaseLlm):
|
||||
@@ -58,7 +60,7 @@ class HuggingFaceLlm(BaseLlm):
|
||||
raise ValueError("`top_p` must be > 0.0 and < 1.0")
|
||||
|
||||
model = config.model
|
||||
logging.info(f"Using HuggingFaceHub with model {model}")
|
||||
logger.info(f"Using HuggingFaceHub with model {model}")
|
||||
llm = HuggingFaceHub(
|
||||
huggingfacehub_api_token=os.environ["HUGGINGFACE_ACCESS_TOKEN"],
|
||||
repo_id=model,
|
||||
|
||||
@@ -65,7 +65,8 @@ class OpenAILlm(BaseLlm):
|
||||
messages: list[BaseMessage],
|
||||
) -> str:
|
||||
from langchain.output_parsers.openai_tools import JsonOutputToolsParser
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from langchain_core.utils.function_calling import \
|
||||
convert_to_openai_tool
|
||||
|
||||
openai_tools = [convert_to_openai_tool(tools)]
|
||||
chat = chat.bind(tools=openai_tools).pipe(JsonOutputToolsParser())
|
||||
|
||||
@@ -9,6 +9,8 @@ from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class VertexAILlm(BaseLlm):
|
||||
@@ -28,7 +30,7 @@ class VertexAILlm(BaseLlm):
|
||||
@staticmethod
|
||||
def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
|
||||
if config.top_p and config.top_p != 1:
|
||||
logging.warning("Config option `top_p` is not supported by this model.")
|
||||
logger.warning("Config option `top_p` is not supported by this model.")
|
||||
|
||||
messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)
|
||||
|
||||
|
||||
@@ -9,6 +9,8 @@ from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils.misc import is_readable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class BeehiivLoader(BaseLoader):
|
||||
@@ -90,9 +92,9 @@ class BeehiivLoader(BaseLoader):
|
||||
if is_readable(data):
|
||||
return data
|
||||
else:
|
||||
logging.warning(f"Page is not readable (too many invalid characters): {link}")
|
||||
logger.warning(f"Page is not readable (too many invalid characters): {link}")
|
||||
except ParserRejectedMarkup as e:
|
||||
logging.error(f"Failed to parse {link}: {e}")
|
||||
logger.error(f"Failed to parse {link}: {e}")
|
||||
return None
|
||||
|
||||
for link in links:
|
||||
|
||||
@@ -10,6 +10,8 @@ from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.loaders.text_file import TextFileLoader
|
||||
from embedchain.utils.misc import detect_datatype
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class DirectoryLoader(BaseLoader):
|
||||
@@ -27,12 +29,12 @@ class DirectoryLoader(BaseLoader):
|
||||
if not directory_path.is_dir():
|
||||
raise ValueError(f"Invalid path: {path}")
|
||||
|
||||
logging.info(f"Loading data from directory: {path}")
|
||||
logger.info(f"Loading data from directory: {path}")
|
||||
data_list = self._process_directory(directory_path)
|
||||
doc_id = hashlib.sha256((str(data_list) + str(directory_path)).encode()).hexdigest()
|
||||
|
||||
for error in self.errors:
|
||||
logging.warning(error)
|
||||
logger.warning(error)
|
||||
|
||||
return {"doc_id": doc_id, "data": data_list}
|
||||
|
||||
@@ -46,7 +48,7 @@ class DirectoryLoader(BaseLoader):
|
||||
loader = self._predict_loader(file_path)
|
||||
data_list.extend(loader.load_data(str(file_path))["data"])
|
||||
elif file_path.is_dir():
|
||||
logging.info(f"Loading data from directory: {file_path}")
|
||||
logger.info(f"Loading data from directory: {file_path}")
|
||||
return data_list
|
||||
|
||||
def _predict_loader(self, file_path: Path) -> BaseLoader:
|
||||
|
||||
@@ -5,6 +5,8 @@ import os
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class DiscordLoader(BaseLoader):
|
||||
@@ -102,7 +104,7 @@ class DiscordLoader(BaseLoader):
|
||||
|
||||
class DiscordClient(discord.Client):
|
||||
async def on_ready(self) -> None:
|
||||
logging.info("Logged on as {0}!".format(self.user))
|
||||
logger.info("Logged on as {0}!".format(self.user))
|
||||
try:
|
||||
channel = self.get_channel(int(channel_id))
|
||||
if not isinstance(channel, discord.TextChannel):
|
||||
@@ -121,7 +123,7 @@ class DiscordLoader(BaseLoader):
|
||||
messages.append(DiscordLoader._format_message(thread_message))
|
||||
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
logger.error(e)
|
||||
await self.close()
|
||||
finally:
|
||||
await self.close()
|
||||
|
||||
@@ -8,6 +8,8 @@ import requests
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils.misc import clean_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DiscourseLoader(BaseLoader):
|
||||
def __init__(self, config: Optional[dict[str, Any]] = None):
|
||||
@@ -35,7 +37,7 @@ class DiscourseLoader(BaseLoader):
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to load post {post_id}: {e}")
|
||||
logger.error(f"Failed to load post {post_id}: {e}")
|
||||
return
|
||||
response_data = response.json()
|
||||
post_contents = clean_string(response_data.get("raw"))
|
||||
@@ -56,7 +58,7 @@ class DiscourseLoader(BaseLoader):
|
||||
self._check_query(query)
|
||||
data = []
|
||||
data_contents = []
|
||||
logging.info(f"Searching data on discourse url: {self.domain}, for query: {query}")
|
||||
logger.info(f"Searching data on discourse url: {self.domain}, for query: {query}")
|
||||
search_url = f"{self.domain}search.json?q={query}"
|
||||
response = requests.get(search_url)
|
||||
try:
|
||||
|
||||
@@ -15,6 +15,8 @@ except ImportError:
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class DocsSiteLoader(BaseLoader):
|
||||
@@ -28,7 +30,7 @@ class DocsSiteLoader(BaseLoader):
|
||||
|
||||
response = requests.get(url)
|
||||
if response.status_code != 200:
|
||||
logging.info(f"Failed to fetch the website: {response.status_code}")
|
||||
logger.info(f"Failed to fetch the website: {response.status_code}")
|
||||
return
|
||||
|
||||
soup = BeautifulSoup(response.text, "html.parser")
|
||||
@@ -53,7 +55,7 @@ class DocsSiteLoader(BaseLoader):
|
||||
def _load_data_from_url(url: str) -> list:
|
||||
response = requests.get(url)
|
||||
if response.status_code != 200:
|
||||
logging.info(f"Failed to fetch the website: {response.status_code}")
|
||||
logger.info(f"Failed to fetch the website: {response.status_code}")
|
||||
return []
|
||||
|
||||
soup = BeautifulSoup(response.content, "html.parser")
|
||||
|
||||
@@ -22,6 +22,8 @@ except ImportError:
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils.misc import clean_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GmailReader:
|
||||
SCOPES = ["https://www.googleapis.com/auth/gmail.readonly"]
|
||||
@@ -114,7 +116,7 @@ class GmailLoader(BaseLoader):
|
||||
def load_data(self, query: str):
|
||||
reader = GmailReader(query=query)
|
||||
emails = reader.load_emails()
|
||||
logging.info(f"Gmail Loader: {len(emails)} emails found for query '{query}'")
|
||||
logger.info(f"Gmail Loader: {len(emails)} emails found for query '{query}'")
|
||||
|
||||
data = []
|
||||
for email in emails:
|
||||
|
||||
@@ -5,6 +5,8 @@ from typing import Any, Optional
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils.misc import clean_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MySQLLoader(BaseLoader):
|
||||
def __init__(self, config: Optional[dict[str, Any]]):
|
||||
@@ -32,7 +34,7 @@ class MySQLLoader(BaseLoader):
|
||||
self.connection = sqlconnector.connection.MySQLConnection(**config)
|
||||
self.cursor = self.connection.cursor()
|
||||
except (sqlconnector.Error, IOError) as err:
|
||||
logging.info(f"Connection failed: {err}")
|
||||
logger.info(f"Connection failed: {err}")
|
||||
raise ValueError(
|
||||
f"Unable to connect with the given config: {config}.",
|
||||
"Please provide the correct configuration to load data from you MySQL DB. \
|
||||
|
||||
@@ -9,6 +9,8 @@ from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils.misc import clean_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NotionDocument:
|
||||
"""
|
||||
@@ -98,7 +100,7 @@ class NotionLoader(BaseLoader):
|
||||
|
||||
id = source[-32:]
|
||||
formatted_id = f"{id[:8]}-{id[8:12]}-{id[12:16]}-{id[16:20]}-{id[20:]}"
|
||||
logging.debug(f"Extracted notion page id as: {formatted_id}")
|
||||
logger.debug(f"Extracted notion page id as: {formatted_id}")
|
||||
|
||||
integration_token = os.getenv("NOTION_INTEGRATION_TOKEN")
|
||||
reader = NotionPageLoader(integration_token=integration_token)
|
||||
|
||||
@@ -4,6 +4,8 @@ from typing import Any, Optional
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PostgresLoader(BaseLoader):
|
||||
def __init__(self, config: Optional[dict[str, Any]] = None):
|
||||
@@ -32,7 +34,7 @@ class PostgresLoader(BaseLoader):
|
||||
conn_params.append(f"{key}={value}")
|
||||
config_info = " ".join(conn_params)
|
||||
|
||||
logging.info(f"Connecting to postrgres sql: {config_info}")
|
||||
logger.info(f"Connecting to postrgres sql: {config_info}")
|
||||
self.connection = psycopg.connect(conninfo=config_info)
|
||||
self.cursor = self.connection.cursor()
|
||||
|
||||
|
||||
@@ -19,6 +19,8 @@ from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.loaders.web_page import WebPageLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class SitemapLoader(BaseLoader):
|
||||
@@ -41,7 +43,7 @@ class SitemapLoader(BaseLoader):
|
||||
response.raise_for_status()
|
||||
soup = BeautifulSoup(response.text, "xml")
|
||||
except requests.RequestException as e:
|
||||
logging.error(f"Error fetching sitemap from URL: {e}")
|
||||
logger.error(f"Error fetching sitemap from URL: {e}")
|
||||
return
|
||||
elif os.path.isfile(sitemap_source):
|
||||
with open(sitemap_source, "r") as file:
|
||||
@@ -60,7 +62,7 @@ class SitemapLoader(BaseLoader):
|
||||
loader_data = web_page_loader.load_data(link)
|
||||
return loader_data.get("data")
|
||||
except ParserRejectedMarkup as e:
|
||||
logging.error(f"Failed to parse {link}: {e}")
|
||||
logger.error(f"Failed to parse {link}: {e}")
|
||||
return None
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
@@ -72,6 +74,6 @@ class SitemapLoader(BaseLoader):
|
||||
if data:
|
||||
output.extend(data)
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading page {link}: {e}")
|
||||
logger.error(f"Error loading page {link}: {e}")
|
||||
|
||||
return {"doc_id": doc_id, "data": output}
|
||||
|
||||
@@ -11,6 +11,8 @@ from embedchain.utils.misc import clean_string
|
||||
|
||||
SLACK_API_BASE_URL = "https://www.slack.com/api/"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SlackLoader(BaseLoader):
|
||||
def __init__(self, config: Optional[dict[str, Any]] = None):
|
||||
@@ -38,7 +40,7 @@ class SlackLoader(BaseLoader):
|
||||
"SLACK_USER_TOKEN environment variables not provided. Check `https://docs.embedchain.ai/data-sources/slack` to learn more." # noqa:E501
|
||||
)
|
||||
|
||||
logging.info(f"Creating Slack Loader with config: {config}")
|
||||
logger.info(f"Creating Slack Loader with config: {config}")
|
||||
# get slack client config params
|
||||
slack_bot_token = os.getenv("SLACK_USER_TOKEN")
|
||||
ssl_cert = ssl.create_default_context(cafile=certifi.where())
|
||||
@@ -54,7 +56,7 @@ class SlackLoader(BaseLoader):
|
||||
headers=headers,
|
||||
team_id=team_id,
|
||||
)
|
||||
logging.info("Slack Loader setup successful!")
|
||||
logger.info("Slack Loader setup successful!")
|
||||
|
||||
@staticmethod
|
||||
def _check_query(query):
|
||||
@@ -69,7 +71,7 @@ class SlackLoader(BaseLoader):
|
||||
data = []
|
||||
data_content = []
|
||||
|
||||
logging.info(f"Searching slack conversations for query: {query}")
|
||||
logger.info(f"Searching slack conversations for query: {query}")
|
||||
results = self.client.search_messages(
|
||||
query=query,
|
||||
sort="timestamp",
|
||||
@@ -79,7 +81,7 @@ class SlackLoader(BaseLoader):
|
||||
|
||||
messages = results.get("messages")
|
||||
num_message = len(messages)
|
||||
logging.info(f"Found {num_message} messages for query: {query}")
|
||||
logger.info(f"Found {num_message} messages for query: {query}")
|
||||
|
||||
matches = messages.get("matches", [])
|
||||
for message in matches:
|
||||
@@ -107,7 +109,7 @@ class SlackLoader(BaseLoader):
|
||||
"data": data,
|
||||
}
|
||||
except Exception as e:
|
||||
logging.warning(f"Error in loading slack data: {e}")
|
||||
logger.warning(f"Error in loading slack data: {e}")
|
||||
raise ValueError(
|
||||
f"Error in loading slack data: {e}. Check `https://docs.embedchain.ai/data-sources/slack` to learn more." # noqa:E501
|
||||
) from e
|
||||
|
||||
@@ -9,6 +9,8 @@ from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils.misc import is_readable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class SubstackLoader(BaseLoader):
|
||||
@@ -90,9 +92,9 @@ class SubstackLoader(BaseLoader):
|
||||
if is_readable(data):
|
||||
return data
|
||||
else:
|
||||
logging.warning(f"Page is not readable (too many invalid characters): {link}")
|
||||
logger.warning(f"Page is not readable (too many invalid characters): {link}")
|
||||
except ParserRejectedMarkup as e:
|
||||
logging.error(f"Failed to parse {link}: {e}")
|
||||
logger.error(f"Failed to parse {link}: {e}")
|
||||
return None
|
||||
|
||||
for link in links:
|
||||
|
||||
@@ -14,6 +14,8 @@ from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils.misc import clean_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class WebPageLoader(BaseLoader):
|
||||
@@ -87,7 +89,7 @@ class WebPageLoader(BaseLoader):
|
||||
|
||||
cleaned_size = len(content)
|
||||
if original_size != 0:
|
||||
logging.info(
|
||||
logger.info(
|
||||
f"[{url}] Cleaned page size: {cleaned_size} characters, down from {original_size} (shrunk: {original_size-cleaned_size} chars, {round((1-(cleaned_size/original_size)) * 100, 2)}%)" # noqa:E501
|
||||
)
|
||||
|
||||
|
||||
@@ -7,6 +7,8 @@ from tqdm import tqdm
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.loaders.youtube_video import YoutubeVideoLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class YoutubeChannelLoader(BaseLoader):
|
||||
"""Loader for youtube channel."""
|
||||
@@ -36,7 +38,7 @@ class YoutubeChannelLoader(BaseLoader):
|
||||
videos = [entry["url"] for entry in info_dict["entries"]]
|
||||
return videos
|
||||
except Exception:
|
||||
logging.error(f"Failed to fetch youtube videos for channel: {channel_name}")
|
||||
logger.error(f"Failed to fetch youtube videos for channel: {channel_name}")
|
||||
return []
|
||||
|
||||
def _load_yt_video(video_link):
|
||||
@@ -45,12 +47,12 @@ class YoutubeChannelLoader(BaseLoader):
|
||||
if each_load_data:
|
||||
return each_load_data.get("data")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to load youtube video {video_link}: {e}")
|
||||
logger.error(f"Failed to load youtube video {video_link}: {e}")
|
||||
return None
|
||||
|
||||
def _add_youtube_channel():
|
||||
video_links = _get_yt_video_links()
|
||||
logging.info("Loading videos from youtube channel...")
|
||||
logger.info("Loading videos from youtube channel...")
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
# Submitting all tasks and storing the future object with the video link
|
||||
future_to_video = {
|
||||
@@ -67,7 +69,7 @@ class YoutubeChannelLoader(BaseLoader):
|
||||
data.extend(results)
|
||||
data_urls.extend([result.get("meta_data").get("url") for result in results])
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to process youtube video {video}: {e}")
|
||||
logger.error(f"Failed to process youtube video {video}: {e}")
|
||||
|
||||
_add_youtube_channel()
|
||||
doc_id = hashlib.sha256((youtube_url + ", ".join(data_urls)).encode()).hexdigest()
|
||||
|
||||
@@ -8,6 +8,8 @@ from embedchain.core.db.models import ChatHistory as ChatHistoryModel
|
||||
from embedchain.memory.message import ChatMessage
|
||||
from embedchain.memory.utils import merge_metadata_dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatHistory:
|
||||
def __init__(self) -> None:
|
||||
@@ -31,11 +33,11 @@ class ChatHistory:
|
||||
try:
|
||||
self.db_session.commit()
|
||||
except Exception as e:
|
||||
logging.error(f"Error adding chat memory to db: {e}")
|
||||
logger.error(f"Error adding chat memory to db: {e}")
|
||||
self.db_session.rollback()
|
||||
return None
|
||||
|
||||
logging.info(f"Added chat memory to db with id: {memory_id}")
|
||||
logger.info(f"Added chat memory to db with id: {memory_id}")
|
||||
return memory_id
|
||||
|
||||
def delete(self, app_id: str, session_id: Optional[str] = None):
|
||||
@@ -55,7 +57,7 @@ class ChatHistory:
|
||||
try:
|
||||
self.db_session.commit()
|
||||
except Exception as e:
|
||||
logging.error(f"Error deleting chat history: {e}")
|
||||
logger.error(f"Error deleting chat history: {e}")
|
||||
self.db_session.rollback()
|
||||
|
||||
def get(
|
||||
|
||||
@@ -3,6 +3,8 @@ from typing import Any, Optional
|
||||
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseMessage(JSONSerializable):
|
||||
"""
|
||||
@@ -52,7 +54,7 @@ class ChatMessage(JSONSerializable):
|
||||
|
||||
def add_user_message(self, message: str, metadata: Optional[dict] = None):
|
||||
if self.human_message:
|
||||
logging.info(
|
||||
logger.info(
|
||||
"Human message already exists in the chat message,\
|
||||
overwriting it with new message."
|
||||
)
|
||||
@@ -61,7 +63,7 @@ class ChatMessage(JSONSerializable):
|
||||
|
||||
def add_ai_message(self, message: str, metadata: Optional[dict] = None):
|
||||
if self.ai_message:
|
||||
logging.info(
|
||||
logger.info(
|
||||
"AI message already exists in the chat message,\
|
||||
overwriting it with new message."
|
||||
)
|
||||
|
||||
@@ -157,7 +157,6 @@ class AIAssistant:
|
||||
log_level=logging.INFO,
|
||||
collect_metrics=True,
|
||||
):
|
||||
|
||||
self.name = name or "AI Assistant"
|
||||
self.data_sources = data_sources or []
|
||||
self.log_level = log_level
|
||||
|
||||
@@ -11,6 +11,8 @@ from tqdm import tqdm
|
||||
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_content(content, type):
|
||||
implemented = ["html.parser", "lxml", "lxml-xml", "xml", "html5lib"]
|
||||
@@ -61,7 +63,7 @@ def parse_content(content, type):
|
||||
|
||||
cleaned_size = len(content)
|
||||
if original_size != 0:
|
||||
logging.info(
|
||||
logger.info(
|
||||
f"Cleaned page size: {cleaned_size} characters, down from {original_size} (shrunk: {original_size-cleaned_size} chars, {round((1-(cleaned_size/original_size)) * 100, 2)}%)" # noqa:E501
|
||||
)
|
||||
|
||||
@@ -208,31 +210,31 @@ def detect_datatype(source: Any) -> DataType:
|
||||
}
|
||||
|
||||
if url.netloc in YOUTUBE_ALLOWED_NETLOCKS:
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
|
||||
logger.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
|
||||
return DataType.YOUTUBE_VIDEO
|
||||
|
||||
if url.netloc in {"notion.so", "notion.site"}:
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `notion`.")
|
||||
logger.debug(f"Source of `{formatted_source}` detected as `notion`.")
|
||||
return DataType.NOTION
|
||||
|
||||
if url.path.endswith(".pdf"):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `pdf_file`.")
|
||||
logger.debug(f"Source of `{formatted_source}` detected as `pdf_file`.")
|
||||
return DataType.PDF_FILE
|
||||
|
||||
if url.path.endswith(".xml"):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `sitemap`.")
|
||||
logger.debug(f"Source of `{formatted_source}` detected as `sitemap`.")
|
||||
return DataType.SITEMAP
|
||||
|
||||
if url.path.endswith(".csv"):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `csv`.")
|
||||
logger.debug(f"Source of `{formatted_source}` detected as `csv`.")
|
||||
return DataType.CSV
|
||||
|
||||
if url.path.endswith(".mdx") or url.path.endswith(".md"):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `mdx`.")
|
||||
logger.debug(f"Source of `{formatted_source}` detected as `mdx`.")
|
||||
return DataType.MDX
|
||||
|
||||
if url.path.endswith(".docx"):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
|
||||
logger.debug(f"Source of `{formatted_source}` detected as `docx`.")
|
||||
return DataType.DOCX
|
||||
|
||||
if url.path.endswith(".yaml"):
|
||||
@@ -242,14 +244,14 @@ def detect_datatype(source: Any) -> DataType:
|
||||
try:
|
||||
yaml_content = yaml.safe_load(response.text)
|
||||
except yaml.YAMLError as exc:
|
||||
logging.error(f"Error parsing YAML: {exc}")
|
||||
logger.error(f"Error parsing YAML: {exc}")
|
||||
raise TypeError(f"Not a valid data type. Error loading YAML: {exc}")
|
||||
|
||||
if is_openapi_yaml(yaml_content):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `openapi`.")
|
||||
logger.debug(f"Source of `{formatted_source}` detected as `openapi`.")
|
||||
return DataType.OPENAPI
|
||||
else:
|
||||
logging.error(
|
||||
logger.error(
|
||||
f"Source of `{formatted_source}` does not contain all the required \
|
||||
fields of OpenAPI yaml. Check 'https://spec.openapis.org/oas/v3.1.0'"
|
||||
)
|
||||
@@ -258,35 +260,35 @@ def detect_datatype(source: Any) -> DataType:
|
||||
make sure you have all the required fields in YAML config data"
|
||||
)
|
||||
except requests.exceptions.RequestException as e:
|
||||
logging.error(f"Error fetching URL {formatted_source}: {e}")
|
||||
logger.error(f"Error fetching URL {formatted_source}: {e}")
|
||||
|
||||
if url.path.endswith(".json"):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `json_file`.")
|
||||
logger.debug(f"Source of `{formatted_source}` detected as `json_file`.")
|
||||
return DataType.JSON
|
||||
|
||||
if "docs" in url.netloc or ("docs" in url.path and url.scheme != "file"):
|
||||
# `docs_site` detection via path is not accepted for local filesystem URIs,
|
||||
# because that would mean all paths that contain `docs` are now doc sites, which is too aggressive.
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `docs_site`.")
|
||||
logger.debug(f"Source of `{formatted_source}` detected as `docs_site`.")
|
||||
return DataType.DOCS_SITE
|
||||
|
||||
if "github.com" in url.netloc:
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `github`.")
|
||||
logger.debug(f"Source of `{formatted_source}` detected as `github`.")
|
||||
return DataType.GITHUB
|
||||
|
||||
if is_google_drive_folder(url.netloc + url.path):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `google drive folder`.")
|
||||
logger.debug(f"Source of `{formatted_source}` detected as `google drive folder`.")
|
||||
return DataType.GOOGLE_DRIVE_FOLDER
|
||||
|
||||
# If none of the above conditions are met, it's a general web page
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `web_page`.")
|
||||
logger.debug(f"Source of `{formatted_source}` detected as `web_page`.")
|
||||
return DataType.WEB_PAGE
|
||||
|
||||
elif not isinstance(source, str):
|
||||
# For datatypes where source is not a string.
|
||||
|
||||
if isinstance(source, tuple) and len(source) == 2 and isinstance(source[0], str) and isinstance(source[1], str):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `qna_pair`.")
|
||||
logger.debug(f"Source of `{formatted_source}` detected as `qna_pair`.")
|
||||
return DataType.QNA_PAIR
|
||||
|
||||
# Raise an error if it isn't a string and also not a valid non-string type (one of the previous).
|
||||
@@ -300,37 +302,37 @@ def detect_datatype(source: Any) -> DataType:
|
||||
# Note: checking for string is not necessary anymore.
|
||||
|
||||
if source.endswith(".docx"):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
|
||||
logger.debug(f"Source of `{formatted_source}` detected as `docx`.")
|
||||
return DataType.DOCX
|
||||
|
||||
if source.endswith(".csv"):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `csv`.")
|
||||
logger.debug(f"Source of `{formatted_source}` detected as `csv`.")
|
||||
return DataType.CSV
|
||||
|
||||
if source.endswith(".xml"):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `xml`.")
|
||||
logger.debug(f"Source of `{formatted_source}` detected as `xml`.")
|
||||
return DataType.XML
|
||||
|
||||
if source.endswith(".mdx") or source.endswith(".md"):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `mdx`.")
|
||||
logger.debug(f"Source of `{formatted_source}` detected as `mdx`.")
|
||||
return DataType.MDX
|
||||
|
||||
if source.endswith(".txt"):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `text`.")
|
||||
logger.debug(f"Source of `{formatted_source}` detected as `text`.")
|
||||
return DataType.TEXT_FILE
|
||||
|
||||
if source.endswith(".pdf"):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `pdf_file`.")
|
||||
logger.debug(f"Source of `{formatted_source}` detected as `pdf_file`.")
|
||||
return DataType.PDF_FILE
|
||||
|
||||
if source.endswith(".yaml"):
|
||||
with open(source, "r") as file:
|
||||
yaml_content = yaml.safe_load(file)
|
||||
if is_openapi_yaml(yaml_content):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `openapi`.")
|
||||
logger.debug(f"Source of `{formatted_source}` detected as `openapi`.")
|
||||
return DataType.OPENAPI
|
||||
else:
|
||||
logging.error(
|
||||
logger.error(
|
||||
f"Source of `{formatted_source}` does not contain all the required \
|
||||
fields of OpenAPI yaml. Check 'https://spec.openapis.org/oas/v3.1.0'"
|
||||
)
|
||||
@@ -340,11 +342,11 @@ def detect_datatype(source: Any) -> DataType:
|
||||
)
|
||||
|
||||
if source.endswith(".json"):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `json`.")
|
||||
logger.debug(f"Source of `{formatted_source}` detected as `json`.")
|
||||
return DataType.JSON
|
||||
|
||||
if os.path.exists(source) and is_readable(open(source).read()):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `text_file`.")
|
||||
logger.debug(f"Source of `{formatted_source}` detected as `text_file`.")
|
||||
return DataType.TEXT_FILE
|
||||
|
||||
# If the source is a valid file, that's not detectable as a type, an error is raised.
|
||||
@@ -360,11 +362,11 @@ def detect_datatype(source: Any) -> DataType:
|
||||
|
||||
# check if the source is valid json string
|
||||
if is_valid_json_string(source):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `json`.")
|
||||
logger.debug(f"Source of `{formatted_source}` detected as `json`.")
|
||||
return DataType.JSON
|
||||
|
||||
# Use text as final fallback.
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `text`.")
|
||||
logger.debug(f"Source of `{formatted_source}` detected as `text`.")
|
||||
return DataType.TEXT
|
||||
|
||||
|
||||
|
||||
@@ -22,6 +22,9 @@ except RuntimeError:
|
||||
from chromadb.errors import InvalidDimensionException
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class ChromaDB(BaseVectorDB):
|
||||
"""Vector database using ChromaDB."""
|
||||
@@ -47,7 +50,7 @@ class ChromaDB(BaseVectorDB):
|
||||
setattr(self.settings, key, value)
|
||||
|
||||
if self.config.host and self.config.port:
|
||||
logging.info(f"Connecting to ChromaDB server: {self.config.host}:{self.config.port}")
|
||||
logger.info(f"Connecting to ChromaDB server: {self.config.host}:{self.config.port}")
|
||||
self.settings.chroma_server_host = self.config.host
|
||||
self.settings.chroma_server_http_port = self.config.port
|
||||
self.settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"
|
||||
|
||||
@@ -14,6 +14,8 @@ from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.utils.misc import chunks
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class ElasticsearchDB(BaseVectorDB):
|
||||
@@ -62,7 +64,7 @@ class ElasticsearchDB(BaseVectorDB):
|
||||
"""
|
||||
This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
|
||||
"""
|
||||
logging.info(self.client.info())
|
||||
logger.info(self.client.info())
|
||||
index_settings = {
|
||||
"mappings": {
|
||||
"properties": {
|
||||
|
||||
@@ -19,6 +19,8 @@ from embedchain.config import OpenSearchDBConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class OpenSearchDB(BaseVectorDB):
|
||||
@@ -43,12 +45,12 @@ class OpenSearchDB(BaseVectorDB):
|
||||
**self.config.extra_params,
|
||||
)
|
||||
info = self.client.info()
|
||||
logging.info(f"Connected to {info['version']['distribution']}. Version: {info['version']['number']}")
|
||||
logger.info(f"Connected to {info['version']['distribution']}. Version: {info['version']['number']}")
|
||||
# Remove auth credentials from config after successful connection
|
||||
super().__init__(config=self.config)
|
||||
|
||||
def _initialize(self):
|
||||
logging.info(self.client.info())
|
||||
logger.info(self.client.info())
|
||||
index_name = self._get_index()
|
||||
if self.client.indices.exists(index=index_name):
|
||||
print(f"Index '{index_name}' already exists.")
|
||||
|
||||
@@ -16,6 +16,8 @@ from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.utils.misc import chunks
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class PineconeDB(BaseVectorDB):
|
||||
@@ -49,7 +51,7 @@ class PineconeDB(BaseVectorDB):
|
||||
# Setup BM25Encoder if sparse vectors are to be used
|
||||
self.bm25_encoder = None
|
||||
if self.config.hybrid_search:
|
||||
logging.info("Initializing BM25Encoder for sparse vectors..")
|
||||
logger.info("Initializing BM25Encoder for sparse vectors..")
|
||||
self.bm25_encoder = self.config.bm25_encoder if self.config.bm25_encoder else BM25Encoder.default()
|
||||
|
||||
# Call parent init here because embedder is needed
|
||||
|
||||
@@ -13,6 +13,8 @@ except ImportError:
|
||||
"Zilliz requires extra dependencies. Install with `pip install --upgrade embedchain[milvus]`"
|
||||
) from None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class ZillizVectorDB(BaseVectorDB):
|
||||
@@ -62,7 +64,7 @@ class ZillizVectorDB(BaseVectorDB):
|
||||
:type name: str
|
||||
"""
|
||||
if utility.has_collection(name):
|
||||
logging.info(f"[ZillizDB]: found an existing collection {name}, make sure the auto-id is disabled.")
|
||||
logger.info(f"[ZillizDB]: found an existing collection {name}, make sure the auto-id is disabled.")
|
||||
self.collection = Collection(name)
|
||||
else:
|
||||
fields = [
|
||||
|
||||
Reference in New Issue
Block a user