diff --git a/.gitattributes b/.gitattributes index 0f3e9c0..6cb2653 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,3 +1,6 @@ tests/test_data/image_codec_decoded_batch.pt filter=lfs diff=lfs merge=lfs -text tests/test_data/image_codec_encoded_batch.pt filter=lfs diff=lfs merge=lfs -text tests/test_data/image_codec_input_batch.pt filter=lfs diff=lfs merge=lfs -text +tests/test_data/SPECTRUM_input_batch.pt filter=lfs diff=lfs merge=lfs -text +tests/test_data/SPECTRUM_decoded_batch.pt filter=lfs diff=lfs merge=lfs -text +tests/test_data/SPECTRUM_encoded_batch.pt filter=lfs diff=lfs merge=lfs -text diff --git a/aion/codecs/modules/convnext.py b/aion/codecs/modules/convnext.py new file mode 100644 index 0000000..80cc7d0 --- /dev/null +++ b/aion/codecs/modules/convnext.py @@ -0,0 +1,172 @@ +import torch + +from aion.codecs.modules.utils import LayerNorm, GRN + + +class ConvNextBlock1d(torch.nn.Module): + """ConvNeXtV2 Block. + Modified to 1D from the original 2D implementation from https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + """ + + def __init__(self, dim: int): + super().__init__() + self.dwconv = torch.nn.Conv1d( + dim, dim, kernel_size=7, padding=3, groups=dim + ) # depthwise conv + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = torch.nn.Linear( + dim, 4 * dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = torch.nn.GELU() + self.grn = GRN(4 * dim) + self.pwconv2 = torch.nn.Linear(4 * dim, dim) + + def forward(self, x): + y = self.dwconv(x) + y = y.permute(0, 2, 1) # (B, C, N) -> (B, N, C) + y = self.norm(y) + y = self.pwconv1(y) + y = self.act(y) + y = self.grn(y) + y = self.pwconv2(y) + y = y.permute(0, 2, 1) # (B, N, C) -> (B, C, N) + + y = x + y + return y + + +class ConvNextEncoder1d(torch.nn.Module): + r"""ConvNeXt encoder. + + Modified from https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py + + Args: + in_chans : Number of input image channels. Default: 3 + depths : Number of blocks at each stage. Default: [3, 3, 9, 3] + dims : Feature dimension at each stage. Default: [96, 192, 384, 768] + drop_path_rate : Stochastic depth rate. Default: 0. + layer_scale_init_value : Init value for Layer Scale. Default: 1e-6. + """ + + def __init__( + self, + in_chans: int = 2, + depths: tuple[int, ...] = (3, 3, 9, 3), + dims: tuple[int, ...] = (96, 192, 384, 768), + ): + super().__init__() + assert len(depths) == len(dims), "depths and dims should have the same length" + num_layers = len(depths) + + self.downsample_layers = ( + torch.nn.ModuleList() + ) # stem and 3 intermediate downsampling conv layers + stem = torch.nn.Sequential( + torch.nn.Conv1d(in_chans, dims[0], kernel_size=4, stride=4), + LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), + ) + self.downsample_layers.append(stem) + for i in range(num_layers - 1): + downsample_layer = torch.nn.Sequential( + LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), + torch.nn.Conv1d(dims[i], dims[i + 1], kernel_size=2, stride=2), + ) + self.downsample_layers.append(downsample_layer) + + self.stages = torch.nn.ModuleList() + for i in range(num_layers): + stage = torch.nn.Sequential( + *[ + ConvNextBlock1d( + dim=dims[i], + ) + for j in range(depths[i]) + ] + ) + self.stages.append(stage) + + self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first") + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (torch.nn.Conv1d, torch.nn.Linear)): + torch.nn.init.trunc_normal_(m.weight, std=0.02) + torch.nn.init.constant_(m.bias, 0) + + def forward(self, x): + for ds, st in zip(self.downsample_layers, self.stages): + x = ds(x) + x = st(x) + return self.norm(x) + + +class ConvNextDecoder1d(torch.nn.Module): + r"""ConvNeXt decoder. Essentially a mirrored version of the encoder. + + Args: + in_chans (int): Number of input image channels. Default: 3 + depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] + dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] + drop_path_rate (float): Stochastic depth rate. Default: 0. + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + + def __init__( + self, + in_chans=768, + depths=[3, 3, 9, 3], + dims=[384, 192, 96, 2], + ): + super().__init__() + assert len(depths) == len(dims), "depths and dims should have the same length" + num_layers = len(depths) + + self.upsample_layers = torch.nn.ModuleList() + + stem = torch.nn.Sequential( + torch.nn.ConvTranspose1d(in_chans, dims[0], kernel_size=2, stride=2), + LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), + ) + self.upsample_layers.append(stem) + + for i in range(num_layers - 1): + upsample_layer = torch.nn.Sequential( + LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), + torch.nn.ConvTranspose1d( + dims[i], + dims[i + 1], + kernel_size=2 if i < (num_layers - 2) else 4, + stride=2 if i < (num_layers - 2) else 4, + ), + ) + self.upsample_layers.append(upsample_layer) + + self.stages = torch.nn.ModuleList() + for i in range(num_layers): + stage = torch.nn.Sequential( + *[ + ConvNextBlock1d( + dim=dims[i], + ) + for j in range(depths[i]) + ] + ) + self.stages.append(stage) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (torch.nn.Conv1d, torch.nn.Linear)): + torch.nn.init.trunc_normal_(m.weight, std=0.02) + torch.nn.init.constant_(m.bias, 0) + + def forward(self, x): + for us, st in zip(self.upsample_layers, self.stages): + x = us(x) + x = st(x) + return x diff --git a/aion/codecs/modules/spectrum.py b/aion/codecs/modules/spectrum.py new file mode 100644 index 0000000..3df6a4a --- /dev/null +++ b/aion/codecs/modules/spectrum.py @@ -0,0 +1,110 @@ +import torch +from jaxtyping import Float + + +def interp1d( + x: Float[torch.Tensor, " b n"], + y: Float[torch.Tensor, " b n"], + xnew: Float[torch.Tensor, " b m"], + mask_value: float | None = 0.0, +) -> Float[torch.Tensor, " b m"]: + """Linear interpolation of a 1-D tensor using torch.searchsorted. + Assumes that x and xnew are sorted in increasing order. + + Args: + x: The x-coordinates of the data points, shape [batch, N]. + y: The y-coordinates of the data points, shape [batch, N]. + xnew: The x-coordinates of the interpolated points, shape [batch, M]. + mask_value: The value to use for xnew outside the range of x. + Returns: + The y-coordinates of the interpolated points, shape [batch, M]. + """ + # Find the indices where xnew should be inserted in sorted_x + # Given a point xnew[i] in xnew, return j where x[j] is the nearest point in x such that + # x[j] < xnew[i], except if the nearest point in x has x[j] = xnew[i] then return j - 1. + indices = torch.searchsorted(x, xnew) - 1 + + # We can define a local linear approx of the grad in each interval + # between two points in x, and we would like to use this to interpolate + # y at those points in xnew which lie inside the range of x, otherwise + # interpolated_y is masked for points in xnew outside the range of x. + # There are len(x) - 1 such intervals between points in x, having indices + # ranging between 0 and len(x) - 2. Points with xnew < min(x) will be + # assigned indices of -1 and points with xnew > max(x) will be assigned + # indices equal to len(x). These are not valid segment indices, but we can + # clamp them to 0 and len(x) - 2 respectively to avoid breaking the + # calculation of the slope variable. The nonsense values we obtain outside + # the range of x will be discarded when masking. + indices = torch.clamp(indices, 0, x.shape[1] - 1 - 1) + + slopes = (y[:, :-1] - y[:, 1:]) / (x[:, :-1] - x[:, 1:]) + + # Interpolate the y-coordinates + ynew = torch.gather(y, 1, indices) + ( + xnew - torch.gather(x, 1, indices) + ) * torch.gather(slopes, 1, indices) + + # Mask out the values that are outside the valid range + mask = (xnew < x[..., 0].reshape(-1, 1)) | (xnew > x[..., -1].reshape(-1, 1)) + ynew[mask] = mask_value + + return ynew + + +class LatentSpectralGrid(torch.nn.Module): + def __init__(self, lambda_min: float, resolution: float, num_pixels: int): + """ + Initialize a latent grid to represent spectra from multiple resolutions. + + Args: + lambda_min: The minimum wavelength value, in Angstrom. + resolution: The resolution of the spectra, in Angstrom per pixel. + num_pixels: The number of pixels in the spectra. + + """ + super().__init__() + self.register_buffer("lambda_min", torch.tensor(lambda_min)) + self.register_buffer("resolution", torch.tensor(resolution)) + self.register_buffer("length", torch.tensor(num_pixels)) + self.register_buffer( + "_wavelength", + (torch.arange(0, num_pixels) * resolution + lambda_min).reshape( + 1, num_pixels + ), + ) + + @property + def wavelength(self) -> Float[torch.Tensor, " n"]: + return self._wavelength.squeeze() + + def to_observed( + self, + x_latent: Float[torch.Tensor, " b n"], + wavelength: Float[torch.Tensor, " b m"], + ) -> Float[torch.Tensor, " b m"]: + """Transforms the latent representation to the observed wavelength grid. + + Args: + x_latent: The latent representation, [batch, self.num_pixels]. + wavelength: The observed wavelength grid, [batch, M]. + + Returns: + The transformed representation on the observed wavelength grid. + """ + b = x_latent.shape[0] + return interp1d(self._wavelength.repeat([b, 1]), x_latent, wavelength) + + def to_latent( + self, x_obs: Float[torch.Tensor, "b m"], wavelength: Float[torch.Tensor, "b m"] + ) -> Float[torch.Tensor, "b n"]: + """Transforms the observed representation to the latent wavelength grid. + + Args: + x_obs: The observed representation, [batch, N]. + wavelength: The wavelength grid, [batch, N]. + + Returns: + The transformed representation on the latent wavelength grid. + """ + b = x_obs.shape[0] + return interp1d(wavelength, x_obs, self._wavelength.repeat([b, 1])) diff --git a/aion/codecs/modules/utils.py b/aion/codecs/modules/utils.py new file mode 100644 index 0000000..6b2907a --- /dev/null +++ b/aion/codecs/modules/utils.py @@ -0,0 +1,45 @@ +import torch +import torch.nn.functional as F +from einops import rearrange + + +class LayerNorm(torch.nn.Module): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(normalized_shape)) + self.bias = torch.nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape,) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm( + x, self.normalized_shape, self.weight, self.bias, self.eps + ) + elif self.data_format == "channels_first": + x = rearrange(x, "b c ... -> b ... c") + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return rearrange(x, "b ... c -> b c ...") + + +class GRN(torch.nn.Module): + """GRN (Global Response Normalization) layer""" + + def __init__(self, dim): + super().__init__() + self.gamma = torch.nn.Parameter(torch.zeros(1, 1, dim)) + self.beta = torch.nn.Parameter(torch.zeros(1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(1,), keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x diff --git a/aion/codecs/quantizers/__init__.py b/aion/codecs/quantizers/__init__.py index 83936b7..d52c7b1 100644 --- a/aion/codecs/quantizers/__init__.py +++ b/aion/codecs/quantizers/__init__.py @@ -3,6 +3,7 @@ import torch from jaxtyping import Float, Integer +from vector_quantize_pytorch import LFQ class Quantizer(torch.nn.Module, ABC): @@ -186,3 +187,180 @@ def forward( codes = self.encode(z_e) codebook_usage = len(torch.unique(codes)) / self.codebook_size return z_q, torch.zeros([]), codebook_usage + + +class LucidrainsLFQ(Quantizer): + def __init__( + self, + dim: int | None = None, + codebook_size: int | None = None, + inv_temperature: float = 100.0, + entropy_loss_weight: float = 0.1, + commitment_loss_weight: float = 0.25, + diversity_gamma: float = 1.0, + num_codebooks: int = 1, + keep_num_codebooks_dim: bool | None = None, + codebook_scale: float = 1.0, + frac_per_sample_entropy: float = 1.0, + use_code_agnostic_commit_loss: bool = False, + projection_has_bias: bool = True, + soft_clamp_input_value: bool | None = None, + cosine_sim_project_in: bool = False, + cosine_sim_project_in_scale: float | None = None, + ): + """Lookup Free Quantizer (LFQ) from the MagVITv2 paper + https://arxiv.org/abs/2310.05737 + + Following the implementation from vector-quantize-pytorch + """ + super().__init__() + self._inverse_temperature = inv_temperature + self._quantizer = LFQ( + dim=dim, + codebook_size=codebook_size, + entropy_loss_weight=entropy_loss_weight, + commitment_loss_weight=commitment_loss_weight, + diversity_gamma=diversity_gamma, + num_codebooks=num_codebooks, + keep_num_codebooks_dim=keep_num_codebooks_dim, + codebook_scale=codebook_scale, + frac_per_sample_entropy=frac_per_sample_entropy, + use_code_agnostic_commit_loss=use_code_agnostic_commit_loss, + projection_has_bias=projection_has_bias, + soft_clamp_input_value=soft_clamp_input_value, + cosine_sim_project_in=cosine_sim_project_in, + cosine_sim_project_in_scale=cosine_sim_project_in_scale, + ) + + def forward( + self, z_e: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Performs a forward pass through the vector quantizer. + Args: + z_e: Tensor (B, C, ...) + The input tensor to be quantized. + Returns: + z_q: Tensor + The quantized tensor. + loss: Tensor + The embedding loss for the quantization. + codebook_usage: Tensor + The fraction of codes used in the codebook. + """ + # In cases where we only have a sequence, we need to move the sequence dimension to the last dimension + # For compatibility with the upstream quantizer + if len(z_e.shape) == 3: + z_e = z_e.movedim(1, -1) + z_q, indices, aux_loss = self._quantizer( + z_e, inv_temperature=self._inverse_temperature + ) + codebook_usage = indices.unique().numel() / self.codebook_size + if len(z_q.shape) == 3: + z_q = z_q.movedim(-1, 1) + return z_q, aux_loss, torch.tensor(codebook_usage) + + def quantize(self, z: torch.Tensor) -> torch.Tensor: + """Quantizes the input tensor z, returns the corresponding + codebook index. + """ + return self.encode(z) + + def encode(self, z: torch.Tensor) -> torch.Tensor: + """Encodes the input tensor z, returns the corresponding + codebook index. + """ + # In cases where we only have a sequence, we need to move the sequence dimension to the last dimension + # For compatibility with the upstream quantizer + if len(z.shape) == 3: + z = z.movedim(1, -1) + z_q, indices, aux_loss = self._quantizer( + z, inv_temperature=self._inverse_temperature + ) + return indices + + def decode(self, codes: torch.Tensor) -> torch.Tensor: + """Decodes the input code index into corresponding codebook entry of + dimension (embedding_dim). + """ + z = self._quantizer.indices_to_codes(codes) + # For compatibility with the upstream quantizer, we need to move the last dimension to the sequence dimension + if len(z.shape) == 3: + z = z.movedim(-1, 1) + return z + + @property + def codebook_size(self) -> int: + """Returns the size of the codebook.""" + return len(self._quantizer.codebook) + + @property + def embedding_dim(self) -> int: + """Returns the dimension of the codebook entries.""" + return self._quantizer.codebook_dim + + +class ScalarLinearQuantizer(Quantizer): + """A simple non-adaptive quantizer which will encode scalars by binning + on fixed histogram in the specified range. + """ + + def __init__(self, codebook_size: int, range: tuple[float, float]): + super().__init__() + self.register_buffer( + "buckets", torch.linspace(range[0], range[1], codebook_size) + ) + + def forward( + self, z_e: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Performs a forward pass through the vector quantizer. + Args: + z_e: Tensor (B, C, ...) + The input tensor to be quantized. + Returns: + z_q: Tensor + The quantized tensor. + loss: Tensor + The embedding loss for the quantization. + codebook_usage: Tensor + The fraction of codes used in the codebook. + """ + indices = self.encode(z_e) + z_q = self.decode(indices) + codebook_usage = indices.unique().numel() / self.codebook_size + return z_q, torch.tensor(0), torch.tensor(codebook_usage) + + def quantize(self, z: torch.Tensor) -> torch.Tensor: + """Quantizes the input tensor z, returns the corresponding + codebook index. + """ + return self.decode(self.encode(z)) + + def encode(self, z: torch.Tensor) -> torch.Tensor: + """Encodes the input tensor z, returns the corresponding + codebook index. + """ + return torch.clamp( + torch.bucketize(z, self.buckets, out_int32=True), 0, self.codebook_size - 1 + ) + + def decode(self, codes: torch.Tensor) -> torch.Tensor: + """Decodes the input code index into corresponding codebook entry of + dimension (embedding_dim). + """ + return self.buckets[codes] + + @property + def codebook_size(self) -> int: + """Returns the size of the codebook.""" + return len(self.buckets) + + @property + def codebook(self) -> torch.Tensor: + """Returns the codebook.""" + return self.decode(torch.arange(self.codebook_size)) + + @property + def embedding_dim(self) -> int: + """Returns the dimension of the codebook entries.""" + return 1 diff --git a/aion/codecs/tokenizers/spectrum.py b/aion/codecs/tokenizers/spectrum.py new file mode 100644 index 0000000..40a0e86 --- /dev/null +++ b/aion/codecs/tokenizers/spectrum.py @@ -0,0 +1,234 @@ +import torch +from huggingface_hub import PyTorchModelHubMixin +from jaxtyping import Float, Real +from typing import Type + +from aion.modalities import Spectrum +from aion.codecs.modules.convnext import ConvNextDecoder1d, ConvNextEncoder1d +from aion.codecs.modules.spectrum import LatentSpectralGrid +from aion.codecs.quantizers import LucidrainsLFQ, Quantizer, ScalarLinearQuantizer +from aion.codecs.tokenizers.base import Codec + + +class AutoencoderSpectrumCodec(Codec): + """Meta-class for autoencoder codecs for spectra, does not actually contains a network.""" + + def __init__( + self, + quantizer: Quantizer, + encoder: torch.nn.Module, + decoder: torch.nn.Module, + normalization_quantizer: Quantizer, + lambda_min: float = 3500.0, + resolution: float = 0.8, + num_pixels: int = 8704, + latent_channels: int = 512, + embedding_dim: int = 4, + clip_ivar: float = 100, + clip_flux: float | None = None, + input_scaling: float = 0.2, + ): + super().__init__() + self._quantizer = quantizer + self.encoder = encoder + self.decoder = decoder + self.normalization_quantizer = normalization_quantizer + self.latent_grid = LatentSpectralGrid( + lambda_min=lambda_min, resolution=resolution, num_pixels=num_pixels + ) + self.embedding_dim = embedding_dim + self.clip_ivar = clip_ivar + self.clip_flux = clip_flux + self.input_scaling = input_scaling + self.pre_quant_norm = torch.nn.LayerNorm(latent_channels) + self.quant_conv = torch.nn.Conv1d(latent_channels, embedding_dim, 1) + self.post_quant_conv = torch.nn.Conv1d(embedding_dim, latent_channels, 1) + + @property + def modality(self) -> Type[Spectrum]: + return Spectrum + + @property + def quantizer(self) -> Quantizer: + return self._quantizer + + def _encode(self, x: Spectrum) -> Float[torch.Tensor, "b c t"]: + # Extract fields from Spectrum instance + flux = x.flux + ivar = x.ivar + mask = x.mask + wavelength = x.wavelength + + # Robustify the model against NaN values in the input + # And add optional cliping of extreme values + spectrum = torch.nan_to_num(flux) + if self.clip_flux is not None: + spectrum = torch.clamp(spectrum, -self.clip_flux, self.clip_flux) + ivar = torch.nan_to_num(ivar) + if self.clip_ivar is not None: + ivar = torch.clamp(ivar, 0, self.clip_ivar) + istd = torch.sqrt(ivar) + + # Normalize input spectrum + normalization = (spectrum * (1.0 - mask.float())).sum(dim=-1) / ( + torch.count_nonzero(~mask, dim=-1) + 1.0 + ) + + normalization = torch.clamp(normalization, 0.1) + + # Compressing the range of this normalization factor + normalization = torch.log10(normalization + 1.0) + + # Apply quantization to normalization factor + normalization = self.normalization_quantizer.quantize(normalization) + + # Normalize the spectrum + n = torch.clamp((10 ** normalization[..., None] - 1.0), 0.1) + spectrum = (spectrum / n - 1.0) * self.input_scaling + istd = (istd / n) * self.input_scaling + + # Project spectra on the latent grid + spectrum = self.latent_grid.to_latent(spectrum, wavelength) + istd = self.latent_grid.to_latent(istd, wavelength) + + # Apply additional range compression for good measure + x = torch.arcsinh(torch.stack([spectrum, istd], dim=1)) + h = self.encoder(x) + h = self.pre_quant_norm(h.moveaxis(1, -1)).moveaxis(-1, 1) + h = self.quant_conv(h) + return h, normalization + + def encode(self, x: Spectrum) -> Real[torch.Tensor, " b code"]: + # Override to handle normalization token + # First verify input type + if not isinstance(x, self.modality): + raise ValueError( + f"Input type {type(x).__name__} does not match the modality of the codec {self.modality.__name__}" + ) + + # Get embedding using _encode + embedding, normalization = self._encode(x) + + # Quantize embedding + embedding = self.quantizer.encode(embedding) + + # Quantize normalization + normalization = self.normalization_quantizer.encode(normalization) + + # Concatenate normalization token with embedding + embedding = torch.cat([normalization[..., None], embedding], dim=-1) + + return embedding + + def decode( + self, + z: Real[torch.Tensor, " b code"], + wavelength: Float[torch.Tensor, " b t"] | None = None, + ) -> Spectrum: + # Override to handle normalization token extraction + # Extract the normalization token from the sequence + norm_token, z = z[..., 0], z[..., 1:] + + normalization = self.normalization_quantizer.decode(norm_token) + + z = self.quantizer.decode(z) + + return self._decode(z, normalization=normalization, wavelength=wavelength) + + def _decode( + self, + z: Float[torch.Tensor, " b c l"], + normalization: Float[torch.Tensor, " b"], + wavelength: Float[torch.Tensor, " b t"] | None = None, + ) -> Spectrum: + h = self.post_quant_conv(z) + spectra = self.decoder(h) + + if spectra.shape[1] == 1: # just flux + spectra = spectra.squeeze(1) + mask = torch.ones_like(spectra) * -torch.inf + elif spectra.shape[1] == 2: # flux and mask + spectra, mask = spectra.chunk(2, dim=1) + spectra, mask = spectra.squeeze(1), mask.squeeze(1) + else: + raise ValueError("Invalid number of output channels, must be 1 or 2") + + # If the wavelength are provided, interpolate the spectrum on the observed grid + if wavelength is not None: + spectra = self.latent_grid.to_observed(spectra, wavelength) + mask = self.latent_grid.to_observed(mask, wavelength) + else: + b = spectra.shape[0] + wavelength = self.latent_grid.wavelength.reshape(1, -1).repeat(b, 1) + + # Decode the spectrum on the latent grid and apply normalization + if normalization is not None: + spectra = (spectra + 1.0) * torch.clamp( + 10 ** normalization[..., None] - 1.0, 0.1 + ) + + # Round mask + mask = torch.round(torch.sigmoid(mask)).bool().detach() + + # Return Spectrum instance + return Spectrum( + flux=spectra, + ivar=torch.ones_like(spectra), # We don't decode ivar, so set to ones + mask=mask, + wavelength=wavelength, + ) + + +class SpectrumCodec(AutoencoderSpectrumCodec, PyTorchModelHubMixin): + """Spectrum codec based on convnext blocks.""" + + def __init__( + self, + encoder_depths: tuple[int, ...] = (3, 3, 9, 3), + encoder_dims: tuple[int, ...] = (96, 192, 384, 768), + decoder_depths: tuple[int, ...] = (3, 3, 9, 3), + decoder_dims: tuple[int, ...] = (384, 192, 96, 1), + lambda_min: float = 3500.0, + resolution: float = 0.8, + num_pixels: int = 8704, + latent_channels: int = 512, + embedding_dim: int = 4, + clip_ivar: float = 100, + clip_flux: float | None = None, + input_scaling: float = 0.2, + normalization_range: tuple[float, float] = (-1, 5), + codebook_size: int = 1024, + dim: int = 10, + ): + assert encoder_dims[-1] == latent_channels, ( + "Last encoder dim must match latent_channels" + ) + quantizer = LucidrainsLFQ(dim=dim, codebook_size=codebook_size) + normalization_quantizer = ScalarLinearQuantizer( + codebook_size=codebook_size, range=normalization_range + ) + encoder = ConvNextEncoder1d( + in_chans=2, + depths=encoder_depths, + dims=encoder_dims, + ) + + decoder = ConvNextDecoder1d( + in_chans=latent_channels, + depths=decoder_depths, + dims=decoder_dims, + ) + super().__init__( + quantizer=quantizer, + encoder=encoder, + decoder=decoder, + normalization_quantizer=normalization_quantizer, + lambda_min=lambda_min, + resolution=resolution, + num_pixels=num_pixels, + latent_channels=latent_channels, + embedding_dim=embedding_dim, + clip_ivar=clip_ivar, + clip_flux=clip_flux, + input_scaling=input_scaling, + ) diff --git a/aion/modalities.py b/aion/modalities.py index 3e16ddd..14d2f26 100644 --- a/aion/modalities.py +++ b/aion/modalities.py @@ -2,7 +2,7 @@ from typing import List, Union from pydantic import BaseModel, Field, ConfigDict -from jaxtyping import Float +from jaxtyping import Float, Bool from torch import Tensor @@ -30,5 +30,29 @@ def __repr__(self) -> str: return repr_str +class Spectrum(Modality): + """Spectrum modality data. + + Represents astronomical spectra with flux measurements, inverse variance, mask, and wavelength information. + """ + + flux: Float[Tensor, "batch length"] = Field( + description="Array of flux measurements of the spectrum." + ) + ivar: Float[Tensor, "batch length"] = Field( + description="Array of inverse variance values for the spectrum." + ) + mask: Bool[Tensor, "batch length"] = Field( + description="Mask array indicating valid/invalid values in the spectrum." + ) + wavelength: Float[Tensor, "batch length"] = Field( + description="Array of wavelength values in Angstroms." + ) + + def __repr__(self) -> str: + repr_str = f"Spectrum(flux_shape={list(self.flux.shape)}, wavelength_range=[{self.wavelength.min().item():.1f}, {self.wavelength.max().item():.1f}])" + return repr_str + + # Convenience type for any modality data -ModalityType = Union[Image,] +ModalityType = Union[Image, Spectrum] diff --git a/pyproject.toml b/pyproject.toml index 0e7efde..7adba0b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "tokenizers>=0.15.2", "torch>=2.4.0", "pydantic>=2.10.6", + "vector_quantize_pytorch==1.14.30", ] [project.optional-dependencies] diff --git a/tests/test_data/SPECTRUM_decoded_batch.pt b/tests/test_data/SPECTRUM_decoded_batch.pt new file mode 100644 index 0000000..57e48d4 --- /dev/null +++ b/tests/test_data/SPECTRUM_decoded_batch.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7f01193b6a2e5284e419451c614157955559d4500ae4614da0ab4405e70d85cb +size 13371261 diff --git a/tests/test_data/SPECTRUM_encoded_batch.pt b/tests/test_data/SPECTRUM_encoded_batch.pt new file mode 100644 index 0000000..2e113e1 --- /dev/null +++ b/tests/test_data/SPECTRUM_encoded_batch.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:67b58211ca19f193da68a2a281ff01f3ab50bddb950bf7a20b89cd638ad41d5d +size 280871 diff --git a/tests/test_data/SPECTRUM_input_batch.pt b/tests/test_data/SPECTRUM_input_batch.pt new file mode 100644 index 0000000..f960fac --- /dev/null +++ b/tests/test_data/SPECTRUM_input_batch.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:524953206a9e1f439984d31c56b49d9fe1af40bcdcafed314763f91a88426c85 +size 16975169 diff --git a/tests/tokenizers/test_spectrum_tokenizer.py b/tests/tokenizers/test_spectrum_tokenizer.py new file mode 100644 index 0000000..07dd6f6 --- /dev/null +++ b/tests/tokenizers/test_spectrum_tokenizer.py @@ -0,0 +1,61 @@ +import torch + +from aion.modalities import Spectrum +from aion.codecs.tokenizers.spectrum import SpectrumCodec + + +def test_hf_previous_predictions(data_dir): + codec = SpectrumCodec.from_pretrained("polymathic-ai/aion-spectrum-codec") + + input_batch = torch.load(data_dir / "SPECTRUM_input_batch.pt", weights_only=False)[ + "spectrum" + ] + reference_encoded_output = torch.load( + data_dir / "SPECTRUM_encoded_batch.pt", weights_only=False + ) + reference_decoded_output = torch.load( + data_dir / "SPECTRUM_decoded_batch.pt", weights_only=False + ) + + with torch.no_grad(): + # Create Spectrum modality instance + spectrum_input = Spectrum( + flux=input_batch["flux"], + ivar=input_batch["ivar"], + mask=input_batch["mask"], + wavelength=input_batch["lambda"], + ) + + encoded_output = codec.encode(spectrum_input) + assert encoded_output.shape == reference_encoded_output.shape + assert torch.allclose(encoded_output, reference_encoded_output) + + decoded_spectrum = codec.decode(encoded_output) + + assert ( + decoded_spectrum.flux.shape + == reference_decoded_output["spectrum"]["flux"].shape + ) + assert torch.allclose( + decoded_spectrum.flux, + reference_decoded_output["spectrum"]["flux"], + rtol=1e-3, + atol=1e-4, + ) + assert ( + decoded_spectrum.wavelength.shape + == reference_decoded_output["spectrum"]["lambda"].shape + ) + assert torch.allclose( + decoded_spectrum.wavelength, + reference_decoded_output["spectrum"]["lambda"], + rtol=1e-3, + atol=1e-4, + ) + assert ( + decoded_spectrum.mask.shape + == reference_decoded_output["spectrum"]["mask"].shape + ) + assert torch.allclose( + decoded_spectrum.mask, reference_decoded_output["spectrum"]["mask"].bool() + )