diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index a2019b75..d6364b5b 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -242,7 +242,7 @@ class EmbedChain(JSONSerializable): src: Any, metadata: Optional[Dict[str, Any]] = None, source_id: Optional[str] = None, - dry_run = False + dry_run=False, ) -> Tuple[List[str], Dict[str, Any], List[str], int]: """The loader to use to load the data. @@ -320,14 +320,14 @@ class EmbedChain(JSONSerializable): return list(documents), metadatas, ids, count_new_chunks def load_and_embed_v2( - self, - loader: BaseLoader, - chunker: BaseChunker, - src: Any, - metadata: Optional[Dict[str, Any]] = None, - source_id: Optional[str] = None, - dry_run = False - ): + self, + loader: BaseLoader, + chunker: BaseChunker, + src: Any, + metadata: Optional[Dict[str, Any]] = None, + source_id: Optional[str] = None, + dry_run=False, + ): """ Loads the data from the given URL, chunks it, and adds it to database. @@ -364,9 +364,7 @@ class EmbedChain(JSONSerializable): # this means that doc content has changed. if existing_doc_id and existing_doc_id != new_doc_id: print("Doc content has changed. Recomputing chunks and embeddings intelligently.") - self.db.delete({ - "doc_id": existing_doc_id - }) + self.db.delete({"doc_id": existing_doc_id}) # get existing ids, and discard doc if any common id exist. where = {"app_id": self.config.id} if self.config.id is not None else {} diff --git a/embedchain/loaders/csv.py b/embedchain/loaders/csv.py index 9de84de0..1730ae9c 100644 --- a/embedchain/loaders/csv.py +++ b/embedchain/loaders/csv.py @@ -46,7 +46,4 @@ class CsvLoader(BaseLoader): lines.append(line) result.append({"content": line, "meta_data": {"url": content, "row": i + 1}}) doc_id = hashlib.sha256((content + " ".join(lines)).encode()).hexdigest() - return { - "doc_id": doc_id, - "data": result - } + return {"doc_id": doc_id, "data": result} diff --git a/embedchain/loaders/local_qna_pair.py b/embedchain/loaders/local_qna_pair.py index 36da278e..61d9a576 100644 --- a/embedchain/loaders/local_qna_pair.py +++ b/embedchain/loaders/local_qna_pair.py @@ -22,5 +22,5 @@ class LocalQnaPairLoader(BaseLoader): "content": content, "meta_data": meta_data, } - ] + ], } diff --git a/embedchain/loaders/local_text.py b/embedchain/loaders/local_text.py index 7a519578..118cbd3a 100644 --- a/embedchain/loaders/local_text.py +++ b/embedchain/loaders/local_text.py @@ -20,5 +20,5 @@ class LocalTextLoader(BaseLoader): "content": content, "meta_data": meta_data, } - ] + ], } diff --git a/embedchain/loaders/notion.py b/embedchain/loaders/notion.py index 065673e5..7ff84ed5 100644 --- a/embedchain/loaders/notion.py +++ b/embedchain/loaders/notion.py @@ -39,9 +39,9 @@ class NotionLoader(BaseLoader): return { "doc_id": doc_id, "data": [ - { - "content": text, - "meta_data": {"url": f"notion-{formatted_id}"}, - } - ], + { + "content": text, + "meta_data": {"url": f"notion-{formatted_id}"}, + } + ], } diff --git a/embedchain/loaders/sitemap.py b/embedchain/loaders/sitemap.py index c78542c4..d85e8829 100644 --- a/embedchain/loaders/sitemap.py +++ b/embedchain/loaders/sitemap.py @@ -43,7 +43,4 @@ class SitemapLoader(BaseLoader): logging.warning(f"Page is not readable (too many invalid characters): {link}") except ParserRejectedMarkup as e: logging.error(f"Failed to parse {link}: {e}") - return { - "doc_id": doc_id, - "data": [data[0] for data in output] - } + return {"doc_id": doc_id, "data": [data[0] for data in output]} diff --git a/embedchain/loaders/web_page.py b/embedchain/loaders/web_page.py index 9e62df38..53d41df0 100644 --- a/embedchain/loaders/web_page.py +++ b/embedchain/loaders/web_page.py @@ -66,7 +66,7 @@ class WebPageLoader(BaseLoader): } content = content doc_id = hashlib.sha256((content + url).encode()).hexdigest() - return { + return { "doc_id": doc_id, "data": [ { diff --git a/embedchain/vectordb/base_vector_db.py b/embedchain/vectordb/base_vector_db.py index e7ef5c8e..aee18f1c 100644 --- a/embedchain/vectordb/base_vector_db.py +++ b/embedchain/vectordb/base_vector_db.py @@ -47,4 +47,4 @@ class BaseVectorDB(JSONSerializable): raise NotImplementedError def set_collection_name(self, name: str): - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/embedchain/vectordb/chroma.py b/embedchain/vectordb/chroma.py index 2717b378..3086c23f 100644 --- a/embedchain/vectordb/chroma.py +++ b/embedchain/vectordb/chroma.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, List, Optional, Any +from typing import Any, Dict, List, Optional from chromadb import Collection, QueryResult from langchain.docstore.document import Document @@ -105,9 +105,7 @@ class ChromaDB(BaseVectorDB): args["where"] = where if limit: args["limit"] = limit - return self.collection.get( - **args - ) + return self.collection.get(**args) def get_advanced(self, where): return self.collection.get(where=where, limit=1) diff --git a/tests/chunkers/test_text.py b/tests/chunkers/test_text.py index 6ea56620..e5bc32ab 100644 --- a/tests/chunkers/test_text.py +++ b/tests/chunkers/test_text.py @@ -76,5 +76,5 @@ class MockLoader: "content": src, "meta_data": {"url": "none"}, } - ] + ], } diff --git a/tests/embedchain/test_add.py b/tests/embedchain/test_add.py index 446ab496..f8c20564 100644 --- a/tests/embedchain/test_add.py +++ b/tests/embedchain/test_add.py @@ -3,7 +3,7 @@ import unittest from unittest.mock import MagicMock, patch from embedchain import App -from embedchain.config import AppConfig, AddConfig, ChunkerConfig +from embedchain.config import AddConfig, AppConfig, ChunkerConfig from embedchain.models.data_type import DataType