|
|
- #!/usr/bin/env python
- # coding: utf8
-
- """ Unit testing for Separator class. """
-
- __email__ = 'research@deezer.com'
- __author__ = 'Deezer Research'
- __license__ = 'MIT License'
-
- import json
- import os
-
- from os import makedirs
- from os.path import join
- from tempfile import TemporaryDirectory
-
- import numpy as np
- import pandas as pd
-
- from spleeter.audio.adapter import AudioAdapter
- from spleeter.__main__ import spleeter
- from typer.testing import CliRunner
-
-
- TRAIN_CONFIG = {
- 'mix_name': 'mix',
- 'instrument_list': ['vocals', 'other'],
- 'sample_rate': 44100,
- 'frame_length': 4096,
- 'frame_step': 1024,
- 'T': 128,
- 'F': 128,
- 'n_channels': 2,
- 'chunk_duration': 4,
- 'n_chunks_per_song': 1,
- 'separation_exponent': 2,
- 'mask_extension': 'zeros',
- 'learning_rate': 1e-4,
- 'batch_size': 2,
- 'train_max_steps': 10,
- 'throttle_secs': 20,
- 'save_checkpoints_steps': 100,
- 'save_summary_steps': 5,
- 'random_seed': 0,
- 'model': {
- 'type': 'unet.unet',
- 'params': {
- 'conv_activation': 'ELU',
- 'deconv_activation': 'ELU'
- }
- }
- }
-
-
- def generate_fake_training_dataset(path,
- instrument_list=['vocals', 'other'],
- n_channels=2,
- n_songs = 2,
- fs = 44100,
- duration = 6,
- ):
- """
- generates a fake training dataset in path:
- - generates audio files
- - generates a csv file describing the dataset
- """
- aa = AudioAdapter.default()
- rng = np.random.RandomState(seed=0)
- dataset_df = pd.DataFrame(
- columns=['mix_path'] + [
- f'{instr}_path' for instr in instrument_list] + ['duration'])
- for song in range(n_songs):
- song_path = join(path, 'train', f'song{song}')
- makedirs(song_path, exist_ok=True)
- dataset_df.loc[song, f'duration'] = duration
- for instr in instrument_list+['mix']:
- filename = join(song_path, f'{instr}.wav')
- data = rng.rand(duration*fs, n_channels)-0.5
- aa.save(filename, data, fs)
- dataset_df.loc[song, f'{instr}_path'] = join(
- 'train',
- f'song{song}',
- f'{instr}.wav')
- dataset_df.to_csv(join(path, 'train', 'train.csv'), index=False)
-
-
- def test_train():
-
- with TemporaryDirectory() as path:
- # generate training dataset
- for n_channels in [1,2]:
- TRAIN_CONFIG["n_channels"] = n_channels
- generate_fake_training_dataset(path,
- n_channels=n_channels,
- fs=TRAIN_CONFIG["sample_rate"]
- )
- # set training command arguments
- runner = CliRunner()
-
- model_dir = join(path, f'model_{n_channels}')
- train_dir = join(path, f'train')
- cache_dir = join(path, f'cache_{n_channels}')
-
- TRAIN_CONFIG['train_csv'] = join(train_dir, 'train.csv')
- TRAIN_CONFIG['validation_csv'] = join(train_dir, 'train.csv')
- TRAIN_CONFIG['model_dir'] = model_dir
- TRAIN_CONFIG['training_cache'] = join(cache_dir, 'training')
- TRAIN_CONFIG['validation_cache'] = join(cache_dir, 'validation')
- with open('useless_config.json', 'w') as stream:
- json.dump(TRAIN_CONFIG, stream)
-
- # execute training
- result = runner.invoke(spleeter, [
- 'train',
- '-p', 'useless_config.json',
- '-d', path,
- "--verbose"
- ])
-
- # assert that model checkpoint was created.
- assert os.path.exists(join(model_dir, 'model.ckpt-10.index'))
- assert os.path.exists(join(model_dir, 'checkpoint'))
- assert os.path.exists(join(model_dir, 'model.ckpt-0.meta'))
- assert result.exit_code == 0
|