You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

625 lines
21 KiB

2 years ago
  1. #!/usr/bin/env python
  2. # coding: utf8
  3. """
  4. Module for building data preprocessing pipeline using the tensorflow
  5. data API. Data preprocessing such as audio loading, spectrogram
  6. computation, cropping, feature caching or data augmentation is done
  7. using a tensorflow dataset object that output a tuple (input_, output)
  8. where:
  9. - input is a dictionary with a single key that contains the (batched)
  10. mix spectrogram of audio samples
  11. - output is a dictionary of spectrogram of the isolated tracks
  12. (ground truth)
  13. """
  14. import os
  15. import time
  16. from os.path import exists
  17. from os.path import sep as SEPARATOR
  18. from typing import Any, Dict, Optional
  19. # pyright: reportMissingImports=false
  20. # pylint: disable=import-error
  21. import tensorflow as tf
  22. from .audio.adapter import AudioAdapter
  23. from .audio.convertor import db_uint_spectrogram_to_gain, spectrogram_to_db_uint
  24. from .audio.spectrogram import (
  25. compute_spectrogram_tf,
  26. random_pitch_shift,
  27. random_time_stretch,
  28. )
  29. from .utils.logging import logger
  30. from .utils.tensor import (
  31. check_tensor_shape,
  32. dataset_from_csv,
  33. set_tensor_shape,
  34. sync_apply,
  35. )
  36. # pylint: enable=import-error
  37. __email__ = "spleeter@deezer.com"
  38. __author__ = "Deezer Research"
  39. __license__ = "MIT License"
  40. # Default audio parameters to use.
  41. DEFAULT_AUDIO_PARAMS: Dict = {
  42. "instrument_list": ("vocals", "accompaniment"),
  43. "mix_name": "mix",
  44. "sample_rate": 44100,
  45. "frame_length": 4096,
  46. "frame_step": 1024,
  47. "T": 512,
  48. "F": 1024,
  49. }
  50. def get_training_dataset(
  51. audio_params: Dict, audio_adapter: AudioAdapter, audio_path: str
  52. ) -> Any:
  53. """
  54. Builds training dataset.
  55. Parameters:
  56. audio_params (Dict):
  57. Audio parameters.
  58. audio_adapter (AudioAdapter):
  59. Adapter to load audio from.
  60. audio_path (str):
  61. Path of directory containing audio.
  62. Returns:
  63. Any:
  64. Built dataset.
  65. """
  66. builder = DatasetBuilder(
  67. audio_params,
  68. audio_adapter,
  69. audio_path,
  70. chunk_duration=audio_params.get("chunk_duration", 20.0),
  71. random_seed=audio_params.get("random_seed", 0),
  72. )
  73. return builder.build(
  74. audio_params.get("train_csv"),
  75. cache_directory=audio_params.get("training_cache"),
  76. batch_size=audio_params.get("batch_size"),
  77. n_chunks_per_song=audio_params.get("n_chunks_per_song", 2),
  78. random_data_augmentation=False,
  79. convert_to_uint=True,
  80. wait_for_cache=False,
  81. )
  82. def get_validation_dataset(
  83. audio_params: Dict, audio_adapter: AudioAdapter, audio_path: str
  84. ) -> Any:
  85. """
  86. Builds validation dataset.
  87. Parameters:
  88. audio_params (Dict):
  89. Audio parameters.
  90. audio_adapter (AudioAdapter):
  91. Adapter to load audio from.
  92. audio_path (str):
  93. Path of directory containing audio.
  94. Returns:
  95. Any:
  96. Built dataset.
  97. """
  98. builder = DatasetBuilder(
  99. audio_params, audio_adapter, audio_path, chunk_duration=12.0
  100. )
  101. return builder.build(
  102. audio_params.get("validation_csv"),
  103. batch_size=audio_params.get("batch_size"),
  104. cache_directory=audio_params.get("validation_cache"),
  105. convert_to_uint=True,
  106. infinite_generator=False,
  107. n_chunks_per_song=1,
  108. # should not perform data augmentation for eval:
  109. random_data_augmentation=False,
  110. random_time_crop=False,
  111. shuffle=False,
  112. )
  113. class InstrumentDatasetBuilder(object):
  114. """ Instrument based filter and mapper provider. """
  115. def __init__(self, parent, instrument) -> None:
  116. """
  117. Default constructor.
  118. Parameters:
  119. parent:
  120. Parent dataset builder.
  121. instrument:
  122. Target instrument.
  123. """
  124. self._parent = parent
  125. self._instrument = instrument
  126. self._spectrogram_key = f"{instrument}_spectrogram"
  127. self._min_spectrogram_key = f"min_{instrument}_spectrogram"
  128. self._max_spectrogram_key = f"max_{instrument}_spectrogram"
  129. def load_waveform(self, sample):
  130. """ Load waveform for given sample. """
  131. return dict(
  132. sample,
  133. **self._parent._audio_adapter.load_tf_waveform(
  134. sample[f"{self._instrument}_path"],
  135. offset=sample["start"],
  136. duration=self._parent._chunk_duration,
  137. sample_rate=self._parent._sample_rate,
  138. waveform_name="waveform",
  139. ),
  140. )
  141. def compute_spectrogram(self, sample):
  142. """ Compute spectrogram of the given sample. """
  143. return dict(
  144. sample,
  145. **{
  146. self._spectrogram_key: compute_spectrogram_tf(
  147. sample["waveform"],
  148. frame_length=self._parent._frame_length,
  149. frame_step=self._parent._frame_step,
  150. spec_exponent=1.0,
  151. window_exponent=1.0,
  152. )
  153. },
  154. )
  155. def filter_frequencies(self, sample):
  156. """ """
  157. return dict(
  158. sample,
  159. **{
  160. self._spectrogram_key: sample[self._spectrogram_key][
  161. :, : self._parent._F, :
  162. ]
  163. },
  164. )
  165. def convert_to_uint(self, sample):
  166. """ Convert given sample from float to unit. """
  167. return dict(
  168. sample,
  169. **spectrogram_to_db_uint(
  170. sample[self._spectrogram_key],
  171. tensor_key=self._spectrogram_key,
  172. min_key=self._min_spectrogram_key,
  173. max_key=self._max_spectrogram_key,
  174. ),
  175. )
  176. def filter_infinity(self, sample):
  177. """ Filter infinity sample. """
  178. return tf.logical_not(tf.math.is_inf(sample[self._min_spectrogram_key]))
  179. def convert_to_float32(self, sample):
  180. """ Convert given sample from unit to float. """
  181. return dict(
  182. sample,
  183. **{
  184. self._spectrogram_key: db_uint_spectrogram_to_gain(
  185. sample[self._spectrogram_key],
  186. sample[self._min_spectrogram_key],
  187. sample[self._max_spectrogram_key],
  188. )
  189. },
  190. )
  191. def time_crop(self, sample):
  192. """ """
  193. def start(sample):
  194. """ mid_segment_start """
  195. return tf.cast(
  196. tf.maximum(
  197. tf.shape(sample[self._spectrogram_key])[0] / 2
  198. - self._parent._T / 2,
  199. 0,
  200. ),
  201. tf.int32,
  202. )
  203. return dict(
  204. sample,
  205. **{
  206. self._spectrogram_key: sample[self._spectrogram_key][
  207. start(sample) : start(sample) + self._parent._T, :, :
  208. ]
  209. },
  210. )
  211. def filter_shape(self, sample):
  212. """ Filter badly shaped sample. """
  213. return check_tensor_shape(
  214. sample[self._spectrogram_key],
  215. (self._parent._T, self._parent._F, self._parent._n_channels),
  216. )
  217. def reshape_spectrogram(self, sample):
  218. """ Reshape given sample. """
  219. return dict(
  220. sample,
  221. **{
  222. self._spectrogram_key: set_tensor_shape(
  223. sample[self._spectrogram_key],
  224. (self._parent._T, self._parent._F, self._parent._n_channels),
  225. )
  226. },
  227. )
  228. class DatasetBuilder(object):
  229. """
  230. TO BE DOCUMENTED.
  231. """
  232. MARGIN: float = 0.5
  233. """ Margin at beginning and end of songs in seconds. """
  234. WAIT_PERIOD: int = 60
  235. """ Wait period for cache (in seconds). """
  236. def __init__(
  237. self,
  238. audio_params: Dict,
  239. audio_adapter: AudioAdapter,
  240. audio_path: str,
  241. random_seed: int = 0,
  242. chunk_duration: float = 20.0,
  243. ) -> None:
  244. """
  245. Default constructor.
  246. NOTE: Probably need for AudioAdapter.
  247. Parameters:
  248. audio_params (Dict):
  249. Audio parameters to use.
  250. audio_adapter (AudioAdapter):
  251. Audio adapter to use.
  252. audio_path (str):
  253. random_seed (int):
  254. chunk_duration (float):
  255. """
  256. # Length of segment in frames (if fs=22050 and
  257. # frame_step=512, then T=512 corresponds to 11.89s)
  258. self._T = audio_params["T"]
  259. # Number of frequency bins to be used (should
  260. # be less than frame_length/2 + 1)
  261. self._F = audio_params["F"]
  262. self._sample_rate = audio_params["sample_rate"]
  263. self._frame_length = audio_params["frame_length"]
  264. self._frame_step = audio_params["frame_step"]
  265. self._mix_name = audio_params["mix_name"]
  266. self._n_channels = audio_params["n_channels"]
  267. self._instruments = [self._mix_name] + audio_params["instrument_list"]
  268. self._instrument_builders = None
  269. self._chunk_duration = chunk_duration
  270. self._audio_adapter = audio_adapter
  271. self._audio_params = audio_params
  272. self._audio_path = audio_path
  273. self._random_seed = random_seed
  274. self.check_parameters_compatibility()
  275. def check_parameters_compatibility(self):
  276. if self._frame_length / 2 + 1 < self._F:
  277. raise ValueError(
  278. "F is too large and must be set to at most frame_length/2+1. Decrease F or increase frame_length to fix."
  279. )
  280. if (
  281. self._chunk_duration * self._sample_rate - self._frame_length
  282. ) / self._frame_step < self._T:
  283. raise ValueError(
  284. "T is too large considering STFT parameters and chunk duratoin. Make sure spectrogram time dimension of chunks is larger than T (for instance reducing T or frame_step or increasing chunk duration)."
  285. )
  286. def expand_path(self, sample):
  287. """ Expands audio paths for the given sample. """
  288. return dict(
  289. sample,
  290. **{
  291. f"{instrument}_path": tf.strings.join(
  292. (self._audio_path, sample[f"{instrument}_path"]), SEPARATOR
  293. )
  294. for instrument in self._instruments
  295. },
  296. )
  297. def filter_error(self, sample):
  298. """ Filter errored sample. """
  299. return tf.logical_not(sample["waveform_error"])
  300. def filter_waveform(self, sample):
  301. """ Filter waveform from sample. """
  302. return {k: v for k, v in sample.items() if not k == "waveform"}
  303. def harmonize_spectrogram(self, sample):
  304. """ Ensure same size for vocals and mix spectrograms. """
  305. def _reduce(sample):
  306. return tf.reduce_min(
  307. [
  308. tf.shape(sample[f"{instrument}_spectrogram"])[0]
  309. for instrument in self._instruments
  310. ]
  311. )
  312. return dict(
  313. sample,
  314. **{
  315. f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"][
  316. : _reduce(sample), :, :
  317. ]
  318. for instrument in self._instruments
  319. },
  320. )
  321. def filter_short_segments(self, sample):
  322. """ Filter out too short segment. """
  323. return tf.reduce_any(
  324. [
  325. tf.shape(sample[f"{instrument}_spectrogram"])[0] >= self._T
  326. for instrument in self._instruments
  327. ]
  328. )
  329. def random_time_crop(self, sample):
  330. """ Random time crop of 11.88s. """
  331. return dict(
  332. sample,
  333. **sync_apply(
  334. {
  335. f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
  336. for instrument in self._instruments
  337. },
  338. lambda x: tf.image.random_crop(
  339. x,
  340. (self._T, len(self._instruments) * self._F, self._n_channels),
  341. seed=self._random_seed,
  342. ),
  343. ),
  344. )
  345. def random_time_stretch(self, sample):
  346. """ Randomly time stretch the given sample. """
  347. return dict(
  348. sample,
  349. **sync_apply(
  350. {
  351. f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
  352. for instrument in self._instruments
  353. },
  354. lambda x: random_time_stretch(x, factor_min=0.9, factor_max=1.1),
  355. ),
  356. )
  357. def random_pitch_shift(self, sample):
  358. """ Randomly pitch shift the given sample. """
  359. return dict(
  360. sample,
  361. **sync_apply(
  362. {
  363. f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
  364. for instrument in self._instruments
  365. },
  366. lambda x: random_pitch_shift(x, shift_min=-1.0, shift_max=1.0),
  367. concat_axis=0,
  368. ),
  369. )
  370. def map_features(self, sample):
  371. """ Select features and annotation of the given sample. """
  372. input_ = {
  373. f"{self._mix_name}_spectrogram": sample[f"{self._mix_name}_spectrogram"]
  374. }
  375. output = {
  376. f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
  377. for instrument in self._audio_params["instrument_list"]
  378. }
  379. return (input_, output)
  380. def compute_segments(self, dataset: Any, n_chunks_per_song: int) -> Any:
  381. """
  382. Computes segments for each song of the dataset.
  383. Parameters:
  384. dataset (Any):
  385. Dataset to compute segments for.
  386. n_chunks_per_song (int):
  387. Number of segment per song to compute.
  388. Returns:
  389. Any:
  390. Segmented dataset.
  391. """
  392. if n_chunks_per_song <= 0:
  393. raise ValueError("n_chunks_per_song must be positif")
  394. datasets = []
  395. for k in range(n_chunks_per_song):
  396. if n_chunks_per_song > 1:
  397. datasets.append(
  398. dataset.map(
  399. lambda sample: dict(
  400. sample,
  401. start=tf.maximum(
  402. k
  403. * (
  404. sample["duration"]
  405. - self._chunk_duration
  406. - 2 * self.MARGIN
  407. )
  408. / (n_chunks_per_song - 1)
  409. + self.MARGIN,
  410. 0,
  411. ),
  412. )
  413. )
  414. )
  415. elif n_chunks_per_song == 1: # Take central segment.
  416. datasets.append(
  417. dataset.map(
  418. lambda sample: dict(
  419. sample,
  420. start=tf.maximum(
  421. sample["duration"] / 2 - self._chunk_duration / 2, 0
  422. ),
  423. )
  424. )
  425. )
  426. dataset = datasets[-1]
  427. for d in datasets[:-1]:
  428. dataset = dataset.concatenate(d)
  429. return dataset
  430. @property
  431. def instruments(self) -> Any:
  432. """
  433. Instrument dataset builder generator.
  434. Yields:
  435. Any:
  436. InstrumentBuilder instance.
  437. """
  438. if self._instrument_builders is None:
  439. self._instrument_builders = []
  440. for instrument in self._instruments:
  441. self._instrument_builders.append(
  442. InstrumentDatasetBuilder(self, instrument)
  443. )
  444. for builder in self._instrument_builders:
  445. yield builder
  446. def cache(self, dataset: Any, cache: str, wait: bool) -> Any:
  447. """
  448. Cache the given dataset if cache is enabled. Eventually waits for
  449. cache to be available (useful if another process is already
  450. computing cache) if provided wait flag is `True`.
  451. Parameters:
  452. dataset (Any):
  453. Dataset to be cached if cache is required.
  454. cache (str):
  455. Path of cache directory to be used, None if no cache.
  456. wait (bool):
  457. If caching is enabled, True is cache should be waited.
  458. Returns:
  459. Any:
  460. Cached dataset if needed, original dataset otherwise.
  461. """
  462. if cache is not None:
  463. if wait:
  464. while not exists(f"{cache}.index"):
  465. logger.info(f"Cache not available, wait {self.WAIT_PERIOD}")
  466. time.sleep(self.WAIT_PERIOD)
  467. cache_path = os.path.split(cache)[0]
  468. os.makedirs(cache_path, exist_ok=True)
  469. return dataset.cache(cache)
  470. return dataset
  471. def build(
  472. self,
  473. csv_path: str,
  474. batch_size: int = 8,
  475. shuffle: bool = True,
  476. convert_to_uint: bool = True,
  477. random_data_augmentation: bool = False,
  478. random_time_crop: bool = True,
  479. infinite_generator: bool = True,
  480. cache_directory: Optional[str] = None,
  481. wait_for_cache: bool = False,
  482. num_parallel_calls: int = 4,
  483. n_chunks_per_song: float = 2,
  484. ) -> Any:
  485. """
  486. TO BE DOCUMENTED.
  487. """
  488. dataset = dataset_from_csv(csv_path)
  489. dataset = self.compute_segments(dataset, n_chunks_per_song)
  490. # Shuffle data
  491. if shuffle:
  492. dataset = dataset.shuffle(
  493. buffer_size=200000,
  494. seed=self._random_seed,
  495. # useless since it is cached :
  496. reshuffle_each_iteration=True,
  497. )
  498. # Expand audio path.
  499. dataset = dataset.map(self.expand_path)
  500. # Load waveform, compute spectrogram, and filtering error,
  501. # K bins frequencies, and waveform.
  502. N = num_parallel_calls
  503. for instrument in self.instruments:
  504. dataset = (
  505. dataset.map(instrument.load_waveform, num_parallel_calls=N)
  506. .filter(self.filter_error)
  507. .map(instrument.compute_spectrogram, num_parallel_calls=N)
  508. .map(instrument.filter_frequencies)
  509. )
  510. dataset = dataset.map(self.filter_waveform)
  511. # Convert to uint before caching in order to save space.
  512. if convert_to_uint:
  513. for instrument in self.instruments:
  514. dataset = dataset.map(instrument.convert_to_uint)
  515. dataset = self.cache(dataset, cache_directory, wait_for_cache)
  516. # Check for INFINITY (should not happen)
  517. for instrument in self.instruments:
  518. dataset = dataset.filter(instrument.filter_infinity)
  519. # Repeat indefinitly
  520. if infinite_generator:
  521. dataset = dataset.repeat(count=-1)
  522. # Ensure same size for vocals and mix spectrograms.
  523. # NOTE: could be done before caching ?
  524. dataset = dataset.map(self.harmonize_spectrogram)
  525. # Filter out too short segment.
  526. # NOTE: could be done before caching ?
  527. dataset = dataset.filter(self.filter_short_segments)
  528. # Random time crop of 11.88s
  529. if random_time_crop:
  530. dataset = dataset.map(self.random_time_crop, num_parallel_calls=N)
  531. else:
  532. # frame_duration = 11.88/T
  533. # take central segment (for validation)
  534. for instrument in self.instruments:
  535. dataset = dataset.map(instrument.time_crop)
  536. # Post cache shuffling. Done where the data are the lightest:
  537. # after croping but before converting back to float.
  538. if shuffle:
  539. dataset = dataset.shuffle(
  540. buffer_size=256, seed=self._random_seed, reshuffle_each_iteration=True
  541. )
  542. # Convert back to float32
  543. if convert_to_uint:
  544. for instrument in self.instruments:
  545. dataset = dataset.map(
  546. instrument.convert_to_float32, num_parallel_calls=N
  547. )
  548. M = 8 # Parallel call post caching.
  549. # Must be applied with the same factor on mix and vocals.
  550. if random_data_augmentation:
  551. dataset = dataset.map(self.random_time_stretch, num_parallel_calls=M).map(
  552. self.random_pitch_shift, num_parallel_calls=M
  553. )
  554. # Filter by shape (remove badly shaped tensors).
  555. for instrument in self.instruments:
  556. dataset = dataset.filter(instrument.filter_shape).map(
  557. instrument.reshape_spectrogram
  558. )
  559. # Select features and annotation.
  560. dataset = dataset.map(self.map_features)
  561. # Make batch (done after selection to avoid
  562. # error due to unprocessed instrument spectrogram batching).
  563. dataset = dataset.batch(batch_size)
  564. return dataset