#!/usr/bin/env python # coding: utf8 """ Module that provides a class wrapper for source separation. Examples: ```python >>> from spleeter.separator import Separator >>> separator = Separator('spleeter:2stems') >>> separator.separate(waveform, lambda instrument, data: ...) >>> separator.separate_to_file(...) ``` """ import atexit import os from multiprocessing import Pool from os.path import basename, dirname, join, splitext from typing import Dict, Generator, Optional # pyright: reportMissingImports=false # pylint: disable=import-error import numpy as np import tensorflow as tf from librosa.core import istft, stft from scipy.signal.windows import hann from spleeter.model.provider import ModelProvider from . import SpleeterError from .audio import Codec, STFTBackend from .audio.adapter import AudioAdapter from .audio.convertor import to_stereo from .model import EstimatorSpecBuilder, InputProviderFactory, model_fn from .model.provider import ModelProvider from .types import AudioDescriptor from .utils.configuration import load_configuration # pylint: enable=import-error __email__ = "spleeter@deezer.com" __author__ = "Deezer Research" __license__ = "MIT License" class DataGenerator(object): """ Generator object that store a sample and generate it once while called. Used to feed a tensorflow estimator without knowing the whole data at build time. """ def __init__(self) -> None: """ Default constructor. """ self._current_data = None def update_data(self, data) -> None: """ Replace internal data. """ self._current_data = data def __call__(self) -> Generator: """ Generation process. """ buffer = self._current_data while buffer: yield buffer buffer = self._current_data def create_estimator(params, MWF): """ Initialize tensorflow estimator that will perform separation Params: - params: a dictionary of parameters for building the model Returns: a tensorflow estimator """ # Load model. provider: ModelProvider = ModelProvider.default() params["model_dir"] = provider.get(params["model_dir"]) params["MWF"] = MWF # Setup config session_config = tf.compat.v1.ConfigProto() session_config.gpu_options.per_process_gpu_memory_fraction = 0.7 config = tf.estimator.RunConfig(session_config=session_config) # Setup estimator estimator = tf.estimator.Estimator( model_fn=model_fn, model_dir=params["model_dir"], params=params, config=config ) return estimator class Separator(object): """ A wrapper class for performing separation. """ def __init__( self, params_descriptor: str, MWF: bool = False, stft_backend: STFTBackend = STFTBackend.AUTO, multiprocess: bool = True, ) -> None: """ Default constructor. Parameters: params_descriptor (str): Descriptor for TF params to be used. MWF (bool): (Optional) `True` if MWF should be used, `False` otherwise. """ self._params = load_configuration(params_descriptor) self._sample_rate = self._params["sample_rate"] self._MWF = MWF self._tf_graph = tf.Graph() self._prediction_generator = None self._input_provider = None self._builder = None self._features = None self._session = None if multiprocess: self._pool = Pool() atexit.register(self._pool.close) else: self._pool = None self._tasks = [] self._params["stft_backend"] = STFTBackend.resolve(stft_backend) self._data_generator = DataGenerator() def _get_prediction_generator(self) -> Generator: """ Lazy loading access method for internal prediction generator returned by the predict method of a tensorflow estimator. Returns: Generator: Generator of prediction. """ if self._prediction_generator is None: estimator = create_estimator(self._params, self._MWF) def get_dataset(): return tf.data.Dataset.from_generator( self._data_generator, output_types={"waveform": tf.float32, "audio_id": tf.string}, output_shapes={"waveform": (None, 2), "audio_id": ()}, ) self._prediction_generator = estimator.predict( get_dataset, yield_single_examples=False ) return self._prediction_generator def join(self, timeout: int = 200) -> None: """ Wait for all pending tasks to be finished. Parameters: timeout (int): (Optional) task waiting timeout. """ while len(self._tasks) > 0: task = self._tasks.pop() task.get() task.wait(timeout=timeout) def _stft( self, data: np.ndarray, inverse: bool = False, length: Optional[int] = None ) -> np.ndarray: """ Single entrypoint for both stft and istft. This computes stft and istft with librosa on stereo data. The two channels are processed separately and are concatenated together in the result. The expected input formats are: (n_samples, 2) for stft and (T, F, 2) for istft. Parameters: data (numpy.array): Array with either the waveform or the complex spectrogram depending on the parameter inverse inverse (bool): (Optional) Should a stft or an istft be computed. length (Optional[int]): Returns: numpy.ndarray: Stereo data as numpy array for the transform. The channels are stored in the last dimension. """ assert not (inverse and length is None) data = np.asfortranarray(data) N = self._params["frame_length"] H = self._params["frame_step"] win = hann(N, sym=False) fstft = istft if inverse else stft win_len_arg = {"win_length": None, "length": None} if inverse else {"n_fft": N} n_channels = data.shape[-1] out = [] for c in range(n_channels): d = ( np.concatenate((np.zeros((N,)), data[:, c], np.zeros((N,)))) if not inverse else data[:, :, c].T ) s = fstft(d, hop_length=H, window=win, center=False, **win_len_arg) if inverse: s = s[N : N + length] s = np.expand_dims(s.T, 2 - inverse) out.append(s) if len(out) == 1: return out[0] return np.concatenate(out, axis=2 - inverse) def _get_input_provider(self): if self._input_provider is None: self._input_provider = InputProviderFactory.get(self._params) return self._input_provider def _get_features(self): if self._features is None: provider = self._get_input_provider() self._features = provider.get_input_dict_placeholders() return self._features def _get_builder(self): if self._builder is None: self._builder = EstimatorSpecBuilder(self._get_features(), self._params) return self._builder def _get_session(self): if self._session is None: saver = tf.compat.v1.train.Saver() provider = ModelProvider.default() model_directory: str = provider.get(self._params["model_dir"]) latest_checkpoint = tf.train.latest_checkpoint(model_directory) self._session = tf.compat.v1.Session() saver.restore(self._session, latest_checkpoint) return self._session def _separate_librosa( self, waveform: np.ndarray, audio_descriptor: AudioDescriptor ) -> Dict: """ Performs separation with librosa backend for STFT. Parameters: waveform (numpy.ndarray): Waveform to be separated (as a numpy array) audio_descriptor (AudioDescriptor): """ with self._tf_graph.as_default(): out = {} features = self._get_features() # TODO: fix the logic, build sometimes return, # sometimes set attribute. outputs = self._get_builder().outputs stft = self._stft(waveform) if stft.shape[-1] == 1: stft = np.concatenate([stft, stft], axis=-1) elif stft.shape[-1] > 2: stft = stft[:, :2] sess = self._get_session() outputs = sess.run( outputs, feed_dict=self._get_input_provider().get_feed_dict( features, stft, audio_descriptor ), ) for inst in self._get_builder().instruments: out[inst] = self._stft( outputs[inst], inverse=True, length=waveform.shape[0] ) return out def _separate_tensorflow( self, waveform: np.ndarray, audio_descriptor: AudioDescriptor ) -> Dict: """ Performs source separation over the given waveform with tensorflow backend. Parameters: waveform (numpy.ndarray): Waveform to be separated (as a numpy array) audio_descriptor (AudioDescriptor): Returns: Separated waveforms. """ if not waveform.shape[-1] == 2: waveform = to_stereo(waveform) prediction_generator = self._get_prediction_generator() # NOTE: update data in generator before performing separation. self._data_generator.update_data( {"waveform": waveform, "audio_id": np.array(audio_descriptor)} ) # NOTE: perform separation. prediction = next(prediction_generator) prediction.pop("audio_id") return prediction def separate( self, waveform: np.ndarray, audio_descriptor: Optional[str] = "" ) -> None: """ Performs separation on a waveform. Parameters: waveform (numpy.ndarray): Waveform to be separated (as a numpy array) audio_descriptor (str): (Optional) string describing the waveform (e.g. filename). """ backend: str = self._params["stft_backend"] if backend == STFTBackend.TENSORFLOW: return self._separate_tensorflow(waveform, audio_descriptor) elif backend == STFTBackend.LIBROSA: return self._separate_librosa(waveform, audio_descriptor) raise ValueError(f"Unsupported STFT backend {backend}") def separate_to_file( self, audio_descriptor: AudioDescriptor, destination: str, audio_adapter: Optional[AudioAdapter] = None, offset: int = 0, duration: float = 600.0, codec: Codec = Codec.WAV, bitrate: str = "128k", filename_format: str = "{filename}/{instrument}.{codec}", synchronous: bool = True, ) -> None: """ Performs source separation and export result to file using given audio adapter. Filename format should be a Python formattable string that could use following parameters : - {instrument} - {filename} - {foldername} - {codec}. Parameters: audio_descriptor (AudioDescriptor): Describe song to separate, used by audio adapter to retrieve and load audio data, in case of file based audio adapter, such descriptor would be a file path. destination (str): Target directory to write output to. audio_adapter (Optional[AudioAdapter]): (Optional) Audio adapter to use for I/O. offset (int): (Optional) Offset of loaded song. duration (float): (Optional) Duration of loaded song (default: 600s). codec (Codec): (Optional) Export codec. bitrate (str): (Optional) Export bitrate. filename_format (str): (Optional) Filename format. synchronous (bool): (Optional) True is should by synchronous. """ if audio_adapter is None: audio_adapter = AudioAdapter.default() waveform, _ = audio_adapter.load( audio_descriptor, offset=offset, duration=duration, sample_rate=self._sample_rate, ) sources = self.separate(waveform, audio_descriptor) self.save_to_file( sources, audio_descriptor, destination, filename_format, codec, audio_adapter, bitrate, synchronous, ) def save_to_file( self, sources: Dict, audio_descriptor: AudioDescriptor, destination: str, filename_format: str = "{filename}/{instrument}.{codec}", codec: Codec = Codec.WAV, audio_adapter: Optional[AudioAdapter] = None, bitrate: str = "128k", synchronous: bool = True, ) -> None: """ Export dictionary of sources to files. Parameters: sources (Dict): Dictionary of sources to be exported. The keys are the name of the instruments, and the values are `N x 2` numpy arrays containing the corresponding intrument waveform, as returned by the separate method audio_descriptor (AudioDescriptor): Describe song to separate, used by audio adapter to retrieve and load audio data, in case of file based audio adapter, such descriptor would be a file path. destination (str): Target directory to write output to. filename_format (str): (Optional) Filename format. codec (Codec): (Optional) Export codec. audio_adapter (Optional[AudioAdapter]): (Optional) Audio adapter to use for I/O. bitrate (str): (Optional) Export bitrate. synchronous (bool): (Optional) True is should by synchronous. """ if audio_adapter is None: audio_adapter = AudioAdapter.default() foldername = basename(dirname(audio_descriptor)) filename = splitext(basename(audio_descriptor))[0] generated = [] for instrument, data in sources.items(): path = join( destination, filename_format.format( filename=filename, instrument=instrument, foldername=foldername, codec=codec, ), ) directory = os.path.dirname(path) if not os.path.exists(directory): os.makedirs(directory) if path in generated: raise SpleeterError( ( f"Separated source path conflict : {path}," "please check your filename format" ) ) generated.append(path) if self._pool: task = self._pool.apply_async( audio_adapter.save, (path, data, self._sample_rate, codec, bitrate) ) self._tasks.append(task) else: audio_adapter.save(path, data, self._sample_rate, codec, bitrate) if synchronous and self._pool: self.join()