feat: Update line length to 120 chars (#278)

This commit is contained in:
Deshraj Yadav
2023-07-15 07:11:55 -07:00
committed by GitHub
parent 4f722621fd
commit fd97fb268a
9 changed files with 31 additions and 63 deletions

View File

@@ -46,4 +46,4 @@ class BaseChunker:
Override in child class if custom logic.
"""
return self.text_splitter.split_text(content)
return self.text_splitter.split_text(content)

View File

@@ -19,4 +19,4 @@ class CodeDocsPageChunker(BaseChunker):
if config is None:
config = TEXT_SPLITTER_CHUNK_PARAMS
text_splitter = RecursiveCharacterTextSplitter(**config)
super().__init__(text_splitter)
super().__init__(text_splitter)

View File

@@ -40,13 +40,8 @@ class InitConfig(BaseConfig):
:raises ValueError: If the template is not valid as template should contain
$context and $query
"""
if (
os.getenv("OPENAI_API_KEY") is None
and os.getenv("OPENAI_ORGANIZATION") is None
):
raise ValueError(
"OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided" # noqa:E501
)
if os.getenv("OPENAI_API_KEY") is None and os.getenv("OPENAI_ORGANIZATION") is None:
raise ValueError("OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided") # noqa:E501
self.ef = embedding_functions.OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"),
organization_id=os.getenv("OPENAI_ORGANIZATION"),
@@ -74,8 +69,6 @@ class InitConfig(BaseConfig):
if not isinstance(level, int):
raise ValueError(f"Invalid log level: {debug_level}")
logging.basicConfig(
format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s", level=level
)
logging.basicConfig(format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s", level=level)
self.logger = logging.getLogger(__name__)
return

View File

@@ -113,9 +113,7 @@ class QueryConfig(BaseConfig):
if self.history is None:
raise ValueError("`template` should have `query` and `context` keys")
else:
raise ValueError(
"`template` should have `query`, `context` and `history` keys"
)
raise ValueError("`template` should have `query`, `context` and `history` keys")
if not isinstance(stream, bool):
raise ValueError("`stream` should be bool")
@@ -129,9 +127,7 @@ class QueryConfig(BaseConfig):
:return: Boolean, valid (true) or invalid (false)
"""
if self.history is None:
return re.search(query_re, template.template) and re.search(
context_re, template.template
)
return re.search(query_re, template.template) and re.search(context_re, template.template)
else:
return (
re.search(query_re, template.template)

View File

@@ -66,7 +66,7 @@ class DataFormatter:
"text": TextChunker(config),
"docx": DocxFileChunker(config),
"sitemap": WebPageChunker(config),
"code_docs_page": CodeDocsPageChunker(config)
"code_docs_page": CodeDocsPageChunker(config),
}
if data_type in chunkers:
return chunkers[data_type]

View File

@@ -9,7 +9,7 @@ from langchain.docstore.document import Document
from langchain.memory import ConversationBufferMemory
from embedchain.config import AddConfig, ChatConfig, InitConfig, QueryConfig
from embedchain.config.QueryConfig import DEFAULT_PROMPT, CODE_DOCS_PAGE_PROMPT_TEMPLATE
from embedchain.config.QueryConfig import CODE_DOCS_PAGE_PROMPT_TEMPLATE, DEFAULT_PROMPT
from embedchain.data_formatter import DataFormatter
gpt4all_model = None
@@ -54,10 +54,8 @@ class EmbedChain:
data_formatter = DataFormatter(data_type, config)
self.user_asks.append([data_type, url, metadata])
self.load_and_embed(
data_formatter.loader, data_formatter.chunker, url, metadata
)
if data_type in ("code_docs_page", ):
self.load_and_embed(data_formatter.loader, data_formatter.chunker, url, metadata)
if data_type in ("code_docs_page",):
self.is_code_docs_instance = True
def add_local(self, data_type, content, metadata=None, config: AddConfig = None):
@@ -106,12 +104,8 @@ class EmbedChain:
existing_ids = set(existing_docs["ids"])
if len(existing_ids):
data_dict = {
id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)
}
data_dict = {
id: value for id, value in data_dict.items() if id not in existing_ids
}
data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)}
data_dict = {id: value for id, value in data_dict.items() if id not in existing_ids}
if not data_dict:
print(f"All data from {src} already exists in the database.")
@@ -125,15 +119,8 @@ class EmbedChain:
# Add metadata to each document
metadatas_with_metadata = [meta or metadata for meta in metadatas]
self.collection.add(
documents=documents, metadatas=list(metadatas_with_metadata), ids=ids
)
print(
(
f"Successfully saved {src}. New chunks count: "
f"{self.count() - chunks_before_addition}"
)
)
self.collection.add(documents=documents, metadatas=list(metadatas_with_metadata), ids=ids)
print((f"Successfully saved {src}. New chunks count: " f"{self.count() - chunks_before_addition}"))
def _format_result(self, results):
return [
@@ -180,13 +167,9 @@ class EmbedChain:
"""
context_string = (" | ").join(contexts)
if not config.history:
prompt = config.template.substitute(
context=context_string, query=input_query
)
prompt = config.template.substitute(context=context_string, query=input_query)
else:
prompt = config.template.substitute(
context=context_string, query=input_query, history=config.history
)
prompt = config.template.substitute(context=context_string, query=input_query, history=config.history)
return prompt
def get_answer_from_llm(self, prompt, config: ChatConfig):
@@ -387,17 +370,13 @@ class OpenSourceApp(EmbedChain):
:param config: InitConfig instance to load as configuration. Optional.
`ef` defaults to open source.
"""
print(
"Loading open source embedding model. This may take some time..."
) # noqa:E501
print("Loading open source embedding model. This may take some time...") # noqa:E501
if not config:
config = InitConfig()
if not config.ef:
config._set_embedding_function(
embedding_functions.SentenceTransformerEmbeddingFunction(
model_name="all-MiniLM-L6-v2"
)
embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")
)
if not config.db:

View File

@@ -3,6 +3,7 @@ from bs4 import BeautifulSoup
from embedchain.utils import clean_string
class CodeDocsPageLoader:
def load_data(self, url):
"""Load data from a web page."""
@@ -10,14 +11,14 @@ class CodeDocsPageLoader:
data = response.content
soup = BeautifulSoup(data, "html.parser")
selectors = [
'article.bd-article',
"article.bd-article",
'article[role="main"]',
'div.md-content',
"div.md-content",
'div[role="main"]',
'div.container',
'div.section',
'article',
'main',
"div.container",
"div.section",
"article",
"main",
]
content = None
for selector in selectors:
@@ -43,11 +44,11 @@ class CodeDocsPageLoader:
]
):
tag.string = " "
for div in soup.find_all("div", {'class': 'cell_output'}):
for div in soup.find_all("div", {"class": "cell_output"}):
div.decompose()
for div in soup.find_all("div", {'class': 'output_wrapper'}):
for div in soup.find_all("div", {"class": "output_wrapper"}):
div.decompose()
for div in soup.find_all("div", {'class': 'output'}):
for div in soup.find_all("div", {"class": "output"}):
div.decompose()
content = clean_string(soup.get_text())
output = []