From c9a83e887512144d5fdcffa4613548483d6a824f Mon Sep 17 00:00:00 2001 From: Bella Deanhardt Date: Thu, 6 Feb 2025 11:42:49 -0500 Subject: [PATCH 1/5] mlp mlx conversion and test --- ml_mdm/models/unet_mlx.py | 34 ++++++++++++++++++ tests/test_files/test_mlx_unet.py | 58 +++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+) create mode 100644 ml_mdm/models/unet_mlx.py create mode 100644 tests/test_files/test_mlx_unet.py diff --git a/ml_mdm/models/unet_mlx.py b/ml_mdm/models/unet_mlx.py new file mode 100644 index 0000000..0084ccc --- /dev/null +++ b/ml_mdm/models/unet_mlx.py @@ -0,0 +1,34 @@ +# For licensing see accompanying LICENSE file. +# Copyright (C) 2024 Apple Inc. All rights reserved. + +import mlx.core as mx +import mlx.nn as nn + + +def zero_module_mlx(module): + """ + Zero out the parameters of an MLX module and return it. + """ + # Create a new parameter dictionary with all parameters replaced by zeros + zeroed_params = { + name: mx.zeros(param.shape, dtype=param.dtype) + for name, param in module.parameters().items() + } + # Update the module's parameters with the zeroed parameters + module.update(zeroed_params) + return module + + +class MLP_MLX(nn.Module): # mlx based nn.Module + def __init__(self, channels, multiplier=4): + super().__init__() + ### use mlx layers + self.main = nn.Sequential( + nn.LayerNorm(channels), + nn.Linear(channels, multiplier * channels), + nn.GELU(), + zero_module_mlx(nn.Linear(multiplier * channels, channels)), + ) + + def forward(self, x): + return x + self.main(x) diff --git a/tests/test_files/test_mlx_unet.py b/tests/test_files/test_mlx_unet.py new file mode 100644 index 0000000..a58a60e --- /dev/null +++ b/tests/test_files/test_mlx_unet.py @@ -0,0 +1,58 @@ +# For licensing see accompanying LICENSE file. +# Copyright (C) 2024 Apple Inc. All rights reserved. + +import mlx.core as mx +import numpy as np +import torch + +from ml_mdm.models.unet import MLP +from ml_mdm.models.unet_mlx import MLP_MLX + + +def test_pytorch_mlp(): + """ + Simple test for our MLP implementations + """ + # Define parameters + channels = 8 # Number of channels + multiplier = 4 # Multiplier for hidden dimensions + + # Create a model instance + pytorch_mlp = MLP(channels=channels, multiplier=multiplier) + mlx_mlp = MLP_MLX(channels=channels, multiplier=multiplier) + + ## Start by testing pytorch version + + # Set model to evaluation mode + pytorch_mlp.eval() + + # Create a dummy pytorch input tensor (batch size = 2, channels = 8) + input_tensor = torch.randn(2, channels) + + # Pass the input through the model + output = pytorch_mlp(input_tensor) + + # Assertions to validate the output shape and properties + assert output.shape == input_tensor.shape, "Output shape mismatch" + assert torch.allclose( + output, input_tensor, atol=1e-5 + ), "Output should be close to input as the final layer is zero-initialized" + + ## now test mlx version + + # Convert the same input to MLX tensor + mlx_tensor = mx.array(input_tensor.numpy()) + + mlx_mlp.eval() + + mlx_output = mlx_mlp.forward(mlx_tensor) + + assert isinstance(mlx_output, mx.array) + assert mlx_output.shape == input_tensor.shape, "MLX MLP: Output shape mismatch" + + # Validate numerical equivalence using numpy + assert np.allclose( + output.detach().numpy(), np.array(mlx_output), atol=1e-5 + ), "Outputs of PyTorch MLP and MLX MLP should match" + + print("Test passed for both PyTorch and MLX MLP!") From 9e164a563b67231d1a50acfee59410f4edd399ba Mon Sep 17 00:00:00 2001 From: Bella Deanhardt Date: Thu, 6 Feb 2025 11:48:00 -0500 Subject: [PATCH 2/5] fixed file path for test --- tests/test_mlx_unet.py | 58 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 tests/test_mlx_unet.py diff --git a/tests/test_mlx_unet.py b/tests/test_mlx_unet.py new file mode 100644 index 0000000..a58a60e --- /dev/null +++ b/tests/test_mlx_unet.py @@ -0,0 +1,58 @@ +# For licensing see accompanying LICENSE file. +# Copyright (C) 2024 Apple Inc. All rights reserved. + +import mlx.core as mx +import numpy as np +import torch + +from ml_mdm.models.unet import MLP +from ml_mdm.models.unet_mlx import MLP_MLX + + +def test_pytorch_mlp(): + """ + Simple test for our MLP implementations + """ + # Define parameters + channels = 8 # Number of channels + multiplier = 4 # Multiplier for hidden dimensions + + # Create a model instance + pytorch_mlp = MLP(channels=channels, multiplier=multiplier) + mlx_mlp = MLP_MLX(channels=channels, multiplier=multiplier) + + ## Start by testing pytorch version + + # Set model to evaluation mode + pytorch_mlp.eval() + + # Create a dummy pytorch input tensor (batch size = 2, channels = 8) + input_tensor = torch.randn(2, channels) + + # Pass the input through the model + output = pytorch_mlp(input_tensor) + + # Assertions to validate the output shape and properties + assert output.shape == input_tensor.shape, "Output shape mismatch" + assert torch.allclose( + output, input_tensor, atol=1e-5 + ), "Output should be close to input as the final layer is zero-initialized" + + ## now test mlx version + + # Convert the same input to MLX tensor + mlx_tensor = mx.array(input_tensor.numpy()) + + mlx_mlp.eval() + + mlx_output = mlx_mlp.forward(mlx_tensor) + + assert isinstance(mlx_output, mx.array) + assert mlx_output.shape == input_tensor.shape, "MLX MLP: Output shape mismatch" + + # Validate numerical equivalence using numpy + assert np.allclose( + output.detach().numpy(), np.array(mlx_output), atol=1e-5 + ), "Outputs of PyTorch MLP and MLX MLP should match" + + print("Test passed for both PyTorch and MLX MLP!") From ff39b7a72b3bbe02a9f1940bce490f689e1c807b Mon Sep 17 00:00:00 2001 From: Bella Deanhardt Date: Mon, 10 Feb 2025 09:44:24 -0800 Subject: [PATCH 3/5] correcting file path for tests_mlx_unet.py --- tests/test_files/test_mlx_unet.py | 58 ------------------------------- 1 file changed, 58 deletions(-) delete mode 100644 tests/test_files/test_mlx_unet.py diff --git a/tests/test_files/test_mlx_unet.py b/tests/test_files/test_mlx_unet.py deleted file mode 100644 index a58a60e..0000000 --- a/tests/test_files/test_mlx_unet.py +++ /dev/null @@ -1,58 +0,0 @@ -# For licensing see accompanying LICENSE file. -# Copyright (C) 2024 Apple Inc. All rights reserved. - -import mlx.core as mx -import numpy as np -import torch - -from ml_mdm.models.unet import MLP -from ml_mdm.models.unet_mlx import MLP_MLX - - -def test_pytorch_mlp(): - """ - Simple test for our MLP implementations - """ - # Define parameters - channels = 8 # Number of channels - multiplier = 4 # Multiplier for hidden dimensions - - # Create a model instance - pytorch_mlp = MLP(channels=channels, multiplier=multiplier) - mlx_mlp = MLP_MLX(channels=channels, multiplier=multiplier) - - ## Start by testing pytorch version - - # Set model to evaluation mode - pytorch_mlp.eval() - - # Create a dummy pytorch input tensor (batch size = 2, channels = 8) - input_tensor = torch.randn(2, channels) - - # Pass the input through the model - output = pytorch_mlp(input_tensor) - - # Assertions to validate the output shape and properties - assert output.shape == input_tensor.shape, "Output shape mismatch" - assert torch.allclose( - output, input_tensor, atol=1e-5 - ), "Output should be close to input as the final layer is zero-initialized" - - ## now test mlx version - - # Convert the same input to MLX tensor - mlx_tensor = mx.array(input_tensor.numpy()) - - mlx_mlp.eval() - - mlx_output = mlx_mlp.forward(mlx_tensor) - - assert isinstance(mlx_output, mx.array) - assert mlx_output.shape == input_tensor.shape, "MLX MLP: Output shape mismatch" - - # Validate numerical equivalence using numpy - assert np.allclose( - output.detach().numpy(), np.array(mlx_output), atol=1e-5 - ), "Outputs of PyTorch MLP and MLX MLP should match" - - print("Test passed for both PyTorch and MLX MLP!") From 6e5ed07b5e5ae3e90d106408d8ac0c34a271a304 Mon Sep 17 00:00:00 2001 From: Bella Deanhardt Date: Mon, 10 Feb 2025 10:29:23 -0800 Subject: [PATCH 4/5] ResNet bug --- ml_mdm/models/unet_mlx.py | 55 +++++++++++++++++++++++++++++++++++++++ tests/test_mlx_unet.py | 53 +++++++++++++++++++++++++++++++++++-- 2 files changed, 106 insertions(+), 2 deletions(-) diff --git a/ml_mdm/models/unet_mlx.py b/ml_mdm/models/unet_mlx.py index 0084ccc..064a88a 100644 --- a/ml_mdm/models/unet_mlx.py +++ b/ml_mdm/models/unet_mlx.py @@ -1,9 +1,13 @@ # For licensing see accompanying LICENSE file. # Copyright (C) 2024 Apple Inc. All rights reserved. +import einops + import mlx.core as mx import mlx.nn as nn +from ml_mdm.models.unet import ResNetConfig + def zero_module_mlx(module): """ @@ -32,3 +36,54 @@ def __init__(self, channels, multiplier=4): def forward(self, x): return x + self.main(x) + + +class ResNet_MLX(nn.Module): + def __init__(self, time_emb_channels, config: ResNetConfig): + # TODO(ndjaitly): What about scales of weights. + super(ResNet_MLX, self).__init__() + self.config = config + self.norm1 = nn.GroupNorm(config.num_groups_norm, config.num_channels) + self.conv1 = nn.Conv2d( + config.num_channels, + config.output_channels, + kernel_size=3, + padding=1, + bias=True, + ) + self.time_layer = nn.Linear(time_emb_channels, config.output_channels * 2) + self.norm2 = nn.GroupNorm(config.num_groups_norm, config.output_channels) + self.dropout = nn.Dropout(config.dropout) + self.conv2 = zero_module_mlx( + nn.Conv2d( + config.output_channels, + config.output_channels, + kernel_size=3, + padding=1, + bias=True, + ) + ) + if self.config.output_channels != self.config.num_channels: + self.conv3 = nn.Conv2d( + config.num_channels, config.output_channels, kernel_size=1, bias=True + ) + + def forward(self, x, temb): + h = nn.silu(self.norm1(x)) + h = self.conv1(h) + ta, tb = ( + self.time_layer(nn.silu(temb)).unsqueeze(-1).unsqueeze(-1).chunk(2, dim=1) + ) + if h.size(0) > ta.size(0): # HACK. repeat to match the shape. + N = h.size(0) // ta.size(0) + ta = einops.repeat(ta, "b c h w -> (b n) c h w", n=N) + tb = einops.repeat(tb, "b c h w -> (b n) c h w", n=N) + h = nn.silu(self.norm2(h) * (1 + ta) + tb) + h = self.dropout(h) + h = self.conv2(h) + if self.config.output_channels != self.config.num_channels: + x = self.conv3(x) + return h + x + + def __call__(self, x, temb): + return self.forward(x, temb) diff --git a/tests/test_mlx_unet.py b/tests/test_mlx_unet.py index a58a60e..ca83f30 100644 --- a/tests/test_mlx_unet.py +++ b/tests/test_mlx_unet.py @@ -5,8 +5,8 @@ import numpy as np import torch -from ml_mdm.models.unet import MLP -from ml_mdm.models.unet_mlx import MLP_MLX +from ml_mdm.models.unet import MLP, ResNet, ResNetConfig +from ml_mdm.models.unet_mlx import MLP_MLX, ResNet_MLX def test_pytorch_mlp(): @@ -56,3 +56,52 @@ def test_pytorch_mlp(): ), "Outputs of PyTorch MLP and MLX MLP should match" print("Test passed for both PyTorch and MLX MLP!") + + +def test_pytorch_ResNet(): + """ + Simple test for our ResNet implementations + """ + # Define parameters + batch_size = 2 + time_emb_channels = 32 + height = 16 + width = 16 + + # Create config + config = ResNetConfig( + num_channels=64, + output_channels=128, + num_groups_norm=32, + dropout=0.0, # Set to 0 for deterministic comparison + use_attention_ffn=False, + ) + + # Create model instances + pytorch_resnet = ResNet(time_emb_channels=time_emb_channels, config=config) + mlx_resnet = ResNet_MLX(time_emb_channels=time_emb_channels, config=config) + + # Set both models to evaluation mode + pytorch_resnet.eval() + mlx_resnet.eval() + + # Create a dummy pytorch input tensor (batch size = 2, channels = 64, height, width = 16) + x_torch = torch.randn(batch_size, config.num_channels, height, width) + temb_torch = torch.randn(batch_size, time_emb_channels) + + # pass the input thorugh the model + output_torch, activations_torch = pytorch_resnet(x_torch, temb_torch) + + # Convert inputs to MLX tensors + x_mlx = mx.array(x_torch.numpy()) + temb_mlx = mx.array(temb_torch.numpy()) + + # Get MLX output + output_mlx, activations_mlx = mlx_resnet(x_mlx, temb_mlx) + + # Verify outputs match + assert np.allclose( + output_torch.detach().numpy(), np.array(output_mlx), atol=1e-5 + ), "PyTorch and MLX ResNet outputs should match" + + print("Test passed for ResNet implementations!") From 8f8c45af622837553b05187cc087bd30e71a3d6e Mon Sep 17 00:00:00 2001 From: Bella Deanhardt Date: Sat, 15 Feb 2025 13:02:05 -0500 Subject: [PATCH 5/5] debug --- ml_mdm/models/unet_mlx.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/ml_mdm/models/unet_mlx.py b/ml_mdm/models/unet_mlx.py index 064a88a..113511a 100644 --- a/ml_mdm/models/unet_mlx.py +++ b/ml_mdm/models/unet_mlx.py @@ -43,7 +43,11 @@ def __init__(self, time_emb_channels, config: ResNetConfig): # TODO(ndjaitly): What about scales of weights. super(ResNet_MLX, self).__init__() self.config = config - self.norm1 = nn.GroupNorm(config.num_groups_norm, config.num_channels) + self.num_groups = config.num_groups_norm + self.num_channels = config.num_channels + self.norm1 = nn.GroupNorm( + config.num_groups_norm, config.num_channels, pytorch_compatible=True + ) self.conv1 = nn.Conv2d( config.num_channels, config.output_channels, @@ -52,7 +56,9 @@ def __init__(self, time_emb_channels, config: ResNetConfig): bias=True, ) self.time_layer = nn.Linear(time_emb_channels, config.output_channels * 2) - self.norm2 = nn.GroupNorm(config.num_groups_norm, config.output_channels) + self.norm2 = nn.GroupNorm( + config.num_groups_norm, config.output_channels, pytorch_compatible=True + ) self.dropout = nn.Dropout(config.dropout) self.conv2 = zero_module_mlx( nn.Conv2d( @@ -69,7 +75,12 @@ def __init__(self, time_emb_channels, config: ResNetConfig): ) def forward(self, x, temb): - h = nn.silu(self.norm1(x)) + print("Shape before norm:", x.shape) + # Try explicitly permuting/reshaping? + h = self.norm1(x) + print("Shape after norm:", h.shape) + h = nn.silu(h) + h = self.conv1(h) ta, tb = ( self.time_layer(nn.silu(temb)).unsqueeze(-1).unsqueeze(-1).chunk(2, dim=1)