You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

200 lines
6.4 KiB

#!/usr/bin/env python
# coding: utf8
""" AudioAdapter class defintion. """
from abc import ABC, abstractmethod
from importlib import import_module
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
# pyright: reportMissingImports=false
# pylint: disable=import-error
import numpy as np
import tensorflow as tf
from spleeter.audio import Codec
from .. import SpleeterError
from ..types import AudioDescriptor, Signal
from ..utils.logging import logger
# pylint: enable=import-error
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
class AudioAdapter(ABC):
""" An abstract class for manipulating audio signal. """
_DEFAULT: "AudioAdapter" = None
""" Default audio adapter singleton instance. """
@abstractmethod
def load(
self,
audio_descriptor: AudioDescriptor,
offset: Optional[float] = None,
duration: Optional[float] = None,
sample_rate: Optional[float] = None,
dtype: np.dtype = np.float32,
) -> Signal:
"""
Loads the audio file denoted by the given audio descriptor and
returns it data as a waveform. Aims to be implemented by client.
Parameters:
audio_descriptor (AudioDescriptor):
Describe song to load, in case of file based audio adapter,
such descriptor would be a file path.
offset (Optional[float]):
Start offset to load from in seconds.
duration (Optional[float]):
Duration to load in seconds.
sample_rate (Optional[float]):
Sample rate to load audio with.
dtype (numpy.dtype):
(Optional) Numpy data type to use, default to `float32`.
Returns:
Signal:
Loaded data as (wf, sample_rate) tuple.
"""
pass
def load_tf_waveform(
self,
audio_descriptor,
offset: float = 0.0,
duration: float = 1800.0,
sample_rate: int = 44100,
dtype: bytes = b"float32",
waveform_name: str = "waveform",
) -> Dict[str, Any]:
"""
Load the audio and convert it to a tensorflow waveform.
Parameters:
audio_descriptor ():
Describe song to load, in case of file based audio adapter,
such descriptor would be a file path.
offset (float):
Start offset to load from in seconds.
duration (float):
Duration to load in seconds.
sample_rate (float):
Sample rate to load audio with.
dtype (bytes):
(Optional)data type to use, default to `b'float32'`.
waveform_name (str):
(Optional) Name of the key in output dict, default to
`'waveform'`.
Returns:
Dict[str, Any]:
TF output dict with waveform as `(T x chan numpy array)`
and a boolean that tells whether there were an error while
trying to load the waveform.
"""
# Cast parameters to TF format.
offset = tf.cast(offset, tf.float64)
duration = tf.cast(duration, tf.float64)
# Defined safe loading function.
def safe_load(path, offset, duration, sample_rate, dtype):
logger.info(f"Loading audio {path} from {offset} to {offset + duration}")
try:
(data, _) = self.load(
path.numpy(),
offset.numpy(),
duration.numpy(),
sample_rate.numpy(),
dtype=dtype.numpy(),
)
logger.info("Audio data loaded successfully")
return (data, False)
except Exception as e:
logger.exception("An error occurs while loading audio", exc_info=e)
return (np.float32(-1.0), True)
# Execute function and format results.
results = (
tf.py_function(
safe_load,
[audio_descriptor, offset, duration, sample_rate, dtype],
(tf.float32, tf.bool),
),
)
waveform, error = results[0]
return {waveform_name: waveform, f"{waveform_name}_error": error}
@abstractmethod
def save(
self,
path: Union[Path, str],
data: np.ndarray,
sample_rate: float,
codec: Codec = None,
bitrate: str = None,
) -> None:
"""
Save the given audio data to the file denoted by the given path.
Parameters:
path (Union[Path, str]):
Path like of the audio file to save data in.
data (numpy.ndarray):
Waveform data to write.
sample_rate (float):
Sample rate to write file in.
codec ():
(Optional) Writing codec to use, default to `None`.
bitrate (str):
(Optional) Bitrate of the written audio file, default to
`None`.
"""
pass
@classmethod
def default(cls: type) -> "AudioAdapter":
"""
Builds and returns a default audio adapter instance.
Returns:
AudioAdapter:
Default adapter instance to use.
"""
if cls._DEFAULT is None:
from .ffmpeg import FFMPEGProcessAudioAdapter
cls._DEFAULT = FFMPEGProcessAudioAdapter()
return cls._DEFAULT
@classmethod
def get(cls: type, descriptor: str) -> "AudioAdapter":
"""
Load dynamically an AudioAdapter from given class descriptor.
Parameters:
descriptor (str):
Adapter class descriptor (module.Class)
Returns:
AudioAdapter:
Created adapter instance.
"""
if not descriptor:
return cls.default()
module_path: List[str] = descriptor.split(".")
adapter_class_name: str = module_path[-1]
module_path: str = ".".join(module_path[:-1])
adapter_module = import_module(module_path)
adapter_class = getattr(adapter_module, adapter_class_name)
if not issubclass(adapter_class, AudioAdapter):
raise SpleeterError(
f"{adapter_class_name} is not a valid AudioAdapter class"
)
return adapter_class()