chore: linting (#597)
This commit is contained in:
@@ -242,7 +242,7 @@ class EmbedChain(JSONSerializable):
|
||||
src: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
source_id: Optional[str] = None,
|
||||
dry_run = False
|
||||
dry_run=False,
|
||||
) -> Tuple[List[str], Dict[str, Any], List[str], int]:
|
||||
"""The loader to use to load the data.
|
||||
|
||||
@@ -320,14 +320,14 @@ class EmbedChain(JSONSerializable):
|
||||
return list(documents), metadatas, ids, count_new_chunks
|
||||
|
||||
def load_and_embed_v2(
|
||||
self,
|
||||
loader: BaseLoader,
|
||||
chunker: BaseChunker,
|
||||
src: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
source_id: Optional[str] = None,
|
||||
dry_run = False
|
||||
):
|
||||
self,
|
||||
loader: BaseLoader,
|
||||
chunker: BaseChunker,
|
||||
src: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
source_id: Optional[str] = None,
|
||||
dry_run=False,
|
||||
):
|
||||
"""
|
||||
Loads the data from the given URL, chunks it, and adds it to database.
|
||||
|
||||
@@ -364,9 +364,7 @@ class EmbedChain(JSONSerializable):
|
||||
# this means that doc content has changed.
|
||||
if existing_doc_id and existing_doc_id != new_doc_id:
|
||||
print("Doc content has changed. Recomputing chunks and embeddings intelligently.")
|
||||
self.db.delete({
|
||||
"doc_id": existing_doc_id
|
||||
})
|
||||
self.db.delete({"doc_id": existing_doc_id})
|
||||
|
||||
# get existing ids, and discard doc if any common id exist.
|
||||
where = {"app_id": self.config.id} if self.config.id is not None else {}
|
||||
|
||||
@@ -46,7 +46,4 @@ class CsvLoader(BaseLoader):
|
||||
lines.append(line)
|
||||
result.append({"content": line, "meta_data": {"url": content, "row": i + 1}})
|
||||
doc_id = hashlib.sha256((content + " ".join(lines)).encode()).hexdigest()
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": result
|
||||
}
|
||||
return {"doc_id": doc_id, "data": result}
|
||||
|
||||
@@ -22,5 +22,5 @@ class LocalQnaPairLoader(BaseLoader):
|
||||
"content": content,
|
||||
"meta_data": meta_data,
|
||||
}
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
@@ -20,5 +20,5 @@ class LocalTextLoader(BaseLoader):
|
||||
"content": content,
|
||||
"meta_data": meta_data,
|
||||
}
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
@@ -39,9 +39,9 @@ class NotionLoader(BaseLoader):
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": [
|
||||
{
|
||||
"content": text,
|
||||
"meta_data": {"url": f"notion-{formatted_id}"},
|
||||
}
|
||||
],
|
||||
{
|
||||
"content": text,
|
||||
"meta_data": {"url": f"notion-{formatted_id}"},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
@@ -43,7 +43,4 @@ class SitemapLoader(BaseLoader):
|
||||
logging.warning(f"Page is not readable (too many invalid characters): {link}")
|
||||
except ParserRejectedMarkup as e:
|
||||
logging.error(f"Failed to parse {link}: {e}")
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": [data[0] for data in output]
|
||||
}
|
||||
return {"doc_id": doc_id, "data": [data[0] for data in output]}
|
||||
|
||||
@@ -66,7 +66,7 @@ class WebPageLoader(BaseLoader):
|
||||
}
|
||||
content = content
|
||||
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
|
||||
return {
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": [
|
||||
{
|
||||
|
||||
@@ -47,4 +47,4 @@ class BaseVectorDB(JSONSerializable):
|
||||
raise NotImplementedError
|
||||
|
||||
def set_collection_name(self, name: str):
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from chromadb import Collection, QueryResult
|
||||
from langchain.docstore.document import Document
|
||||
@@ -105,9 +105,7 @@ class ChromaDB(BaseVectorDB):
|
||||
args["where"] = where
|
||||
if limit:
|
||||
args["limit"] = limit
|
||||
return self.collection.get(
|
||||
**args
|
||||
)
|
||||
return self.collection.get(**args)
|
||||
|
||||
def get_advanced(self, where):
|
||||
return self.collection.get(where=where, limit=1)
|
||||
|
||||
@@ -76,5 +76,5 @@ class MockLoader:
|
||||
"content": src,
|
||||
"meta_data": {"url": "none"},
|
||||
}
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, AddConfig, ChunkerConfig
|
||||
from embedchain.config import AddConfig, AppConfig, ChunkerConfig
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user