feat: Update line length to 120 chars (#278)

This commit is contained in:
Deshraj Yadav
2023-07-15 07:11:55 -07:00
committed by GitHub
parent 4f722621fd
commit fd97fb268a
9 changed files with 31 additions and 63 deletions

View File

@@ -9,7 +9,7 @@ from langchain.docstore.document import Document
from langchain.memory import ConversationBufferMemory
from embedchain.config import AddConfig, ChatConfig, InitConfig, QueryConfig
from embedchain.config.QueryConfig import DEFAULT_PROMPT, CODE_DOCS_PAGE_PROMPT_TEMPLATE
from embedchain.config.QueryConfig import CODE_DOCS_PAGE_PROMPT_TEMPLATE, DEFAULT_PROMPT
from embedchain.data_formatter import DataFormatter
gpt4all_model = None
@@ -54,10 +54,8 @@ class EmbedChain:
data_formatter = DataFormatter(data_type, config)
self.user_asks.append([data_type, url, metadata])
self.load_and_embed(
data_formatter.loader, data_formatter.chunker, url, metadata
)
if data_type in ("code_docs_page", ):
self.load_and_embed(data_formatter.loader, data_formatter.chunker, url, metadata)
if data_type in ("code_docs_page",):
self.is_code_docs_instance = True
def add_local(self, data_type, content, metadata=None, config: AddConfig = None):
@@ -106,12 +104,8 @@ class EmbedChain:
existing_ids = set(existing_docs["ids"])
if len(existing_ids):
data_dict = {
id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)
}
data_dict = {
id: value for id, value in data_dict.items() if id not in existing_ids
}
data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)}
data_dict = {id: value for id, value in data_dict.items() if id not in existing_ids}
if not data_dict:
print(f"All data from {src} already exists in the database.")
@@ -125,15 +119,8 @@ class EmbedChain:
# Add metadata to each document
metadatas_with_metadata = [meta or metadata for meta in metadatas]
self.collection.add(
documents=documents, metadatas=list(metadatas_with_metadata), ids=ids
)
print(
(
f"Successfully saved {src}. New chunks count: "
f"{self.count() - chunks_before_addition}"
)
)
self.collection.add(documents=documents, metadatas=list(metadatas_with_metadata), ids=ids)
print((f"Successfully saved {src}. New chunks count: " f"{self.count() - chunks_before_addition}"))
def _format_result(self, results):
return [
@@ -180,13 +167,9 @@ class EmbedChain:
"""
context_string = (" | ").join(contexts)
if not config.history:
prompt = config.template.substitute(
context=context_string, query=input_query
)
prompt = config.template.substitute(context=context_string, query=input_query)
else:
prompt = config.template.substitute(
context=context_string, query=input_query, history=config.history
)
prompt = config.template.substitute(context=context_string, query=input_query, history=config.history)
return prompt
def get_answer_from_llm(self, prompt, config: ChatConfig):
@@ -387,17 +370,13 @@ class OpenSourceApp(EmbedChain):
:param config: InitConfig instance to load as configuration. Optional.
`ef` defaults to open source.
"""
print(
"Loading open source embedding model. This may take some time..."
) # noqa:E501
print("Loading open source embedding model. This may take some time...") # noqa:E501
if not config:
config = InitConfig()
if not config.ef:
config._set_embedding_function(
embedding_functions.SentenceTransformerEmbeddingFunction(
model_name="all-MiniLM-L6-v2"
)
embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")
)
if not config.db: