diff --git a/encodec/modules/conv.py b/encodec/modules/conv.py index e83ae84..cc15459 100644 --- a/encodec/modules/conv.py +++ b/encodec/modules/conv.py @@ -28,10 +28,9 @@ def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Modu return weight_norm(module) elif norm == 'spectral_norm': return spectral_norm(module) - else: - # We already check was in CONV_NORMALIZATION, so any other choice - # doesn't need reparametrization. - return module + # We already check was in CONV_NORMALIZATION, so any other choice + # doesn't need reparametrization. + return module def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module: @@ -47,8 +46,8 @@ def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', raise ValueError("GroupNorm doesn't support causal evaluation.") assert isinstance(module, nn.modules.conv._ConvNd) return nn.GroupNorm(1, module.out_channels, **norm_kwargs) - else: - return nn.Identity() + + return nn.Identity() def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, @@ -83,18 +82,21 @@ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', val length = x.shape[-1] padding_left, padding_right = paddings assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) - if mode == 'reflect': - max_pad = max(padding_left, padding_right) - extra_pad = 0 - if length <= max_pad: - extra_pad = max_pad - length + 1 - x = F.pad(x, (0, extra_pad)) - padded = F.pad(x, paddings, mode, value) - end = padded.shape[-1] - extra_pad - return padded[..., :end] - else: + + # Constant padding + if mode != 'reflect': return F.pad(x, paddings, mode, value) + # Reflect padding + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): """Remove padding from x, handling properly zero padding. Only for 1d!""" @@ -193,7 +195,6 @@ def __init__(self, in_channels: int, out_channels: int, self.pad_mode = pad_mode def forward(self, x): - B, C, T = x.shape kernel_size = self.conv.conv.kernel_size[0] stride = self.conv.conv.stride[0] dilation = self.conv.conv.dilation[0] @@ -244,10 +245,10 @@ def forward(self, x): # if trim_right_ratio = 1.0, trim everything from right padding_right = math.ceil(padding_total * self.trim_right_ratio) padding_left = padding_total - padding_right - y = unpad1d(y, (padding_left, padding_right)) else: # Asymmetric padding required for odd strides padding_right = padding_total // 2 padding_left = padding_total - padding_right - y = unpad1d(y, (padding_left, padding_right)) + + y = unpad1d(y, (padding_left, padding_right)) return y