-
Notifications
You must be signed in to change notification settings - Fork 1
Description
I suspect the following error is caused by a recipe that runs in fp16.
When we clamp the filter at 1e-12 in fp16, it gets casted to 0
E.g.
❯ python
Python 3.10.18 (main, Jun 3 2025, 18:23:41) [Clang 16.0.0 (clang-1600.0.26.6)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
Cmd click to launch VS Code Native REPL
>>> import torch
>>> x = torch.tensor([1e-12])
>>> x.dtype
torch.float32
>>> x = torch.tensor([1e-12], dtype = torch.float16)
>>> x
tensor([0.], dtype=torch.float16)
>>>
Traceback (most recent call last):
File "/localscratch/gfdb.21204969.0/speechbrain/recipes/CommonVoice/ASR/transformer/train.py", line 440, in
asr_brain.fit(
File "/home/gfdb/speechbrain/speechbrain/core.py", line 1211, in fit
self._fit_train(train_set=train_set, epoch=epoch, enable=enable)
File "/home/gfdb/speechbrain/speechbrain/core.py", line 1036, in _fit_train
loss = self.fit_batch(batch)
^^^^^^^^^^^^^^^^^^^^^
File "/home/gfdb/speechbrain/speechbrain/core.py", line 835, in fit_batch
outputs = self.compute_forward(batch, sb.Stage.TRAIN)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/localscratch/gfdb.21204969.0/speechbrain/recipes/CommonVoice/ASR/transformer/train.py", line 48, in compute_forward
wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gfdb/envs/wav2aug/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/gfdb/wav2aug/wav2aug/gpu/wav2aug.py", line 123, in call
waveforms = op(waveforms, lengths)
^^^^^^^^^^^^^^^^^^^^^^
File "/home/gfdb/wav2aug/wav2aug/gpu/wav2aug.py", line 60, in
lambda x, lengths: freq_drop(x),
^^^^^^^^^^^^
File "/home/gfdb/envs/wav2aug/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/gfdb/wav2aug/wav2aug/gpu/frequency_dropout.py", line 176, in freq_drop
notch_kernel = _notch_filter(freq, _FILTER_LEN, width, device, dtype)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gfdb/wav2aug/wav2aug/gpu/frequency_dropout.py", line 31, in _notch_filter
assert 0 < notch_freq <= 1
^^^^^^^^^^^^^^^^^^^
AssertionError