You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

625 lines
21 KiB

#!/usr/bin/env python
# coding: utf8
"""
Module for building data preprocessing pipeline using the tensorflow
data API. Data preprocessing such as audio loading, spectrogram
computation, cropping, feature caching or data augmentation is done
using a tensorflow dataset object that output a tuple (input_, output)
where:
- input is a dictionary with a single key that contains the (batched)
mix spectrogram of audio samples
- output is a dictionary of spectrogram of the isolated tracks
(ground truth)
"""
import os
import time
from os.path import exists
from os.path import sep as SEPARATOR
from typing import Any, Dict, Optional
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
from .audio.adapter import AudioAdapter
from .audio.convertor import db_uint_spectrogram_to_gain, spectrogram_to_db_uint
from .audio.spectrogram import (
compute_spectrogram_tf,
random_pitch_shift,
random_time_stretch,
)
from .utils.logging import logger
from .utils.tensor import (
check_tensor_shape,
dataset_from_csv,
set_tensor_shape,
sync_apply,
)
# pylint: enable=import-error
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
# Default audio parameters to use.
DEFAULT_AUDIO_PARAMS: Dict = {
"instrument_list": ("vocals", "accompaniment"),
"mix_name": "mix",
"sample_rate": 44100,
"frame_length": 4096,
"frame_step": 1024,
"T": 512,
"F": 1024,
}
def get_training_dataset(
audio_params: Dict, audio_adapter: AudioAdapter, audio_path: str
) -> Any:
"""
Builds training dataset.
Parameters:
audio_params (Dict):
Audio parameters.
audio_adapter (AudioAdapter):
Adapter to load audio from.
audio_path (str):
Path of directory containing audio.
Returns:
Any:
Built dataset.
"""
builder = DatasetBuilder(
audio_params,
audio_adapter,
audio_path,
chunk_duration=audio_params.get("chunk_duration", 20.0),
random_seed=audio_params.get("random_seed", 0),
)
return builder.build(
audio_params.get("train_csv"),
cache_directory=audio_params.get("training_cache"),
batch_size=audio_params.get("batch_size"),
n_chunks_per_song=audio_params.get("n_chunks_per_song", 2),
random_data_augmentation=False,
convert_to_uint=True,
wait_for_cache=False,
)
def get_validation_dataset(
audio_params: Dict, audio_adapter: AudioAdapter, audio_path: str
) -> Any:
"""
Builds validation dataset.
Parameters:
audio_params (Dict):
Audio parameters.
audio_adapter (AudioAdapter):
Adapter to load audio from.
audio_path (str):
Path of directory containing audio.
Returns:
Any:
Built dataset.
"""
builder = DatasetBuilder(
audio_params, audio_adapter, audio_path, chunk_duration=12.0
)
return builder.build(
audio_params.get("validation_csv"),
batch_size=audio_params.get("batch_size"),
cache_directory=audio_params.get("validation_cache"),
convert_to_uint=True,
infinite_generator=False,
n_chunks_per_song=1,
# should not perform data augmentation for eval:
random_data_augmentation=False,
random_time_crop=False,
shuffle=False,
)
class InstrumentDatasetBuilder(object):
""" Instrument based filter and mapper provider. """
def __init__(self, parent, instrument) -> None:
"""
Default constructor.
Parameters:
parent:
Parent dataset builder.
instrument:
Target instrument.
"""
self._parent = parent
self._instrument = instrument
self._spectrogram_key = f"{instrument}_spectrogram"
self._min_spectrogram_key = f"min_{instrument}_spectrogram"
self._max_spectrogram_key = f"max_{instrument}_spectrogram"
def load_waveform(self, sample):
""" Load waveform for given sample. """
return dict(
sample,
**self._parent._audio_adapter.load_tf_waveform(
sample[f"{self._instrument}_path"],
offset=sample["start"],
duration=self._parent._chunk_duration,
sample_rate=self._parent._sample_rate,
waveform_name="waveform",
),
)
def compute_spectrogram(self, sample):
""" Compute spectrogram of the given sample. """
return dict(
sample,
**{
self._spectrogram_key: compute_spectrogram_tf(
sample["waveform"],
frame_length=self._parent._frame_length,
frame_step=self._parent._frame_step,
spec_exponent=1.0,
window_exponent=1.0,
)
},
)
def filter_frequencies(self, sample):
""" """
return dict(
sample,
**{
self._spectrogram_key: sample[self._spectrogram_key][
:, : self._parent._F, :
]
},
)
def convert_to_uint(self, sample):
""" Convert given sample from float to unit. """
return dict(
sample,
**spectrogram_to_db_uint(
sample[self._spectrogram_key],
tensor_key=self._spectrogram_key,
min_key=self._min_spectrogram_key,
max_key=self._max_spectrogram_key,
),
)
def filter_infinity(self, sample):
""" Filter infinity sample. """
return tf.logical_not(tf.math.is_inf(sample[self._min_spectrogram_key]))
def convert_to_float32(self, sample):
""" Convert given sample from unit to float. """
return dict(
sample,
**{
self._spectrogram_key: db_uint_spectrogram_to_gain(
sample[self._spectrogram_key],
sample[self._min_spectrogram_key],
sample[self._max_spectrogram_key],
)
},
)
def time_crop(self, sample):
""" """
def start(sample):
""" mid_segment_start """
return tf.cast(
tf.maximum(
tf.shape(sample[self._spectrogram_key])[0] / 2
- self._parent._T / 2,
0,
),
tf.int32,
)
return dict(
sample,
**{
self._spectrogram_key: sample[self._spectrogram_key][
start(sample) : start(sample) + self._parent._T, :, :
]
},
)
def filter_shape(self, sample):
""" Filter badly shaped sample. """
return check_tensor_shape(
sample[self._spectrogram_key],
(self._parent._T, self._parent._F, self._parent._n_channels),
)
def reshape_spectrogram(self, sample):
""" Reshape given sample. """
return dict(
sample,
**{
self._spectrogram_key: set_tensor_shape(
sample[self._spectrogram_key],
(self._parent._T, self._parent._F, self._parent._n_channels),
)
},
)
class DatasetBuilder(object):
"""
TO BE DOCUMENTED.
"""
MARGIN: float = 0.5
""" Margin at beginning and end of songs in seconds. """
WAIT_PERIOD: int = 60
""" Wait period for cache (in seconds). """
def __init__(
self,
audio_params: Dict,
audio_adapter: AudioAdapter,
audio_path: str,
random_seed: int = 0,
chunk_duration: float = 20.0,
) -> None:
"""
Default constructor.
NOTE: Probably need for AudioAdapter.
Parameters:
audio_params (Dict):
Audio parameters to use.
audio_adapter (AudioAdapter):
Audio adapter to use.
audio_path (str):
random_seed (int):
chunk_duration (float):
"""
# Length of segment in frames (if fs=22050 and
# frame_step=512, then T=512 corresponds to 11.89s)
self._T = audio_params["T"]
# Number of frequency bins to be used (should
# be less than frame_length/2 + 1)
self._F = audio_params["F"]
self._sample_rate = audio_params["sample_rate"]
self._frame_length = audio_params["frame_length"]
self._frame_step = audio_params["frame_step"]
self._mix_name = audio_params["mix_name"]
self._n_channels = audio_params["n_channels"]
self._instruments = [self._mix_name] + audio_params["instrument_list"]
self._instrument_builders = None
self._chunk_duration = chunk_duration
self._audio_adapter = audio_adapter
self._audio_params = audio_params
self._audio_path = audio_path
self._random_seed = random_seed
self.check_parameters_compatibility()
def check_parameters_compatibility(self):
if self._frame_length / 2 + 1 < self._F:
raise ValueError(
"F is too large and must be set to at most frame_length/2+1. Decrease F or increase frame_length to fix."
)
if (
self._chunk_duration * self._sample_rate - self._frame_length
) / self._frame_step < self._T:
raise ValueError(
"T is too large considering STFT parameters and chunk duratoin. Make sure spectrogram time dimension of chunks is larger than T (for instance reducing T or frame_step or increasing chunk duration)."
)
def expand_path(self, sample):
""" Expands audio paths for the given sample. """
return dict(
sample,
**{
f"{instrument}_path": tf.strings.join(
(self._audio_path, sample[f"{instrument}_path"]), SEPARATOR
)
for instrument in self._instruments
},
)
def filter_error(self, sample):
""" Filter errored sample. """
return tf.logical_not(sample["waveform_error"])
def filter_waveform(self, sample):
""" Filter waveform from sample. """
return {k: v for k, v in sample.items() if not k == "waveform"}
def harmonize_spectrogram(self, sample):
""" Ensure same size for vocals and mix spectrograms. """
def _reduce(sample):
return tf.reduce_min(
[
tf.shape(sample[f"{instrument}_spectrogram"])[0]
for instrument in self._instruments
]
)
return dict(
sample,
**{
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"][
: _reduce(sample), :, :
]
for instrument in self._instruments
},
)
def filter_short_segments(self, sample):
""" Filter out too short segment. """
return tf.reduce_any(
[
tf.shape(sample[f"{instrument}_spectrogram"])[0] >= self._T
for instrument in self._instruments
]
)
def random_time_crop(self, sample):
""" Random time crop of 11.88s. """
return dict(
sample,
**sync_apply(
{
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
for instrument in self._instruments
},
lambda x: tf.image.random_crop(
x,
(self._T, len(self._instruments) * self._F, self._n_channels),
seed=self._random_seed,
),
),
)
def random_time_stretch(self, sample):
""" Randomly time stretch the given sample. """
return dict(
sample,
**sync_apply(
{
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
for instrument in self._instruments
},
lambda x: random_time_stretch(x, factor_min=0.9, factor_max=1.1),
),
)
def random_pitch_shift(self, sample):
""" Randomly pitch shift the given sample. """
return dict(
sample,
**sync_apply(
{
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
for instrument in self._instruments
},
lambda x: random_pitch_shift(x, shift_min=-1.0, shift_max=1.0),
concat_axis=0,
),
)
def map_features(self, sample):
""" Select features and annotation of the given sample. """
input_ = {
f"{self._mix_name}_spectrogram": sample[f"{self._mix_name}_spectrogram"]
}
output = {
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
for instrument in self._audio_params["instrument_list"]
}
return (input_, output)
def compute_segments(self, dataset: Any, n_chunks_per_song: int) -> Any:
"""
Computes segments for each song of the dataset.
Parameters:
dataset (Any):
Dataset to compute segments for.
n_chunks_per_song (int):
Number of segment per song to compute.
Returns:
Any:
Segmented dataset.
"""
if n_chunks_per_song <= 0:
raise ValueError("n_chunks_per_song must be positif")
datasets = []
for k in range(n_chunks_per_song):
if n_chunks_per_song > 1:
datasets.append(
dataset.map(
lambda sample: dict(
sample,
start=tf.maximum(
k
* (
sample["duration"]
- self._chunk_duration
- 2 * self.MARGIN
)
/ (n_chunks_per_song - 1)
+ self.MARGIN,
0,
),
)
)
)
elif n_chunks_per_song == 1: # Take central segment.
datasets.append(
dataset.map(
lambda sample: dict(
sample,
start=tf.maximum(
sample["duration"] / 2 - self._chunk_duration / 2, 0
),
)
)
)
dataset = datasets[-1]
for d in datasets[:-1]:
dataset = dataset.concatenate(d)
return dataset
@property
def instruments(self) -> Any:
"""
Instrument dataset builder generator.
Yields:
Any:
InstrumentBuilder instance.
"""
if self._instrument_builders is None:
self._instrument_builders = []
for instrument in self._instruments:
self._instrument_builders.append(
InstrumentDatasetBuilder(self, instrument)
)
for builder in self._instrument_builders:
yield builder
def cache(self, dataset: Any, cache: str, wait: bool) -> Any:
"""
Cache the given dataset if cache is enabled. Eventually waits for
cache to be available (useful if another process is already
computing cache) if provided wait flag is `True`.
Parameters:
dataset (Any):
Dataset to be cached if cache is required.
cache (str):
Path of cache directory to be used, None if no cache.
wait (bool):
If caching is enabled, True is cache should be waited.
Returns:
Any:
Cached dataset if needed, original dataset otherwise.
"""
if cache is not None:
if wait:
while not exists(f"{cache}.index"):
logger.info(f"Cache not available, wait {self.WAIT_PERIOD}")
time.sleep(self.WAIT_PERIOD)
cache_path = os.path.split(cache)[0]
os.makedirs(cache_path, exist_ok=True)
return dataset.cache(cache)
return dataset
def build(
self,
csv_path: str,
batch_size: int = 8,
shuffle: bool = True,
convert_to_uint: bool = True,
random_data_augmentation: bool = False,
random_time_crop: bool = True,
infinite_generator: bool = True,
cache_directory: Optional[str] = None,
wait_for_cache: bool = False,
num_parallel_calls: int = 4,
n_chunks_per_song: float = 2,
) -> Any:
"""
TO BE DOCUMENTED.
"""
dataset = dataset_from_csv(csv_path)
dataset = self.compute_segments(dataset, n_chunks_per_song)
# Shuffle data
if shuffle:
dataset = dataset.shuffle(
buffer_size=200000,
seed=self._random_seed,
# useless since it is cached :
reshuffle_each_iteration=True,
)
# Expand audio path.
dataset = dataset.map(self.expand_path)
# Load waveform, compute spectrogram, and filtering error,
# K bins frequencies, and waveform.
N = num_parallel_calls
for instrument in self.instruments:
dataset = (
dataset.map(instrument.load_waveform, num_parallel_calls=N)
.filter(self.filter_error)
.map(instrument.compute_spectrogram, num_parallel_calls=N)
.map(instrument.filter_frequencies)
)
dataset = dataset.map(self.filter_waveform)
# Convert to uint before caching in order to save space.
if convert_to_uint:
for instrument in self.instruments:
dataset = dataset.map(instrument.convert_to_uint)
dataset = self.cache(dataset, cache_directory, wait_for_cache)
# Check for INFINITY (should not happen)
for instrument in self.instruments:
dataset = dataset.filter(instrument.filter_infinity)
# Repeat indefinitly
if infinite_generator:
dataset = dataset.repeat(count=-1)
# Ensure same size for vocals and mix spectrograms.
# NOTE: could be done before caching ?
dataset = dataset.map(self.harmonize_spectrogram)
# Filter out too short segment.
# NOTE: could be done before caching ?
dataset = dataset.filter(self.filter_short_segments)
# Random time crop of 11.88s
if random_time_crop:
dataset = dataset.map(self.random_time_crop, num_parallel_calls=N)
else:
# frame_duration = 11.88/T
# take central segment (for validation)
for instrument in self.instruments:
dataset = dataset.map(instrument.time_crop)
# Post cache shuffling. Done where the data are the lightest:
# after croping but before converting back to float.
if shuffle:
dataset = dataset.shuffle(
buffer_size=256, seed=self._random_seed, reshuffle_each_iteration=True
)
# Convert back to float32
if convert_to_uint:
for instrument in self.instruments:
dataset = dataset.map(
instrument.convert_to_float32, num_parallel_calls=N
)
M = 8 # Parallel call post caching.
# Must be applied with the same factor on mix and vocals.
if random_data_augmentation:
dataset = dataset.map(self.random_time_stretch, num_parallel_calls=M).map(
self.random_pitch_shift, num_parallel_calls=M
)
# Filter by shape (remove badly shaped tensors).
for instrument in self.instruments:
dataset = dataset.filter(instrument.filter_shape).map(
instrument.reshape_spectrogram
)
# Select features and annotation.
dataset = dataset.map(self.map_features)
# Make batch (done after selection to avoid
# error due to unprocessed instrument spectrogram batching).
dataset = dataset.batch(batch_size)
return dataset