chore: linting (#597)

This commit is contained in:
cachho
2023-09-12 18:04:38 +02:00
committed by GitHub
parent 0f9a10c598
commit 03146946fa
11 changed files with 25 additions and 35 deletions

View File

@@ -242,7 +242,7 @@ class EmbedChain(JSONSerializable):
src: Any, src: Any,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
source_id: Optional[str] = None, source_id: Optional[str] = None,
dry_run = False dry_run=False,
) -> Tuple[List[str], Dict[str, Any], List[str], int]: ) -> Tuple[List[str], Dict[str, Any], List[str], int]:
"""The loader to use to load the data. """The loader to use to load the data.
@@ -320,14 +320,14 @@ class EmbedChain(JSONSerializable):
return list(documents), metadatas, ids, count_new_chunks return list(documents), metadatas, ids, count_new_chunks
def load_and_embed_v2( def load_and_embed_v2(
self, self,
loader: BaseLoader, loader: BaseLoader,
chunker: BaseChunker, chunker: BaseChunker,
src: Any, src: Any,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
source_id: Optional[str] = None, source_id: Optional[str] = None,
dry_run = False dry_run=False,
): ):
""" """
Loads the data from the given URL, chunks it, and adds it to database. Loads the data from the given URL, chunks it, and adds it to database.
@@ -364,9 +364,7 @@ class EmbedChain(JSONSerializable):
# this means that doc content has changed. # this means that doc content has changed.
if existing_doc_id and existing_doc_id != new_doc_id: if existing_doc_id and existing_doc_id != new_doc_id:
print("Doc content has changed. Recomputing chunks and embeddings intelligently.") print("Doc content has changed. Recomputing chunks and embeddings intelligently.")
self.db.delete({ self.db.delete({"doc_id": existing_doc_id})
"doc_id": existing_doc_id
})
# get existing ids, and discard doc if any common id exist. # get existing ids, and discard doc if any common id exist.
where = {"app_id": self.config.id} if self.config.id is not None else {} where = {"app_id": self.config.id} if self.config.id is not None else {}

View File

@@ -46,7 +46,4 @@ class CsvLoader(BaseLoader):
lines.append(line) lines.append(line)
result.append({"content": line, "meta_data": {"url": content, "row": i + 1}}) result.append({"content": line, "meta_data": {"url": content, "row": i + 1}})
doc_id = hashlib.sha256((content + " ".join(lines)).encode()).hexdigest() doc_id = hashlib.sha256((content + " ".join(lines)).encode()).hexdigest()
return { return {"doc_id": doc_id, "data": result}
"doc_id": doc_id,
"data": result
}

View File

@@ -22,5 +22,5 @@ class LocalQnaPairLoader(BaseLoader):
"content": content, "content": content,
"meta_data": meta_data, "meta_data": meta_data,
} }
] ],
} }

View File

@@ -20,5 +20,5 @@ class LocalTextLoader(BaseLoader):
"content": content, "content": content,
"meta_data": meta_data, "meta_data": meta_data,
} }
] ],
} }

View File

@@ -39,9 +39,9 @@ class NotionLoader(BaseLoader):
return { return {
"doc_id": doc_id, "doc_id": doc_id,
"data": [ "data": [
{ {
"content": text, "content": text,
"meta_data": {"url": f"notion-{formatted_id}"}, "meta_data": {"url": f"notion-{formatted_id}"},
} }
], ],
} }

View File

@@ -43,7 +43,4 @@ class SitemapLoader(BaseLoader):
logging.warning(f"Page is not readable (too many invalid characters): {link}") logging.warning(f"Page is not readable (too many invalid characters): {link}")
except ParserRejectedMarkup as e: except ParserRejectedMarkup as e:
logging.error(f"Failed to parse {link}: {e}") logging.error(f"Failed to parse {link}: {e}")
return { return {"doc_id": doc_id, "data": [data[0] for data in output]}
"doc_id": doc_id,
"data": [data[0] for data in output]
}

View File

@@ -66,7 +66,7 @@ class WebPageLoader(BaseLoader):
} }
content = content content = content
doc_id = hashlib.sha256((content + url).encode()).hexdigest() doc_id = hashlib.sha256((content + url).encode()).hexdigest()
return { return {
"doc_id": doc_id, "doc_id": doc_id,
"data": [ "data": [
{ {

View File

@@ -47,4 +47,4 @@ class BaseVectorDB(JSONSerializable):
raise NotImplementedError raise NotImplementedError
def set_collection_name(self, name: str): def set_collection_name(self, name: str):
raise NotImplementedError raise NotImplementedError

View File

@@ -1,5 +1,5 @@
import logging import logging
from typing import Dict, List, Optional, Any from typing import Any, Dict, List, Optional
from chromadb import Collection, QueryResult from chromadb import Collection, QueryResult
from langchain.docstore.document import Document from langchain.docstore.document import Document
@@ -105,9 +105,7 @@ class ChromaDB(BaseVectorDB):
args["where"] = where args["where"] = where
if limit: if limit:
args["limit"] = limit args["limit"] = limit
return self.collection.get( return self.collection.get(**args)
**args
)
def get_advanced(self, where): def get_advanced(self, where):
return self.collection.get(where=where, limit=1) return self.collection.get(where=where, limit=1)

View File

@@ -76,5 +76,5 @@ class MockLoader:
"content": src, "content": src,
"meta_data": {"url": "none"}, "meta_data": {"url": "none"},
} }
] ],
} }

View File

@@ -3,7 +3,7 @@ import unittest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from embedchain import App from embedchain import App
from embedchain.config import AppConfig, AddConfig, ChunkerConfig from embedchain.config import AddConfig, AppConfig, ChunkerConfig
from embedchain.models.data_type import DataType from embedchain.models.data_type import DataType