You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

262 lines
8.2 KiB

#!/usr/bin/env python
# coding: utf8
"""
Python oneliner script usage.
USAGE: python -m spleeter {train,evaluate,separate} ...
Notes:
All critical import involving TF, numpy or Pandas are deported to
command function scope to avoid heavy import on CLI evaluation,
leading to large bootstraping time.
"""
import json
from functools import partial
from glob import glob
from itertools import product
from os.path import join
from pathlib import Path
from typing import Container, Dict, List, Optional
# pyright: reportMissingImports=false
# pylint: disable=import-error
from typer import Exit, Typer
from . import SpleeterError
from .options import *
from .utils.logging import configure_logger, logger
# pylint: enable=import-error
spleeter: Typer = Typer(add_completion=False, no_args_is_help=True, short_help="-h")
""" CLI application. """
@spleeter.callback()
def default(
version: bool = VersionOption,
) -> None:
pass
@spleeter.command(no_args_is_help=True)
def train(
adapter: str = AudioAdapterOption,
data: Path = TrainingDataDirectoryOption,
params_filename: str = ModelParametersOption,
verbose: bool = VerboseOption,
) -> None:
"""
Train a source separation model
"""
import tensorflow as tf
from .audio.adapter import AudioAdapter
from .dataset import get_training_dataset, get_validation_dataset
from .model import model_fn
from .model.provider import ModelProvider
from .utils.configuration import load_configuration
configure_logger(verbose)
audio_adapter = AudioAdapter.get(adapter)
audio_path = str(data)
params = load_configuration(params_filename)
session_config = tf.compat.v1.ConfigProto()
session_config.gpu_options.per_process_gpu_memory_fraction = 0.45
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=params["model_dir"],
params=params,
config=tf.estimator.RunConfig(
save_checkpoints_steps=params["save_checkpoints_steps"],
tf_random_seed=params["random_seed"],
save_summary_steps=params["save_summary_steps"],
session_config=session_config,
log_step_count_steps=10,
keep_checkpoint_max=2,
),
)
input_fn = partial(get_training_dataset, params, audio_adapter, audio_path)
train_spec = tf.estimator.TrainSpec(
input_fn=input_fn, max_steps=params["train_max_steps"]
)
input_fn = partial(get_validation_dataset, params, audio_adapter, audio_path)
evaluation_spec = tf.estimator.EvalSpec(
input_fn=input_fn, steps=None, throttle_secs=params["throttle_secs"]
)
logger.info("Start model training")
tf.estimator.train_and_evaluate(estimator, train_spec, evaluation_spec)
ModelProvider.writeProbe(params["model_dir"])
logger.info("Model training done")
@spleeter.command(no_args_is_help=True)
def separate(
deprecated_files: Optional[str] = AudioInputOption,
files: List[Path] = AudioInputArgument,
adapter: str = AudioAdapterOption,
bitrate: str = AudioBitrateOption,
codec: Codec = AudioCodecOption,
duration: float = AudioDurationOption,
offset: float = AudioOffsetOption,
output_path: Path = AudioOutputOption,
stft_backend: STFTBackend = AudioSTFTBackendOption,
filename_format: str = FilenameFormatOption,
params_filename: str = ModelParametersOption,
mwf: bool = MWFOption,
verbose: bool = VerboseOption,
) -> None:
"""
Separate audio file(s)
"""
from .audio.adapter import AudioAdapter
from .separator import Separator
configure_logger(verbose)
if deprecated_files is not None:
logger.error(
"⚠️ -i option is not supported anymore, audio files must be supplied "
"using input argument instead (see spleeter separate --help)"
)
raise Exit(20)
audio_adapter: AudioAdapter = AudioAdapter.get(adapter)
separator: Separator = Separator(
params_filename, MWF=mwf, stft_backend=stft_backend
)
for filename in files:
separator.separate_to_file(
str(filename),
str(output_path),
audio_adapter=audio_adapter,
offset=offset,
duration=duration,
codec=codec,
bitrate=bitrate,
filename_format=filename_format,
synchronous=False,
)
separator.join()
EVALUATION_SPLIT: str = "test"
EVALUATION_METRICS_DIRECTORY: str = "metrics"
EVALUATION_INSTRUMENTS: Container[str] = ("vocals", "drums", "bass", "other")
EVALUATION_METRICS: Container[str] = ("SDR", "SAR", "SIR", "ISR")
EVALUATION_MIXTURE: str = "mixture.wav"
EVALUATION_AUDIO_DIRECTORY: str = "audio"
def _compile_metrics(metrics_output_directory) -> Dict:
"""
Compiles metrics from given directory and returns results as dict.
Parameters:
metrics_output_directory (str):
Directory to get metrics from.
Returns:
Dict:
Compiled metrics as dict.
"""
import numpy as np
import pandas as pd
songs = glob(join(metrics_output_directory, "test/*.json"))
index = pd.MultiIndex.from_tuples(
product(EVALUATION_INSTRUMENTS, EVALUATION_METRICS),
names=["instrument", "metric"],
)
pd.DataFrame([], index=["config1", "config2"], columns=index)
metrics = {
instrument: {k: [] for k in EVALUATION_METRICS}
for instrument in EVALUATION_INSTRUMENTS
}
for song in songs:
with open(song, "r") as stream:
data = json.load(stream)
for target in data["targets"]:
instrument = target["name"]
for metric in EVALUATION_METRICS:
sdr_med = np.median(
[
frame["metrics"][metric]
for frame in target["frames"]
if not np.isnan(frame["metrics"][metric])
]
)
metrics[instrument][metric].append(sdr_med)
return metrics
@spleeter.command(no_args_is_help=True)
def evaluate(
adapter: str = AudioAdapterOption,
output_path: Path = AudioOutputOption,
stft_backend: STFTBackend = AudioSTFTBackendOption,
params_filename: str = ModelParametersOption,
mus_dir: Path = MUSDBDirectoryOption,
mwf: bool = MWFOption,
verbose: bool = VerboseOption,
) -> Dict:
"""
Evaluate a model on the musDB test dataset
"""
import numpy as np
configure_logger(verbose)
try:
import musdb
import museval
except ImportError:
logger.error("Extra dependencies musdb and museval not found")
logger.error("Please install musdb and museval first, abort")
raise Exit(10)
# Separate musdb sources.
songs = glob(join(mus_dir, EVALUATION_SPLIT, "*/"))
mixtures = [join(song, EVALUATION_MIXTURE) for song in songs]
audio_output_directory = join(output_path, EVALUATION_AUDIO_DIRECTORY)
separate(
deprecated_files=None,
files=mixtures,
adapter=adapter,
bitrate="128k",
codec=Codec.WAV,
duration=600.0,
offset=0,
output_path=join(audio_output_directory, EVALUATION_SPLIT),
stft_backend=stft_backend,
filename_format="{foldername}/{instrument}.{codec}",
params_filename=params_filename,
mwf=mwf,
verbose=verbose,
)
# Compute metrics with musdb.
metrics_output_directory = join(output_path, EVALUATION_METRICS_DIRECTORY)
logger.info("Starting musdb evaluation (this could be long) ...")
dataset = musdb.DB(root=mus_dir, is_wav=True, subsets=[EVALUATION_SPLIT])
museval.eval_mus_dir(
dataset=dataset,
estimates_dir=audio_output_directory,
output_dir=metrics_output_directory,
)
logger.info("musdb evaluation done")
# Compute and pretty print median metrics.
metrics = _compile_metrics(metrics_output_directory)
for instrument, metric in metrics.items():
logger.info(f"{instrument}:")
for metric, value in metric.items():
logger.info(f"{metric}: {np.median(value):.3f}")
return metrics
def entrypoint():
""" Application entrypoint. """
try:
spleeter()
except SpleeterError as e:
logger.error(e)
if __name__ == "__main__":
entrypoint()