From 83b558291314e4513ac61ef51f73d33832d73a1a Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Mon, 7 Apr 2025 12:04:21 +0200 Subject: [PATCH 01/20] Add base class for quantizers --- aion/quantizers/__init__.py | 40 +++++++++++++++++++++++++++++++++++++ pyproject.toml | 5 +++++ 2 files changed, 45 insertions(+) create mode 100644 aion/quantizers/__init__.py diff --git a/aion/quantizers/__init__.py b/aion/quantizers/__init__.py new file mode 100644 index 0000000..7598ef8 --- /dev/null +++ b/aion/quantizers/__init__.py @@ -0,0 +1,40 @@ +from abc import ABC, abstractmethod +import torch + +from jaxtyping import Float + + +class Quantizer(torch.nn.Module, ABC): + """Abstract interface for all quantizer modules.""" + + @abstractmethod + def quantize( + self, x: Float[torch.Tensor, " b c1 *input_shape"] + ) -> Float[torch.Tensor, " b c *code_shape"]: + """Quantize the input tensor.""" + raise NotImplementedError + + @abstractmethod + def reconstruct( + self, z: Float[torch.Tensor, " b c *code_shape"] + ) -> Float[torch.Tensor, " b c *input_shape"]: + """Reconstruct the input tensor from the quantized tensor.""" + raise NotImplementedError + + @abstractmethod + def forward( + self, z_e: Float[torch.Tensor, " b c *input_shape"] + ) -> tuple[ + Float[torch.Tensor, " b c *code_shape"], + Float[torch.Tensor, " b"], + Float[torch.Tensor, " b"], + ]: + """Performs a forward pass through the vector quantizer. + Args: + x: The input tensor to be quantized. + Returns: + z: The quantized tensor. + quantization_error: The error of the quantization. + codebook_usage: The fraction of codes used in the codebook. + """ + raise NotImplementedError diff --git a/pyproject.toml b/pyproject.toml index 0042f07..c61bcd8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,3 +33,8 @@ dev = [ [tool.setuptools] packages = ["aion", "aion.fourm"] + +[tool.ruff.lint] +# Ignore space in shape notation for jaxtyping +# See https://docs.kidger.site/jaxtyping/faq/#flake8-or-ruff-are-throwing-an-error +ignore = ["F722"] From 23c75bf3f342d24d25e778a8b7e6b48d42d78ed3 Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Mon, 7 Apr 2025 12:04:57 +0200 Subject: [PATCH 02/20] Add base class for tokenizers --- aion/tokenizers/__init__.py | 0 aion/tokenizers/base.py | 70 +++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) create mode 100644 aion/tokenizers/__init__.py create mode 100644 aion/tokenizers/base.py diff --git a/aion/tokenizers/__init__.py b/aion/tokenizers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/aion/tokenizers/base.py b/aion/tokenizers/base.py new file mode 100644 index 0000000..f96b44c --- /dev/null +++ b/aion/tokenizers/base.py @@ -0,0 +1,70 @@ +from abc import ABC, abstractmethod + +import torch +from jaxtyping import Float + +from aion.quantizers import Quantizer + + +class Codec(ABC): + """Abstract definition of a Codec. + + A codec embeds a specific type of data into a sequence of either + discrete tokens or continuous embedddings, and then decode it back. + """ + + @property + @abstractmethod + def modality(self) -> str: + """Returns the modality key that this codec can operate on.""" + raise NotImplementedError + + @abstractmethod + def _encode( + self, x: Float[torch.Tensor, " b c *input_shape"] + ) -> Float[torch.Tensor, " b c1 *code_shape"]: + """Function to be implemented by subclasses which + takes a batch of input samples and embedds it into a + latent space, before any quantization. + """ + raise NotImplementedError + + @abstractmethod + def _decode( + self, z: Float[torch.Tensor, " b c1 *code_shape"] + ) -> Float[torch.Tensor, " b c *input_shape"]: + """Function to be implemented by subclasses which + takes a batch of latent space embeddings (after dequantization) + and decodes it into the original input space. + """ + raise NotImplementedError + + def encode( + self, x: Float[torch.Tensor, " b c *input_shape"] + ) -> Float[torch.Tensor, " b c1 *code_shape"]: + """Encodes a given batch of samples into latent space.""" + return self._encode(x) + + def decode( + self, z: Float[torch.Tensor, " b c1 *code_shape"] + ) -> Float[torch.Tensor, " b c *input_shape"]: + """Encodes a given batch of samples into latent space.""" + return self._decode(z) + + +class QuantizedCodec(Codec): + def __init__(self, quantizer: Quantizer): + super().__init__() + self.quantizer = quantizer + + def decode( + self, z: Float[torch.Tensor, " b c1 *code_shape"] + ) -> Float[torch.Tensor, " b c *input_shape"]: + z = self.quantizer.reconstruct(z) + return super().decode(z) + + def encode( + self, x: Float[torch.Tensor, " b c *input_shape"] + ) -> Float[torch.Tensor, " b c1 *code_shape"]: + embedding = super().encode(x) + return self.quantizer.quantize(embedding) From b250bf4ae0eebfbd149f3413ba5b12cd6a50b247 Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Mon, 7 Apr 2025 21:54:14 +0200 Subject: [PATCH 03/20] Change arborescence for tokenizers --- aion/tokenizers.py | 8 -------- aion/tokenizers/__init__.py | 8 ++++++++ 2 files changed, 8 insertions(+), 8 deletions(-) delete mode 100644 aion/tokenizers.py diff --git a/aion/tokenizers.py b/aion/tokenizers.py deleted file mode 100644 index 81a2226..0000000 --- a/aion/tokenizers.py +++ /dev/null @@ -1,8 +0,0 @@ -import torch -from torch import package - - -def load_tokenizer(path: str, device: str = "cpu") -> torch.nn.Module: - importer = package.PackageImporter(path) - model = importer.load_pickle("network", "network.pkl", map_location=device) - return model diff --git a/aion/tokenizers/__init__.py b/aion/tokenizers/__init__.py index e69de29..81a2226 100644 --- a/aion/tokenizers/__init__.py +++ b/aion/tokenizers/__init__.py @@ -0,0 +1,8 @@ +import torch +from torch import package + + +def load_tokenizer(path: str, device: str = "cpu") -> torch.nn.Module: + importer = package.PackageImporter(path) + model = importer.load_pickle("network", "network.pkl", map_location=device) + return model From dc75b8c7441094f10136cf16363ddf8d9906b250 Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Mon, 7 Apr 2025 21:54:38 +0200 Subject: [PATCH 04/20] Add utils functions --- aion/utils.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 aion/utils.py diff --git a/aion/utils.py b/aion/utils.py new file mode 100644 index 0000000..4e60367 --- /dev/null +++ b/aion/utils.py @@ -0,0 +1,15 @@ +import torch + + +def range_compression( + sample: torch.Tensor, div_factor: float | int = 0.01 +) -> torch.Tensor: + """Applies arcsinh compression on each band of the input.""" + return torch.arcsinh(sample / div_factor) * div_factor + + +def reverse_range_compression( + sample: torch.Tensor, div_factor: float | int = 0.01 +) -> torch.Tensor: + """Undos arcsinh compression on each band of the input.""" + return torch.sinh(sample / div_factor) * div_factor From 83898235279059a4f3f619fd7d3911a412fcced4 Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Mon, 7 Apr 2025 21:55:02 +0200 Subject: [PATCH 05/20] Add MagVitAE model --- aion/modules/__init__.py | 0 aion/modules/magvit.py | 214 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 214 insertions(+) create mode 100644 aion/modules/__init__.py create mode 100644 aion/modules/magvit.py diff --git a/aion/modules/__init__.py b/aion/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/aion/modules/magvit.py b/aion/modules/magvit.py new file mode 100644 index 0000000..83ce352 --- /dev/null +++ b/aion/modules/magvit.py @@ -0,0 +1,214 @@ +import torch +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +def cast_tuple(t, length=1): + return t if isinstance(t, tuple) else ((t,) * length) + + +class SameConv2d(torch.nn.Module): + def __init__(self, dim_in, dim_out, kernel_size): + super().__init__() + kernel_size = cast_tuple(kernel_size, 2) + padding = [k // 2 for k in kernel_size] + self.conv = torch.nn.Conv2d( + dim_in, dim_out, kernel_size=kernel_size, padding=padding + ) + + def forward(self, x: torch.Tensor): + return self.conv(x) + + +class SqueezeExcite(torch.nn.Module): + # global context network - attention-esque squeeze-excite variant (https://arxiv.org/abs/2012.13375) + + def __init__(self, dim, *, dim_out=None, dim_hidden_min=16, init_bias=-10): + super().__init__() + dim_out = dim_out if dim_out is not None else dim + + self.to_k = torch.nn.Conv2d(dim, 1, 1) + dim_hidden = max(dim_hidden_min, dim_out // 2) + + self.net = torch.nn.Sequential( + torch.nn.Conv2d(dim, dim_hidden, 1), + torch.nn.LeakyReLU(0.1), + torch.nn.Conv2d(dim_hidden, dim_out, 1), + torch.nn.Sigmoid(), + ) + + torch.nn.init.zeros_(self.net[-2].weight) + torch.nn.init.constant_(self.net[-2].bias, init_bias) + + def forward(self, x): + context = self.to_k(x) + + context = rearrange(context, "b c h w -> b c (h w)").softmax(dim=-1) + spatial_flattened_input = rearrange(x, "b c h w -> b c (h w)") + + out = torch.einsum("b i n, b c n -> b c i", context, spatial_flattened_input) + out = rearrange(out, "... -> ... 1") + gates = self.net(out) + + return gates * x + + +class ResidualUnit(torch.nn.Module): + def __init__(self, dim: int, kernel_size: int | tuple[int, int, int]): + super().__init__() + self.net = torch.nn.Sequential( + SameConv2d(dim, dim, kernel_size), + torch.nn.ELU(), + torch.nn.Conv2d(dim, dim, 1), + torch.nn.ELU(), + SqueezeExcite(dim), + ) + + def forward(self, x: torch.Tensor): + return self.net(x) + x + + +class SpatialDownsample2x(torch.nn.Module): + def __init__( + self, + dim: int, + dim_out: int = None, + kernel_size: int = 3, + ): + super().__init__() + dim_out = dim_out if dim_out is not None else dim + self.conv = torch.nn.Conv2d( + dim, dim_out, kernel_size, stride=2, padding=kernel_size // 2 + ) + + def forward(self, x: torch.Tensor): + out = self.conv(x) + return out + + +class SpatialUpsample2x(torch.nn.Module): + def __init__(self, dim: int, dim_out: int = None): + super().__init__() + dim_out = dim_out if dim_out is not None else dim + conv = torch.nn.Conv2d(dim, dim_out * 4, 1) + + self.net = torch.nn.Sequential( + conv, + torch.nn.SiLU(), + Rearrange("b (c p1 p2) h w -> b c (h p1) (w p2)", p1=2, p2=2), + ) + + self.init_conv_(conv) + + def init_conv_(self, conv: torch.nn.Module): + o, i, h, w = conv.weight.shape + conv_weight = torch.empty(o // 4, i, h, w) + torch.nn.init.kaiming_uniform_(conv_weight) + conv_weight = repeat(conv_weight, "o ... -> (o 4) ...") + + conv.weight.data.copy_(conv_weight) + torch.nn.init.zeros_(conv.bias.data) + + def forward(self, x: torch.Tensor): + out = self.net(x) + return out + + +class MagVitAE(torch.nn.Module): + """MagViTAE implementation from Yu, et al. (2024), adapted for Pytorch. + Code borrowed from https://github.com/lucidrains/magvit2-pytorch, and adapted for images. + """ + + def __init__( + self, + n_bands: int = 3, + hidden_dims: int = 512, + residual_conv_kernel_size: int = 3, + n_compressions: int = 2, + num_consecutive: int = 2, + ): + super().__init__() + + self.encoder_layers = torch.nn.ModuleList([]) + self.decoder_layers = torch.nn.ModuleList([]) + init_dim = int(hidden_dims / 2**n_compressions) + dim = init_dim + + self.conv_in = SameConv2d(n_bands, init_dim, 7) + self.conv_out = SameConv2d(init_dim, n_bands, 3) + + # Residual layers + encoder_layer = ResidualUnit(dim, residual_conv_kernel_size) + decoder_layer = ResidualUnit(dim, residual_conv_kernel_size) + self.encoder_layers.append(encoder_layer) + self.decoder_layers.insert(0, decoder_layer) + + # Compressions + for i in range(n_compressions): + dim_out = dim * 2 + encoder_layer = SpatialDownsample2x(dim, dim_out) + decoder_layer = SpatialUpsample2x(dim_out, dim) + self.encoder_layers.append(encoder_layer) + self.decoder_layers.insert(0, decoder_layer) + dim = dim_out + + # Consecutive residual layers + encoder_layer = torch.nn.Sequential( + *[ + ResidualUnit(dim, residual_conv_kernel_size) + for _ in range(num_consecutive) + ] + ) + decoder_layer = torch.nn.Sequential( + *[ + ResidualUnit(dim, residual_conv_kernel_size) + for _ in range(num_consecutive) + ] + ) + self.encoder_layers.append(encoder_layer) + self.decoder_layers.insert(0, decoder_layer) + + # Add a final non-compress layer + dim_out = dim + encoder_layer = SameConv2d(dim, dim_out, 7) + decoder_layer = SameConv2d(dim_out, dim, 3) + self.encoder_layers.append(encoder_layer) + self.decoder_layers.insert(0, decoder_layer) + dim = dim_out + + # Consecutive residual layers + encoder_layer = torch.nn.Sequential( + *[ + ResidualUnit(dim, residual_conv_kernel_size) + for _ in range(num_consecutive) + ] + ) + decoder_layer = torch.nn.Sequential( + *[ + ResidualUnit(dim, residual_conv_kernel_size) + for _ in range(num_consecutive) + ] + ) + self.encoder_layers.append(encoder_layer) + self.decoder_layers.insert(0, decoder_layer) + + # add a final norm just before quantization layer + self.encoder_layers.append( + torch.nn.Sequential( + Rearrange("b c ... -> b ... c"), + torch.nn.LayerNorm(dim), + Rearrange("b ... c -> b c ..."), + ) + ) + + def encode(self, x: torch.Tensor): + x = self.conv_in(x) + for layer in self.encoder_layers: + x = layer(x) + return x + + def decode(self, x: torch.Tensor): + for layer in self.decoder_layers: + x = layer(x) + x = self.conv_out(x) + return x From b51e9636bc4b81a6fe119f7c314567922c6e9d1d Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Mon, 7 Apr 2025 21:55:21 +0200 Subject: [PATCH 06/20] Add subsampler module --- aion/modules/subsampler.py | 60 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 aion/modules/subsampler.py diff --git a/aion/modules/subsampler.py b/aion/modules/subsampler.py new file mode 100644 index 0000000..af66f7a --- /dev/null +++ b/aion/modules/subsampler.py @@ -0,0 +1,60 @@ +import torch +import torch.nn.functional as F +from einops import rearrange +from jaxtyping import Bool, Float + + +class SubsampledLinear(torch.nn.Module): + def __init__(self, dim_in: int, dim_out: int, subsample_in: bool = True): + """ + Subsampled linear layer for the encoder. + It takes in a zero-padded tensor and a mask. + It projects the tensor into some shared projection space. + It can also be used to reverse out of the space with the mask. + + Args: + dim_in : Number of total possible bands. + dim_out : Number of embedding dimensions. + subsample_in : Whether to subsample the input. Defaults to True. + """ + super().__init__() + self.subsample_in = subsample_in + self.dim_in = dim_in # Number of total possible bands + self.dim_out = dim_out # Number of embedding dimensions + temp_linear = torch.nn.Linear(dim_in, dim_out) + self.weight = torch.nn.Parameter(temp_linear.weight) + self.bias = torch.nn.Parameter(temp_linear.bias) + + def _subsample_in(self, x, labels: Bool[torch.Tensor, " b c"]): + # Get mask + mask = labels[:, None, None, :].float() + x = x * mask + + # Normalize + label_sizes = labels.sum(dim=1, keepdim=True) + scales = ((self.dim_in / label_sizes) ** 0.5).squeeze() + + # Apply linear layer + return scales[:, None, None, None] * F.linear(x, self.weight, self.bias) + + def _subsample_out(self, x, labels): + # Get mask + mask = labels[:, None, None, :].float() + + # Apply linear layer and mask + return F.linear(x, self.weight, self.bias) * mask + + def forward( + self, x: Float[torch.Tensor, " b c h w"], labels: Bool[torch.Tensor, " b c"] + ) -> Float[torch.Tensor, " b c h w"]: + x = rearrange(x, "b c h w -> b h w c") + + if self.subsample_in: + x = self._subsample_in(x, labels) + + else: + x = self._subsample_out(x, labels) + + x = rearrange(x, "b h w c -> b c h w") + + return x From 57926abebf954391b673c5c40815dc29cd25e189 Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Mon, 7 Apr 2025 21:56:26 +0200 Subject: [PATCH 07/20] Add FiniteScale quantizer --- aion/quantizers/__init__.py | 152 +++++++++++++++++++++++++++++++++++- 1 file changed, 150 insertions(+), 2 deletions(-) diff --git a/aion/quantizers/__init__.py b/aion/quantizers/__init__.py index 7598ef8..18b6102 100644 --- a/aion/quantizers/__init__.py +++ b/aion/quantizers/__init__.py @@ -1,7 +1,8 @@ +import math from abc import ABC, abstractmethod -import torch -from jaxtyping import Float +import torch +from jaxtyping import Float, Integer class Quantizer(torch.nn.Module, ABC): @@ -38,3 +39,150 @@ def forward( codebook_usage: The fraction of codes used in the codebook. """ raise NotImplementedError + + +class FiniteScaleQuantizer(Quantizer): + def __init__( + self, + levels: list[int], + eps: float = 1e-3, + ): + """Finite scalar quantizer (FSQ) module + https://arxiv.org/pdf/2309.15505.pdf + + Following the implementation from: + https://github.com/duchenzhuang/FSQ-pytorch/blob/main/quantizers/fsq.py + + Args: + levels: list[int] + The number of levels for each dimension. Length of the list should match + the number of embedding dimensions. + eps: float + The epsilon value for the FSQ. + """ + super().__init__() + _levels = torch.tensor(levels, dtype=torch.int32) + self.register_buffer("levels", _levels) + self._embedding_dim = len(levels) + self._basis = torch.cumprod( + torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int32 + ) + self.eps = eps + + @property + def codebook_size(self): + return math.prod(self.levels) + + @property + def embedding_dim(self): + return self._embedding_dim + + def _bound( + self, z: Float[torch.Tensor, " b t *c"] + ) -> Float[torch.Tensor, " b t *c"]: + """Bound `z`, an array of shape (..., d).""" + half_l = (self.levels - 1) * (1 + self.eps) / 2 + offset = torch.where(self.levels % 2 == 1, 0.0, 0.5) + shift = torch.atanh(offset / half_l) + return torch.tanh(z + shift) * half_l - offset + + def _quantize( + self, z: Float[torch.Tensor, " b t *c"] + ) -> Float[torch.Tensor, " b t *c"]: + """Quantizes z, returns quantized codes zhat with the same shape as z. + Assumes last dimension of z is the embedding dimension. + """ + + def round_ste(z): + zhat = z.round() + return z + (zhat - z).detach() + + quantized = round_ste(self._bound(z)) + # Renormalize to [-1, 1]. + half_width = self.levels // 2 + return quantized / half_width + + def _scale_and_shift(self, zhat_normalized): + half_width = self.levels // 2 + return (zhat_normalized * half_width) + half_width + + def _scale_and_shift_inverse(self, zhat): + half_width = self.levels // 2 + return (zhat - half_width) / half_width + + def quantize( + self, z: Float[torch.Tensor, " b *c t"] + ) -> Float[torch.Tensor, " b *c t"]: + """ + Quantizes the input tensor. + + Args: + z (Tensor): The input tensor to be quantized. + + Returns: + Tensor: The quantized tensor, same shape as input. + """ + # Move the embedding dimension to the last dimension for easier broadcasting + z = z.moveaxis(1, -1) + zhat = self._quantize(z) + return zhat.moveaxis(-1, 1) + + def encode( + self, z: Float[torch.Tensor, " b *c t"] + ) -> Integer[torch.Tensor, " b *code"]: + """ + Encodes the input tensor `z` using quantization. + + Args: + z (Tensor): The input tensor to be encoded. + + Returns: + Tensor: integer code index. + """ + # Move the embedding dimension to the last dimension for easier broadcasting + z = z.moveaxis(1, -1) + zhat = self._quantize(z) + zhat = self._scale_and_shift(zhat) + return (zhat * self._basis.to(zhat)).sum(axis=-1).to(torch.int32) + + def reconstruct( + self, codes: Integer[torch.Tensor, " b *code"] + ) -> Float[torch.Tensor, "b *c t"]: + """ + Decodes the given codes into the corresponding values. + + Args: + codes (Tensor): The codes to be decoded. + + Returns: + Tensor: The decoded tensor. + """ + indices = codes.unsqueeze(-1) + codes_non_centered = (indices // self._basis.to(indices)) % self.levels + zhat = self._scale_and_shift_inverse(codes_non_centered) + # Move the embedding dimension back to the second dimension + return zhat.moveaxis(-1, 1) + + def forward( + self, z_e: Float[torch.Tensor, " b t *codes"] + ) -> tuple[ + Float[torch.Tensor, " b t *shape"], + Float[torch.Tensor, ""], + Float[torch.Tensor, ""], + ]: + """ + Forward pass of the quantizer module. + + Args: + z_e: The input tensor. + + Returns: + tuple[Tensor, Tensor, Tensor]: A tuple containing: + - decoded (Tensor): The decoded tensor. + - loss (Tensor): In this case, no additional loss is necessary, so always returns 0. + - codebook_usage (Tensor): The ratio of unique codes used in the codebook. + """ + z_q = self.quantize(z_e) + codes = self.encode(z_e) + codebook_usage = len(torch.unique(codes)) / self.codebook_size + return z_q, torch.zeros([]), codebook_usage From 2733ce7ff31f2b8c12c15904fa60a64a3e3119ab Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Mon, 7 Apr 2025 21:56:48 +0200 Subject: [PATCH 08/20] Use quantizer encode instead of quantize TODO: Clarify the difference between encode and quantize --- aion/tokenizers/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aion/tokenizers/base.py b/aion/tokenizers/base.py index f96b44c..9230917 100644 --- a/aion/tokenizers/base.py +++ b/aion/tokenizers/base.py @@ -67,4 +67,4 @@ def encode( self, x: Float[torch.Tensor, " b c *input_shape"] ) -> Float[torch.Tensor, " b c1 *code_shape"]: embedding = super().encode(x) - return self.quantizer.quantize(embedding) + return self.quantizer.encode(embedding) From c190989b4d69e31b44c0af166a424b352245f8af Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Mon, 7 Apr 2025 21:57:21 +0200 Subject: [PATCH 09/20] Add MagVitAE image tokenizer --- aion/tokenizers/image.py | 133 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 133 insertions(+) create mode 100644 aion/tokenizers/image.py diff --git a/aion/tokenizers/image.py b/aion/tokenizers/image.py new file mode 100644 index 0000000..b9c28b9 --- /dev/null +++ b/aion/tokenizers/image.py @@ -0,0 +1,133 @@ +import torch +from jaxtyping import Float + +from aion.modules.magvit import MagVitAE +from aion.modules.subsampler import SubsampledLinear +from aion.quantizers import Quantizer +from aion.tokenizers.base import QuantizedCodec +from aion.utils import range_compression, reverse_range_compression + + +class AutoencoderImageCodec(QuantizedCodec): + """Meta-class for autoencoder codecs for images, does not actually contain a network.""" + + def __init__( + self, + n_bands: int, + quantizer: Quantizer, + encoder: torch.nn.Module, + decoder: torch.nn.Module, + hidden_dims: int = 64, + embedding_dim: int = 5, + multisurvey_projection_dims: int = 54, + range_compression_factor: float = 0.01, + mult_factor: float = 10.0, + ): + super().__init__(quantizer) + self.range_compression_factor = range_compression_factor + self.mult_factor = mult_factor + self.n_bands = n_bands + self.encoder = encoder + self.decoder = decoder + + self.subsample_in = SubsampledLinear( + dim_in=n_bands, dim_out=multisurvey_projection_dims, subsample_in=True + ) + self.subsample_out = SubsampledLinear( + dim_in=multisurvey_projection_dims, dim_out=n_bands, subsample_in=False + ) + # Go down to size of levels + self.pre_quant_proj = torch.nn.Conv2d( + hidden_dims, embedding_dim, kernel_size=1, stride=1, padding=0 + ) + + # Go back to the original size + self.post_quant_proj = torch.nn.Conv2d( + embedding_dim, hidden_dims, kernel_size=1, stride=1, padding=0 + ) + + @property + def modality(self) -> str: + return "image" + + def _preprocess_sample(self, x): + x = range_compression(x, self.range_compression_factor) + x = x * self.mult_factor + return x + + def _postprocess_sample(self, x): + x = x / self.mult_factor + x = reverse_range_compression(x, self.range_compression_factor) + return x + + def _encode( + self, x: Float[torch.Tensor, " b {self.n_bands} w h"] + ) -> Float[torch.Tensor, " b c1 w1 h1"]: + x = self._preprocess_sample(x) + batch_size = x.shape[0] + channel_mask = torch.zeros((batch_size, self.n_bands), device=x.device) + x = self.subsample_in(x, channel_mask) + h = self.encoder(x) + h = self.pre_quant_proj(h) + return h + + def _decode( + self, + z: Float[torch.Tensor, " b c1 w1 h1"], + ) -> Float[torch.Tensor, " b {self.n_bands} w h"]: + # Decode the image + h = self.post_quant_proj(z) + dec = self.decoder(h) + batch_size = z.shape[0] + channel_mask = torch.ones((batch_size, self.n_bands), device=z.device) + dec = self.subsample_out(dec, channel_mask) + # Undo range compression if necessary + dec = self._postprocess_sample(dec) + + return dec + + +class MagViTAEImageCodec(AutoencoderImageCodec): + def __init__( + self, + n_bands: int, + quantizer: Quantizer, + hidden_dims: int = 512, + multisurvey_projection_dims: int = 54, + n_compressions: int = 2, # Number of compressions in the network + num_consecutive: int = 4, # Number of consecutive residual layers per compression + embedding_dim: int = 5, + range_compression_factor: float = 0.01, + mult_factor: float = 10.0, + ): + """ + MagViT Autoencoder for images. + + Args: + n_bands: Number of bands in the input images. + quantizer: Quantizer to use. + hidden_dims: Number of hidden dimensions in the network. + n_compressions: Number of compressions in the network. + num_consecutive: Number of consecutive residual layers per compression. + embedding_dim: Dimension of the latent space. + range_compression_factor: Range compression factor. + mult_factor: Multiplication factor. + """ + # Get MagViT architecture + self.model = MagVitAE( + n_bands=multisurvey_projection_dims, + hidden_dims=hidden_dims, + n_compressions=n_compressions, + num_consecutive=num_consecutive, + ) + super().__init__( + n_bands, + quantizer, + self.model.encode, + self.model.decode, + hidden_dims, + embedding_dim, + multisurvey_projection_dims, + range_compression_factor, + mult_factor, + ) From cd1f705c6a7232359e23b91b39a1b49808fed270 Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Mon, 7 Apr 2025 22:08:53 +0200 Subject: [PATCH 10/20] Add test for MagVitAE image tokenizer --- tests/tokenizers/test_image_tokenizer.py | 31 ++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 tests/tokenizers/test_image_tokenizer.py diff --git a/tests/tokenizers/test_image_tokenizer.py b/tests/tokenizers/test_image_tokenizer.py new file mode 100644 index 0000000..41f50a8 --- /dev/null +++ b/tests/tokenizers/test_image_tokenizer.py @@ -0,0 +1,31 @@ +import pytest +import torch + +from aion.quantizers import FiniteScaleQuantizer +from aion.tokenizers.image import MagViTAEImageCodec + + +@pytest.mark.parametrize("n_bands", [3, 10]) +@pytest.mark.parametrize("embedding_dim", [5, 10]) +@pytest.mark.parametrize("multisurvey_projection_dims", [12, 24]) +@pytest.mark.parametrize("hidden_dims", [8, 16]) +def test_magvit_image_tokenizer( + n_bands, embedding_dim, multisurvey_projection_dims, hidden_dims +): + tokenizer = MagViTAEImageCodec( + n_bands=n_bands, + quantizer=FiniteScaleQuantizer(levels=[1] * embedding_dim), + hidden_dims=hidden_dims, + multisurvey_projection_dims=multisurvey_projection_dims, + n_compressions=2, + num_consecutive=4, + embedding_dim=embedding_dim, + range_compression_factor=0.01, + mult_factor=10, + ) + batch_size = 4 + random_input = torch.randn(batch_size, n_bands, 96, 96) + encoded = tokenizer.encode(random_input) + assert encoded.shape == (batch_size, 24, 24) + decoded = tokenizer.decode(encoded) + assert decoded.shape == random_input.shape From 71ea594c5058bdfd419d852292f5d52fd3312e9f Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Mon, 7 Apr 2025 22:09:53 +0200 Subject: [PATCH 11/20] Add tests to CI --- .github/workflows/build.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 8edf8c6..105baaf 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -25,3 +25,7 @@ jobs: run: | python -m pip install --upgrade pip pip install . + - name: Run tests + run: | + pip install ".dev" + pytest tests From 2f531e0f7e6b26c4a12b19fdbe8aa36e3911b8bb Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Mon, 7 Apr 2025 22:27:21 +0200 Subject: [PATCH 12/20] Fix github CI --- .github/workflows/build.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 105baaf..12d9b0c 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -27,5 +27,5 @@ jobs: pip install . - name: Run tests run: | - pip install ".dev" + pip install ".[dev]" pytest tests From c22e5fb99fb3d124e150d083d221122d612d6545 Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Mon, 7 Apr 2025 22:33:47 +0200 Subject: [PATCH 13/20] Remove unnecessary package listing to fix tests --- pyproject.toml | 3 --- 1 file changed, 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c61bcd8..06c8f93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,9 +31,6 @@ dev = [ "ruff", ] -[tool.setuptools] -packages = ["aion", "aion.fourm"] - [tool.ruff.lint] # Ignore space in shape notation for jaxtyping # See https://docs.kidger.site/jaxtyping/faq/#flake8-or-ruff-are-throwing-an-error From 7ec31f92ec4427f56986b29a5ee5526509dea9c9 Mon Sep 17 00:00:00 2001 From: Francois Lanusse Date: Tue, 8 Apr 2025 11:10:40 -0400 Subject: [PATCH 14/20] Update aion/utils.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- aion/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aion/utils.py b/aion/utils.py index 4e60367..cd9a04f 100644 --- a/aion/utils.py +++ b/aion/utils.py @@ -11,5 +11,5 @@ def range_compression( def reverse_range_compression( sample: torch.Tensor, div_factor: float | int = 0.01 ) -> torch.Tensor: - """Undos arcsinh compression on each band of the input.""" + """Undoes arcsinh compression on each band of the input.""" return torch.sinh(sample / div_factor) * div_factor From f4c3bfff6d5e0ce4b1884d0083aab2fe8654583b Mon Sep 17 00:00:00 2001 From: Francois Lanusse Date: Tue, 8 Apr 2025 11:12:38 -0400 Subject: [PATCH 15/20] Update aion/tokenizers/base.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- aion/tokenizers/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aion/tokenizers/base.py b/aion/tokenizers/base.py index 9230917..f870c64 100644 --- a/aion/tokenizers/base.py +++ b/aion/tokenizers/base.py @@ -10,7 +10,7 @@ class Codec(ABC): """Abstract definition of a Codec. A codec embeds a specific type of data into a sequence of either - discrete tokens or continuous embedddings, and then decode it back. + discrete tokens or continuous embeddings, and then decode it back. """ @property From 806da4b4cd0c98c26ce7e0923fc9e26bf301a19d Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Wed, 9 Apr 2025 11:20:16 +0200 Subject: [PATCH 16/20] Add ruff cache to gitignore --- .gitignore | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 15201ac..de88b1c 100644 --- a/.gitignore +++ b/.gitignore @@ -147,10 +147,13 @@ venv.bak/ /site # mypy -.mypy_cache/ +.mypy_cache .dmypy.json dmypy.json +# Ruff +.ruff_cache + # Pyre type checker .pyre/ From 0b5c6e4b95a4291dc2cbdb5f866942e10321f09f Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Wed, 9 Apr 2025 11:21:54 +0200 Subject: [PATCH 17/20] Move tokenizers to dedicated codecs module --- aion/{modules => codecs}/__init__.py | 0 aion/codecs/modules/__init__.py | 0 aion/{ => codecs}/modules/magvit.py | 0 aion/{ => codecs}/modules/subsampler.py | 0 aion/{ => codecs}/quantizers/__init__.py | 0 aion/{ => codecs}/tokenizers/__init__.py | 0 aion/{ => codecs}/tokenizers/base.py | 2 +- aion/{ => codecs}/tokenizers/image.py | 10 +++++----- aion/{ => codecs}/utils.py | 0 tests/tokenizers/test_image_tokenizer.py | 4 ++-- 10 files changed, 8 insertions(+), 8 deletions(-) rename aion/{modules => codecs}/__init__.py (100%) create mode 100644 aion/codecs/modules/__init__.py rename aion/{ => codecs}/modules/magvit.py (100%) rename aion/{ => codecs}/modules/subsampler.py (100%) rename aion/{ => codecs}/quantizers/__init__.py (100%) rename aion/{ => codecs}/tokenizers/__init__.py (100%) rename aion/{ => codecs}/tokenizers/base.py (97%) rename aion/{ => codecs}/tokenizers/image.py (93%) rename aion/{ => codecs}/utils.py (100%) diff --git a/aion/modules/__init__.py b/aion/codecs/__init__.py similarity index 100% rename from aion/modules/__init__.py rename to aion/codecs/__init__.py diff --git a/aion/codecs/modules/__init__.py b/aion/codecs/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/aion/modules/magvit.py b/aion/codecs/modules/magvit.py similarity index 100% rename from aion/modules/magvit.py rename to aion/codecs/modules/magvit.py diff --git a/aion/modules/subsampler.py b/aion/codecs/modules/subsampler.py similarity index 100% rename from aion/modules/subsampler.py rename to aion/codecs/modules/subsampler.py diff --git a/aion/quantizers/__init__.py b/aion/codecs/quantizers/__init__.py similarity index 100% rename from aion/quantizers/__init__.py rename to aion/codecs/quantizers/__init__.py diff --git a/aion/tokenizers/__init__.py b/aion/codecs/tokenizers/__init__.py similarity index 100% rename from aion/tokenizers/__init__.py rename to aion/codecs/tokenizers/__init__.py diff --git a/aion/tokenizers/base.py b/aion/codecs/tokenizers/base.py similarity index 97% rename from aion/tokenizers/base.py rename to aion/codecs/tokenizers/base.py index f870c64..a3194fa 100644 --- a/aion/tokenizers/base.py +++ b/aion/codecs/tokenizers/base.py @@ -3,7 +3,7 @@ import torch from jaxtyping import Float -from aion.quantizers import Quantizer +from aion.codecs.quantizers import Quantizer class Codec(ABC): diff --git a/aion/tokenizers/image.py b/aion/codecs/tokenizers/image.py similarity index 93% rename from aion/tokenizers/image.py rename to aion/codecs/tokenizers/image.py index b9c28b9..abae50b 100644 --- a/aion/tokenizers/image.py +++ b/aion/codecs/tokenizers/image.py @@ -1,11 +1,11 @@ import torch from jaxtyping import Float -from aion.modules.magvit import MagVitAE -from aion.modules.subsampler import SubsampledLinear -from aion.quantizers import Quantizer -from aion.tokenizers.base import QuantizedCodec -from aion.utils import range_compression, reverse_range_compression +from aion.codecs.modules.magvit import MagVitAE +from aion.codecs.modules.subsampler import SubsampledLinear +from aion.codecs.quantizers import Quantizer +from aion.codecs.tokenizers.base import QuantizedCodec +from aion.codecs.utils import range_compression, reverse_range_compression class AutoencoderImageCodec(QuantizedCodec): diff --git a/aion/utils.py b/aion/codecs/utils.py similarity index 100% rename from aion/utils.py rename to aion/codecs/utils.py diff --git a/tests/tokenizers/test_image_tokenizer.py b/tests/tokenizers/test_image_tokenizer.py index 41f50a8..8b8d509 100644 --- a/tests/tokenizers/test_image_tokenizer.py +++ b/tests/tokenizers/test_image_tokenizer.py @@ -1,8 +1,8 @@ import pytest import torch -from aion.quantizers import FiniteScaleQuantizer -from aion.tokenizers.image import MagViTAEImageCodec +from aion.codecs.tokenizers.image import MagViTAEImageCodec +from aion.codecs.quantizers import FiniteScaleQuantizer @pytest.mark.parametrize("n_bands", [3, 10]) From 4dd6d98fc3bb3c7804fb7238c898835b555c6460 Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Wed, 9 Apr 2025 12:32:20 +0200 Subject: [PATCH 18/20] Reorganize quantizers and add identity quantizer --- aion/codecs/quantizers/__init__.py | 196 ++------------------------ aion/codecs/quantizers/base.py | 40 ++++++ aion/codecs/quantizers/scalar.py | 219 +++++++++++++++++++++++++++++ 3 files changed, 267 insertions(+), 188 deletions(-) create mode 100644 aion/codecs/quantizers/base.py create mode 100644 aion/codecs/quantizers/scalar.py diff --git a/aion/codecs/quantizers/__init__.py b/aion/codecs/quantizers/__init__.py index 18b6102..40c8137 100644 --- a/aion/codecs/quantizers/__init__.py +++ b/aion/codecs/quantizers/__init__.py @@ -1,188 +1,8 @@ -import math -from abc import ABC, abstractmethod - -import torch -from jaxtyping import Float, Integer - - -class Quantizer(torch.nn.Module, ABC): - """Abstract interface for all quantizer modules.""" - - @abstractmethod - def quantize( - self, x: Float[torch.Tensor, " b c1 *input_shape"] - ) -> Float[torch.Tensor, " b c *code_shape"]: - """Quantize the input tensor.""" - raise NotImplementedError - - @abstractmethod - def reconstruct( - self, z: Float[torch.Tensor, " b c *code_shape"] - ) -> Float[torch.Tensor, " b c *input_shape"]: - """Reconstruct the input tensor from the quantized tensor.""" - raise NotImplementedError - - @abstractmethod - def forward( - self, z_e: Float[torch.Tensor, " b c *input_shape"] - ) -> tuple[ - Float[torch.Tensor, " b c *code_shape"], - Float[torch.Tensor, " b"], - Float[torch.Tensor, " b"], - ]: - """Performs a forward pass through the vector quantizer. - Args: - x: The input tensor to be quantized. - Returns: - z: The quantized tensor. - quantization_error: The error of the quantization. - codebook_usage: The fraction of codes used in the codebook. - """ - raise NotImplementedError - - -class FiniteScaleQuantizer(Quantizer): - def __init__( - self, - levels: list[int], - eps: float = 1e-3, - ): - """Finite scalar quantizer (FSQ) module - https://arxiv.org/pdf/2309.15505.pdf - - Following the implementation from: - https://github.com/duchenzhuang/FSQ-pytorch/blob/main/quantizers/fsq.py - - Args: - levels: list[int] - The number of levels for each dimension. Length of the list should match - the number of embedding dimensions. - eps: float - The epsilon value for the FSQ. - """ - super().__init__() - _levels = torch.tensor(levels, dtype=torch.int32) - self.register_buffer("levels", _levels) - self._embedding_dim = len(levels) - self._basis = torch.cumprod( - torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int32 - ) - self.eps = eps - - @property - def codebook_size(self): - return math.prod(self.levels) - - @property - def embedding_dim(self): - return self._embedding_dim - - def _bound( - self, z: Float[torch.Tensor, " b t *c"] - ) -> Float[torch.Tensor, " b t *c"]: - """Bound `z`, an array of shape (..., d).""" - half_l = (self.levels - 1) * (1 + self.eps) / 2 - offset = torch.where(self.levels % 2 == 1, 0.0, 0.5) - shift = torch.atanh(offset / half_l) - return torch.tanh(z + shift) * half_l - offset - - def _quantize( - self, z: Float[torch.Tensor, " b t *c"] - ) -> Float[torch.Tensor, " b t *c"]: - """Quantizes z, returns quantized codes zhat with the same shape as z. - Assumes last dimension of z is the embedding dimension. - """ - - def round_ste(z): - zhat = z.round() - return z + (zhat - z).detach() - - quantized = round_ste(self._bound(z)) - # Renormalize to [-1, 1]. - half_width = self.levels // 2 - return quantized / half_width - - def _scale_and_shift(self, zhat_normalized): - half_width = self.levels // 2 - return (zhat_normalized * half_width) + half_width - - def _scale_and_shift_inverse(self, zhat): - half_width = self.levels // 2 - return (zhat - half_width) / half_width - - def quantize( - self, z: Float[torch.Tensor, " b *c t"] - ) -> Float[torch.Tensor, " b *c t"]: - """ - Quantizes the input tensor. - - Args: - z (Tensor): The input tensor to be quantized. - - Returns: - Tensor: The quantized tensor, same shape as input. - """ - # Move the embedding dimension to the last dimension for easier broadcasting - z = z.moveaxis(1, -1) - zhat = self._quantize(z) - return zhat.moveaxis(-1, 1) - - def encode( - self, z: Float[torch.Tensor, " b *c t"] - ) -> Integer[torch.Tensor, " b *code"]: - """ - Encodes the input tensor `z` using quantization. - - Args: - z (Tensor): The input tensor to be encoded. - - Returns: - Tensor: integer code index. - """ - # Move the embedding dimension to the last dimension for easier broadcasting - z = z.moveaxis(1, -1) - zhat = self._quantize(z) - zhat = self._scale_and_shift(zhat) - return (zhat * self._basis.to(zhat)).sum(axis=-1).to(torch.int32) - - def reconstruct( - self, codes: Integer[torch.Tensor, " b *code"] - ) -> Float[torch.Tensor, "b *c t"]: - """ - Decodes the given codes into the corresponding values. - - Args: - codes (Tensor): The codes to be decoded. - - Returns: - Tensor: The decoded tensor. - """ - indices = codes.unsqueeze(-1) - codes_non_centered = (indices // self._basis.to(indices)) % self.levels - zhat = self._scale_and_shift_inverse(codes_non_centered) - # Move the embedding dimension back to the second dimension - return zhat.moveaxis(-1, 1) - - def forward( - self, z_e: Float[torch.Tensor, " b t *codes"] - ) -> tuple[ - Float[torch.Tensor, " b t *shape"], - Float[torch.Tensor, ""], - Float[torch.Tensor, ""], - ]: - """ - Forward pass of the quantizer module. - - Args: - z_e: The input tensor. - - Returns: - tuple[Tensor, Tensor, Tensor]: A tuple containing: - - decoded (Tensor): The decoded tensor. - - loss (Tensor): In this case, no additional loss is necessary, so always returns 0. - - codebook_usage (Tensor): The ratio of unique codes used in the codebook. - """ - z_q = self.quantize(z_e) - codes = self.encode(z_e) - codebook_usage = len(torch.unique(codes)) / self.codebook_size - return z_q, torch.zeros([]), codebook_usage +from .base import Quantizer +from .scalar import FiniteScaleQuantizer, IdentityQuantizer + +__all__ = [ + "FiniteScaleQuantizer", + "IdentityQuantizer", + "Quantizer", +] diff --git a/aion/codecs/quantizers/base.py b/aion/codecs/quantizers/base.py new file mode 100644 index 0000000..752245f --- /dev/null +++ b/aion/codecs/quantizers/base.py @@ -0,0 +1,40 @@ +from abc import ABC, abstractmethod + +import torch +from jaxtyping import Float + + +class Quantizer(torch.nn.Module, ABC): + """Abstract interface for all quantizer modules.""" + + @abstractmethod + def quantize( + self, x: Float[torch.Tensor, " b c1 *input_shape"] + ) -> Float[torch.Tensor, " b c *code_shape"]: + """Quantize the input tensor.""" + raise NotImplementedError + + @abstractmethod + def reconstruct( + self, z: Float[torch.Tensor, " b c *code_shape"] + ) -> Float[torch.Tensor, " b c *input_shape"]: + """Reconstruct the input tensor from the quantized tensor.""" + raise NotImplementedError + + @abstractmethod + def forward( + self, z_e: Float[torch.Tensor, " b c *input_shape"] + ) -> tuple[ + Float[torch.Tensor, " b c *code_shape"], + Float[torch.Tensor, " b"], + Float[torch.Tensor, " b"], + ]: + """Performs a forward pass through the vector quantizer. + Args: + x: The input tensor to be quantized. + Returns: + z: The quantized tensor. + quantization_error: The error of the quantization. + codebook_usage: The fraction of codes used in the codebook. + """ + raise NotImplementedError diff --git a/aion/codecs/quantizers/scalar.py b/aion/codecs/quantizers/scalar.py new file mode 100644 index 0000000..3834547 --- /dev/null +++ b/aion/codecs/quantizers/scalar.py @@ -0,0 +1,219 @@ +import math + +import torch +from jaxtyping import Float, Integer + +from aion.codecs.quantizers.base import Quantizer + + +class IdentityQuantizer(Quantizer): + """ + Identity quantizer module. + + The identity quantizer module takes a batch of tensors and returns the same tensor. + + Args: + codebook_size: int + The number of labels to be used as signature for the codebook. + """ + + def __init__(self, codebook_size: int): + super().__init__() + self._codebook_size = codebook_size + + def forward( + self, z_e: Float[torch.Tensor, " b c"] + ) -> tuple[ + Float[torch.Tensor, " b c"], + Float[torch.Tensor, " b"], + Float[torch.Tensor, " b"], + ]: + """Performs a forward pass through the vector quantizer. + Args: + z_e: The input tensor to be quantized. + Returns: + z_q: The quantized tensor. + loss: The embedding loss for the quantization. + codebook_usage: The fraction of codes used in the codebook. + """ + codebook_usage = z_e.unique().numel() / self._codebook_size.item() + return self.quantize(z_e), torch.tensor(0), torch.tensor(codebook_usage) + + def quantize(self, z: Float[torch.Tensor, " b c"]) -> Float[torch.Tensor, " b c"]: + """Quantize the input tensor z, returns corresponding + codebook entry. + """ + return z + + def encode(self, z: Float[torch.Tensor, " b c"]) -> Integer[torch.Tensor, " b c"]: + return self.quantize(z) + + def reconstruct( + self, codes: Float[torch.Tensor, " b c"] + ) -> Float[torch.Tensor, " b c"]: + """Decodes the input code index into corresponding codebook entry of + dimension (embedding_dim). + """ + return codes + + @property + def codebook_size(self) -> int: + """Returns the size of the codebook.""" + return int(self._codebook_size.item()) + + @property + def codebook(self) -> Float[torch.Tensor, " c"]: + """Returns the codebook.""" + return torch.arange(self._codebook_size.item()) + + @property + def embedding_dim(self) -> int: + """Returns the dimension of the codebook entries.""" + return 1 + + +class FiniteScaleQuantizer(Quantizer): + def __init__( + self, + levels: list[int], + eps: float = 1e-3, + ): + """Finite scalar quantizer (FSQ) module + https://arxiv.org/pdf/2309.15505.pdf + + Following the implementation from: + https://github.com/duchenzhuang/FSQ-pytorch/blob/main/quantizers/fsq.py + + Args: + levels: list[int] + The number of levels for each dimension. Length of the list should match + the number of embedding dimensions. + eps: float + The epsilon value for the FSQ. + """ + super().__init__() + _levels = torch.tensor(levels, dtype=torch.int32) + self.register_buffer("levels", _levels) + self._embedding_dim = len(levels) + self._basis = torch.cumprod( + torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int32 + ) + self.eps = eps + + @property + def codebook_size(self): + return math.prod(self.levels) + + @property + def embedding_dim(self): + return self._embedding_dim + + def _bound( + self, z: Float[torch.Tensor, " b t *c"] + ) -> Float[torch.Tensor, " b t *c"]: + """Bound `z`, an array of shape (..., d).""" + half_l = (self.levels - 1) * (1 + self.eps) / 2 + offset = torch.where(self.levels % 2 == 1, 0.0, 0.5) + shift = torch.atanh(offset / half_l) + return torch.tanh(z + shift) * half_l - offset + + def _quantize( + self, z: Float[torch.Tensor, " b t *c"] + ) -> Float[torch.Tensor, " b t *c"]: + """Quantizes z, returns quantized codes zhat with the same shape as z. + Assumes last dimension of z is the embedding dimension. + """ + + def round_ste(z): + zhat = z.round() + return z + (zhat - z).detach() + + quantized = round_ste(self._bound(z)) + # Renormalize to [-1, 1]. + half_width = self.levels // 2 + return quantized / half_width + + def _scale_and_shift(self, zhat_normalized): + half_width = self.levels // 2 + return (zhat_normalized * half_width) + half_width + + def _scale_and_shift_inverse(self, zhat): + half_width = self.levels // 2 + return (zhat - half_width) / half_width + + def quantize( + self, z: Float[torch.Tensor, " b *c t"] + ) -> Float[torch.Tensor, " b *c t"]: + """ + Quantizes the input tensor. + + Args: + z (Tensor): The input tensor to be quantized. + + Returns: + Tensor: The quantized tensor, same shape as input. + """ + # Move the embedding dimension to the last dimension for easier broadcasting + z = z.moveaxis(1, -1) + zhat = self._quantize(z) + return zhat.moveaxis(-1, 1) + + def encode( + self, z: Float[torch.Tensor, " b *c t"] + ) -> Integer[torch.Tensor, " b *code"]: + """ + Encodes the input tensor `z` using quantization. + + Args: + z (Tensor): The input tensor to be encoded. + + Returns: + Tensor: integer code index. + """ + # Move the embedding dimension to the last dimension for easier broadcasting + z = z.moveaxis(1, -1) + zhat = self._quantize(z) + zhat = self._scale_and_shift(zhat) + return (zhat * self._basis.to(zhat)).sum(axis=-1).to(torch.int32) + + def reconstruct( + self, codes: Integer[torch.Tensor, " b *code"] + ) -> Float[torch.Tensor, "b *c t"]: + """ + Decodes the given codes into the corresponding values. + + Args: + codes (Tensor): The codes to be decoded. + + Returns: + Tensor: The decoded tensor. + """ + indices = codes.unsqueeze(-1) + codes_non_centered = (indices // self._basis.to(indices)) % self.levels + zhat = self._scale_and_shift_inverse(codes_non_centered) + # Move the embedding dimension back to the second dimension + return zhat.moveaxis(-1, 1) + + def forward( + self, z_e: Float[torch.Tensor, " b t *codes"] + ) -> tuple[ + Float[torch.Tensor, " b t *shape"], + Float[torch.Tensor, ""], + Float[torch.Tensor, ""], + ]: + """ + Forward pass of the quantizer module. + + Args: + z_e: The input tensor. + + Returns: + tuple[Tensor, Tensor, Tensor]: A tuple containing: + - decoded (Tensor): The decoded tensor. + - loss (Tensor): In this case, no additional loss is necessary, so always returns 0. + - codebook_usage (Tensor): The ratio of unique codes used in the codebook. + """ + z_q = self.quantize(z_e) + codes = self.encode(z_e) + codebook_usage = len(torch.unique(codes)) / self.codebook_size + return z_q, torch.zeros([]), codebook_usage From 72d3bf28b10d06d270dbd1db57a1f066ef79d17d Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Wed, 9 Apr 2025 12:32:48 +0200 Subject: [PATCH 19/20] Add identity scalar codec --- aion/codecs/tokenizers/scalar.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 aion/codecs/tokenizers/scalar.py diff --git a/aion/codecs/tokenizers/scalar.py b/aion/codecs/tokenizers/scalar.py new file mode 100644 index 0000000..20467c2 --- /dev/null +++ b/aion/codecs/tokenizers/scalar.py @@ -0,0 +1,26 @@ +from aion.codecs.tokenizers.base import QuantizedCodec +from aion.codecs.quantizers import Quantizer + +from jaxtyping import Float +import torch + + +class ScalarIdentityCodec(QuantizedCodec): + """Codec for scalar quantities. + + A codec that embeds scalar quantities through an identity mapping. + + """ + + def __init__(self, quantizer: Quantizer): + super().__init__(quantizer) + + @property + def modality(self): + return "label" + + def _encode(self, x: Float[torch.Tensor, " b t"]) -> Float[torch.Tensor, " b t"]: + return x + + def _decode(self, z: Float[torch.Tensor, " b c"]) -> Float[torch.Tensor, " b c"]: + return z From d1a5566a632f0d77a0a0e7e23fb8c4636ec30b97 Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Wed, 9 Apr 2025 12:33:10 +0200 Subject: [PATCH 20/20] Add test for scalar identity codec --- tests/tokenizers/test_scalar_tokenizer.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 tests/tokenizers/test_scalar_tokenizer.py diff --git a/tests/tokenizers/test_scalar_tokenizer.py b/tests/tokenizers/test_scalar_tokenizer.py new file mode 100644 index 0000000..9e0055e --- /dev/null +++ b/tests/tokenizers/test_scalar_tokenizer.py @@ -0,0 +1,16 @@ +import pytest +import torch + +from aion.codecs.quantizers import IdentityQuantizer +from aion.codecs.tokenizers.scalar import ScalarIdentityCodec + + +@pytest.mark.parametrize("codebook_size", [10, 20]) +@pytest.mark.parametrize("embedding_dim", [1, 4]) +def test_scalar_identity_codec(codebook_size, embedding_dim): + quantizer = IdentityQuantizer(codebook_size=codebook_size) + codec = ScalarIdentityCodec(quantizer=quantizer) + x = torch.randint(0, codebook_size, (64, embedding_dim)) + z = codec.encode(x) + assert z.shape == x.shape + assert torch.allclose(z, x)