Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install .
- name: Run tests
run: |
pip install ".[dev]"
pytest tests
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,13 @@ venv.bak/
/site

# mypy
.mypy_cache/
.mypy_cache
.dmypy.json
dmypy.json

# Ruff
.ruff_cache

# Pyre type checker
.pyre/

Expand Down
Empty file added aion/codecs/__init__.py
Empty file.
Empty file added aion/codecs/modules/__init__.py
Empty file.
214 changes: 214 additions & 0 deletions aion/codecs/modules/magvit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
import torch
from einops import rearrange, repeat
from einops.layers.torch import Rearrange


def cast_tuple(t, length=1):
return t if isinstance(t, tuple) else ((t,) * length)


class SameConv2d(torch.nn.Module):
def __init__(self, dim_in, dim_out, kernel_size):
super().__init__()
kernel_size = cast_tuple(kernel_size, 2)
padding = [k // 2 for k in kernel_size]
self.conv = torch.nn.Conv2d(
dim_in, dim_out, kernel_size=kernel_size, padding=padding
)

def forward(self, x: torch.Tensor):
return self.conv(x)


class SqueezeExcite(torch.nn.Module):
# global context network - attention-esque squeeze-excite variant (https://arxiv.org/abs/2012.13375)

def __init__(self, dim, *, dim_out=None, dim_hidden_min=16, init_bias=-10):
super().__init__()
dim_out = dim_out if dim_out is not None else dim

self.to_k = torch.nn.Conv2d(dim, 1, 1)
dim_hidden = max(dim_hidden_min, dim_out // 2)

self.net = torch.nn.Sequential(
torch.nn.Conv2d(dim, dim_hidden, 1),
torch.nn.LeakyReLU(0.1),
torch.nn.Conv2d(dim_hidden, dim_out, 1),
torch.nn.Sigmoid(),
)

torch.nn.init.zeros_(self.net[-2].weight)
torch.nn.init.constant_(self.net[-2].bias, init_bias)

def forward(self, x):
context = self.to_k(x)

context = rearrange(context, "b c h w -> b c (h w)").softmax(dim=-1)
spatial_flattened_input = rearrange(x, "b c h w -> b c (h w)")

out = torch.einsum("b i n, b c n -> b c i", context, spatial_flattened_input)
out = rearrange(out, "... -> ... 1")
gates = self.net(out)

return gates * x


class ResidualUnit(torch.nn.Module):
def __init__(self, dim: int, kernel_size: int | tuple[int, int, int]):
super().__init__()
self.net = torch.nn.Sequential(
SameConv2d(dim, dim, kernel_size),
torch.nn.ELU(),
torch.nn.Conv2d(dim, dim, 1),
torch.nn.ELU(),
SqueezeExcite(dim),
)

def forward(self, x: torch.Tensor):
return self.net(x) + x


class SpatialDownsample2x(torch.nn.Module):
def __init__(
self,
dim: int,
dim_out: int = None,
kernel_size: int = 3,
):
super().__init__()
dim_out = dim_out if dim_out is not None else dim
self.conv = torch.nn.Conv2d(
dim, dim_out, kernel_size, stride=2, padding=kernel_size // 2
)

def forward(self, x: torch.Tensor):
out = self.conv(x)
return out


class SpatialUpsample2x(torch.nn.Module):
def __init__(self, dim: int, dim_out: int = None):
super().__init__()
dim_out = dim_out if dim_out is not None else dim
conv = torch.nn.Conv2d(dim, dim_out * 4, 1)

self.net = torch.nn.Sequential(
conv,
torch.nn.SiLU(),
Rearrange("b (c p1 p2) h w -> b c (h p1) (w p2)", p1=2, p2=2),
)

self.init_conv_(conv)

def init_conv_(self, conv: torch.nn.Module):
o, i, h, w = conv.weight.shape
conv_weight = torch.empty(o // 4, i, h, w)
torch.nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, "o ... -> (o 4) ...")

conv.weight.data.copy_(conv_weight)
torch.nn.init.zeros_(conv.bias.data)

def forward(self, x: torch.Tensor):
out = self.net(x)
return out


class MagVitAE(torch.nn.Module):
"""MagViTAE implementation from Yu, et al. (2024), adapted for Pytorch.
Code borrowed from https://github.com/lucidrains/magvit2-pytorch, and adapted for images.
"""

def __init__(
self,
n_bands: int = 3,
hidden_dims: int = 512,
residual_conv_kernel_size: int = 3,
n_compressions: int = 2,
num_consecutive: int = 2,
):
super().__init__()

self.encoder_layers = torch.nn.ModuleList([])
self.decoder_layers = torch.nn.ModuleList([])
init_dim = int(hidden_dims / 2**n_compressions)
dim = init_dim

self.conv_in = SameConv2d(n_bands, init_dim, 7)
self.conv_out = SameConv2d(init_dim, n_bands, 3)

# Residual layers
encoder_layer = ResidualUnit(dim, residual_conv_kernel_size)
decoder_layer = ResidualUnit(dim, residual_conv_kernel_size)
self.encoder_layers.append(encoder_layer)
self.decoder_layers.insert(0, decoder_layer)

# Compressions
for i in range(n_compressions):
dim_out = dim * 2
encoder_layer = SpatialDownsample2x(dim, dim_out)
decoder_layer = SpatialUpsample2x(dim_out, dim)
self.encoder_layers.append(encoder_layer)
self.decoder_layers.insert(0, decoder_layer)
dim = dim_out

# Consecutive residual layers
encoder_layer = torch.nn.Sequential(
*[
ResidualUnit(dim, residual_conv_kernel_size)
for _ in range(num_consecutive)
]
)
decoder_layer = torch.nn.Sequential(
*[
ResidualUnit(dim, residual_conv_kernel_size)
for _ in range(num_consecutive)
]
)
self.encoder_layers.append(encoder_layer)
self.decoder_layers.insert(0, decoder_layer)

# Add a final non-compress layer
dim_out = dim
encoder_layer = SameConv2d(dim, dim_out, 7)
decoder_layer = SameConv2d(dim_out, dim, 3)
self.encoder_layers.append(encoder_layer)
self.decoder_layers.insert(0, decoder_layer)
dim = dim_out

# Consecutive residual layers
encoder_layer = torch.nn.Sequential(
*[
ResidualUnit(dim, residual_conv_kernel_size)
for _ in range(num_consecutive)
]
)
decoder_layer = torch.nn.Sequential(
*[
ResidualUnit(dim, residual_conv_kernel_size)
for _ in range(num_consecutive)
]
)
self.encoder_layers.append(encoder_layer)
self.decoder_layers.insert(0, decoder_layer)

# add a final norm just before quantization layer
self.encoder_layers.append(
torch.nn.Sequential(
Rearrange("b c ... -> b ... c"),
torch.nn.LayerNorm(dim),
Rearrange("b ... c -> b c ..."),
)
)

def encode(self, x: torch.Tensor):
x = self.conv_in(x)
for layer in self.encoder_layers:
x = layer(x)
return x

def decode(self, x: torch.Tensor):
for layer in self.decoder_layers:
x = layer(x)
x = self.conv_out(x)
return x
60 changes: 60 additions & 0 deletions aion/codecs/modules/subsampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torch
import torch.nn.functional as F
from einops import rearrange
from jaxtyping import Bool, Float


class SubsampledLinear(torch.nn.Module):
def __init__(self, dim_in: int, dim_out: int, subsample_in: bool = True):
"""
Subsampled linear layer for the encoder.
It takes in a zero-padded tensor and a mask.
It projects the tensor into some shared projection space.
It can also be used to reverse out of the space with the mask.

Args:
dim_in : Number of total possible bands.
dim_out : Number of embedding dimensions.
subsample_in : Whether to subsample the input. Defaults to True.
"""
super().__init__()
self.subsample_in = subsample_in
self.dim_in = dim_in # Number of total possible bands
self.dim_out = dim_out # Number of embedding dimensions
temp_linear = torch.nn.Linear(dim_in, dim_out)
self.weight = torch.nn.Parameter(temp_linear.weight)
self.bias = torch.nn.Parameter(temp_linear.bias)

def _subsample_in(self, x, labels: Bool[torch.Tensor, " b c"]):
# Get mask
mask = labels[:, None, None, :].float()
x = x * mask

# Normalize
label_sizes = labels.sum(dim=1, keepdim=True)
scales = ((self.dim_in / label_sizes) ** 0.5).squeeze()

# Apply linear layer
return scales[:, None, None, None] * F.linear(x, self.weight, self.bias)

def _subsample_out(self, x, labels):
# Get mask
mask = labels[:, None, None, :].float()

# Apply linear layer and mask
return F.linear(x, self.weight, self.bias) * mask

def forward(
self, x: Float[torch.Tensor, " b c h w"], labels: Bool[torch.Tensor, " b c"]
) -> Float[torch.Tensor, " b c h w"]:
x = rearrange(x, "b c h w -> b h w c")

if self.subsample_in:
x = self._subsample_in(x, labels)

else:
x = self._subsample_out(x, labels)

x = rearrange(x, "b h w c -> b c h w")

return x
8 changes: 8 additions & 0 deletions aion/codecs/quantizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .base import Quantizer
from .scalar import FiniteScaleQuantizer, IdentityQuantizer

__all__ = [
"FiniteScaleQuantizer",
"IdentityQuantizer",
"Quantizer",
]
40 changes: 40 additions & 0 deletions aion/codecs/quantizers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from abc import ABC, abstractmethod

import torch
from jaxtyping import Float


class Quantizer(torch.nn.Module, ABC):
"""Abstract interface for all quantizer modules."""

@abstractmethod
def quantize(
self, x: Float[torch.Tensor, " b c1 *input_shape"]
) -> Float[torch.Tensor, " b c *code_shape"]:
"""Quantize the input tensor."""
raise NotImplementedError

@abstractmethod
def reconstruct(
self, z: Float[torch.Tensor, " b c *code_shape"]
) -> Float[torch.Tensor, " b c *input_shape"]:
"""Reconstruct the input tensor from the quantized tensor."""
raise NotImplementedError

@abstractmethod
def forward(
self, z_e: Float[torch.Tensor, " b c *input_shape"]
) -> tuple[
Float[torch.Tensor, " b c *code_shape"],
Float[torch.Tensor, " b"],
Float[torch.Tensor, " b"],
]:
"""Performs a forward pass through the vector quantizer.
Args:
x: The input tensor to be quantized.
Returns:
z: The quantized tensor.
quantization_error: The error of the quantization.
codebook_usage: The fraction of codes used in the codebook.
"""
raise NotImplementedError
Loading