[Feature] Add Slack Loader (#932)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2023-11-13 13:06:01 -08:00
committed by GitHub
parent 23522b7b55
commit 539286aafd
10 changed files with 248 additions and 12 deletions

View File

@@ -21,6 +21,7 @@ Embedchain comes with built-in support for various data sources. We handle the c
<Card title="🎥📺 youtube video" href="/data-sources/youtube-video"></Card> <Card title="🎥📺 youtube video" href="/data-sources/youtube-video"></Card>
<Card title="📬 Gmail" href="/data-sources/gmail"></Card> <Card title="📬 Gmail" href="/data-sources/gmail"></Card>
<Card title="🐘 Postgres" href="/data-sources/postgres"></Card> <Card title="🐘 Postgres" href="/data-sources/postgres"></Card>
<Card title="🤖 Slack" href="/data-sources/slack"></Card>
</CardGroup> </CardGroup>
<br/ > <br/ >

View File

@@ -0,0 +1,54 @@
---
title: '🤖 Slack'
---
## Pre-requisite
- Download required packages by running `pip install --upgrade "embedchain[slack]"`.
- Configure your slack bot token as environment variable `SLACK_USER_TOKEN`.
- Find your user token on your [Slack Account](https://api.slack.com/authentication/token-types)
- Make sure your slack user token includes [search](https://api.slack.com/scopes/search:read) scope.
## Example
1. Setup the Slack loader by configuring the Slack Webclient.
```Python
from embedchain.loaders.slack import SlackLoader
os.environ["SLACK_USER_TOKEN"] = "xoxp-*"
loader = SlackLoader()
"""
config = {
'base_url': slack_app_url,
'headers': web_headers,
'team_id': slack_team_id,
}
loader = SlackLoader(config)
"""
```
NOTE: you can also pass the `config` with `base_url`, `headers`, `team_id` to setup your SlackLoader.
2. Once you setup the loader, you can create an app and load data using the above slack loader
```Python
import os
from embedchain.pipeline import Pipeline as App
app = App()
app.add("in:random", data_type="slack", loader=loader)
question = "Which bots are available in the slack workspace's random channel?"
# Answer: The available bot in the slack workspace's random channel is the Embedchain bot.
```
3. We automatically create a chunker to chunk your slack data, however if you wish to provide your own chunker class. Here is how you can do that:
```Python
from embedchain.chunkers.slack import SlackChunker
from embedchain.config.add_config import ChunkerConfig
slack_chunker_config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
slack_chunker = SlackChunker(config=slack_chunker_config)
app.add(slack_chunker, data_type="slack", loader=loader, chunker=slack_chunker)
```

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helper.json_serializable import register_deserializable
@register_deserializable
class SlackChunker(BaseChunker):
"""Chunker for postgres."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -68,6 +68,7 @@ class DataFormatter(JSONSerializable):
custom_loaders = set( custom_loaders = set(
[ [
DataType.POSTGRES, DataType.POSTGRES,
DataType.SLACK,
] ]
) )
@@ -106,6 +107,7 @@ class DataFormatter(JSONSerializable):
DataType.GMAIL: "embedchain.chunkers.gmail.GmailChunker", DataType.GMAIL: "embedchain.chunkers.gmail.GmailChunker",
DataType.NOTION: "embedchain.chunkers.notion.NotionChunker", DataType.NOTION: "embedchain.chunkers.notion.NotionChunker",
DataType.POSTGRES: "embedchain.chunkers.postgres.PostgresChunker", DataType.POSTGRES: "embedchain.chunkers.postgres.PostgresChunker",
DataType.SLACK: "embedchain.chunkers.slack.SlackChunker",
} }
if data_type in chunker_classes: if data_type in chunker_classes:

View File

@@ -40,9 +40,7 @@ class PostgresLoader(BaseLoader):
def _check_query(self, query): def _check_query(self, query):
if not isinstance(query, str): if not isinstance(query, str):
raise ValueError( raise ValueError(
f"Invalid postgres query: {query}", f"Invalid postgres query: {query}. Provide the valid source to add from postgres, make sure you are following `https://docs.embedchain.ai/data-sources/postgres`", # noqa:E501
"Provide the valid source to add from postgres, \
make sure you are following `https://docs.embedchain.ai/data-sources/postgres`",
) )
def load_data(self, query): def load_data(self, query):

108
embedchain/loaders/slack.py Normal file
View File

@@ -0,0 +1,108 @@
import hashlib
import logging
import os
import ssl
from typing import Any, Dict, Optional
import certifi
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils import clean_string
SLACK_API_BASE_URL = "https://www.slack.com/api/"
class SlackLoader(BaseLoader):
def __init__(self, config: Optional[Dict[str, Any]] = None):
super().__init__()
if config is not None:
self.config = config
else:
self.config = {"base_url": SLACK_API_BASE_URL}
self.client = None
self._setup_loader(self.config)
def _setup_loader(self, config: Dict[str, Any]):
try:
from slack_sdk import WebClient
except ImportError as e:
raise ImportError(
"Slack loader requires extra dependencies. \
Install with `pip install --upgrade embedchain[slack]`"
) from e
if os.getenv("SLACK_USER_TOKEN") is None:
raise ValueError(
"SLACK_USER_TOKEN environment variables not provided. Check `https://docs.embedchain.ai/data-sources/slack` to learn more." # noqa:E501
)
logging.info(f"Creating Slack Loader with config: {config}")
# get slack client config params
slack_bot_token = os.getenv("SLACK_USER_TOKEN")
ssl_cert = ssl.create_default_context(cafile=certifi.where())
base_url = config.get("base_url", SLACK_API_BASE_URL)
headers = config.get("headers")
# for Org-Wide App
team_id = config.get("team_id")
self.client = WebClient(
token=slack_bot_token,
base_url=base_url,
ssl=ssl_cert,
headers=headers,
team_id=team_id,
)
logging.info("Slack Loader setup successful!")
def _check_query(self, query):
if not isinstance(query, str):
raise ValueError(
f"Invalid query passed to Slack loader, found: {query}. Check `https://docs.embedchain.ai/data-sources/slack` to learn more." # noqa:E501
)
def load_data(self, query):
self._check_query(query)
try:
data = []
data_content = []
logging.info(f"Searching slack conversations for query: {query}")
results = self.client.search_messages(
query=query,
sort="timestamp",
sort_dir="desc",
count=1000,
)
messages = results.get("messages")
num_message = results.get("total")
logging.info(f"Found {num_message} messages for query: {query}")
matches = messages.get("matches", [])
for message in matches:
url = message.get("permalink")
text = message.get("text")
content = clean_string(text)
message_meta_data_keys = ["channel", "iid", "team", "ts", "type", "user", "username"]
meta_data = message.fromkeys(message_meta_data_keys, "")
meta_data.update({"url": url})
data.append(
{
"content": content,
"meta_data": meta_data,
}
)
data_content.append(content)
doc_id = hashlib.md5((query + ", ".join(data_content)).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": data,
}
except Exception as e:
logging.warning(f"Error in loading slack data: {e}")
raise ValueError(
f"Error in loading slack data: {e}. Check `https://docs.embedchain.ai/data-sources/slack` to learn more." # noqa:E501
) from e

View File

@@ -30,6 +30,7 @@ class IndirectDataType(Enum):
OPENAPI = "openapi" OPENAPI = "openapi"
GMAIL = "gmail" GMAIL = "gmail"
POSTGRES = "postgres" POSTGRES = "postgres"
SLACK = "slack"
class SpecialDataType(Enum): class SpecialDataType(Enum):
@@ -59,3 +60,4 @@ class DataType(Enum):
OPENAPI = IndirectDataType.OPENAPI.value OPENAPI = IndirectDataType.OPENAPI.value
GMAIL = IndirectDataType.GMAIL.value GMAIL = IndirectDataType.GMAIL.value
POSTGRES = IndirectDataType.POSTGRES.value POSTGRES = IndirectDataType.POSTGRES.value
SLACK = IndirectDataType.SLACK.value

View File

@@ -83,7 +83,7 @@ async def create_app_using_default_config(app_id: str, config: UploadFile = None
return DefaultResponse(response=f"App created successfully. App ID: {app_id}") return DefaultResponse(response=f"App created successfully. App ID: {app_id}")
except Exception as e: except Exception as e:
logging.warn(str(e)) logging.warning(str(e))
raise HTTPException(detail=f"Error creating app: {str(e)}", status_code=400) raise HTTPException(detail=f"Error creating app: {str(e)}", status_code=400)
@@ -113,13 +113,13 @@ async def get_datasources_associated_with_app_id(app_id: str, db: Session = Depe
response = app.get_data_sources() response = app.get_data_sources()
return {"results": response} return {"results": response}
except ValueError as ve: except ValueError as ve:
logging.warn(str(ve)) logging.warning(str(ve))
raise HTTPException( raise HTTPException(
detail=generate_error_message_for_api_keys(ve), detail=generate_error_message_for_api_keys(ve),
status_code=400, status_code=400,
) )
except Exception as e: except Exception as e:
logging.warn(str(e)) logging.warning(str(e))
raise HTTPException(detail=f"Error occurred: {str(e)}", status_code=400) raise HTTPException(detail=f"Error occurred: {str(e)}", status_code=400)
@@ -152,13 +152,13 @@ async def add_datasource_to_an_app(body: SourceApp, app_id: str, db: Session = D
response = app.add(source=body.source, data_type=body.data_type) response = app.add(source=body.source, data_type=body.data_type)
return DefaultResponse(response=response) return DefaultResponse(response=response)
except ValueError as ve: except ValueError as ve:
logging.warn(str(ve)) logging.warning(str(ve))
raise HTTPException( raise HTTPException(
detail=generate_error_message_for_api_keys(ve), detail=generate_error_message_for_api_keys(ve),
status_code=400, status_code=400,
) )
except Exception as e: except Exception as e:
logging.warn(str(e)) logging.warning(str(e))
raise HTTPException(detail=f"Error occurred: {str(e)}", status_code=400) raise HTTPException(detail=f"Error occurred: {str(e)}", status_code=400)
@@ -190,13 +190,13 @@ async def query_an_app(body: QueryApp, app_id: str, db: Session = Depends(get_db
response = app.query(body.query) response = app.query(body.query)
return DefaultResponse(response=response) return DefaultResponse(response=response)
except ValueError as ve: except ValueError as ve:
logging.warn(str(ve)) logging.warning(str(ve))
raise HTTPException( raise HTTPException(
detail=generate_error_message_for_api_keys(ve), detail=generate_error_message_for_api_keys(ve),
status_code=400, status_code=400,
) )
except Exception as e: except Exception as e:
logging.warn(str(e)) logging.warning(str(e))
raise HTTPException(detail=f"Error occurred: {str(e)}", status_code=400) raise HTTPException(detail=f"Error occurred: {str(e)}", status_code=400)
@@ -273,13 +273,13 @@ async def deploy_app(body: DeployAppRequest, app_id: str, db: Session = Depends(
app.deploy() app.deploy()
return DefaultResponse(response="App deployed successfully.") return DefaultResponse(response="App deployed successfully.")
except ValueError as ve: except ValueError as ve:
logging.warn(str(ve)) logging.warning(str(ve))
raise HTTPException( raise HTTPException(
detail=generate_error_message_for_api_keys(ve), detail=generate_error_message_for_api_keys(ve),
status_code=400, status_code=400,
) )
except Exception as e: except Exception as e:
logging.warn(str(e)) logging.warning(str(e))
raise HTTPException(detail=f"Error occurred: {str(e)}", status_code=400) raise HTTPException(detail=f"Error occurred: {str(e)}", status_code=400)

View File

@@ -9,6 +9,7 @@ from embedchain.chunkers.pdf_file import PdfFileChunker
from embedchain.chunkers.postgres import PostgresChunker from embedchain.chunkers.postgres import PostgresChunker
from embedchain.chunkers.qna_pair import QnaPairChunker from embedchain.chunkers.qna_pair import QnaPairChunker
from embedchain.chunkers.sitemap import SitemapChunker from embedchain.chunkers.sitemap import SitemapChunker
from embedchain.chunkers.slack import SlackChunker
from embedchain.chunkers.table import TableChunker from embedchain.chunkers.table import TableChunker
from embedchain.chunkers.text import TextChunker from embedchain.chunkers.text import TextChunker
from embedchain.chunkers.web_page import WebPageChunker from embedchain.chunkers.web_page import WebPageChunker
@@ -35,6 +36,7 @@ chunker_common_config = {
OpenAPIChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len}, OpenAPIChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
GmailChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len}, GmailChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
PostgresChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len}, PostgresChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
SlackChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
} }

View File

@@ -0,0 +1,47 @@
import pytest
from embedchain.loaders.slack import SlackLoader
@pytest.fixture
def slack_loader(mocker, monkeypatch):
# Mocking necessary dependencies
mocker.patch("slack_sdk.WebClient")
mocker.patch("ssl.create_default_context")
mocker.patch("certifi.where")
monkeypatch.setenv("SLACK_USER_TOKEN", "slack_user_token")
return SlackLoader()
def test_slack_loader_initialization(slack_loader):
assert slack_loader.client is not None
assert slack_loader.config == {"base_url": "https://www.slack.com/api/"}
def test_slack_loader_setup_loader(slack_loader):
slack_loader._setup_loader({"base_url": "https://custom.slack.api/"})
assert slack_loader.client is not None
def test_slack_loader_check_query(slack_loader):
valid_json_query = "test_query"
invalid_query = 123
slack_loader._check_query(valid_json_query)
with pytest.raises(ValueError):
slack_loader._check_query(invalid_query)
def test_slack_loader_load_data(slack_loader, mocker):
valid_json_query = "in:random"
mocker.patch.object(slack_loader.client, "search_messages", return_value={"messages": {}})
result = slack_loader.load_data(valid_json_query)
assert "doc_id" in result
assert "data" in result