#!/usr/bin/env python
|
|
# coding: utf8
|
|
|
|
"""
|
|
Python oneliner script usage.
|
|
|
|
USAGE: python -m spleeter {train,evaluate,separate} ...
|
|
|
|
Notes:
|
|
All critical import involving TF, numpy or Pandas are deported to
|
|
command function scope to avoid heavy import on CLI evaluation,
|
|
leading to large bootstraping time.
|
|
"""
|
|
import json
|
|
from functools import partial
|
|
from glob import glob
|
|
from itertools import product
|
|
from os.path import join
|
|
from pathlib import Path
|
|
from typing import Container, Dict, List, Optional
|
|
|
|
# pyright: reportMissingImports=false
|
|
# pylint: disable=import-error
|
|
from typer import Exit, Typer
|
|
|
|
from . import SpleeterError
|
|
from .options import *
|
|
from .utils.logging import configure_logger, logger
|
|
|
|
# pylint: enable=import-error
|
|
|
|
spleeter: Typer = Typer(add_completion=False, no_args_is_help=True, short_help="-h")
|
|
""" CLI application. """
|
|
|
|
|
|
@spleeter.callback()
|
|
def default(
|
|
version: bool = VersionOption,
|
|
) -> None:
|
|
pass
|
|
|
|
|
|
@spleeter.command(no_args_is_help=True)
|
|
def train(
|
|
adapter: str = AudioAdapterOption,
|
|
data: Path = TrainingDataDirectoryOption,
|
|
params_filename: str = ModelParametersOption,
|
|
verbose: bool = VerboseOption,
|
|
) -> None:
|
|
"""
|
|
Train a source separation model
|
|
"""
|
|
import tensorflow as tf
|
|
|
|
from .audio.adapter import AudioAdapter
|
|
from .dataset import get_training_dataset, get_validation_dataset
|
|
from .model import model_fn
|
|
from .model.provider import ModelProvider
|
|
from .utils.configuration import load_configuration
|
|
|
|
configure_logger(verbose)
|
|
audio_adapter = AudioAdapter.get(adapter)
|
|
audio_path = str(data)
|
|
params = load_configuration(params_filename)
|
|
session_config = tf.compat.v1.ConfigProto()
|
|
session_config.gpu_options.per_process_gpu_memory_fraction = 0.45
|
|
estimator = tf.estimator.Estimator(
|
|
model_fn=model_fn,
|
|
model_dir=params["model_dir"],
|
|
params=params,
|
|
config=tf.estimator.RunConfig(
|
|
save_checkpoints_steps=params["save_checkpoints_steps"],
|
|
tf_random_seed=params["random_seed"],
|
|
save_summary_steps=params["save_summary_steps"],
|
|
session_config=session_config,
|
|
log_step_count_steps=10,
|
|
keep_checkpoint_max=2,
|
|
),
|
|
)
|
|
input_fn = partial(get_training_dataset, params, audio_adapter, audio_path)
|
|
train_spec = tf.estimator.TrainSpec(
|
|
input_fn=input_fn, max_steps=params["train_max_steps"]
|
|
)
|
|
input_fn = partial(get_validation_dataset, params, audio_adapter, audio_path)
|
|
evaluation_spec = tf.estimator.EvalSpec(
|
|
input_fn=input_fn, steps=None, throttle_secs=params["throttle_secs"]
|
|
)
|
|
logger.info("Start model training")
|
|
tf.estimator.train_and_evaluate(estimator, train_spec, evaluation_spec)
|
|
ModelProvider.writeProbe(params["model_dir"])
|
|
logger.info("Model training done")
|
|
|
|
|
|
@spleeter.command(no_args_is_help=True)
|
|
def separate(
|
|
deprecated_files: Optional[str] = AudioInputOption,
|
|
files: List[Path] = AudioInputArgument,
|
|
adapter: str = AudioAdapterOption,
|
|
bitrate: str = AudioBitrateOption,
|
|
codec: Codec = AudioCodecOption,
|
|
duration: float = AudioDurationOption,
|
|
offset: float = AudioOffsetOption,
|
|
output_path: Path = AudioOutputOption,
|
|
stft_backend: STFTBackend = AudioSTFTBackendOption,
|
|
filename_format: str = FilenameFormatOption,
|
|
params_filename: str = ModelParametersOption,
|
|
mwf: bool = MWFOption,
|
|
verbose: bool = VerboseOption,
|
|
) -> None:
|
|
"""
|
|
Separate audio file(s)
|
|
"""
|
|
from .audio.adapter import AudioAdapter
|
|
from .separator import Separator
|
|
|
|
configure_logger(verbose)
|
|
if deprecated_files is not None:
|
|
logger.error(
|
|
"⚠️ -i option is not supported anymore, audio files must be supplied "
|
|
"using input argument instead (see spleeter separate --help)"
|
|
)
|
|
raise Exit(20)
|
|
audio_adapter: AudioAdapter = AudioAdapter.get(adapter)
|
|
separator: Separator = Separator(
|
|
params_filename, MWF=mwf, stft_backend=stft_backend
|
|
)
|
|
for filename in files:
|
|
separator.separate_to_file(
|
|
str(filename),
|
|
str(output_path),
|
|
audio_adapter=audio_adapter,
|
|
offset=offset,
|
|
duration=duration,
|
|
codec=codec,
|
|
bitrate=bitrate,
|
|
filename_format=filename_format,
|
|
synchronous=False,
|
|
)
|
|
separator.join()
|
|
|
|
|
|
EVALUATION_SPLIT: str = "test"
|
|
EVALUATION_METRICS_DIRECTORY: str = "metrics"
|
|
EVALUATION_INSTRUMENTS: Container[str] = ("vocals", "drums", "bass", "other")
|
|
EVALUATION_METRICS: Container[str] = ("SDR", "SAR", "SIR", "ISR")
|
|
EVALUATION_MIXTURE: str = "mixture.wav"
|
|
EVALUATION_AUDIO_DIRECTORY: str = "audio"
|
|
|
|
|
|
def _compile_metrics(metrics_output_directory) -> Dict:
|
|
"""
|
|
Compiles metrics from given directory and returns results as dict.
|
|
|
|
Parameters:
|
|
metrics_output_directory (str):
|
|
Directory to get metrics from.
|
|
|
|
Returns:
|
|
Dict:
|
|
Compiled metrics as dict.
|
|
"""
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
songs = glob(join(metrics_output_directory, "test/*.json"))
|
|
index = pd.MultiIndex.from_tuples(
|
|
product(EVALUATION_INSTRUMENTS, EVALUATION_METRICS),
|
|
names=["instrument", "metric"],
|
|
)
|
|
pd.DataFrame([], index=["config1", "config2"], columns=index)
|
|
metrics = {
|
|
instrument: {k: [] for k in EVALUATION_METRICS}
|
|
for instrument in EVALUATION_INSTRUMENTS
|
|
}
|
|
for song in songs:
|
|
with open(song, "r") as stream:
|
|
data = json.load(stream)
|
|
for target in data["targets"]:
|
|
instrument = target["name"]
|
|
for metric in EVALUATION_METRICS:
|
|
sdr_med = np.median(
|
|
[
|
|
frame["metrics"][metric]
|
|
for frame in target["frames"]
|
|
if not np.isnan(frame["metrics"][metric])
|
|
]
|
|
)
|
|
metrics[instrument][metric].append(sdr_med)
|
|
return metrics
|
|
|
|
|
|
@spleeter.command(no_args_is_help=True)
|
|
def evaluate(
|
|
adapter: str = AudioAdapterOption,
|
|
output_path: Path = AudioOutputOption,
|
|
stft_backend: STFTBackend = AudioSTFTBackendOption,
|
|
params_filename: str = ModelParametersOption,
|
|
mus_dir: Path = MUSDBDirectoryOption,
|
|
mwf: bool = MWFOption,
|
|
verbose: bool = VerboseOption,
|
|
) -> Dict:
|
|
"""
|
|
Evaluate a model on the musDB test dataset
|
|
"""
|
|
import numpy as np
|
|
|
|
configure_logger(verbose)
|
|
try:
|
|
import musdb
|
|
import museval
|
|
except ImportError:
|
|
logger.error("Extra dependencies musdb and museval not found")
|
|
logger.error("Please install musdb and museval first, abort")
|
|
raise Exit(10)
|
|
# Separate musdb sources.
|
|
songs = glob(join(mus_dir, EVALUATION_SPLIT, "*/"))
|
|
mixtures = [join(song, EVALUATION_MIXTURE) for song in songs]
|
|
audio_output_directory = join(output_path, EVALUATION_AUDIO_DIRECTORY)
|
|
separate(
|
|
deprecated_files=None,
|
|
files=mixtures,
|
|
adapter=adapter,
|
|
bitrate="128k",
|
|
codec=Codec.WAV,
|
|
duration=600.0,
|
|
offset=0,
|
|
output_path=join(audio_output_directory, EVALUATION_SPLIT),
|
|
stft_backend=stft_backend,
|
|
filename_format="{foldername}/{instrument}.{codec}",
|
|
params_filename=params_filename,
|
|
mwf=mwf,
|
|
verbose=verbose,
|
|
)
|
|
# Compute metrics with musdb.
|
|
metrics_output_directory = join(output_path, EVALUATION_METRICS_DIRECTORY)
|
|
logger.info("Starting musdb evaluation (this could be long) ...")
|
|
dataset = musdb.DB(root=mus_dir, is_wav=True, subsets=[EVALUATION_SPLIT])
|
|
museval.eval_mus_dir(
|
|
dataset=dataset,
|
|
estimates_dir=audio_output_directory,
|
|
output_dir=metrics_output_directory,
|
|
)
|
|
logger.info("musdb evaluation done")
|
|
# Compute and pretty print median metrics.
|
|
metrics = _compile_metrics(metrics_output_directory)
|
|
for instrument, metric in metrics.items():
|
|
logger.info(f"{instrument}:")
|
|
for metric, value in metric.items():
|
|
logger.info(f"{metric}: {np.median(value):.3f}")
|
|
return metrics
|
|
|
|
|
|
def entrypoint():
|
|
""" Application entrypoint. """
|
|
try:
|
|
spleeter()
|
|
except SpleeterError as e:
|
|
logger.error(e)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
entrypoint()
|