Adding fetching data functionality for reference links in the web page (#1806)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
@@ -15,7 +15,14 @@ class BaseChunker(JSONSerializable):
|
||||
self.text_splitter = text_splitter
|
||||
self.data_type = None
|
||||
|
||||
def create_chunks(self, loader, src, app_id=None, config: Optional[ChunkerConfig] = None):
|
||||
def create_chunks(
|
||||
self,
|
||||
loader,
|
||||
src,
|
||||
app_id=None,
|
||||
config: Optional[ChunkerConfig] = None,
|
||||
**kwargs: Optional[dict[str, Any]],
|
||||
):
|
||||
"""
|
||||
Loads data and chunks it.
|
||||
|
||||
@@ -30,7 +37,7 @@ class BaseChunker(JSONSerializable):
|
||||
id_map = {}
|
||||
min_chunk_size = config.min_chunk_size if config is not None else 1
|
||||
logger.info(f"Skipping chunks smaller than {min_chunk_size} characters")
|
||||
data_result = loader.load_data(src)
|
||||
data_result = loader.load_data(src, **kwargs)
|
||||
data_records = data_result["data"]
|
||||
doc_id = data_result["doc_id"]
|
||||
# Prefix app_id in the document id if app_id is not None to
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from importlib import import_module
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config import AddConfig
|
||||
@@ -40,7 +40,13 @@ class DataFormatter(JSONSerializable):
|
||||
module = import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
def _get_loader(self, data_type: DataType, config: LoaderConfig, loader: Optional[BaseLoader]) -> BaseLoader:
|
||||
def _get_loader(
|
||||
self,
|
||||
data_type: DataType,
|
||||
config: LoaderConfig,
|
||||
loader: Optional[BaseLoader],
|
||||
**kwargs: Optional[dict[str, Any]],
|
||||
) -> BaseLoader:
|
||||
"""
|
||||
Returns the appropriate data loader for the given data type.
|
||||
|
||||
|
||||
@@ -329,7 +329,7 @@ class EmbedChain(JSONSerializable):
|
||||
app_id = self.config.id if self.config is not None else None
|
||||
|
||||
# Create chunks
|
||||
embeddings_data = chunker.create_chunks(loader, src, app_id=app_id, config=add_config.chunker)
|
||||
embeddings_data = chunker.create_chunks(loader, src, app_id=app_id, config=add_config.chunker, **kwargs)
|
||||
# spread chunking results
|
||||
documents = embeddings_data["documents"]
|
||||
metadatas = embeddings_data["metadatas"]
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
|
||||
|
||||
@@ -5,7 +7,7 @@ class BaseLoader(JSONSerializable):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def load_data(self, url):
|
||||
def load_data(self, url, **kwargs: Optional[dict[str, Any]]):
|
||||
"""
|
||||
Implemented by child classes
|
||||
"""
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
|
||||
@@ -22,14 +23,29 @@ class WebPageLoader(BaseLoader):
|
||||
# Shared session for all instances
|
||||
_session = requests.Session()
|
||||
|
||||
def load_data(self, url):
|
||||
def load_data(self, url, **kwargs: Optional[dict[str, Any]]):
|
||||
"""Load data from a web page using a shared requests' session."""
|
||||
all_references = False
|
||||
for key, value in kwargs.items():
|
||||
if key == "all_references":
|
||||
all_references = kwargs["all_references"]
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/98.0.4758.102 Safari/537.36", # noqa:E501
|
||||
}
|
||||
response = self._session.get(url, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
data = response.content
|
||||
reference_links = self.fetch_reference_links(response)
|
||||
if all_references:
|
||||
for i in reference_links:
|
||||
try:
|
||||
response = self._session.get(i, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
data += response.content
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to add URL {url}: {e}")
|
||||
continue
|
||||
|
||||
content = self._get_clean_content(data, url)
|
||||
|
||||
metadata = {"url": url}
|
||||
@@ -98,3 +114,13 @@ class WebPageLoader(BaseLoader):
|
||||
@classmethod
|
||||
def close_session(cls):
|
||||
cls._session.close()
|
||||
|
||||
def fetch_reference_links(self, response):
|
||||
if response.status_code == 200:
|
||||
soup = BeautifulSoup(response.content, "html.parser")
|
||||
a_tags = soup.find_all("a", href=True)
|
||||
reference_links = [a["href"] for a in a_tags if a["href"].startswith("http")]
|
||||
return reference_links
|
||||
else:
|
||||
print(f"Failed to retrieve the page. Status code: {response.status_code}")
|
||||
return []
|
||||
|
||||
@@ -136,6 +136,7 @@ class ChromaDB(BaseVectorDB):
|
||||
documents: list[str],
|
||||
metadatas: list[object],
|
||||
ids: list[str],
|
||||
**kwargs: Optional[dict[str, Any]],
|
||||
) -> Any:
|
||||
"""
|
||||
Add vectors to chroma database
|
||||
|
||||
Reference in New Issue
Block a user