[Feature] Discourse Loader (#948)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2023-11-13 16:39:11 -08:00
committed by GitHub
parent 919cc74e94
commit 95c0d47236
12 changed files with 324 additions and 4 deletions

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helper.json_serializable import register_deserializable
@register_deserializable
class DiscourseChunker(BaseChunker):
"""Chunker for discourse."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -70,6 +70,7 @@ class DataFormatter(JSONSerializable):
DataType.POSTGRES,
DataType.MYSQL,
DataType.SLACK,
DataType.DISCOURSE,
]
)
@@ -110,6 +111,7 @@ class DataFormatter(JSONSerializable):
DataType.POSTGRES: "embedchain.chunkers.postgres.PostgresChunker",
DataType.MYSQL: "embedchain.chunkers.mysql.MySQLChunker",
DataType.SLACK: "embedchain.chunkers.slack.SlackChunker",
DataType.DISCOURSE: "embedchain.chunkers.discourse.DiscourseChunker",
}
if data_type in chunker_classes:

View File

@@ -16,7 +16,8 @@ from embedchain.embedder.base import BaseEmbedder
from embedchain.helper.json_serializable import JSONSerializable
from embedchain.llm.base import BaseLlm
from embedchain.loaders.base_loader import BaseLoader
from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
from embedchain.models.data_type import (DataType, DirectDataType,
IndirectDataType, SpecialDataType)
from embedchain.telemetry.posthog import AnonymousTelemetry
from embedchain.utils import detect_datatype, is_valid_json_string
from embedchain.vectordb.base import BaseVectorDB

View File

@@ -0,0 +1,72 @@
import concurrent.futures
import hashlib
import logging
from typing import Any, Dict, Optional
import requests
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils import clean_string
class DiscourseLoader(BaseLoader):
def __init__(self, config: Optional[Dict[str, Any]] = None):
super().__init__()
if not config:
raise ValueError(
"DiscourseLoader requires a config. Check the documentation for the correct format - `https://docs.embedchain.ai/data-sources/discourse`" # noqa: E501
)
self.domain = config.get("domain")
if not self.domain:
raise ValueError(
"DiscourseLoader requires a domain. Check the documentation for the correct format - `https://docs.embedchain.ai/data-sources/discourse`" # noqa: E501
)
def _check_query(self, query):
if not query or not isinstance(query, str):
raise ValueError(
"DiscourseLoader requires a query. Check the documentation for the correct format - `https://docs.embedchain.ai/data-sources/discourse`" # noqa: E501
)
def _load_post(self, post_id):
post_url = f"{self.domain}/posts/{post_id}.json"
response = requests.get(post_url)
response.raise_for_status()
response_data = response.json()
post_contents = clean_string(response_data.get("raw"))
meta_data = {
"url": post_url,
"created_at": response_data.get("created_at", ""),
"username": response_data.get("username", ""),
"topic_slug": response_data.get("topic_slug", ""),
"score": response_data.get("score", ""),
}
data = {
"content": post_contents,
"meta_data": meta_data,
}
return data
def load_data(self, query):
self._check_query(query)
data = []
data_contents = []
logging.info(f"Searching data on discourse url: {self.domain}, for query: {query}")
search_url = f"{self.domain}/search.json?q={query}"
response = requests.get(search_url)
response.raise_for_status()
response_data = response.json()
post_ids = response_data.get("grouped_search_result").get("post_ids")
with concurrent.futures.ThreadPoolExecutor() as executor:
future_to_post_id = {executor.submit(self._load_post, post_id): post_id for post_id in post_ids}
for future in concurrent.futures.as_completed(future_to_post_id):
post_id = future_to_post_id[future]
try:
post_data = future.result()
data.append(post_data)
except Exception as e:
logging.error(f"Failed to load post {post_id}: {e}")
doc_id = hashlib.sha256((query + ", ".join(data_contents)).encode()).hexdigest()
response_data = {"doc_id": doc_id, "data": data}
return response_data

View File

@@ -32,6 +32,7 @@ class IndirectDataType(Enum):
POSTGRES = "postgres"
MYSQL = "mysql"
SLACK = "slack"
DISCOURSE = "discourse"
class SpecialDataType(Enum):
@@ -63,3 +64,4 @@ class DataType(Enum):
POSTGRES = IndirectDataType.POSTGRES.value
MYSQL = IndirectDataType.MYSQL.value
SLACK = IndirectDataType.SLACK.value
DISCOURSE = IndirectDataType.DISCOURSE.value

View File

@@ -5,11 +5,66 @@ import re
import string
from typing import Any
from bs4 import BeautifulSoup
from schema import Optional, Or, Schema
from embedchain.models.data_type import DataType
def parse_content(content, type):
implemented = ["html.parser", "lxml", "lxml-xml", "xml", "html5lib"]
if type not in implemented:
raise ValueError(f"Parser type {type} not implemented. Please choose one of {implemented}")
soup = BeautifulSoup(content, type)
original_size = len(str(soup.get_text()))
tags_to_exclude = [
"nav",
"aside",
"form",
"header",
"noscript",
"svg",
"canvas",
"footer",
"script",
"style",
]
for tag in soup(tags_to_exclude):
tag.decompose()
ids_to_exclude = ["sidebar", "main-navigation", "menu-main-menu"]
for id in ids_to_exclude:
tags = soup.find_all(id=id)
for tag in tags:
tag.decompose()
classes_to_exclude = [
"elementor-location-header",
"navbar-header",
"nav",
"header-sidebar-wrapper",
"blog-sidebar-wrapper",
"related-posts",
]
for class_name in classes_to_exclude:
tags = soup.find_all(class_=class_name)
for tag in tags:
tag.decompose()
content = soup.get_text()
content = clean_string(content)
cleaned_size = len(content)
if original_size != 0:
logging.info(
f"Cleaned page size: {cleaned_size} characters, down from {original_size} (shrunk: {original_size-cleaned_size} chars, {round((1-(cleaned_size/original_size)) * 100, 2)}%)" # noqa:E501
)
return content
def clean_string(text):
"""
This function takes in a string and performs a series of text cleaning operations.