feat: Update line length to 120 chars (#278)
This commit is contained in:
@@ -46,4 +46,4 @@ class BaseChunker:
|
||||
|
||||
Override in child class if custom logic.
|
||||
"""
|
||||
return self.text_splitter.split_text(content)
|
||||
return self.text_splitter.split_text(content)
|
||||
|
||||
@@ -19,4 +19,4 @@ class CodeDocsPageChunker(BaseChunker):
|
||||
if config is None:
|
||||
config = TEXT_SPLITTER_CHUNK_PARAMS
|
||||
text_splitter = RecursiveCharacterTextSplitter(**config)
|
||||
super().__init__(text_splitter)
|
||||
super().__init__(text_splitter)
|
||||
|
||||
@@ -40,13 +40,8 @@ class InitConfig(BaseConfig):
|
||||
:raises ValueError: If the template is not valid as template should contain
|
||||
$context and $query
|
||||
"""
|
||||
if (
|
||||
os.getenv("OPENAI_API_KEY") is None
|
||||
and os.getenv("OPENAI_ORGANIZATION") is None
|
||||
):
|
||||
raise ValueError(
|
||||
"OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided" # noqa:E501
|
||||
)
|
||||
if os.getenv("OPENAI_API_KEY") is None and os.getenv("OPENAI_ORGANIZATION") is None:
|
||||
raise ValueError("OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided") # noqa:E501
|
||||
self.ef = embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
organization_id=os.getenv("OPENAI_ORGANIZATION"),
|
||||
@@ -74,8 +69,6 @@ class InitConfig(BaseConfig):
|
||||
if not isinstance(level, int):
|
||||
raise ValueError(f"Invalid log level: {debug_level}")
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s", level=level
|
||||
)
|
||||
logging.basicConfig(format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s", level=level)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
return
|
||||
|
||||
@@ -113,9 +113,7 @@ class QueryConfig(BaseConfig):
|
||||
if self.history is None:
|
||||
raise ValueError("`template` should have `query` and `context` keys")
|
||||
else:
|
||||
raise ValueError(
|
||||
"`template` should have `query`, `context` and `history` keys"
|
||||
)
|
||||
raise ValueError("`template` should have `query`, `context` and `history` keys")
|
||||
|
||||
if not isinstance(stream, bool):
|
||||
raise ValueError("`stream` should be bool")
|
||||
@@ -129,9 +127,7 @@ class QueryConfig(BaseConfig):
|
||||
:return: Boolean, valid (true) or invalid (false)
|
||||
"""
|
||||
if self.history is None:
|
||||
return re.search(query_re, template.template) and re.search(
|
||||
context_re, template.template
|
||||
)
|
||||
return re.search(query_re, template.template) and re.search(context_re, template.template)
|
||||
else:
|
||||
return (
|
||||
re.search(query_re, template.template)
|
||||
|
||||
@@ -66,7 +66,7 @@ class DataFormatter:
|
||||
"text": TextChunker(config),
|
||||
"docx": DocxFileChunker(config),
|
||||
"sitemap": WebPageChunker(config),
|
||||
"code_docs_page": CodeDocsPageChunker(config)
|
||||
"code_docs_page": CodeDocsPageChunker(config),
|
||||
}
|
||||
if data_type in chunkers:
|
||||
return chunkers[data_type]
|
||||
|
||||
@@ -9,7 +9,7 @@ from langchain.docstore.document import Document
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
|
||||
from embedchain.config import AddConfig, ChatConfig, InitConfig, QueryConfig
|
||||
from embedchain.config.QueryConfig import DEFAULT_PROMPT, CODE_DOCS_PAGE_PROMPT_TEMPLATE
|
||||
from embedchain.config.QueryConfig import CODE_DOCS_PAGE_PROMPT_TEMPLATE, DEFAULT_PROMPT
|
||||
from embedchain.data_formatter import DataFormatter
|
||||
|
||||
gpt4all_model = None
|
||||
@@ -54,10 +54,8 @@ class EmbedChain:
|
||||
|
||||
data_formatter = DataFormatter(data_type, config)
|
||||
self.user_asks.append([data_type, url, metadata])
|
||||
self.load_and_embed(
|
||||
data_formatter.loader, data_formatter.chunker, url, metadata
|
||||
)
|
||||
if data_type in ("code_docs_page", ):
|
||||
self.load_and_embed(data_formatter.loader, data_formatter.chunker, url, metadata)
|
||||
if data_type in ("code_docs_page",):
|
||||
self.is_code_docs_instance = True
|
||||
|
||||
def add_local(self, data_type, content, metadata=None, config: AddConfig = None):
|
||||
@@ -106,12 +104,8 @@ class EmbedChain:
|
||||
existing_ids = set(existing_docs["ids"])
|
||||
|
||||
if len(existing_ids):
|
||||
data_dict = {
|
||||
id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)
|
||||
}
|
||||
data_dict = {
|
||||
id: value for id, value in data_dict.items() if id not in existing_ids
|
||||
}
|
||||
data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)}
|
||||
data_dict = {id: value for id, value in data_dict.items() if id not in existing_ids}
|
||||
|
||||
if not data_dict:
|
||||
print(f"All data from {src} already exists in the database.")
|
||||
@@ -125,15 +119,8 @@ class EmbedChain:
|
||||
# Add metadata to each document
|
||||
metadatas_with_metadata = [meta or metadata for meta in metadatas]
|
||||
|
||||
self.collection.add(
|
||||
documents=documents, metadatas=list(metadatas_with_metadata), ids=ids
|
||||
)
|
||||
print(
|
||||
(
|
||||
f"Successfully saved {src}. New chunks count: "
|
||||
f"{self.count() - chunks_before_addition}"
|
||||
)
|
||||
)
|
||||
self.collection.add(documents=documents, metadatas=list(metadatas_with_metadata), ids=ids)
|
||||
print((f"Successfully saved {src}. New chunks count: " f"{self.count() - chunks_before_addition}"))
|
||||
|
||||
def _format_result(self, results):
|
||||
return [
|
||||
@@ -180,13 +167,9 @@ class EmbedChain:
|
||||
"""
|
||||
context_string = (" | ").join(contexts)
|
||||
if not config.history:
|
||||
prompt = config.template.substitute(
|
||||
context=context_string, query=input_query
|
||||
)
|
||||
prompt = config.template.substitute(context=context_string, query=input_query)
|
||||
else:
|
||||
prompt = config.template.substitute(
|
||||
context=context_string, query=input_query, history=config.history
|
||||
)
|
||||
prompt = config.template.substitute(context=context_string, query=input_query, history=config.history)
|
||||
return prompt
|
||||
|
||||
def get_answer_from_llm(self, prompt, config: ChatConfig):
|
||||
@@ -387,17 +370,13 @@ class OpenSourceApp(EmbedChain):
|
||||
:param config: InitConfig instance to load as configuration. Optional.
|
||||
`ef` defaults to open source.
|
||||
"""
|
||||
print(
|
||||
"Loading open source embedding model. This may take some time..."
|
||||
) # noqa:E501
|
||||
print("Loading open source embedding model. This may take some time...") # noqa:E501
|
||||
if not config:
|
||||
config = InitConfig()
|
||||
|
||||
if not config.ef:
|
||||
config._set_embedding_function(
|
||||
embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
model_name="all-MiniLM-L6-v2"
|
||||
)
|
||||
embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")
|
||||
)
|
||||
|
||||
if not config.db:
|
||||
|
||||
@@ -3,6 +3,7 @@ from bs4 import BeautifulSoup
|
||||
|
||||
from embedchain.utils import clean_string
|
||||
|
||||
|
||||
class CodeDocsPageLoader:
|
||||
def load_data(self, url):
|
||||
"""Load data from a web page."""
|
||||
@@ -10,14 +11,14 @@ class CodeDocsPageLoader:
|
||||
data = response.content
|
||||
soup = BeautifulSoup(data, "html.parser")
|
||||
selectors = [
|
||||
'article.bd-article',
|
||||
"article.bd-article",
|
||||
'article[role="main"]',
|
||||
'div.md-content',
|
||||
"div.md-content",
|
||||
'div[role="main"]',
|
||||
'div.container',
|
||||
'div.section',
|
||||
'article',
|
||||
'main',
|
||||
"div.container",
|
||||
"div.section",
|
||||
"article",
|
||||
"main",
|
||||
]
|
||||
content = None
|
||||
for selector in selectors:
|
||||
@@ -43,11 +44,11 @@ class CodeDocsPageLoader:
|
||||
]
|
||||
):
|
||||
tag.string = " "
|
||||
for div in soup.find_all("div", {'class': 'cell_output'}):
|
||||
for div in soup.find_all("div", {"class": "cell_output"}):
|
||||
div.decompose()
|
||||
for div in soup.find_all("div", {'class': 'output_wrapper'}):
|
||||
for div in soup.find_all("div", {"class": "output_wrapper"}):
|
||||
div.decompose()
|
||||
for div in soup.find_all("div", {'class': 'output'}):
|
||||
for div in soup.find_all("div", {"class": "output"}):
|
||||
div.decompose()
|
||||
content = clean_string(soup.get_text())
|
||||
output = []
|
||||
|
||||
Reference in New Issue
Block a user