Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 20 additions & 19 deletions encodec/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@ def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Modu
return weight_norm(module)
elif norm == 'spectral_norm':
return spectral_norm(module)
else:
# We already check was in CONV_NORMALIZATION, so any other choice
# doesn't need reparametrization.
return module
# We already check was in CONV_NORMALIZATION, so any other choice
# doesn't need reparametrization.
return module


def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
Expand All @@ -47,8 +46,8 @@ def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none',
raise ValueError("GroupNorm doesn't support causal evaluation.")
assert isinstance(module, nn.modules.conv._ConvNd)
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
else:
return nn.Identity()

return nn.Identity()


def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
Expand Down Expand Up @@ -83,18 +82,21 @@ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', val
length = x.shape[-1]
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
if mode == 'reflect':
max_pad = max(padding_left, padding_right)
extra_pad = 0
if length <= max_pad:
extra_pad = max_pad - length + 1
x = F.pad(x, (0, extra_pad))
padded = F.pad(x, paddings, mode, value)
end = padded.shape[-1] - extra_pad
return padded[..., :end]
else:

# Constant padding
if mode != 'reflect':
return F.pad(x, paddings, mode, value)

# Reflect padding
max_pad = max(padding_left, padding_right)
extra_pad = 0
if length <= max_pad:
extra_pad = max_pad - length + 1
x = F.pad(x, (0, extra_pad))
padded = F.pad(x, paddings, mode, value)
end = padded.shape[-1] - extra_pad
return padded[..., :end]


def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
Expand Down Expand Up @@ -193,7 +195,6 @@ def __init__(self, in_channels: int, out_channels: int,
self.pad_mode = pad_mode

def forward(self, x):
B, C, T = x.shape
kernel_size = self.conv.conv.kernel_size[0]
stride = self.conv.conv.stride[0]
dilation = self.conv.conv.dilation[0]
Expand Down Expand Up @@ -244,10 +245,10 @@ def forward(self, x):
# if trim_right_ratio = 1.0, trim everything from right
padding_right = math.ceil(padding_total * self.trim_right_ratio)
padding_left = padding_total - padding_right
y = unpad1d(y, (padding_left, padding_right))
else:
# Asymmetric padding required for odd strides
padding_right = padding_total // 2
padding_left = padding_total - padding_right
y = unpad1d(y, (padding_left, padding_right))

y = unpad1d(y, (padding_left, padding_right))
return y