From fd97fb268ac047557dedfec79b507dc072bd4cf3 Mon Sep 17 00:00:00 2001 From: Deshraj Yadav Date: Sat, 15 Jul 2023 07:11:55 -0700 Subject: [PATCH] feat: Update line length to 120 chars (#278) --- embedchain/chunkers/base_chunker.py | 2 +- embedchain/chunkers/code_docs_page.py | 2 +- embedchain/config/InitConfig.py | 13 ++----- embedchain/config/QueryConfig.py | 8 +--- embedchain/data_formatter/data_formatter.py | 2 +- embedchain/embedchain.py | 43 ++++++--------------- embedchain/loaders/code_docs_page.py | 19 ++++----- pyproject.toml | 4 +- setup.py | 1 - 9 files changed, 31 insertions(+), 63 deletions(-) diff --git a/embedchain/chunkers/base_chunker.py b/embedchain/chunkers/base_chunker.py index f14b2d2c..e2334488 100644 --- a/embedchain/chunkers/base_chunker.py +++ b/embedchain/chunkers/base_chunker.py @@ -46,4 +46,4 @@ class BaseChunker: Override in child class if custom logic. """ - return self.text_splitter.split_text(content) \ No newline at end of file + return self.text_splitter.split_text(content) diff --git a/embedchain/chunkers/code_docs_page.py b/embedchain/chunkers/code_docs_page.py index ed50fb65..a3470cfa 100644 --- a/embedchain/chunkers/code_docs_page.py +++ b/embedchain/chunkers/code_docs_page.py @@ -19,4 +19,4 @@ class CodeDocsPageChunker(BaseChunker): if config is None: config = TEXT_SPLITTER_CHUNK_PARAMS text_splitter = RecursiveCharacterTextSplitter(**config) - super().__init__(text_splitter) \ No newline at end of file + super().__init__(text_splitter) diff --git a/embedchain/config/InitConfig.py b/embedchain/config/InitConfig.py index 35b8d9bb..47fe7e64 100644 --- a/embedchain/config/InitConfig.py +++ b/embedchain/config/InitConfig.py @@ -40,13 +40,8 @@ class InitConfig(BaseConfig): :raises ValueError: If the template is not valid as template should contain $context and $query """ - if ( - os.getenv("OPENAI_API_KEY") is None - and os.getenv("OPENAI_ORGANIZATION") is None - ): - raise ValueError( - "OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided" # noqa:E501 - ) + if os.getenv("OPENAI_API_KEY") is None and os.getenv("OPENAI_ORGANIZATION") is None: + raise ValueError("OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided") # noqa:E501 self.ef = embedding_functions.OpenAIEmbeddingFunction( api_key=os.getenv("OPENAI_API_KEY"), organization_id=os.getenv("OPENAI_ORGANIZATION"), @@ -74,8 +69,6 @@ class InitConfig(BaseConfig): if not isinstance(level, int): raise ValueError(f"Invalid log level: {debug_level}") - logging.basicConfig( - format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s", level=level - ) + logging.basicConfig(format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s", level=level) self.logger = logging.getLogger(__name__) return diff --git a/embedchain/config/QueryConfig.py b/embedchain/config/QueryConfig.py index 88637338..39285a21 100644 --- a/embedchain/config/QueryConfig.py +++ b/embedchain/config/QueryConfig.py @@ -113,9 +113,7 @@ class QueryConfig(BaseConfig): if self.history is None: raise ValueError("`template` should have `query` and `context` keys") else: - raise ValueError( - "`template` should have `query`, `context` and `history` keys" - ) + raise ValueError("`template` should have `query`, `context` and `history` keys") if not isinstance(stream, bool): raise ValueError("`stream` should be bool") @@ -129,9 +127,7 @@ class QueryConfig(BaseConfig): :return: Boolean, valid (true) or invalid (false) """ if self.history is None: - return re.search(query_re, template.template) and re.search( - context_re, template.template - ) + return re.search(query_re, template.template) and re.search(context_re, template.template) else: return ( re.search(query_re, template.template) diff --git a/embedchain/data_formatter/data_formatter.py b/embedchain/data_formatter/data_formatter.py index 232cbd86..10c8c68b 100644 --- a/embedchain/data_formatter/data_formatter.py +++ b/embedchain/data_formatter/data_formatter.py @@ -66,7 +66,7 @@ class DataFormatter: "text": TextChunker(config), "docx": DocxFileChunker(config), "sitemap": WebPageChunker(config), - "code_docs_page": CodeDocsPageChunker(config) + "code_docs_page": CodeDocsPageChunker(config), } if data_type in chunkers: return chunkers[data_type] diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 91913758..0327675a 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -9,7 +9,7 @@ from langchain.docstore.document import Document from langchain.memory import ConversationBufferMemory from embedchain.config import AddConfig, ChatConfig, InitConfig, QueryConfig -from embedchain.config.QueryConfig import DEFAULT_PROMPT, CODE_DOCS_PAGE_PROMPT_TEMPLATE +from embedchain.config.QueryConfig import CODE_DOCS_PAGE_PROMPT_TEMPLATE, DEFAULT_PROMPT from embedchain.data_formatter import DataFormatter gpt4all_model = None @@ -54,10 +54,8 @@ class EmbedChain: data_formatter = DataFormatter(data_type, config) self.user_asks.append([data_type, url, metadata]) - self.load_and_embed( - data_formatter.loader, data_formatter.chunker, url, metadata - ) - if data_type in ("code_docs_page", ): + self.load_and_embed(data_formatter.loader, data_formatter.chunker, url, metadata) + if data_type in ("code_docs_page",): self.is_code_docs_instance = True def add_local(self, data_type, content, metadata=None, config: AddConfig = None): @@ -106,12 +104,8 @@ class EmbedChain: existing_ids = set(existing_docs["ids"]) if len(existing_ids): - data_dict = { - id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas) - } - data_dict = { - id: value for id, value in data_dict.items() if id not in existing_ids - } + data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)} + data_dict = {id: value for id, value in data_dict.items() if id not in existing_ids} if not data_dict: print(f"All data from {src} already exists in the database.") @@ -125,15 +119,8 @@ class EmbedChain: # Add metadata to each document metadatas_with_metadata = [meta or metadata for meta in metadatas] - self.collection.add( - documents=documents, metadatas=list(metadatas_with_metadata), ids=ids - ) - print( - ( - f"Successfully saved {src}. New chunks count: " - f"{self.count() - chunks_before_addition}" - ) - ) + self.collection.add(documents=documents, metadatas=list(metadatas_with_metadata), ids=ids) + print((f"Successfully saved {src}. New chunks count: " f"{self.count() - chunks_before_addition}")) def _format_result(self, results): return [ @@ -180,13 +167,9 @@ class EmbedChain: """ context_string = (" | ").join(contexts) if not config.history: - prompt = config.template.substitute( - context=context_string, query=input_query - ) + prompt = config.template.substitute(context=context_string, query=input_query) else: - prompt = config.template.substitute( - context=context_string, query=input_query, history=config.history - ) + prompt = config.template.substitute(context=context_string, query=input_query, history=config.history) return prompt def get_answer_from_llm(self, prompt, config: ChatConfig): @@ -387,17 +370,13 @@ class OpenSourceApp(EmbedChain): :param config: InitConfig instance to load as configuration. Optional. `ef` defaults to open source. """ - print( - "Loading open source embedding model. This may take some time..." - ) # noqa:E501 + print("Loading open source embedding model. This may take some time...") # noqa:E501 if not config: config = InitConfig() if not config.ef: config._set_embedding_function( - embedding_functions.SentenceTransformerEmbeddingFunction( - model_name="all-MiniLM-L6-v2" - ) + embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2") ) if not config.db: diff --git a/embedchain/loaders/code_docs_page.py b/embedchain/loaders/code_docs_page.py index 055bbbaf..5d4e1720 100644 --- a/embedchain/loaders/code_docs_page.py +++ b/embedchain/loaders/code_docs_page.py @@ -3,6 +3,7 @@ from bs4 import BeautifulSoup from embedchain.utils import clean_string + class CodeDocsPageLoader: def load_data(self, url): """Load data from a web page.""" @@ -10,14 +11,14 @@ class CodeDocsPageLoader: data = response.content soup = BeautifulSoup(data, "html.parser") selectors = [ - 'article.bd-article', + "article.bd-article", 'article[role="main"]', - 'div.md-content', + "div.md-content", 'div[role="main"]', - 'div.container', - 'div.section', - 'article', - 'main', + "div.container", + "div.section", + "article", + "main", ] content = None for selector in selectors: @@ -43,11 +44,11 @@ class CodeDocsPageLoader: ] ): tag.string = " " - for div in soup.find_all("div", {'class': 'cell_output'}): + for div in soup.find_all("div", {"class": "cell_output"}): div.decompose() - for div in soup.find_all("div", {'class': 'output_wrapper'}): + for div in soup.find_all("div", {"class": "output_wrapper"}): div.decompose() - for div in soup.find_all("div", {'class': 'output'}): + for div in soup.find_all("div", {"class": "output"}): div.decompose() content = clean_string(soup.get_text()) output = [] diff --git a/pyproject.toml b/pyproject.toml index 79842bbe..9f1da4f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ exclude = [ "node_modules", "venv", ] -line-length = 88 +line-length = 120 dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" target-version = "py38" @@ -38,7 +38,7 @@ target-version = "py38" max-complexity = 10 [tool.black] -line-length = 88 +line-length = 120 target-version = ["py38", "py39", "py310", "py311"] include = '\.pyi?$' exclude = ''' diff --git a/setup.py b/setup.py index ebd8bd3b..961a02ac 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,5 @@ import setuptools - with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read()