Adding fetching data functionality for reference links in the web page (#1806)

This commit is contained in:
Vatsal Rathod
2024-10-15 07:26:35 -04:00
committed by GitHub
parent 721d765921
commit 20c3aee636
9 changed files with 86 additions and 8 deletions

View File

@@ -15,6 +15,9 @@ title: '📊 add'
<ParamField path="metadata" type="dict" optional>
Any metadata that you want to store with the data source. Metadata is generally really useful for doing metadata filtering on top of semantic search to yield faster search and better results.
</ParamField>
<ParamField path="all_references" type="bool" optional>
This parameter instructs Embedchain to retrieve all the context and information from the specified link, as well as from any reference links on the page.
</ParamField>
## Usage

View File

@@ -1,6 +1,6 @@
import hashlib
import logging
from typing import Optional
from typing import Any, Optional
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import JSONSerializable
@@ -15,7 +15,14 @@ class BaseChunker(JSONSerializable):
self.text_splitter = text_splitter
self.data_type = None
def create_chunks(self, loader, src, app_id=None, config: Optional[ChunkerConfig] = None):
def create_chunks(
self,
loader,
src,
app_id=None,
config: Optional[ChunkerConfig] = None,
**kwargs: Optional[dict[str, Any]],
):
"""
Loads data and chunks it.
@@ -30,7 +37,7 @@ class BaseChunker(JSONSerializable):
id_map = {}
min_chunk_size = config.min_chunk_size if config is not None else 1
logger.info(f"Skipping chunks smaller than {min_chunk_size} characters")
data_result = loader.load_data(src)
data_result = loader.load_data(src, **kwargs)
data_records = data_result["data"]
doc_id = data_result["doc_id"]
# Prefix app_id in the document id if app_id is not None to

View File

@@ -1,5 +1,5 @@
from importlib import import_module
from typing import Optional
from typing import Any, Optional
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig
@@ -40,7 +40,13 @@ class DataFormatter(JSONSerializable):
module = import_module(module_path)
return getattr(module, class_name)
def _get_loader(self, data_type: DataType, config: LoaderConfig, loader: Optional[BaseLoader]) -> BaseLoader:
def _get_loader(
self,
data_type: DataType,
config: LoaderConfig,
loader: Optional[BaseLoader],
**kwargs: Optional[dict[str, Any]],
) -> BaseLoader:
"""
Returns the appropriate data loader for the given data type.

View File

@@ -329,7 +329,7 @@ class EmbedChain(JSONSerializable):
app_id = self.config.id if self.config is not None else None
# Create chunks
embeddings_data = chunker.create_chunks(loader, src, app_id=app_id, config=add_config.chunker)
embeddings_data = chunker.create_chunks(loader, src, app_id=app_id, config=add_config.chunker, **kwargs)
# spread chunking results
documents = embeddings_data["documents"]
metadatas = embeddings_data["metadatas"]

View File

@@ -1,3 +1,5 @@
from typing import Any, Optional
from embedchain.helpers.json_serializable import JSONSerializable
@@ -5,7 +7,7 @@ class BaseLoader(JSONSerializable):
def __init__(self):
pass
def load_data(self, url):
def load_data(self, url, **kwargs: Optional[dict[str, Any]]):
"""
Implemented by child classes
"""

View File

@@ -1,5 +1,6 @@
import hashlib
import logging
from typing import Any, Optional
import requests
@@ -22,14 +23,29 @@ class WebPageLoader(BaseLoader):
# Shared session for all instances
_session = requests.Session()
def load_data(self, url):
def load_data(self, url, **kwargs: Optional[dict[str, Any]]):
"""Load data from a web page using a shared requests' session."""
all_references = False
for key, value in kwargs.items():
if key == "all_references":
all_references = kwargs["all_references"]
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/98.0.4758.102 Safari/537.36", # noqa:E501
}
response = self._session.get(url, headers=headers, timeout=30)
response.raise_for_status()
data = response.content
reference_links = self.fetch_reference_links(response)
if all_references:
for i in reference_links:
try:
response = self._session.get(i, headers=headers, timeout=30)
response.raise_for_status()
data += response.content
except Exception as e:
logging.error(f"Failed to add URL {url}: {e}")
continue
content = self._get_clean_content(data, url)
metadata = {"url": url}
@@ -98,3 +114,13 @@ class WebPageLoader(BaseLoader):
@classmethod
def close_session(cls):
cls._session.close()
def fetch_reference_links(self, response):
if response.status_code == 200:
soup = BeautifulSoup(response.content, "html.parser")
a_tags = soup.find_all("a", href=True)
reference_links = [a["href"] for a in a_tags if a["href"].startswith("http")]
return reference_links
else:
print(f"Failed to retrieve the page. Status code: {response.status_code}")
return []

View File

@@ -136,6 +136,7 @@ class ChromaDB(BaseVectorDB):
documents: list[str],
metadatas: list[object],
ids: list[str],
**kwargs: Optional[dict[str, Any]],
) -> Any:
"""
Add vectors to chroma database

View File

@@ -2644,6 +2644,7 @@ description = "Client library to connect to the LangSmith LLM Tracing and Evalua
optional = false
python-versions = "<4.0,>=3.8.1"
files = [
{file = "langsmith-0.1.126-py3-none-any.whl", hash = "sha256:16c38ba5dae37a3cc715b6bc5d87d9579228433c2f34d6fa328345ee2b2bcc2a"},
{file = "langsmith-0.1.126.tar.gz", hash = "sha256:40f72e2d1d975473dd69269996053122941c1252915bcea55787607e2a7f949a"},
]
@@ -6750,3 +6751,4 @@ weaviate = ["weaviate-client"]
lock-version = "2.0"
python-versions = ">=3.9,<=3.13"
content-hash = "ec8a87e5281b7fa0c2c28f24c2562e823f0c546a24da2bb285b2f239b7b1758d"

View File

@@ -2,6 +2,7 @@ import hashlib
from unittest.mock import Mock, patch
import pytest
import requests
from embedchain.loaders.web_page import WebPageLoader
@@ -115,3 +116,33 @@ def test_get_clean_content_excludes_unnecessary_info(web_page_loader):
assert class_name not in content
assert len(content) > 0
def test_fetch_reference_links_success(web_page_loader):
# Mock a successful response
response = Mock(spec=requests.Response)
response.status_code = 200
response.content = b"""
<html>
<body>
<a href="http://example.com">Example</a>
<a href="https://another-example.com">Another Example</a>
<a href="/relative-link">Relative Link</a>
</body>
</html>
"""
expected_links = ["http://example.com", "https://another-example.com"]
result = web_page_loader.fetch_reference_links(response)
assert result == expected_links
def test_fetch_reference_links_failure(web_page_loader):
# Mock a failed response
response = Mock(spec=requests.Response)
response.status_code = 404
response.content = b""
expected_links = []
result = web_page_loader.fetch_reference_links(response)
assert result == expected_links