You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

140 lines
4.8 KiB

2 years ago
  1. #!/usr/bin/env python
  2. # coding: utf8
  3. """ Unit testing for Separator class. """
  4. __email__ = 'spleeter@deezer.com'
  5. __author__ = 'Deezer Research'
  6. __license__ = 'MIT License'
  7. import itertools
  8. from os.path import splitext, basename, exists, join
  9. from tempfile import TemporaryDirectory
  10. import pytest
  11. import numpy as np
  12. import tensorflow as tf
  13. from spleeter import SpleeterError
  14. from spleeter.audio.adapter import AudioAdapter
  15. from spleeter.separator import Separator
  16. TEST_AUDIO_DESCRIPTORS = ['audio_example.mp3', 'audio_example_mono.mp3']
  17. BACKENDS = ["tensorflow", "librosa"]
  18. MODELS = ['spleeter:2stems', 'spleeter:4stems', 'spleeter:5stems']
  19. MODEL_TO_INST = {
  20. 'spleeter:2stems': ('vocals', 'accompaniment'),
  21. 'spleeter:4stems': ('vocals', 'drums', 'bass', 'other'),
  22. 'spleeter:5stems': ('vocals', 'drums', 'bass', 'piano', 'other'),
  23. }
  24. MODELS_AND_TEST_FILES = list(itertools.product(TEST_AUDIO_DESCRIPTORS, MODELS))
  25. TEST_CONFIGURATIONS = list(itertools.product(
  26. TEST_AUDIO_DESCRIPTORS, MODELS, BACKENDS))
  27. print("RUNNING TESTS WITH TF VERSION {}".format(tf.__version__))
  28. @pytest.mark.parametrize('test_file', TEST_AUDIO_DESCRIPTORS)
  29. def test_separator_backends(test_file):
  30. adapter = AudioAdapter.default()
  31. waveform, _ = adapter.load(test_file)
  32. separator_lib = Separator(
  33. "spleeter:2stems", stft_backend="librosa", multiprocess=False)
  34. separator_tf = Separator(
  35. "spleeter:2stems", stft_backend="tensorflow", multiprocess=False)
  36. # Test the stft and inverse stft provides exact reconstruction
  37. stft_matrix = separator_lib._stft(waveform)
  38. reconstructed = separator_lib._stft(
  39. stft_matrix, inverse=True, length=waveform.shape[0])
  40. assert np.allclose(reconstructed, waveform, atol=3e-2)
  41. # compare both separation, it should be close
  42. out_tf = separator_tf._separate_tensorflow(waveform, test_file)
  43. out_lib = separator_lib._separate_librosa(waveform, test_file)
  44. for instrument in out_lib.keys():
  45. # test that both outputs are close everywhere
  46. assert np.allclose(out_tf[instrument], out_lib[instrument], atol=1e-5)
  47. @pytest.mark.parametrize(
  48. 'test_file, configuration, backend',
  49. TEST_CONFIGURATIONS)
  50. def test_separate(test_file, configuration, backend):
  51. """ Test separation from raw data. """
  52. instruments = MODEL_TO_INST[configuration]
  53. adapter = AudioAdapter.default()
  54. waveform, _ = adapter.load(test_file)
  55. separator = Separator(
  56. configuration, stft_backend=backend, multiprocess=False)
  57. prediction = separator.separate(waveform, test_file)
  58. assert len(prediction) == len(instruments)
  59. for instrument in instruments:
  60. assert instrument in prediction
  61. for instrument in instruments:
  62. track = prediction[instrument]
  63. assert waveform.shape[:-1] == track.shape[:-1]
  64. assert not np.allclose(waveform, track)
  65. for compared in instruments:
  66. if instrument != compared:
  67. assert not np.allclose(track, prediction[compared])
  68. @pytest.mark.parametrize(
  69. 'test_file, configuration, backend',
  70. TEST_CONFIGURATIONS)
  71. def test_separate_to_file(test_file, configuration, backend):
  72. """ Test file based separation. """
  73. instruments = MODEL_TO_INST[configuration]
  74. separator = Separator(
  75. configuration, stft_backend=backend, multiprocess=False)
  76. name = splitext(basename(test_file))[0]
  77. with TemporaryDirectory() as directory:
  78. separator.separate_to_file(
  79. test_file,
  80. directory)
  81. for instrument in instruments:
  82. assert exists(join(
  83. directory,
  84. '{}/{}.wav'.format(name, instrument)))
  85. @pytest.mark.parametrize(
  86. 'test_file, configuration, backend',
  87. TEST_CONFIGURATIONS)
  88. def test_filename_format(test_file, configuration, backend):
  89. """ Test custom filename format. """
  90. instruments = MODEL_TO_INST[configuration]
  91. separator = Separator(
  92. configuration, stft_backend=backend, multiprocess=False)
  93. name = splitext(basename(test_file))[0]
  94. with TemporaryDirectory() as directory:
  95. separator.separate_to_file(
  96. test_file,
  97. directory,
  98. filename_format='export/{filename}/{instrument}.{codec}')
  99. for instrument in instruments:
  100. assert exists(join(
  101. directory,
  102. 'export/{}/{}.wav'.format(name, instrument)))
  103. @pytest.mark.parametrize(
  104. 'test_file, configuration',
  105. MODELS_AND_TEST_FILES)
  106. def test_filename_conflict(test_file, configuration):
  107. """ Test error handling with static pattern. """
  108. separator = Separator(configuration, multiprocess=False)
  109. with TemporaryDirectory() as directory:
  110. with pytest.raises(SpleeterError):
  111. separator.separate_to_file(
  112. test_file,
  113. directory,
  114. filename_format='I wanna be your lover')