You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

74 lines
2.2 KiB

2 years ago
  1. import torch
  2. import torchaudio
  3. import matplotlib as plt
  4. import musdb
  5. import os
  6. import numpy as np
  7. import glob
  8. import librosa
  9. import soundfile
  10. def load(path, sr=22050, mono=True, mode="numpy", offset=0.0, duration=None):
  11. y, curr_sr = librosa.load(path, sr=sr, mono=mono, res_type='kaiser_fast', offset=offset, duration=duration)
  12. if len(y.shape) == 1:
  13. # Expand channel dimension
  14. y = y[np.newaxis, :]
  15. if mode == "pytorch":
  16. y = torch.tensor(y)
  17. return y, curr_sr
  18. def write_wav(path, audio, sr):
  19. soundfile.write(path, audio.T, sr, "PCM_16")
  20. def get_musdbhq(database_path):
  21. '''
  22. Retrieve audio file paths for MUSDB HQ dataset
  23. :param database_path: MUSDB HQ root directory
  24. :return: dictionary with train and test keys, each containing list of samples, each sample containing all audio paths
  25. '''
  26. subsets = list()
  27. for subset in ["train", "test"]:
  28. print("Loading " + subset + " set...")
  29. tracks = glob.glob(os.path.join(database_path, subset, "*"))
  30. samples = list()
  31. # Go through tracks
  32. for track_folder in sorted(tracks):
  33. # Skip track if mixture is already written, assuming this track is done already
  34. example = dict()
  35. for stem in ["mix", "bass", "drums", "other", "vocals"]:
  36. filename = stem if stem != "mix" else "mixture"
  37. audio_path = os.path.join(track_folder, filename + ".wav")
  38. example[stem] = audio_path
  39. # Add other instruments to form accompaniment
  40. acc_path = os.path.join(track_folder, "accompaniment.wav")
  41. if not os.path.exists(acc_path):
  42. print("Writing accompaniment to " + track_folder)
  43. stem_audio = []
  44. for stem in ["bass", "drums", "other"]:
  45. audio, sr = load(example[stem], sr=None, mono=False)
  46. stem_audio.append(audio)
  47. acc_audio = np.clip(sum(stem_audio), -1.0, 1.0)
  48. write_wav(acc_path, acc_audio, sr)
  49. example["accompaniment"] = acc_path
  50. samples.append(example)
  51. subsets.append(samples)
  52. return subsets
  53. path = "C:/Users/IAN/Desktop/Wave-U-Net/musdb18-hq/"
  54. res = get_musdbhq(path)
  55. print(res)