Support for Audio Files (#1416)

This commit is contained in:
Dev Khant
2024-06-12 22:55:58 +05:30
committed by GitHub
parent 1bddd46ed2
commit 08b67b4a78
10 changed files with 211 additions and 1 deletions

View File

@@ -11,7 +11,7 @@ install:
install_all: install_all:
poetry install --all-extras poetry install --all-extras
poetry run pip install pinecone-text pinecone-client langchain-anthropic "unstructured[local-inference, all-docs]" ollama poetry run pip install pinecone-text pinecone-client langchain-anthropic "unstructured[local-inference, all-docs]" ollama deepgram-sdk==3.2.7
install_es: install_es:
poetry install --extras elasticsearch poetry install --extras elasticsearch

View File

@@ -0,0 +1,25 @@
---
title: "🎤 Audio"
---
To use an audio as data source, just add `data_type` as `audio` and pass in the path of the audio (local or hosted).
We use [Deepgram](https://developers.deepgram.com/docs/introduction) to transcribe the audiot to text, and then use the generated text as the data source.
You would require an Deepgram API key which is available [here](https://console.deepgram.com/signup?jump=keys) to use this feature.
### Without customization
```python
import os
from embedchain import App
os.environ["DEEPGRAM_API_KEY"] = "153xxx"
app = App()
app.add("introduction.wav", data_type="audio")
response = app.query("What is my name and how old am I?")
print(response)
# Answer: Your name is Dave and you are 21 years old.
```

View File

@@ -9,6 +9,7 @@ Embedchain comes with built-in support for various data sources. We handle the c
<Card title="CSV file" href="/components/data-sources/csv"></Card> <Card title="CSV file" href="/components/data-sources/csv"></Card>
<Card title="JSON file" href="/components/data-sources/json"></Card> <Card title="JSON file" href="/components/data-sources/json"></Card>
<Card title="Text" href="/components/data-sources/text"></Card> <Card title="Text" href="/components/data-sources/text"></Card>
<Card title="Text File" href="/components/data-sources/text-file"></Card>
<Card title="Directory" href="/components/data-sources/directory"></Card> <Card title="Directory" href="/components/data-sources/directory"></Card>
<Card title="Web page" href="/components/data-sources/web-page"></Card> <Card title="Web page" href="/components/data-sources/web-page"></Card>
<Card title="Youtube Channel" href="/components/data-sources/youtube-channel"></Card> <Card title="Youtube Channel" href="/components/data-sources/youtube-channel"></Card>
@@ -33,6 +34,7 @@ Embedchain comes with built-in support for various data sources. We handle the c
<Card title="Beehiiv" href="/components/data-sources/beehiiv"></Card> <Card title="Beehiiv" href="/components/data-sources/beehiiv"></Card>
<Card title="Dropbox" href="/components/data-sources/dropbox"></Card> <Card title="Dropbox" href="/components/data-sources/dropbox"></Card>
<Card title="Image" href="/components/data-sources/image"></Card> <Card title="Image" href="/components/data-sources/image"></Card>
<Card title="Audio" href="/components/data-sources/audio"></Card>
<Card title="Custom" href="/components/data-sources/custom"></Card> <Card title="Custom" href="/components/data-sources/custom"></Card>
</CardGroup> </CardGroup>

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class AudioChunker(BaseChunker):
"""Chunker for audio."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -81,6 +81,7 @@ class DataFormatter(JSONSerializable):
DataType.DROPBOX: "embedchain.loaders.dropbox.DropboxLoader", DataType.DROPBOX: "embedchain.loaders.dropbox.DropboxLoader",
DataType.TEXT_FILE: "embedchain.loaders.text_file.TextFileLoader", DataType.TEXT_FILE: "embedchain.loaders.text_file.TextFileLoader",
DataType.EXCEL_FILE: "embedchain.loaders.excel_file.ExcelFileLoader", DataType.EXCEL_FILE: "embedchain.loaders.excel_file.ExcelFileLoader",
DataType.AUDIO: "embedchain.loaders.audio.AudioLoader",
} }
if data_type == DataType.CUSTOM or loader is not None: if data_type == DataType.CUSTOM or loader is not None:
@@ -129,6 +130,7 @@ class DataFormatter(JSONSerializable):
DataType.DROPBOX: "embedchain.chunkers.common_chunker.CommonChunker", DataType.DROPBOX: "embedchain.chunkers.common_chunker.CommonChunker",
DataType.TEXT_FILE: "embedchain.chunkers.common_chunker.CommonChunker", DataType.TEXT_FILE: "embedchain.chunkers.common_chunker.CommonChunker",
DataType.EXCEL_FILE: "embedchain.chunkers.excel_file.ExcelFileChunker", DataType.EXCEL_FILE: "embedchain.chunkers.excel_file.ExcelFileChunker",
DataType.AUDIO: "embedchain.chunkers.audio.AudioChunker",
} }
if chunker is not None: if chunker is not None:

View File

@@ -0,0 +1,51 @@
import os
import hashlib
import validators
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
try:
from deepgram import DeepgramClient, PrerecordedOptions
except ImportError:
raise ImportError(
"Audio file requires extra dependencies. Install with `pip install deepgram-sdk==3.2.7`"
) from None
@register_deserializable
class AudioLoader(BaseLoader):
def __init__(self):
if not os.environ.get("DEEPGRAM_API_KEY"):
raise ValueError("DEEPGRAM_API_KEY is not set")
DG_KEY = os.environ.get("DEEPGRAM_API_KEY")
self.client = DeepgramClient(DG_KEY)
def load_data(self, url: str):
"""Load data from a audio file or URL."""
options = PrerecordedOptions(
model="nova-2",
smart_format=True,
)
if validators.url(url):
source = {"url": url}
response = self.client.listen.prerecorded.v("1").transcribe_url(source, options)
else:
with open(url, "rb") as audio:
source = {"buffer": audio}
response = self.client.listen.prerecorded.v("1").transcribe_file(source, options)
content = response["results"]["channels"][0]["alternatives"][0]["transcript"]
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
metadata = {"url": url}
return {
"doc_id": doc_id,
"data": [
{
"content": content,
"meta_data": metadata,
}
],
}

View File

@@ -41,6 +41,7 @@ class IndirectDataType(Enum):
DROPBOX = "dropbox" DROPBOX = "dropbox"
TEXT_FILE = "text_file" TEXT_FILE = "text_file"
EXCEL_FILE = "excel_file" EXCEL_FILE = "excel_file"
AUDIO = "audio"
class SpecialDataType(Enum): class SpecialDataType(Enum):
@@ -81,3 +82,4 @@ class DataType(Enum):
DROPBOX = IndirectDataType.DROPBOX.value DROPBOX = IndirectDataType.DROPBOX.value
TEXT_FILE = IndirectDataType.TEXT_FILE.value TEXT_FILE = IndirectDataType.TEXT_FILE.value
EXCEL_FILE = IndirectDataType.EXCEL_FILE.value EXCEL_FILE = IndirectDataType.EXCEL_FILE.value
AUDIO = IndirectDataType.AUDIO.value

View File

@@ -237,6 +237,12 @@ def detect_datatype(source: Any) -> DataType:
logger.debug(f"Source of `{formatted_source}` detected as `docx`.") logger.debug(f"Source of `{formatted_source}` detected as `docx`.")
return DataType.DOCX return DataType.DOCX
if url.path.endswith(
(".mp3", ".mp4", ".mp2", ".aac", ".wav", ".flac", ".pcm", ".m4a", ".ogg", ".opus", ".webm")
):
logger.debug(f"Source of `{formatted_source}` detected as `audio`.")
return DataType.AUDIO
if url.path.endswith(".yaml"): if url.path.endswith(".yaml"):
try: try:
response = requests.get(source) response = requests.get(source)

View File

@@ -19,6 +19,7 @@ from embedchain.chunkers.text import TextChunker
from embedchain.chunkers.web_page import WebPageChunker from embedchain.chunkers.web_page import WebPageChunker
from embedchain.chunkers.xml import XmlChunker from embedchain.chunkers.xml import XmlChunker
from embedchain.chunkers.youtube_video import YoutubeVideoChunker from embedchain.chunkers.youtube_video import YoutubeVideoChunker
from embedchain.chunkers.audio import AudioChunker
from embedchain.config.add_config import ChunkerConfig from embedchain.config.add_config import ChunkerConfig
chunker_config = ChunkerConfig(chunk_size=500, chunk_overlap=0, length_function=len) chunker_config = ChunkerConfig(chunk_size=500, chunk_overlap=0, length_function=len)
@@ -45,6 +46,7 @@ chunker_common_config = {
CommonChunker: {"chunk_size": 2000, "chunk_overlap": 0, "length_function": len}, CommonChunker: {"chunk_size": 2000, "chunk_overlap": 0, "length_function": len},
GoogleDriveChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len}, GoogleDriveChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
ExcelFileChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len}, ExcelFileChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
AudioChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
} }

View File

@@ -0,0 +1,98 @@
import os
import sys
import hashlib
import pytest
from unittest.mock import mock_open, patch
if sys.version_info > (3, 10): # as `match` statement was introduced in python 3.10
from deepgram import PrerecordedOptions
from embedchain.loaders.audio import AudioLoader
@pytest.fixture
def setup_audio_loader(mocker):
mock_dropbox = mocker.patch("deepgram.DeepgramClient")
mock_dbx = mocker.MagicMock()
mock_dropbox.return_value = mock_dbx
os.environ["DEEPGRAM_API_KEY"] = "test_key"
loader = AudioLoader()
loader.client = mock_dbx
yield loader, mock_dbx
if "DEEPGRAM_API_KEY" in os.environ:
del os.environ["DEEPGRAM_API_KEY"]
@pytest.mark.skipif(
sys.version_info < (3, 10), reason="Test skipped for Python 3.9 or lower"
) # as `match` statement was introduced in python 3.10
def test_initialization(setup_audio_loader):
"""Test initialization of AudioLoader."""
loader, _ = setup_audio_loader
assert loader is not None
@pytest.mark.skipif(
sys.version_info < (3, 10), reason="Test skipped for Python 3.9 or lower"
) # as `match` statement was introduced in python 3.10
def test_load_data_from_url(setup_audio_loader):
loader, mock_dbx = setup_audio_loader
url = "https://example.com/audio.mp3"
expected_content = "This is a test audio transcript."
mock_response = {"results": {"channels": [{"alternatives": [{"transcript": expected_content}]}]}}
mock_dbx.listen.prerecorded.v.return_value.transcribe_url.return_value = mock_response
result = loader.load_data(url)
doc_id = hashlib.sha256((expected_content + url).encode()).hexdigest()
expected_result = {
"doc_id": doc_id,
"data": [
{
"content": expected_content,
"meta_data": {"url": url},
}
],
}
assert result == expected_result
mock_dbx.listen.prerecorded.v.assert_called_once_with("1")
mock_dbx.listen.prerecorded.v.return_value.transcribe_url.assert_called_once_with(
{"url": url}, PrerecordedOptions(model="nova-2", smart_format=True)
)
@pytest.mark.skipif(
sys.version_info < (3, 10), reason="Test skipped for Python 3.9 or lower"
) # as `match` statement was introduced in python 3.10
def test_load_data_from_file(setup_audio_loader):
loader, mock_dbx = setup_audio_loader
file_path = "local_audio.mp3"
expected_content = "This is a test audio transcript."
mock_response = {"results": {"channels": [{"alternatives": [{"transcript": expected_content}]}]}}
mock_dbx.listen.prerecorded.v.return_value.transcribe_file.return_value = mock_response
# Mock the file reading functionality
with patch("builtins.open", mock_open(read_data=b"some data")) as mock_file:
result = loader.load_data(file_path)
doc_id = hashlib.sha256((expected_content + file_path).encode()).hexdigest()
expected_result = {
"doc_id": doc_id,
"data": [
{
"content": expected_content,
"meta_data": {"url": file_path},
}
],
}
assert result == expected_result
mock_dbx.listen.prerecorded.v.assert_called_once_with("1")
mock_dbx.listen.prerecorded.v.return_value.transcribe_file.assert_called_once_with(
{"buffer": mock_file.return_value}, PrerecordedOptions(model="nova-2", smart_format=True)
)