feat: Update line length to 120 chars (#278)
This commit is contained in:
@@ -46,4 +46,4 @@ class BaseChunker:
|
|||||||
|
|
||||||
Override in child class if custom logic.
|
Override in child class if custom logic.
|
||||||
"""
|
"""
|
||||||
return self.text_splitter.split_text(content)
|
return self.text_splitter.split_text(content)
|
||||||
|
|||||||
@@ -19,4 +19,4 @@ class CodeDocsPageChunker(BaseChunker):
|
|||||||
if config is None:
|
if config is None:
|
||||||
config = TEXT_SPLITTER_CHUNK_PARAMS
|
config = TEXT_SPLITTER_CHUNK_PARAMS
|
||||||
text_splitter = RecursiveCharacterTextSplitter(**config)
|
text_splitter = RecursiveCharacterTextSplitter(**config)
|
||||||
super().__init__(text_splitter)
|
super().__init__(text_splitter)
|
||||||
|
|||||||
@@ -40,13 +40,8 @@ class InitConfig(BaseConfig):
|
|||||||
:raises ValueError: If the template is not valid as template should contain
|
:raises ValueError: If the template is not valid as template should contain
|
||||||
$context and $query
|
$context and $query
|
||||||
"""
|
"""
|
||||||
if (
|
if os.getenv("OPENAI_API_KEY") is None and os.getenv("OPENAI_ORGANIZATION") is None:
|
||||||
os.getenv("OPENAI_API_KEY") is None
|
raise ValueError("OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided") # noqa:E501
|
||||||
and os.getenv("OPENAI_ORGANIZATION") is None
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided" # noqa:E501
|
|
||||||
)
|
|
||||||
self.ef = embedding_functions.OpenAIEmbeddingFunction(
|
self.ef = embedding_functions.OpenAIEmbeddingFunction(
|
||||||
api_key=os.getenv("OPENAI_API_KEY"),
|
api_key=os.getenv("OPENAI_API_KEY"),
|
||||||
organization_id=os.getenv("OPENAI_ORGANIZATION"),
|
organization_id=os.getenv("OPENAI_ORGANIZATION"),
|
||||||
@@ -74,8 +69,6 @@ class InitConfig(BaseConfig):
|
|||||||
if not isinstance(level, int):
|
if not isinstance(level, int):
|
||||||
raise ValueError(f"Invalid log level: {debug_level}")
|
raise ValueError(f"Invalid log level: {debug_level}")
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s", level=level)
|
||||||
format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s", level=level
|
|
||||||
)
|
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -113,9 +113,7 @@ class QueryConfig(BaseConfig):
|
|||||||
if self.history is None:
|
if self.history is None:
|
||||||
raise ValueError("`template` should have `query` and `context` keys")
|
raise ValueError("`template` should have `query` and `context` keys")
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError("`template` should have `query`, `context` and `history` keys")
|
||||||
"`template` should have `query`, `context` and `history` keys"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not isinstance(stream, bool):
|
if not isinstance(stream, bool):
|
||||||
raise ValueError("`stream` should be bool")
|
raise ValueError("`stream` should be bool")
|
||||||
@@ -129,9 +127,7 @@ class QueryConfig(BaseConfig):
|
|||||||
:return: Boolean, valid (true) or invalid (false)
|
:return: Boolean, valid (true) or invalid (false)
|
||||||
"""
|
"""
|
||||||
if self.history is None:
|
if self.history is None:
|
||||||
return re.search(query_re, template.template) and re.search(
|
return re.search(query_re, template.template) and re.search(context_re, template.template)
|
||||||
context_re, template.template
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return (
|
return (
|
||||||
re.search(query_re, template.template)
|
re.search(query_re, template.template)
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ class DataFormatter:
|
|||||||
"text": TextChunker(config),
|
"text": TextChunker(config),
|
||||||
"docx": DocxFileChunker(config),
|
"docx": DocxFileChunker(config),
|
||||||
"sitemap": WebPageChunker(config),
|
"sitemap": WebPageChunker(config),
|
||||||
"code_docs_page": CodeDocsPageChunker(config)
|
"code_docs_page": CodeDocsPageChunker(config),
|
||||||
}
|
}
|
||||||
if data_type in chunkers:
|
if data_type in chunkers:
|
||||||
return chunkers[data_type]
|
return chunkers[data_type]
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from langchain.docstore.document import Document
|
|||||||
from langchain.memory import ConversationBufferMemory
|
from langchain.memory import ConversationBufferMemory
|
||||||
|
|
||||||
from embedchain.config import AddConfig, ChatConfig, InitConfig, QueryConfig
|
from embedchain.config import AddConfig, ChatConfig, InitConfig, QueryConfig
|
||||||
from embedchain.config.QueryConfig import DEFAULT_PROMPT, CODE_DOCS_PAGE_PROMPT_TEMPLATE
|
from embedchain.config.QueryConfig import CODE_DOCS_PAGE_PROMPT_TEMPLATE, DEFAULT_PROMPT
|
||||||
from embedchain.data_formatter import DataFormatter
|
from embedchain.data_formatter import DataFormatter
|
||||||
|
|
||||||
gpt4all_model = None
|
gpt4all_model = None
|
||||||
@@ -54,10 +54,8 @@ class EmbedChain:
|
|||||||
|
|
||||||
data_formatter = DataFormatter(data_type, config)
|
data_formatter = DataFormatter(data_type, config)
|
||||||
self.user_asks.append([data_type, url, metadata])
|
self.user_asks.append([data_type, url, metadata])
|
||||||
self.load_and_embed(
|
self.load_and_embed(data_formatter.loader, data_formatter.chunker, url, metadata)
|
||||||
data_formatter.loader, data_formatter.chunker, url, metadata
|
if data_type in ("code_docs_page",):
|
||||||
)
|
|
||||||
if data_type in ("code_docs_page", ):
|
|
||||||
self.is_code_docs_instance = True
|
self.is_code_docs_instance = True
|
||||||
|
|
||||||
def add_local(self, data_type, content, metadata=None, config: AddConfig = None):
|
def add_local(self, data_type, content, metadata=None, config: AddConfig = None):
|
||||||
@@ -106,12 +104,8 @@ class EmbedChain:
|
|||||||
existing_ids = set(existing_docs["ids"])
|
existing_ids = set(existing_docs["ids"])
|
||||||
|
|
||||||
if len(existing_ids):
|
if len(existing_ids):
|
||||||
data_dict = {
|
data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)}
|
||||||
id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)
|
data_dict = {id: value for id, value in data_dict.items() if id not in existing_ids}
|
||||||
}
|
|
||||||
data_dict = {
|
|
||||||
id: value for id, value in data_dict.items() if id not in existing_ids
|
|
||||||
}
|
|
||||||
|
|
||||||
if not data_dict:
|
if not data_dict:
|
||||||
print(f"All data from {src} already exists in the database.")
|
print(f"All data from {src} already exists in the database.")
|
||||||
@@ -125,15 +119,8 @@ class EmbedChain:
|
|||||||
# Add metadata to each document
|
# Add metadata to each document
|
||||||
metadatas_with_metadata = [meta or metadata for meta in metadatas]
|
metadatas_with_metadata = [meta or metadata for meta in metadatas]
|
||||||
|
|
||||||
self.collection.add(
|
self.collection.add(documents=documents, metadatas=list(metadatas_with_metadata), ids=ids)
|
||||||
documents=documents, metadatas=list(metadatas_with_metadata), ids=ids
|
print((f"Successfully saved {src}. New chunks count: " f"{self.count() - chunks_before_addition}"))
|
||||||
)
|
|
||||||
print(
|
|
||||||
(
|
|
||||||
f"Successfully saved {src}. New chunks count: "
|
|
||||||
f"{self.count() - chunks_before_addition}"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _format_result(self, results):
|
def _format_result(self, results):
|
||||||
return [
|
return [
|
||||||
@@ -180,13 +167,9 @@ class EmbedChain:
|
|||||||
"""
|
"""
|
||||||
context_string = (" | ").join(contexts)
|
context_string = (" | ").join(contexts)
|
||||||
if not config.history:
|
if not config.history:
|
||||||
prompt = config.template.substitute(
|
prompt = config.template.substitute(context=context_string, query=input_query)
|
||||||
context=context_string, query=input_query
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
prompt = config.template.substitute(
|
prompt = config.template.substitute(context=context_string, query=input_query, history=config.history)
|
||||||
context=context_string, query=input_query, history=config.history
|
|
||||||
)
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def get_answer_from_llm(self, prompt, config: ChatConfig):
|
def get_answer_from_llm(self, prompt, config: ChatConfig):
|
||||||
@@ -387,17 +370,13 @@ class OpenSourceApp(EmbedChain):
|
|||||||
:param config: InitConfig instance to load as configuration. Optional.
|
:param config: InitConfig instance to load as configuration. Optional.
|
||||||
`ef` defaults to open source.
|
`ef` defaults to open source.
|
||||||
"""
|
"""
|
||||||
print(
|
print("Loading open source embedding model. This may take some time...") # noqa:E501
|
||||||
"Loading open source embedding model. This may take some time..."
|
|
||||||
) # noqa:E501
|
|
||||||
if not config:
|
if not config:
|
||||||
config = InitConfig()
|
config = InitConfig()
|
||||||
|
|
||||||
if not config.ef:
|
if not config.ef:
|
||||||
config._set_embedding_function(
|
config._set_embedding_function(
|
||||||
embedding_functions.SentenceTransformerEmbeddingFunction(
|
embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")
|
||||||
model_name="all-MiniLM-L6-v2"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not config.db:
|
if not config.db:
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from bs4 import BeautifulSoup
|
|||||||
|
|
||||||
from embedchain.utils import clean_string
|
from embedchain.utils import clean_string
|
||||||
|
|
||||||
|
|
||||||
class CodeDocsPageLoader:
|
class CodeDocsPageLoader:
|
||||||
def load_data(self, url):
|
def load_data(self, url):
|
||||||
"""Load data from a web page."""
|
"""Load data from a web page."""
|
||||||
@@ -10,14 +11,14 @@ class CodeDocsPageLoader:
|
|||||||
data = response.content
|
data = response.content
|
||||||
soup = BeautifulSoup(data, "html.parser")
|
soup = BeautifulSoup(data, "html.parser")
|
||||||
selectors = [
|
selectors = [
|
||||||
'article.bd-article',
|
"article.bd-article",
|
||||||
'article[role="main"]',
|
'article[role="main"]',
|
||||||
'div.md-content',
|
"div.md-content",
|
||||||
'div[role="main"]',
|
'div[role="main"]',
|
||||||
'div.container',
|
"div.container",
|
||||||
'div.section',
|
"div.section",
|
||||||
'article',
|
"article",
|
||||||
'main',
|
"main",
|
||||||
]
|
]
|
||||||
content = None
|
content = None
|
||||||
for selector in selectors:
|
for selector in selectors:
|
||||||
@@ -43,11 +44,11 @@ class CodeDocsPageLoader:
|
|||||||
]
|
]
|
||||||
):
|
):
|
||||||
tag.string = " "
|
tag.string = " "
|
||||||
for div in soup.find_all("div", {'class': 'cell_output'}):
|
for div in soup.find_all("div", {"class": "cell_output"}):
|
||||||
div.decompose()
|
div.decompose()
|
||||||
for div in soup.find_all("div", {'class': 'output_wrapper'}):
|
for div in soup.find_all("div", {"class": "output_wrapper"}):
|
||||||
div.decompose()
|
div.decompose()
|
||||||
for div in soup.find_all("div", {'class': 'output'}):
|
for div in soup.find_all("div", {"class": "output"}):
|
||||||
div.decompose()
|
div.decompose()
|
||||||
content = clean_string(soup.get_text())
|
content = clean_string(soup.get_text())
|
||||||
output = []
|
output = []
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ exclude = [
|
|||||||
"node_modules",
|
"node_modules",
|
||||||
"venv",
|
"venv",
|
||||||
]
|
]
|
||||||
line-length = 88
|
line-length = 120
|
||||||
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
|
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
|
||||||
target-version = "py38"
|
target-version = "py38"
|
||||||
|
|
||||||
@@ -38,7 +38,7 @@ target-version = "py38"
|
|||||||
max-complexity = 10
|
max-complexity = 10
|
||||||
|
|
||||||
[tool.black]
|
[tool.black]
|
||||||
line-length = 88
|
line-length = 120
|
||||||
target-version = ["py38", "py39", "py310", "py311"]
|
target-version = ["py38", "py39", "py310", "py311"]
|
||||||
include = '\.pyi?$'
|
include = '\.pyi?$'
|
||||||
exclude = '''
|
exclude = '''
|
||||||
|
|||||||
Reference in New Issue
Block a user