diff --git a/west/utils/audio.py b/west/utils/audio.py index c599a96..e46412b 100644 --- a/west/utils/audio.py +++ b/west/utils/audio.py @@ -60,15 +60,17 @@ def mel_spectrogram(y, # print("max value is ", torch.max(y)) global mel_basis, hann_window # noqa - if f"{str(fmax)}_{str(y.device)}" not in mel_basis: + mel_key = f"{n_fft}_{num_mels}_{sampling_rate}_{fmin}_{fmax}_{str(y.device)}" + win_key = f"{win_size}_{str(y.device)}" + if mel_key not in mel_basis: mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) - mel_basis[str(fmax) + "_" + - str(y.device)] = torch.from_numpy(mel).float().to(y.device) - hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + mel_basis[mel_key] = torch.from_numpy(mel).float().to(y.device) + if win_key not in hann_window: + hann_window[win_key] = torch.hann_window(win_size).to(y.device) y = torch.nn.functional.pad(y.unsqueeze(1), (int( (n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), @@ -81,7 +83,7 @@ def mel_spectrogram(y, n_fft, hop_length=hop_size, win_length=win_size, - window=hann_window[str(y.device)], + window=hann_window[win_key], center=center, pad_mode="reflect", normalized=False, @@ -91,7 +93,7 @@ def mel_spectrogram(y, spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) - spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) + spec = torch.matmul(mel_basis[mel_key], spec) spec = spectral_normalize_torch(spec) return spec