chore: linting (#597)

This commit is contained in:
cachho
2023-09-12 18:04:38 +02:00
committed by GitHub
parent 0f9a10c598
commit 03146946fa
11 changed files with 25 additions and 35 deletions

View File

@@ -242,7 +242,7 @@ class EmbedChain(JSONSerializable):
src: Any,
metadata: Optional[Dict[str, Any]] = None,
source_id: Optional[str] = None,
dry_run = False
dry_run=False,
) -> Tuple[List[str], Dict[str, Any], List[str], int]:
"""The loader to use to load the data.
@@ -320,14 +320,14 @@ class EmbedChain(JSONSerializable):
return list(documents), metadatas, ids, count_new_chunks
def load_and_embed_v2(
self,
loader: BaseLoader,
chunker: BaseChunker,
src: Any,
metadata: Optional[Dict[str, Any]] = None,
source_id: Optional[str] = None,
dry_run = False
):
self,
loader: BaseLoader,
chunker: BaseChunker,
src: Any,
metadata: Optional[Dict[str, Any]] = None,
source_id: Optional[str] = None,
dry_run=False,
):
"""
Loads the data from the given URL, chunks it, and adds it to database.
@@ -364,9 +364,7 @@ class EmbedChain(JSONSerializable):
# this means that doc content has changed.
if existing_doc_id and existing_doc_id != new_doc_id:
print("Doc content has changed. Recomputing chunks and embeddings intelligently.")
self.db.delete({
"doc_id": existing_doc_id
})
self.db.delete({"doc_id": existing_doc_id})
# get existing ids, and discard doc if any common id exist.
where = {"app_id": self.config.id} if self.config.id is not None else {}

View File

@@ -46,7 +46,4 @@ class CsvLoader(BaseLoader):
lines.append(line)
result.append({"content": line, "meta_data": {"url": content, "row": i + 1}})
doc_id = hashlib.sha256((content + " ".join(lines)).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": result
}
return {"doc_id": doc_id, "data": result}

View File

@@ -22,5 +22,5 @@ class LocalQnaPairLoader(BaseLoader):
"content": content,
"meta_data": meta_data,
}
]
],
}

View File

@@ -20,5 +20,5 @@ class LocalTextLoader(BaseLoader):
"content": content,
"meta_data": meta_data,
}
]
],
}

View File

@@ -39,9 +39,9 @@ class NotionLoader(BaseLoader):
return {
"doc_id": doc_id,
"data": [
{
"content": text,
"meta_data": {"url": f"notion-{formatted_id}"},
}
],
{
"content": text,
"meta_data": {"url": f"notion-{formatted_id}"},
}
],
}

View File

@@ -43,7 +43,4 @@ class SitemapLoader(BaseLoader):
logging.warning(f"Page is not readable (too many invalid characters): {link}")
except ParserRejectedMarkup as e:
logging.error(f"Failed to parse {link}: {e}")
return {
"doc_id": doc_id,
"data": [data[0] for data in output]
}
return {"doc_id": doc_id, "data": [data[0] for data in output]}

View File

@@ -66,7 +66,7 @@ class WebPageLoader(BaseLoader):
}
content = content
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
return {
return {
"doc_id": doc_id,
"data": [
{

View File

@@ -47,4 +47,4 @@ class BaseVectorDB(JSONSerializable):
raise NotImplementedError
def set_collection_name(self, name: str):
raise NotImplementedError
raise NotImplementedError

View File

@@ -1,5 +1,5 @@
import logging
from typing import Dict, List, Optional, Any
from typing import Any, Dict, List, Optional
from chromadb import Collection, QueryResult
from langchain.docstore.document import Document
@@ -105,9 +105,7 @@ class ChromaDB(BaseVectorDB):
args["where"] = where
if limit:
args["limit"] = limit
return self.collection.get(
**args
)
return self.collection.get(**args)
def get_advanced(self, where):
return self.collection.get(where=where, limit=1)

View File

@@ -76,5 +76,5 @@ class MockLoader:
"content": src,
"meta_data": {"url": "none"},
}
]
],
}

View File

@@ -3,7 +3,7 @@ import unittest
from unittest.mock import MagicMock, patch
from embedchain import App
from embedchain.config import AppConfig, AddConfig, ChunkerConfig
from embedchain.config import AddConfig, AppConfig, ChunkerConfig
from embedchain.models.data_type import DataType