You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

461 lines
16 KiB

2 years ago
  1. #!/usr/bin/env python
  2. # coding: utf8
  3. """
  4. Module that provides a class wrapper for source separation.
  5. Examples:
  6. ```python
  7. >>> from spleeter.separator import Separator
  8. >>> separator = Separator('spleeter:2stems')
  9. >>> separator.separate(waveform, lambda instrument, data: ...)
  10. >>> separator.separate_to_file(...)
  11. ```
  12. """
  13. import atexit
  14. import os
  15. from multiprocessing import Pool
  16. from os.path import basename, dirname, join, splitext
  17. from typing import Dict, Generator, Optional
  18. # pyright: reportMissingImports=false
  19. # pylint: disable=import-error
  20. import numpy as np
  21. import tensorflow as tf
  22. from librosa.core import istft, stft
  23. from scipy.signal.windows import hann
  24. from spleeter.model.provider import ModelProvider
  25. from . import SpleeterError
  26. from .audio import Codec, STFTBackend
  27. from .audio.adapter import AudioAdapter
  28. from .audio.convertor import to_stereo
  29. from .model import EstimatorSpecBuilder, InputProviderFactory, model_fn
  30. from .model.provider import ModelProvider
  31. from .types import AudioDescriptor
  32. from .utils.configuration import load_configuration
  33. # pylint: enable=import-error
  34. __email__ = "spleeter@deezer.com"
  35. __author__ = "Deezer Research"
  36. __license__ = "MIT License"
  37. class DataGenerator(object):
  38. """
  39. Generator object that store a sample and generate it once while called.
  40. Used to feed a tensorflow estimator without knowing the whole data at
  41. build time.
  42. """
  43. def __init__(self) -> None:
  44. """ Default constructor. """
  45. self._current_data = None
  46. def update_data(self, data) -> None:
  47. """ Replace internal data. """
  48. self._current_data = data
  49. def __call__(self) -> Generator:
  50. """ Generation process. """
  51. buffer = self._current_data
  52. while buffer:
  53. yield buffer
  54. buffer = self._current_data
  55. def create_estimator(params, MWF):
  56. """
  57. Initialize tensorflow estimator that will perform separation
  58. Params:
  59. - params: a dictionary of parameters for building the model
  60. Returns:
  61. a tensorflow estimator
  62. """
  63. # Load model.
  64. provider: ModelProvider = ModelProvider.default()
  65. params["model_dir"] = provider.get(params["model_dir"])
  66. params["MWF"] = MWF
  67. # Setup config
  68. session_config = tf.compat.v1.ConfigProto()
  69. session_config.gpu_options.per_process_gpu_memory_fraction = 0.7
  70. config = tf.estimator.RunConfig(session_config=session_config)
  71. # Setup estimator
  72. estimator = tf.estimator.Estimator(
  73. model_fn=model_fn, model_dir=params["model_dir"], params=params, config=config
  74. )
  75. return estimator
  76. class Separator(object):
  77. """ A wrapper class for performing separation. """
  78. def __init__(
  79. self,
  80. params_descriptor: str,
  81. MWF: bool = False,
  82. stft_backend: STFTBackend = STFTBackend.AUTO,
  83. multiprocess: bool = True,
  84. ) -> None:
  85. """
  86. Default constructor.
  87. Parameters:
  88. params_descriptor (str):
  89. Descriptor for TF params to be used.
  90. MWF (bool):
  91. (Optional) `True` if MWF should be used, `False` otherwise.
  92. """
  93. self._params = load_configuration(params_descriptor)
  94. self._sample_rate = self._params["sample_rate"]
  95. self._MWF = MWF
  96. self._tf_graph = tf.Graph()
  97. self._prediction_generator = None
  98. self._input_provider = None
  99. self._builder = None
  100. self._features = None
  101. self._session = None
  102. if multiprocess:
  103. self._pool = Pool()
  104. atexit.register(self._pool.close)
  105. else:
  106. self._pool = None
  107. self._tasks = []
  108. self._params["stft_backend"] = STFTBackend.resolve(stft_backend)
  109. self._data_generator = DataGenerator()
  110. def _get_prediction_generator(self) -> Generator:
  111. """
  112. Lazy loading access method for internal prediction generator
  113. returned by the predict method of a tensorflow estimator.
  114. Returns:
  115. Generator:
  116. Generator of prediction.
  117. """
  118. if self._prediction_generator is None:
  119. estimator = create_estimator(self._params, self._MWF)
  120. def get_dataset():
  121. return tf.data.Dataset.from_generator(
  122. self._data_generator,
  123. output_types={"waveform": tf.float32, "audio_id": tf.string},
  124. output_shapes={"waveform": (None, 2), "audio_id": ()},
  125. )
  126. self._prediction_generator = estimator.predict(
  127. get_dataset, yield_single_examples=False
  128. )
  129. return self._prediction_generator
  130. def join(self, timeout: int = 200) -> None:
  131. """
  132. Wait for all pending tasks to be finished.
  133. Parameters:
  134. timeout (int):
  135. (Optional) task waiting timeout.
  136. """
  137. while len(self._tasks) > 0:
  138. task = self._tasks.pop()
  139. task.get()
  140. task.wait(timeout=timeout)
  141. def _stft(
  142. self, data: np.ndarray, inverse: bool = False, length: Optional[int] = None
  143. ) -> np.ndarray:
  144. """
  145. Single entrypoint for both stft and istft. This computes stft and
  146. istft with librosa on stereo data. The two channels are processed
  147. separately and are concatenated together in the result. The
  148. expected input formats are: (n_samples, 2) for stft and (T, F, 2)
  149. for istft.
  150. Parameters:
  151. data (numpy.array):
  152. Array with either the waveform or the complex spectrogram
  153. depending on the parameter inverse
  154. inverse (bool):
  155. (Optional) Should a stft or an istft be computed.
  156. length (Optional[int]):
  157. Returns:
  158. numpy.ndarray:
  159. Stereo data as numpy array for the transform. The channels
  160. are stored in the last dimension.
  161. """
  162. assert not (inverse and length is None)
  163. data = np.asfortranarray(data)
  164. N = self._params["frame_length"]
  165. H = self._params["frame_step"]
  166. win = hann(N, sym=False)
  167. fstft = istft if inverse else stft
  168. win_len_arg = {"win_length": None, "length": None} if inverse else {"n_fft": N}
  169. n_channels = data.shape[-1]
  170. out = []
  171. for c in range(n_channels):
  172. d = (
  173. np.concatenate((np.zeros((N,)), data[:, c], np.zeros((N,))))
  174. if not inverse
  175. else data[:, :, c].T
  176. )
  177. s = fstft(d, hop_length=H, window=win, center=False, **win_len_arg)
  178. if inverse:
  179. s = s[N : N + length]
  180. s = np.expand_dims(s.T, 2 - inverse)
  181. out.append(s)
  182. if len(out) == 1:
  183. return out[0]
  184. return np.concatenate(out, axis=2 - inverse)
  185. def _get_input_provider(self):
  186. if self._input_provider is None:
  187. self._input_provider = InputProviderFactory.get(self._params)
  188. return self._input_provider
  189. def _get_features(self):
  190. if self._features is None:
  191. provider = self._get_input_provider()
  192. self._features = provider.get_input_dict_placeholders()
  193. return self._features
  194. def _get_builder(self):
  195. if self._builder is None:
  196. self._builder = EstimatorSpecBuilder(self._get_features(), self._params)
  197. return self._builder
  198. def _get_session(self):
  199. if self._session is None:
  200. saver = tf.compat.v1.train.Saver()
  201. provider = ModelProvider.default()
  202. model_directory: str = provider.get(self._params["model_dir"])
  203. latest_checkpoint = tf.train.latest_checkpoint(model_directory)
  204. self._session = tf.compat.v1.Session()
  205. saver.restore(self._session, latest_checkpoint)
  206. return self._session
  207. def _separate_librosa(
  208. self, waveform: np.ndarray, audio_descriptor: AudioDescriptor
  209. ) -> Dict:
  210. """
  211. Performs separation with librosa backend for STFT.
  212. Parameters:
  213. waveform (numpy.ndarray):
  214. Waveform to be separated (as a numpy array)
  215. audio_descriptor (AudioDescriptor):
  216. """
  217. with self._tf_graph.as_default():
  218. out = {}
  219. features = self._get_features()
  220. # TODO: fix the logic, build sometimes return,
  221. # sometimes set attribute.
  222. outputs = self._get_builder().outputs
  223. stft = self._stft(waveform)
  224. if stft.shape[-1] == 1:
  225. stft = np.concatenate([stft, stft], axis=-1)
  226. elif stft.shape[-1] > 2:
  227. stft = stft[:, :2]
  228. sess = self._get_session()
  229. outputs = sess.run(
  230. outputs,
  231. feed_dict=self._get_input_provider().get_feed_dict(
  232. features, stft, audio_descriptor
  233. ),
  234. )
  235. for inst in self._get_builder().instruments:
  236. out[inst] = self._stft(
  237. outputs[inst], inverse=True, length=waveform.shape[0]
  238. )
  239. return out
  240. def _separate_tensorflow(
  241. self, waveform: np.ndarray, audio_descriptor: AudioDescriptor
  242. ) -> Dict:
  243. """
  244. Performs source separation over the given waveform with tensorflow
  245. backend.
  246. Parameters:
  247. waveform (numpy.ndarray):
  248. Waveform to be separated (as a numpy array)
  249. audio_descriptor (AudioDescriptor):
  250. Returns:
  251. Separated waveforms.
  252. """
  253. if not waveform.shape[-1] == 2:
  254. waveform = to_stereo(waveform)
  255. prediction_generator = self._get_prediction_generator()
  256. # NOTE: update data in generator before performing separation.
  257. self._data_generator.update_data(
  258. {"waveform": waveform, "audio_id": np.array(audio_descriptor)}
  259. )
  260. # NOTE: perform separation.
  261. prediction = next(prediction_generator)
  262. prediction.pop("audio_id")
  263. return prediction
  264. def separate(
  265. self, waveform: np.ndarray, audio_descriptor: Optional[str] = ""
  266. ) -> None:
  267. """
  268. Performs separation on a waveform.
  269. Parameters:
  270. waveform (numpy.ndarray):
  271. Waveform to be separated (as a numpy array)
  272. audio_descriptor (str):
  273. (Optional) string describing the waveform (e.g. filename).
  274. """
  275. backend: str = self._params["stft_backend"]
  276. if backend == STFTBackend.TENSORFLOW:
  277. return self._separate_tensorflow(waveform, audio_descriptor)
  278. elif backend == STFTBackend.LIBROSA:
  279. return self._separate_librosa(waveform, audio_descriptor)
  280. raise ValueError(f"Unsupported STFT backend {backend}")
  281. def separate_to_file(
  282. self,
  283. audio_descriptor: AudioDescriptor,
  284. destination: str,
  285. audio_adapter: Optional[AudioAdapter] = None,
  286. offset: int = 0,
  287. duration: float = 600.0,
  288. codec: Codec = Codec.WAV,
  289. bitrate: str = "128k",
  290. filename_format: str = "{filename}/{instrument}.{codec}",
  291. synchronous: bool = True,
  292. ) -> None:
  293. """
  294. Performs source separation and export result to file using
  295. given audio adapter.
  296. Filename format should be a Python formattable string that could
  297. use following parameters :
  298. - {instrument}
  299. - {filename}
  300. - {foldername}
  301. - {codec}.
  302. Parameters:
  303. audio_descriptor (AudioDescriptor):
  304. Describe song to separate, used by audio adapter to
  305. retrieve and load audio data, in case of file based
  306. audio adapter, such descriptor would be a file path.
  307. destination (str):
  308. Target directory to write output to.
  309. audio_adapter (Optional[AudioAdapter]):
  310. (Optional) Audio adapter to use for I/O.
  311. offset (int):
  312. (Optional) Offset of loaded song.
  313. duration (float):
  314. (Optional) Duration of loaded song (default: 600s).
  315. codec (Codec):
  316. (Optional) Export codec.
  317. bitrate (str):
  318. (Optional) Export bitrate.
  319. filename_format (str):
  320. (Optional) Filename format.
  321. synchronous (bool):
  322. (Optional) True is should by synchronous.
  323. """
  324. if audio_adapter is None:
  325. audio_adapter = AudioAdapter.default()
  326. waveform, _ = audio_adapter.load(
  327. audio_descriptor,
  328. offset=offset,
  329. duration=duration,
  330. sample_rate=self._sample_rate,
  331. )
  332. sources = self.separate(waveform, audio_descriptor)
  333. self.save_to_file(
  334. sources,
  335. audio_descriptor,
  336. destination,
  337. filename_format,
  338. codec,
  339. audio_adapter,
  340. bitrate,
  341. synchronous,
  342. )
  343. def save_to_file(
  344. self,
  345. sources: Dict,
  346. audio_descriptor: AudioDescriptor,
  347. destination: str,
  348. filename_format: str = "{filename}/{instrument}.{codec}",
  349. codec: Codec = Codec.WAV,
  350. audio_adapter: Optional[AudioAdapter] = None,
  351. bitrate: str = "128k",
  352. synchronous: bool = True,
  353. ) -> None:
  354. """
  355. Export dictionary of sources to files.
  356. Parameters:
  357. sources (Dict):
  358. Dictionary of sources to be exported. The keys are the name
  359. of the instruments, and the values are `N x 2` numpy arrays
  360. containing the corresponding intrument waveform, as
  361. returned by the separate method
  362. audio_descriptor (AudioDescriptor):
  363. Describe song to separate, used by audio adapter to
  364. retrieve and load audio data, in case of file based audio
  365. adapter, such descriptor would be a file path.
  366. destination (str):
  367. Target directory to write output to.
  368. filename_format (str):
  369. (Optional) Filename format.
  370. codec (Codec):
  371. (Optional) Export codec.
  372. audio_adapter (Optional[AudioAdapter]):
  373. (Optional) Audio adapter to use for I/O.
  374. bitrate (str):
  375. (Optional) Export bitrate.
  376. synchronous (bool):
  377. (Optional) True is should by synchronous.
  378. """
  379. if audio_adapter is None:
  380. audio_adapter = AudioAdapter.default()
  381. foldername = basename(dirname(audio_descriptor))
  382. filename = splitext(basename(audio_descriptor))[0]
  383. generated = []
  384. for instrument, data in sources.items():
  385. path = join(
  386. destination,
  387. filename_format.format(
  388. filename=filename,
  389. instrument=instrument,
  390. foldername=foldername,
  391. codec=codec,
  392. ),
  393. )
  394. directory = os.path.dirname(path)
  395. if not os.path.exists(directory):
  396. os.makedirs(directory)
  397. if path in generated:
  398. raise SpleeterError(
  399. (
  400. f"Separated source path conflict : {path},"
  401. "please check your filename format"
  402. )
  403. )
  404. generated.append(path)
  405. if self._pool:
  406. task = self._pool.apply_async(
  407. audio_adapter.save, (path, data, self._sample_rate, codec, bitrate)
  408. )
  409. self._tasks.append(task)
  410. else:
  411. audio_adapter.save(path, data, self._sample_rate, codec, bitrate)
  412. if synchronous and self._pool:
  413. self.join()