diff --git a/docs/data-sources/discourse.mdx b/docs/data-sources/discourse.mdx
new file mode 100644
index 00000000..51daa040
--- /dev/null
+++ b/docs/data-sources/discourse.mdx
@@ -0,0 +1,44 @@
+---
+title: '๐จ๏ธ Discourse'
+---
+
+You can now easily load data from your community built with [Discourse](https://discourse.org/).
+
+## Example
+
+1. Setup the Discourse Loader with your community url.
+```Python
+from embedchain.loaders.discourse import DiscourseLoader
+
+dicourse_loader = DiscourseLoader(config={"domain": "https://community.openai.com"})
+```
+
+2. Once you setup the loader, you can create an app and load data using the above discourse loader
+```Python
+import os
+from embedchain.pipeline import Pipeline as App
+
+os.environ["OPENAI_API_KEY"] = "sk-xxx"
+
+app = App()
+
+app.add("openai", data_type="discourse", loader=dicourse_loader)
+
+question = "Where can I find the OpenAI API status page?"
+app.query(question)
+# Answer: You can find the OpenAI API status page at https:/status.openai.com/.
+```
+
+NOTE: The `add` function of the app will accept any executable search query to load data. Refer [Discourse API Docs](https://docs.discourse.org/#tag/Search) to learn more about search queries.
+
+3. We automatically create a chunker to chunk your discourse data, however if you wish to provide your own chunker class. Here is how you can do that:
+```Python
+
+from embedchain.chunkers.discourse import DiscourseChunker
+from embedchain.config.add_config import ChunkerConfig
+
+discourse_chunker_config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
+discourse_chunker = DiscourseChunker(config=discourse_chunker_config)
+
+app.add("openai", data_type='discourse', loader=dicourse_loader, chunker=discourse_chunker)
+```
\ No newline at end of file
diff --git a/docs/data-sources/overview.mdx b/docs/data-sources/overview.mdx
index 7ea2a96b..08b2a572 100644
--- a/docs/data-sources/overview.mdx
+++ b/docs/data-sources/overview.mdx
@@ -18,11 +18,12 @@ Embedchain comes with built-in support for various data sources. We handle the c
-
+
+
diff --git a/docs/data-sources/youtube-video.mdx b/docs/data-sources/youtube-video.mdx
index 5baf2f9a..aed31e63 100644
--- a/docs/data-sources/youtube-video.mdx
+++ b/docs/data-sources/youtube-video.mdx
@@ -1,5 +1,5 @@
---
-title: '๐ฅ๐บ Youtube video'
+title: '๐บ Youtube video'
---
diff --git a/embedchain/chunkers/discourse.py b/embedchain/chunkers/discourse.py
new file mode 100644
index 00000000..f78c616e
--- /dev/null
+++ b/embedchain/chunkers/discourse.py
@@ -0,0 +1,22 @@
+from typing import Optional
+
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+
+from embedchain.chunkers.base_chunker import BaseChunker
+from embedchain.config.add_config import ChunkerConfig
+from embedchain.helper.json_serializable import register_deserializable
+
+
+@register_deserializable
+class DiscourseChunker(BaseChunker):
+ """Chunker for discourse."""
+
+ def __init__(self, config: Optional[ChunkerConfig] = None):
+ if config is None:
+ config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
+ text_splitter = RecursiveCharacterTextSplitter(
+ chunk_size=config.chunk_size,
+ chunk_overlap=config.chunk_overlap,
+ length_function=config.length_function,
+ )
+ super().__init__(text_splitter)
diff --git a/embedchain/data_formatter/data_formatter.py b/embedchain/data_formatter/data_formatter.py
index e0b16b61..49bbd650 100644
--- a/embedchain/data_formatter/data_formatter.py
+++ b/embedchain/data_formatter/data_formatter.py
@@ -70,6 +70,7 @@ class DataFormatter(JSONSerializable):
DataType.POSTGRES,
DataType.MYSQL,
DataType.SLACK,
+ DataType.DISCOURSE,
]
)
@@ -110,6 +111,7 @@ class DataFormatter(JSONSerializable):
DataType.POSTGRES: "embedchain.chunkers.postgres.PostgresChunker",
DataType.MYSQL: "embedchain.chunkers.mysql.MySQLChunker",
DataType.SLACK: "embedchain.chunkers.slack.SlackChunker",
+ DataType.DISCOURSE: "embedchain.chunkers.discourse.DiscourseChunker",
}
if data_type in chunker_classes:
diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py
index 67b4770c..88c098a6 100644
--- a/embedchain/embedchain.py
+++ b/embedchain/embedchain.py
@@ -16,7 +16,8 @@ from embedchain.embedder.base import BaseEmbedder
from embedchain.helper.json_serializable import JSONSerializable
from embedchain.llm.base import BaseLlm
from embedchain.loaders.base_loader import BaseLoader
-from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
+from embedchain.models.data_type import (DataType, DirectDataType,
+ IndirectDataType, SpecialDataType)
from embedchain.telemetry.posthog import AnonymousTelemetry
from embedchain.utils import detect_datatype, is_valid_json_string
from embedchain.vectordb.base import BaseVectorDB
diff --git a/embedchain/loaders/discourse.py b/embedchain/loaders/discourse.py
new file mode 100644
index 00000000..4111b847
--- /dev/null
+++ b/embedchain/loaders/discourse.py
@@ -0,0 +1,72 @@
+import concurrent.futures
+import hashlib
+import logging
+from typing import Any, Dict, Optional
+
+import requests
+
+from embedchain.loaders.base_loader import BaseLoader
+from embedchain.utils import clean_string
+
+
+class DiscourseLoader(BaseLoader):
+ def __init__(self, config: Optional[Dict[str, Any]] = None):
+ super().__init__()
+ if not config:
+ raise ValueError(
+ "DiscourseLoader requires a config. Check the documentation for the correct format - `https://docs.embedchain.ai/data-sources/discourse`" # noqa: E501
+ )
+
+ self.domain = config.get("domain")
+ if not self.domain:
+ raise ValueError(
+ "DiscourseLoader requires a domain. Check the documentation for the correct format - `https://docs.embedchain.ai/data-sources/discourse`" # noqa: E501
+ )
+
+ def _check_query(self, query):
+ if not query or not isinstance(query, str):
+ raise ValueError(
+ "DiscourseLoader requires a query. Check the documentation for the correct format - `https://docs.embedchain.ai/data-sources/discourse`" # noqa: E501
+ )
+
+ def _load_post(self, post_id):
+ post_url = f"{self.domain}/posts/{post_id}.json"
+ response = requests.get(post_url)
+ response.raise_for_status()
+ response_data = response.json()
+ post_contents = clean_string(response_data.get("raw"))
+ meta_data = {
+ "url": post_url,
+ "created_at": response_data.get("created_at", ""),
+ "username": response_data.get("username", ""),
+ "topic_slug": response_data.get("topic_slug", ""),
+ "score": response_data.get("score", ""),
+ }
+ data = {
+ "content": post_contents,
+ "meta_data": meta_data,
+ }
+ return data
+
+ def load_data(self, query):
+ self._check_query(query)
+ data = []
+ data_contents = []
+ logging.info(f"Searching data on discourse url: {self.domain}, for query: {query}")
+ search_url = f"{self.domain}/search.json?q={query}"
+ response = requests.get(search_url)
+ response.raise_for_status()
+ response_data = response.json()
+ post_ids = response_data.get("grouped_search_result").get("post_ids")
+ with concurrent.futures.ThreadPoolExecutor() as executor:
+ future_to_post_id = {executor.submit(self._load_post, post_id): post_id for post_id in post_ids}
+ for future in concurrent.futures.as_completed(future_to_post_id):
+ post_id = future_to_post_id[future]
+ try:
+ post_data = future.result()
+ data.append(post_data)
+ except Exception as e:
+ logging.error(f"Failed to load post {post_id}: {e}")
+ doc_id = hashlib.sha256((query + ", ".join(data_contents)).encode()).hexdigest()
+ response_data = {"doc_id": doc_id, "data": data}
+ return response_data
diff --git a/embedchain/models/data_type.py b/embedchain/models/data_type.py
index eb5fa8d2..d973d766 100644
--- a/embedchain/models/data_type.py
+++ b/embedchain/models/data_type.py
@@ -32,6 +32,7 @@ class IndirectDataType(Enum):
POSTGRES = "postgres"
MYSQL = "mysql"
SLACK = "slack"
+ DISCOURSE = "discourse"
class SpecialDataType(Enum):
@@ -63,3 +64,4 @@ class DataType(Enum):
POSTGRES = IndirectDataType.POSTGRES.value
MYSQL = IndirectDataType.MYSQL.value
SLACK = IndirectDataType.SLACK.value
+ DISCOURSE = IndirectDataType.DISCOURSE.value
diff --git a/embedchain/utils.py b/embedchain/utils.py
index 811375f7..2abd2ad7 100644
--- a/embedchain/utils.py
+++ b/embedchain/utils.py
@@ -5,11 +5,66 @@ import re
import string
from typing import Any
+from bs4 import BeautifulSoup
from schema import Optional, Or, Schema
from embedchain.models.data_type import DataType
+def parse_content(content, type):
+ implemented = ["html.parser", "lxml", "lxml-xml", "xml", "html5lib"]
+ if type not in implemented:
+ raise ValueError(f"Parser type {type} not implemented. Please choose one of {implemented}")
+
+ soup = BeautifulSoup(content, type)
+ original_size = len(str(soup.get_text()))
+
+ tags_to_exclude = [
+ "nav",
+ "aside",
+ "form",
+ "header",
+ "noscript",
+ "svg",
+ "canvas",
+ "footer",
+ "script",
+ "style",
+ ]
+ for tag in soup(tags_to_exclude):
+ tag.decompose()
+
+ ids_to_exclude = ["sidebar", "main-navigation", "menu-main-menu"]
+ for id in ids_to_exclude:
+ tags = soup.find_all(id=id)
+ for tag in tags:
+ tag.decompose()
+
+ classes_to_exclude = [
+ "elementor-location-header",
+ "navbar-header",
+ "nav",
+ "header-sidebar-wrapper",
+ "blog-sidebar-wrapper",
+ "related-posts",
+ ]
+ for class_name in classes_to_exclude:
+ tags = soup.find_all(class_=class_name)
+ for tag in tags:
+ tag.decompose()
+
+ content = soup.get_text()
+ content = clean_string(content)
+
+ cleaned_size = len(content)
+ if original_size != 0:
+ logging.info(
+ f"Cleaned page size: {cleaned_size} characters, down from {original_size} (shrunk: {original_size-cleaned_size} chars, {round((1-(cleaned_size/original_size)) * 100, 2)}%)" # noqa:E501
+ )
+
+ return content
+
+
def clean_string(text):
"""
This function takes in a string and performs a series of text cleaning operations.
diff --git a/tests/chunkers/test_chunkers.py b/tests/chunkers/test_chunkers.py
index 28787716..b50c3013 100644
--- a/tests/chunkers/test_chunkers.py
+++ b/tests/chunkers/test_chunkers.py
@@ -1,3 +1,4 @@
+from embedchain.chunkers.discourse import DiscourseChunker
from embedchain.chunkers.docs_site import DocsSiteChunker
from embedchain.chunkers.docx_file import DocxFileChunker
from embedchain.chunkers.gmail import GmailChunker
@@ -37,6 +38,7 @@ chunker_common_config = {
GmailChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
PostgresChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
SlackChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
+ DiscourseChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
}
diff --git a/tests/helper_classes/test_json_serializable.py b/tests/helper_classes/test_json_serializable.py
index 9d2768b0..ba06a005 100644
--- a/tests/helper_classes/test_json_serializable.py
+++ b/tests/helper_classes/test_json_serializable.py
@@ -4,7 +4,8 @@ from string import Template
from embedchain import App
from embedchain.config import AppConfig, BaseLlmConfig
-from embedchain.helper.json_serializable import JSONSerializable, register_deserializable
+from embedchain.helper.json_serializable import (JSONSerializable,
+ register_deserializable)
class TestJsonSerializable(unittest.TestCase):
diff --git a/tests/loaders/test_discourse.py b/tests/loaders/test_discourse.py
new file mode 100644
index 00000000..2949f0c9
--- /dev/null
+++ b/tests/loaders/test_discourse.py
@@ -0,0 +1,118 @@
+import pytest
+import requests
+
+from embedchain.loaders.discourse import DiscourseLoader
+
+
+@pytest.fixture
+def discourse_loader_config():
+ return {
+ "domain": "https://example.com",
+ }
+
+
+@pytest.fixture
+def discourse_loader(discourse_loader_config):
+ return DiscourseLoader(config=discourse_loader_config)
+
+
+def test_discourse_loader_init_with_valid_config():
+ config = {"domain": "https://example.com"}
+ loader = DiscourseLoader(config=config)
+ assert loader.domain == "https://example.com"
+
+
+def test_discourse_loader_init_with_missing_config():
+ with pytest.raises(ValueError, match="DiscourseLoader requires a config"):
+ DiscourseLoader()
+
+
+def test_discourse_loader_init_with_missing_domain():
+ config = {"another_key": "value"}
+ with pytest.raises(ValueError, match="DiscourseLoader requires a domain"):
+ DiscourseLoader(config=config)
+
+
+def test_discourse_loader_check_query_with_valid_query(discourse_loader):
+ discourse_loader._check_query("sample query")
+
+
+def test_discourse_loader_check_query_with_empty_query(discourse_loader):
+ with pytest.raises(ValueError, match="DiscourseLoader requires a query"):
+ discourse_loader._check_query("")
+
+
+def test_discourse_loader_check_query_with_invalid_query_type(discourse_loader):
+ with pytest.raises(ValueError, match="DiscourseLoader requires a query"):
+ discourse_loader._check_query(123)
+
+
+def test_discourse_loader_load_post_with_valid_post_id(discourse_loader, monkeypatch):
+ def mock_get(*args, **kwargs):
+ class MockResponse:
+ def json(self):
+ return {"raw": "Sample post content"}
+
+ def raise_for_status(self):
+ pass
+
+ return MockResponse()
+
+ monkeypatch.setattr(requests, "get", mock_get)
+
+ post_data = discourse_loader._load_post(123)
+
+ assert post_data["content"] == "Sample post content"
+ assert "meta_data" in post_data
+
+
+def test_discourse_loader_load_post_with_invalid_post_id(discourse_loader, monkeypatch):
+ def mock_get(*args, **kwargs):
+ class MockResponse:
+ def raise_for_status(self):
+ raise requests.exceptions.RequestException("Test error")
+
+ return MockResponse()
+
+ monkeypatch.setattr(requests, "get", mock_get)
+
+ with pytest.raises(Exception, match="Test error"):
+ discourse_loader._load_post(123)
+
+
+def test_discourse_loader_load_data_with_valid_query(discourse_loader, monkeypatch):
+ def mock_get(*args, **kwargs):
+ class MockResponse:
+ def json(self):
+ return {"grouped_search_result": {"post_ids": [123, 456, 789]}}
+
+ def raise_for_status(self):
+ pass
+
+ return MockResponse()
+
+ monkeypatch.setattr(requests, "get", mock_get)
+
+ def mock_load_post(*args, **kwargs):
+ return {
+ "content": "Sample post content",
+ "meta_data": {
+ "url": "https://example.com/posts/123.json",
+ "created_at": "2021-01-01",
+ "username": "test_user",
+ "topic_slug": "test_topic",
+ "score": 10,
+ },
+ }
+
+ monkeypatch.setattr(discourse_loader, "_load_post", mock_load_post)
+
+ data = discourse_loader.load_data("sample query")
+
+ assert len(data["data"]) == 3
+ assert data["data"][0]["content"] == "Sample post content"
+ assert data["data"][0]["meta_data"]["url"] == "https://example.com/posts/123.json"
+ assert data["data"][0]["meta_data"]["created_at"] == "2021-01-01"
+ assert data["data"][0]["meta_data"]["username"] == "test_user"
+ assert data["data"][0]["meta_data"]["topic_slug"] == "test_topic"
+ assert data["data"][0]["meta_data"]["score"] == 10