import torch
|
|
import torchaudio
|
|
import matplotlib as plt
|
|
import musdb
|
|
import os
|
|
import numpy as np
|
|
import glob
|
|
import librosa
|
|
import soundfile
|
|
|
|
|
|
|
|
def load(path, sr=22050, mono=True, mode="numpy", offset=0.0, duration=None):
|
|
y, curr_sr = librosa.load(path, sr=sr, mono=mono, res_type='kaiser_fast', offset=offset, duration=duration)
|
|
|
|
if len(y.shape) == 1:
|
|
# Expand channel dimension
|
|
y = y[np.newaxis, :]
|
|
|
|
if mode == "pytorch":
|
|
y = torch.tensor(y)
|
|
|
|
return y, curr_sr
|
|
|
|
|
|
def write_wav(path, audio, sr):
|
|
soundfile.write(path, audio.T, sr, "PCM_16")
|
|
|
|
def get_musdbhq(database_path):
|
|
'''
|
|
Retrieve audio file paths for MUSDB HQ dataset
|
|
:param database_path: MUSDB HQ root directory
|
|
:return: dictionary with train and test keys, each containing list of samples, each sample containing all audio paths
|
|
'''
|
|
subsets = list()
|
|
|
|
for subset in ["train", "test"]:
|
|
print("Loading " + subset + " set...")
|
|
tracks = glob.glob(os.path.join(database_path, subset, "*"))
|
|
samples = list()
|
|
|
|
# Go through tracks
|
|
for track_folder in sorted(tracks):
|
|
# Skip track if mixture is already written, assuming this track is done already
|
|
example = dict()
|
|
for stem in ["mix", "bass", "drums", "other", "vocals"]:
|
|
filename = stem if stem != "mix" else "mixture"
|
|
audio_path = os.path.join(track_folder, filename + ".wav")
|
|
example[stem] = audio_path
|
|
|
|
# Add other instruments to form accompaniment
|
|
acc_path = os.path.join(track_folder, "accompaniment.wav")
|
|
|
|
if not os.path.exists(acc_path):
|
|
print("Writing accompaniment to " + track_folder)
|
|
stem_audio = []
|
|
for stem in ["bass", "drums", "other"]:
|
|
audio, sr = load(example[stem], sr=None, mono=False)
|
|
stem_audio.append(audio)
|
|
acc_audio = np.clip(sum(stem_audio), -1.0, 1.0)
|
|
write_wav(acc_path, acc_audio, sr)
|
|
|
|
example["accompaniment"] = acc_path
|
|
|
|
samples.append(example)
|
|
|
|
subsets.append(samples)
|
|
|
|
return subsets
|
|
|
|
path = "C:/Users/IAN/Desktop/Wave-U-Net/musdb18-hq/"
|
|
|
|
res = get_musdbhq(path)
|
|
print(res)
|