diff --git a/encodec/msstftd.py b/encodec/msstftd.py index a1d3242..d916cde 100644 --- a/encodec/msstftd.py +++ b/encodec/msstftd.py @@ -67,7 +67,7 @@ def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, self.convs.append( NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size)) ) - in_chs = min(filters_scale * self.filters, max_filters) + in_chs = self.filters for i, dilation in enumerate(dilations): out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters) self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride,