[Improvements] Improve logging and fix insertion in data_sources table (#1337)

This commit is contained in:
Deshraj Yadav
2024-04-11 15:00:04 -07:00
committed by GitHub
parent f8619870ad
commit 536f85b78a
4 changed files with 15 additions and 16 deletions

View File

@@ -9,14 +9,9 @@ import requests
import yaml import yaml
from tqdm import tqdm from tqdm import tqdm
from embedchain.cache import ( from embedchain.cache import (Config, ExactMatchEvaluation,
Config, SearchDistanceEvaluation, cache,
ExactMatchEvaluation, gptcache_data_manager, gptcache_pre_function)
SearchDistanceEvaluation,
cache,
gptcache_data_manager,
gptcache_pre_function,
)
from embedchain.client import Client from embedchain.client import Client
from embedchain.config import AppConfig, CacheConfig, ChunkerConfig from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
from embedchain.core.db.database import get_session, init_db, setup_engine from embedchain.core.db.database import get_session, init_db, setup_engine
@@ -25,7 +20,8 @@ from embedchain.embedchain import EmbedChain
from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.base import BaseEmbedder
from embedchain.embedder.openai import OpenAIEmbedder from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.evaluation.base import BaseMetric from embedchain.evaluation.base import BaseMetric
from embedchain.evaluation.metrics import AnswerRelevance, ContextRelevance, Groundedness from embedchain.evaluation.metrics import (AnswerRelevance, ContextRelevance,
Groundedness)
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm from embedchain.llm.base import BaseLlm

View File

@@ -179,6 +179,10 @@ class EmbedChain(JSONSerializable):
if data_type in {DataType.DOCS_SITE}: if data_type in {DataType.DOCS_SITE}:
self.is_docs_site_instance = True self.is_docs_site_instance = True
# Convert the source to a string if it is not already
if not isinstance(source, str):
source = str(source)
# Insert the data into the 'ec_data_sources' table # Insert the data into the 'ec_data_sources' table
self.db_session.add( self.db_session.add(
DataSource( DataSource(
@@ -310,12 +314,12 @@ class EmbedChain(JSONSerializable):
new_doc_id = embeddings_data["doc_id"] new_doc_id = embeddings_data["doc_id"]
if existing_doc_id and existing_doc_id == new_doc_id: if existing_doc_id and existing_doc_id == new_doc_id:
print("Doc content has not changed. Skipping creating chunks and embeddings") logger.info("Doc content has not changed. Skipping creating chunks and embeddings")
return [], [], [], 0 return [], [], [], 0
# this means that doc content has changed. # this means that doc content has changed.
if existing_doc_id and existing_doc_id != new_doc_id: if existing_doc_id and existing_doc_id != new_doc_id:
print("Doc content has changed. Recomputing chunks and embeddings intelligently.") logger.info("Doc content has changed. Recomputing chunks and embeddings intelligently.")
self.db.delete({"doc_id": existing_doc_id}) self.db.delete({"doc_id": existing_doc_id})
# get existing ids, and discard doc if any common id exist. # get existing ids, and discard doc if any common id exist.
@@ -341,7 +345,7 @@ class EmbedChain(JSONSerializable):
src_copy = src src_copy = src
if len(src_copy) > 50: if len(src_copy) > 50:
src_copy = src[:50] + "..." src_copy = src[:50] + "..."
print(f"All data from {src_copy} already exists in the database.") logger.info(f"All data from {src_copy} already exists in the database.")
# Make sure to return a matching return type # Make sure to return a matching return type
return [], [], [], 0 return [], [], [], 0
@@ -388,12 +392,12 @@ class EmbedChain(JSONSerializable):
if batch_docs: if batch_docs:
self.db.add(documents=batch_docs, metadatas=batch_meta, ids=batch_ids, **kwargs) self.db.add(documents=batch_docs, metadatas=batch_meta, ids=batch_ids, **kwargs)
except Exception as e: except Exception as e:
print(f"Failed to add batch due to a bad request: {e}") logger.info(f"Failed to add batch due to a bad request: {e}")
# Handle the error, e.g., by logging, retrying, or skipping # Handle the error, e.g., by logging, retrying, or skipping
pass pass
count_new_chunks = self.db.count() - chunks_before_addition count_new_chunks = self.db.count() - chunks_before_addition
print(f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}") logger.info(f"Successfully saved {str(src)[:100]} ({chunker.data_type}). New chunks count: {count_new_chunks}")
return list(documents), metadatas, ids, count_new_chunks return list(documents), metadatas, ids, count_new_chunks

View File

@@ -26,7 +26,6 @@ class AnthropicLlm(BaseLlm):
@staticmethod @staticmethod
def _get_answer(prompt: str, config: BaseLlmConfig) -> str: def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
chat = ChatAnthropic( chat = ChatAnthropic(
anthropic_api_key=os.environ["ANTHROPIC_API_KEY"], temperature=config.temperature, model_name=config.model anthropic_api_key=os.environ["ANTHROPIC_API_KEY"], temperature=config.temperature, model_name=config.model
) )

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "embedchain" name = "embedchain"
version = "0.1.99" version = "0.1.100"
description = "Simplest open source retrieval (RAG) framework" description = "Simplest open source retrieval (RAG) framework"
authors = [ authors = [
"Taranjeet Singh <taranjeet@embedchain.ai>", "Taranjeet Singh <taranjeet@embedchain.ai>",