You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

262 lines
8.2 KiB

2 years ago
  1. #!/usr/bin/env python
  2. # coding: utf8
  3. """
  4. Python oneliner script usage.
  5. USAGE: python -m spleeter {train,evaluate,separate} ...
  6. Notes:
  7. All critical import involving TF, numpy or Pandas are deported to
  8. command function scope to avoid heavy import on CLI evaluation,
  9. leading to large bootstraping time.
  10. """
  11. import json
  12. from functools import partial
  13. from glob import glob
  14. from itertools import product
  15. from os.path import join
  16. from pathlib import Path
  17. from typing import Container, Dict, List, Optional
  18. # pyright: reportMissingImports=false
  19. # pylint: disable=import-error
  20. from typer import Exit, Typer
  21. from . import SpleeterError
  22. from .options import *
  23. from .utils.logging import configure_logger, logger
  24. # pylint: enable=import-error
  25. spleeter: Typer = Typer(add_completion=False, no_args_is_help=True, short_help="-h")
  26. """ CLI application. """
  27. @spleeter.callback()
  28. def default(
  29. version: bool = VersionOption,
  30. ) -> None:
  31. pass
  32. @spleeter.command(no_args_is_help=True)
  33. def train(
  34. adapter: str = AudioAdapterOption,
  35. data: Path = TrainingDataDirectoryOption,
  36. params_filename: str = ModelParametersOption,
  37. verbose: bool = VerboseOption,
  38. ) -> None:
  39. """
  40. Train a source separation model
  41. """
  42. import tensorflow as tf
  43. from .audio.adapter import AudioAdapter
  44. from .dataset import get_training_dataset, get_validation_dataset
  45. from .model import model_fn
  46. from .model.provider import ModelProvider
  47. from .utils.configuration import load_configuration
  48. configure_logger(verbose)
  49. audio_adapter = AudioAdapter.get(adapter)
  50. audio_path = str(data)
  51. params = load_configuration(params_filename)
  52. session_config = tf.compat.v1.ConfigProto()
  53. session_config.gpu_options.per_process_gpu_memory_fraction = 0.45
  54. estimator = tf.estimator.Estimator(
  55. model_fn=model_fn,
  56. model_dir=params["model_dir"],
  57. params=params,
  58. config=tf.estimator.RunConfig(
  59. save_checkpoints_steps=params["save_checkpoints_steps"],
  60. tf_random_seed=params["random_seed"],
  61. save_summary_steps=params["save_summary_steps"],
  62. session_config=session_config,
  63. log_step_count_steps=10,
  64. keep_checkpoint_max=2,
  65. ),
  66. )
  67. input_fn = partial(get_training_dataset, params, audio_adapter, audio_path)
  68. train_spec = tf.estimator.TrainSpec(
  69. input_fn=input_fn, max_steps=params["train_max_steps"]
  70. )
  71. input_fn = partial(get_validation_dataset, params, audio_adapter, audio_path)
  72. evaluation_spec = tf.estimator.EvalSpec(
  73. input_fn=input_fn, steps=None, throttle_secs=params["throttle_secs"]
  74. )
  75. logger.info("Start model training")
  76. tf.estimator.train_and_evaluate(estimator, train_spec, evaluation_spec)
  77. ModelProvider.writeProbe(params["model_dir"])
  78. logger.info("Model training done")
  79. @spleeter.command(no_args_is_help=True)
  80. def separate(
  81. deprecated_files: Optional[str] = AudioInputOption,
  82. files: List[Path] = AudioInputArgument,
  83. adapter: str = AudioAdapterOption,
  84. bitrate: str = AudioBitrateOption,
  85. codec: Codec = AudioCodecOption,
  86. duration: float = AudioDurationOption,
  87. offset: float = AudioOffsetOption,
  88. output_path: Path = AudioOutputOption,
  89. stft_backend: STFTBackend = AudioSTFTBackendOption,
  90. filename_format: str = FilenameFormatOption,
  91. params_filename: str = ModelParametersOption,
  92. mwf: bool = MWFOption,
  93. verbose: bool = VerboseOption,
  94. ) -> None:
  95. """
  96. Separate audio file(s)
  97. """
  98. from .audio.adapter import AudioAdapter
  99. from .separator import Separator
  100. configure_logger(verbose)
  101. if deprecated_files is not None:
  102. logger.error(
  103. "⚠️ -i option is not supported anymore, audio files must be supplied "
  104. "using input argument instead (see spleeter separate --help)"
  105. )
  106. raise Exit(20)
  107. audio_adapter: AudioAdapter = AudioAdapter.get(adapter)
  108. separator: Separator = Separator(
  109. params_filename, MWF=mwf, stft_backend=stft_backend
  110. )
  111. for filename in files:
  112. separator.separate_to_file(
  113. str(filename),
  114. str(output_path),
  115. audio_adapter=audio_adapter,
  116. offset=offset,
  117. duration=duration,
  118. codec=codec,
  119. bitrate=bitrate,
  120. filename_format=filename_format,
  121. synchronous=False,
  122. )
  123. separator.join()
  124. EVALUATION_SPLIT: str = "test"
  125. EVALUATION_METRICS_DIRECTORY: str = "metrics"
  126. EVALUATION_INSTRUMENTS: Container[str] = ("vocals", "drums", "bass", "other")
  127. EVALUATION_METRICS: Container[str] = ("SDR", "SAR", "SIR", "ISR")
  128. EVALUATION_MIXTURE: str = "mixture.wav"
  129. EVALUATION_AUDIO_DIRECTORY: str = "audio"
  130. def _compile_metrics(metrics_output_directory) -> Dict:
  131. """
  132. Compiles metrics from given directory and returns results as dict.
  133. Parameters:
  134. metrics_output_directory (str):
  135. Directory to get metrics from.
  136. Returns:
  137. Dict:
  138. Compiled metrics as dict.
  139. """
  140. import numpy as np
  141. import pandas as pd
  142. songs = glob(join(metrics_output_directory, "test/*.json"))
  143. index = pd.MultiIndex.from_tuples(
  144. product(EVALUATION_INSTRUMENTS, EVALUATION_METRICS),
  145. names=["instrument", "metric"],
  146. )
  147. pd.DataFrame([], index=["config1", "config2"], columns=index)
  148. metrics = {
  149. instrument: {k: [] for k in EVALUATION_METRICS}
  150. for instrument in EVALUATION_INSTRUMENTS
  151. }
  152. for song in songs:
  153. with open(song, "r") as stream:
  154. data = json.load(stream)
  155. for target in data["targets"]:
  156. instrument = target["name"]
  157. for metric in EVALUATION_METRICS:
  158. sdr_med = np.median(
  159. [
  160. frame["metrics"][metric]
  161. for frame in target["frames"]
  162. if not np.isnan(frame["metrics"][metric])
  163. ]
  164. )
  165. metrics[instrument][metric].append(sdr_med)
  166. return metrics
  167. @spleeter.command(no_args_is_help=True)
  168. def evaluate(
  169. adapter: str = AudioAdapterOption,
  170. output_path: Path = AudioOutputOption,
  171. stft_backend: STFTBackend = AudioSTFTBackendOption,
  172. params_filename: str = ModelParametersOption,
  173. mus_dir: Path = MUSDBDirectoryOption,
  174. mwf: bool = MWFOption,
  175. verbose: bool = VerboseOption,
  176. ) -> Dict:
  177. """
  178. Evaluate a model on the musDB test dataset
  179. """
  180. import numpy as np
  181. configure_logger(verbose)
  182. try:
  183. import musdb
  184. import museval
  185. except ImportError:
  186. logger.error("Extra dependencies musdb and museval not found")
  187. logger.error("Please install musdb and museval first, abort")
  188. raise Exit(10)
  189. # Separate musdb sources.
  190. songs = glob(join(mus_dir, EVALUATION_SPLIT, "*/"))
  191. mixtures = [join(song, EVALUATION_MIXTURE) for song in songs]
  192. audio_output_directory = join(output_path, EVALUATION_AUDIO_DIRECTORY)
  193. separate(
  194. deprecated_files=None,
  195. files=mixtures,
  196. adapter=adapter,
  197. bitrate="128k",
  198. codec=Codec.WAV,
  199. duration=600.0,
  200. offset=0,
  201. output_path=join(audio_output_directory, EVALUATION_SPLIT),
  202. stft_backend=stft_backend,
  203. filename_format="{foldername}/{instrument}.{codec}",
  204. params_filename=params_filename,
  205. mwf=mwf,
  206. verbose=verbose,
  207. )
  208. # Compute metrics with musdb.
  209. metrics_output_directory = join(output_path, EVALUATION_METRICS_DIRECTORY)
  210. logger.info("Starting musdb evaluation (this could be long) ...")
  211. dataset = musdb.DB(root=mus_dir, is_wav=True, subsets=[EVALUATION_SPLIT])
  212. museval.eval_mus_dir(
  213. dataset=dataset,
  214. estimates_dir=audio_output_directory,
  215. output_dir=metrics_output_directory,
  216. )
  217. logger.info("musdb evaluation done")
  218. # Compute and pretty print median metrics.
  219. metrics = _compile_metrics(metrics_output_directory)
  220. for instrument, metric in metrics.items():
  221. logger.info(f"{instrument}:")
  222. for metric, value in metric.items():
  223. logger.info(f"{metric}: {np.median(value):.3f}")
  224. return metrics
  225. def entrypoint():
  226. """ Application entrypoint. """
  227. try:
  228. spleeter()
  229. except SpleeterError as e:
  230. logger.error(e)
  231. if __name__ == "__main__":
  232. entrypoint()