You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

200 lines
6.4 KiB

2 years ago
  1. #!/usr/bin/env python
  2. # coding: utf8
  3. """ AudioAdapter class defintion. """
  4. from abc import ABC, abstractmethod
  5. from importlib import import_module
  6. from pathlib import Path
  7. from typing import Any, Dict, List, Optional, Union
  8. # pyright: reportMissingImports=false
  9. # pylint: disable=import-error
  10. import numpy as np
  11. import tensorflow as tf
  12. from spleeter.audio import Codec
  13. from .. import SpleeterError
  14. from ..types import AudioDescriptor, Signal
  15. from ..utils.logging import logger
  16. # pylint: enable=import-error
  17. __email__ = "spleeter@deezer.com"
  18. __author__ = "Deezer Research"
  19. __license__ = "MIT License"
  20. class AudioAdapter(ABC):
  21. """ An abstract class for manipulating audio signal. """
  22. _DEFAULT: "AudioAdapter" = None
  23. """ Default audio adapter singleton instance. """
  24. @abstractmethod
  25. def load(
  26. self,
  27. audio_descriptor: AudioDescriptor,
  28. offset: Optional[float] = None,
  29. duration: Optional[float] = None,
  30. sample_rate: Optional[float] = None,
  31. dtype: np.dtype = np.float32,
  32. ) -> Signal:
  33. """
  34. Loads the audio file denoted by the given audio descriptor and
  35. returns it data as a waveform. Aims to be implemented by client.
  36. Parameters:
  37. audio_descriptor (AudioDescriptor):
  38. Describe song to load, in case of file based audio adapter,
  39. such descriptor would be a file path.
  40. offset (Optional[float]):
  41. Start offset to load from in seconds.
  42. duration (Optional[float]):
  43. Duration to load in seconds.
  44. sample_rate (Optional[float]):
  45. Sample rate to load audio with.
  46. dtype (numpy.dtype):
  47. (Optional) Numpy data type to use, default to `float32`.
  48. Returns:
  49. Signal:
  50. Loaded data as (wf, sample_rate) tuple.
  51. """
  52. pass
  53. def load_tf_waveform(
  54. self,
  55. audio_descriptor,
  56. offset: float = 0.0,
  57. duration: float = 1800.0,
  58. sample_rate: int = 44100,
  59. dtype: bytes = b"float32",
  60. waveform_name: str = "waveform",
  61. ) -> Dict[str, Any]:
  62. """
  63. Load the audio and convert it to a tensorflow waveform.
  64. Parameters:
  65. audio_descriptor ():
  66. Describe song to load, in case of file based audio adapter,
  67. such descriptor would be a file path.
  68. offset (float):
  69. Start offset to load from in seconds.
  70. duration (float):
  71. Duration to load in seconds.
  72. sample_rate (float):
  73. Sample rate to load audio with.
  74. dtype (bytes):
  75. (Optional)data type to use, default to `b'float32'`.
  76. waveform_name (str):
  77. (Optional) Name of the key in output dict, default to
  78. `'waveform'`.
  79. Returns:
  80. Dict[str, Any]:
  81. TF output dict with waveform as `(T x chan numpy array)`
  82. and a boolean that tells whether there were an error while
  83. trying to load the waveform.
  84. """
  85. # Cast parameters to TF format.
  86. offset = tf.cast(offset, tf.float64)
  87. duration = tf.cast(duration, tf.float64)
  88. # Defined safe loading function.
  89. def safe_load(path, offset, duration, sample_rate, dtype):
  90. logger.info(f"Loading audio {path} from {offset} to {offset + duration}")
  91. try:
  92. (data, _) = self.load(
  93. path.numpy(),
  94. offset.numpy(),
  95. duration.numpy(),
  96. sample_rate.numpy(),
  97. dtype=dtype.numpy(),
  98. )
  99. logger.info("Audio data loaded successfully")
  100. return (data, False)
  101. except Exception as e:
  102. logger.exception("An error occurs while loading audio", exc_info=e)
  103. return (np.float32(-1.0), True)
  104. # Execute function and format results.
  105. results = (
  106. tf.py_function(
  107. safe_load,
  108. [audio_descriptor, offset, duration, sample_rate, dtype],
  109. (tf.float32, tf.bool),
  110. ),
  111. )
  112. waveform, error = results[0]
  113. return {waveform_name: waveform, f"{waveform_name}_error": error}
  114. @abstractmethod
  115. def save(
  116. self,
  117. path: Union[Path, str],
  118. data: np.ndarray,
  119. sample_rate: float,
  120. codec: Codec = None,
  121. bitrate: str = None,
  122. ) -> None:
  123. """
  124. Save the given audio data to the file denoted by the given path.
  125. Parameters:
  126. path (Union[Path, str]):
  127. Path like of the audio file to save data in.
  128. data (numpy.ndarray):
  129. Waveform data to write.
  130. sample_rate (float):
  131. Sample rate to write file in.
  132. codec ():
  133. (Optional) Writing codec to use, default to `None`.
  134. bitrate (str):
  135. (Optional) Bitrate of the written audio file, default to
  136. `None`.
  137. """
  138. pass
  139. @classmethod
  140. def default(cls: type) -> "AudioAdapter":
  141. """
  142. Builds and returns a default audio adapter instance.
  143. Returns:
  144. AudioAdapter:
  145. Default adapter instance to use.
  146. """
  147. if cls._DEFAULT is None:
  148. from .ffmpeg import FFMPEGProcessAudioAdapter
  149. cls._DEFAULT = FFMPEGProcessAudioAdapter()
  150. return cls._DEFAULT
  151. @classmethod
  152. def get(cls: type, descriptor: str) -> "AudioAdapter":
  153. """
  154. Load dynamically an AudioAdapter from given class descriptor.
  155. Parameters:
  156. descriptor (str):
  157. Adapter class descriptor (module.Class)
  158. Returns:
  159. AudioAdapter:
  160. Created adapter instance.
  161. """
  162. if not descriptor:
  163. return cls.default()
  164. module_path: List[str] = descriptor.split(".")
  165. adapter_class_name: str = module_path[-1]
  166. module_path: str = ".".join(module_path[:-1])
  167. adapter_module = import_module(module_path)
  168. adapter_class = getattr(adapter_module, adapter_class_name)
  169. if not issubclass(adapter_class, AudioAdapter):
  170. raise SpleeterError(
  171. f"{adapter_class_name} is not a valid AudioAdapter class"
  172. )
  173. return adapter_class()