diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 8edf8c6..12d9b0c 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -25,3 +25,7 @@ jobs: run: | python -m pip install --upgrade pip pip install . + - name: Run tests + run: | + pip install ".[dev]" + pytest tests diff --git a/.gitignore b/.gitignore index 15201ac..de88b1c 100644 --- a/.gitignore +++ b/.gitignore @@ -147,10 +147,13 @@ venv.bak/ /site # mypy -.mypy_cache/ +.mypy_cache .dmypy.json dmypy.json +# Ruff +.ruff_cache + # Pyre type checker .pyre/ diff --git a/aion/codecs/__init__.py b/aion/codecs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/aion/codecs/modules/__init__.py b/aion/codecs/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/aion/codecs/modules/magvit.py b/aion/codecs/modules/magvit.py new file mode 100644 index 0000000..83ce352 --- /dev/null +++ b/aion/codecs/modules/magvit.py @@ -0,0 +1,214 @@ +import torch +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +def cast_tuple(t, length=1): + return t if isinstance(t, tuple) else ((t,) * length) + + +class SameConv2d(torch.nn.Module): + def __init__(self, dim_in, dim_out, kernel_size): + super().__init__() + kernel_size = cast_tuple(kernel_size, 2) + padding = [k // 2 for k in kernel_size] + self.conv = torch.nn.Conv2d( + dim_in, dim_out, kernel_size=kernel_size, padding=padding + ) + + def forward(self, x: torch.Tensor): + return self.conv(x) + + +class SqueezeExcite(torch.nn.Module): + # global context network - attention-esque squeeze-excite variant (https://arxiv.org/abs/2012.13375) + + def __init__(self, dim, *, dim_out=None, dim_hidden_min=16, init_bias=-10): + super().__init__() + dim_out = dim_out if dim_out is not None else dim + + self.to_k = torch.nn.Conv2d(dim, 1, 1) + dim_hidden = max(dim_hidden_min, dim_out // 2) + + self.net = torch.nn.Sequential( + torch.nn.Conv2d(dim, dim_hidden, 1), + torch.nn.LeakyReLU(0.1), + torch.nn.Conv2d(dim_hidden, dim_out, 1), + torch.nn.Sigmoid(), + ) + + torch.nn.init.zeros_(self.net[-2].weight) + torch.nn.init.constant_(self.net[-2].bias, init_bias) + + def forward(self, x): + context = self.to_k(x) + + context = rearrange(context, "b c h w -> b c (h w)").softmax(dim=-1) + spatial_flattened_input = rearrange(x, "b c h w -> b c (h w)") + + out = torch.einsum("b i n, b c n -> b c i", context, spatial_flattened_input) + out = rearrange(out, "... -> ... 1") + gates = self.net(out) + + return gates * x + + +class ResidualUnit(torch.nn.Module): + def __init__(self, dim: int, kernel_size: int | tuple[int, int, int]): + super().__init__() + self.net = torch.nn.Sequential( + SameConv2d(dim, dim, kernel_size), + torch.nn.ELU(), + torch.nn.Conv2d(dim, dim, 1), + torch.nn.ELU(), + SqueezeExcite(dim), + ) + + def forward(self, x: torch.Tensor): + return self.net(x) + x + + +class SpatialDownsample2x(torch.nn.Module): + def __init__( + self, + dim: int, + dim_out: int = None, + kernel_size: int = 3, + ): + super().__init__() + dim_out = dim_out if dim_out is not None else dim + self.conv = torch.nn.Conv2d( + dim, dim_out, kernel_size, stride=2, padding=kernel_size // 2 + ) + + def forward(self, x: torch.Tensor): + out = self.conv(x) + return out + + +class SpatialUpsample2x(torch.nn.Module): + def __init__(self, dim: int, dim_out: int = None): + super().__init__() + dim_out = dim_out if dim_out is not None else dim + conv = torch.nn.Conv2d(dim, dim_out * 4, 1) + + self.net = torch.nn.Sequential( + conv, + torch.nn.SiLU(), + Rearrange("b (c p1 p2) h w -> b c (h p1) (w p2)", p1=2, p2=2), + ) + + self.init_conv_(conv) + + def init_conv_(self, conv: torch.nn.Module): + o, i, h, w = conv.weight.shape + conv_weight = torch.empty(o // 4, i, h, w) + torch.nn.init.kaiming_uniform_(conv_weight) + conv_weight = repeat(conv_weight, "o ... -> (o 4) ...") + + conv.weight.data.copy_(conv_weight) + torch.nn.init.zeros_(conv.bias.data) + + def forward(self, x: torch.Tensor): + out = self.net(x) + return out + + +class MagVitAE(torch.nn.Module): + """MagViTAE implementation from Yu, et al. (2024), adapted for Pytorch. + Code borrowed from https://github.com/lucidrains/magvit2-pytorch, and adapted for images. + """ + + def __init__( + self, + n_bands: int = 3, + hidden_dims: int = 512, + residual_conv_kernel_size: int = 3, + n_compressions: int = 2, + num_consecutive: int = 2, + ): + super().__init__() + + self.encoder_layers = torch.nn.ModuleList([]) + self.decoder_layers = torch.nn.ModuleList([]) + init_dim = int(hidden_dims / 2**n_compressions) + dim = init_dim + + self.conv_in = SameConv2d(n_bands, init_dim, 7) + self.conv_out = SameConv2d(init_dim, n_bands, 3) + + # Residual layers + encoder_layer = ResidualUnit(dim, residual_conv_kernel_size) + decoder_layer = ResidualUnit(dim, residual_conv_kernel_size) + self.encoder_layers.append(encoder_layer) + self.decoder_layers.insert(0, decoder_layer) + + # Compressions + for i in range(n_compressions): + dim_out = dim * 2 + encoder_layer = SpatialDownsample2x(dim, dim_out) + decoder_layer = SpatialUpsample2x(dim_out, dim) + self.encoder_layers.append(encoder_layer) + self.decoder_layers.insert(0, decoder_layer) + dim = dim_out + + # Consecutive residual layers + encoder_layer = torch.nn.Sequential( + *[ + ResidualUnit(dim, residual_conv_kernel_size) + for _ in range(num_consecutive) + ] + ) + decoder_layer = torch.nn.Sequential( + *[ + ResidualUnit(dim, residual_conv_kernel_size) + for _ in range(num_consecutive) + ] + ) + self.encoder_layers.append(encoder_layer) + self.decoder_layers.insert(0, decoder_layer) + + # Add a final non-compress layer + dim_out = dim + encoder_layer = SameConv2d(dim, dim_out, 7) + decoder_layer = SameConv2d(dim_out, dim, 3) + self.encoder_layers.append(encoder_layer) + self.decoder_layers.insert(0, decoder_layer) + dim = dim_out + + # Consecutive residual layers + encoder_layer = torch.nn.Sequential( + *[ + ResidualUnit(dim, residual_conv_kernel_size) + for _ in range(num_consecutive) + ] + ) + decoder_layer = torch.nn.Sequential( + *[ + ResidualUnit(dim, residual_conv_kernel_size) + for _ in range(num_consecutive) + ] + ) + self.encoder_layers.append(encoder_layer) + self.decoder_layers.insert(0, decoder_layer) + + # add a final norm just before quantization layer + self.encoder_layers.append( + torch.nn.Sequential( + Rearrange("b c ... -> b ... c"), + torch.nn.LayerNorm(dim), + Rearrange("b ... c -> b c ..."), + ) + ) + + def encode(self, x: torch.Tensor): + x = self.conv_in(x) + for layer in self.encoder_layers: + x = layer(x) + return x + + def decode(self, x: torch.Tensor): + for layer in self.decoder_layers: + x = layer(x) + x = self.conv_out(x) + return x diff --git a/aion/codecs/modules/subsampler.py b/aion/codecs/modules/subsampler.py new file mode 100644 index 0000000..af66f7a --- /dev/null +++ b/aion/codecs/modules/subsampler.py @@ -0,0 +1,60 @@ +import torch +import torch.nn.functional as F +from einops import rearrange +from jaxtyping import Bool, Float + + +class SubsampledLinear(torch.nn.Module): + def __init__(self, dim_in: int, dim_out: int, subsample_in: bool = True): + """ + Subsampled linear layer for the encoder. + It takes in a zero-padded tensor and a mask. + It projects the tensor into some shared projection space. + It can also be used to reverse out of the space with the mask. + + Args: + dim_in : Number of total possible bands. + dim_out : Number of embedding dimensions. + subsample_in : Whether to subsample the input. Defaults to True. + """ + super().__init__() + self.subsample_in = subsample_in + self.dim_in = dim_in # Number of total possible bands + self.dim_out = dim_out # Number of embedding dimensions + temp_linear = torch.nn.Linear(dim_in, dim_out) + self.weight = torch.nn.Parameter(temp_linear.weight) + self.bias = torch.nn.Parameter(temp_linear.bias) + + def _subsample_in(self, x, labels: Bool[torch.Tensor, " b c"]): + # Get mask + mask = labels[:, None, None, :].float() + x = x * mask + + # Normalize + label_sizes = labels.sum(dim=1, keepdim=True) + scales = ((self.dim_in / label_sizes) ** 0.5).squeeze() + + # Apply linear layer + return scales[:, None, None, None] * F.linear(x, self.weight, self.bias) + + def _subsample_out(self, x, labels): + # Get mask + mask = labels[:, None, None, :].float() + + # Apply linear layer and mask + return F.linear(x, self.weight, self.bias) * mask + + def forward( + self, x: Float[torch.Tensor, " b c h w"], labels: Bool[torch.Tensor, " b c"] + ) -> Float[torch.Tensor, " b c h w"]: + x = rearrange(x, "b c h w -> b h w c") + + if self.subsample_in: + x = self._subsample_in(x, labels) + + else: + x = self._subsample_out(x, labels) + + x = rearrange(x, "b h w c -> b c h w") + + return x diff --git a/aion/codecs/quantizers/__init__.py b/aion/codecs/quantizers/__init__.py new file mode 100644 index 0000000..40c8137 --- /dev/null +++ b/aion/codecs/quantizers/__init__.py @@ -0,0 +1,8 @@ +from .base import Quantizer +from .scalar import FiniteScaleQuantizer, IdentityQuantizer + +__all__ = [ + "FiniteScaleQuantizer", + "IdentityQuantizer", + "Quantizer", +] diff --git a/aion/codecs/quantizers/base.py b/aion/codecs/quantizers/base.py new file mode 100644 index 0000000..752245f --- /dev/null +++ b/aion/codecs/quantizers/base.py @@ -0,0 +1,40 @@ +from abc import ABC, abstractmethod + +import torch +from jaxtyping import Float + + +class Quantizer(torch.nn.Module, ABC): + """Abstract interface for all quantizer modules.""" + + @abstractmethod + def quantize( + self, x: Float[torch.Tensor, " b c1 *input_shape"] + ) -> Float[torch.Tensor, " b c *code_shape"]: + """Quantize the input tensor.""" + raise NotImplementedError + + @abstractmethod + def reconstruct( + self, z: Float[torch.Tensor, " b c *code_shape"] + ) -> Float[torch.Tensor, " b c *input_shape"]: + """Reconstruct the input tensor from the quantized tensor.""" + raise NotImplementedError + + @abstractmethod + def forward( + self, z_e: Float[torch.Tensor, " b c *input_shape"] + ) -> tuple[ + Float[torch.Tensor, " b c *code_shape"], + Float[torch.Tensor, " b"], + Float[torch.Tensor, " b"], + ]: + """Performs a forward pass through the vector quantizer. + Args: + x: The input tensor to be quantized. + Returns: + z: The quantized tensor. + quantization_error: The error of the quantization. + codebook_usage: The fraction of codes used in the codebook. + """ + raise NotImplementedError diff --git a/aion/codecs/quantizers/scalar.py b/aion/codecs/quantizers/scalar.py new file mode 100644 index 0000000..3834547 --- /dev/null +++ b/aion/codecs/quantizers/scalar.py @@ -0,0 +1,219 @@ +import math + +import torch +from jaxtyping import Float, Integer + +from aion.codecs.quantizers.base import Quantizer + + +class IdentityQuantizer(Quantizer): + """ + Identity quantizer module. + + The identity quantizer module takes a batch of tensors and returns the same tensor. + + Args: + codebook_size: int + The number of labels to be used as signature for the codebook. + """ + + def __init__(self, codebook_size: int): + super().__init__() + self._codebook_size = codebook_size + + def forward( + self, z_e: Float[torch.Tensor, " b c"] + ) -> tuple[ + Float[torch.Tensor, " b c"], + Float[torch.Tensor, " b"], + Float[torch.Tensor, " b"], + ]: + """Performs a forward pass through the vector quantizer. + Args: + z_e: The input tensor to be quantized. + Returns: + z_q: The quantized tensor. + loss: The embedding loss for the quantization. + codebook_usage: The fraction of codes used in the codebook. + """ + codebook_usage = z_e.unique().numel() / self._codebook_size.item() + return self.quantize(z_e), torch.tensor(0), torch.tensor(codebook_usage) + + def quantize(self, z: Float[torch.Tensor, " b c"]) -> Float[torch.Tensor, " b c"]: + """Quantize the input tensor z, returns corresponding + codebook entry. + """ + return z + + def encode(self, z: Float[torch.Tensor, " b c"]) -> Integer[torch.Tensor, " b c"]: + return self.quantize(z) + + def reconstruct( + self, codes: Float[torch.Tensor, " b c"] + ) -> Float[torch.Tensor, " b c"]: + """Decodes the input code index into corresponding codebook entry of + dimension (embedding_dim). + """ + return codes + + @property + def codebook_size(self) -> int: + """Returns the size of the codebook.""" + return int(self._codebook_size.item()) + + @property + def codebook(self) -> Float[torch.Tensor, " c"]: + """Returns the codebook.""" + return torch.arange(self._codebook_size.item()) + + @property + def embedding_dim(self) -> int: + """Returns the dimension of the codebook entries.""" + return 1 + + +class FiniteScaleQuantizer(Quantizer): + def __init__( + self, + levels: list[int], + eps: float = 1e-3, + ): + """Finite scalar quantizer (FSQ) module + https://arxiv.org/pdf/2309.15505.pdf + + Following the implementation from: + https://github.com/duchenzhuang/FSQ-pytorch/blob/main/quantizers/fsq.py + + Args: + levels: list[int] + The number of levels for each dimension. Length of the list should match + the number of embedding dimensions. + eps: float + The epsilon value for the FSQ. + """ + super().__init__() + _levels = torch.tensor(levels, dtype=torch.int32) + self.register_buffer("levels", _levels) + self._embedding_dim = len(levels) + self._basis = torch.cumprod( + torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int32 + ) + self.eps = eps + + @property + def codebook_size(self): + return math.prod(self.levels) + + @property + def embedding_dim(self): + return self._embedding_dim + + def _bound( + self, z: Float[torch.Tensor, " b t *c"] + ) -> Float[torch.Tensor, " b t *c"]: + """Bound `z`, an array of shape (..., d).""" + half_l = (self.levels - 1) * (1 + self.eps) / 2 + offset = torch.where(self.levels % 2 == 1, 0.0, 0.5) + shift = torch.atanh(offset / half_l) + return torch.tanh(z + shift) * half_l - offset + + def _quantize( + self, z: Float[torch.Tensor, " b t *c"] + ) -> Float[torch.Tensor, " b t *c"]: + """Quantizes z, returns quantized codes zhat with the same shape as z. + Assumes last dimension of z is the embedding dimension. + """ + + def round_ste(z): + zhat = z.round() + return z + (zhat - z).detach() + + quantized = round_ste(self._bound(z)) + # Renormalize to [-1, 1]. + half_width = self.levels // 2 + return quantized / half_width + + def _scale_and_shift(self, zhat_normalized): + half_width = self.levels // 2 + return (zhat_normalized * half_width) + half_width + + def _scale_and_shift_inverse(self, zhat): + half_width = self.levels // 2 + return (zhat - half_width) / half_width + + def quantize( + self, z: Float[torch.Tensor, " b *c t"] + ) -> Float[torch.Tensor, " b *c t"]: + """ + Quantizes the input tensor. + + Args: + z (Tensor): The input tensor to be quantized. + + Returns: + Tensor: The quantized tensor, same shape as input. + """ + # Move the embedding dimension to the last dimension for easier broadcasting + z = z.moveaxis(1, -1) + zhat = self._quantize(z) + return zhat.moveaxis(-1, 1) + + def encode( + self, z: Float[torch.Tensor, " b *c t"] + ) -> Integer[torch.Tensor, " b *code"]: + """ + Encodes the input tensor `z` using quantization. + + Args: + z (Tensor): The input tensor to be encoded. + + Returns: + Tensor: integer code index. + """ + # Move the embedding dimension to the last dimension for easier broadcasting + z = z.moveaxis(1, -1) + zhat = self._quantize(z) + zhat = self._scale_and_shift(zhat) + return (zhat * self._basis.to(zhat)).sum(axis=-1).to(torch.int32) + + def reconstruct( + self, codes: Integer[torch.Tensor, " b *code"] + ) -> Float[torch.Tensor, "b *c t"]: + """ + Decodes the given codes into the corresponding values. + + Args: + codes (Tensor): The codes to be decoded. + + Returns: + Tensor: The decoded tensor. + """ + indices = codes.unsqueeze(-1) + codes_non_centered = (indices // self._basis.to(indices)) % self.levels + zhat = self._scale_and_shift_inverse(codes_non_centered) + # Move the embedding dimension back to the second dimension + return zhat.moveaxis(-1, 1) + + def forward( + self, z_e: Float[torch.Tensor, " b t *codes"] + ) -> tuple[ + Float[torch.Tensor, " b t *shape"], + Float[torch.Tensor, ""], + Float[torch.Tensor, ""], + ]: + """ + Forward pass of the quantizer module. + + Args: + z_e: The input tensor. + + Returns: + tuple[Tensor, Tensor, Tensor]: A tuple containing: + - decoded (Tensor): The decoded tensor. + - loss (Tensor): In this case, no additional loss is necessary, so always returns 0. + - codebook_usage (Tensor): The ratio of unique codes used in the codebook. + """ + z_q = self.quantize(z_e) + codes = self.encode(z_e) + codebook_usage = len(torch.unique(codes)) / self.codebook_size + return z_q, torch.zeros([]), codebook_usage diff --git a/aion/tokenizers.py b/aion/codecs/tokenizers/__init__.py similarity index 100% rename from aion/tokenizers.py rename to aion/codecs/tokenizers/__init__.py diff --git a/aion/codecs/tokenizers/base.py b/aion/codecs/tokenizers/base.py new file mode 100644 index 0000000..a3194fa --- /dev/null +++ b/aion/codecs/tokenizers/base.py @@ -0,0 +1,70 @@ +from abc import ABC, abstractmethod + +import torch +from jaxtyping import Float + +from aion.codecs.quantizers import Quantizer + + +class Codec(ABC): + """Abstract definition of a Codec. + + A codec embeds a specific type of data into a sequence of either + discrete tokens or continuous embeddings, and then decode it back. + """ + + @property + @abstractmethod + def modality(self) -> str: + """Returns the modality key that this codec can operate on.""" + raise NotImplementedError + + @abstractmethod + def _encode( + self, x: Float[torch.Tensor, " b c *input_shape"] + ) -> Float[torch.Tensor, " b c1 *code_shape"]: + """Function to be implemented by subclasses which + takes a batch of input samples and embedds it into a + latent space, before any quantization. + """ + raise NotImplementedError + + @abstractmethod + def _decode( + self, z: Float[torch.Tensor, " b c1 *code_shape"] + ) -> Float[torch.Tensor, " b c *input_shape"]: + """Function to be implemented by subclasses which + takes a batch of latent space embeddings (after dequantization) + and decodes it into the original input space. + """ + raise NotImplementedError + + def encode( + self, x: Float[torch.Tensor, " b c *input_shape"] + ) -> Float[torch.Tensor, " b c1 *code_shape"]: + """Encodes a given batch of samples into latent space.""" + return self._encode(x) + + def decode( + self, z: Float[torch.Tensor, " b c1 *code_shape"] + ) -> Float[torch.Tensor, " b c *input_shape"]: + """Encodes a given batch of samples into latent space.""" + return self._decode(z) + + +class QuantizedCodec(Codec): + def __init__(self, quantizer: Quantizer): + super().__init__() + self.quantizer = quantizer + + def decode( + self, z: Float[torch.Tensor, " b c1 *code_shape"] + ) -> Float[torch.Tensor, " b c *input_shape"]: + z = self.quantizer.reconstruct(z) + return super().decode(z) + + def encode( + self, x: Float[torch.Tensor, " b c *input_shape"] + ) -> Float[torch.Tensor, " b c1 *code_shape"]: + embedding = super().encode(x) + return self.quantizer.encode(embedding) diff --git a/aion/codecs/tokenizers/image.py b/aion/codecs/tokenizers/image.py new file mode 100644 index 0000000..abae50b --- /dev/null +++ b/aion/codecs/tokenizers/image.py @@ -0,0 +1,133 @@ +import torch +from jaxtyping import Float + +from aion.codecs.modules.magvit import MagVitAE +from aion.codecs.modules.subsampler import SubsampledLinear +from aion.codecs.quantizers import Quantizer +from aion.codecs.tokenizers.base import QuantizedCodec +from aion.codecs.utils import range_compression, reverse_range_compression + + +class AutoencoderImageCodec(QuantizedCodec): + """Meta-class for autoencoder codecs for images, does not actually contain a network.""" + + def __init__( + self, + n_bands: int, + quantizer: Quantizer, + encoder: torch.nn.Module, + decoder: torch.nn.Module, + hidden_dims: int = 64, + embedding_dim: int = 5, + multisurvey_projection_dims: int = 54, + range_compression_factor: float = 0.01, + mult_factor: float = 10.0, + ): + super().__init__(quantizer) + self.range_compression_factor = range_compression_factor + self.mult_factor = mult_factor + self.n_bands = n_bands + self.encoder = encoder + self.decoder = decoder + + self.subsample_in = SubsampledLinear( + dim_in=n_bands, dim_out=multisurvey_projection_dims, subsample_in=True + ) + self.subsample_out = SubsampledLinear( + dim_in=multisurvey_projection_dims, dim_out=n_bands, subsample_in=False + ) + # Go down to size of levels + self.pre_quant_proj = torch.nn.Conv2d( + hidden_dims, embedding_dim, kernel_size=1, stride=1, padding=0 + ) + + # Go back to the original size + self.post_quant_proj = torch.nn.Conv2d( + embedding_dim, hidden_dims, kernel_size=1, stride=1, padding=0 + ) + + @property + def modality(self) -> str: + return "image" + + def _preprocess_sample(self, x): + x = range_compression(x, self.range_compression_factor) + x = x * self.mult_factor + return x + + def _postprocess_sample(self, x): + x = x / self.mult_factor + x = reverse_range_compression(x, self.range_compression_factor) + return x + + def _encode( + self, x: Float[torch.Tensor, " b {self.n_bands} w h"] + ) -> Float[torch.Tensor, " b c1 w1 h1"]: + x = self._preprocess_sample(x) + batch_size = x.shape[0] + channel_mask = torch.zeros((batch_size, self.n_bands), device=x.device) + x = self.subsample_in(x, channel_mask) + h = self.encoder(x) + h = self.pre_quant_proj(h) + return h + + def _decode( + self, + z: Float[torch.Tensor, " b c1 w1 h1"], + ) -> Float[torch.Tensor, " b {self.n_bands} w h"]: + # Decode the image + h = self.post_quant_proj(z) + dec = self.decoder(h) + batch_size = z.shape[0] + channel_mask = torch.ones((batch_size, self.n_bands), device=z.device) + dec = self.subsample_out(dec, channel_mask) + # Undo range compression if necessary + dec = self._postprocess_sample(dec) + + return dec + + +class MagViTAEImageCodec(AutoencoderImageCodec): + def __init__( + self, + n_bands: int, + quantizer: Quantizer, + hidden_dims: int = 512, + multisurvey_projection_dims: int = 54, + n_compressions: int = 2, # Number of compressions in the network + num_consecutive: int = 4, # Number of consecutive residual layers per compression + embedding_dim: int = 5, + range_compression_factor: float = 0.01, + mult_factor: float = 10.0, + ): + """ + MagViT Autoencoder for images. + + Args: + n_bands: Number of bands in the input images. + quantizer: Quantizer to use. + hidden_dims: Number of hidden dimensions in the network. + n_compressions: Number of compressions in the network. + num_consecutive: Number of consecutive residual layers per compression. + embedding_dim: Dimension of the latent space. + range_compression_factor: Range compression factor. + mult_factor: Multiplication factor. + """ + # Get MagViT architecture + self.model = MagVitAE( + n_bands=multisurvey_projection_dims, + hidden_dims=hidden_dims, + n_compressions=n_compressions, + num_consecutive=num_consecutive, + ) + super().__init__( + n_bands, + quantizer, + self.model.encode, + self.model.decode, + hidden_dims, + embedding_dim, + multisurvey_projection_dims, + range_compression_factor, + mult_factor, + ) diff --git a/aion/codecs/tokenizers/scalar.py b/aion/codecs/tokenizers/scalar.py new file mode 100644 index 0000000..20467c2 --- /dev/null +++ b/aion/codecs/tokenizers/scalar.py @@ -0,0 +1,26 @@ +from aion.codecs.tokenizers.base import QuantizedCodec +from aion.codecs.quantizers import Quantizer + +from jaxtyping import Float +import torch + + +class ScalarIdentityCodec(QuantizedCodec): + """Codec for scalar quantities. + + A codec that embeds scalar quantities through an identity mapping. + + """ + + def __init__(self, quantizer: Quantizer): + super().__init__(quantizer) + + @property + def modality(self): + return "label" + + def _encode(self, x: Float[torch.Tensor, " b t"]) -> Float[torch.Tensor, " b t"]: + return x + + def _decode(self, z: Float[torch.Tensor, " b c"]) -> Float[torch.Tensor, " b c"]: + return z diff --git a/aion/codecs/utils.py b/aion/codecs/utils.py new file mode 100644 index 0000000..cd9a04f --- /dev/null +++ b/aion/codecs/utils.py @@ -0,0 +1,15 @@ +import torch + + +def range_compression( + sample: torch.Tensor, div_factor: float | int = 0.01 +) -> torch.Tensor: + """Applies arcsinh compression on each band of the input.""" + return torch.arcsinh(sample / div_factor) * div_factor + + +def reverse_range_compression( + sample: torch.Tensor, div_factor: float | int = 0.01 +) -> torch.Tensor: + """Undoes arcsinh compression on each band of the input.""" + return torch.sinh(sample / div_factor) * div_factor diff --git a/pyproject.toml b/pyproject.toml index 0042f07..06c8f93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,5 +31,7 @@ dev = [ "ruff", ] -[tool.setuptools] -packages = ["aion", "aion.fourm"] +[tool.ruff.lint] +# Ignore space in shape notation for jaxtyping +# See https://docs.kidger.site/jaxtyping/faq/#flake8-or-ruff-are-throwing-an-error +ignore = ["F722"] diff --git a/tests/tokenizers/test_image_tokenizer.py b/tests/tokenizers/test_image_tokenizer.py new file mode 100644 index 0000000..8b8d509 --- /dev/null +++ b/tests/tokenizers/test_image_tokenizer.py @@ -0,0 +1,31 @@ +import pytest +import torch + +from aion.codecs.tokenizers.image import MagViTAEImageCodec +from aion.codecs.quantizers import FiniteScaleQuantizer + + +@pytest.mark.parametrize("n_bands", [3, 10]) +@pytest.mark.parametrize("embedding_dim", [5, 10]) +@pytest.mark.parametrize("multisurvey_projection_dims", [12, 24]) +@pytest.mark.parametrize("hidden_dims", [8, 16]) +def test_magvit_image_tokenizer( + n_bands, embedding_dim, multisurvey_projection_dims, hidden_dims +): + tokenizer = MagViTAEImageCodec( + n_bands=n_bands, + quantizer=FiniteScaleQuantizer(levels=[1] * embedding_dim), + hidden_dims=hidden_dims, + multisurvey_projection_dims=multisurvey_projection_dims, + n_compressions=2, + num_consecutive=4, + embedding_dim=embedding_dim, + range_compression_factor=0.01, + mult_factor=10, + ) + batch_size = 4 + random_input = torch.randn(batch_size, n_bands, 96, 96) + encoded = tokenizer.encode(random_input) + assert encoded.shape == (batch_size, 24, 24) + decoded = tokenizer.decode(encoded) + assert decoded.shape == random_input.shape diff --git a/tests/tokenizers/test_scalar_tokenizer.py b/tests/tokenizers/test_scalar_tokenizer.py new file mode 100644 index 0000000..9e0055e --- /dev/null +++ b/tests/tokenizers/test_scalar_tokenizer.py @@ -0,0 +1,16 @@ +import pytest +import torch + +from aion.codecs.quantizers import IdentityQuantizer +from aion.codecs.tokenizers.scalar import ScalarIdentityCodec + + +@pytest.mark.parametrize("codebook_size", [10, 20]) +@pytest.mark.parametrize("embedding_dim", [1, 4]) +def test_scalar_identity_codec(codebook_size, embedding_dim): + quantizer = IdentityQuantizer(codebook_size=codebook_size) + codec = ScalarIdentityCodec(quantizer=quantizer) + x = torch.randint(0, codebook_size, (64, embedding_dim)) + z = codec.encode(x) + assert z.shape == x.shape + assert torch.allclose(z, x)