From 2ce99b6e2700fb60a00976d82ab4eb943c60675e Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Thu, 10 Apr 2025 13:22:02 +0200 Subject: [PATCH 01/12] Add spectrum tokenizer --- aion/codecs/modules/convnext.py | 172 +++++++++++++++++++++++ aion/codecs/modules/spectrum.py | 110 +++++++++++++++ aion/codecs/modules/utils.py | 45 ++++++ aion/codecs/quantizers/__init__.py | 178 ++++++++++++++++++++++++ aion/codecs/tokenizers/spectrum.py | 212 +++++++++++++++++++++++++++++ 5 files changed, 717 insertions(+) create mode 100644 aion/codecs/modules/convnext.py create mode 100644 aion/codecs/modules/spectrum.py create mode 100644 aion/codecs/modules/utils.py create mode 100644 aion/codecs/tokenizers/spectrum.py diff --git a/aion/codecs/modules/convnext.py b/aion/codecs/modules/convnext.py new file mode 100644 index 0000000..80cc7d0 --- /dev/null +++ b/aion/codecs/modules/convnext.py @@ -0,0 +1,172 @@ +import torch + +from aion.codecs.modules.utils import LayerNorm, GRN + + +class ConvNextBlock1d(torch.nn.Module): + """ConvNeXtV2 Block. + Modified to 1D from the original 2D implementation from https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + """ + + def __init__(self, dim: int): + super().__init__() + self.dwconv = torch.nn.Conv1d( + dim, dim, kernel_size=7, padding=3, groups=dim + ) # depthwise conv + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = torch.nn.Linear( + dim, 4 * dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = torch.nn.GELU() + self.grn = GRN(4 * dim) + self.pwconv2 = torch.nn.Linear(4 * dim, dim) + + def forward(self, x): + y = self.dwconv(x) + y = y.permute(0, 2, 1) # (B, C, N) -> (B, N, C) + y = self.norm(y) + y = self.pwconv1(y) + y = self.act(y) + y = self.grn(y) + y = self.pwconv2(y) + y = y.permute(0, 2, 1) # (B, N, C) -> (B, C, N) + + y = x + y + return y + + +class ConvNextEncoder1d(torch.nn.Module): + r"""ConvNeXt encoder. + + Modified from https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py + + Args: + in_chans : Number of input image channels. Default: 3 + depths : Number of blocks at each stage. Default: [3, 3, 9, 3] + dims : Feature dimension at each stage. Default: [96, 192, 384, 768] + drop_path_rate : Stochastic depth rate. Default: 0. + layer_scale_init_value : Init value for Layer Scale. Default: 1e-6. + """ + + def __init__( + self, + in_chans: int = 2, + depths: tuple[int, ...] = (3, 3, 9, 3), + dims: tuple[int, ...] = (96, 192, 384, 768), + ): + super().__init__() + assert len(depths) == len(dims), "depths and dims should have the same length" + num_layers = len(depths) + + self.downsample_layers = ( + torch.nn.ModuleList() + ) # stem and 3 intermediate downsampling conv layers + stem = torch.nn.Sequential( + torch.nn.Conv1d(in_chans, dims[0], kernel_size=4, stride=4), + LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), + ) + self.downsample_layers.append(stem) + for i in range(num_layers - 1): + downsample_layer = torch.nn.Sequential( + LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), + torch.nn.Conv1d(dims[i], dims[i + 1], kernel_size=2, stride=2), + ) + self.downsample_layers.append(downsample_layer) + + self.stages = torch.nn.ModuleList() + for i in range(num_layers): + stage = torch.nn.Sequential( + *[ + ConvNextBlock1d( + dim=dims[i], + ) + for j in range(depths[i]) + ] + ) + self.stages.append(stage) + + self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first") + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (torch.nn.Conv1d, torch.nn.Linear)): + torch.nn.init.trunc_normal_(m.weight, std=0.02) + torch.nn.init.constant_(m.bias, 0) + + def forward(self, x): + for ds, st in zip(self.downsample_layers, self.stages): + x = ds(x) + x = st(x) + return self.norm(x) + + +class ConvNextDecoder1d(torch.nn.Module): + r"""ConvNeXt decoder. Essentially a mirrored version of the encoder. + + Args: + in_chans (int): Number of input image channels. Default: 3 + depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] + dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] + drop_path_rate (float): Stochastic depth rate. Default: 0. + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + + def __init__( + self, + in_chans=768, + depths=[3, 3, 9, 3], + dims=[384, 192, 96, 2], + ): + super().__init__() + assert len(depths) == len(dims), "depths and dims should have the same length" + num_layers = len(depths) + + self.upsample_layers = torch.nn.ModuleList() + + stem = torch.nn.Sequential( + torch.nn.ConvTranspose1d(in_chans, dims[0], kernel_size=2, stride=2), + LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), + ) + self.upsample_layers.append(stem) + + for i in range(num_layers - 1): + upsample_layer = torch.nn.Sequential( + LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), + torch.nn.ConvTranspose1d( + dims[i], + dims[i + 1], + kernel_size=2 if i < (num_layers - 2) else 4, + stride=2 if i < (num_layers - 2) else 4, + ), + ) + self.upsample_layers.append(upsample_layer) + + self.stages = torch.nn.ModuleList() + for i in range(num_layers): + stage = torch.nn.Sequential( + *[ + ConvNextBlock1d( + dim=dims[i], + ) + for j in range(depths[i]) + ] + ) + self.stages.append(stage) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (torch.nn.Conv1d, torch.nn.Linear)): + torch.nn.init.trunc_normal_(m.weight, std=0.02) + torch.nn.init.constant_(m.bias, 0) + + def forward(self, x): + for us, st in zip(self.upsample_layers, self.stages): + x = us(x) + x = st(x) + return x diff --git a/aion/codecs/modules/spectrum.py b/aion/codecs/modules/spectrum.py new file mode 100644 index 0000000..3df6a4a --- /dev/null +++ b/aion/codecs/modules/spectrum.py @@ -0,0 +1,110 @@ +import torch +from jaxtyping import Float + + +def interp1d( + x: Float[torch.Tensor, " b n"], + y: Float[torch.Tensor, " b n"], + xnew: Float[torch.Tensor, " b m"], + mask_value: float | None = 0.0, +) -> Float[torch.Tensor, " b m"]: + """Linear interpolation of a 1-D tensor using torch.searchsorted. + Assumes that x and xnew are sorted in increasing order. + + Args: + x: The x-coordinates of the data points, shape [batch, N]. + y: The y-coordinates of the data points, shape [batch, N]. + xnew: The x-coordinates of the interpolated points, shape [batch, M]. + mask_value: The value to use for xnew outside the range of x. + Returns: + The y-coordinates of the interpolated points, shape [batch, M]. + """ + # Find the indices where xnew should be inserted in sorted_x + # Given a point xnew[i] in xnew, return j where x[j] is the nearest point in x such that + # x[j] < xnew[i], except if the nearest point in x has x[j] = xnew[i] then return j - 1. + indices = torch.searchsorted(x, xnew) - 1 + + # We can define a local linear approx of the grad in each interval + # between two points in x, and we would like to use this to interpolate + # y at those points in xnew which lie inside the range of x, otherwise + # interpolated_y is masked for points in xnew outside the range of x. + # There are len(x) - 1 such intervals between points in x, having indices + # ranging between 0 and len(x) - 2. Points with xnew < min(x) will be + # assigned indices of -1 and points with xnew > max(x) will be assigned + # indices equal to len(x). These are not valid segment indices, but we can + # clamp them to 0 and len(x) - 2 respectively to avoid breaking the + # calculation of the slope variable. The nonsense values we obtain outside + # the range of x will be discarded when masking. + indices = torch.clamp(indices, 0, x.shape[1] - 1 - 1) + + slopes = (y[:, :-1] - y[:, 1:]) / (x[:, :-1] - x[:, 1:]) + + # Interpolate the y-coordinates + ynew = torch.gather(y, 1, indices) + ( + xnew - torch.gather(x, 1, indices) + ) * torch.gather(slopes, 1, indices) + + # Mask out the values that are outside the valid range + mask = (xnew < x[..., 0].reshape(-1, 1)) | (xnew > x[..., -1].reshape(-1, 1)) + ynew[mask] = mask_value + + return ynew + + +class LatentSpectralGrid(torch.nn.Module): + def __init__(self, lambda_min: float, resolution: float, num_pixels: int): + """ + Initialize a latent grid to represent spectra from multiple resolutions. + + Args: + lambda_min: The minimum wavelength value, in Angstrom. + resolution: The resolution of the spectra, in Angstrom per pixel. + num_pixels: The number of pixels in the spectra. + + """ + super().__init__() + self.register_buffer("lambda_min", torch.tensor(lambda_min)) + self.register_buffer("resolution", torch.tensor(resolution)) + self.register_buffer("length", torch.tensor(num_pixels)) + self.register_buffer( + "_wavelength", + (torch.arange(0, num_pixels) * resolution + lambda_min).reshape( + 1, num_pixels + ), + ) + + @property + def wavelength(self) -> Float[torch.Tensor, " n"]: + return self._wavelength.squeeze() + + def to_observed( + self, + x_latent: Float[torch.Tensor, " b n"], + wavelength: Float[torch.Tensor, " b m"], + ) -> Float[torch.Tensor, " b m"]: + """Transforms the latent representation to the observed wavelength grid. + + Args: + x_latent: The latent representation, [batch, self.num_pixels]. + wavelength: The observed wavelength grid, [batch, M]. + + Returns: + The transformed representation on the observed wavelength grid. + """ + b = x_latent.shape[0] + return interp1d(self._wavelength.repeat([b, 1]), x_latent, wavelength) + + def to_latent( + self, x_obs: Float[torch.Tensor, "b m"], wavelength: Float[torch.Tensor, "b m"] + ) -> Float[torch.Tensor, "b n"]: + """Transforms the observed representation to the latent wavelength grid. + + Args: + x_obs: The observed representation, [batch, N]. + wavelength: The wavelength grid, [batch, N]. + + Returns: + The transformed representation on the latent wavelength grid. + """ + b = x_obs.shape[0] + return interp1d(wavelength, x_obs, self._wavelength.repeat([b, 1])) diff --git a/aion/codecs/modules/utils.py b/aion/codecs/modules/utils.py new file mode 100644 index 0000000..6b2907a --- /dev/null +++ b/aion/codecs/modules/utils.py @@ -0,0 +1,45 @@ +import torch +import torch.nn.functional as F +from einops import rearrange + + +class LayerNorm(torch.nn.Module): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(normalized_shape)) + self.bias = torch.nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape,) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm( + x, self.normalized_shape, self.weight, self.bias, self.eps + ) + elif self.data_format == "channels_first": + x = rearrange(x, "b c ... -> b ... c") + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return rearrange(x, "b ... c -> b c ...") + + +class GRN(torch.nn.Module): + """GRN (Global Response Normalization) layer""" + + def __init__(self, dim): + super().__init__() + self.gamma = torch.nn.Parameter(torch.zeros(1, 1, dim)) + self.beta = torch.nn.Parameter(torch.zeros(1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(1,), keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x diff --git a/aion/codecs/quantizers/__init__.py b/aion/codecs/quantizers/__init__.py index 18b6102..aae4649 100644 --- a/aion/codecs/quantizers/__init__.py +++ b/aion/codecs/quantizers/__init__.py @@ -3,6 +3,7 @@ import torch from jaxtyping import Float, Integer +from vector_quantize_pytorch import LFQ class Quantizer(torch.nn.Module, ABC): @@ -186,3 +187,180 @@ def forward( codes = self.encode(z_e) codebook_usage = len(torch.unique(codes)) / self.codebook_size return z_q, torch.zeros([]), codebook_usage + + +class LucidrainsLFQ(Quantizer): + def __init__( + self, + dim: int | None = None, + codebook_size: int | None = None, + inv_temperature: float = 100.0, + entropy_loss_weight: float = 0.1, + commitment_loss_weight: float = 0.25, + diversity_gamma: float = 1.0, + num_codebooks: int = 1, + keep_num_codebooks_dim: bool | None = None, + codebook_scale: float = 1.0, + frac_per_sample_entropy: float = 1.0, + use_code_agnostic_commit_loss: bool = False, + projection_has_bias: bool = True, + soft_clamp_input_value: bool | None = None, + cosine_sim_project_in: bool = False, + cosine_sim_project_in_scale: float | None = None, + ): + """Lookup Free Quantizer (LFQ) from the MagVITv2 paper + https://arxiv.org/abs/2310.05737 + + Following the implementation from vector-quantize-pytorch + """ + super().__init__() + self._inverse_temperature = inv_temperature + self._quantizer = LFQ( + dim=dim, + codebook_size=codebook_size, + entropy_loss_weight=entropy_loss_weight, + commitment_loss_weight=commitment_loss_weight, + diversity_gamma=diversity_gamma, + num_codebooks=num_codebooks, + keep_num_codebooks_dim=keep_num_codebooks_dim, + codebook_scale=codebook_scale, + frac_per_sample_entropy=frac_per_sample_entropy, + use_code_agnostic_commit_loss=use_code_agnostic_commit_loss, + projection_has_bias=projection_has_bias, + soft_clamp_input_value=soft_clamp_input_value, + cosine_sim_project_in=cosine_sim_project_in, + cosine_sim_project_in_scale=cosine_sim_project_in_scale, + ) + + def forward( + self, z_e: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Performs a forward pass through the vector quantizer. + Args: + z_e: Tensor (B, C, ...) + The input tensor to be quantized. + Returns: + z_q: Tensor + The quantized tensor. + loss: Tensor + The embedding loss for the quantization. + codebook_usage: Tensor + The fraction of codes used in the codebook. + """ + # In cases where we only have a sequence, we need to move the sequence dimension to the last dimension + # For compatibility with the upstream quantizer + if len(z_e.shape) == 3: + z_e = z_e.movedim(1, -1) + z_q, indices, aux_loss = self._quantizer( + z_e, inv_temperature=self._inverse_temperature + ) + codebook_usage = indices.unique().numel() / self.codebook_size + if len(z_q.shape) == 3: + z_q = z_q.movedim(-1, 1) + return z_q, aux_loss, torch.tensor(codebook_usage) + + def quantize(self, z: torch.Tensor) -> torch.Tensor: + """Quantizes the input tensor z, returns the corresponding + codebook index. + """ + return self.encode(z) + + def encode(self, z: torch.Tensor) -> torch.Tensor: + """Encodes the input tensor z, returns the corresponding + codebook index. + """ + # In cases where we only have a sequence, we need to move the sequence dimension to the last dimension + # For compatibility with the upstream quantizer + if len(z.shape) == 3: + z = z.movedim(1, -1) + z_q, indices, aux_loss = self._quantizer( + z, inv_temperature=self._inverse_temperature + ) + return indices + + def reconstruct(self, codes: torch.Tensor) -> torch.Tensor: + """Decodes the input code index into corresponding codebook entry of + dimension (embedding_dim). + """ + z = self._quantizer.indices_to_codes(codes) + # For compatibility with the upstream quantizer, we need to move the last dimension to the sequence dimension + if len(z.shape) == 3: + z = z.movedim(-1, 1) + return z + + @property + def codebook_size(self) -> int: + """Returns the size of the codebook.""" + return len(self._quantizer.codebook) + + @property + def embedding_dim(self) -> int: + """Returns the dimension of the codebook entries.""" + return self._quantizer.codebook_dim + + +class ScalarLinearQuantizer(Quantizer): + """A simple non-adaptive quantizer which will encode scalars by binning + on fixed histogram in the specified range. + """ + + def __init__(self, codebook_size: int, range: tuple[float, float]): + super().__init__() + self.register_buffer( + "buckets", torch.linspace(range[0], range[1], codebook_size) + ) + + def forward( + self, z_e: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Performs a forward pass through the vector quantizer. + Args: + z_e: Tensor (B, C, ...) + The input tensor to be quantized. + Returns: + z_q: Tensor + The quantized tensor. + loss: Tensor + The embedding loss for the quantization. + codebook_usage: Tensor + The fraction of codes used in the codebook. + """ + indices = self.encode(z_e) + z_q = self.decode(indices) + codebook_usage = indices.unique().numel() / self.codebook_size + return z_q, torch.tensor(0), torch.tensor(codebook_usage) + + def quantize(self, z: torch.Tensor) -> torch.Tensor: + """Quantizes the input tensor z, returns the corresponding + codebook index. + """ + return self.reconstruct(self.encode(z)) + + def encode(self, z: torch.Tensor) -> torch.Tensor: + """Encodes the input tensor z, returns the corresponding + codebook index. + """ + return torch.clamp( + torch.bucketize(z, self.buckets, out_int32=True), 0, self.codebook_size - 1 + ) + + def reconstruct(self, codes: torch.Tensor) -> torch.Tensor: + """Decodes the input code index into corresponding codebook entry of + dimension (embedding_dim). + """ + return self.buckets[codes] + + @property + def codebook_size(self) -> int: + """Returns the size of the codebook.""" + return len(self.buckets) + + @property + def codebook(self) -> torch.Tensor: + """Returns the codebook.""" + return self.decode(torch.arange(self.codebook_size)) + + @property + def embedding_dim(self) -> int: + """Returns the dimension of the codebook entries.""" + return 1 diff --git a/aion/codecs/tokenizers/spectrum.py b/aion/codecs/tokenizers/spectrum.py new file mode 100644 index 0000000..1633bf9 --- /dev/null +++ b/aion/codecs/tokenizers/spectrum.py @@ -0,0 +1,212 @@ +import torch +from jaxtyping import Float, Real + +from aion.codecs.modules.convnext import ConvNextDecoder1d, ConvNextEncoder1d +from aion.codecs.modules.spectrum import LatentSpectralGrid +from aion.codecs.quantizers import Quantizer +from aion.codecs.tokenizers.base import QuantizedCodec + + +class AutoencoderSpectrumCodec(QuantizedCodec): + """Meta-class for autoencoder codecs for spectra, does not actually contains a network.""" + + def __init__( + self, + quantizer: Quantizer, + encoder: torch.nn.Module, + decoder: torch.nn.Module, + normalization_quantizer: Quantizer, + lambda_min: float = 3500.0, + resolution: float = 0.8, + num_pixels: int = 8704, + latent_channels: int = 512, + embedding_dim: int = 4, + clip_ivar: float = 100, + clip_flux: float | None = None, + input_scaling: float = 0.2, + ): + super().__init__(quantizer) + self.encoder = encoder + self.decoder = decoder + self.normalization_quantizer = normalization_quantizer + self.latent_grid = LatentSpectralGrid( + lambda_min=lambda_min, resolution=resolution, num_pixels=num_pixels + ) + self.embedding_dim = embedding_dim + self.clip_ivar = clip_ivar + self.clip_flux = clip_flux + self.input_scaling = input_scaling + self.pre_quant_norm = torch.nn.LayerNorm(latent_channels) + self.quant_conv = torch.nn.Conv1d(latent_channels, embedding_dim, 1) + self.post_quant_conv = torch.nn.Conv1d(embedding_dim, latent_channels, 1) + + @property + def modality(self): + return "spectrum" + + def _encode( + self, + flux: Float[torch.Tensor, " b t"], + ivar: Float[torch.Tensor, " b t"], + mask: Float[torch.Tensor, " b t"], + wavelength: Float[torch.Tensor, " b t"], + ) -> tuple[Float[torch.Tensor, " b c t"], Float[torch.Tensor, " b"]]: + # Robustify the model against NaN values in the input + # And add optional cliping of extreme values + spectrum = torch.nan_to_num(flux) + if self.clip_flux is not None: + spectrum = torch.clamp(spectrum, -self.clip_flux, self.clip_flux) + ivar = torch.nan_to_num(ivar) + if self.clip_ivar is not None: + ivar = torch.clamp(ivar, 0, self.clip_ivar) + istd = torch.sqrt(ivar) + + # Normalize input spectrum + normalization = (spectrum * (1.0 - mask.float())).sum(dim=-1) / ( + torch.count_nonzero(~mask, dim=-1) + 1.0 + ) + + normalization = torch.clamp(normalization, 0.1) + + # Compressing the range of this normalization factor + normalization = torch.log10(normalization + 1.0) + + # Apply quantization to normalization factor + normalization = self.normalization_quantizer.quantize(normalization) + + # Normalize the spectrum + n = torch.clamp((10 ** normalization[..., None] - 1.0), 0.1) + spectrum = (spectrum / n - 1.0) * self.input_scaling + istd = (istd / n) * self.input_scaling + + # Project spectra on the latent grid + spectrum = self.latent_grid.to_latent(spectrum, wavelength) + istd = self.latent_grid.to_latent(istd, wavelength) + + # Apply additional range compression for good measure + x = torch.arcsinh(torch.stack([spectrum, istd], dim=1)) + h = self.encoder(x) + h = self.pre_quant_norm(h.moveaxis(1, -1)).moveaxis(-1, 1) + h = self.quant_conv(h) + return h, normalization + + def encode( + self, + flux: Float[torch.Tensor, " b t"], + ivar: Float[torch.Tensor, " b t"], + mask: Float[torch.Tensor, " b t"], + wavelength: Float[torch.Tensor, " b t"], + ) -> Real[torch.Tensor, " b code"]: + embedding, normalization = self._encode(flux, ivar, mask, wavelength) + + embedding = self.quantizer.encode(embedding) + + normalization = self.normalization_quantizer.encode(normalization) + embedding = torch.cat([normalization[..., None], embedding], dim=-1) + + return embedding + + def decode( + self, z, wavelength: Float[torch.Tensor, " b t"] | None = None + ) -> Float[torch.Tensor, " b t"]: + # Extract the normalization token from the sequence + norm_token, z = z[..., 0], z[..., 1:] + + normalization = self.normalization_quantizer.reconstruct(norm_token) + + z = self.quantizer.reconstruct(z) + + # The wavelength grid to decode the spectrum + return self._decode(z, wavelength=wavelength, normalization=normalization) + + def _decode( + self, + z: Float[torch.Tensor, " b c l"], + wavelength: Float[torch.Tensor, " b t"] | None = None, + normalization: Float[torch.Tensor, " b"] | None = None, + sigmoid_and_round_mask: bool = True, + ) -> tuple[ + Float[torch.Tensor, " b t"], + Float[torch.Tensor, " b t"], + Float[torch.Tensor, " b t"], + ]: + h = self.post_quant_conv(z) + spectra = self.decoder(h) + + if spectra.shape[1] == 1: # just flux + spectra = spectra.squeeze(1) + mask = torch.ones_like(spectra) * -torch.inf + elif spectra.shape[1] == 2: # flux and mask + spectra, mask = spectra.chunk(2, dim=1) + spectra, mask = spectra.squeeze(1), mask.squeeze(1) + else: + raise ValueError("Invalid number of output channels, must be 1 or 2") + + # If the wavelength are provided, interpolate the spectrum on the observed grid + if wavelength is not None: + spectra = self.latent_grid.to_observed(spectra, wavelength) + mask = self.latent_grid.to_observed(mask, wavelength) + else: + b = spectra.shape[0] + wavelength = self.latent_grid.wavelength.reshape(1, -1).repeat(b, 1) + + # Decode the spectrum on the latent grid and apply normalization + if normalization is not None: + spectra = (spectra + 1.0) * torch.clamp( + 10 ** normalization[..., None] - 1.0, 0.1 + ) + + if sigmoid_and_round_mask: + mask = torch.round(torch.sigmoid(mask)).detach() + + return spectra, wavelength, mask + + +class ConvNextAESpectrumCodec(AutoencoderSpectrumCodec): + """Spectrum codec based on convnext blocks.""" + + def __init__( + self, + quantizer: Quantizer, + normalization_quantizer: Quantizer, + encoder_depths: tuple[int, ...] = (3, 3, 9, 3), + encoder_dims: tuple[int, ...] = (96, 192, 384, 768), + decoder_depths: tuple[int, ...] = (3, 3, 9, 3), + decoder_dims: tuple[int, ...] = (384, 192, 96, 1), + lambda_min: float = 3500.0, + resolution: float = 0.8, + num_pixels: int = 8704, + latent_channels: int = 512, + embedding_dim: int = 4, + clip_ivar: float = 100, + clip_flux: float | None = None, + input_scaling: float = 0.2, + ): + assert encoder_dims[-1] == latent_channels, ( + "Last encoder dim must match latent_channels" + ) + self.encoder = ConvNextEncoder1d( + in_chans=2, + depths=encoder_depths, + dims=encoder_dims, + ) + + self.decoder = ConvNextDecoder1d( + in_chans=latent_channels, + depths=decoder_depths, + dims=decoder_dims, + ) + super().__init__( + quantizer=quantizer, + encoder=self.encoder, + decoder=self.decoder, + normalization_quantizer=normalization_quantizer, + lambda_min=lambda_min, + resolution=resolution, + num_pixels=num_pixels, + latent_channels=latent_channels, + embedding_dim=embedding_dim, + clip_ivar=clip_ivar, + clip_flux=clip_flux, + input_scaling=input_scaling, + ) From 9cd08ba9751a8cef5ea5dda3509fe355a4c48e11 Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Mon, 12 May 2025 14:10:38 -0400 Subject: [PATCH 02/12] Rename FiniteScaleQuantizer->FiniteScalarQuantizer --- aion/codecs/quantizers/__init__.py | 2 +- tests/tokenizers/test_image_tokenizer.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/aion/codecs/quantizers/__init__.py b/aion/codecs/quantizers/__init__.py index 18b6102..6fb3f64 100644 --- a/aion/codecs/quantizers/__init__.py +++ b/aion/codecs/quantizers/__init__.py @@ -41,7 +41,7 @@ def forward( raise NotImplementedError -class FiniteScaleQuantizer(Quantizer): +class FiniteScalarQuantizer(Quantizer): def __init__( self, levels: list[int], diff --git a/tests/tokenizers/test_image_tokenizer.py b/tests/tokenizers/test_image_tokenizer.py index 8b8d509..55243dd 100644 --- a/tests/tokenizers/test_image_tokenizer.py +++ b/tests/tokenizers/test_image_tokenizer.py @@ -2,7 +2,7 @@ import torch from aion.codecs.tokenizers.image import MagViTAEImageCodec -from aion.codecs.quantizers import FiniteScaleQuantizer +from aion.codecs.quantizers import FiniteScalarQuantizer @pytest.mark.parametrize("n_bands", [3, 10]) @@ -14,7 +14,7 @@ def test_magvit_image_tokenizer( ): tokenizer = MagViTAEImageCodec( n_bands=n_bands, - quantizer=FiniteScaleQuantizer(levels=[1] * embedding_dim), + quantizer=FiniteScalarQuantizer(levels=[1] * embedding_dim), hidden_dims=hidden_dims, multisurvey_projection_dims=multisurvey_projection_dims, n_compressions=2, From d37fb30914cb0f0ce231152a4e15a3a56487e52b Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Tue, 13 May 2025 17:33:12 -0400 Subject: [PATCH 03/12] Add channel mask as input --- aion/codecs/tokenizers/base.py | 14 +++++++++----- aion/codecs/tokenizers/image.py | 8 ++++---- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/aion/codecs/tokenizers/base.py b/aion/codecs/tokenizers/base.py index a3194fa..bf041a6 100644 --- a/aion/codecs/tokenizers/base.py +++ b/aion/codecs/tokenizers/base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod import torch -from jaxtyping import Float +from jaxtyping import Float, Bool from aion.codecs.quantizers import Quantizer @@ -40,10 +40,12 @@ def _decode( raise NotImplementedError def encode( - self, x: Float[torch.Tensor, " b c *input_shape"] + self, + x: Float[torch.Tensor, " b c *input_shape"], + channel_mask: Bool[torch.Tensor, " b c"], ) -> Float[torch.Tensor, " b c1 *code_shape"]: """Encodes a given batch of samples into latent space.""" - return self._encode(x) + return self._encode(x, channel_mask) def decode( self, z: Float[torch.Tensor, " b c1 *code_shape"] @@ -64,7 +66,9 @@ def decode( return super().decode(z) def encode( - self, x: Float[torch.Tensor, " b c *input_shape"] + self, + x: Float[torch.Tensor, " b c *input_shape"], + channel_mask: Bool[torch.Tensor, " b c"], ) -> Float[torch.Tensor, " b c1 *code_shape"]: - embedding = super().encode(x) + embedding = super().encode(x, channel_mask) return self.quantizer.encode(embedding) diff --git a/aion/codecs/tokenizers/image.py b/aion/codecs/tokenizers/image.py index abae50b..93963d9 100644 --- a/aion/codecs/tokenizers/image.py +++ b/aion/codecs/tokenizers/image.py @@ -1,5 +1,5 @@ import torch -from jaxtyping import Float +from jaxtyping import Float, Bool from aion.codecs.modules.magvit import MagVitAE from aion.codecs.modules.subsampler import SubsampledLinear @@ -61,11 +61,11 @@ def _postprocess_sample(self, x): return x def _encode( - self, x: Float[torch.Tensor, " b {self.n_bands} w h"] + self, + x: Float[torch.Tensor, " b {self.n_bands} w h"], + channel_mask: Bool[torch.Tensor, " b {self.n_bands}"], ) -> Float[torch.Tensor, " b c1 w1 h1"]: x = self._preprocess_sample(x) - batch_size = x.shape[0] - channel_mask = torch.zeros((batch_size, self.n_bands), device=x.device) x = self.subsample_in(x, channel_mask) h = self.encoder(x) h = self.pre_quant_proj(h) From 1cf702b44f5ec19c9fae05293404f55ce2c18863 Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Tue, 13 May 2025 18:00:39 -0400 Subject: [PATCH 04/12] Add test to ensure previous results consistency --- tests/tokenizers/test_image_tokenizer.py | 36 +++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/tests/tokenizers/test_image_tokenizer.py b/tests/tokenizers/test_image_tokenizer.py index 55243dd..35f2472 100644 --- a/tests/tokenizers/test_image_tokenizer.py +++ b/tests/tokenizers/test_image_tokenizer.py @@ -25,7 +25,41 @@ def test_magvit_image_tokenizer( ) batch_size = 4 random_input = torch.randn(batch_size, n_bands, 96, 96) - encoded = tokenizer.encode(random_input) + channel_mask = torch.ones(batch_size, n_bands) + encoded = tokenizer.encode(random_input, channel_mask) assert encoded.shape == (batch_size, 24, 24) decoded = tokenizer.decode(encoded) assert decoded.shape == random_input.shape + + +def test_previous_predictions(): + quantizer = FiniteScalarQuantizer(levels=[7, 5, 5, 5, 5]) + codec = MagViTAEImageCodec( + embedding_dim=5, + hidden_dims=512, + mult_factor=10.0, + multisurvey_projection_dims=54, + n_bands=9, + n_compressions=2, + num_consecutive=4, + range_compression_factor=0.01, + quantizer=quantizer, + ) + + codec.model.load_state_dict(torch.load("new_model.pt")) + subsample_layers_checkpoints = torch.load("image_tokenizer_checkpoints.pt") + codec.subsample_in.load_state_dict(subsample_layers_checkpoints["subsample_in"]) + codec.subsample_out.load_state_dict(subsample_layers_checkpoints["subsample_out"]) + codec.pre_quant_proj.load_state_dict(subsample_layers_checkpoints["pre_quant_proj"]) + codec.post_quant_proj.load_state_dict( + subsample_layers_checkpoints["post_quant_proj"] + ) + + input_batch = torch.load("image_codec_test_input.pt") + reference_output = torch.load("image_codec_test_output.pt") + + with torch.no_grad(): + output = codec.encode( + input_batch["image"]["array"], input_batch["image"]["channel_mask"] + ) + assert torch.allclose(output, reference_output) From 8c37cc4f16e287e514b598100aaf37447e9e729a Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Tue, 13 May 2025 22:11:44 -0400 Subject: [PATCH 05/12] Make the tokenizer a pytorch module Ease model weight loading. --- aion/codecs/tokenizers/base.py | 9 ++++++++- aion/codecs/tokenizers/image.py | 7 ++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/aion/codecs/tokenizers/base.py b/aion/codecs/tokenizers/base.py index bf041a6..c354b5a 100644 --- a/aion/codecs/tokenizers/base.py +++ b/aion/codecs/tokenizers/base.py @@ -6,7 +6,7 @@ from aion.codecs.quantizers import Quantizer -class Codec(ABC): +class Codec(ABC, torch.nn.Module): """Abstract definition of a Codec. A codec embeds a specific type of data into a sequence of either @@ -53,6 +53,13 @@ def decode( """Encodes a given batch of samples into latent space.""" return self._decode(z) + def forward( + self, + x: Float[torch.Tensor, " b c *input_shape"], + channel_mask: Bool[torch.Tensor, " b c"], + ) -> Float[torch.Tensor, " b c1 *code_shape"]: + return self.encode(x, channel_mask) + class QuantizedCodec(Codec): def __init__(self, quantizer: Quantizer): diff --git a/aion/codecs/tokenizers/image.py b/aion/codecs/tokenizers/image.py index 93963d9..729877d 100644 --- a/aion/codecs/tokenizers/image.py +++ b/aion/codecs/tokenizers/image.py @@ -114,7 +114,7 @@ def __init__( mult_factor: Multiplication factor. """ # Get MagViT architecture - self.model = MagVitAE( + model = MagVitAE( n_bands=multisurvey_projection_dims, hidden_dims=hidden_dims, n_compressions=n_compressions, @@ -123,11 +123,12 @@ def __init__( super().__init__( n_bands, quantizer, - self.model.encode, - self.model.decode, + model.encode, + model.decode, hidden_dims, embedding_dim, multisurvey_projection_dims, range_compression_factor, mult_factor, ) + self.model = model From aa71c1e133f549bb188a1b142697cac48186055d Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Tue, 13 May 2025 22:18:34 -0400 Subject: [PATCH 06/12] Update test to load only one model checkpoint --- tests/tokenizers/test_image_tokenizer.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/tests/tokenizers/test_image_tokenizer.py b/tests/tokenizers/test_image_tokenizer.py index 35f2472..fded6cd 100644 --- a/tests/tokenizers/test_image_tokenizer.py +++ b/tests/tokenizers/test_image_tokenizer.py @@ -46,15 +46,7 @@ def test_previous_predictions(): quantizer=quantizer, ) - codec.model.load_state_dict(torch.load("new_model.pt")) - subsample_layers_checkpoints = torch.load("image_tokenizer_checkpoints.pt") - codec.subsample_in.load_state_dict(subsample_layers_checkpoints["subsample_in"]) - codec.subsample_out.load_state_dict(subsample_layers_checkpoints["subsample_out"]) - codec.pre_quant_proj.load_state_dict(subsample_layers_checkpoints["pre_quant_proj"]) - codec.post_quant_proj.load_state_dict( - subsample_layers_checkpoints["post_quant_proj"] - ) - + codec.load_state_dict(torch.load("image_codec.pt")) input_batch = torch.load("image_codec_test_input.pt") reference_output = torch.load("image_codec_test_output.pt") From a6c157480fa5091bc21c3ee9da5ac88145809ffc Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Fri, 23 May 2025 16:58:56 +0200 Subject: [PATCH 07/12] Update SpectrumCodec Add tests for SpectrumCodec --- .gitattributes | 3 ++ aion/codecs/quantizers/__init__.py | 6 +-- aion/codecs/tokenizers/spectrum.py | 26 +++++++----- tests/test_data/SPECTRUM_decoded_batch.pt | 3 ++ tests/test_data/SPECTRUM_encoded_batch.pt | 3 ++ tests/test_data/SPECTRUM_input_batch.pt | 3 ++ tests/tokenizers/test_spectrum_tokenizer.py | 44 +++++++++++++++++++++ 7 files changed, 75 insertions(+), 13 deletions(-) create mode 100644 tests/test_data/SPECTRUM_decoded_batch.pt create mode 100644 tests/test_data/SPECTRUM_encoded_batch.pt create mode 100644 tests/test_data/SPECTRUM_input_batch.pt create mode 100644 tests/tokenizers/test_spectrum_tokenizer.py diff --git a/.gitattributes b/.gitattributes index 0f3e9c0..6cb2653 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,3 +1,6 @@ tests/test_data/image_codec_decoded_batch.pt filter=lfs diff=lfs merge=lfs -text tests/test_data/image_codec_encoded_batch.pt filter=lfs diff=lfs merge=lfs -text tests/test_data/image_codec_input_batch.pt filter=lfs diff=lfs merge=lfs -text +tests/test_data/SPECTRUM_input_batch.pt filter=lfs diff=lfs merge=lfs -text +tests/test_data/SPECTRUM_decoded_batch.pt filter=lfs diff=lfs merge=lfs -text +tests/test_data/SPECTRUM_encoded_batch.pt filter=lfs diff=lfs merge=lfs -text diff --git a/aion/codecs/quantizers/__init__.py b/aion/codecs/quantizers/__init__.py index 1f264cc..d52c7b1 100644 --- a/aion/codecs/quantizers/__init__.py +++ b/aion/codecs/quantizers/__init__.py @@ -278,7 +278,7 @@ def encode(self, z: torch.Tensor) -> torch.Tensor: ) return indices - def reconstruct(self, codes: torch.Tensor) -> torch.Tensor: + def decode(self, codes: torch.Tensor) -> torch.Tensor: """Decodes the input code index into corresponding codebook entry of dimension (embedding_dim). """ @@ -334,7 +334,7 @@ def quantize(self, z: torch.Tensor) -> torch.Tensor: """Quantizes the input tensor z, returns the corresponding codebook index. """ - return self.reconstruct(self.encode(z)) + return self.decode(self.encode(z)) def encode(self, z: torch.Tensor) -> torch.Tensor: """Encodes the input tensor z, returns the corresponding @@ -344,7 +344,7 @@ def encode(self, z: torch.Tensor) -> torch.Tensor: torch.bucketize(z, self.buckets, out_int32=True), 0, self.codebook_size - 1 ) - def reconstruct(self, codes: torch.Tensor) -> torch.Tensor: + def decode(self, codes: torch.Tensor) -> torch.Tensor: """Decodes the input code index into corresponding codebook entry of dimension (embedding_dim). """ diff --git a/aion/codecs/tokenizers/spectrum.py b/aion/codecs/tokenizers/spectrum.py index 1633bf9..c2356dc 100644 --- a/aion/codecs/tokenizers/spectrum.py +++ b/aion/codecs/tokenizers/spectrum.py @@ -1,9 +1,10 @@ import torch +from huggingface_hub import PyTorchModelHubMixin from jaxtyping import Float, Real from aion.codecs.modules.convnext import ConvNextDecoder1d, ConvNextEncoder1d from aion.codecs.modules.spectrum import LatentSpectralGrid -from aion.codecs.quantizers import Quantizer +from aion.codecs.quantizers import LucidrainsLFQ, Quantizer, ScalarLinearQuantizer from aion.codecs.tokenizers.base import QuantizedCodec @@ -112,9 +113,9 @@ def decode( # Extract the normalization token from the sequence norm_token, z = z[..., 0], z[..., 1:] - normalization = self.normalization_quantizer.reconstruct(norm_token) + normalization = self.normalization_quantizer.decode(norm_token) - z = self.quantizer.reconstruct(z) + z = self.quantizer.decode(z) # The wavelength grid to decode the spectrum return self._decode(z, wavelength=wavelength, normalization=normalization) @@ -162,13 +163,11 @@ def _decode( return spectra, wavelength, mask -class ConvNextAESpectrumCodec(AutoencoderSpectrumCodec): +class SpectrumCodec(AutoencoderSpectrumCodec, PyTorchModelHubMixin): """Spectrum codec based on convnext blocks.""" def __init__( self, - quantizer: Quantizer, - normalization_quantizer: Quantizer, encoder_depths: tuple[int, ...] = (3, 3, 9, 3), encoder_dims: tuple[int, ...] = (96, 192, 384, 768), decoder_depths: tuple[int, ...] = (3, 3, 9, 3), @@ -181,25 +180,32 @@ def __init__( clip_ivar: float = 100, clip_flux: float | None = None, input_scaling: float = 0.2, + normalization_range: tuple[float, float] = (-1, 5), + codebook_size: int = 1024, + dim: int = 10, ): assert encoder_dims[-1] == latent_channels, ( "Last encoder dim must match latent_channels" ) - self.encoder = ConvNextEncoder1d( + quantizer = LucidrainsLFQ(dim=dim, codebook_size=codebook_size) + normalization_quantizer = ScalarLinearQuantizer( + codebook_size=codebook_size, range=normalization_range + ) + encoder = ConvNextEncoder1d( in_chans=2, depths=encoder_depths, dims=encoder_dims, ) - self.decoder = ConvNextDecoder1d( + decoder = ConvNextDecoder1d( in_chans=latent_channels, depths=decoder_depths, dims=decoder_dims, ) super().__init__( quantizer=quantizer, - encoder=self.encoder, - decoder=self.decoder, + encoder=encoder, + decoder=decoder, normalization_quantizer=normalization_quantizer, lambda_min=lambda_min, resolution=resolution, diff --git a/tests/test_data/SPECTRUM_decoded_batch.pt b/tests/test_data/SPECTRUM_decoded_batch.pt new file mode 100644 index 0000000..57e48d4 --- /dev/null +++ b/tests/test_data/SPECTRUM_decoded_batch.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7f01193b6a2e5284e419451c614157955559d4500ae4614da0ab4405e70d85cb +size 13371261 diff --git a/tests/test_data/SPECTRUM_encoded_batch.pt b/tests/test_data/SPECTRUM_encoded_batch.pt new file mode 100644 index 0000000..2e113e1 --- /dev/null +++ b/tests/test_data/SPECTRUM_encoded_batch.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:67b58211ca19f193da68a2a281ff01f3ab50bddb950bf7a20b89cd638ad41d5d +size 280871 diff --git a/tests/test_data/SPECTRUM_input_batch.pt b/tests/test_data/SPECTRUM_input_batch.pt new file mode 100644 index 0000000..f960fac --- /dev/null +++ b/tests/test_data/SPECTRUM_input_batch.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:524953206a9e1f439984d31c56b49d9fe1af40bcdcafed314763f91a88426c85 +size 16975169 diff --git a/tests/tokenizers/test_spectrum_tokenizer.py b/tests/tokenizers/test_spectrum_tokenizer.py new file mode 100644 index 0000000..4450502 --- /dev/null +++ b/tests/tokenizers/test_spectrum_tokenizer.py @@ -0,0 +1,44 @@ +import torch + +from aion.codecs.tokenizers.spectrum import SpectrumCodec + + +def test_hf_previous_predictions(data_dir): + codec = SpectrumCodec.from_pretrained("polymathic-ai/aion-spectrum-codec") + + input_batch = torch.load(data_dir / "SPECTRUM_input_batch.pt", weights_only=False)[ + "spectrum" + ] + reference_encoded_output = torch.load( + data_dir / "SPECTRUM_encoded_batch.pt", weights_only=False + ) + reference_decoded_output = torch.load( + data_dir / "SPECTRUM_decoded_batch.pt", weights_only=False + ) + + with torch.no_grad(): + encoded_output = codec.encode( + input_batch["flux"], + input_batch["ivar"], + input_batch["mask"], + input_batch["lambda"], + ) + assert encoded_output.shape == reference_encoded_output.shape + assert torch.allclose(encoded_output, reference_encoded_output) + + flux, wavelength, mask = codec.decode(encoded_output) + assert flux.shape == reference_decoded_output["spectrum"]["flux"].shape + assert torch.allclose( + flux, reference_decoded_output["spectrum"]["flux"], rtol=1e-3, atol=1e-4 + ) + assert wavelength.shape == reference_decoded_output["spectrum"]["lambda"].shape + assert torch.allclose( + wavelength, + reference_decoded_output["spectrum"]["lambda"], + rtol=1e-3, + atol=1e-4, + ) + assert mask.shape == reference_decoded_output["spectrum"]["mask"].shape + assert torch.allclose( + mask, reference_decoded_output["spectrum"]["mask"], rtol=1e-3, atol=1e-4 + ) From b8549c128b9c0d36e7147ada7f8df49f7cc27356 Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Fri, 23 May 2025 17:02:54 +0200 Subject: [PATCH 08/12] Add vector_quantize_pytorch to dependencies --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index a59cbdc..e25e76e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "numpy", "tokenizers>=0.15.2", "torch>=2.4.0", + "vector_quantize_pytorch", ] [project.optional-dependencies] From 213fad9f780777331d6bc9c77af2edc6ace090f5 Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Fri, 23 May 2025 17:08:32 +0200 Subject: [PATCH 09/12] Specify version of vector_quantize_pytorch --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e25e76e..cf9e505 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ dependencies = [ "numpy", "tokenizers>=0.15.2", "torch>=2.4.0", - "vector_quantize_pytorch", + "vector_quantize_pytorch==1.14.30", ] [project.optional-dependencies] From 5bba1fbd1c4fa52fe8fb0cd11941d4991395944b Mon Sep 17 00:00:00 2001 From: Francois Lanusse Date: Fri, 23 May 2025 20:23:32 +0200 Subject: [PATCH 10/12] update spectrum --- aion/codecs/tokenizers/spectrum.py | 84 ++++++++++++--------- aion/modalities.py | 26 ++++++- tests/tokenizers/test_spectrum_tokenizer.py | 47 +++++++++--- 3 files changed, 110 insertions(+), 47 deletions(-) diff --git a/aion/codecs/tokenizers/spectrum.py b/aion/codecs/tokenizers/spectrum.py index c2356dc..2bb7461 100644 --- a/aion/codecs/tokenizers/spectrum.py +++ b/aion/codecs/tokenizers/spectrum.py @@ -1,14 +1,16 @@ import torch from huggingface_hub import PyTorchModelHubMixin from jaxtyping import Float, Real +from typing import Type +from aion.modalities import Spectrum from aion.codecs.modules.convnext import ConvNextDecoder1d, ConvNextEncoder1d from aion.codecs.modules.spectrum import LatentSpectralGrid from aion.codecs.quantizers import LucidrainsLFQ, Quantizer, ScalarLinearQuantizer -from aion.codecs.tokenizers.base import QuantizedCodec +from aion.codecs.tokenizers.base import Codec -class AutoencoderSpectrumCodec(QuantizedCodec): +class AutoencoderSpectrumCodec(Codec): """Meta-class for autoencoder codecs for spectra, does not actually contains a network.""" def __init__( @@ -26,7 +28,8 @@ def __init__( clip_flux: float | None = None, input_scaling: float = 0.2, ): - super().__init__(quantizer) + super().__init__() + self._quantizer = quantizer self.encoder = encoder self.decoder = decoder self.normalization_quantizer = normalization_quantizer @@ -42,16 +45,20 @@ def __init__( self.post_quant_conv = torch.nn.Conv1d(embedding_dim, latent_channels, 1) @property - def modality(self): - return "spectrum" + def modality(self) -> Type[Spectrum]: + return Spectrum + + @property + def quantizer(self) -> Quantizer: + return self._quantizer + + def _encode(self, x: Spectrum) -> Float[torch.Tensor, "b c t"]: + # Extract fields from Spectrum instance + flux = x.flux + ivar = x.ivar + mask = x.mask + wavelength = x.wavelength - def _encode( - self, - flux: Float[torch.Tensor, " b t"], - ivar: Float[torch.Tensor, " b t"], - mask: Float[torch.Tensor, " b t"], - wavelength: Float[torch.Tensor, " b t"], - ) -> tuple[Float[torch.Tensor, " b c t"], Float[torch.Tensor, " b"]]: # Robustify the model against NaN values in the input # And add optional cliping of extreme values spectrum = torch.nan_to_num(flux) @@ -91,25 +98,34 @@ def _encode( h = self.quant_conv(h) return h, normalization - def encode( - self, - flux: Float[torch.Tensor, " b t"], - ivar: Float[torch.Tensor, " b t"], - mask: Float[torch.Tensor, " b t"], - wavelength: Float[torch.Tensor, " b t"], - ) -> Real[torch.Tensor, " b code"]: - embedding, normalization = self._encode(flux, ivar, mask, wavelength) + def encode(self, x: Spectrum) -> Real[torch.Tensor, " b code"]: + # Override to handle normalization token + # First verify input type + if not isinstance(x, self.modality): + raise ValueError( + f"Input type {type(x).__name__} does not match the modality of the codec {self.modality.__name__}" + ) + # Get embedding using _encode + embedding, normalization = self._encode(x) + + # Quantize embedding embedding = self.quantizer.encode(embedding) + # Quantize normalization normalization = self.normalization_quantizer.encode(normalization) + + # Concatenate normalization token with embedding embedding = torch.cat([normalization[..., None], embedding], dim=-1) return embedding def decode( - self, z, wavelength: Float[torch.Tensor, " b t"] | None = None - ) -> Float[torch.Tensor, " b t"]: + self, + z: Real[torch.Tensor, " b code"], + wavelength: Float[torch.Tensor, " b t"] | None = None, + ) -> Spectrum: + # Override to handle normalization token extraction # Extract the normalization token from the sequence norm_token, z = z[..., 0], z[..., 1:] @@ -117,20 +133,14 @@ def decode( z = self.quantizer.decode(z) - # The wavelength grid to decode the spectrum - return self._decode(z, wavelength=wavelength, normalization=normalization) + return self._decode(z, normalization=normalization, wavelength=wavelength) def _decode( self, z: Float[torch.Tensor, " b c l"], + normalization: Float[torch.Tensor, " b"], wavelength: Float[torch.Tensor, " b t"] | None = None, - normalization: Float[torch.Tensor, " b"] | None = None, - sigmoid_and_round_mask: bool = True, - ) -> tuple[ - Float[torch.Tensor, " b t"], - Float[torch.Tensor, " b t"], - Float[torch.Tensor, " b t"], - ]: + ) -> Spectrum: h = self.post_quant_conv(z) spectra = self.decoder(h) @@ -157,10 +167,16 @@ def _decode( 10 ** normalization[..., None] - 1.0, 0.1 ) - if sigmoid_and_round_mask: - mask = torch.round(torch.sigmoid(mask)).detach() + # Round mask + mask = torch.round(torch.sigmoid(mask)).detach() - return spectra, wavelength, mask + # Return Spectrum instance + return Spectrum( + flux=spectra, + ivar=torch.ones_like(spectra), # We don't decode ivar, so set to ones + mask=mask, + wavelength=wavelength, + ) class SpectrumCodec(AutoencoderSpectrumCodec, PyTorchModelHubMixin): diff --git a/aion/modalities.py b/aion/modalities.py index 3e16ddd..a7491d1 100644 --- a/aion/modalities.py +++ b/aion/modalities.py @@ -30,5 +30,29 @@ def __repr__(self) -> str: return repr_str +class Spectrum(Modality): + """Spectrum modality data. + + Represents astronomical spectra with flux measurements, inverse variance, mask, and wavelength information. + """ + + flux: Float[Tensor, "batch length"] = Field( + description="Array of flux measurements of the spectrum." + ) + ivar: Float[Tensor, "batch length"] = Field( + description="Array of inverse variance values for the spectrum." + ) + mask: Float[Tensor, "batch length"] = Field( + description="Mask array indicating valid/invalid values in the spectrum." + ) + wavelength: Float[Tensor, "batch length"] = Field( + description="Array of wavelength values in Angstroms." + ) + + def __repr__(self) -> str: + repr_str = f"Spectrum(flux_shape={list(self.flux.shape)}, wavelength_range=[{self.wavelength.min().item():.1f}, {self.wavelength.max().item():.1f}])" + return repr_str + + # Convenience type for any modality data -ModalityType = Union[Image,] +ModalityType = Union[Image, Spectrum] diff --git a/tests/tokenizers/test_spectrum_tokenizer.py b/tests/tokenizers/test_spectrum_tokenizer.py index 4450502..6478e65 100644 --- a/tests/tokenizers/test_spectrum_tokenizer.py +++ b/tests/tokenizers/test_spectrum_tokenizer.py @@ -1,5 +1,6 @@ import torch +from aion.modalities import Spectrum from aion.codecs.tokenizers.spectrum import SpectrumCodec @@ -17,28 +18,50 @@ def test_hf_previous_predictions(data_dir): ) with torch.no_grad(): - encoded_output = codec.encode( - input_batch["flux"], - input_batch["ivar"], - input_batch["mask"], - input_batch["lambda"], + # Create Spectrum modality instance + spectrum_input = Spectrum( + flux=input_batch["flux"], + ivar=input_batch["ivar"], + mask=input_batch["mask"], + wavelength=input_batch["lambda"], ) + + encoded_output = codec.encode(spectrum_input) assert encoded_output.shape == reference_encoded_output.shape assert torch.allclose(encoded_output, reference_encoded_output) - flux, wavelength, mask = codec.decode(encoded_output) - assert flux.shape == reference_decoded_output["spectrum"]["flux"].shape + # Decode - the custom decode method handles the wavelength internally + decoded_spectrum = codec.decode( + encoded_output, wavelength=input_batch["lambda"] + ) + + assert ( + decoded_spectrum.flux.shape + == reference_decoded_output["spectrum"]["flux"].shape + ) assert torch.allclose( - flux, reference_decoded_output["spectrum"]["flux"], rtol=1e-3, atol=1e-4 + decoded_spectrum.flux, + reference_decoded_output["spectrum"]["flux"], + rtol=1e-3, + atol=1e-4, + ) + assert ( + decoded_spectrum.wavelength.shape + == reference_decoded_output["spectrum"]["lambda"].shape ) - assert wavelength.shape == reference_decoded_output["spectrum"]["lambda"].shape assert torch.allclose( - wavelength, + decoded_spectrum.wavelength, reference_decoded_output["spectrum"]["lambda"], rtol=1e-3, atol=1e-4, ) - assert mask.shape == reference_decoded_output["spectrum"]["mask"].shape + assert ( + decoded_spectrum.mask.shape + == reference_decoded_output["spectrum"]["mask"].shape + ) assert torch.allclose( - mask, reference_decoded_output["spectrum"]["mask"], rtol=1e-3, atol=1e-4 + decoded_spectrum.mask, + reference_decoded_output["spectrum"]["mask"], + rtol=1e-3, + atol=1e-4, ) From 740b3fcce9a39fac583eab4762fac188e27f387b Mon Sep 17 00:00:00 2001 From: EiffL Date: Fri, 23 May 2025 14:53:04 -0400 Subject: [PATCH 11/12] fixed issues --- aion/codecs/tokenizers/spectrum.py | 2 +- aion/modalities.py | 4 ++-- tests/tokenizers/test_spectrum_tokenizer.py | 7 ++----- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/aion/codecs/tokenizers/spectrum.py b/aion/codecs/tokenizers/spectrum.py index 2bb7461..40a0e86 100644 --- a/aion/codecs/tokenizers/spectrum.py +++ b/aion/codecs/tokenizers/spectrum.py @@ -168,7 +168,7 @@ def _decode( ) # Round mask - mask = torch.round(torch.sigmoid(mask)).detach() + mask = torch.round(torch.sigmoid(mask)).bool().detach() # Return Spectrum instance return Spectrum( diff --git a/aion/modalities.py b/aion/modalities.py index a7491d1..14d2f26 100644 --- a/aion/modalities.py +++ b/aion/modalities.py @@ -2,7 +2,7 @@ from typing import List, Union from pydantic import BaseModel, Field, ConfigDict -from jaxtyping import Float +from jaxtyping import Float, Bool from torch import Tensor @@ -42,7 +42,7 @@ class Spectrum(Modality): ivar: Float[Tensor, "batch length"] = Field( description="Array of inverse variance values for the spectrum." ) - mask: Float[Tensor, "batch length"] = Field( + mask: Bool[Tensor, "batch length"] = Field( description="Mask array indicating valid/invalid values in the spectrum." ) wavelength: Float[Tensor, "batch length"] = Field( diff --git a/tests/tokenizers/test_spectrum_tokenizer.py b/tests/tokenizers/test_spectrum_tokenizer.py index 6478e65..a0887c0 100644 --- a/tests/tokenizers/test_spectrum_tokenizer.py +++ b/tests/tokenizers/test_spectrum_tokenizer.py @@ -30,9 +30,8 @@ def test_hf_previous_predictions(data_dir): assert encoded_output.shape == reference_encoded_output.shape assert torch.allclose(encoded_output, reference_encoded_output) - # Decode - the custom decode method handles the wavelength internally decoded_spectrum = codec.decode( - encoded_output, wavelength=input_batch["lambda"] + encoded_output ) assert ( @@ -61,7 +60,5 @@ def test_hf_previous_predictions(data_dir): ) assert torch.allclose( decoded_spectrum.mask, - reference_decoded_output["spectrum"]["mask"], - rtol=1e-3, - atol=1e-4, + reference_decoded_output["spectrum"]["mask"].bool() ) From 1730a0160c667bc831e02994d7a1815f57f62e63 Mon Sep 17 00:00:00 2001 From: Francois Lanusse Date: Fri, 23 May 2025 20:54:35 +0200 Subject: [PATCH 12/12] fixing formatting --- tests/tokenizers/test_spectrum_tokenizer.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/tokenizers/test_spectrum_tokenizer.py b/tests/tokenizers/test_spectrum_tokenizer.py index a0887c0..07dd6f6 100644 --- a/tests/tokenizers/test_spectrum_tokenizer.py +++ b/tests/tokenizers/test_spectrum_tokenizer.py @@ -30,9 +30,7 @@ def test_hf_previous_predictions(data_dir): assert encoded_output.shape == reference_encoded_output.shape assert torch.allclose(encoded_output, reference_encoded_output) - decoded_spectrum = codec.decode( - encoded_output - ) + decoded_spectrum = codec.decode(encoded_output) assert ( decoded_spectrum.flux.shape @@ -59,6 +57,5 @@ def test_hf_previous_predictions(data_dir): == reference_decoded_output["spectrum"]["mask"].shape ) assert torch.allclose( - decoded_spectrum.mask, - reference_decoded_output["spectrum"]["mask"].bool() + decoded_spectrum.mask, reference_decoded_output["spectrum"]["mask"].bool() )