Support for Audio Files (#1416)
This commit is contained in:
2
Makefile
2
Makefile
@@ -11,7 +11,7 @@ install:
|
|||||||
|
|
||||||
install_all:
|
install_all:
|
||||||
poetry install --all-extras
|
poetry install --all-extras
|
||||||
poetry run pip install pinecone-text pinecone-client langchain-anthropic "unstructured[local-inference, all-docs]" ollama
|
poetry run pip install pinecone-text pinecone-client langchain-anthropic "unstructured[local-inference, all-docs]" ollama deepgram-sdk==3.2.7
|
||||||
|
|
||||||
install_es:
|
install_es:
|
||||||
poetry install --extras elasticsearch
|
poetry install --extras elasticsearch
|
||||||
|
|||||||
25
docs/components/data-sources/audio.mdx
Normal file
25
docs/components/data-sources/audio.mdx
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
---
|
||||||
|
title: "🎤 Audio"
|
||||||
|
---
|
||||||
|
|
||||||
|
|
||||||
|
To use an audio as data source, just add `data_type` as `audio` and pass in the path of the audio (local or hosted).
|
||||||
|
|
||||||
|
We use [Deepgram](https://developers.deepgram.com/docs/introduction) to transcribe the audiot to text, and then use the generated text as the data source.
|
||||||
|
|
||||||
|
You would require an Deepgram API key which is available [here](https://console.deepgram.com/signup?jump=keys) to use this feature.
|
||||||
|
|
||||||
|
### Without customization
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from embedchain import App
|
||||||
|
|
||||||
|
os.environ["DEEPGRAM_API_KEY"] = "153xxx"
|
||||||
|
|
||||||
|
app = App()
|
||||||
|
app.add("introduction.wav", data_type="audio")
|
||||||
|
response = app.query("What is my name and how old am I?")
|
||||||
|
print(response)
|
||||||
|
# Answer: Your name is Dave and you are 21 years old.
|
||||||
|
```
|
||||||
@@ -9,6 +9,7 @@ Embedchain comes with built-in support for various data sources. We handle the c
|
|||||||
<Card title="CSV file" href="/components/data-sources/csv"></Card>
|
<Card title="CSV file" href="/components/data-sources/csv"></Card>
|
||||||
<Card title="JSON file" href="/components/data-sources/json"></Card>
|
<Card title="JSON file" href="/components/data-sources/json"></Card>
|
||||||
<Card title="Text" href="/components/data-sources/text"></Card>
|
<Card title="Text" href="/components/data-sources/text"></Card>
|
||||||
|
<Card title="Text File" href="/components/data-sources/text-file"></Card>
|
||||||
<Card title="Directory" href="/components/data-sources/directory"></Card>
|
<Card title="Directory" href="/components/data-sources/directory"></Card>
|
||||||
<Card title="Web page" href="/components/data-sources/web-page"></Card>
|
<Card title="Web page" href="/components/data-sources/web-page"></Card>
|
||||||
<Card title="Youtube Channel" href="/components/data-sources/youtube-channel"></Card>
|
<Card title="Youtube Channel" href="/components/data-sources/youtube-channel"></Card>
|
||||||
@@ -33,6 +34,7 @@ Embedchain comes with built-in support for various data sources. We handle the c
|
|||||||
<Card title="Beehiiv" href="/components/data-sources/beehiiv"></Card>
|
<Card title="Beehiiv" href="/components/data-sources/beehiiv"></Card>
|
||||||
<Card title="Dropbox" href="/components/data-sources/dropbox"></Card>
|
<Card title="Dropbox" href="/components/data-sources/dropbox"></Card>
|
||||||
<Card title="Image" href="/components/data-sources/image"></Card>
|
<Card title="Image" href="/components/data-sources/image"></Card>
|
||||||
|
<Card title="Audio" href="/components/data-sources/audio"></Card>
|
||||||
<Card title="Custom" href="/components/data-sources/custom"></Card>
|
<Card title="Custom" href="/components/data-sources/custom"></Card>
|
||||||
</CardGroup>
|
</CardGroup>
|
||||||
|
|
||||||
|
|||||||
22
embedchain/chunkers/audio.py
Normal file
22
embedchain/chunkers/audio.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
|
|
||||||
|
from embedchain.chunkers.base_chunker import BaseChunker
|
||||||
|
from embedchain.config.add_config import ChunkerConfig
|
||||||
|
from embedchain.helpers.json_serializable import register_deserializable
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
|
class AudioChunker(BaseChunker):
|
||||||
|
"""Chunker for audio."""
|
||||||
|
|
||||||
|
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||||
|
if config is None:
|
||||||
|
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=config.chunk_size,
|
||||||
|
chunk_overlap=config.chunk_overlap,
|
||||||
|
length_function=config.length_function,
|
||||||
|
)
|
||||||
|
super().__init__(text_splitter)
|
||||||
@@ -81,6 +81,7 @@ class DataFormatter(JSONSerializable):
|
|||||||
DataType.DROPBOX: "embedchain.loaders.dropbox.DropboxLoader",
|
DataType.DROPBOX: "embedchain.loaders.dropbox.DropboxLoader",
|
||||||
DataType.TEXT_FILE: "embedchain.loaders.text_file.TextFileLoader",
|
DataType.TEXT_FILE: "embedchain.loaders.text_file.TextFileLoader",
|
||||||
DataType.EXCEL_FILE: "embedchain.loaders.excel_file.ExcelFileLoader",
|
DataType.EXCEL_FILE: "embedchain.loaders.excel_file.ExcelFileLoader",
|
||||||
|
DataType.AUDIO: "embedchain.loaders.audio.AudioLoader",
|
||||||
}
|
}
|
||||||
|
|
||||||
if data_type == DataType.CUSTOM or loader is not None:
|
if data_type == DataType.CUSTOM or loader is not None:
|
||||||
@@ -129,6 +130,7 @@ class DataFormatter(JSONSerializable):
|
|||||||
DataType.DROPBOX: "embedchain.chunkers.common_chunker.CommonChunker",
|
DataType.DROPBOX: "embedchain.chunkers.common_chunker.CommonChunker",
|
||||||
DataType.TEXT_FILE: "embedchain.chunkers.common_chunker.CommonChunker",
|
DataType.TEXT_FILE: "embedchain.chunkers.common_chunker.CommonChunker",
|
||||||
DataType.EXCEL_FILE: "embedchain.chunkers.excel_file.ExcelFileChunker",
|
DataType.EXCEL_FILE: "embedchain.chunkers.excel_file.ExcelFileChunker",
|
||||||
|
DataType.AUDIO: "embedchain.chunkers.audio.AudioChunker",
|
||||||
}
|
}
|
||||||
|
|
||||||
if chunker is not None:
|
if chunker is not None:
|
||||||
|
|||||||
51
embedchain/loaders/audio.py
Normal file
51
embedchain/loaders/audio.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
import os
|
||||||
|
import hashlib
|
||||||
|
import validators
|
||||||
|
from embedchain.helpers.json_serializable import register_deserializable
|
||||||
|
from embedchain.loaders.base_loader import BaseLoader
|
||||||
|
|
||||||
|
try:
|
||||||
|
from deepgram import DeepgramClient, PrerecordedOptions
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Audio file requires extra dependencies. Install with `pip install deepgram-sdk==3.2.7`"
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
|
class AudioLoader(BaseLoader):
|
||||||
|
def __init__(self):
|
||||||
|
if not os.environ.get("DEEPGRAM_API_KEY"):
|
||||||
|
raise ValueError("DEEPGRAM_API_KEY is not set")
|
||||||
|
|
||||||
|
DG_KEY = os.environ.get("DEEPGRAM_API_KEY")
|
||||||
|
self.client = DeepgramClient(DG_KEY)
|
||||||
|
|
||||||
|
def load_data(self, url: str):
|
||||||
|
"""Load data from a audio file or URL."""
|
||||||
|
|
||||||
|
options = PrerecordedOptions(
|
||||||
|
model="nova-2",
|
||||||
|
smart_format=True,
|
||||||
|
)
|
||||||
|
if validators.url(url):
|
||||||
|
source = {"url": url}
|
||||||
|
response = self.client.listen.prerecorded.v("1").transcribe_url(source, options)
|
||||||
|
else:
|
||||||
|
with open(url, "rb") as audio:
|
||||||
|
source = {"buffer": audio}
|
||||||
|
response = self.client.listen.prerecorded.v("1").transcribe_file(source, options)
|
||||||
|
content = response["results"]["channels"][0]["alternatives"][0]["transcript"]
|
||||||
|
|
||||||
|
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
|
||||||
|
metadata = {"url": url}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"doc_id": doc_id,
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"content": content,
|
||||||
|
"meta_data": metadata,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
@@ -41,6 +41,7 @@ class IndirectDataType(Enum):
|
|||||||
DROPBOX = "dropbox"
|
DROPBOX = "dropbox"
|
||||||
TEXT_FILE = "text_file"
|
TEXT_FILE = "text_file"
|
||||||
EXCEL_FILE = "excel_file"
|
EXCEL_FILE = "excel_file"
|
||||||
|
AUDIO = "audio"
|
||||||
|
|
||||||
|
|
||||||
class SpecialDataType(Enum):
|
class SpecialDataType(Enum):
|
||||||
@@ -81,3 +82,4 @@ class DataType(Enum):
|
|||||||
DROPBOX = IndirectDataType.DROPBOX.value
|
DROPBOX = IndirectDataType.DROPBOX.value
|
||||||
TEXT_FILE = IndirectDataType.TEXT_FILE.value
|
TEXT_FILE = IndirectDataType.TEXT_FILE.value
|
||||||
EXCEL_FILE = IndirectDataType.EXCEL_FILE.value
|
EXCEL_FILE = IndirectDataType.EXCEL_FILE.value
|
||||||
|
AUDIO = IndirectDataType.AUDIO.value
|
||||||
|
|||||||
@@ -237,6 +237,12 @@ def detect_datatype(source: Any) -> DataType:
|
|||||||
logger.debug(f"Source of `{formatted_source}` detected as `docx`.")
|
logger.debug(f"Source of `{formatted_source}` detected as `docx`.")
|
||||||
return DataType.DOCX
|
return DataType.DOCX
|
||||||
|
|
||||||
|
if url.path.endswith(
|
||||||
|
(".mp3", ".mp4", ".mp2", ".aac", ".wav", ".flac", ".pcm", ".m4a", ".ogg", ".opus", ".webm")
|
||||||
|
):
|
||||||
|
logger.debug(f"Source of `{formatted_source}` detected as `audio`.")
|
||||||
|
return DataType.AUDIO
|
||||||
|
|
||||||
if url.path.endswith(".yaml"):
|
if url.path.endswith(".yaml"):
|
||||||
try:
|
try:
|
||||||
response = requests.get(source)
|
response = requests.get(source)
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from embedchain.chunkers.text import TextChunker
|
|||||||
from embedchain.chunkers.web_page import WebPageChunker
|
from embedchain.chunkers.web_page import WebPageChunker
|
||||||
from embedchain.chunkers.xml import XmlChunker
|
from embedchain.chunkers.xml import XmlChunker
|
||||||
from embedchain.chunkers.youtube_video import YoutubeVideoChunker
|
from embedchain.chunkers.youtube_video import YoutubeVideoChunker
|
||||||
|
from embedchain.chunkers.audio import AudioChunker
|
||||||
from embedchain.config.add_config import ChunkerConfig
|
from embedchain.config.add_config import ChunkerConfig
|
||||||
|
|
||||||
chunker_config = ChunkerConfig(chunk_size=500, chunk_overlap=0, length_function=len)
|
chunker_config = ChunkerConfig(chunk_size=500, chunk_overlap=0, length_function=len)
|
||||||
@@ -45,6 +46,7 @@ chunker_common_config = {
|
|||||||
CommonChunker: {"chunk_size": 2000, "chunk_overlap": 0, "length_function": len},
|
CommonChunker: {"chunk_size": 2000, "chunk_overlap": 0, "length_function": len},
|
||||||
GoogleDriveChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
|
GoogleDriveChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
|
||||||
ExcelFileChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
|
ExcelFileChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
|
||||||
|
AudioChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
98
tests/loaders/test_audio.py
Normal file
98
tests/loaders/test_audio.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import hashlib
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import mock_open, patch
|
||||||
|
|
||||||
|
if sys.version_info > (3, 10): # as `match` statement was introduced in python 3.10
|
||||||
|
from deepgram import PrerecordedOptions
|
||||||
|
from embedchain.loaders.audio import AudioLoader
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def setup_audio_loader(mocker):
|
||||||
|
mock_dropbox = mocker.patch("deepgram.DeepgramClient")
|
||||||
|
mock_dbx = mocker.MagicMock()
|
||||||
|
mock_dropbox.return_value = mock_dbx
|
||||||
|
|
||||||
|
os.environ["DEEPGRAM_API_KEY"] = "test_key"
|
||||||
|
loader = AudioLoader()
|
||||||
|
loader.client = mock_dbx
|
||||||
|
|
||||||
|
yield loader, mock_dbx
|
||||||
|
|
||||||
|
if "DEEPGRAM_API_KEY" in os.environ:
|
||||||
|
del os.environ["DEEPGRAM_API_KEY"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 10), reason="Test skipped for Python 3.9 or lower"
|
||||||
|
) # as `match` statement was introduced in python 3.10
|
||||||
|
def test_initialization(setup_audio_loader):
|
||||||
|
"""Test initialization of AudioLoader."""
|
||||||
|
loader, _ = setup_audio_loader
|
||||||
|
assert loader is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 10), reason="Test skipped for Python 3.9 or lower"
|
||||||
|
) # as `match` statement was introduced in python 3.10
|
||||||
|
def test_load_data_from_url(setup_audio_loader):
|
||||||
|
loader, mock_dbx = setup_audio_loader
|
||||||
|
url = "https://example.com/audio.mp3"
|
||||||
|
expected_content = "This is a test audio transcript."
|
||||||
|
|
||||||
|
mock_response = {"results": {"channels": [{"alternatives": [{"transcript": expected_content}]}]}}
|
||||||
|
mock_dbx.listen.prerecorded.v.return_value.transcribe_url.return_value = mock_response
|
||||||
|
|
||||||
|
result = loader.load_data(url)
|
||||||
|
|
||||||
|
doc_id = hashlib.sha256((expected_content + url).encode()).hexdigest()
|
||||||
|
expected_result = {
|
||||||
|
"doc_id": doc_id,
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"content": expected_content,
|
||||||
|
"meta_data": {"url": url},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
assert result == expected_result
|
||||||
|
mock_dbx.listen.prerecorded.v.assert_called_once_with("1")
|
||||||
|
mock_dbx.listen.prerecorded.v.return_value.transcribe_url.assert_called_once_with(
|
||||||
|
{"url": url}, PrerecordedOptions(model="nova-2", smart_format=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 10), reason="Test skipped for Python 3.9 or lower"
|
||||||
|
) # as `match` statement was introduced in python 3.10
|
||||||
|
def test_load_data_from_file(setup_audio_loader):
|
||||||
|
loader, mock_dbx = setup_audio_loader
|
||||||
|
file_path = "local_audio.mp3"
|
||||||
|
expected_content = "This is a test audio transcript."
|
||||||
|
|
||||||
|
mock_response = {"results": {"channels": [{"alternatives": [{"transcript": expected_content}]}]}}
|
||||||
|
mock_dbx.listen.prerecorded.v.return_value.transcribe_file.return_value = mock_response
|
||||||
|
|
||||||
|
# Mock the file reading functionality
|
||||||
|
with patch("builtins.open", mock_open(read_data=b"some data")) as mock_file:
|
||||||
|
result = loader.load_data(file_path)
|
||||||
|
|
||||||
|
doc_id = hashlib.sha256((expected_content + file_path).encode()).hexdigest()
|
||||||
|
expected_result = {
|
||||||
|
"doc_id": doc_id,
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"content": expected_content,
|
||||||
|
"meta_data": {"url": file_path},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
assert result == expected_result
|
||||||
|
mock_dbx.listen.prerecorded.v.assert_called_once_with("1")
|
||||||
|
mock_dbx.listen.prerecorded.v.return_value.transcribe_file.assert_called_once_with(
|
||||||
|
{"buffer": mock_file.return_value}, PrerecordedOptions(model="nova-2", smart_format=True)
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user