Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

139 řádky
3.6 KiB

před 2 roky
  1. #!/usr/bin/env python
  2. # coding: utf8
  3. """ This module provides audio data convertion functions. """
  4. # pyright: reportMissingImports=false
  5. # pylint: disable=import-error
  6. import numpy as np
  7. import tensorflow as tf
  8. from ..utils.tensor import from_float32_to_uint8, from_uint8_to_float32
  9. # pylint: enable=import-error
  10. __email__ = "spleeter@deezer.com"
  11. __author__ = "Deezer Research"
  12. __license__ = "MIT License"
  13. def to_n_channels(waveform: tf.Tensor, n_channels: int) -> tf.Tensor:
  14. """
  15. Convert a waveform to n_channels by removing or duplicating channels if
  16. needed (in tensorflow).
  17. Parameters:
  18. waveform (tensorflow.Tensor):
  19. Waveform to transform.
  20. n_channels (int):
  21. Number of channel to reshape waveform in.
  22. Returns:
  23. tensorflow.Tensor:
  24. Reshaped waveform.
  25. """
  26. return tf.cond(
  27. tf.shape(waveform)[1] >= n_channels,
  28. true_fn=lambda: waveform[:, :n_channels],
  29. false_fn=lambda: tf.tile(waveform, [1, n_channels])[:, :n_channels],
  30. )
  31. def to_stereo(waveform: np.ndarray) -> np.ndarray:
  32. """
  33. Convert a waveform to stereo by duplicating if mono, or truncating
  34. if too many channels.
  35. Parameters:
  36. waveform (numpy.ndarray):
  37. a `(N, d)` numpy array.
  38. Returns:
  39. numpy.ndarray:
  40. A stereo waveform as a `(N, 1)` numpy array.
  41. """
  42. if waveform.shape[1] == 1:
  43. return np.repeat(waveform, 2, axis=-1)
  44. if waveform.shape[1] > 2:
  45. return waveform[:, :2]
  46. return waveform
  47. def gain_to_db(tensor: tf.Tensor, espilon: float = 10e-10) -> tf.Tensor:
  48. """
  49. Convert from gain to decibel in tensorflow.
  50. Parameters:
  51. tensor (tensorflow.Tensor):
  52. Tensor to convert
  53. epsilon (float):
  54. Operation constant.
  55. Returns:
  56. tensorflow.Tensor:
  57. Converted tensor.
  58. """
  59. return 20.0 / np.log(10) * tf.math.log(tf.maximum(tensor, espilon))
  60. def db_to_gain(tensor: tf.Tensor) -> tf.Tensor:
  61. """
  62. Convert from decibel to gain in tensorflow.
  63. Parameters:
  64. tensor (tensorflow.Tensor):
  65. Tensor to convert
  66. Returns:
  67. tensorflow.Tensor:
  68. Converted tensor.
  69. """
  70. return tf.pow(10.0, (tensor / 20.0))
  71. def spectrogram_to_db_uint(
  72. spectrogram: tf.Tensor, db_range: float = 100.0, **kwargs
  73. ) -> tf.Tensor:
  74. """
  75. Encodes given spectrogram into uint8 using decibel scale.
  76. Parameters:
  77. spectrogram (tensorflow.Tensor):
  78. Spectrogram to be encoded as TF float tensor.
  79. db_range (float):
  80. Range in decibel for encoding.
  81. Returns:
  82. tensorflow.Tensor:
  83. Encoded decibel spectrogram as `uint8` tensor.
  84. """
  85. db_spectrogram: tf.Tensor = gain_to_db(spectrogram)
  86. max_db_spectrogram: tf.Tensor = tf.reduce_max(db_spectrogram)
  87. db_spectrogram: tf.Tensor = tf.maximum(
  88. db_spectrogram, max_db_spectrogram - db_range
  89. )
  90. return from_float32_to_uint8(db_spectrogram, **kwargs)
  91. def db_uint_spectrogram_to_gain(
  92. db_uint_spectrogram: tf.Tensor, min_db: tf.Tensor, max_db: tf.Tensor
  93. ) -> tf.Tensor:
  94. """
  95. Decode spectrogram from uint8 decibel scale.
  96. Paramters:
  97. db_uint_spectrogram (tensorflow.Tensor):
  98. Decibel spectrogram to decode.
  99. min_db (tensorflow.Tensor):
  100. Lower bound limit for decoding.
  101. max_db (tensorflow.Tensor):
  102. Upper bound limit for decoding.
  103. Returns:
  104. tensorflow.Tensor:
  105. Decoded spectrogram as `float32` tensor.
  106. """
  107. db_spectrogram: tf.Tensor = from_uint8_to_float32(
  108. db_uint_spectrogram, min_db, max_db
  109. )
  110. return db_to_gain(db_spectrogram)