#!/usr/bin/env python
|
|
# coding: utf8
|
|
|
|
"""
|
|
Module for building data preprocessing pipeline using the tensorflow
|
|
data API. Data preprocessing such as audio loading, spectrogram
|
|
computation, cropping, feature caching or data augmentation is done
|
|
using a tensorflow dataset object that output a tuple (input_, output)
|
|
where:
|
|
|
|
- input is a dictionary with a single key that contains the (batched)
|
|
mix spectrogram of audio samples
|
|
- output is a dictionary of spectrogram of the isolated tracks
|
|
(ground truth)
|
|
"""
|
|
|
|
import os
|
|
import time
|
|
from os.path import exists
|
|
from os.path import sep as SEPARATOR
|
|
from typing import Any, Dict, Optional
|
|
|
|
# pyright: reportMissingImports=false
|
|
# pylint: disable=import-error
|
|
import tensorflow as tf
|
|
|
|
from .audio.adapter import AudioAdapter
|
|
from .audio.convertor import db_uint_spectrogram_to_gain, spectrogram_to_db_uint
|
|
from .audio.spectrogram import (
|
|
compute_spectrogram_tf,
|
|
random_pitch_shift,
|
|
random_time_stretch,
|
|
)
|
|
from .utils.logging import logger
|
|
from .utils.tensor import (
|
|
check_tensor_shape,
|
|
dataset_from_csv,
|
|
set_tensor_shape,
|
|
sync_apply,
|
|
)
|
|
|
|
# pylint: enable=import-error
|
|
|
|
__email__ = "spleeter@deezer.com"
|
|
__author__ = "Deezer Research"
|
|
__license__ = "MIT License"
|
|
|
|
# Default audio parameters to use.
|
|
DEFAULT_AUDIO_PARAMS: Dict = {
|
|
"instrument_list": ("vocals", "accompaniment"),
|
|
"mix_name": "mix",
|
|
"sample_rate": 44100,
|
|
"frame_length": 4096,
|
|
"frame_step": 1024,
|
|
"T": 512,
|
|
"F": 1024,
|
|
}
|
|
|
|
|
|
def get_training_dataset(
|
|
audio_params: Dict, audio_adapter: AudioAdapter, audio_path: str
|
|
) -> Any:
|
|
"""
|
|
Builds training dataset.
|
|
|
|
Parameters:
|
|
audio_params (Dict):
|
|
Audio parameters.
|
|
audio_adapter (AudioAdapter):
|
|
Adapter to load audio from.
|
|
audio_path (str):
|
|
Path of directory containing audio.
|
|
|
|
Returns:
|
|
Any:
|
|
Built dataset.
|
|
"""
|
|
builder = DatasetBuilder(
|
|
audio_params,
|
|
audio_adapter,
|
|
audio_path,
|
|
chunk_duration=audio_params.get("chunk_duration", 20.0),
|
|
random_seed=audio_params.get("random_seed", 0),
|
|
)
|
|
return builder.build(
|
|
audio_params.get("train_csv"),
|
|
cache_directory=audio_params.get("training_cache"),
|
|
batch_size=audio_params.get("batch_size"),
|
|
n_chunks_per_song=audio_params.get("n_chunks_per_song", 2),
|
|
random_data_augmentation=False,
|
|
convert_to_uint=True,
|
|
wait_for_cache=False,
|
|
)
|
|
|
|
|
|
def get_validation_dataset(
|
|
audio_params: Dict, audio_adapter: AudioAdapter, audio_path: str
|
|
) -> Any:
|
|
"""
|
|
Builds validation dataset.
|
|
|
|
Parameters:
|
|
audio_params (Dict):
|
|
Audio parameters.
|
|
audio_adapter (AudioAdapter):
|
|
Adapter to load audio from.
|
|
audio_path (str):
|
|
Path of directory containing audio.
|
|
|
|
Returns:
|
|
Any:
|
|
Built dataset.
|
|
"""
|
|
builder = DatasetBuilder(
|
|
audio_params, audio_adapter, audio_path, chunk_duration=12.0
|
|
)
|
|
return builder.build(
|
|
audio_params.get("validation_csv"),
|
|
batch_size=audio_params.get("batch_size"),
|
|
cache_directory=audio_params.get("validation_cache"),
|
|
convert_to_uint=True,
|
|
infinite_generator=False,
|
|
n_chunks_per_song=1,
|
|
# should not perform data augmentation for eval:
|
|
random_data_augmentation=False,
|
|
random_time_crop=False,
|
|
shuffle=False,
|
|
)
|
|
|
|
|
|
class InstrumentDatasetBuilder(object):
|
|
""" Instrument based filter and mapper provider. """
|
|
|
|
def __init__(self, parent, instrument) -> None:
|
|
"""
|
|
Default constructor.
|
|
|
|
Parameters:
|
|
parent:
|
|
Parent dataset builder.
|
|
instrument:
|
|
Target instrument.
|
|
"""
|
|
self._parent = parent
|
|
self._instrument = instrument
|
|
self._spectrogram_key = f"{instrument}_spectrogram"
|
|
self._min_spectrogram_key = f"min_{instrument}_spectrogram"
|
|
self._max_spectrogram_key = f"max_{instrument}_spectrogram"
|
|
|
|
def load_waveform(self, sample):
|
|
""" Load waveform for given sample. """
|
|
return dict(
|
|
sample,
|
|
**self._parent._audio_adapter.load_tf_waveform(
|
|
sample[f"{self._instrument}_path"],
|
|
offset=sample["start"],
|
|
duration=self._parent._chunk_duration,
|
|
sample_rate=self._parent._sample_rate,
|
|
waveform_name="waveform",
|
|
),
|
|
)
|
|
|
|
def compute_spectrogram(self, sample):
|
|
""" Compute spectrogram of the given sample. """
|
|
return dict(
|
|
sample,
|
|
**{
|
|
self._spectrogram_key: compute_spectrogram_tf(
|
|
sample["waveform"],
|
|
frame_length=self._parent._frame_length,
|
|
frame_step=self._parent._frame_step,
|
|
spec_exponent=1.0,
|
|
window_exponent=1.0,
|
|
)
|
|
},
|
|
)
|
|
|
|
def filter_frequencies(self, sample):
|
|
""" """
|
|
return dict(
|
|
sample,
|
|
**{
|
|
self._spectrogram_key: sample[self._spectrogram_key][
|
|
:, : self._parent._F, :
|
|
]
|
|
},
|
|
)
|
|
|
|
def convert_to_uint(self, sample):
|
|
""" Convert given sample from float to unit. """
|
|
return dict(
|
|
sample,
|
|
**spectrogram_to_db_uint(
|
|
sample[self._spectrogram_key],
|
|
tensor_key=self._spectrogram_key,
|
|
min_key=self._min_spectrogram_key,
|
|
max_key=self._max_spectrogram_key,
|
|
),
|
|
)
|
|
|
|
def filter_infinity(self, sample):
|
|
""" Filter infinity sample. """
|
|
return tf.logical_not(tf.math.is_inf(sample[self._min_spectrogram_key]))
|
|
|
|
def convert_to_float32(self, sample):
|
|
""" Convert given sample from unit to float. """
|
|
return dict(
|
|
sample,
|
|
**{
|
|
self._spectrogram_key: db_uint_spectrogram_to_gain(
|
|
sample[self._spectrogram_key],
|
|
sample[self._min_spectrogram_key],
|
|
sample[self._max_spectrogram_key],
|
|
)
|
|
},
|
|
)
|
|
|
|
def time_crop(self, sample):
|
|
""" """
|
|
|
|
def start(sample):
|
|
""" mid_segment_start """
|
|
return tf.cast(
|
|
tf.maximum(
|
|
tf.shape(sample[self._spectrogram_key])[0] / 2
|
|
- self._parent._T / 2,
|
|
0,
|
|
),
|
|
tf.int32,
|
|
)
|
|
|
|
return dict(
|
|
sample,
|
|
**{
|
|
self._spectrogram_key: sample[self._spectrogram_key][
|
|
start(sample) : start(sample) + self._parent._T, :, :
|
|
]
|
|
},
|
|
)
|
|
|
|
def filter_shape(self, sample):
|
|
""" Filter badly shaped sample. """
|
|
return check_tensor_shape(
|
|
sample[self._spectrogram_key],
|
|
(self._parent._T, self._parent._F, self._parent._n_channels),
|
|
)
|
|
|
|
def reshape_spectrogram(self, sample):
|
|
""" Reshape given sample. """
|
|
return dict(
|
|
sample,
|
|
**{
|
|
self._spectrogram_key: set_tensor_shape(
|
|
sample[self._spectrogram_key],
|
|
(self._parent._T, self._parent._F, self._parent._n_channels),
|
|
)
|
|
},
|
|
)
|
|
|
|
|
|
class DatasetBuilder(object):
|
|
"""
|
|
TO BE DOCUMENTED.
|
|
"""
|
|
|
|
MARGIN: float = 0.5
|
|
""" Margin at beginning and end of songs in seconds. """
|
|
|
|
WAIT_PERIOD: int = 60
|
|
""" Wait period for cache (in seconds). """
|
|
|
|
def __init__(
|
|
self,
|
|
audio_params: Dict,
|
|
audio_adapter: AudioAdapter,
|
|
audio_path: str,
|
|
random_seed: int = 0,
|
|
chunk_duration: float = 20.0,
|
|
) -> None:
|
|
"""
|
|
Default constructor.
|
|
|
|
NOTE: Probably need for AudioAdapter.
|
|
|
|
Parameters:
|
|
audio_params (Dict):
|
|
Audio parameters to use.
|
|
audio_adapter (AudioAdapter):
|
|
Audio adapter to use.
|
|
audio_path (str):
|
|
random_seed (int):
|
|
chunk_duration (float):
|
|
"""
|
|
# Length of segment in frames (if fs=22050 and
|
|
# frame_step=512, then T=512 corresponds to 11.89s)
|
|
self._T = audio_params["T"]
|
|
# Number of frequency bins to be used (should
|
|
# be less than frame_length/2 + 1)
|
|
self._F = audio_params["F"]
|
|
self._sample_rate = audio_params["sample_rate"]
|
|
self._frame_length = audio_params["frame_length"]
|
|
self._frame_step = audio_params["frame_step"]
|
|
self._mix_name = audio_params["mix_name"]
|
|
self._n_channels = audio_params["n_channels"]
|
|
self._instruments = [self._mix_name] + audio_params["instrument_list"]
|
|
self._instrument_builders = None
|
|
self._chunk_duration = chunk_duration
|
|
self._audio_adapter = audio_adapter
|
|
self._audio_params = audio_params
|
|
self._audio_path = audio_path
|
|
self._random_seed = random_seed
|
|
|
|
self.check_parameters_compatibility()
|
|
|
|
def check_parameters_compatibility(self):
|
|
if self._frame_length / 2 + 1 < self._F:
|
|
raise ValueError(
|
|
"F is too large and must be set to at most frame_length/2+1. Decrease F or increase frame_length to fix."
|
|
)
|
|
|
|
if (
|
|
self._chunk_duration * self._sample_rate - self._frame_length
|
|
) / self._frame_step < self._T:
|
|
raise ValueError(
|
|
"T is too large considering STFT parameters and chunk duratoin. Make sure spectrogram time dimension of chunks is larger than T (for instance reducing T or frame_step or increasing chunk duration)."
|
|
)
|
|
|
|
def expand_path(self, sample):
|
|
""" Expands audio paths for the given sample. """
|
|
return dict(
|
|
sample,
|
|
**{
|
|
f"{instrument}_path": tf.strings.join(
|
|
(self._audio_path, sample[f"{instrument}_path"]), SEPARATOR
|
|
)
|
|
for instrument in self._instruments
|
|
},
|
|
)
|
|
|
|
def filter_error(self, sample):
|
|
""" Filter errored sample. """
|
|
return tf.logical_not(sample["waveform_error"])
|
|
|
|
def filter_waveform(self, sample):
|
|
""" Filter waveform from sample. """
|
|
return {k: v for k, v in sample.items() if not k == "waveform"}
|
|
|
|
def harmonize_spectrogram(self, sample):
|
|
""" Ensure same size for vocals and mix spectrograms. """
|
|
|
|
def _reduce(sample):
|
|
return tf.reduce_min(
|
|
[
|
|
tf.shape(sample[f"{instrument}_spectrogram"])[0]
|
|
for instrument in self._instruments
|
|
]
|
|
)
|
|
|
|
return dict(
|
|
sample,
|
|
**{
|
|
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"][
|
|
: _reduce(sample), :, :
|
|
]
|
|
for instrument in self._instruments
|
|
},
|
|
)
|
|
|
|
def filter_short_segments(self, sample):
|
|
""" Filter out too short segment. """
|
|
return tf.reduce_any(
|
|
[
|
|
tf.shape(sample[f"{instrument}_spectrogram"])[0] >= self._T
|
|
for instrument in self._instruments
|
|
]
|
|
)
|
|
|
|
def random_time_crop(self, sample):
|
|
""" Random time crop of 11.88s. """
|
|
return dict(
|
|
sample,
|
|
**sync_apply(
|
|
{
|
|
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
|
|
for instrument in self._instruments
|
|
},
|
|
lambda x: tf.image.random_crop(
|
|
x,
|
|
(self._T, len(self._instruments) * self._F, self._n_channels),
|
|
seed=self._random_seed,
|
|
),
|
|
),
|
|
)
|
|
|
|
def random_time_stretch(self, sample):
|
|
""" Randomly time stretch the given sample. """
|
|
return dict(
|
|
sample,
|
|
**sync_apply(
|
|
{
|
|
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
|
|
for instrument in self._instruments
|
|
},
|
|
lambda x: random_time_stretch(x, factor_min=0.9, factor_max=1.1),
|
|
),
|
|
)
|
|
|
|
def random_pitch_shift(self, sample):
|
|
""" Randomly pitch shift the given sample. """
|
|
return dict(
|
|
sample,
|
|
**sync_apply(
|
|
{
|
|
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
|
|
for instrument in self._instruments
|
|
},
|
|
lambda x: random_pitch_shift(x, shift_min=-1.0, shift_max=1.0),
|
|
concat_axis=0,
|
|
),
|
|
)
|
|
|
|
def map_features(self, sample):
|
|
""" Select features and annotation of the given sample. """
|
|
input_ = {
|
|
f"{self._mix_name}_spectrogram": sample[f"{self._mix_name}_spectrogram"]
|
|
}
|
|
output = {
|
|
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
|
|
for instrument in self._audio_params["instrument_list"]
|
|
}
|
|
return (input_, output)
|
|
|
|
def compute_segments(self, dataset: Any, n_chunks_per_song: int) -> Any:
|
|
"""
|
|
Computes segments for each song of the dataset.
|
|
|
|
Parameters:
|
|
dataset (Any):
|
|
Dataset to compute segments for.
|
|
n_chunks_per_song (int):
|
|
Number of segment per song to compute.
|
|
|
|
Returns:
|
|
Any:
|
|
Segmented dataset.
|
|
"""
|
|
if n_chunks_per_song <= 0:
|
|
raise ValueError("n_chunks_per_song must be positif")
|
|
datasets = []
|
|
for k in range(n_chunks_per_song):
|
|
if n_chunks_per_song > 1:
|
|
datasets.append(
|
|
dataset.map(
|
|
lambda sample: dict(
|
|
sample,
|
|
start=tf.maximum(
|
|
k
|
|
* (
|
|
sample["duration"]
|
|
- self._chunk_duration
|
|
- 2 * self.MARGIN
|
|
)
|
|
/ (n_chunks_per_song - 1)
|
|
+ self.MARGIN,
|
|
0,
|
|
),
|
|
)
|
|
)
|
|
)
|
|
elif n_chunks_per_song == 1: # Take central segment.
|
|
datasets.append(
|
|
dataset.map(
|
|
lambda sample: dict(
|
|
sample,
|
|
start=tf.maximum(
|
|
sample["duration"] / 2 - self._chunk_duration / 2, 0
|
|
),
|
|
)
|
|
)
|
|
)
|
|
dataset = datasets[-1]
|
|
for d in datasets[:-1]:
|
|
dataset = dataset.concatenate(d)
|
|
return dataset
|
|
|
|
@property
|
|
def instruments(self) -> Any:
|
|
"""
|
|
Instrument dataset builder generator.
|
|
|
|
Yields:
|
|
Any:
|
|
InstrumentBuilder instance.
|
|
"""
|
|
if self._instrument_builders is None:
|
|
self._instrument_builders = []
|
|
for instrument in self._instruments:
|
|
self._instrument_builders.append(
|
|
InstrumentDatasetBuilder(self, instrument)
|
|
)
|
|
for builder in self._instrument_builders:
|
|
yield builder
|
|
|
|
def cache(self, dataset: Any, cache: str, wait: bool) -> Any:
|
|
"""
|
|
Cache the given dataset if cache is enabled. Eventually waits for
|
|
cache to be available (useful if another process is already
|
|
computing cache) if provided wait flag is `True`.
|
|
|
|
Parameters:
|
|
dataset (Any):
|
|
Dataset to be cached if cache is required.
|
|
cache (str):
|
|
Path of cache directory to be used, None if no cache.
|
|
wait (bool):
|
|
If caching is enabled, True is cache should be waited.
|
|
|
|
Returns:
|
|
Any:
|
|
Cached dataset if needed, original dataset otherwise.
|
|
"""
|
|
if cache is not None:
|
|
if wait:
|
|
while not exists(f"{cache}.index"):
|
|
logger.info(f"Cache not available, wait {self.WAIT_PERIOD}")
|
|
time.sleep(self.WAIT_PERIOD)
|
|
cache_path = os.path.split(cache)[0]
|
|
os.makedirs(cache_path, exist_ok=True)
|
|
return dataset.cache(cache)
|
|
return dataset
|
|
|
|
def build(
|
|
self,
|
|
csv_path: str,
|
|
batch_size: int = 8,
|
|
shuffle: bool = True,
|
|
convert_to_uint: bool = True,
|
|
random_data_augmentation: bool = False,
|
|
random_time_crop: bool = True,
|
|
infinite_generator: bool = True,
|
|
cache_directory: Optional[str] = None,
|
|
wait_for_cache: bool = False,
|
|
num_parallel_calls: int = 4,
|
|
n_chunks_per_song: float = 2,
|
|
) -> Any:
|
|
"""
|
|
TO BE DOCUMENTED.
|
|
"""
|
|
dataset = dataset_from_csv(csv_path)
|
|
dataset = self.compute_segments(dataset, n_chunks_per_song)
|
|
# Shuffle data
|
|
if shuffle:
|
|
dataset = dataset.shuffle(
|
|
buffer_size=200000,
|
|
seed=self._random_seed,
|
|
# useless since it is cached :
|
|
reshuffle_each_iteration=True,
|
|
)
|
|
# Expand audio path.
|
|
dataset = dataset.map(self.expand_path)
|
|
# Load waveform, compute spectrogram, and filtering error,
|
|
# K bins frequencies, and waveform.
|
|
N = num_parallel_calls
|
|
for instrument in self.instruments:
|
|
dataset = (
|
|
dataset.map(instrument.load_waveform, num_parallel_calls=N)
|
|
.filter(self.filter_error)
|
|
.map(instrument.compute_spectrogram, num_parallel_calls=N)
|
|
.map(instrument.filter_frequencies)
|
|
)
|
|
dataset = dataset.map(self.filter_waveform)
|
|
# Convert to uint before caching in order to save space.
|
|
if convert_to_uint:
|
|
for instrument in self.instruments:
|
|
dataset = dataset.map(instrument.convert_to_uint)
|
|
dataset = self.cache(dataset, cache_directory, wait_for_cache)
|
|
# Check for INFINITY (should not happen)
|
|
for instrument in self.instruments:
|
|
dataset = dataset.filter(instrument.filter_infinity)
|
|
# Repeat indefinitly
|
|
if infinite_generator:
|
|
dataset = dataset.repeat(count=-1)
|
|
# Ensure same size for vocals and mix spectrograms.
|
|
# NOTE: could be done before caching ?
|
|
dataset = dataset.map(self.harmonize_spectrogram)
|
|
# Filter out too short segment.
|
|
# NOTE: could be done before caching ?
|
|
dataset = dataset.filter(self.filter_short_segments)
|
|
# Random time crop of 11.88s
|
|
if random_time_crop:
|
|
dataset = dataset.map(self.random_time_crop, num_parallel_calls=N)
|
|
else:
|
|
# frame_duration = 11.88/T
|
|
# take central segment (for validation)
|
|
for instrument in self.instruments:
|
|
dataset = dataset.map(instrument.time_crop)
|
|
# Post cache shuffling. Done where the data are the lightest:
|
|
# after croping but before converting back to float.
|
|
if shuffle:
|
|
dataset = dataset.shuffle(
|
|
buffer_size=256, seed=self._random_seed, reshuffle_each_iteration=True
|
|
)
|
|
# Convert back to float32
|
|
if convert_to_uint:
|
|
for instrument in self.instruments:
|
|
dataset = dataset.map(
|
|
instrument.convert_to_float32, num_parallel_calls=N
|
|
)
|
|
M = 8 # Parallel call post caching.
|
|
# Must be applied with the same factor on mix and vocals.
|
|
if random_data_augmentation:
|
|
dataset = dataset.map(self.random_time_stretch, num_parallel_calls=M).map(
|
|
self.random_pitch_shift, num_parallel_calls=M
|
|
)
|
|
# Filter by shape (remove badly shaped tensors).
|
|
for instrument in self.instruments:
|
|
dataset = dataset.filter(instrument.filter_shape).map(
|
|
instrument.reshape_spectrogram
|
|
)
|
|
# Select features and annotation.
|
|
dataset = dataset.map(self.map_features)
|
|
# Make batch (done after selection to avoid
|
|
# error due to unprocessed instrument spectrogram batching).
|
|
dataset = dataset.batch(batch_size)
|
|
return dataset
|