[Feature] Discourse Loader (#948)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2023-11-13 16:39:11 -08:00
committed by GitHub
parent 919cc74e94
commit 95c0d47236
12 changed files with 324 additions and 4 deletions

View File

@@ -0,0 +1,44 @@
---
title: '🗨️ Discourse'
---
You can now easily load data from your community built with [Discourse](https://discourse.org/).
## Example
1. Setup the Discourse Loader with your community url.
```Python
from embedchain.loaders.discourse import DiscourseLoader
dicourse_loader = DiscourseLoader(config={"domain": "https://community.openai.com"})
```
2. Once you setup the loader, you can create an app and load data using the above discourse loader
```Python
import os
from embedchain.pipeline import Pipeline as App
os.environ["OPENAI_API_KEY"] = "sk-xxx"
app = App()
app.add("openai", data_type="discourse", loader=dicourse_loader)
question = "Where can I find the OpenAI API status page?"
app.query(question)
# Answer: You can find the OpenAI API status page at https:/status.openai.com/.
```
NOTE: The `add` function of the app will accept any executable search query to load data. Refer [Discourse API Docs](https://docs.discourse.org/#tag/Search) to learn more about search queries.
3. We automatically create a chunker to chunk your discourse data, however if you wish to provide your own chunker class. Here is how you can do that:
```Python
from embedchain.chunkers.discourse import DiscourseChunker
from embedchain.config.add_config import ChunkerConfig
discourse_chunker_config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
discourse_chunker = DiscourseChunker(config=discourse_chunker_config)
app.add("openai", data_type='discourse', loader=dicourse_loader, chunker=discourse_chunker)
```

View File

@@ -18,11 +18,12 @@ Embedchain comes with built-in support for various data sources. We handle the c
<Card title="🌐📄 web page" href="/data-sources/web-page"></Card> <Card title="🌐📄 web page" href="/data-sources/web-page"></Card>
<Card title="🧾 xml" href="/data-sources/xml"></Card> <Card title="🧾 xml" href="/data-sources/xml"></Card>
<Card title="🙌 OpenAPI" href="/data-sources/openapi"></Card> <Card title="🙌 OpenAPI" href="/data-sources/openapi"></Card>
<Card title="🎥📺 youtube video" href="/data-sources/youtube-video"></Card> <Card title="📺 youtube video" href="/data-sources/youtube-video"></Card>
<Card title="📬 Gmail" href="/data-sources/gmail"></Card> <Card title="📬 Gmail" href="/data-sources/gmail"></Card>
<Card title="🐘 Postgres" href="/data-sources/postgres"></Card> <Card title="🐘 Postgres" href="/data-sources/postgres"></Card>
<Card title="🐬 MySQL" href="/data-sources/mysql"></Card> <Card title="🐬 MySQL" href="/data-sources/mysql"></Card>
<Card title="🤖 Slack" href="/data-sources/slack"></Card> <Card title="🤖 Slack" href="/data-sources/slack"></Card>
<Card title="🗨️ Discourse" href="/data-sources/discourse"></Card>
</CardGroup> </CardGroup>
<br/ > <br/ >

View File

@@ -1,5 +1,5 @@
--- ---
title: '🎥📺 Youtube video' title: '📺 Youtube video'
--- ---

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helper.json_serializable import register_deserializable
@register_deserializable
class DiscourseChunker(BaseChunker):
"""Chunker for discourse."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -70,6 +70,7 @@ class DataFormatter(JSONSerializable):
DataType.POSTGRES, DataType.POSTGRES,
DataType.MYSQL, DataType.MYSQL,
DataType.SLACK, DataType.SLACK,
DataType.DISCOURSE,
] ]
) )
@@ -110,6 +111,7 @@ class DataFormatter(JSONSerializable):
DataType.POSTGRES: "embedchain.chunkers.postgres.PostgresChunker", DataType.POSTGRES: "embedchain.chunkers.postgres.PostgresChunker",
DataType.MYSQL: "embedchain.chunkers.mysql.MySQLChunker", DataType.MYSQL: "embedchain.chunkers.mysql.MySQLChunker",
DataType.SLACK: "embedchain.chunkers.slack.SlackChunker", DataType.SLACK: "embedchain.chunkers.slack.SlackChunker",
DataType.DISCOURSE: "embedchain.chunkers.discourse.DiscourseChunker",
} }
if data_type in chunker_classes: if data_type in chunker_classes:

View File

@@ -16,7 +16,8 @@ from embedchain.embedder.base import BaseEmbedder
from embedchain.helper.json_serializable import JSONSerializable from embedchain.helper.json_serializable import JSONSerializable
from embedchain.llm.base import BaseLlm from embedchain.llm.base import BaseLlm
from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.base_loader import BaseLoader
from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType from embedchain.models.data_type import (DataType, DirectDataType,
IndirectDataType, SpecialDataType)
from embedchain.telemetry.posthog import AnonymousTelemetry from embedchain.telemetry.posthog import AnonymousTelemetry
from embedchain.utils import detect_datatype, is_valid_json_string from embedchain.utils import detect_datatype, is_valid_json_string
from embedchain.vectordb.base import BaseVectorDB from embedchain.vectordb.base import BaseVectorDB

View File

@@ -0,0 +1,72 @@
import concurrent.futures
import hashlib
import logging
from typing import Any, Dict, Optional
import requests
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils import clean_string
class DiscourseLoader(BaseLoader):
def __init__(self, config: Optional[Dict[str, Any]] = None):
super().__init__()
if not config:
raise ValueError(
"DiscourseLoader requires a config. Check the documentation for the correct format - `https://docs.embedchain.ai/data-sources/discourse`" # noqa: E501
)
self.domain = config.get("domain")
if not self.domain:
raise ValueError(
"DiscourseLoader requires a domain. Check the documentation for the correct format - `https://docs.embedchain.ai/data-sources/discourse`" # noqa: E501
)
def _check_query(self, query):
if not query or not isinstance(query, str):
raise ValueError(
"DiscourseLoader requires a query. Check the documentation for the correct format - `https://docs.embedchain.ai/data-sources/discourse`" # noqa: E501
)
def _load_post(self, post_id):
post_url = f"{self.domain}/posts/{post_id}.json"
response = requests.get(post_url)
response.raise_for_status()
response_data = response.json()
post_contents = clean_string(response_data.get("raw"))
meta_data = {
"url": post_url,
"created_at": response_data.get("created_at", ""),
"username": response_data.get("username", ""),
"topic_slug": response_data.get("topic_slug", ""),
"score": response_data.get("score", ""),
}
data = {
"content": post_contents,
"meta_data": meta_data,
}
return data
def load_data(self, query):
self._check_query(query)
data = []
data_contents = []
logging.info(f"Searching data on discourse url: {self.domain}, for query: {query}")
search_url = f"{self.domain}/search.json?q={query}"
response = requests.get(search_url)
response.raise_for_status()
response_data = response.json()
post_ids = response_data.get("grouped_search_result").get("post_ids")
with concurrent.futures.ThreadPoolExecutor() as executor:
future_to_post_id = {executor.submit(self._load_post, post_id): post_id for post_id in post_ids}
for future in concurrent.futures.as_completed(future_to_post_id):
post_id = future_to_post_id[future]
try:
post_data = future.result()
data.append(post_data)
except Exception as e:
logging.error(f"Failed to load post {post_id}: {e}")
doc_id = hashlib.sha256((query + ", ".join(data_contents)).encode()).hexdigest()
response_data = {"doc_id": doc_id, "data": data}
return response_data

View File

@@ -32,6 +32,7 @@ class IndirectDataType(Enum):
POSTGRES = "postgres" POSTGRES = "postgres"
MYSQL = "mysql" MYSQL = "mysql"
SLACK = "slack" SLACK = "slack"
DISCOURSE = "discourse"
class SpecialDataType(Enum): class SpecialDataType(Enum):
@@ -63,3 +64,4 @@ class DataType(Enum):
POSTGRES = IndirectDataType.POSTGRES.value POSTGRES = IndirectDataType.POSTGRES.value
MYSQL = IndirectDataType.MYSQL.value MYSQL = IndirectDataType.MYSQL.value
SLACK = IndirectDataType.SLACK.value SLACK = IndirectDataType.SLACK.value
DISCOURSE = IndirectDataType.DISCOURSE.value

View File

@@ -5,11 +5,66 @@ import re
import string import string
from typing import Any from typing import Any
from bs4 import BeautifulSoup
from schema import Optional, Or, Schema from schema import Optional, Or, Schema
from embedchain.models.data_type import DataType from embedchain.models.data_type import DataType
def parse_content(content, type):
implemented = ["html.parser", "lxml", "lxml-xml", "xml", "html5lib"]
if type not in implemented:
raise ValueError(f"Parser type {type} not implemented. Please choose one of {implemented}")
soup = BeautifulSoup(content, type)
original_size = len(str(soup.get_text()))
tags_to_exclude = [
"nav",
"aside",
"form",
"header",
"noscript",
"svg",
"canvas",
"footer",
"script",
"style",
]
for tag in soup(tags_to_exclude):
tag.decompose()
ids_to_exclude = ["sidebar", "main-navigation", "menu-main-menu"]
for id in ids_to_exclude:
tags = soup.find_all(id=id)
for tag in tags:
tag.decompose()
classes_to_exclude = [
"elementor-location-header",
"navbar-header",
"nav",
"header-sidebar-wrapper",
"blog-sidebar-wrapper",
"related-posts",
]
for class_name in classes_to_exclude:
tags = soup.find_all(class_=class_name)
for tag in tags:
tag.decompose()
content = soup.get_text()
content = clean_string(content)
cleaned_size = len(content)
if original_size != 0:
logging.info(
f"Cleaned page size: {cleaned_size} characters, down from {original_size} (shrunk: {original_size-cleaned_size} chars, {round((1-(cleaned_size/original_size)) * 100, 2)}%)" # noqa:E501
)
return content
def clean_string(text): def clean_string(text):
""" """
This function takes in a string and performs a series of text cleaning operations. This function takes in a string and performs a series of text cleaning operations.

View File

@@ -1,3 +1,4 @@
from embedchain.chunkers.discourse import DiscourseChunker
from embedchain.chunkers.docs_site import DocsSiteChunker from embedchain.chunkers.docs_site import DocsSiteChunker
from embedchain.chunkers.docx_file import DocxFileChunker from embedchain.chunkers.docx_file import DocxFileChunker
from embedchain.chunkers.gmail import GmailChunker from embedchain.chunkers.gmail import GmailChunker
@@ -37,6 +38,7 @@ chunker_common_config = {
GmailChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len}, GmailChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
PostgresChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len}, PostgresChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
SlackChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len}, SlackChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
DiscourseChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
} }

View File

@@ -4,7 +4,8 @@ from string import Template
from embedchain import App from embedchain import App
from embedchain.config import AppConfig, BaseLlmConfig from embedchain.config import AppConfig, BaseLlmConfig
from embedchain.helper.json_serializable import JSONSerializable, register_deserializable from embedchain.helper.json_serializable import (JSONSerializable,
register_deserializable)
class TestJsonSerializable(unittest.TestCase): class TestJsonSerializable(unittest.TestCase):

View File

@@ -0,0 +1,118 @@
import pytest
import requests
from embedchain.loaders.discourse import DiscourseLoader
@pytest.fixture
def discourse_loader_config():
return {
"domain": "https://example.com",
}
@pytest.fixture
def discourse_loader(discourse_loader_config):
return DiscourseLoader(config=discourse_loader_config)
def test_discourse_loader_init_with_valid_config():
config = {"domain": "https://example.com"}
loader = DiscourseLoader(config=config)
assert loader.domain == "https://example.com"
def test_discourse_loader_init_with_missing_config():
with pytest.raises(ValueError, match="DiscourseLoader requires a config"):
DiscourseLoader()
def test_discourse_loader_init_with_missing_domain():
config = {"another_key": "value"}
with pytest.raises(ValueError, match="DiscourseLoader requires a domain"):
DiscourseLoader(config=config)
def test_discourse_loader_check_query_with_valid_query(discourse_loader):
discourse_loader._check_query("sample query")
def test_discourse_loader_check_query_with_empty_query(discourse_loader):
with pytest.raises(ValueError, match="DiscourseLoader requires a query"):
discourse_loader._check_query("")
def test_discourse_loader_check_query_with_invalid_query_type(discourse_loader):
with pytest.raises(ValueError, match="DiscourseLoader requires a query"):
discourse_loader._check_query(123)
def test_discourse_loader_load_post_with_valid_post_id(discourse_loader, monkeypatch):
def mock_get(*args, **kwargs):
class MockResponse:
def json(self):
return {"raw": "Sample post content"}
def raise_for_status(self):
pass
return MockResponse()
monkeypatch.setattr(requests, "get", mock_get)
post_data = discourse_loader._load_post(123)
assert post_data["content"] == "Sample post content"
assert "meta_data" in post_data
def test_discourse_loader_load_post_with_invalid_post_id(discourse_loader, monkeypatch):
def mock_get(*args, **kwargs):
class MockResponse:
def raise_for_status(self):
raise requests.exceptions.RequestException("Test error")
return MockResponse()
monkeypatch.setattr(requests, "get", mock_get)
with pytest.raises(Exception, match="Test error"):
discourse_loader._load_post(123)
def test_discourse_loader_load_data_with_valid_query(discourse_loader, monkeypatch):
def mock_get(*args, **kwargs):
class MockResponse:
def json(self):
return {"grouped_search_result": {"post_ids": [123, 456, 789]}}
def raise_for_status(self):
pass
return MockResponse()
monkeypatch.setattr(requests, "get", mock_get)
def mock_load_post(*args, **kwargs):
return {
"content": "Sample post content",
"meta_data": {
"url": "https://example.com/posts/123.json",
"created_at": "2021-01-01",
"username": "test_user",
"topic_slug": "test_topic",
"score": 10,
},
}
monkeypatch.setattr(discourse_loader, "_load_post", mock_load_post)
data = discourse_loader.load_data("sample query")
assert len(data["data"]) == 3
assert data["data"][0]["content"] == "Sample post content"
assert data["data"][0]["meta_data"]["url"] == "https://example.com/posts/123.json"
assert data["data"][0]["meta_data"]["created_at"] == "2021-01-01"
assert data["data"][0]["meta_data"]["username"] == "test_user"
assert data["data"][0]["meta_data"]["topic_slug"] == "test_topic"
assert data["data"][0]["meta_data"]["score"] == 10