From 9f4147858e2cf5b9d12a3036658bca67aa82a1f9 Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Thu, 21 Aug 2025 13:31:19 +0200 Subject: [PATCH 1/2] asdasd --- README.md | 8 +- .../experimental/consistency_distillation.py | 51 ++++ blaxbird/_src/experimental/edm.py | 66 +--- blaxbird/_src/experimental/nn/unet.py | 289 ++++++++++++++++++ .../_src/experimental/parameterizations.py | 105 +++++++ blaxbird/_src/experimental/rfm.py | 42 +-- blaxbird/_src/experimental/samplers.py | 5 +- pyproject.toml | 5 +- 8 files changed, 456 insertions(+), 115 deletions(-) create mode 100644 blaxbird/_src/experimental/consistency_distillation.py create mode 100644 blaxbird/_src/experimental/nn/unet.py create mode 100644 blaxbird/_src/experimental/parameterizations.py diff --git a/README.md b/README.md index 83badd3..0f19504 100644 --- a/README.md +++ b/README.md @@ -15,11 +15,11 @@ Using `blaxbird` one can - distribute data and model weights over multiple processes or GPUs, - define hooks that are periodically called during training. -In addition, `blaxbird` offers high-quality implementation of common neural network modules and algorithms, such as: +In addition, `blaxbird` offers high-quality implementations of common neural network modules and algorithms, such as: -- MLP, Diffusion Transformer, -- Flow Matching and Denoising Score Matching (EDM schedules) with Euler and Heun samplers, -- Consistency Distillation/Matching. +- MLPs, DiTs, UNets, +- Flow Matching and Denoising Score Matching (EDM schedules) models with Euler and Heun samplers, +- Consistency Distillation/Matching models. ## Example diff --git a/blaxbird/_src/experimental/consistency_distillation.py b/blaxbird/_src/experimental/consistency_distillation.py new file mode 100644 index 0000000..df87fc5 --- /dev/null +++ b/blaxbird/_src/experimental/consistency_distillation.py @@ -0,0 +1,51 @@ +import numpy as np +from flax import nnx +from jax import numpy as jnp +from jax import random as jr + +from blaxbird._src.experimental import samplers +from blaxbird._src.experimental.parameterizations import RFMConfig + + +def _forward_process(inputs, times, noise): + new_shape = (-1,) + tuple(np.ones(inputs.ndim - 1, dtype=np.int32).tolist()) + times = times.reshape(new_shape) + inputs_t = times * inputs + (1.0 - times) * noise + return inputs_t + + +def rfm(config: RFMConfig = RFMConfig()): + """Construct rectified flow matching functions. + + Args: + config: a FlowMatchingConfig object + + Returns: + returns a tuple consisting of train_step, val_step and sampling functions + """ + parameterization = config.parameterization + + def _loss_fn(model, rng_key, batch): + inputs = batch["inputs"] + time_key, rng_key = jr.split(rng_key) + times = jr.uniform(time_key, shape=(inputs.shape[0],)) + times = ( + times * (parameterization.t_max - parameterization.t_eps) + + parameterization.t_eps + ) + noise_key, rng_key = jr.split(rng_key) + noise = jr.normal(noise_key, inputs.shape) + inputs_t = _forward_process(inputs, times, noise) + vt = model(inputs=inputs_t, times=times, context=batch.get("context")) + ut = inputs - noise + loss = jnp.mean(jnp.square(ut - vt)) + return loss + + def train_step(model, rng_key, batch, **kwargs): + return nnx.value_and_grad(_loss_fn)(model, rng_key, batch) + + def val_step(model, rng_key, batch, **kwargs): + return _loss_fn(model, rng_key, batch) + + sampler = getattr(samplers, config.sampler + "_sample_fn")(config) + return train_step, val_step, sampler diff --git a/blaxbird/_src/experimental/edm.py b/blaxbird/_src/experimental/edm.py index 5068980..cb06227 100644 --- a/blaxbird/_src/experimental/edm.py +++ b/blaxbird/_src/experimental/edm.py @@ -1,74 +1,10 @@ -import dataclasses - import numpy as np from flax import nnx from jax import numpy as jnp from jax import random as jr from blaxbird._src.experimental import samplers - - -@dataclasses.dataclass -class EDMParameterization: - n_sampling_steps: int = 25 - sigma_min: float = 0.002 - sigma_max: float = 80.0 - rho: float = 7.0 - sigma_data: float = 0.5 - P_mean: float = -1.2 - P_std: float = 1.2 - S_churn: float = 40 - S_min: float = 0.05 - S_max: float = 50 - S_noise: float = 1.003 - - def sigma(self, eps): - return jnp.exp(eps * self.P_std + self.P_mean) - - def loss_weight(self, sigma): - return (jnp.square(sigma) + jnp.square(self.sigma_data)) / jnp.square( - sigma * self.sigma_data - ) - - def skip_scaling(self, sigma): - return self.sigma_data**2 / (sigma**2 + self.sigma_data**2) - - def out_scaling(self, sigma): - return sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 - - def in_scaling(self, sigma): - return 1 / (sigma**2 + self.sigma_data**2) ** 0.5 - - def noise_conditioning(self, sigma): - return 0.25 * jnp.log(sigma) - - def sampling_sigmas(self, num_steps): - rho_inv = 1 / self.rho - step_idxs = jnp.arange(num_steps, dtype=jnp.float32) - sigmas = ( - self.sigma_max**rho_inv - + step_idxs - / (num_steps - 1) - * (self.sigma_min**rho_inv - self.sigma_max**rho_inv) - ) ** self.rho - return jnp.concatenate([sigmas, jnp.zeros_like(sigmas[:1])]) - - def sigma_hat(self, sigma, num_steps): - gamma = ( - jnp.minimum(self.S_churn / num_steps, 2**0.5 - 1) - if self.S_min <= sigma <= self.S_max - else 0 - ) - return sigma + gamma * sigma - - -@dataclasses.dataclass -class EDMConfig: - n_sampling_steps: int = 25 - sampler: str = "heun" - parameterization: EDMParameterization = dataclasses.field( - default_factory=EDMParameterization - ) +from blaxbird._src.experimental.parameterizations import EDMConfig def edm(config: EDMConfig): diff --git a/blaxbird/_src/experimental/nn/unet.py b/blaxbird/_src/experimental/nn/unet.py new file mode 100644 index 0000000..3c11dbf --- /dev/null +++ b/blaxbird/_src/experimental/nn/unet.py @@ -0,0 +1,289 @@ +import jax +from einops import rearrange +from flax import nnx +from jax import numpy as jnp + +from blaxbird._src.experimental.nn.embedding import timestep_embedding +from blaxbird._src.experimental.nn.mlp import MLP + + +def _modulate(inputs, shift, scale): # noqa: ANN001, ANN202 + return inputs * (1.0 + scale[:, None]) + shift[:, None] + + +def get_sinusoidal_embedding_1d(length, embedding_dim): # noqa: ANN001, ANN202 + emb = timestep_embedding(length.reshape(-1), embedding_dim) + return emb + + +def sinusoidal_init(shape, dtype): # noqa: ANN001, ANN202 + def get_sinusoidal_embedding_2d(grid, embedding_dim): # noqa: ANN001, ANN202 + emb_h = get_sinusoidal_embedding_1d(grid[0], embedding_dim // 2) + emb_w = get_sinusoidal_embedding_1d(grid[1], embedding_dim // 2) + emb = jnp.concatenate([emb_h, emb_w], axis=1) + return emb + + _, n_h_patches, n_w_patches, embedding_dim = shape + grid_h = jnp.arange(n_h_patches, dtype=jnp.float32) + grid_w = jnp.arange(n_w_patches, dtype=jnp.float32) + grid = jnp.meshgrid(grid_w, grid_h) + + grid = jnp.stack(grid, axis=0) + grid = grid.reshape([2, 1, n_w_patches, n_h_patches]) + pos_embed = get_sinusoidal_embedding_2d(grid, embedding_dim) + + return jnp.expand_dims(pos_embed, 0) # (1, H*W, D) + + +class OutProjection(nnx.Module): + def __init__( # noqa: PLR0913 + self, hidden_size, n_embedding_features, patch_size, out_channels, *, rngs + ): + super().__init__() + self.ada = nnx.Sequential( + nnx.silu, nnx.Linear(n_embedding_features, 2 * hidden_size, rngs=rngs) + ) + self.norm = nnx.LayerNorm(hidden_size, rngs=rngs) + self.out = nnx.Linear( + hidden_size, patch_size * patch_size * out_channels, rngs=rngs + ) + + def __call__(self, inputs, context): + shift, scale = jnp.split(self.ada(context), 2, -1) + outs = self.out(_modulate(self.norm(inputs), shift, scale)) + return outs + + +class DiTBlock(nnx.Module): + def __init__( # noqa: PLR0913 + self, + hidden_size: int, + n_embedding_features: int, + *, + n_heads: int, + dropout_rate: float = 0.1, + rngs: nnx.rnglib.Rngs, + ): + """Diffusion-Transformer block. + + Args: + hidden_size: number of features of the hidden layers + n_embedding_features: number o features of time embedding + n_heads: number of transformer heads + dropout_rate: float + rngs: random keys + """ + super().__init__() + self.ada = nnx.Sequential( + nnx.silu, nnx.Linear(n_embedding_features, hidden_size * 6, rngs=rngs) + ) + + self.layer_norm1 = nnx.LayerNorm( + hidden_size, use_scale=False, use_bias=False, rngs=rngs + ) + self.self_attn = nnx.MultiHeadAttention( + num_heads=n_heads, in_features=hidden_size, rngs=rngs, decode=False + ) + self.layer_norm2 = nnx.LayerNorm( + hidden_size, use_scale=False, use_bias=False, rngs=rngs + ) + self.mlp = MLP( + hidden_size, + (hidden_size * 4, hidden_size), + dropout_rate=dropout_rate, + rngs=rngs, + ) + + def __call__(self, inputs: jax.Array, context: jax.Array) -> jax.Array: + """Transform inputs through the DiT block. + + Args: + inputs: input array + context: values to condition on + + Returns: + returns a jax.Array + """ + hidden = inputs + adaln_norm = self.ada(context) + attn, gate = jnp.split(adaln_norm, 2, axis=-1) + + pre_shift, pre_scale, post_scale = jnp.split(attn, 3, -1) + intermediate = _modulate(self.layer_norm1(hidden), pre_shift, pre_scale) + intermediate = self.self_attn(intermediate) + hidden = hidden + post_scale[:, None] * intermediate + + pre_shift, pre_scale, post_scale = jnp.split(gate, 3, -1) + intermediate = _modulate(self.layer_norm2(hidden), pre_shift, pre_scale) + intermediate = self.mlp(intermediate) + outputs = hidden + post_scale[:, None] * intermediate + + return outputs + + +class DiT(nnx.Module): + def __init__( # noqa: PLR0913 + self, + image_size: tuple[int, int, int], + n_hidden_channels: int, + patch_size: int, + n_layers: int, + n_heads: int, + n_embedding_features=256, + dropout_rate=0.0, + *, + rngs: nnx.rnglib.Rngs, + ): + """Diffusion-Transformer. + + Args: + image_size: size of the image, e.g., (32, 32, 3) + n_hidden_channels: number if hidden channels + patch_size: size of each path + n_layers: integer + n_heads: integer + n_embedding_features: integer + dropout_rate: float + rngs: random keys + """ + self.image_size = image_size + self.n_in_channels = image_size[-1] + self.n_embedding_features = n_embedding_features + self.patch_size = patch_size + self.time_embedding = nnx.Sequential( + nnx.Linear(n_embedding_features, n_embedding_features, rngs=rngs), + nnx.swish, + nnx.Linear(n_embedding_features, n_embedding_features, rngs=rngs), + nnx.swish, + ) + self.patchify = nnx.Conv( + self.n_in_channels, + n_hidden_channels, + (patch_size, patch_size), + (patch_size, patch_size), + padding="VALID", + kernel_init=nnx.initializers.xavier_uniform(), + rngs=rngs, + ) + self.patch_embedding = nnx.Param( + sinusoidal_init( + ( + 1, + image_size[0] // patch_size, + image_size[1] // patch_size, + n_hidden_channels, + ), + None, + ), + ) + self.dit_blocks = tuple( + [ + DiTBlock( + n_hidden_channels, + n_embedding_features, + n_heads=n_heads, + dropout_rate=dropout_rate, + rngs=rngs, + ) + for _ in range(n_layers) + ] + ) + self.out_projection = OutProjection( + n_hidden_channels, + n_embedding_features, + patch_size, + self.n_in_channels, + rngs=rngs, + ) + + def _patchify(self, inputs): + n_h_patches = self.image_size[0] // self.patch_size + n_w_patches = self.image_size[1] // self.patch_size + hidden = self.patchify(inputs) + outputs = rearrange( + hidden, "b h w c -> b (h w) c", h=n_h_patches, w=n_w_patches + ) + return outputs + + def _unpatchify(self, inputs): + H = self.image_size[0] // self.patch_size + W = self.image_size[1] // self.patch_size + P = Q = self.patch_size + hidden = jnp.reshape(inputs, (-1, H, W, P, Q, self.n_in_channels)) + outputs = rearrange( + hidden, "b h w p q c -> b (h p) (w q) c", h=H, w=W, p=P, q=Q + ) + return outputs + + def _embed(self, inputs): + return inputs + jax.lax.stop_gradient(self.patch_embedding.value) + + def __call__( + self, inputs: jax.Array, times: jax.Array, context: jax.Array = None + ): + """Transform inputs through the DiT. + + Args: + inputs: input in image form + times: one-dimensional array + context: conditioning variable in image form + + Returns: + returns a jax + """ + hidden = self._patchify(inputs) + hidden = self._embed(hidden) + times = self.time_embedding( + timestep_embedding(times, self.n_embedding_features) + ) + + for block in self.dit_blocks: + hidden = block(hidden, context=times) + + hidden = self.out_projection(hidden, times) + outputs = self._unpatchify(hidden) + return outputs + + +def SmallDiT(image_size, patch_size=2, **kwargs): + return DiT( + image_size, + n_hidden_channels=384, + patch_size=patch_size, + n_layers=12, + n_heads=6, + **kwargs, + ) + + +def BaseDiT(image_size, patch_size=2, **kwargs): + return DiT( + image_size, + n_hidden_channels=768, + patch_size=patch_size, + n_layers=12, + n_heads=12, + **kwargs, + ) + + +def LargeDiT(image_size, patch_size=2, **kwargs): + return DiT( + image_size, + n_hidden_channels=1024, + patch_size=patch_size, + n_layers=24, + n_heads=16, + **kwargs, + ) + + +def XtraLargeDiT(image_size, patch_size=2, **kwargs): + return DiT( + image_size, + n_hidden_channels=1152, + patch_size=patch_size, + n_layers=28, + n_heads=16, + **kwargs, + ) diff --git a/blaxbird/_src/experimental/parameterizations.py b/blaxbird/_src/experimental/parameterizations.py new file mode 100644 index 0000000..acf575e --- /dev/null +++ b/blaxbird/_src/experimental/parameterizations.py @@ -0,0 +1,105 @@ +import dataclasses + +from jax import numpy as jnp + + +@dataclasses.dataclass +class EDMParameterization: + n_sampling_steps: int = 25 + sigma_min: float = 0.002 + sigma_max: float = 80.0 + rho: float = 7.0 + sigma_data: float = 0.5 + P_mean: float = -1.2 + P_std: float = 1.2 + S_churn: float = 40 + S_min: float = 0.05 + S_max: float = 50 + S_noise: float = 1.003 + + def sigma(self, eps): + return jnp.exp(eps * self.P_std + self.P_mean) + + def loss_weight(self, sigma): + return (jnp.square(sigma) + jnp.square(self.sigma_data)) / jnp.square( + sigma * self.sigma_data + ) + + def skip_scaling(self, sigma): + return self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + + def out_scaling(self, sigma): + return sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 + + def in_scaling(self, sigma): + return 1 / (sigma**2 + self.sigma_data**2) ** 0.5 + + def noise_conditioning(self, sigma): + return 0.25 * jnp.log(sigma) + + def sampling_sigmas(self, num_steps): + rho_inv = 1 / self.rho + step_idxs = jnp.arange(num_steps, dtype=jnp.float32) + sigmas = ( + self.sigma_max**rho_inv + + step_idxs + / (num_steps - 1) + * (self.sigma_min**rho_inv - self.sigma_max**rho_inv) + ) ** self.rho + return jnp.concatenate([sigmas, jnp.zeros_like(sigmas[:1])]) + + def sigma_hat(self, sigma, num_steps): + gamma = ( + jnp.minimum(self.S_churn / num_steps, 2**0.5 - 1) + if self.S_min <= sigma <= self.S_max + else 0 + ) + return sigma + gamma * sigma + + +@dataclasses.dataclass +class EDMConfig: + n_sampling_steps: int = 25 + sampler: str = "heun" + parameterization: EDMParameterization = dataclasses.field( + default_factory=EDMParameterization + ) + + +@dataclasses.dataclass +class RFMParameterization: + t_eps: float = 1e-5 + t_max: float = 1.0 + + def sigma(self, eps): + return self.t_eps + (self.t_max - self.t_eps) + + def loss_weight(self, t): + return 1.0 + + def skip_scaling(self, t): + return 0.0 + + def out_scaling(self, t): + return 1.0 + + def in_scaling(self, t): + return 1.0 + + def noise_conditioning(self, t): + return t + + def sampling_sigmas(self, num_steps): + return jnp.linspace(self.t_eps, self.t_max, num_steps) + + def sigma_hat(self, t, num_steps): + return t + + +@dataclasses.dataclass +class RFMConfig: + n_sampling_steps: int = 25 + sampler: str = "euler" + parameterization: RFMParameterization = dataclasses.field( + default_factory=RFMParameterization + ) diff --git a/blaxbird/_src/experimental/rfm.py b/blaxbird/_src/experimental/rfm.py index b7c5f18..df87fc5 100644 --- a/blaxbird/_src/experimental/rfm.py +++ b/blaxbird/_src/experimental/rfm.py @@ -1,11 +1,10 @@ -import dataclasses - import numpy as np from flax import nnx from jax import numpy as jnp from jax import random as jr from blaxbird._src.experimental import samplers +from blaxbird._src.experimental.parameterizations import RFMConfig def _forward_process(inputs, times, noise): @@ -15,45 +14,6 @@ def _forward_process(inputs, times, noise): return inputs_t -@dataclasses.dataclass -class RFMParameterization: - t_eps: float = 1e-5 - t_max: float = 1.0 - - def sigma(self, eps): - return self.t_eps + (self.t_max - self.t_eps) - - def loss_weight(self, t): - return 1.0 - - def skip_scaling(self, t): - return 0.0 - - def out_scaling(self, t): - return 1.0 - - def in_scaling(self, t): - return 1.0 - - def noise_conditioning(self, t): - return t - - def sampling_sigmas(self, num_steps): - return jnp.linspace(self.t_eps, self.t_max, num_steps) - - def sigma_hat(self, t, num_steps): - return t - - -@dataclasses.dataclass -class RFMConfig: - n_sampling_steps: int = 25 - sampler: str = "euler" - parameterization: RFMParameterization = dataclasses.field( - default_factory=RFMParameterization - ) - - def rfm(config: RFMConfig = RFMConfig()): """Construct rectified flow matching functions. diff --git a/blaxbird/_src/experimental/samplers.py b/blaxbird/_src/experimental/samplers.py index 1a72535..f220567 100644 --- a/blaxbird/_src/experimental/samplers.py +++ b/blaxbird/_src/experimental/samplers.py @@ -5,10 +5,7 @@ from jax import numpy as jnp from jax import random as jr -from blaxbird._src.experimental.edm import EDMConfig -from blaxbird._src.experimental.rfm import ( - RFMConfig, -) +from blaxbird._src.experimental.parameterizations import EDMConfig, RFMConfig def euler_sample_fn(config: RFMConfig): diff --git a/pyproject.toml b/pyproject.toml index 122aca8..75376e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,8 +62,11 @@ skips = ["B101", "B310"] show_error_codes = true no_implicit_optional = true -[tool.pytest] +[tool.pytest.ini_options] addopts = "-v --doctest-modules --cov=./blaxbird --cov-report=xml" +testpaths = [ + "blaxbird" +] [tool.ruff] indent-width = 2 From e0975547489d052620ee848969b082eacd6ca2d0 Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Thu, 21 Aug 2025 13:34:41 +0200 Subject: [PATCH 2/2] ff --- blaxbird/__init__.py | 2 +- blaxbird/_src/experimental/nn/mlp.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/blaxbird/__init__.py b/blaxbird/__init__.py index 9d37827..3dc7f3f 100644 --- a/blaxbird/__init__.py +++ b/blaxbird/__init__.py @@ -1,6 +1,6 @@ """blaxbird: A high-level API for building and training Flax NNX models.""" -__version__ = "0.1.0" +__version__ = "0.1.1" from blaxbird._src.checkpointer import get_default_checkpointer from blaxbird._src.trainer import train_fn diff --git a/blaxbird/_src/experimental/nn/mlp.py b/blaxbird/_src/experimental/nn/mlp.py index f7bec36..0afcec2 100644 --- a/blaxbird/_src/experimental/nn/mlp.py +++ b/blaxbird/_src/experimental/nn/mlp.py @@ -5,7 +5,6 @@ class MLP(nnx.Module): - # ruff: noqa: PLR0913, ANN204, ANN101 def __init__( self, in_features: int,