[Feature] Discourse Loader (#948)
Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
44
docs/data-sources/discourse.mdx
Normal file
44
docs/data-sources/discourse.mdx
Normal file
@@ -0,0 +1,44 @@
|
||||
---
|
||||
title: '🗨️ Discourse'
|
||||
---
|
||||
|
||||
You can now easily load data from your community built with [Discourse](https://discourse.org/).
|
||||
|
||||
## Example
|
||||
|
||||
1. Setup the Discourse Loader with your community url.
|
||||
```Python
|
||||
from embedchain.loaders.discourse import DiscourseLoader
|
||||
|
||||
dicourse_loader = DiscourseLoader(config={"domain": "https://community.openai.com"})
|
||||
```
|
||||
|
||||
2. Once you setup the loader, you can create an app and load data using the above discourse loader
|
||||
```Python
|
||||
import os
|
||||
from embedchain.pipeline import Pipeline as App
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "sk-xxx"
|
||||
|
||||
app = App()
|
||||
|
||||
app.add("openai", data_type="discourse", loader=dicourse_loader)
|
||||
|
||||
question = "Where can I find the OpenAI API status page?"
|
||||
app.query(question)
|
||||
# Answer: You can find the OpenAI API status page at https:/status.openai.com/.
|
||||
```
|
||||
|
||||
NOTE: The `add` function of the app will accept any executable search query to load data. Refer [Discourse API Docs](https://docs.discourse.org/#tag/Search) to learn more about search queries.
|
||||
|
||||
3. We automatically create a chunker to chunk your discourse data, however if you wish to provide your own chunker class. Here is how you can do that:
|
||||
```Python
|
||||
|
||||
from embedchain.chunkers.discourse import DiscourseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
|
||||
discourse_chunker_config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
|
||||
discourse_chunker = DiscourseChunker(config=discourse_chunker_config)
|
||||
|
||||
app.add("openai", data_type='discourse', loader=dicourse_loader, chunker=discourse_chunker)
|
||||
```
|
||||
@@ -18,11 +18,12 @@ Embedchain comes with built-in support for various data sources. We handle the c
|
||||
<Card title="🌐📄 web page" href="/data-sources/web-page"></Card>
|
||||
<Card title="🧾 xml" href="/data-sources/xml"></Card>
|
||||
<Card title="🙌 OpenAPI" href="/data-sources/openapi"></Card>
|
||||
<Card title="🎥📺 youtube video" href="/data-sources/youtube-video"></Card>
|
||||
<Card title="📺 youtube video" href="/data-sources/youtube-video"></Card>
|
||||
<Card title="📬 Gmail" href="/data-sources/gmail"></Card>
|
||||
<Card title="🐘 Postgres" href="/data-sources/postgres"></Card>
|
||||
<Card title="🐬 MySQL" href="/data-sources/mysql"></Card>
|
||||
<Card title="🤖 Slack" href="/data-sources/slack"></Card>
|
||||
<Card title="🗨️ Discourse" href="/data-sources/discourse"></Card>
|
||||
</CardGroup>
|
||||
|
||||
<br/ >
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: '🎥📺 Youtube video'
|
||||
title: '📺 Youtube video'
|
||||
---
|
||||
|
||||
|
||||
|
||||
22
embedchain/chunkers/discourse.py
Normal file
22
embedchain/chunkers/discourse.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class DiscourseChunker(BaseChunker):
|
||||
"""Chunker for discourse."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
@@ -70,6 +70,7 @@ class DataFormatter(JSONSerializable):
|
||||
DataType.POSTGRES,
|
||||
DataType.MYSQL,
|
||||
DataType.SLACK,
|
||||
DataType.DISCOURSE,
|
||||
]
|
||||
)
|
||||
|
||||
@@ -110,6 +111,7 @@ class DataFormatter(JSONSerializable):
|
||||
DataType.POSTGRES: "embedchain.chunkers.postgres.PostgresChunker",
|
||||
DataType.MYSQL: "embedchain.chunkers.mysql.MySQLChunker",
|
||||
DataType.SLACK: "embedchain.chunkers.slack.SlackChunker",
|
||||
DataType.DISCOURSE: "embedchain.chunkers.discourse.DiscourseChunker",
|
||||
}
|
||||
|
||||
if data_type in chunker_classes:
|
||||
|
||||
@@ -16,7 +16,8 @@ from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.helper.json_serializable import JSONSerializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
|
||||
from embedchain.models.data_type import (DataType, DirectDataType,
|
||||
IndirectDataType, SpecialDataType)
|
||||
from embedchain.telemetry.posthog import AnonymousTelemetry
|
||||
from embedchain.utils import detect_datatype, is_valid_json_string
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
72
embedchain/loaders/discourse.py
Normal file
72
embedchain/loaders/discourse.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import concurrent.futures
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils import clean_string
|
||||
|
||||
|
||||
class DiscourseLoader(BaseLoader):
|
||||
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||||
super().__init__()
|
||||
if not config:
|
||||
raise ValueError(
|
||||
"DiscourseLoader requires a config. Check the documentation for the correct format - `https://docs.embedchain.ai/data-sources/discourse`" # noqa: E501
|
||||
)
|
||||
|
||||
self.domain = config.get("domain")
|
||||
if not self.domain:
|
||||
raise ValueError(
|
||||
"DiscourseLoader requires a domain. Check the documentation for the correct format - `https://docs.embedchain.ai/data-sources/discourse`" # noqa: E501
|
||||
)
|
||||
|
||||
def _check_query(self, query):
|
||||
if not query or not isinstance(query, str):
|
||||
raise ValueError(
|
||||
"DiscourseLoader requires a query. Check the documentation for the correct format - `https://docs.embedchain.ai/data-sources/discourse`" # noqa: E501
|
||||
)
|
||||
|
||||
def _load_post(self, post_id):
|
||||
post_url = f"{self.domain}/posts/{post_id}.json"
|
||||
response = requests.get(post_url)
|
||||
response.raise_for_status()
|
||||
response_data = response.json()
|
||||
post_contents = clean_string(response_data.get("raw"))
|
||||
meta_data = {
|
||||
"url": post_url,
|
||||
"created_at": response_data.get("created_at", ""),
|
||||
"username": response_data.get("username", ""),
|
||||
"topic_slug": response_data.get("topic_slug", ""),
|
||||
"score": response_data.get("score", ""),
|
||||
}
|
||||
data = {
|
||||
"content": post_contents,
|
||||
"meta_data": meta_data,
|
||||
}
|
||||
return data
|
||||
|
||||
def load_data(self, query):
|
||||
self._check_query(query)
|
||||
data = []
|
||||
data_contents = []
|
||||
logging.info(f"Searching data on discourse url: {self.domain}, for query: {query}")
|
||||
search_url = f"{self.domain}/search.json?q={query}"
|
||||
response = requests.get(search_url)
|
||||
response.raise_for_status()
|
||||
response_data = response.json()
|
||||
post_ids = response_data.get("grouped_search_result").get("post_ids")
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future_to_post_id = {executor.submit(self._load_post, post_id): post_id for post_id in post_ids}
|
||||
for future in concurrent.futures.as_completed(future_to_post_id):
|
||||
post_id = future_to_post_id[future]
|
||||
try:
|
||||
post_data = future.result()
|
||||
data.append(post_data)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to load post {post_id}: {e}")
|
||||
doc_id = hashlib.sha256((query + ", ".join(data_contents)).encode()).hexdigest()
|
||||
response_data = {"doc_id": doc_id, "data": data}
|
||||
return response_data
|
||||
@@ -32,6 +32,7 @@ class IndirectDataType(Enum):
|
||||
POSTGRES = "postgres"
|
||||
MYSQL = "mysql"
|
||||
SLACK = "slack"
|
||||
DISCOURSE = "discourse"
|
||||
|
||||
|
||||
class SpecialDataType(Enum):
|
||||
@@ -63,3 +64,4 @@ class DataType(Enum):
|
||||
POSTGRES = IndirectDataType.POSTGRES.value
|
||||
MYSQL = IndirectDataType.MYSQL.value
|
||||
SLACK = IndirectDataType.SLACK.value
|
||||
DISCOURSE = IndirectDataType.DISCOURSE.value
|
||||
|
||||
@@ -5,11 +5,66 @@ import re
|
||||
import string
|
||||
from typing import Any
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from schema import Optional, Or, Schema
|
||||
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
|
||||
def parse_content(content, type):
|
||||
implemented = ["html.parser", "lxml", "lxml-xml", "xml", "html5lib"]
|
||||
if type not in implemented:
|
||||
raise ValueError(f"Parser type {type} not implemented. Please choose one of {implemented}")
|
||||
|
||||
soup = BeautifulSoup(content, type)
|
||||
original_size = len(str(soup.get_text()))
|
||||
|
||||
tags_to_exclude = [
|
||||
"nav",
|
||||
"aside",
|
||||
"form",
|
||||
"header",
|
||||
"noscript",
|
||||
"svg",
|
||||
"canvas",
|
||||
"footer",
|
||||
"script",
|
||||
"style",
|
||||
]
|
||||
for tag in soup(tags_to_exclude):
|
||||
tag.decompose()
|
||||
|
||||
ids_to_exclude = ["sidebar", "main-navigation", "menu-main-menu"]
|
||||
for id in ids_to_exclude:
|
||||
tags = soup.find_all(id=id)
|
||||
for tag in tags:
|
||||
tag.decompose()
|
||||
|
||||
classes_to_exclude = [
|
||||
"elementor-location-header",
|
||||
"navbar-header",
|
||||
"nav",
|
||||
"header-sidebar-wrapper",
|
||||
"blog-sidebar-wrapper",
|
||||
"related-posts",
|
||||
]
|
||||
for class_name in classes_to_exclude:
|
||||
tags = soup.find_all(class_=class_name)
|
||||
for tag in tags:
|
||||
tag.decompose()
|
||||
|
||||
content = soup.get_text()
|
||||
content = clean_string(content)
|
||||
|
||||
cleaned_size = len(content)
|
||||
if original_size != 0:
|
||||
logging.info(
|
||||
f"Cleaned page size: {cleaned_size} characters, down from {original_size} (shrunk: {original_size-cleaned_size} chars, {round((1-(cleaned_size/original_size)) * 100, 2)}%)" # noqa:E501
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def clean_string(text):
|
||||
"""
|
||||
This function takes in a string and performs a series of text cleaning operations.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from embedchain.chunkers.discourse import DiscourseChunker
|
||||
from embedchain.chunkers.docs_site import DocsSiteChunker
|
||||
from embedchain.chunkers.docx_file import DocxFileChunker
|
||||
from embedchain.chunkers.gmail import GmailChunker
|
||||
@@ -37,6 +38,7 @@ chunker_common_config = {
|
||||
GmailChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
|
||||
PostgresChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
|
||||
SlackChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
|
||||
DiscourseChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,8 @@ from string import Template
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, BaseLlmConfig
|
||||
from embedchain.helper.json_serializable import JSONSerializable, register_deserializable
|
||||
from embedchain.helper.json_serializable import (JSONSerializable,
|
||||
register_deserializable)
|
||||
|
||||
|
||||
class TestJsonSerializable(unittest.TestCase):
|
||||
|
||||
118
tests/loaders/test_discourse.py
Normal file
118
tests/loaders/test_discourse.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from embedchain.loaders.discourse import DiscourseLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def discourse_loader_config():
|
||||
return {
|
||||
"domain": "https://example.com",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def discourse_loader(discourse_loader_config):
|
||||
return DiscourseLoader(config=discourse_loader_config)
|
||||
|
||||
|
||||
def test_discourse_loader_init_with_valid_config():
|
||||
config = {"domain": "https://example.com"}
|
||||
loader = DiscourseLoader(config=config)
|
||||
assert loader.domain == "https://example.com"
|
||||
|
||||
|
||||
def test_discourse_loader_init_with_missing_config():
|
||||
with pytest.raises(ValueError, match="DiscourseLoader requires a config"):
|
||||
DiscourseLoader()
|
||||
|
||||
|
||||
def test_discourse_loader_init_with_missing_domain():
|
||||
config = {"another_key": "value"}
|
||||
with pytest.raises(ValueError, match="DiscourseLoader requires a domain"):
|
||||
DiscourseLoader(config=config)
|
||||
|
||||
|
||||
def test_discourse_loader_check_query_with_valid_query(discourse_loader):
|
||||
discourse_loader._check_query("sample query")
|
||||
|
||||
|
||||
def test_discourse_loader_check_query_with_empty_query(discourse_loader):
|
||||
with pytest.raises(ValueError, match="DiscourseLoader requires a query"):
|
||||
discourse_loader._check_query("")
|
||||
|
||||
|
||||
def test_discourse_loader_check_query_with_invalid_query_type(discourse_loader):
|
||||
with pytest.raises(ValueError, match="DiscourseLoader requires a query"):
|
||||
discourse_loader._check_query(123)
|
||||
|
||||
|
||||
def test_discourse_loader_load_post_with_valid_post_id(discourse_loader, monkeypatch):
|
||||
def mock_get(*args, **kwargs):
|
||||
class MockResponse:
|
||||
def json(self):
|
||||
return {"raw": "Sample post content"}
|
||||
|
||||
def raise_for_status(self):
|
||||
pass
|
||||
|
||||
return MockResponse()
|
||||
|
||||
monkeypatch.setattr(requests, "get", mock_get)
|
||||
|
||||
post_data = discourse_loader._load_post(123)
|
||||
|
||||
assert post_data["content"] == "Sample post content"
|
||||
assert "meta_data" in post_data
|
||||
|
||||
|
||||
def test_discourse_loader_load_post_with_invalid_post_id(discourse_loader, monkeypatch):
|
||||
def mock_get(*args, **kwargs):
|
||||
class MockResponse:
|
||||
def raise_for_status(self):
|
||||
raise requests.exceptions.RequestException("Test error")
|
||||
|
||||
return MockResponse()
|
||||
|
||||
monkeypatch.setattr(requests, "get", mock_get)
|
||||
|
||||
with pytest.raises(Exception, match="Test error"):
|
||||
discourse_loader._load_post(123)
|
||||
|
||||
|
||||
def test_discourse_loader_load_data_with_valid_query(discourse_loader, monkeypatch):
|
||||
def mock_get(*args, **kwargs):
|
||||
class MockResponse:
|
||||
def json(self):
|
||||
return {"grouped_search_result": {"post_ids": [123, 456, 789]}}
|
||||
|
||||
def raise_for_status(self):
|
||||
pass
|
||||
|
||||
return MockResponse()
|
||||
|
||||
monkeypatch.setattr(requests, "get", mock_get)
|
||||
|
||||
def mock_load_post(*args, **kwargs):
|
||||
return {
|
||||
"content": "Sample post content",
|
||||
"meta_data": {
|
||||
"url": "https://example.com/posts/123.json",
|
||||
"created_at": "2021-01-01",
|
||||
"username": "test_user",
|
||||
"topic_slug": "test_topic",
|
||||
"score": 10,
|
||||
},
|
||||
}
|
||||
|
||||
monkeypatch.setattr(discourse_loader, "_load_post", mock_load_post)
|
||||
|
||||
data = discourse_loader.load_data("sample query")
|
||||
|
||||
assert len(data["data"]) == 3
|
||||
assert data["data"][0]["content"] == "Sample post content"
|
||||
assert data["data"][0]["meta_data"]["url"] == "https://example.com/posts/123.json"
|
||||
assert data["data"][0]["meta_data"]["created_at"] == "2021-01-01"
|
||||
assert data["data"][0]["meta_data"]["username"] == "test_user"
|
||||
assert data["data"][0]["meta_data"]["topic_slug"] == "test_topic"
|
||||
assert data["data"][0]["meta_data"]["score"] == 10
|
||||
Reference in New Issue
Block a user