From 08b67b4a7869ff9f0ee827d6b81c0122aeade2ff Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Wed, 12 Jun 2024 22:55:58 +0530 Subject: [PATCH] Support for Audio Files (#1416) --- Makefile | 2 +- docs/components/data-sources/audio.mdx | 25 ++++++ docs/components/data-sources/overview.mdx | 2 + embedchain/chunkers/audio.py | 22 +++++ embedchain/data_formatter/data_formatter.py | 2 + embedchain/loaders/audio.py | 51 +++++++++++ embedchain/models/data_type.py | 2 + embedchain/utils/misc.py | 6 ++ tests/chunkers/test_chunkers.py | 2 + tests/loaders/test_audio.py | 98 +++++++++++++++++++++ 10 files changed, 211 insertions(+), 1 deletion(-) create mode 100644 docs/components/data-sources/audio.mdx create mode 100644 embedchain/chunkers/audio.py create mode 100644 embedchain/loaders/audio.py create mode 100644 tests/loaders/test_audio.py diff --git a/Makefile b/Makefile index 8ddb58e7..ad2ddd5f 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,7 @@ install: install_all: poetry install --all-extras - poetry run pip install pinecone-text pinecone-client langchain-anthropic "unstructured[local-inference, all-docs]" ollama + poetry run pip install pinecone-text pinecone-client langchain-anthropic "unstructured[local-inference, all-docs]" ollama deepgram-sdk==3.2.7 install_es: poetry install --extras elasticsearch diff --git a/docs/components/data-sources/audio.mdx b/docs/components/data-sources/audio.mdx new file mode 100644 index 00000000..5f2772a7 --- /dev/null +++ b/docs/components/data-sources/audio.mdx @@ -0,0 +1,25 @@ +--- +title: "🎤 Audio" +--- + + +To use an audio as data source, just add `data_type` as `audio` and pass in the path of the audio (local or hosted). + +We use [Deepgram](https://developers.deepgram.com/docs/introduction) to transcribe the audiot to text, and then use the generated text as the data source. + +You would require an Deepgram API key which is available [here](https://console.deepgram.com/signup?jump=keys) to use this feature. + +### Without customization + +```python +import os +from embedchain import App + +os.environ["DEEPGRAM_API_KEY"] = "153xxx" + +app = App() +app.add("introduction.wav", data_type="audio") +response = app.query("What is my name and how old am I?") +print(response) +# Answer: Your name is Dave and you are 21 years old. +``` diff --git a/docs/components/data-sources/overview.mdx b/docs/components/data-sources/overview.mdx index 16114f52..66f5948a 100644 --- a/docs/components/data-sources/overview.mdx +++ b/docs/components/data-sources/overview.mdx @@ -9,6 +9,7 @@ Embedchain comes with built-in support for various data sources. We handle the c + @@ -33,6 +34,7 @@ Embedchain comes with built-in support for various data sources. We handle the c + diff --git a/embedchain/chunkers/audio.py b/embedchain/chunkers/audio.py new file mode 100644 index 00000000..0aebda32 --- /dev/null +++ b/embedchain/chunkers/audio.py @@ -0,0 +1,22 @@ +from typing import Optional + +from langchain.text_splitter import RecursiveCharacterTextSplitter + +from embedchain.chunkers.base_chunker import BaseChunker +from embedchain.config.add_config import ChunkerConfig +from embedchain.helpers.json_serializable import register_deserializable + + +@register_deserializable +class AudioChunker(BaseChunker): + """Chunker for audio.""" + + def __init__(self, config: Optional[ChunkerConfig] = None): + if config is None: + config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len) + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=config.chunk_size, + chunk_overlap=config.chunk_overlap, + length_function=config.length_function, + ) + super().__init__(text_splitter) diff --git a/embedchain/data_formatter/data_formatter.py b/embedchain/data_formatter/data_formatter.py index 75dd76a2..398d16a3 100644 --- a/embedchain/data_formatter/data_formatter.py +++ b/embedchain/data_formatter/data_formatter.py @@ -81,6 +81,7 @@ class DataFormatter(JSONSerializable): DataType.DROPBOX: "embedchain.loaders.dropbox.DropboxLoader", DataType.TEXT_FILE: "embedchain.loaders.text_file.TextFileLoader", DataType.EXCEL_FILE: "embedchain.loaders.excel_file.ExcelFileLoader", + DataType.AUDIO: "embedchain.loaders.audio.AudioLoader", } if data_type == DataType.CUSTOM or loader is not None: @@ -129,6 +130,7 @@ class DataFormatter(JSONSerializable): DataType.DROPBOX: "embedchain.chunkers.common_chunker.CommonChunker", DataType.TEXT_FILE: "embedchain.chunkers.common_chunker.CommonChunker", DataType.EXCEL_FILE: "embedchain.chunkers.excel_file.ExcelFileChunker", + DataType.AUDIO: "embedchain.chunkers.audio.AudioChunker", } if chunker is not None: diff --git a/embedchain/loaders/audio.py b/embedchain/loaders/audio.py new file mode 100644 index 00000000..44d7e9bd --- /dev/null +++ b/embedchain/loaders/audio.py @@ -0,0 +1,51 @@ +import os +import hashlib +import validators +from embedchain.helpers.json_serializable import register_deserializable +from embedchain.loaders.base_loader import BaseLoader + +try: + from deepgram import DeepgramClient, PrerecordedOptions +except ImportError: + raise ImportError( + "Audio file requires extra dependencies. Install with `pip install deepgram-sdk==3.2.7`" + ) from None + + +@register_deserializable +class AudioLoader(BaseLoader): + def __init__(self): + if not os.environ.get("DEEPGRAM_API_KEY"): + raise ValueError("DEEPGRAM_API_KEY is not set") + + DG_KEY = os.environ.get("DEEPGRAM_API_KEY") + self.client = DeepgramClient(DG_KEY) + + def load_data(self, url: str): + """Load data from a audio file or URL.""" + + options = PrerecordedOptions( + model="nova-2", + smart_format=True, + ) + if validators.url(url): + source = {"url": url} + response = self.client.listen.prerecorded.v("1").transcribe_url(source, options) + else: + with open(url, "rb") as audio: + source = {"buffer": audio} + response = self.client.listen.prerecorded.v("1").transcribe_file(source, options) + content = response["results"]["channels"][0]["alternatives"][0]["transcript"] + + doc_id = hashlib.sha256((content + url).encode()).hexdigest() + metadata = {"url": url} + + return { + "doc_id": doc_id, + "data": [ + { + "content": content, + "meta_data": metadata, + } + ], + } diff --git a/embedchain/models/data_type.py b/embedchain/models/data_type.py index 5f3ebfac..6370bf06 100644 --- a/embedchain/models/data_type.py +++ b/embedchain/models/data_type.py @@ -41,6 +41,7 @@ class IndirectDataType(Enum): DROPBOX = "dropbox" TEXT_FILE = "text_file" EXCEL_FILE = "excel_file" + AUDIO = "audio" class SpecialDataType(Enum): @@ -81,3 +82,4 @@ class DataType(Enum): DROPBOX = IndirectDataType.DROPBOX.value TEXT_FILE = IndirectDataType.TEXT_FILE.value EXCEL_FILE = IndirectDataType.EXCEL_FILE.value + AUDIO = IndirectDataType.AUDIO.value diff --git a/embedchain/utils/misc.py b/embedchain/utils/misc.py index 52872207..31fd12ac 100644 --- a/embedchain/utils/misc.py +++ b/embedchain/utils/misc.py @@ -237,6 +237,12 @@ def detect_datatype(source: Any) -> DataType: logger.debug(f"Source of `{formatted_source}` detected as `docx`.") return DataType.DOCX + if url.path.endswith( + (".mp3", ".mp4", ".mp2", ".aac", ".wav", ".flac", ".pcm", ".m4a", ".ogg", ".opus", ".webm") + ): + logger.debug(f"Source of `{formatted_source}` detected as `audio`.") + return DataType.AUDIO + if url.path.endswith(".yaml"): try: response = requests.get(source) diff --git a/tests/chunkers/test_chunkers.py b/tests/chunkers/test_chunkers.py index 55c61e56..067fa869 100644 --- a/tests/chunkers/test_chunkers.py +++ b/tests/chunkers/test_chunkers.py @@ -19,6 +19,7 @@ from embedchain.chunkers.text import TextChunker from embedchain.chunkers.web_page import WebPageChunker from embedchain.chunkers.xml import XmlChunker from embedchain.chunkers.youtube_video import YoutubeVideoChunker +from embedchain.chunkers.audio import AudioChunker from embedchain.config.add_config import ChunkerConfig chunker_config = ChunkerConfig(chunk_size=500, chunk_overlap=0, length_function=len) @@ -45,6 +46,7 @@ chunker_common_config = { CommonChunker: {"chunk_size": 2000, "chunk_overlap": 0, "length_function": len}, GoogleDriveChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len}, ExcelFileChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len}, + AudioChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len}, } diff --git a/tests/loaders/test_audio.py b/tests/loaders/test_audio.py new file mode 100644 index 00000000..c365893b --- /dev/null +++ b/tests/loaders/test_audio.py @@ -0,0 +1,98 @@ +import os +import sys +import hashlib +import pytest +from unittest.mock import mock_open, patch + +if sys.version_info > (3, 10): # as `match` statement was introduced in python 3.10 + from deepgram import PrerecordedOptions + from embedchain.loaders.audio import AudioLoader + + +@pytest.fixture +def setup_audio_loader(mocker): + mock_dropbox = mocker.patch("deepgram.DeepgramClient") + mock_dbx = mocker.MagicMock() + mock_dropbox.return_value = mock_dbx + + os.environ["DEEPGRAM_API_KEY"] = "test_key" + loader = AudioLoader() + loader.client = mock_dbx + + yield loader, mock_dbx + + if "DEEPGRAM_API_KEY" in os.environ: + del os.environ["DEEPGRAM_API_KEY"] + + +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="Test skipped for Python 3.9 or lower" +) # as `match` statement was introduced in python 3.10 +def test_initialization(setup_audio_loader): + """Test initialization of AudioLoader.""" + loader, _ = setup_audio_loader + assert loader is not None + + +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="Test skipped for Python 3.9 or lower" +) # as `match` statement was introduced in python 3.10 +def test_load_data_from_url(setup_audio_loader): + loader, mock_dbx = setup_audio_loader + url = "https://example.com/audio.mp3" + expected_content = "This is a test audio transcript." + + mock_response = {"results": {"channels": [{"alternatives": [{"transcript": expected_content}]}]}} + mock_dbx.listen.prerecorded.v.return_value.transcribe_url.return_value = mock_response + + result = loader.load_data(url) + + doc_id = hashlib.sha256((expected_content + url).encode()).hexdigest() + expected_result = { + "doc_id": doc_id, + "data": [ + { + "content": expected_content, + "meta_data": {"url": url}, + } + ], + } + + assert result == expected_result + mock_dbx.listen.prerecorded.v.assert_called_once_with("1") + mock_dbx.listen.prerecorded.v.return_value.transcribe_url.assert_called_once_with( + {"url": url}, PrerecordedOptions(model="nova-2", smart_format=True) + ) + + +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="Test skipped for Python 3.9 or lower" +) # as `match` statement was introduced in python 3.10 +def test_load_data_from_file(setup_audio_loader): + loader, mock_dbx = setup_audio_loader + file_path = "local_audio.mp3" + expected_content = "This is a test audio transcript." + + mock_response = {"results": {"channels": [{"alternatives": [{"transcript": expected_content}]}]}} + mock_dbx.listen.prerecorded.v.return_value.transcribe_file.return_value = mock_response + + # Mock the file reading functionality + with patch("builtins.open", mock_open(read_data=b"some data")) as mock_file: + result = loader.load_data(file_path) + + doc_id = hashlib.sha256((expected_content + file_path).encode()).hexdigest() + expected_result = { + "doc_id": doc_id, + "data": [ + { + "content": expected_content, + "meta_data": {"url": file_path}, + } + ], + } + + assert result == expected_result + mock_dbx.listen.prerecorded.v.assert_called_once_with("1") + mock_dbx.listen.prerecorded.v.return_value.transcribe_file.assert_called_once_with( + {"buffer": mock_file.return_value}, PrerecordedOptions(model="nova-2", smart_format=True) + )