feat: Update line length to 120 chars (#278)

This commit is contained in:
Deshraj Yadav
2023-07-15 07:11:55 -07:00
committed by GitHub
parent 4f722621fd
commit fd97fb268a
9 changed files with 31 additions and 63 deletions

View File

@@ -40,13 +40,8 @@ class InitConfig(BaseConfig):
:raises ValueError: If the template is not valid as template should contain :raises ValueError: If the template is not valid as template should contain
$context and $query $context and $query
""" """
if ( if os.getenv("OPENAI_API_KEY") is None and os.getenv("OPENAI_ORGANIZATION") is None:
os.getenv("OPENAI_API_KEY") is None raise ValueError("OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided") # noqa:E501
and os.getenv("OPENAI_ORGANIZATION") is None
):
raise ValueError(
"OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided" # noqa:E501
)
self.ef = embedding_functions.OpenAIEmbeddingFunction( self.ef = embedding_functions.OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), api_key=os.getenv("OPENAI_API_KEY"),
organization_id=os.getenv("OPENAI_ORGANIZATION"), organization_id=os.getenv("OPENAI_ORGANIZATION"),
@@ -74,8 +69,6 @@ class InitConfig(BaseConfig):
if not isinstance(level, int): if not isinstance(level, int):
raise ValueError(f"Invalid log level: {debug_level}") raise ValueError(f"Invalid log level: {debug_level}")
logging.basicConfig( logging.basicConfig(format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s", level=level)
format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s", level=level
)
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
return return

View File

@@ -113,9 +113,7 @@ class QueryConfig(BaseConfig):
if self.history is None: if self.history is None:
raise ValueError("`template` should have `query` and `context` keys") raise ValueError("`template` should have `query` and `context` keys")
else: else:
raise ValueError( raise ValueError("`template` should have `query`, `context` and `history` keys")
"`template` should have `query`, `context` and `history` keys"
)
if not isinstance(stream, bool): if not isinstance(stream, bool):
raise ValueError("`stream` should be bool") raise ValueError("`stream` should be bool")
@@ -129,9 +127,7 @@ class QueryConfig(BaseConfig):
:return: Boolean, valid (true) or invalid (false) :return: Boolean, valid (true) or invalid (false)
""" """
if self.history is None: if self.history is None:
return re.search(query_re, template.template) and re.search( return re.search(query_re, template.template) and re.search(context_re, template.template)
context_re, template.template
)
else: else:
return ( return (
re.search(query_re, template.template) re.search(query_re, template.template)

View File

@@ -66,7 +66,7 @@ class DataFormatter:
"text": TextChunker(config), "text": TextChunker(config),
"docx": DocxFileChunker(config), "docx": DocxFileChunker(config),
"sitemap": WebPageChunker(config), "sitemap": WebPageChunker(config),
"code_docs_page": CodeDocsPageChunker(config) "code_docs_page": CodeDocsPageChunker(config),
} }
if data_type in chunkers: if data_type in chunkers:
return chunkers[data_type] return chunkers[data_type]

View File

@@ -9,7 +9,7 @@ from langchain.docstore.document import Document
from langchain.memory import ConversationBufferMemory from langchain.memory import ConversationBufferMemory
from embedchain.config import AddConfig, ChatConfig, InitConfig, QueryConfig from embedchain.config import AddConfig, ChatConfig, InitConfig, QueryConfig
from embedchain.config.QueryConfig import DEFAULT_PROMPT, CODE_DOCS_PAGE_PROMPT_TEMPLATE from embedchain.config.QueryConfig import CODE_DOCS_PAGE_PROMPT_TEMPLATE, DEFAULT_PROMPT
from embedchain.data_formatter import DataFormatter from embedchain.data_formatter import DataFormatter
gpt4all_model = None gpt4all_model = None
@@ -54,10 +54,8 @@ class EmbedChain:
data_formatter = DataFormatter(data_type, config) data_formatter = DataFormatter(data_type, config)
self.user_asks.append([data_type, url, metadata]) self.user_asks.append([data_type, url, metadata])
self.load_and_embed( self.load_and_embed(data_formatter.loader, data_formatter.chunker, url, metadata)
data_formatter.loader, data_formatter.chunker, url, metadata if data_type in ("code_docs_page",):
)
if data_type in ("code_docs_page", ):
self.is_code_docs_instance = True self.is_code_docs_instance = True
def add_local(self, data_type, content, metadata=None, config: AddConfig = None): def add_local(self, data_type, content, metadata=None, config: AddConfig = None):
@@ -106,12 +104,8 @@ class EmbedChain:
existing_ids = set(existing_docs["ids"]) existing_ids = set(existing_docs["ids"])
if len(existing_ids): if len(existing_ids):
data_dict = { data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)}
id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas) data_dict = {id: value for id, value in data_dict.items() if id not in existing_ids}
}
data_dict = {
id: value for id, value in data_dict.items() if id not in existing_ids
}
if not data_dict: if not data_dict:
print(f"All data from {src} already exists in the database.") print(f"All data from {src} already exists in the database.")
@@ -125,15 +119,8 @@ class EmbedChain:
# Add metadata to each document # Add metadata to each document
metadatas_with_metadata = [meta or metadata for meta in metadatas] metadatas_with_metadata = [meta or metadata for meta in metadatas]
self.collection.add( self.collection.add(documents=documents, metadatas=list(metadatas_with_metadata), ids=ids)
documents=documents, metadatas=list(metadatas_with_metadata), ids=ids print((f"Successfully saved {src}. New chunks count: " f"{self.count() - chunks_before_addition}"))
)
print(
(
f"Successfully saved {src}. New chunks count: "
f"{self.count() - chunks_before_addition}"
)
)
def _format_result(self, results): def _format_result(self, results):
return [ return [
@@ -180,13 +167,9 @@ class EmbedChain:
""" """
context_string = (" | ").join(contexts) context_string = (" | ").join(contexts)
if not config.history: if not config.history:
prompt = config.template.substitute( prompt = config.template.substitute(context=context_string, query=input_query)
context=context_string, query=input_query
)
else: else:
prompt = config.template.substitute( prompt = config.template.substitute(context=context_string, query=input_query, history=config.history)
context=context_string, query=input_query, history=config.history
)
return prompt return prompt
def get_answer_from_llm(self, prompt, config: ChatConfig): def get_answer_from_llm(self, prompt, config: ChatConfig):
@@ -387,17 +370,13 @@ class OpenSourceApp(EmbedChain):
:param config: InitConfig instance to load as configuration. Optional. :param config: InitConfig instance to load as configuration. Optional.
`ef` defaults to open source. `ef` defaults to open source.
""" """
print( print("Loading open source embedding model. This may take some time...") # noqa:E501
"Loading open source embedding model. This may take some time..."
) # noqa:E501
if not config: if not config:
config = InitConfig() config = InitConfig()
if not config.ef: if not config.ef:
config._set_embedding_function( config._set_embedding_function(
embedding_functions.SentenceTransformerEmbeddingFunction( embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")
model_name="all-MiniLM-L6-v2"
)
) )
if not config.db: if not config.db:

View File

@@ -3,6 +3,7 @@ from bs4 import BeautifulSoup
from embedchain.utils import clean_string from embedchain.utils import clean_string
class CodeDocsPageLoader: class CodeDocsPageLoader:
def load_data(self, url): def load_data(self, url):
"""Load data from a web page.""" """Load data from a web page."""
@@ -10,14 +11,14 @@ class CodeDocsPageLoader:
data = response.content data = response.content
soup = BeautifulSoup(data, "html.parser") soup = BeautifulSoup(data, "html.parser")
selectors = [ selectors = [
'article.bd-article', "article.bd-article",
'article[role="main"]', 'article[role="main"]',
'div.md-content', "div.md-content",
'div[role="main"]', 'div[role="main"]',
'div.container', "div.container",
'div.section', "div.section",
'article', "article",
'main', "main",
] ]
content = None content = None
for selector in selectors: for selector in selectors:
@@ -43,11 +44,11 @@ class CodeDocsPageLoader:
] ]
): ):
tag.string = " " tag.string = " "
for div in soup.find_all("div", {'class': 'cell_output'}): for div in soup.find_all("div", {"class": "cell_output"}):
div.decompose() div.decompose()
for div in soup.find_all("div", {'class': 'output_wrapper'}): for div in soup.find_all("div", {"class": "output_wrapper"}):
div.decompose() div.decompose()
for div in soup.find_all("div", {'class': 'output'}): for div in soup.find_all("div", {"class": "output"}):
div.decompose() div.decompose()
content = clean_string(soup.get_text()) content = clean_string(soup.get_text())
output = [] output = []

View File

@@ -30,7 +30,7 @@ exclude = [
"node_modules", "node_modules",
"venv", "venv",
] ]
line-length = 88 line-length = 120
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
target-version = "py38" target-version = "py38"
@@ -38,7 +38,7 @@ target-version = "py38"
max-complexity = 10 max-complexity = 10
[tool.black] [tool.black]
line-length = 88 line-length = 120
target-version = ["py38", "py39", "py310", "py311"] target-version = ["py38", "py39", "py310", "py311"]
include = '\.pyi?$' include = '\.pyi?$'
exclude = ''' exclude = '''

View File

@@ -1,6 +1,5 @@
import setuptools import setuptools
with open("README.md", "r", encoding="utf-8") as fh: with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read() long_description = fh.read()