diff --git a/docs/data-sources/discourse.mdx b/docs/data-sources/discourse.mdx new file mode 100644 index 00000000..51daa040 --- /dev/null +++ b/docs/data-sources/discourse.mdx @@ -0,0 +1,44 @@ +--- +title: '๐Ÿ—จ๏ธ Discourse' +--- + +You can now easily load data from your community built with [Discourse](https://discourse.org/). + +## Example + +1. Setup the Discourse Loader with your community url. +```Python +from embedchain.loaders.discourse import DiscourseLoader + +dicourse_loader = DiscourseLoader(config={"domain": "https://community.openai.com"}) +``` + +2. Once you setup the loader, you can create an app and load data using the above discourse loader +```Python +import os +from embedchain.pipeline import Pipeline as App + +os.environ["OPENAI_API_KEY"] = "sk-xxx" + +app = App() + +app.add("openai", data_type="discourse", loader=dicourse_loader) + +question = "Where can I find the OpenAI API status page?" +app.query(question) +# Answer: You can find the OpenAI API status page at https:/status.openai.com/. +``` + +NOTE: The `add` function of the app will accept any executable search query to load data. Refer [Discourse API Docs](https://docs.discourse.org/#tag/Search) to learn more about search queries. + +3. We automatically create a chunker to chunk your discourse data, however if you wish to provide your own chunker class. Here is how you can do that: +```Python + +from embedchain.chunkers.discourse import DiscourseChunker +from embedchain.config.add_config import ChunkerConfig + +discourse_chunker_config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len) +discourse_chunker = DiscourseChunker(config=discourse_chunker_config) + +app.add("openai", data_type='discourse', loader=dicourse_loader, chunker=discourse_chunker) +``` \ No newline at end of file diff --git a/docs/data-sources/overview.mdx b/docs/data-sources/overview.mdx index 7ea2a96b..08b2a572 100644 --- a/docs/data-sources/overview.mdx +++ b/docs/data-sources/overview.mdx @@ -18,11 +18,12 @@ Embedchain comes with built-in support for various data sources. We handle the c - + +
diff --git a/docs/data-sources/youtube-video.mdx b/docs/data-sources/youtube-video.mdx index 5baf2f9a..aed31e63 100644 --- a/docs/data-sources/youtube-video.mdx +++ b/docs/data-sources/youtube-video.mdx @@ -1,5 +1,5 @@ --- -title: '๐ŸŽฅ๐Ÿ“บ Youtube video' +title: '๐Ÿ“บ Youtube video' --- diff --git a/embedchain/chunkers/discourse.py b/embedchain/chunkers/discourse.py new file mode 100644 index 00000000..f78c616e --- /dev/null +++ b/embedchain/chunkers/discourse.py @@ -0,0 +1,22 @@ +from typing import Optional + +from langchain.text_splitter import RecursiveCharacterTextSplitter + +from embedchain.chunkers.base_chunker import BaseChunker +from embedchain.config.add_config import ChunkerConfig +from embedchain.helper.json_serializable import register_deserializable + + +@register_deserializable +class DiscourseChunker(BaseChunker): + """Chunker for discourse.""" + + def __init__(self, config: Optional[ChunkerConfig] = None): + if config is None: + config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len) + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=config.chunk_size, + chunk_overlap=config.chunk_overlap, + length_function=config.length_function, + ) + super().__init__(text_splitter) diff --git a/embedchain/data_formatter/data_formatter.py b/embedchain/data_formatter/data_formatter.py index e0b16b61..49bbd650 100644 --- a/embedchain/data_formatter/data_formatter.py +++ b/embedchain/data_formatter/data_formatter.py @@ -70,6 +70,7 @@ class DataFormatter(JSONSerializable): DataType.POSTGRES, DataType.MYSQL, DataType.SLACK, + DataType.DISCOURSE, ] ) @@ -110,6 +111,7 @@ class DataFormatter(JSONSerializable): DataType.POSTGRES: "embedchain.chunkers.postgres.PostgresChunker", DataType.MYSQL: "embedchain.chunkers.mysql.MySQLChunker", DataType.SLACK: "embedchain.chunkers.slack.SlackChunker", + DataType.DISCOURSE: "embedchain.chunkers.discourse.DiscourseChunker", } if data_type in chunker_classes: diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 67b4770c..88c098a6 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -16,7 +16,8 @@ from embedchain.embedder.base import BaseEmbedder from embedchain.helper.json_serializable import JSONSerializable from embedchain.llm.base import BaseLlm from embedchain.loaders.base_loader import BaseLoader -from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType +from embedchain.models.data_type import (DataType, DirectDataType, + IndirectDataType, SpecialDataType) from embedchain.telemetry.posthog import AnonymousTelemetry from embedchain.utils import detect_datatype, is_valid_json_string from embedchain.vectordb.base import BaseVectorDB diff --git a/embedchain/loaders/discourse.py b/embedchain/loaders/discourse.py new file mode 100644 index 00000000..4111b847 --- /dev/null +++ b/embedchain/loaders/discourse.py @@ -0,0 +1,72 @@ +import concurrent.futures +import hashlib +import logging +from typing import Any, Dict, Optional + +import requests + +from embedchain.loaders.base_loader import BaseLoader +from embedchain.utils import clean_string + + +class DiscourseLoader(BaseLoader): + def __init__(self, config: Optional[Dict[str, Any]] = None): + super().__init__() + if not config: + raise ValueError( + "DiscourseLoader requires a config. Check the documentation for the correct format - `https://docs.embedchain.ai/data-sources/discourse`" # noqa: E501 + ) + + self.domain = config.get("domain") + if not self.domain: + raise ValueError( + "DiscourseLoader requires a domain. Check the documentation for the correct format - `https://docs.embedchain.ai/data-sources/discourse`" # noqa: E501 + ) + + def _check_query(self, query): + if not query or not isinstance(query, str): + raise ValueError( + "DiscourseLoader requires a query. Check the documentation for the correct format - `https://docs.embedchain.ai/data-sources/discourse`" # noqa: E501 + ) + + def _load_post(self, post_id): + post_url = f"{self.domain}/posts/{post_id}.json" + response = requests.get(post_url) + response.raise_for_status() + response_data = response.json() + post_contents = clean_string(response_data.get("raw")) + meta_data = { + "url": post_url, + "created_at": response_data.get("created_at", ""), + "username": response_data.get("username", ""), + "topic_slug": response_data.get("topic_slug", ""), + "score": response_data.get("score", ""), + } + data = { + "content": post_contents, + "meta_data": meta_data, + } + return data + + def load_data(self, query): + self._check_query(query) + data = [] + data_contents = [] + logging.info(f"Searching data on discourse url: {self.domain}, for query: {query}") + search_url = f"{self.domain}/search.json?q={query}" + response = requests.get(search_url) + response.raise_for_status() + response_data = response.json() + post_ids = response_data.get("grouped_search_result").get("post_ids") + with concurrent.futures.ThreadPoolExecutor() as executor: + future_to_post_id = {executor.submit(self._load_post, post_id): post_id for post_id in post_ids} + for future in concurrent.futures.as_completed(future_to_post_id): + post_id = future_to_post_id[future] + try: + post_data = future.result() + data.append(post_data) + except Exception as e: + logging.error(f"Failed to load post {post_id}: {e}") + doc_id = hashlib.sha256((query + ", ".join(data_contents)).encode()).hexdigest() + response_data = {"doc_id": doc_id, "data": data} + return response_data diff --git a/embedchain/models/data_type.py b/embedchain/models/data_type.py index eb5fa8d2..d973d766 100644 --- a/embedchain/models/data_type.py +++ b/embedchain/models/data_type.py @@ -32,6 +32,7 @@ class IndirectDataType(Enum): POSTGRES = "postgres" MYSQL = "mysql" SLACK = "slack" + DISCOURSE = "discourse" class SpecialDataType(Enum): @@ -63,3 +64,4 @@ class DataType(Enum): POSTGRES = IndirectDataType.POSTGRES.value MYSQL = IndirectDataType.MYSQL.value SLACK = IndirectDataType.SLACK.value + DISCOURSE = IndirectDataType.DISCOURSE.value diff --git a/embedchain/utils.py b/embedchain/utils.py index 811375f7..2abd2ad7 100644 --- a/embedchain/utils.py +++ b/embedchain/utils.py @@ -5,11 +5,66 @@ import re import string from typing import Any +from bs4 import BeautifulSoup from schema import Optional, Or, Schema from embedchain.models.data_type import DataType +def parse_content(content, type): + implemented = ["html.parser", "lxml", "lxml-xml", "xml", "html5lib"] + if type not in implemented: + raise ValueError(f"Parser type {type} not implemented. Please choose one of {implemented}") + + soup = BeautifulSoup(content, type) + original_size = len(str(soup.get_text())) + + tags_to_exclude = [ + "nav", + "aside", + "form", + "header", + "noscript", + "svg", + "canvas", + "footer", + "script", + "style", + ] + for tag in soup(tags_to_exclude): + tag.decompose() + + ids_to_exclude = ["sidebar", "main-navigation", "menu-main-menu"] + for id in ids_to_exclude: + tags = soup.find_all(id=id) + for tag in tags: + tag.decompose() + + classes_to_exclude = [ + "elementor-location-header", + "navbar-header", + "nav", + "header-sidebar-wrapper", + "blog-sidebar-wrapper", + "related-posts", + ] + for class_name in classes_to_exclude: + tags = soup.find_all(class_=class_name) + for tag in tags: + tag.decompose() + + content = soup.get_text() + content = clean_string(content) + + cleaned_size = len(content) + if original_size != 0: + logging.info( + f"Cleaned page size: {cleaned_size} characters, down from {original_size} (shrunk: {original_size-cleaned_size} chars, {round((1-(cleaned_size/original_size)) * 100, 2)}%)" # noqa:E501 + ) + + return content + + def clean_string(text): """ This function takes in a string and performs a series of text cleaning operations. diff --git a/tests/chunkers/test_chunkers.py b/tests/chunkers/test_chunkers.py index 28787716..b50c3013 100644 --- a/tests/chunkers/test_chunkers.py +++ b/tests/chunkers/test_chunkers.py @@ -1,3 +1,4 @@ +from embedchain.chunkers.discourse import DiscourseChunker from embedchain.chunkers.docs_site import DocsSiteChunker from embedchain.chunkers.docx_file import DocxFileChunker from embedchain.chunkers.gmail import GmailChunker @@ -37,6 +38,7 @@ chunker_common_config = { GmailChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len}, PostgresChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len}, SlackChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len}, + DiscourseChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len}, } diff --git a/tests/helper_classes/test_json_serializable.py b/tests/helper_classes/test_json_serializable.py index 9d2768b0..ba06a005 100644 --- a/tests/helper_classes/test_json_serializable.py +++ b/tests/helper_classes/test_json_serializable.py @@ -4,7 +4,8 @@ from string import Template from embedchain import App from embedchain.config import AppConfig, BaseLlmConfig -from embedchain.helper.json_serializable import JSONSerializable, register_deserializable +from embedchain.helper.json_serializable import (JSONSerializable, + register_deserializable) class TestJsonSerializable(unittest.TestCase): diff --git a/tests/loaders/test_discourse.py b/tests/loaders/test_discourse.py new file mode 100644 index 00000000..2949f0c9 --- /dev/null +++ b/tests/loaders/test_discourse.py @@ -0,0 +1,118 @@ +import pytest +import requests + +from embedchain.loaders.discourse import DiscourseLoader + + +@pytest.fixture +def discourse_loader_config(): + return { + "domain": "https://example.com", + } + + +@pytest.fixture +def discourse_loader(discourse_loader_config): + return DiscourseLoader(config=discourse_loader_config) + + +def test_discourse_loader_init_with_valid_config(): + config = {"domain": "https://example.com"} + loader = DiscourseLoader(config=config) + assert loader.domain == "https://example.com" + + +def test_discourse_loader_init_with_missing_config(): + with pytest.raises(ValueError, match="DiscourseLoader requires a config"): + DiscourseLoader() + + +def test_discourse_loader_init_with_missing_domain(): + config = {"another_key": "value"} + with pytest.raises(ValueError, match="DiscourseLoader requires a domain"): + DiscourseLoader(config=config) + + +def test_discourse_loader_check_query_with_valid_query(discourse_loader): + discourse_loader._check_query("sample query") + + +def test_discourse_loader_check_query_with_empty_query(discourse_loader): + with pytest.raises(ValueError, match="DiscourseLoader requires a query"): + discourse_loader._check_query("") + + +def test_discourse_loader_check_query_with_invalid_query_type(discourse_loader): + with pytest.raises(ValueError, match="DiscourseLoader requires a query"): + discourse_loader._check_query(123) + + +def test_discourse_loader_load_post_with_valid_post_id(discourse_loader, monkeypatch): + def mock_get(*args, **kwargs): + class MockResponse: + def json(self): + return {"raw": "Sample post content"} + + def raise_for_status(self): + pass + + return MockResponse() + + monkeypatch.setattr(requests, "get", mock_get) + + post_data = discourse_loader._load_post(123) + + assert post_data["content"] == "Sample post content" + assert "meta_data" in post_data + + +def test_discourse_loader_load_post_with_invalid_post_id(discourse_loader, monkeypatch): + def mock_get(*args, **kwargs): + class MockResponse: + def raise_for_status(self): + raise requests.exceptions.RequestException("Test error") + + return MockResponse() + + monkeypatch.setattr(requests, "get", mock_get) + + with pytest.raises(Exception, match="Test error"): + discourse_loader._load_post(123) + + +def test_discourse_loader_load_data_with_valid_query(discourse_loader, monkeypatch): + def mock_get(*args, **kwargs): + class MockResponse: + def json(self): + return {"grouped_search_result": {"post_ids": [123, 456, 789]}} + + def raise_for_status(self): + pass + + return MockResponse() + + monkeypatch.setattr(requests, "get", mock_get) + + def mock_load_post(*args, **kwargs): + return { + "content": "Sample post content", + "meta_data": { + "url": "https://example.com/posts/123.json", + "created_at": "2021-01-01", + "username": "test_user", + "topic_slug": "test_topic", + "score": 10, + }, + } + + monkeypatch.setattr(discourse_loader, "_load_post", mock_load_post) + + data = discourse_loader.load_data("sample query") + + assert len(data["data"]) == 3 + assert data["data"][0]["content"] == "Sample post content" + assert data["data"][0]["meta_data"]["url"] == "https://example.com/posts/123.json" + assert data["data"][0]["meta_data"]["created_at"] == "2021-01-01" + assert data["data"][0]["meta_data"]["username"] == "test_user" + assert data["data"][0]["meta_data"]["topic_slug"] == "test_topic" + assert data["data"][0]["meta_data"]["score"] == 10