|
|
- #!/usr/bin/env python
- # coding: utf8
-
- """ This module provides audio data convertion functions. """
-
- # pyright: reportMissingImports=false
- # pylint: disable=import-error
- import numpy as np
- import tensorflow as tf
-
- from ..utils.tensor import from_float32_to_uint8, from_uint8_to_float32
-
- # pylint: enable=import-error
-
- __email__ = "spleeter@deezer.com"
- __author__ = "Deezer Research"
- __license__ = "MIT License"
-
-
- def to_n_channels(waveform: tf.Tensor, n_channels: int) -> tf.Tensor:
- """
- Convert a waveform to n_channels by removing or duplicating channels if
- needed (in tensorflow).
-
- Parameters:
- waveform (tensorflow.Tensor):
- Waveform to transform.
- n_channels (int):
- Number of channel to reshape waveform in.
-
- Returns:
- tensorflow.Tensor:
- Reshaped waveform.
- """
- return tf.cond(
- tf.shape(waveform)[1] >= n_channels,
- true_fn=lambda: waveform[:, :n_channels],
- false_fn=lambda: tf.tile(waveform, [1, n_channels])[:, :n_channels],
- )
-
-
- def to_stereo(waveform: np.ndarray) -> np.ndarray:
- """
- Convert a waveform to stereo by duplicating if mono, or truncating
- if too many channels.
-
- Parameters:
- waveform (numpy.ndarray):
- a `(N, d)` numpy array.
-
- Returns:
- numpy.ndarray:
- A stereo waveform as a `(N, 1)` numpy array.
- """
- if waveform.shape[1] == 1:
- return np.repeat(waveform, 2, axis=-1)
- if waveform.shape[1] > 2:
- return waveform[:, :2]
- return waveform
-
-
- def gain_to_db(tensor: tf.Tensor, espilon: float = 10e-10) -> tf.Tensor:
- """
- Convert from gain to decibel in tensorflow.
-
- Parameters:
- tensor (tensorflow.Tensor):
- Tensor to convert
- epsilon (float):
- Operation constant.
-
- Returns:
- tensorflow.Tensor:
- Converted tensor.
- """
- return 20.0 / np.log(10) * tf.math.log(tf.maximum(tensor, espilon))
-
-
- def db_to_gain(tensor: tf.Tensor) -> tf.Tensor:
- """
- Convert from decibel to gain in tensorflow.
-
- Parameters:
- tensor (tensorflow.Tensor):
- Tensor to convert
-
- Returns:
- tensorflow.Tensor:
- Converted tensor.
- """
- return tf.pow(10.0, (tensor / 20.0))
-
-
- def spectrogram_to_db_uint(
- spectrogram: tf.Tensor, db_range: float = 100.0, **kwargs
- ) -> tf.Tensor:
- """
- Encodes given spectrogram into uint8 using decibel scale.
-
- Parameters:
- spectrogram (tensorflow.Tensor):
- Spectrogram to be encoded as TF float tensor.
- db_range (float):
- Range in decibel for encoding.
-
- Returns:
- tensorflow.Tensor:
- Encoded decibel spectrogram as `uint8` tensor.
- """
- db_spectrogram: tf.Tensor = gain_to_db(spectrogram)
- max_db_spectrogram: tf.Tensor = tf.reduce_max(db_spectrogram)
- db_spectrogram: tf.Tensor = tf.maximum(
- db_spectrogram, max_db_spectrogram - db_range
- )
- return from_float32_to_uint8(db_spectrogram, **kwargs)
-
-
- def db_uint_spectrogram_to_gain(
- db_uint_spectrogram: tf.Tensor, min_db: tf.Tensor, max_db: tf.Tensor
- ) -> tf.Tensor:
- """
- Decode spectrogram from uint8 decibel scale.
-
- Paramters:
- db_uint_spectrogram (tensorflow.Tensor):
- Decibel spectrogram to decode.
- min_db (tensorflow.Tensor):
- Lower bound limit for decoding.
- max_db (tensorflow.Tensor):
- Upper bound limit for decoding.
-
- Returns:
- tensorflow.Tensor:
- Decoded spectrogram as `float32` tensor.
- """
- db_spectrogram: tf.Tensor = from_uint8_to_float32(
- db_uint_spectrogram, min_db, max_db
- )
- return db_to_gain(db_spectrogram)
|