You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

139 lines
3.6 KiB

#!/usr/bin/env python
# coding: utf8
""" This module provides audio data convertion functions. """
# pyright: reportMissingImports=false
# pylint: disable=import-error
import numpy as np
import tensorflow as tf
from ..utils.tensor import from_float32_to_uint8, from_uint8_to_float32
# pylint: enable=import-error
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
def to_n_channels(waveform: tf.Tensor, n_channels: int) -> tf.Tensor:
"""
Convert a waveform to n_channels by removing or duplicating channels if
needed (in tensorflow).
Parameters:
waveform (tensorflow.Tensor):
Waveform to transform.
n_channels (int):
Number of channel to reshape waveform in.
Returns:
tensorflow.Tensor:
Reshaped waveform.
"""
return tf.cond(
tf.shape(waveform)[1] >= n_channels,
true_fn=lambda: waveform[:, :n_channels],
false_fn=lambda: tf.tile(waveform, [1, n_channels])[:, :n_channels],
)
def to_stereo(waveform: np.ndarray) -> np.ndarray:
"""
Convert a waveform to stereo by duplicating if mono, or truncating
if too many channels.
Parameters:
waveform (numpy.ndarray):
a `(N, d)` numpy array.
Returns:
numpy.ndarray:
A stereo waveform as a `(N, 1)` numpy array.
"""
if waveform.shape[1] == 1:
return np.repeat(waveform, 2, axis=-1)
if waveform.shape[1] > 2:
return waveform[:, :2]
return waveform
def gain_to_db(tensor: tf.Tensor, espilon: float = 10e-10) -> tf.Tensor:
"""
Convert from gain to decibel in tensorflow.
Parameters:
tensor (tensorflow.Tensor):
Tensor to convert
epsilon (float):
Operation constant.
Returns:
tensorflow.Tensor:
Converted tensor.
"""
return 20.0 / np.log(10) * tf.math.log(tf.maximum(tensor, espilon))
def db_to_gain(tensor: tf.Tensor) -> tf.Tensor:
"""
Convert from decibel to gain in tensorflow.
Parameters:
tensor (tensorflow.Tensor):
Tensor to convert
Returns:
tensorflow.Tensor:
Converted tensor.
"""
return tf.pow(10.0, (tensor / 20.0))
def spectrogram_to_db_uint(
spectrogram: tf.Tensor, db_range: float = 100.0, **kwargs
) -> tf.Tensor:
"""
Encodes given spectrogram into uint8 using decibel scale.
Parameters:
spectrogram (tensorflow.Tensor):
Spectrogram to be encoded as TF float tensor.
db_range (float):
Range in decibel for encoding.
Returns:
tensorflow.Tensor:
Encoded decibel spectrogram as `uint8` tensor.
"""
db_spectrogram: tf.Tensor = gain_to_db(spectrogram)
max_db_spectrogram: tf.Tensor = tf.reduce_max(db_spectrogram)
db_spectrogram: tf.Tensor = tf.maximum(
db_spectrogram, max_db_spectrogram - db_range
)
return from_float32_to_uint8(db_spectrogram, **kwargs)
def db_uint_spectrogram_to_gain(
db_uint_spectrogram: tf.Tensor, min_db: tf.Tensor, max_db: tf.Tensor
) -> tf.Tensor:
"""
Decode spectrogram from uint8 decibel scale.
Paramters:
db_uint_spectrogram (tensorflow.Tensor):
Decibel spectrogram to decode.
min_db (tensorflow.Tensor):
Lower bound limit for decoding.
max_db (tensorflow.Tensor):
Upper bound limit for decoding.
Returns:
tensorflow.Tensor:
Decoded spectrogram as `float32` tensor.
"""
db_spectrogram: tf.Tensor = from_uint8_to_float32(
db_uint_spectrogram, min_db, max_db
)
return db_to_gain(db_spectrogram)