Support for Audio Files (#1416)
This commit is contained in:
22
embedchain/chunkers/audio.py
Normal file
22
embedchain/chunkers/audio.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class AudioChunker(BaseChunker):
|
||||
"""Chunker for audio."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
@@ -81,6 +81,7 @@ class DataFormatter(JSONSerializable):
|
||||
DataType.DROPBOX: "embedchain.loaders.dropbox.DropboxLoader",
|
||||
DataType.TEXT_FILE: "embedchain.loaders.text_file.TextFileLoader",
|
||||
DataType.EXCEL_FILE: "embedchain.loaders.excel_file.ExcelFileLoader",
|
||||
DataType.AUDIO: "embedchain.loaders.audio.AudioLoader",
|
||||
}
|
||||
|
||||
if data_type == DataType.CUSTOM or loader is not None:
|
||||
@@ -129,6 +130,7 @@ class DataFormatter(JSONSerializable):
|
||||
DataType.DROPBOX: "embedchain.chunkers.common_chunker.CommonChunker",
|
||||
DataType.TEXT_FILE: "embedchain.chunkers.common_chunker.CommonChunker",
|
||||
DataType.EXCEL_FILE: "embedchain.chunkers.excel_file.ExcelFileChunker",
|
||||
DataType.AUDIO: "embedchain.chunkers.audio.AudioChunker",
|
||||
}
|
||||
|
||||
if chunker is not None:
|
||||
|
||||
51
embedchain/loaders/audio.py
Normal file
51
embedchain/loaders/audio.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import os
|
||||
import hashlib
|
||||
import validators
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
try:
|
||||
from deepgram import DeepgramClient, PrerecordedOptions
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Audio file requires extra dependencies. Install with `pip install deepgram-sdk==3.2.7`"
|
||||
) from None
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class AudioLoader(BaseLoader):
|
||||
def __init__(self):
|
||||
if not os.environ.get("DEEPGRAM_API_KEY"):
|
||||
raise ValueError("DEEPGRAM_API_KEY is not set")
|
||||
|
||||
DG_KEY = os.environ.get("DEEPGRAM_API_KEY")
|
||||
self.client = DeepgramClient(DG_KEY)
|
||||
|
||||
def load_data(self, url: str):
|
||||
"""Load data from a audio file or URL."""
|
||||
|
||||
options = PrerecordedOptions(
|
||||
model="nova-2",
|
||||
smart_format=True,
|
||||
)
|
||||
if validators.url(url):
|
||||
source = {"url": url}
|
||||
response = self.client.listen.prerecorded.v("1").transcribe_url(source, options)
|
||||
else:
|
||||
with open(url, "rb") as audio:
|
||||
source = {"buffer": audio}
|
||||
response = self.client.listen.prerecorded.v("1").transcribe_file(source, options)
|
||||
content = response["results"]["channels"][0]["alternatives"][0]["transcript"]
|
||||
|
||||
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
|
||||
metadata = {"url": url}
|
||||
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": [
|
||||
{
|
||||
"content": content,
|
||||
"meta_data": metadata,
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -41,6 +41,7 @@ class IndirectDataType(Enum):
|
||||
DROPBOX = "dropbox"
|
||||
TEXT_FILE = "text_file"
|
||||
EXCEL_FILE = "excel_file"
|
||||
AUDIO = "audio"
|
||||
|
||||
|
||||
class SpecialDataType(Enum):
|
||||
@@ -81,3 +82,4 @@ class DataType(Enum):
|
||||
DROPBOX = IndirectDataType.DROPBOX.value
|
||||
TEXT_FILE = IndirectDataType.TEXT_FILE.value
|
||||
EXCEL_FILE = IndirectDataType.EXCEL_FILE.value
|
||||
AUDIO = IndirectDataType.AUDIO.value
|
||||
|
||||
@@ -237,6 +237,12 @@ def detect_datatype(source: Any) -> DataType:
|
||||
logger.debug(f"Source of `{formatted_source}` detected as `docx`.")
|
||||
return DataType.DOCX
|
||||
|
||||
if url.path.endswith(
|
||||
(".mp3", ".mp4", ".mp2", ".aac", ".wav", ".flac", ".pcm", ".m4a", ".ogg", ".opus", ".webm")
|
||||
):
|
||||
logger.debug(f"Source of `{formatted_source}` detected as `audio`.")
|
||||
return DataType.AUDIO
|
||||
|
||||
if url.path.endswith(".yaml"):
|
||||
try:
|
||||
response = requests.get(source)
|
||||
|
||||
Reference in New Issue
Block a user