|
|
- #!/usr/bin/env python
- # coding: utf8
-
- """ Unit testing for Separator class. """
-
- __email__ = 'spleeter@deezer.com'
- __author__ = 'Deezer Research'
- __license__ = 'MIT License'
-
- import itertools
- from os.path import splitext, basename, exists, join
- from tempfile import TemporaryDirectory
-
- import pytest
- import numpy as np
-
- import tensorflow as tf
-
- from spleeter import SpleeterError
- from spleeter.audio.adapter import AudioAdapter
- from spleeter.separator import Separator
-
- TEST_AUDIO_DESCRIPTORS = ['audio_example.mp3', 'audio_example_mono.mp3']
- BACKENDS = ["tensorflow", "librosa"]
- MODELS = ['spleeter:2stems', 'spleeter:4stems', 'spleeter:5stems']
-
- MODEL_TO_INST = {
- 'spleeter:2stems': ('vocals', 'accompaniment'),
- 'spleeter:4stems': ('vocals', 'drums', 'bass', 'other'),
- 'spleeter:5stems': ('vocals', 'drums', 'bass', 'piano', 'other'),
- }
-
-
- MODELS_AND_TEST_FILES = list(itertools.product(TEST_AUDIO_DESCRIPTORS, MODELS))
- TEST_CONFIGURATIONS = list(itertools.product(
- TEST_AUDIO_DESCRIPTORS, MODELS, BACKENDS))
-
-
- print("RUNNING TESTS WITH TF VERSION {}".format(tf.__version__))
-
-
- @pytest.mark.parametrize('test_file', TEST_AUDIO_DESCRIPTORS)
- def test_separator_backends(test_file):
- adapter = AudioAdapter.default()
- waveform, _ = adapter.load(test_file)
-
- separator_lib = Separator(
- "spleeter:2stems", stft_backend="librosa", multiprocess=False)
- separator_tf = Separator(
- "spleeter:2stems", stft_backend="tensorflow", multiprocess=False)
-
- # Test the stft and inverse stft provides exact reconstruction
- stft_matrix = separator_lib._stft(waveform)
- reconstructed = separator_lib._stft(
- stft_matrix, inverse=True, length=waveform.shape[0])
- assert np.allclose(reconstructed, waveform, atol=3e-2)
-
- # compare both separation, it should be close
- out_tf = separator_tf._separate_tensorflow(waveform, test_file)
- out_lib = separator_lib._separate_librosa(waveform, test_file)
-
- for instrument in out_lib.keys():
- # test that both outputs are close everywhere
- assert np.allclose(out_tf[instrument], out_lib[instrument], atol=1e-5)
-
-
- @pytest.mark.parametrize(
- 'test_file, configuration, backend',
- TEST_CONFIGURATIONS)
- def test_separate(test_file, configuration, backend):
- """ Test separation from raw data. """
- instruments = MODEL_TO_INST[configuration]
- adapter = AudioAdapter.default()
- waveform, _ = adapter.load(test_file)
- separator = Separator(
- configuration, stft_backend=backend, multiprocess=False)
- prediction = separator.separate(waveform, test_file)
- assert len(prediction) == len(instruments)
- for instrument in instruments:
- assert instrument in prediction
- for instrument in instruments:
- track = prediction[instrument]
- assert waveform.shape[:-1] == track.shape[:-1]
- assert not np.allclose(waveform, track)
- for compared in instruments:
- if instrument != compared:
- assert not np.allclose(track, prediction[compared])
-
-
- @pytest.mark.parametrize(
- 'test_file, configuration, backend',
- TEST_CONFIGURATIONS)
- def test_separate_to_file(test_file, configuration, backend):
- """ Test file based separation. """
- instruments = MODEL_TO_INST[configuration]
- separator = Separator(
- configuration, stft_backend=backend, multiprocess=False)
- name = splitext(basename(test_file))[0]
- with TemporaryDirectory() as directory:
- separator.separate_to_file(
- test_file,
- directory)
- for instrument in instruments:
- assert exists(join(
- directory,
- '{}/{}.wav'.format(name, instrument)))
-
-
- @pytest.mark.parametrize(
- 'test_file, configuration, backend',
- TEST_CONFIGURATIONS)
- def test_filename_format(test_file, configuration, backend):
- """ Test custom filename format. """
- instruments = MODEL_TO_INST[configuration]
- separator = Separator(
- configuration, stft_backend=backend, multiprocess=False)
- name = splitext(basename(test_file))[0]
- with TemporaryDirectory() as directory:
- separator.separate_to_file(
- test_file,
- directory,
- filename_format='export/{filename}/{instrument}.{codec}')
- for instrument in instruments:
- assert exists(join(
- directory,
- 'export/{}/{}.wav'.format(name, instrument)))
-
-
- @pytest.mark.parametrize(
- 'test_file, configuration',
- MODELS_AND_TEST_FILES)
- def test_filename_conflict(test_file, configuration):
- """ Test error handling with static pattern. """
- separator = Separator(configuration, multiprocess=False)
- with TemporaryDirectory() as directory:
- with pytest.raises(SpleeterError):
- separator.separate_to_file(
- test_file,
- directory,
- filename_format='I wanna be your lover')
|