#!/usr/bin/env python
|
|
# coding: utf8
|
|
|
|
""" Unit testing for Separator class. """
|
|
|
|
__email__ = 'research@deezer.com'
|
|
__author__ = 'Deezer Research'
|
|
__license__ = 'MIT License'
|
|
|
|
import json
|
|
import os
|
|
|
|
from os import makedirs
|
|
from os.path import join
|
|
from tempfile import TemporaryDirectory
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
from spleeter.audio.adapter import AudioAdapter
|
|
from spleeter.__main__ import spleeter
|
|
from typer.testing import CliRunner
|
|
|
|
|
|
TRAIN_CONFIG = {
|
|
'mix_name': 'mix',
|
|
'instrument_list': ['vocals', 'other'],
|
|
'sample_rate': 44100,
|
|
'frame_length': 4096,
|
|
'frame_step': 1024,
|
|
'T': 128,
|
|
'F': 128,
|
|
'n_channels': 2,
|
|
'chunk_duration': 4,
|
|
'n_chunks_per_song': 1,
|
|
'separation_exponent': 2,
|
|
'mask_extension': 'zeros',
|
|
'learning_rate': 1e-4,
|
|
'batch_size': 2,
|
|
'train_max_steps': 10,
|
|
'throttle_secs': 20,
|
|
'save_checkpoints_steps': 100,
|
|
'save_summary_steps': 5,
|
|
'random_seed': 0,
|
|
'model': {
|
|
'type': 'unet.unet',
|
|
'params': {
|
|
'conv_activation': 'ELU',
|
|
'deconv_activation': 'ELU'
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
def generate_fake_training_dataset(path,
|
|
instrument_list=['vocals', 'other'],
|
|
n_channels=2,
|
|
n_songs = 2,
|
|
fs = 44100,
|
|
duration = 6,
|
|
):
|
|
"""
|
|
generates a fake training dataset in path:
|
|
- generates audio files
|
|
- generates a csv file describing the dataset
|
|
"""
|
|
aa = AudioAdapter.default()
|
|
rng = np.random.RandomState(seed=0)
|
|
dataset_df = pd.DataFrame(
|
|
columns=['mix_path'] + [
|
|
f'{instr}_path' for instr in instrument_list] + ['duration'])
|
|
for song in range(n_songs):
|
|
song_path = join(path, 'train', f'song{song}')
|
|
makedirs(song_path, exist_ok=True)
|
|
dataset_df.loc[song, f'duration'] = duration
|
|
for instr in instrument_list+['mix']:
|
|
filename = join(song_path, f'{instr}.wav')
|
|
data = rng.rand(duration*fs, n_channels)-0.5
|
|
aa.save(filename, data, fs)
|
|
dataset_df.loc[song, f'{instr}_path'] = join(
|
|
'train',
|
|
f'song{song}',
|
|
f'{instr}.wav')
|
|
dataset_df.to_csv(join(path, 'train', 'train.csv'), index=False)
|
|
|
|
|
|
def test_train():
|
|
|
|
with TemporaryDirectory() as path:
|
|
# generate training dataset
|
|
for n_channels in [1,2]:
|
|
TRAIN_CONFIG["n_channels"] = n_channels
|
|
generate_fake_training_dataset(path,
|
|
n_channels=n_channels,
|
|
fs=TRAIN_CONFIG["sample_rate"]
|
|
)
|
|
# set training command arguments
|
|
runner = CliRunner()
|
|
|
|
model_dir = join(path, f'model_{n_channels}')
|
|
train_dir = join(path, f'train')
|
|
cache_dir = join(path, f'cache_{n_channels}')
|
|
|
|
TRAIN_CONFIG['train_csv'] = join(train_dir, 'train.csv')
|
|
TRAIN_CONFIG['validation_csv'] = join(train_dir, 'train.csv')
|
|
TRAIN_CONFIG['model_dir'] = model_dir
|
|
TRAIN_CONFIG['training_cache'] = join(cache_dir, 'training')
|
|
TRAIN_CONFIG['validation_cache'] = join(cache_dir, 'validation')
|
|
with open('useless_config.json', 'w') as stream:
|
|
json.dump(TRAIN_CONFIG, stream)
|
|
|
|
# execute training
|
|
result = runner.invoke(spleeter, [
|
|
'train',
|
|
'-p', 'useless_config.json',
|
|
'-d', path,
|
|
"--verbose"
|
|
])
|
|
|
|
# assert that model checkpoint was created.
|
|
assert os.path.exists(join(model_dir, 'model.ckpt-10.index'))
|
|
assert os.path.exists(join(model_dir, 'checkpoint'))
|
|
assert os.path.exists(join(model_dir, 'model.ckpt-0.meta'))
|
|
assert result.exit_code == 0
|