diff --git a/ml_mdm/models/unet_mlx.py b/ml_mdm/models/unet_mlx.py new file mode 100644 index 0000000..113511a --- /dev/null +++ b/ml_mdm/models/unet_mlx.py @@ -0,0 +1,100 @@ +# For licensing see accompanying LICENSE file. +# Copyright (C) 2024 Apple Inc. All rights reserved. + +import einops + +import mlx.core as mx +import mlx.nn as nn + +from ml_mdm.models.unet import ResNetConfig + + +def zero_module_mlx(module): + """ + Zero out the parameters of an MLX module and return it. + """ + # Create a new parameter dictionary with all parameters replaced by zeros + zeroed_params = { + name: mx.zeros(param.shape, dtype=param.dtype) + for name, param in module.parameters().items() + } + # Update the module's parameters with the zeroed parameters + module.update(zeroed_params) + return module + + +class MLP_MLX(nn.Module): # mlx based nn.Module + def __init__(self, channels, multiplier=4): + super().__init__() + ### use mlx layers + self.main = nn.Sequential( + nn.LayerNorm(channels), + nn.Linear(channels, multiplier * channels), + nn.GELU(), + zero_module_mlx(nn.Linear(multiplier * channels, channels)), + ) + + def forward(self, x): + return x + self.main(x) + + +class ResNet_MLX(nn.Module): + def __init__(self, time_emb_channels, config: ResNetConfig): + # TODO(ndjaitly): What about scales of weights. + super(ResNet_MLX, self).__init__() + self.config = config + self.num_groups = config.num_groups_norm + self.num_channels = config.num_channels + self.norm1 = nn.GroupNorm( + config.num_groups_norm, config.num_channels, pytorch_compatible=True + ) + self.conv1 = nn.Conv2d( + config.num_channels, + config.output_channels, + kernel_size=3, + padding=1, + bias=True, + ) + self.time_layer = nn.Linear(time_emb_channels, config.output_channels * 2) + self.norm2 = nn.GroupNorm( + config.num_groups_norm, config.output_channels, pytorch_compatible=True + ) + self.dropout = nn.Dropout(config.dropout) + self.conv2 = zero_module_mlx( + nn.Conv2d( + config.output_channels, + config.output_channels, + kernel_size=3, + padding=1, + bias=True, + ) + ) + if self.config.output_channels != self.config.num_channels: + self.conv3 = nn.Conv2d( + config.num_channels, config.output_channels, kernel_size=1, bias=True + ) + + def forward(self, x, temb): + print("Shape before norm:", x.shape) + # Try explicitly permuting/reshaping? + h = self.norm1(x) + print("Shape after norm:", h.shape) + h = nn.silu(h) + + h = self.conv1(h) + ta, tb = ( + self.time_layer(nn.silu(temb)).unsqueeze(-1).unsqueeze(-1).chunk(2, dim=1) + ) + if h.size(0) > ta.size(0): # HACK. repeat to match the shape. + N = h.size(0) // ta.size(0) + ta = einops.repeat(ta, "b c h w -> (b n) c h w", n=N) + tb = einops.repeat(tb, "b c h w -> (b n) c h w", n=N) + h = nn.silu(self.norm2(h) * (1 + ta) + tb) + h = self.dropout(h) + h = self.conv2(h) + if self.config.output_channels != self.config.num_channels: + x = self.conv3(x) + return h + x + + def __call__(self, x, temb): + return self.forward(x, temb) diff --git a/tests/test_mlx_unet.py b/tests/test_mlx_unet.py new file mode 100644 index 0000000..ca83f30 --- /dev/null +++ b/tests/test_mlx_unet.py @@ -0,0 +1,107 @@ +# For licensing see accompanying LICENSE file. +# Copyright (C) 2024 Apple Inc. All rights reserved. + +import mlx.core as mx +import numpy as np +import torch + +from ml_mdm.models.unet import MLP, ResNet, ResNetConfig +from ml_mdm.models.unet_mlx import MLP_MLX, ResNet_MLX + + +def test_pytorch_mlp(): + """ + Simple test for our MLP implementations + """ + # Define parameters + channels = 8 # Number of channels + multiplier = 4 # Multiplier for hidden dimensions + + # Create a model instance + pytorch_mlp = MLP(channels=channels, multiplier=multiplier) + mlx_mlp = MLP_MLX(channels=channels, multiplier=multiplier) + + ## Start by testing pytorch version + + # Set model to evaluation mode + pytorch_mlp.eval() + + # Create a dummy pytorch input tensor (batch size = 2, channels = 8) + input_tensor = torch.randn(2, channels) + + # Pass the input through the model + output = pytorch_mlp(input_tensor) + + # Assertions to validate the output shape and properties + assert output.shape == input_tensor.shape, "Output shape mismatch" + assert torch.allclose( + output, input_tensor, atol=1e-5 + ), "Output should be close to input as the final layer is zero-initialized" + + ## now test mlx version + + # Convert the same input to MLX tensor + mlx_tensor = mx.array(input_tensor.numpy()) + + mlx_mlp.eval() + + mlx_output = mlx_mlp.forward(mlx_tensor) + + assert isinstance(mlx_output, mx.array) + assert mlx_output.shape == input_tensor.shape, "MLX MLP: Output shape mismatch" + + # Validate numerical equivalence using numpy + assert np.allclose( + output.detach().numpy(), np.array(mlx_output), atol=1e-5 + ), "Outputs of PyTorch MLP and MLX MLP should match" + + print("Test passed for both PyTorch and MLX MLP!") + + +def test_pytorch_ResNet(): + """ + Simple test for our ResNet implementations + """ + # Define parameters + batch_size = 2 + time_emb_channels = 32 + height = 16 + width = 16 + + # Create config + config = ResNetConfig( + num_channels=64, + output_channels=128, + num_groups_norm=32, + dropout=0.0, # Set to 0 for deterministic comparison + use_attention_ffn=False, + ) + + # Create model instances + pytorch_resnet = ResNet(time_emb_channels=time_emb_channels, config=config) + mlx_resnet = ResNet_MLX(time_emb_channels=time_emb_channels, config=config) + + # Set both models to evaluation mode + pytorch_resnet.eval() + mlx_resnet.eval() + + # Create a dummy pytorch input tensor (batch size = 2, channels = 64, height, width = 16) + x_torch = torch.randn(batch_size, config.num_channels, height, width) + temb_torch = torch.randn(batch_size, time_emb_channels) + + # pass the input thorugh the model + output_torch, activations_torch = pytorch_resnet(x_torch, temb_torch) + + # Convert inputs to MLX tensors + x_mlx = mx.array(x_torch.numpy()) + temb_mlx = mx.array(temb_torch.numpy()) + + # Get MLX output + output_mlx, activations_mlx = mlx_resnet(x_mlx, temb_mlx) + + # Verify outputs match + assert np.allclose( + output_torch.detach().numpy(), np.array(output_mlx), atol=1e-5 + ), "PyTorch and MLX ResNet outputs should match" + + print("Test passed for ResNet implementations!")