Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
tests/test_data/image_codec_decoded_batch.pt filter=lfs diff=lfs merge=lfs -text
tests/test_data/image_codec_encoded_batch.pt filter=lfs diff=lfs merge=lfs -text
tests/test_data/image_codec_input_batch.pt filter=lfs diff=lfs merge=lfs -text
tests/test_data/SPECTRUM_input_batch.pt filter=lfs diff=lfs merge=lfs -text
tests/test_data/SPECTRUM_decoded_batch.pt filter=lfs diff=lfs merge=lfs -text
tests/test_data/SPECTRUM_encoded_batch.pt filter=lfs diff=lfs merge=lfs -text
172 changes: 172 additions & 0 deletions aion/codecs/modules/convnext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import torch

from aion.codecs.modules.utils import LayerNorm, GRN


class ConvNextBlock1d(torch.nn.Module):
"""ConvNeXtV2 Block.
Modified to 1D from the original 2D implementation from https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py

Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
"""

def __init__(self, dim: int):
super().__init__()
self.dwconv = torch.nn.Conv1d(
dim, dim, kernel_size=7, padding=3, groups=dim
) # depthwise conv
self.norm = LayerNorm(dim, eps=1e-6)
self.pwconv1 = torch.nn.Linear(
dim, 4 * dim
) # pointwise/1x1 convs, implemented with linear layers
self.act = torch.nn.GELU()
self.grn = GRN(4 * dim)
self.pwconv2 = torch.nn.Linear(4 * dim, dim)

def forward(self, x):
y = self.dwconv(x)
y = y.permute(0, 2, 1) # (B, C, N) -> (B, N, C)
y = self.norm(y)
y = self.pwconv1(y)
y = self.act(y)
y = self.grn(y)
y = self.pwconv2(y)
y = y.permute(0, 2, 1) # (B, N, C) -> (B, C, N)

y = x + y
return y


class ConvNextEncoder1d(torch.nn.Module):
r"""ConvNeXt encoder.

Modified from https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py

Args:
in_chans : Number of input image channels. Default: 3
depths : Number of blocks at each stage. Default: [3, 3, 9, 3]
dims : Feature dimension at each stage. Default: [96, 192, 384, 768]
drop_path_rate : Stochastic depth rate. Default: 0.
layer_scale_init_value : Init value for Layer Scale. Default: 1e-6.
"""

def __init__(
self,
in_chans: int = 2,
depths: tuple[int, ...] = (3, 3, 9, 3),
dims: tuple[int, ...] = (96, 192, 384, 768),
):
super().__init__()
assert len(depths) == len(dims), "depths and dims should have the same length"
num_layers = len(depths)

self.downsample_layers = (
torch.nn.ModuleList()
) # stem and 3 intermediate downsampling conv layers
stem = torch.nn.Sequential(
torch.nn.Conv1d(in_chans, dims[0], kernel_size=4, stride=4),
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
)
self.downsample_layers.append(stem)
for i in range(num_layers - 1):
downsample_layer = torch.nn.Sequential(
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
torch.nn.Conv1d(dims[i], dims[i + 1], kernel_size=2, stride=2),
)
self.downsample_layers.append(downsample_layer)

self.stages = torch.nn.ModuleList()
for i in range(num_layers):
stage = torch.nn.Sequential(
*[
ConvNextBlock1d(
dim=dims[i],
)
for j in range(depths[i])
]
)
self.stages.append(stage)

self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")

self.apply(self._init_weights)

def _init_weights(self, m):
if isinstance(m, (torch.nn.Conv1d, torch.nn.Linear)):
torch.nn.init.trunc_normal_(m.weight, std=0.02)
torch.nn.init.constant_(m.bias, 0)

def forward(self, x):
for ds, st in zip(self.downsample_layers, self.stages):
x = ds(x)
x = st(x)
return self.norm(x)


class ConvNextDecoder1d(torch.nn.Module):
r"""ConvNeXt decoder. Essentially a mirrored version of the encoder.

Args:
in_chans (int): Number of input image channels. Default: 3
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
drop_path_rate (float): Stochastic depth rate. Default: 0.
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
"""

def __init__(
self,
in_chans=768,
depths=[3, 3, 9, 3],
dims=[384, 192, 96, 2],
):
super().__init__()
assert len(depths) == len(dims), "depths and dims should have the same length"
num_layers = len(depths)

self.upsample_layers = torch.nn.ModuleList()

stem = torch.nn.Sequential(
torch.nn.ConvTranspose1d(in_chans, dims[0], kernel_size=2, stride=2),
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
)
self.upsample_layers.append(stem)

for i in range(num_layers - 1):
upsample_layer = torch.nn.Sequential(
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
torch.nn.ConvTranspose1d(
dims[i],
dims[i + 1],
kernel_size=2 if i < (num_layers - 2) else 4,
stride=2 if i < (num_layers - 2) else 4,
),
)
self.upsample_layers.append(upsample_layer)

self.stages = torch.nn.ModuleList()
for i in range(num_layers):
stage = torch.nn.Sequential(
*[
ConvNextBlock1d(
dim=dims[i],
)
for j in range(depths[i])
]
)
self.stages.append(stage)

self.apply(self._init_weights)

def _init_weights(self, m):
if isinstance(m, (torch.nn.Conv1d, torch.nn.Linear)):
torch.nn.init.trunc_normal_(m.weight, std=0.02)
torch.nn.init.constant_(m.bias, 0)

def forward(self, x):
for us, st in zip(self.upsample_layers, self.stages):
x = us(x)
x = st(x)
return x
110 changes: 110 additions & 0 deletions aion/codecs/modules/spectrum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import torch
from jaxtyping import Float


def interp1d(
x: Float[torch.Tensor, " b n"],
y: Float[torch.Tensor, " b n"],
xnew: Float[torch.Tensor, " b m"],
mask_value: float | None = 0.0,
) -> Float[torch.Tensor, " b m"]:
"""Linear interpolation of a 1-D tensor using torch.searchsorted.
Assumes that x and xnew are sorted in increasing order.

Args:
x: The x-coordinates of the data points, shape [batch, N].
y: The y-coordinates of the data points, shape [batch, N].
xnew: The x-coordinates of the interpolated points, shape [batch, M].
mask_value: The value to use for xnew outside the range of x.
Returns:
The y-coordinates of the interpolated points, shape [batch, M].
"""
# Find the indices where xnew should be inserted in sorted_x
# Given a point xnew[i] in xnew, return j where x[j] is the nearest point in x such that
# x[j] < xnew[i], except if the nearest point in x has x[j] = xnew[i] then return j - 1.
indices = torch.searchsorted(x, xnew) - 1

# We can define a local linear approx of the grad in each interval
# between two points in x, and we would like to use this to interpolate
# y at those points in xnew which lie inside the range of x, otherwise
# interpolated_y is masked for points in xnew outside the range of x.
# There are len(x) - 1 such intervals between points in x, having indices
# ranging between 0 and len(x) - 2. Points with xnew < min(x) will be
# assigned indices of -1 and points with xnew > max(x) will be assigned
# indices equal to len(x). These are not valid segment indices, but we can
# clamp them to 0 and len(x) - 2 respectively to avoid breaking the
# calculation of the slope variable. The nonsense values we obtain outside
# the range of x will be discarded when masking.
indices = torch.clamp(indices, 0, x.shape[1] - 1 - 1)

slopes = (y[:, :-1] - y[:, 1:]) / (x[:, :-1] - x[:, 1:])

# Interpolate the y-coordinates
ynew = torch.gather(y, 1, indices) + (
xnew - torch.gather(x, 1, indices)
) * torch.gather(slopes, 1, indices)

# Mask out the values that are outside the valid range
mask = (xnew < x[..., 0].reshape(-1, 1)) | (xnew > x[..., -1].reshape(-1, 1))
ynew[mask] = mask_value

return ynew


class LatentSpectralGrid(torch.nn.Module):
def __init__(self, lambda_min: float, resolution: float, num_pixels: int):
"""
Initialize a latent grid to represent spectra from multiple resolutions.

Args:
lambda_min: The minimum wavelength value, in Angstrom.
resolution: The resolution of the spectra, in Angstrom per pixel.
num_pixels: The number of pixels in the spectra.

"""
super().__init__()
self.register_buffer("lambda_min", torch.tensor(lambda_min))
self.register_buffer("resolution", torch.tensor(resolution))
self.register_buffer("length", torch.tensor(num_pixels))
self.register_buffer(
"_wavelength",
(torch.arange(0, num_pixels) * resolution + lambda_min).reshape(
1, num_pixels
),
)

@property
def wavelength(self) -> Float[torch.Tensor, " n"]:
return self._wavelength.squeeze()

def to_observed(
self,
x_latent: Float[torch.Tensor, " b n"],
wavelength: Float[torch.Tensor, " b m"],
) -> Float[torch.Tensor, " b m"]:
"""Transforms the latent representation to the observed wavelength grid.

Args:
x_latent: The latent representation, [batch, self.num_pixels].
wavelength: The observed wavelength grid, [batch, M].

Returns:
The transformed representation on the observed wavelength grid.
"""
b = x_latent.shape[0]
return interp1d(self._wavelength.repeat([b, 1]), x_latent, wavelength)

def to_latent(
self, x_obs: Float[torch.Tensor, "b m"], wavelength: Float[torch.Tensor, "b m"]
) -> Float[torch.Tensor, "b n"]:
"""Transforms the observed representation to the latent wavelength grid.

Args:
x_obs: The observed representation, [batch, N].
wavelength: The wavelength grid, [batch, N].

Returns:
The transformed representation on the latent wavelength grid.
"""
b = x_obs.shape[0]
return interp1d(wavelength, x_obs, self._wavelength.repeat([b, 1]))
45 changes: 45 additions & 0 deletions aion/codecs/modules/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch
import torch.nn.functional as F
from einops import rearrange


class LayerNorm(torch.nn.Module):
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""

def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(normalized_shape))
self.bias = torch.nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape,)

def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(
x, self.normalized_shape, self.weight, self.bias, self.eps
)
elif self.data_format == "channels_first":
x = rearrange(x, "b c ... -> b ... c")
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
return rearrange(x, "b ... c -> b c ...")


class GRN(torch.nn.Module):
"""GRN (Global Response Normalization) layer"""

def __init__(self, dim):
super().__init__()
self.gamma = torch.nn.Parameter(torch.zeros(1, 1, dim))
self.beta = torch.nn.Parameter(torch.zeros(1, 1, dim))

def forward(self, x):
Gx = torch.norm(x, p=2, dim=(1,), keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x
Loading
Loading