[Feature] RSS Feed loader (#942)
This commit is contained in:
22
embedchain/chunkers/rss_feed.py
Normal file
22
embedchain/chunkers/rss_feed.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class RSSFeedChunker(BaseChunker):
|
||||
"""Chunker for RSS Feed."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=2000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
@@ -72,6 +72,7 @@ class DataFormatter(JSONSerializable):
|
||||
DataType.SUBSTACK: "embedchain.loaders.substack.SubstackLoader",
|
||||
DataType.YOUTUBE_CHANNEL: "embedchain.loaders.youtube_channel.YoutubeChannelLoader",
|
||||
DataType.DISCORD: "embedchain.loaders.discord.DiscordLoader",
|
||||
DataType.RSSFEED: "embedchain.loaders.rss_feed.RSSFeedLoader",
|
||||
DataType.BEEHIIV: "embedchain.loaders.beehiiv.BeehiivLoader",
|
||||
}
|
||||
|
||||
@@ -113,6 +114,7 @@ class DataFormatter(JSONSerializable):
|
||||
DataType.YOUTUBE_CHANNEL: "embedchain.chunkers.common_chunker.CommonChunker",
|
||||
DataType.DISCORD: "embedchain.chunkers.common_chunker.CommonChunker",
|
||||
DataType.CUSTOM: "embedchain.chunkers.common_chunker.CommonChunker",
|
||||
DataType.RSSFEED: "embedchain.chunkers.rss_feed.RSSFeedChunker",
|
||||
DataType.BEEHIIV: "embedchain.chunkers.beehiiv.BeehiivChunker",
|
||||
}
|
||||
|
||||
|
||||
52
embedchain/loaders/rss_feed.py
Normal file
52
embedchain/loaders/rss_feed.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import hashlib
|
||||
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class RSSFeedLoader(BaseLoader):
|
||||
"""Loader for RSS Feed."""
|
||||
|
||||
def load_data(self, url):
|
||||
"""Load data from a rss feed."""
|
||||
output = self.get_rss_content(url)
|
||||
doc_id = hashlib.sha256((str(output) + url).encode()).hexdigest()
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": output,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def serialize_metadata(metadata):
|
||||
for key, value in metadata.items():
|
||||
if not isinstance(value, (str, int, float, bool)):
|
||||
metadata[key] = str(value)
|
||||
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def get_rss_content(url: str):
|
||||
try:
|
||||
from langchain.document_loaders import RSSFeedLoader as LangchainRSSFeedLoader
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"""RSSFeedLoader file requires extra dependencies.
|
||||
Install with `pip install --upgrade "embedchain[rss_feed]"`"""
|
||||
) from None
|
||||
|
||||
output = []
|
||||
loader = LangchainRSSFeedLoader(urls=[url])
|
||||
data = loader.load()
|
||||
|
||||
for entry in data:
|
||||
meta_data = RSSFeedLoader.serialize_metadata(entry.metadata)
|
||||
meta_data.update({"url": url})
|
||||
output.append(
|
||||
{
|
||||
"content": entry.page_content,
|
||||
"meta_data": meta_data,
|
||||
}
|
||||
)
|
||||
|
||||
return output
|
||||
@@ -33,6 +33,7 @@ class IndirectDataType(Enum):
|
||||
YOUTUBE_CHANNEL = "youtube_channel"
|
||||
DISCORD = "discord"
|
||||
CUSTOM = "custom"
|
||||
RSSFEED = "rss_feed"
|
||||
BEEHIIV = "beehiiv"
|
||||
|
||||
|
||||
@@ -66,4 +67,5 @@ class DataType(Enum):
|
||||
YOUTUBE_CHANNEL = IndirectDataType.YOUTUBE_CHANNEL.value
|
||||
DISCORD = IndirectDataType.DISCORD.value
|
||||
CUSTOM = IndirectDataType.CUSTOM.value
|
||||
RSSFEED = IndirectDataType.RSSFEED.value
|
||||
BEEHIIV = IndirectDataType.BEEHIIV.value
|
||||
|
||||
@@ -196,8 +196,7 @@ def detect_datatype(source: Any) -> DataType:
|
||||
formatted_source = format_source(str(source), 30)
|
||||
|
||||
if url:
|
||||
from langchain.document_loaders.youtube import \
|
||||
ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
|
||||
from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
|
||||
|
||||
if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
|
||||
|
||||
Reference in New Issue
Block a user