You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

124 lines
3.9 KiB

2 years ago
  1. #!/usr/bin/env python
  2. # coding: utf8
  3. """ Unit testing for Separator class. """
  4. __email__ = 'research@deezer.com'
  5. __author__ = 'Deezer Research'
  6. __license__ = 'MIT License'
  7. import json
  8. import os
  9. from os import makedirs
  10. from os.path import join
  11. from tempfile import TemporaryDirectory
  12. import numpy as np
  13. import pandas as pd
  14. from spleeter.audio.adapter import AudioAdapter
  15. from spleeter.__main__ import spleeter
  16. from typer.testing import CliRunner
  17. TRAIN_CONFIG = {
  18. 'mix_name': 'mix',
  19. 'instrument_list': ['vocals', 'other'],
  20. 'sample_rate': 44100,
  21. 'frame_length': 4096,
  22. 'frame_step': 1024,
  23. 'T': 128,
  24. 'F': 128,
  25. 'n_channels': 2,
  26. 'chunk_duration': 4,
  27. 'n_chunks_per_song': 1,
  28. 'separation_exponent': 2,
  29. 'mask_extension': 'zeros',
  30. 'learning_rate': 1e-4,
  31. 'batch_size': 2,
  32. 'train_max_steps': 10,
  33. 'throttle_secs': 20,
  34. 'save_checkpoints_steps': 100,
  35. 'save_summary_steps': 5,
  36. 'random_seed': 0,
  37. 'model': {
  38. 'type': 'unet.unet',
  39. 'params': {
  40. 'conv_activation': 'ELU',
  41. 'deconv_activation': 'ELU'
  42. }
  43. }
  44. }
  45. def generate_fake_training_dataset(path,
  46. instrument_list=['vocals', 'other'],
  47. n_channels=2,
  48. n_songs = 2,
  49. fs = 44100,
  50. duration = 6,
  51. ):
  52. """
  53. generates a fake training dataset in path:
  54. - generates audio files
  55. - generates a csv file describing the dataset
  56. """
  57. aa = AudioAdapter.default()
  58. rng = np.random.RandomState(seed=0)
  59. dataset_df = pd.DataFrame(
  60. columns=['mix_path'] + [
  61. f'{instr}_path' for instr in instrument_list] + ['duration'])
  62. for song in range(n_songs):
  63. song_path = join(path, 'train', f'song{song}')
  64. makedirs(song_path, exist_ok=True)
  65. dataset_df.loc[song, f'duration'] = duration
  66. for instr in instrument_list+['mix']:
  67. filename = join(song_path, f'{instr}.wav')
  68. data = rng.rand(duration*fs, n_channels)-0.5
  69. aa.save(filename, data, fs)
  70. dataset_df.loc[song, f'{instr}_path'] = join(
  71. 'train',
  72. f'song{song}',
  73. f'{instr}.wav')
  74. dataset_df.to_csv(join(path, 'train', 'train.csv'), index=False)
  75. def test_train():
  76. with TemporaryDirectory() as path:
  77. # generate training dataset
  78. for n_channels in [1,2]:
  79. TRAIN_CONFIG["n_channels"] = n_channels
  80. generate_fake_training_dataset(path,
  81. n_channels=n_channels,
  82. fs=TRAIN_CONFIG["sample_rate"]
  83. )
  84. # set training command arguments
  85. runner = CliRunner()
  86. model_dir = join(path, f'model_{n_channels}')
  87. train_dir = join(path, f'train')
  88. cache_dir = join(path, f'cache_{n_channels}')
  89. TRAIN_CONFIG['train_csv'] = join(train_dir, 'train.csv')
  90. TRAIN_CONFIG['validation_csv'] = join(train_dir, 'train.csv')
  91. TRAIN_CONFIG['model_dir'] = model_dir
  92. TRAIN_CONFIG['training_cache'] = join(cache_dir, 'training')
  93. TRAIN_CONFIG['validation_cache'] = join(cache_dir, 'validation')
  94. with open('useless_config.json', 'w') as stream:
  95. json.dump(TRAIN_CONFIG, stream)
  96. # execute training
  97. result = runner.invoke(spleeter, [
  98. 'train',
  99. '-p', 'useless_config.json',
  100. '-d', path,
  101. "--verbose"
  102. ])
  103. # assert that model checkpoint was created.
  104. assert os.path.exists(join(model_dir, 'model.ckpt-10.index'))
  105. assert os.path.exists(join(model_dir, 'checkpoint'))
  106. assert os.path.exists(join(model_dir, 'model.ckpt-0.meta'))
  107. assert result.exit_code == 0