From 7c316085af2b3bff9a723f12a8ee4c1e01d3f1e1 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Fri, 20 Jun 2025 11:41:34 -0700 Subject: [PATCH] [ET-VK] Allow specifying multiple storage types/memory layouts for an operator + register group norm operator ## Changes * Handle cases where an operator needs to specify a separate storage type / memory layout for each individual output. ## Motivation Required for the group norm operator. ## Future Work Currently, the `tag_memory_meta_pass` graph pass assumes that all tensors participating in a computation (aside from weights) will have the same storage type and memory layout. As more operators are being added, there are more exceptions to this rule. The pass may need an update in the near future to make it possible to specify required storage types and memory layouts on a more granular level. Differential Revision: [D77038781](https://our.internmc.facebook.com/intern/diff/D77038781/) [ghstack-poisoned] --- .../vulkan/_passes/tag_memory_meta_pass.py | 32 ++++++--- backends/vulkan/op_registry.py | 26 ++++++++ backends/vulkan/test/test_vulkan_delegate.py | 66 +++++++++++++++++++ backends/vulkan/utils.py | 16 ++++- 4 files changed, 129 insertions(+), 11 deletions(-) diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index 836a0c6ef7d..691e62407ca 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -138,15 +138,23 @@ def propose_node_storage( for arg in node.args: if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg): storage = utils.get_node_storage_type(arg) + # Some operators which return multiple output tensors may specify a + # different storage type for each output. In this case, the storage type + # for the first output is used. + if isinstance(storage, (list, tuple)): + storage = storage[0] if storage is not None and storage in valid_storage_types: return storage # If no storage type has been resolved yet, assume the optimal storage type of # the first opinionated user. This search is recursive. for user in node.users: - optimal_storage = self.propose_node_storage(user) - if optimal_storage is not None: - return optimal_storage + storage = self.propose_node_storage(user) + # See above + if isinstance(storage, (list, tuple)): + storage = storage[0] + if storage is not None: + return storage if self.default_storage in valid_storage_types: return self.default_storage @@ -179,15 +187,23 @@ def propose_node_layout( for arg in node.args: if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg): layout = utils.get_node_memory_layout(arg) + # Some operators which return multiple output tensors may specify a + # different memory layout for each output. In this case, the storage + # type for the first output is used. + if isinstance(layout, (list, tuple)): + layout = layout[0] if layout is not None and layout in valid_layouts: return layout - # If no storage type has been resolved yet, assume the optimal storage type of - # the first opinionated user. This search is recursive. + # If no memory layout has been resolved yet, assume the optimal layout of the + # first opinionated user. This search is recursive. for user in node.users: - optimal_storage = self.propose_node_layout(user, storage) - if optimal_storage is not None: - return optimal_storage + layout = self.propose_node_layout(user, storage) + # See above comment + if isinstance(layout, (list, tuple)): + layout = layout[0] + if layout is not None: + return layout # As a last resort, return the default storage type that should be used. if self.default_layout in valid_layouts: diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 9333f34430e..851429c32dc 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -655,6 +655,32 @@ def register_ported_ops_with_prepacking(features: OpFeatures): return features +@update_features( + [ + exir_ops.edge.aten.native_group_norm.default, + ] +) +def register_ported_ops_with_prepacking(features: OpFeatures): + features.texture_impl = TextureImplFeatures( + valid_packed_dims={PackedDim.CHANNELS}, + ) + features.handles_own_prepacking = True + + features.optimal_storage = [ + VkStorageType.TEXTURE_3D, + VkStorageType.BUFFER, + VkStorageType.BUFFER, + ] + + features.optimal_layout = [ + VkMemoryLayout.TENSOR_CHANNELS_PACKED, + VkMemoryLayout.TENSOR_WIDTH_PACKED, + VkMemoryLayout.TENSOR_WIDTH_PACKED, + ] + + return features + + # Ported ops that support their own prepacking. @update_features( [ diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 0096834f3c6..04adf183e55 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -1898,3 +1898,69 @@ def forward(self, x): dynamic_shapes=dynamic_shapes, test_inputs=test_inputs, ) + + def test_vulkan_backend_group_norm(self): + class ConvGroupNormModule(torch.nn.Module): + def __init__(self): + super().__init__() + # Conv2d: 3 input channels -> 16 output channels + self.conv = torch.nn.Conv2d( + in_channels=3, + out_channels=16, + kernel_size=3, + padding=1, + bias=True, + ) + # GroupNorm: 4 groups for 16 channels (16 % 4 == 0) + self.group_norm = torch.nn.GroupNorm( + num_groups=4, + num_channels=16, + eps=1e-5, + affine=True, + ) + + def forward(self, x): + x = self.conv(x) + x = self.group_norm(x) + return x + + # Create sample inputs: [batch, channels, height, width] + sample_inputs = (torch.randn(size=(1, 3, 32, 32), dtype=torch.float32),) + + # Test with static shapes first + self.lower_module_and_test_output( + ConvGroupNormModule(), + sample_inputs, + ) + + def test_vulkan_backend_group_norm_different_groups(self): + class GroupNormModule(torch.nn.Module): + def __init__(self, num_groups, num_channels): + super().__init__() + self.group_norm = torch.nn.GroupNorm( + num_groups=num_groups, + num_channels=num_channels, + eps=1e-5, + affine=True, + ) + + def forward(self, x): + return self.group_norm(x) + + # Test different group configurations + test_configs = [ + (2, 8), # 2 groups, 8 channels + (4, 16), # 4 groups, 16 channels + (8, 32), # 8 groups, 32 channels + ] + + for num_groups, num_channels in test_configs: + with self.subTest(num_groups=num_groups, num_channels=num_channels): + sample_inputs = ( + torch.randn(size=(2, num_channels, 16, 16), dtype=torch.float32), + ) + + self.lower_module_and_test_output( + GroupNormModule(num_groups, num_channels), + sample_inputs, + ) diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 642f7c5f495..5d57ce1e7be 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -264,9 +264,19 @@ def set_node_spec_attr(node: torch.fx.Node, attr: str, value): if isinstance(spec, TensorSpec): setattr(spec, attr, value) elif isinstance(spec, (list, tuple)): - for s in spec: - assert isinstance(s, TensorSpec) - setattr(s, attr, value) + # Special case if value is a list/tuple of the same length as the + # collection of tensors in the node. In this case, treat the value list + # as a list of values to set indivudually for each tensor in the node + if isinstance(value, (list, tuple)) and len(spec) == len(value): + assert len(spec) == len(value) + for s, v in zip(spec, value): + assert isinstance(s, TensorSpec) + setattr(s, attr, v) + # Otherwise, set the attribute to value for all tensors in the list + else: + for s in spec: + assert isinstance(s, TensorSpec) + setattr(s, attr, value) else: raise RuntimeError(f"Cannot set attr for spec of type {type(spec)}")