#!/usr/bin/env python # coding: utf8 """ Module for building data preprocessing pipeline using the tensorflow data API. Data preprocessing such as audio loading, spectrogram computation, cropping, feature caching or data augmentation is done using a tensorflow dataset object that output a tuple (input_, output) where: - input is a dictionary with a single key that contains the (batched) mix spectrogram of audio samples - output is a dictionary of spectrogram of the isolated tracks (ground truth) """ import os import time from os.path import exists from os.path import sep as SEPARATOR from typing import Any, Dict, Optional # pyright: reportMissingImports=false # pylint: disable=import-error import tensorflow as tf from .audio.adapter import AudioAdapter from .audio.convertor import db_uint_spectrogram_to_gain, spectrogram_to_db_uint from .audio.spectrogram import ( compute_spectrogram_tf, random_pitch_shift, random_time_stretch, ) from .utils.logging import logger from .utils.tensor import ( check_tensor_shape, dataset_from_csv, set_tensor_shape, sync_apply, ) # pylint: enable=import-error __email__ = "spleeter@deezer.com" __author__ = "Deezer Research" __license__ = "MIT License" # Default audio parameters to use. DEFAULT_AUDIO_PARAMS: Dict = { "instrument_list": ("vocals", "accompaniment"), "mix_name": "mix", "sample_rate": 44100, "frame_length": 4096, "frame_step": 1024, "T": 512, "F": 1024, } def get_training_dataset( audio_params: Dict, audio_adapter: AudioAdapter, audio_path: str ) -> Any: """ Builds training dataset. Parameters: audio_params (Dict): Audio parameters. audio_adapter (AudioAdapter): Adapter to load audio from. audio_path (str): Path of directory containing audio. Returns: Any: Built dataset. """ builder = DatasetBuilder( audio_params, audio_adapter, audio_path, chunk_duration=audio_params.get("chunk_duration", 20.0), random_seed=audio_params.get("random_seed", 0), ) return builder.build( audio_params.get("train_csv"), cache_directory=audio_params.get("training_cache"), batch_size=audio_params.get("batch_size"), n_chunks_per_song=audio_params.get("n_chunks_per_song", 2), random_data_augmentation=False, convert_to_uint=True, wait_for_cache=False, ) def get_validation_dataset( audio_params: Dict, audio_adapter: AudioAdapter, audio_path: str ) -> Any: """ Builds validation dataset. Parameters: audio_params (Dict): Audio parameters. audio_adapter (AudioAdapter): Adapter to load audio from. audio_path (str): Path of directory containing audio. Returns: Any: Built dataset. """ builder = DatasetBuilder( audio_params, audio_adapter, audio_path, chunk_duration=12.0 ) return builder.build( audio_params.get("validation_csv"), batch_size=audio_params.get("batch_size"), cache_directory=audio_params.get("validation_cache"), convert_to_uint=True, infinite_generator=False, n_chunks_per_song=1, # should not perform data augmentation for eval: random_data_augmentation=False, random_time_crop=False, shuffle=False, ) class InstrumentDatasetBuilder(object): """ Instrument based filter and mapper provider. """ def __init__(self, parent, instrument) -> None: """ Default constructor. Parameters: parent: Parent dataset builder. instrument: Target instrument. """ self._parent = parent self._instrument = instrument self._spectrogram_key = f"{instrument}_spectrogram" self._min_spectrogram_key = f"min_{instrument}_spectrogram" self._max_spectrogram_key = f"max_{instrument}_spectrogram" def load_waveform(self, sample): """ Load waveform for given sample. """ return dict( sample, **self._parent._audio_adapter.load_tf_waveform( sample[f"{self._instrument}_path"], offset=sample["start"], duration=self._parent._chunk_duration, sample_rate=self._parent._sample_rate, waveform_name="waveform", ), ) def compute_spectrogram(self, sample): """ Compute spectrogram of the given sample. """ return dict( sample, **{ self._spectrogram_key: compute_spectrogram_tf( sample["waveform"], frame_length=self._parent._frame_length, frame_step=self._parent._frame_step, spec_exponent=1.0, window_exponent=1.0, ) }, ) def filter_frequencies(self, sample): """ """ return dict( sample, **{ self._spectrogram_key: sample[self._spectrogram_key][ :, : self._parent._F, : ] }, ) def convert_to_uint(self, sample): """ Convert given sample from float to unit. """ return dict( sample, **spectrogram_to_db_uint( sample[self._spectrogram_key], tensor_key=self._spectrogram_key, min_key=self._min_spectrogram_key, max_key=self._max_spectrogram_key, ), ) def filter_infinity(self, sample): """ Filter infinity sample. """ return tf.logical_not(tf.math.is_inf(sample[self._min_spectrogram_key])) def convert_to_float32(self, sample): """ Convert given sample from unit to float. """ return dict( sample, **{ self._spectrogram_key: db_uint_spectrogram_to_gain( sample[self._spectrogram_key], sample[self._min_spectrogram_key], sample[self._max_spectrogram_key], ) }, ) def time_crop(self, sample): """ """ def start(sample): """ mid_segment_start """ return tf.cast( tf.maximum( tf.shape(sample[self._spectrogram_key])[0] / 2 - self._parent._T / 2, 0, ), tf.int32, ) return dict( sample, **{ self._spectrogram_key: sample[self._spectrogram_key][ start(sample) : start(sample) + self._parent._T, :, : ] }, ) def filter_shape(self, sample): """ Filter badly shaped sample. """ return check_tensor_shape( sample[self._spectrogram_key], (self._parent._T, self._parent._F, self._parent._n_channels), ) def reshape_spectrogram(self, sample): """ Reshape given sample. """ return dict( sample, **{ self._spectrogram_key: set_tensor_shape( sample[self._spectrogram_key], (self._parent._T, self._parent._F, self._parent._n_channels), ) }, ) class DatasetBuilder(object): """ TO BE DOCUMENTED. """ MARGIN: float = 0.5 """ Margin at beginning and end of songs in seconds. """ WAIT_PERIOD: int = 60 """ Wait period for cache (in seconds). """ def __init__( self, audio_params: Dict, audio_adapter: AudioAdapter, audio_path: str, random_seed: int = 0, chunk_duration: float = 20.0, ) -> None: """ Default constructor. NOTE: Probably need for AudioAdapter. Parameters: audio_params (Dict): Audio parameters to use. audio_adapter (AudioAdapter): Audio adapter to use. audio_path (str): random_seed (int): chunk_duration (float): """ # Length of segment in frames (if fs=22050 and # frame_step=512, then T=512 corresponds to 11.89s) self._T = audio_params["T"] # Number of frequency bins to be used (should # be less than frame_length/2 + 1) self._F = audio_params["F"] self._sample_rate = audio_params["sample_rate"] self._frame_length = audio_params["frame_length"] self._frame_step = audio_params["frame_step"] self._mix_name = audio_params["mix_name"] self._n_channels = audio_params["n_channels"] self._instruments = [self._mix_name] + audio_params["instrument_list"] self._instrument_builders = None self._chunk_duration = chunk_duration self._audio_adapter = audio_adapter self._audio_params = audio_params self._audio_path = audio_path self._random_seed = random_seed self.check_parameters_compatibility() def check_parameters_compatibility(self): if self._frame_length / 2 + 1 < self._F: raise ValueError( "F is too large and must be set to at most frame_length/2+1. Decrease F or increase frame_length to fix." ) if ( self._chunk_duration * self._sample_rate - self._frame_length ) / self._frame_step < self._T: raise ValueError( "T is too large considering STFT parameters and chunk duratoin. Make sure spectrogram time dimension of chunks is larger than T (for instance reducing T or frame_step or increasing chunk duration)." ) def expand_path(self, sample): """ Expands audio paths for the given sample. """ return dict( sample, **{ f"{instrument}_path": tf.strings.join( (self._audio_path, sample[f"{instrument}_path"]), SEPARATOR ) for instrument in self._instruments }, ) def filter_error(self, sample): """ Filter errored sample. """ return tf.logical_not(sample["waveform_error"]) def filter_waveform(self, sample): """ Filter waveform from sample. """ return {k: v for k, v in sample.items() if not k == "waveform"} def harmonize_spectrogram(self, sample): """ Ensure same size for vocals and mix spectrograms. """ def _reduce(sample): return tf.reduce_min( [ tf.shape(sample[f"{instrument}_spectrogram"])[0] for instrument in self._instruments ] ) return dict( sample, **{ f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"][ : _reduce(sample), :, : ] for instrument in self._instruments }, ) def filter_short_segments(self, sample): """ Filter out too short segment. """ return tf.reduce_any( [ tf.shape(sample[f"{instrument}_spectrogram"])[0] >= self._T for instrument in self._instruments ] ) def random_time_crop(self, sample): """ Random time crop of 11.88s. """ return dict( sample, **sync_apply( { f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"] for instrument in self._instruments }, lambda x: tf.image.random_crop( x, (self._T, len(self._instruments) * self._F, self._n_channels), seed=self._random_seed, ), ), ) def random_time_stretch(self, sample): """ Randomly time stretch the given sample. """ return dict( sample, **sync_apply( { f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"] for instrument in self._instruments }, lambda x: random_time_stretch(x, factor_min=0.9, factor_max=1.1), ), ) def random_pitch_shift(self, sample): """ Randomly pitch shift the given sample. """ return dict( sample, **sync_apply( { f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"] for instrument in self._instruments }, lambda x: random_pitch_shift(x, shift_min=-1.0, shift_max=1.0), concat_axis=0, ), ) def map_features(self, sample): """ Select features and annotation of the given sample. """ input_ = { f"{self._mix_name}_spectrogram": sample[f"{self._mix_name}_spectrogram"] } output = { f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"] for instrument in self._audio_params["instrument_list"] } return (input_, output) def compute_segments(self, dataset: Any, n_chunks_per_song: int) -> Any: """ Computes segments for each song of the dataset. Parameters: dataset (Any): Dataset to compute segments for. n_chunks_per_song (int): Number of segment per song to compute. Returns: Any: Segmented dataset. """ if n_chunks_per_song <= 0: raise ValueError("n_chunks_per_song must be positif") datasets = [] for k in range(n_chunks_per_song): if n_chunks_per_song > 1: datasets.append( dataset.map( lambda sample: dict( sample, start=tf.maximum( k * ( sample["duration"] - self._chunk_duration - 2 * self.MARGIN ) / (n_chunks_per_song - 1) + self.MARGIN, 0, ), ) ) ) elif n_chunks_per_song == 1: # Take central segment. datasets.append( dataset.map( lambda sample: dict( sample, start=tf.maximum( sample["duration"] / 2 - self._chunk_duration / 2, 0 ), ) ) ) dataset = datasets[-1] for d in datasets[:-1]: dataset = dataset.concatenate(d) return dataset @property def instruments(self) -> Any: """ Instrument dataset builder generator. Yields: Any: InstrumentBuilder instance. """ if self._instrument_builders is None: self._instrument_builders = [] for instrument in self._instruments: self._instrument_builders.append( InstrumentDatasetBuilder(self, instrument) ) for builder in self._instrument_builders: yield builder def cache(self, dataset: Any, cache: str, wait: bool) -> Any: """ Cache the given dataset if cache is enabled. Eventually waits for cache to be available (useful if another process is already computing cache) if provided wait flag is `True`. Parameters: dataset (Any): Dataset to be cached if cache is required. cache (str): Path of cache directory to be used, None if no cache. wait (bool): If caching is enabled, True is cache should be waited. Returns: Any: Cached dataset if needed, original dataset otherwise. """ if cache is not None: if wait: while not exists(f"{cache}.index"): logger.info(f"Cache not available, wait {self.WAIT_PERIOD}") time.sleep(self.WAIT_PERIOD) cache_path = os.path.split(cache)[0] os.makedirs(cache_path, exist_ok=True) return dataset.cache(cache) return dataset def build( self, csv_path: str, batch_size: int = 8, shuffle: bool = True, convert_to_uint: bool = True, random_data_augmentation: bool = False, random_time_crop: bool = True, infinite_generator: bool = True, cache_directory: Optional[str] = None, wait_for_cache: bool = False, num_parallel_calls: int = 4, n_chunks_per_song: float = 2, ) -> Any: """ TO BE DOCUMENTED. """ dataset = dataset_from_csv(csv_path) dataset = self.compute_segments(dataset, n_chunks_per_song) # Shuffle data if shuffle: dataset = dataset.shuffle( buffer_size=200000, seed=self._random_seed, # useless since it is cached : reshuffle_each_iteration=True, ) # Expand audio path. dataset = dataset.map(self.expand_path) # Load waveform, compute spectrogram, and filtering error, # K bins frequencies, and waveform. N = num_parallel_calls for instrument in self.instruments: dataset = ( dataset.map(instrument.load_waveform, num_parallel_calls=N) .filter(self.filter_error) .map(instrument.compute_spectrogram, num_parallel_calls=N) .map(instrument.filter_frequencies) ) dataset = dataset.map(self.filter_waveform) # Convert to uint before caching in order to save space. if convert_to_uint: for instrument in self.instruments: dataset = dataset.map(instrument.convert_to_uint) dataset = self.cache(dataset, cache_directory, wait_for_cache) # Check for INFINITY (should not happen) for instrument in self.instruments: dataset = dataset.filter(instrument.filter_infinity) # Repeat indefinitly if infinite_generator: dataset = dataset.repeat(count=-1) # Ensure same size for vocals and mix spectrograms. # NOTE: could be done before caching ? dataset = dataset.map(self.harmonize_spectrogram) # Filter out too short segment. # NOTE: could be done before caching ? dataset = dataset.filter(self.filter_short_segments) # Random time crop of 11.88s if random_time_crop: dataset = dataset.map(self.random_time_crop, num_parallel_calls=N) else: # frame_duration = 11.88/T # take central segment (for validation) for instrument in self.instruments: dataset = dataset.map(instrument.time_crop) # Post cache shuffling. Done where the data are the lightest: # after croping but before converting back to float. if shuffle: dataset = dataset.shuffle( buffer_size=256, seed=self._random_seed, reshuffle_each_iteration=True ) # Convert back to float32 if convert_to_uint: for instrument in self.instruments: dataset = dataset.map( instrument.convert_to_float32, num_parallel_calls=N ) M = 8 # Parallel call post caching. # Must be applied with the same factor on mix and vocals. if random_data_augmentation: dataset = dataset.map(self.random_time_stretch, num_parallel_calls=M).map( self.random_pitch_shift, num_parallel_calls=M ) # Filter by shape (remove badly shaped tensors). for instrument in self.instruments: dataset = dataset.filter(instrument.filter_shape).map( instrument.reshape_spectrogram ) # Select features and annotation. dataset = dataset.map(self.map_features) # Make batch (done after selection to avoid # error due to unprocessed instrument spectrogram batching). dataset = dataset.batch(batch_size) return dataset