chore: linting (#597)
This commit is contained in:
@@ -242,7 +242,7 @@ class EmbedChain(JSONSerializable):
|
|||||||
src: Any,
|
src: Any,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
source_id: Optional[str] = None,
|
source_id: Optional[str] = None,
|
||||||
dry_run = False
|
dry_run=False,
|
||||||
) -> Tuple[List[str], Dict[str, Any], List[str], int]:
|
) -> Tuple[List[str], Dict[str, Any], List[str], int]:
|
||||||
"""The loader to use to load the data.
|
"""The loader to use to load the data.
|
||||||
|
|
||||||
@@ -320,14 +320,14 @@ class EmbedChain(JSONSerializable):
|
|||||||
return list(documents), metadatas, ids, count_new_chunks
|
return list(documents), metadatas, ids, count_new_chunks
|
||||||
|
|
||||||
def load_and_embed_v2(
|
def load_and_embed_v2(
|
||||||
self,
|
self,
|
||||||
loader: BaseLoader,
|
loader: BaseLoader,
|
||||||
chunker: BaseChunker,
|
chunker: BaseChunker,
|
||||||
src: Any,
|
src: Any,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
source_id: Optional[str] = None,
|
source_id: Optional[str] = None,
|
||||||
dry_run = False
|
dry_run=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Loads the data from the given URL, chunks it, and adds it to database.
|
Loads the data from the given URL, chunks it, and adds it to database.
|
||||||
|
|
||||||
@@ -364,9 +364,7 @@ class EmbedChain(JSONSerializable):
|
|||||||
# this means that doc content has changed.
|
# this means that doc content has changed.
|
||||||
if existing_doc_id and existing_doc_id != new_doc_id:
|
if existing_doc_id and existing_doc_id != new_doc_id:
|
||||||
print("Doc content has changed. Recomputing chunks and embeddings intelligently.")
|
print("Doc content has changed. Recomputing chunks and embeddings intelligently.")
|
||||||
self.db.delete({
|
self.db.delete({"doc_id": existing_doc_id})
|
||||||
"doc_id": existing_doc_id
|
|
||||||
})
|
|
||||||
|
|
||||||
# get existing ids, and discard doc if any common id exist.
|
# get existing ids, and discard doc if any common id exist.
|
||||||
where = {"app_id": self.config.id} if self.config.id is not None else {}
|
where = {"app_id": self.config.id} if self.config.id is not None else {}
|
||||||
|
|||||||
@@ -46,7 +46,4 @@ class CsvLoader(BaseLoader):
|
|||||||
lines.append(line)
|
lines.append(line)
|
||||||
result.append({"content": line, "meta_data": {"url": content, "row": i + 1}})
|
result.append({"content": line, "meta_data": {"url": content, "row": i + 1}})
|
||||||
doc_id = hashlib.sha256((content + " ".join(lines)).encode()).hexdigest()
|
doc_id = hashlib.sha256((content + " ".join(lines)).encode()).hexdigest()
|
||||||
return {
|
return {"doc_id": doc_id, "data": result}
|
||||||
"doc_id": doc_id,
|
|
||||||
"data": result
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -22,5 +22,5 @@ class LocalQnaPairLoader(BaseLoader):
|
|||||||
"content": content,
|
"content": content,
|
||||||
"meta_data": meta_data,
|
"meta_data": meta_data,
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,5 +20,5 @@ class LocalTextLoader(BaseLoader):
|
|||||||
"content": content,
|
"content": content,
|
||||||
"meta_data": meta_data,
|
"meta_data": meta_data,
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,9 +39,9 @@ class NotionLoader(BaseLoader):
|
|||||||
return {
|
return {
|
||||||
"doc_id": doc_id,
|
"doc_id": doc_id,
|
||||||
"data": [
|
"data": [
|
||||||
{
|
{
|
||||||
"content": text,
|
"content": text,
|
||||||
"meta_data": {"url": f"notion-{formatted_id}"},
|
"meta_data": {"url": f"notion-{formatted_id}"},
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -43,7 +43,4 @@ class SitemapLoader(BaseLoader):
|
|||||||
logging.warning(f"Page is not readable (too many invalid characters): {link}")
|
logging.warning(f"Page is not readable (too many invalid characters): {link}")
|
||||||
except ParserRejectedMarkup as e:
|
except ParserRejectedMarkup as e:
|
||||||
logging.error(f"Failed to parse {link}: {e}")
|
logging.error(f"Failed to parse {link}: {e}")
|
||||||
return {
|
return {"doc_id": doc_id, "data": [data[0] for data in output]}
|
||||||
"doc_id": doc_id,
|
|
||||||
"data": [data[0] for data in output]
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ class WebPageLoader(BaseLoader):
|
|||||||
}
|
}
|
||||||
content = content
|
content = content
|
||||||
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
|
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
|
||||||
return {
|
return {
|
||||||
"doc_id": doc_id,
|
"doc_id": doc_id,
|
||||||
"data": [
|
"data": [
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -47,4 +47,4 @@ class BaseVectorDB(JSONSerializable):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def set_collection_name(self, name: str):
|
def set_collection_name(self, name: str):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional, Any
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from chromadb import Collection, QueryResult
|
from chromadb import Collection, QueryResult
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
@@ -105,9 +105,7 @@ class ChromaDB(BaseVectorDB):
|
|||||||
args["where"] = where
|
args["where"] = where
|
||||||
if limit:
|
if limit:
|
||||||
args["limit"] = limit
|
args["limit"] = limit
|
||||||
return self.collection.get(
|
return self.collection.get(**args)
|
||||||
**args
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_advanced(self, where):
|
def get_advanced(self, where):
|
||||||
return self.collection.get(where=where, limit=1)
|
return self.collection.get(where=where, limit=1)
|
||||||
|
|||||||
@@ -76,5 +76,5 @@ class MockLoader:
|
|||||||
"content": src,
|
"content": src,
|
||||||
"meta_data": {"url": "none"},
|
"meta_data": {"url": "none"},
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import unittest
|
|||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from embedchain import App
|
from embedchain import App
|
||||||
from embedchain.config import AppConfig, AddConfig, ChunkerConfig
|
from embedchain.config import AddConfig, AppConfig, ChunkerConfig
|
||||||
from embedchain.models.data_type import DataType
|
from embedchain.models.data_type import DataType
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user