Adding fetching data functionality for reference links in the web page (#1806)
This commit is contained in:
@@ -15,6 +15,9 @@ title: '📊 add'
|
||||
<ParamField path="metadata" type="dict" optional>
|
||||
Any metadata that you want to store with the data source. Metadata is generally really useful for doing metadata filtering on top of semantic search to yield faster search and better results.
|
||||
</ParamField>
|
||||
<ParamField path="all_references" type="bool" optional>
|
||||
This parameter instructs Embedchain to retrieve all the context and information from the specified link, as well as from any reference links on the page.
|
||||
</ParamField>
|
||||
|
||||
## Usage
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
@@ -15,7 +15,14 @@ class BaseChunker(JSONSerializable):
|
||||
self.text_splitter = text_splitter
|
||||
self.data_type = None
|
||||
|
||||
def create_chunks(self, loader, src, app_id=None, config: Optional[ChunkerConfig] = None):
|
||||
def create_chunks(
|
||||
self,
|
||||
loader,
|
||||
src,
|
||||
app_id=None,
|
||||
config: Optional[ChunkerConfig] = None,
|
||||
**kwargs: Optional[dict[str, Any]],
|
||||
):
|
||||
"""
|
||||
Loads data and chunks it.
|
||||
|
||||
@@ -30,7 +37,7 @@ class BaseChunker(JSONSerializable):
|
||||
id_map = {}
|
||||
min_chunk_size = config.min_chunk_size if config is not None else 1
|
||||
logger.info(f"Skipping chunks smaller than {min_chunk_size} characters")
|
||||
data_result = loader.load_data(src)
|
||||
data_result = loader.load_data(src, **kwargs)
|
||||
data_records = data_result["data"]
|
||||
doc_id = data_result["doc_id"]
|
||||
# Prefix app_id in the document id if app_id is not None to
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from importlib import import_module
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config import AddConfig
|
||||
@@ -40,7 +40,13 @@ class DataFormatter(JSONSerializable):
|
||||
module = import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
def _get_loader(self, data_type: DataType, config: LoaderConfig, loader: Optional[BaseLoader]) -> BaseLoader:
|
||||
def _get_loader(
|
||||
self,
|
||||
data_type: DataType,
|
||||
config: LoaderConfig,
|
||||
loader: Optional[BaseLoader],
|
||||
**kwargs: Optional[dict[str, Any]],
|
||||
) -> BaseLoader:
|
||||
"""
|
||||
Returns the appropriate data loader for the given data type.
|
||||
|
||||
|
||||
@@ -329,7 +329,7 @@ class EmbedChain(JSONSerializable):
|
||||
app_id = self.config.id if self.config is not None else None
|
||||
|
||||
# Create chunks
|
||||
embeddings_data = chunker.create_chunks(loader, src, app_id=app_id, config=add_config.chunker)
|
||||
embeddings_data = chunker.create_chunks(loader, src, app_id=app_id, config=add_config.chunker, **kwargs)
|
||||
# spread chunking results
|
||||
documents = embeddings_data["documents"]
|
||||
metadatas = embeddings_data["metadatas"]
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
|
||||
|
||||
@@ -5,7 +7,7 @@ class BaseLoader(JSONSerializable):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def load_data(self, url):
|
||||
def load_data(self, url, **kwargs: Optional[dict[str, Any]]):
|
||||
"""
|
||||
Implemented by child classes
|
||||
"""
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
|
||||
@@ -22,14 +23,29 @@ class WebPageLoader(BaseLoader):
|
||||
# Shared session for all instances
|
||||
_session = requests.Session()
|
||||
|
||||
def load_data(self, url):
|
||||
def load_data(self, url, **kwargs: Optional[dict[str, Any]]):
|
||||
"""Load data from a web page using a shared requests' session."""
|
||||
all_references = False
|
||||
for key, value in kwargs.items():
|
||||
if key == "all_references":
|
||||
all_references = kwargs["all_references"]
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/98.0.4758.102 Safari/537.36", # noqa:E501
|
||||
}
|
||||
response = self._session.get(url, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
data = response.content
|
||||
reference_links = self.fetch_reference_links(response)
|
||||
if all_references:
|
||||
for i in reference_links:
|
||||
try:
|
||||
response = self._session.get(i, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
data += response.content
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to add URL {url}: {e}")
|
||||
continue
|
||||
|
||||
content = self._get_clean_content(data, url)
|
||||
|
||||
metadata = {"url": url}
|
||||
@@ -98,3 +114,13 @@ class WebPageLoader(BaseLoader):
|
||||
@classmethod
|
||||
def close_session(cls):
|
||||
cls._session.close()
|
||||
|
||||
def fetch_reference_links(self, response):
|
||||
if response.status_code == 200:
|
||||
soup = BeautifulSoup(response.content, "html.parser")
|
||||
a_tags = soup.find_all("a", href=True)
|
||||
reference_links = [a["href"] for a in a_tags if a["href"].startswith("http")]
|
||||
return reference_links
|
||||
else:
|
||||
print(f"Failed to retrieve the page. Status code: {response.status_code}")
|
||||
return []
|
||||
|
||||
@@ -136,6 +136,7 @@ class ChromaDB(BaseVectorDB):
|
||||
documents: list[str],
|
||||
metadatas: list[object],
|
||||
ids: list[str],
|
||||
**kwargs: Optional[dict[str, Any]],
|
||||
) -> Any:
|
||||
"""
|
||||
Add vectors to chroma database
|
||||
|
||||
2
embedchain/poetry.lock
generated
2
embedchain/poetry.lock
generated
@@ -2644,6 +2644,7 @@ description = "Client library to connect to the LangSmith LLM Tracing and Evalua
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8.1"
|
||||
files = [
|
||||
|
||||
{file = "langsmith-0.1.126-py3-none-any.whl", hash = "sha256:16c38ba5dae37a3cc715b6bc5d87d9579228433c2f34d6fa328345ee2b2bcc2a"},
|
||||
{file = "langsmith-0.1.126.tar.gz", hash = "sha256:40f72e2d1d975473dd69269996053122941c1252915bcea55787607e2a7f949a"},
|
||||
]
|
||||
@@ -6750,3 +6751,4 @@ weaviate = ["weaviate-client"]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.9,<=3.13"
|
||||
content-hash = "ec8a87e5281b7fa0c2c28f24c2562e823f0c546a24da2bb285b2f239b7b1758d"
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import hashlib
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from embedchain.loaders.web_page import WebPageLoader
|
||||
|
||||
@@ -115,3 +116,33 @@ def test_get_clean_content_excludes_unnecessary_info(web_page_loader):
|
||||
assert class_name not in content
|
||||
|
||||
assert len(content) > 0
|
||||
|
||||
|
||||
def test_fetch_reference_links_success(web_page_loader):
|
||||
# Mock a successful response
|
||||
response = Mock(spec=requests.Response)
|
||||
response.status_code = 200
|
||||
response.content = b"""
|
||||
<html>
|
||||
<body>
|
||||
<a href="http://example.com">Example</a>
|
||||
<a href="https://another-example.com">Another Example</a>
|
||||
<a href="/relative-link">Relative Link</a>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
expected_links = ["http://example.com", "https://another-example.com"]
|
||||
result = web_page_loader.fetch_reference_links(response)
|
||||
assert result == expected_links
|
||||
|
||||
|
||||
def test_fetch_reference_links_failure(web_page_loader):
|
||||
# Mock a failed response
|
||||
response = Mock(spec=requests.Response)
|
||||
response.status_code = 404
|
||||
response.content = b""
|
||||
|
||||
expected_links = []
|
||||
result = web_page_loader.fetch_reference_links(response)
|
||||
assert result == expected_links
|
||||
|
||||
Reference in New Issue
Block a user