From ab8359d1ad9ae5c03c0ae9589d6abd5c8187728c Mon Sep 17 00:00:00 2001 From: TheoCheng100 Date: Wed, 1 Apr 2026 15:35:47 +0800 Subject: [PATCH] update cutlass4.4.2 --- 3rd/cutlass/include/cute/algorithm/axpby.hpp | 2 +- 3rd/cutlass/include/cute/algorithm/clear.hpp | 2 +- .../cute/algorithm/cooperative_copy.hpp | 2 +- .../cute/algorithm/cooperative_gemm.hpp | 2 +- 3rd/cutlass/include/cute/algorithm/copy.hpp | 15 +- 3rd/cutlass/include/cute/algorithm/fill.hpp | 2 +- .../include/cute/algorithm/functional.hpp | 2 +- 3rd/cutlass/include/cute/algorithm/gemm.hpp | 2 +- 3rd/cutlass/include/cute/algorithm/prefer.hpp | 2 +- .../include/cute/algorithm/prefetch.hpp | 2 +- .../cute/algorithm/tensor_algorithms.hpp | 2 +- .../include/cute/algorithm/tensor_reduce.hpp | 2 +- .../cute/algorithm/tuple_algorithms.hpp | 2 +- .../include/cute/arch/cluster_sm100.hpp | 2 +- .../include/cute/arch/cluster_sm90.hpp | 2 +- 3rd/cutlass/include/cute/arch/config.hpp | 71 +- 3rd/cutlass/include/cute/arch/copy.hpp | 2 +- 3rd/cutlass/include/cute/arch/copy_sm100.hpp | 47 +- .../include/cute/arch/copy_sm100_tma.hpp | 4 +- 3rd/cutlass/include/cute/arch/copy_sm50.hpp | 2 +- 3rd/cutlass/include/cute/arch/copy_sm75.hpp | 35 +- 3rd/cutlass/include/cute/arch/copy_sm80.hpp | 2 +- 3rd/cutlass/include/cute/arch/copy_sm90.hpp | 2 +- .../include/cute/arch/copy_sm90_desc.hpp | 10 +- .../include/cute/arch/copy_sm90_tma.hpp | 2 +- 3rd/cutlass/include/cute/arch/mma.hpp | 2 +- 3rd/cutlass/include/cute/arch/mma_sm100.hpp | 2 +- .../include/cute/arch/mma_sm100_desc.hpp | 8 +- .../include/cute/arch/mma_sm100_umma.hpp | 439 ++++- 3rd/cutlass/include/cute/arch/mma_sm120.hpp | 26 +- .../include/cute/arch/mma_sm120_sparse.hpp | 25 +- 3rd/cutlass/include/cute/arch/mma_sm61.hpp | 2 +- 3rd/cutlass/include/cute/arch/mma_sm70.hpp | 2 +- 3rd/cutlass/include/cute/arch/mma_sm75.hpp | 2 +- 3rd/cutlass/include/cute/arch/mma_sm80.hpp | 2 +- 3rd/cutlass/include/cute/arch/mma_sm89.hpp | 118 +- 3rd/cutlass/include/cute/arch/mma_sm90.hpp | 2 +- .../include/cute/arch/mma_sm90_desc.hpp | 2 +- .../include/cute/arch/mma_sm90_gmma.hpp | 2 +- .../include/cute/arch/mma_sm90_gmma_ext.hpp | 2 +- .../cute/arch/mma_sm90_gmma_sparse.hpp | 4 +- .../cute/arch/mma_sm90_gmma_sparse_ext.hpp | 2 +- 3rd/cutlass/include/cute/arch/simd_sm100.hpp | 2 +- .../cute/arch/tmem_allocator_sm100.hpp | 2 +- 3rd/cutlass/include/cute/arch/util.hpp | 2 +- 3rd/cutlass/include/cute/atom/copy_atom.hpp | 152 +- 3rd/cutlass/include/cute/atom/copy_traits.hpp | 2 +- .../include/cute/atom/copy_traits_sm100.hpp | 32 +- .../cute/atom/copy_traits_sm100_im2col.hpp | 2 +- .../cute/atom/copy_traits_sm100_tma.hpp | 16 +- .../include/cute/atom/copy_traits_sm50.hpp | 2 +- .../include/cute/atom/copy_traits_sm75.hpp | 18 +- .../include/cute/atom/copy_traits_sm80.hpp | 2 +- .../include/cute/atom/copy_traits_sm90.hpp | 2 +- .../cute/atom/copy_traits_sm90_im2col.hpp | 13 +- .../cute/atom/copy_traits_sm90_tma.hpp | 71 +- .../atom/copy_traits_sm90_tma_swizzle.hpp | 14 +- 3rd/cutlass/include/cute/atom/mma_atom.hpp | 448 +---- 3rd/cutlass/include/cute/atom/mma_traits.hpp | 10 +- .../include/cute/atom/mma_traits_sm100.hpp | 952 +++++++++- .../include/cute/atom/mma_traits_sm120.hpp | 2 +- .../cute/atom/mma_traits_sm120_sparse.hpp | 2 +- .../include/cute/atom/mma_traits_sm61.hpp | 2 +- .../include/cute/atom/mma_traits_sm70.hpp | 2 +- .../include/cute/atom/mma_traits_sm75.hpp | 2 +- .../include/cute/atom/mma_traits_sm80.hpp | 2 +- .../include/cute/atom/mma_traits_sm89.hpp | 50 +- .../include/cute/atom/mma_traits_sm90.hpp | 2 +- .../cute/atom/mma_traits_sm90_gmma.hpp | 2 +- .../cute/atom/mma_traits_sm90_gmma_ext.hpp | 2 +- .../cute/atom/mma_traits_sm90_gmma_sparse.hpp | 2 +- .../atom/mma_traits_sm90_gmma_sparse_ext.hpp | 2 +- 3rd/cutlass/include/cute/atom/partitioner.hpp | 5 +- 3rd/cutlass/include/cute/config.hpp | 2 +- .../include/cute/container/alignment.hpp | 2 +- 3rd/cutlass/include/cute/container/array.hpp | 20 +- .../include/cute/container/array_aligned.hpp | 2 +- .../include/cute/container/array_subbyte.hpp | 20 +- .../include/cute/container/bit_field.hpp | 2 +- .../include/cute/container/cuda_types.hpp | 2 +- 3rd/cutlass/include/cute/container/tuple.hpp | 16 +- .../include/cute/container/type_list.hpp | 20 +- 3rd/cutlass/include/cute/int_tuple.hpp | 2 +- 3rd/cutlass/include/cute/layout.hpp | 235 +-- 3rd/cutlass/include/cute/layout_composed.hpp | 2 +- .../include/cute/numeric/arithmetic_tuple.hpp | 284 +-- 3rd/cutlass/include/cute/numeric/complex.hpp | 2 +- 3rd/cutlass/include/cute/numeric/int.hpp | 25 +- .../include/cute/numeric/integer_sequence.hpp | 2 +- .../cute/numeric/integral_constant.hpp | 4 +- .../include/cute/numeric/integral_ratio.hpp | 23 +- 3rd/cutlass/include/cute/numeric/math.hpp | 2 +- .../include/cute/numeric/numeric_types.hpp | 71 +- 3rd/cutlass/include/cute/numeric/real.hpp | 2 +- 3rd/cutlass/include/cute/pointer.hpp | 28 +- 3rd/cutlass/include/cute/pointer_base.hpp | 2 +- 3rd/cutlass/include/cute/pointer_flagged.hpp | 19 +- 3rd/cutlass/include/cute/pointer_sparse.hpp | 2 +- 3rd/cutlass/include/cute/pointer_swizzle.hpp | 2 +- 3rd/cutlass/include/cute/stride.hpp | 2 +- 3rd/cutlass/include/cute/swizzle.hpp | 2 +- 3rd/cutlass/include/cute/swizzle_layout.hpp | 11 +- 3rd/cutlass/include/cute/tensor.hpp | 8 +- 3rd/cutlass/include/cute/tensor_impl.hpp | 122 +- 3rd/cutlass/include/cute/tensor_zip.hpp | 2 +- 3rd/cutlass/include/cute/underscore.hpp | 2 +- 3rd/cutlass/include/cute/util/debug.hpp | 2 +- 3rd/cutlass/include/cute/util/print.hpp | 24 +- 3rd/cutlass/include/cute/util/print_latex.hpp | 438 +++++ 3rd/cutlass/include/cute/util/print_svg.hpp | 257 +++ .../include/cute/util/print_tensor.hpp | 197 ++ 3rd/cutlass/include/cute/util/type_traits.hpp | 37 +- 3rd/cutlass/include/cutlass/aligned_buffer.h | 2 +- 3rd/cutlass/include/cutlass/arch/arch.h | 6 +- 3rd/cutlass/include/cutlass/arch/barrier.h | 213 ++- .../include/cutlass/arch/cache_operation.h | 2 +- 3rd/cutlass/include/cutlass/arch/config.h | 71 +- .../cutlass/arch/grid_dependency_control.h | 10 +- 3rd/cutlass/include/cutlass/arch/memory.h | 2 +- .../include/cutlass/arch/memory_sm75.h | 2 +- .../include/cutlass/arch/memory_sm80.h | 3 +- 3rd/cutlass/include/cutlass/arch/mma.h | 2 +- 3rd/cutlass/include/cutlass/arch/mma_sm100.h | 120 ++ 3rd/cutlass/include/cutlass/arch/mma_sm50.h | 2 +- 3rd/cutlass/include/cutlass/arch/mma_sm60.h | 2 +- 3rd/cutlass/include/cutlass/arch/mma_sm61.h | 2 +- 3rd/cutlass/include/cutlass/arch/mma_sm70.h | 8 +- 3rd/cutlass/include/cutlass/arch/mma_sm75.h | 8 +- 3rd/cutlass/include/cutlass/arch/mma_sm80.h | 9 +- 3rd/cutlass/include/cutlass/arch/mma_sm89.h | 9 +- 3rd/cutlass/include/cutlass/arch/mma_sm90.h | 10 +- .../include/cutlass/arch/mma_sparse_sm80.h | 8 +- .../include/cutlass/arch/mma_sparse_sm89.h | 8 +- .../include/cutlass/arch/reg_reconfig.h | 11 +- 3rd/cutlass/include/cutlass/arch/simd.h | 2 +- 3rd/cutlass/include/cutlass/arch/simd_sm60.h | 2 +- 3rd/cutlass/include/cutlass/arch/simd_sm61.h | 2 +- 3rd/cutlass/include/cutlass/arch/synclog.hpp | 134 +- 3rd/cutlass/include/cutlass/arch/wmma.h | 4 +- 3rd/cutlass/include/cutlass/arch/wmma_sm70.h | 10 +- 3rd/cutlass/include/cutlass/arch/wmma_sm72.h | 12 +- 3rd/cutlass/include/cutlass/arch/wmma_sm75.h | 12 +- 3rd/cutlass/include/cutlass/array.h | 20 +- .../include/cutlass/array_planar_complex.h | 2 +- 3rd/cutlass/include/cutlass/array_subbyte.h | 4 +- 3rd/cutlass/include/cutlass/barrier.h | 2 +- 3rd/cutlass/include/cutlass/bfloat16.h | 10 +- 3rd/cutlass/include/cutlass/blas3.h | 2 +- 3rd/cutlass/include/cutlass/blas3_types.h | 2 +- 3rd/cutlass/include/cutlass/block_striped.h | 2 +- .../include/cutlass/cluster_launch.hpp | 5 +- 3rd/cutlass/include/cutlass/complex.h | 39 +- 3rd/cutlass/include/cutlass/constants.h | 2 +- .../conv/collective/builders/sm100_common.inl | 2 +- .../builders/sm100_umma_builder.inl | 2 +- .../conv/collective/builders/sm90_common.inl | 2 +- .../collective/builders/sm90_gmma_builder.inl | 2 +- .../conv/collective/collective_builder.hpp | 2 +- .../conv/collective/collective_conv.hpp | 2 +- .../cutlass/conv/collective/detail.hpp | 2 +- ...100_implicit_gemm_umma_warpspecialized.hpp | 2 +- ..._implicit_gemm_gmma_ss_warpspecialized.hpp | 2 +- .../cutlass/conv/conv2d_problem_size.h | 2 +- .../cutlass/conv/conv3d_problem_size.h | 2 +- .../cutlass/conv/convnd_problem_shape.hpp | 2 +- .../include/cutlass/conv/convolution.h | 2 +- 3rd/cutlass/include/cutlass/conv/detail.hpp | 2 +- .../conv/device/conv_universal_adapter.hpp | 2 +- .../cutlass/conv/device/direct_convolution.h | 2 +- .../conv/device/implicit_gemm_convolution.h | 2 +- .../device/implicit_gemm_convolution_fusion.h | 2 +- .../include/cutlass/conv/dispatch_policy.hpp | 2 +- .../cutlass/conv/kernel/conv_universal.hpp | 2 +- .../cutlass/conv/kernel/default_conv2d.h | 2 +- .../conv/kernel/default_conv2d_dgrad.h | 2 +- .../conv/kernel/default_conv2d_fprop.h | 6 +- .../conv/kernel/default_conv2d_fprop_fusion.h | 4 +- .../kernel/default_conv2d_fprop_with_absmax.h | 2 +- .../default_conv2d_fprop_with_broadcast.h | 2 +- .../default_conv2d_fprop_with_reduction.h | 2 +- .../conv/kernel/default_conv2d_group_fprop.h | 2 +- .../conv/kernel/default_conv2d_wgrad.h | 2 +- .../conv/kernel/default_conv2d_wgrad_fusion.h | 2 +- .../conv/kernel/default_conv3d_dgrad.h | 2 +- .../conv/kernel/default_conv3d_fprop.h | 2 +- .../conv/kernel/default_conv3d_fprop_fusion.h | 4 +- .../default_conv3d_fprop_with_broadcast.h | 2 +- .../conv/kernel/default_conv3d_wgrad.h | 2 +- .../cutlass/conv/kernel/default_deconv2d.h | 2 +- .../kernel/default_deconv2d_with_broadcast.h | 2 +- .../cutlass/conv/kernel/default_deconv3d.h | 2 +- .../kernel/default_deconv3d_with_broadcast.h | 2 +- .../conv/kernel/default_depthwise_fprop.h | 2 +- .../cutlass/conv/kernel/direct_convolution.h | 2 +- .../conv/kernel/implicit_gemm_convolution.h | 2 +- .../kernel/implicit_gemm_convolution_fusion.h | 2 +- .../implicit_gemm_convolution_strided_dgrad.h | 2 +- .../implicit_gemm_convolution_with_absmax.h | 2 +- ...cit_gemm_convolution_with_fused_epilogue.h | 2 +- ...m100_implicit_gemm_tma_warpspecialized.hpp | 23 +- ...sm90_implicit_gemm_tma_warpspecialized.hpp | 2 +- .../cutlass/conv/thread/depthwise_mma.h | 2 +- ...rad_filter_tile_access_iterator_analytic.h | 2 +- ...ad_filter_tile_access_iterator_optimized.h | 2 +- ...t_gradient_tile_access_iterator_analytic.h | 2 +- ..._gradient_tile_access_iterator_optimized.h | 2 +- ...activation_tile_access_iterator_analytic.h | 2 +- ...vation_tile_access_iterator_few_channels.h | 2 +- ...tion_tile_access_iterator_fixed_channels.h | 2 +- ...ctivation_tile_access_iterator_optimized.h | 2 +- ...rop_filter_tile_access_iterator_analytic.h | 2 +- ...filter_tile_access_iterator_few_channels.h | 2 +- ...lter_tile_access_iterator_fixed_channels.h | 2 +- ...op_filter_tile_access_iterator_optimized.h | 2 +- .../cutlass/conv/threadblock/conv2d_params.h | 2 +- .../conv/threadblock/conv2d_tile_iterator.h | 2 +- ...activation_tile_access_iterator_analytic.h | 2 +- ...ctivation_tile_access_iterator_optimized.h | 2 +- ...t_gradient_tile_access_iterator_analytic.h | 2 +- ..._gradient_tile_access_iterator_optimized.h | 2 +- ...rad_filter_tile_access_iterator_analytic.h | 2 +- ...ad_filter_tile_access_iterator_optimized.h | 2 +- ...t_gradient_tile_access_iterator_analytic.h | 2 +- ..._gradient_tile_access_iterator_optimized.h | 2 +- ...activation_tile_access_iterator_analytic.h | 2 +- ...ctivation_tile_access_iterator_optimized.h | 2 +- ...rop_filter_tile_access_iterator_analytic.h | 2 +- ...op_filter_tile_access_iterator_optimized.h | 2 +- .../cutlass/conv/threadblock/conv3d_params.h | 2 +- ...activation_tile_access_iterator_analytic.h | 2 +- ...ctivation_tile_access_iterator_optimized.h | 2 +- ...t_gradient_tile_access_iterator_analytic.h | 2 +- ..._gradient_tile_access_iterator_optimized.h | 2 +- .../depthwise_direct_conv_params.h | 2 +- ...erator_direct_conv_fixed_stride_dilation.h | 2 +- ...le_access_iterator_direct_conv_optimized.h | 2 +- .../depthwise_fprop_direct_conv_multistage.h | 2 +- ...le_access_iterator_direct_conv_optimized.h | 2 +- .../threadblock/depthwise_fprop_pipelined.h | 2 +- .../conv/threadblock/depthwise_mma_base.h | 2 +- ...depthwise_mma_core_with_lane_access_size.h | 2 +- .../implicit_gemm_fprop_fusion_multistage.h | 2 +- .../threadblock/implicit_gemm_multistage.h | 2 +- .../threadblock/implicit_gemm_pipelined.h | 2 +- .../implicit_gemm_wgrad_fusion_multistage.h | 2 +- ...icated_scale_bias_vector_access_iterator.h | 2 +- .../predicated_scale_bias_vector_iterator.h | 2 +- .../conv/threadblock/threadblock_swizzle.h | 2 +- .../cutlass/conv/warp/mma_depthwise_simt.h | 2 +- .../warp/mma_depthwise_simt_tile_iterator.h | 2 +- .../conv/warp/scale_bias_relu_transform.h | 2 +- 3rd/cutlass/include/cutlass/coord.h | 11 +- 3rd/cutlass/include/cutlass/core_io.h | 2 +- .../include/cutlass/cuda_host_adapter.hpp | 6 +- 3rd/cutlass/include/cutlass/cutlass.h | 5 +- .../cutlass/detail/blockwise_scale_layout.hpp | 14 +- .../include/cutlass/detail/cluster.hpp | 2 +- .../include/cutlass/detail/collective.hpp | 2 +- .../detail/collective/mixed_input_utils.hpp | 369 +++- .../detail/collective/moe_stride_utils.hpp | 99 + .../detail/collective/sm103_kernel_type.hpp | 45 + .../cutlass/detail/dependent_false.hpp | 2 +- .../include/cutlass/detail/helper_macros.hpp | 2 +- 3rd/cutlass/include/cutlass/detail/layout.hpp | 2 +- .../mainloop_fusion_helper_scale_factor.hpp | 2 +- 3rd/cutlass/include/cutlass/detail/mma.hpp | 2 +- .../detail/sm100_blockscaled_layout.hpp | 24 +- .../sm100_mixed_dtype_blockwise_layout.hpp | 182 ++ .../cutlass/detail/sm100_tmem_helper.hpp | 2 +- .../detail/sm103_blockscaled_layout.hpp | 117 ++ 3rd/cutlass/include/cutlass/device_kernel.h | 3 +- .../collective/builders/sm100_builder.inl | 559 ++++-- .../collective/builders/sm103_builder.inl | 108 ++ .../collective/builders/sm120_builder.inl | 6 +- .../collective/builders/sm120_common.inl | 2 +- .../collective/builders/sm90_builder.inl | 6 +- .../collective/builders/sm90_common.inl | 2 +- .../collective/collective_builder.hpp | 3 +- .../collective/collective_epilogue.hpp | 2 +- .../epilogue/collective/default_epilogue.hpp | 2 +- .../collective/default_epilogue_array.hpp | 2 +- .../cutlass/epilogue/collective/detail.hpp | 203 +- .../collective/epilogue_tensor_broadcast.hpp | 2 +- .../sm100_epilogue_array_nosmem.hpp | 590 +++++- ...0_epilogue_array_planar_complex_nosmem.hpp | 345 ++++ ...ray_planar_complex_tma_warpspecialized.hpp | 1161 ++++++++++++ ...100_epilogue_array_tma_warpspecialized.hpp | 189 +- .../collective/sm100_epilogue_nosmem.hpp | 62 +- ...gue_planar_complex_tma_warpspecialized.hpp | 897 +++++++++ .../sm100_epilogue_tma_warpspecialized.hpp | 2 +- .../collective/sm70_epilogue_vectorized.hpp | 2 +- .../sm70_epilogue_vectorized_array.hpp | 2 +- ...m90_epilogue_array_tma_warpspecialized.hpp | 11 +- .../sm90_epilogue_tma_warpspecialized.hpp | 23 +- ...e_tma_warpspecialized_bias_elementwise.hpp | 2 +- .../cutlass/epilogue/dispatch_policy.hpp | 79 +- .../cutlass/epilogue/fusion/callbacks.hpp | 2 +- .../cutlass/epilogue/fusion/operations.hpp | 21 +- .../sm100_callbacks_tma_warpspecialized.hpp | 35 +- ...00_visitor_compute_tma_warpspecialized.hpp | 2 +- ...m100_visitor_store_tma_warpspecialized.hpp | 2 +- .../sm120_callbacks_tma_warpspecialized.hpp | 2 +- ...m120_visitor_store_tma_warpspecialized.hpp | 2 +- .../sm90_callbacks_tma_warpspecialized.hpp | 110 +- ...90_visitor_compute_tma_warpspecialized.hpp | 6 +- .../sm90_visitor_load_tma_warpspecialized.hpp | 32 +- ...sm90_visitor_store_tma_warpspecialized.hpp | 12 +- .../sm90_visitor_tma_warpspecialized.hpp | 2 +- .../fusion/sm90_visitor_topk_softmax.hpp | 29 +- .../cutlass/epilogue/thread/activation.h | 2 +- .../cutlass/epilogue/thread/conversion_op.h | 18 +- .../cutlass/epilogue/thread/detail.hpp | 2 +- .../epilogue/thread/linear_combination.h | 2 +- .../linear_combination_bias_elementwise.h | 2 +- .../thread/linear_combination_bias_relu.h | 2 +- .../thread/linear_combination_clamp.h | 2 +- .../thread/linear_combination_dgelu.h | 2 +- .../thread/linear_combination_drelu.h | 2 +- .../epilogue/thread/linear_combination_gelu.h | 2 +- .../thread/linear_combination_generic.h | 2 +- .../linear_combination_generic_with_scaling.h | 2 +- .../thread/linear_combination_hardswish.h | 2 +- .../thread/linear_combination_leaky_relu.h | 2 +- .../thread/linear_combination_params.h | 2 +- .../linear_combination_planar_complex.h | 2 +- .../epilogue/thread/linear_combination_relu.h | 2 +- .../thread/linear_combination_relu0.h | 2 +- .../linear_combination_residual_block.h | 2 +- .../thread/linear_combination_sigmoid.h | 2 +- .../epilogue/thread/linear_combination_silu.h | 2 +- .../linear_combination_tensor_broadcast.hpp | 2 +- .../linear_combination_with_elementwise.h | 2 +- .../cutlass/epilogue/thread/reduction_op.h | 2 +- .../cutlass/epilogue/thread/scale_type.h | 2 +- .../default_epilogue_complex_tensor_op.h | 2 +- ...default_epilogue_complex_tensor_op_blas3.h | 2 +- .../default_epilogue_direct_store.h | 2 +- .../default_epilogue_planar_complex.h | 2 +- .../threadblock/default_epilogue_simt.h | 2 +- .../threadblock/default_epilogue_tensor_op.h | 2 +- .../default_epilogue_tensor_op_blas3.h | 2 +- .../default_epilogue_volta_tensor_op.h | 2 +- .../default_epilogue_with_absmax.h | 2 +- .../default_epilogue_with_broadcast.h | 2 +- .../default_epilogue_with_reduction.h | 2 +- .../default_epilogue_wmma_tensor_op.h | 2 +- .../threadblock/default_thread_map_simt.h | 2 +- .../default_thread_map_tensor_op.h | 2 +- .../default_thread_map_volta_tensor_op.h | 2 +- .../default_thread_map_wmma_tensor_op.h | 2 +- .../direct_store_epilogue_iterator.h | 2 +- .../cutlass/epilogue/threadblock/epilogue.h | 9 +- .../epilogue/threadblock/epilogue_base.h | 10 +- .../threadblock/epilogue_base_streamk.h | 2 +- .../epilogue/threadblock/epilogue_depthwise.h | 2 +- .../threadblock/epilogue_direct_store.h | 2 +- .../threadblock/epilogue_gemm_k_reduction.h | 6 +- .../threadblock/epilogue_planar_complex.h | 2 +- .../threadblock/epilogue_smem_accumulator.h | 6 +- .../epilogue_streamk_with_broadcast.h | 8 +- .../epilogue_visitor_with_softmax.h | 2 +- .../threadblock/epilogue_with_absmax.h | 9 +- .../threadblock/epilogue_with_broadcast.h | 9 +- .../threadblock/epilogue_with_reduction.h | 6 +- .../epilogue_with_scaling_factor.h | 231 +++ .../threadblock/epilogue_with_visitor.h | 2 +- .../epilogue_with_visitor_callbacks.h | 2 +- .../epilogue/threadblock/epilogue_workspace.h | 2 +- .../threadblock/fusion/visitor_2x.hpp | 2 +- .../threadblock/fusion/visitor_compute.hpp | 2 +- .../threadblock/fusion/visitor_load.hpp | 2 +- .../threadblock/fusion/visitor_store.hpp | 2 +- .../epilogue/threadblock/fusion/visitors.hpp | 2 +- .../threadblock/interleaved_epilogue.h | 2 +- .../threadblock/output_iterator_parameter.h | 2 +- .../threadblock/output_tile_thread_map.h | 2 +- .../threadblock/predicated_tile_iterator.h | 2 +- .../predicated_tile_iterator_affine.h | 2 +- ...cated_tile_iterator_affine_layout_params.h | 2 +- .../predicated_tile_iterator_blas3.h | 2 +- .../predicated_tile_iterator_conv.h | 2 +- .../predicated_tile_iterator_direct_conv.h | 2 +- .../predicated_tile_iterator_params.h | 2 +- .../predicated_tile_iterator_predicates.h | 2 +- .../predicated_tile_iterator_strided_dgrad.h | 2 +- .../threadblock/shared_load_iterator.h | 2 +- .../threadblock/shared_load_iterator_mixed.h | 2 +- .../shared_load_iterator_pitch_linear.h | 2 +- .../fragment_iterator_complex_tensor_op.h | 2 +- ...ment_iterator_gaussian_complex_tensor_op.h | 2 +- .../epilogue/warp/fragment_iterator_simt.h | 2 +- .../warp/fragment_iterator_tensor_op.h | 2 +- .../warp/fragment_iterator_volta_tensor_op.h | 2 +- .../warp/fragment_iterator_wmma_tensor_op.h | 2 +- .../cutlass/epilogue/warp/simt_policy.h | 2 +- .../cutlass/epilogue/warp/tensor_op_policy.h | 2 +- .../epilogue/warp/tile_iterator_simt.h | 2 +- .../epilogue/warp/tile_iterator_tensor_op.h | 2 +- .../warp/tile_iterator_tensor_op_mixed.h | 2 +- .../warp/tile_iterator_volta_tensor_op.h | 2 +- .../warp/tile_iterator_wmma_tensor_op.h | 2 +- .../epilogue/warp/volta_tensor_op_policy.h | 2 +- .../epilogue/warp/wmma_tensor_op_policy.h | 2 +- 3rd/cutlass/include/cutlass/exmy_base.h | 6 +- .../distributed/device/detail.hpp | 2 +- .../device/dist_gemm_universal_wrapper.hpp | 51 +- .../distributed/device/full_barrier.hpp | 2 +- .../distributed/kernel/detail.hpp | 2 +- .../kernel/dist_gemm_kernel_wrapper.hpp | 2 +- .../distributed/kernel/full_barrier.hpp | 2 +- .../schedules/dist_gemm_1d_schedules.hpp | 2 +- .../schedules/dist_gemm_base_schedule.hpp | 2 +- 3rd/cutlass/include/cutlass/fast_math.h | 9 +- 3rd/cutlass/include/cutlass/float8.h | 24 +- 3rd/cutlass/include/cutlass/float_subbyte.h | 20 +- .../include/cutlass/floating_point_nvrtc.h | 2 +- 3rd/cutlass/include/cutlass/functional.h | 56 +- ...xBF16_interleaved_complex_umma_builder.inl | 298 +++ .../builders/sm100_9xBF16_umma_builder.inl | 70 +- ...kscaled_mixed_tma_cpasync_umma_builder.inl | 274 +++ .../sm100_blockscaled_sparse_umma_builder.inl | 16 +- .../sm100_blockscaled_umma_builder.inl | 43 +- .../builders/sm100_blockwise_umma_builder.inl | 24 +- .../gemm/collective/builders/sm100_common.inl | 296 ++- .../builders/sm100_cpasync_umma_builder.inl | 179 ++ ...sm100_interleaved_complex_umma_builder.inl | 264 +++ .../sm100_mixed_input_umma_builder.inl | 349 ++++ .../sm100_mixed_tma_cpasync_umma_builder.inl | 171 ++ .../builders/sm100_pipeline_carveout.inl | 19 +- .../sm100_planar_complex_umma_builder.inl | 182 ++ .../builders/sm100_simt_builder.inl | 219 +++ .../builders/sm100_sparse_umma_builder.inl | 16 +- .../builders/sm100_umma_builder.inl | 46 +- .../sm103_blockscaled_umma_builder.inl | 550 ++++++ .../sm120_blockscaled_mma_builder.inl | 39 +- .../sm120_blockscaled_sparse_mma_builder.inl | 8 +- .../builders/sm120_blockwise_mma_builder.inl | 2 +- .../gemm/collective/builders/sm120_common.inl | 2 +- .../collective/builders/sm120_mma_builder.inl | 2 +- .../builders/sm120_sparse_mma_builder.inl | 2 +- .../gemm/collective/builders/sm1xx_common.inl | 19 +- .../builders/sm1xx_sparse_config.inl | 2 +- .../gemm/collective/builders/sm90_common.inl | 2 +- .../collective/builders/sm90_gmma_builder.inl | 42 +- .../builders/sm90_sparse_config.inl | 2 +- .../builders/sm90_sparse_gmma_builder.inl | 2 +- .../gemm/collective/collective_builder.hpp | 19 +- .../collective/collective_builder_decl.hpp | 2 +- .../gemm/collective/collective_mma.hpp | 29 +- .../gemm/collective/collective_mma_decl.hpp | 2 +- .../gemm/collective/fp8_accumulation.hpp | 85 +- ..._blockscaled_mma_array_warpspecialized.hpp | 115 +- ...aled_mma_array_warpspecialized_rcggemm.hpp | 1294 +++++++++++++ ..._mma_mixed_tma_cpasync_warpspecialized.hpp | 1032 ++++++++++ .../sm100_blockscaled_mma_warpspecialized.hpp | 11 +- ...blockscaled_sparse_mma_warpspecialized.hpp | 10 +- .../sm100_mma_array_warpspecialized.hpp | 85 +- ...rray_warpspecialized_blockwise_scaling.hpp | 2 +- ...100_mma_array_warpspecialized_emulated.hpp | 7 +- ...ecialized_interleaved_complex_emulated.hpp | 1202 ++++++++++++ ...rpspecialized_interleaved_complex_tf32.hpp | 992 ++++++++++ ...a_array_warpspecialized_planar_complex.hpp | 963 ++++++++++ ...m100_mma_array_warpspecialized_rcggemm.hpp | 900 +++++++++ .../sm100_mma_cpasync_warpspecialized.hpp | 588 ++++++ ..._mma_mixed_tma_cpasync_warpspecialized.hpp | 752 ++++++++ .../collective/sm100_mma_warpspecialized.hpp | 2 +- ..._mma_warpspecialized_blockwise_scaling.hpp | 109 +- .../sm100_mma_warpspecialized_emulated.hpp | 7 +- ...ecialized_interleaved_complex_emulated.hpp | 1077 +++++++++++ ...rpspecialized_interleaved_complex_tf32.hpp | 880 +++++++++ .../sm100_mma_warpspecialized_mixed_input.hpp | 1294 +++++++++++++ ...100_mma_warpspecialized_planar_complex.hpp | 829 ++++++++ .../sm100_sparse_mma_warpspecialized.hpp | 2 +- ..._blockscaled_mma_array_warpspecialized.hpp | 1685 +++++++++++++++++ .../sm103_blockscaled_mma_warpspecialized.hpp | 1276 +++++++++++++ .../sm120_blockscaled_mma_array_tma.hpp | 6 +- .../collective/sm120_blockscaled_mma_tma.hpp | 6 +- .../sm120_blockscaled_sparse_mma_tma.hpp | 15 +- .../sm120_mma_array_tma_blockwise_scaling.hpp | 2 +- .../cutlass/gemm/collective/sm120_mma_tma.hpp | 2 +- .../sm120_mma_tma_blockwise_scaling.hpp | 2 +- .../gemm/collective/sm120_sparse_mma_tma.hpp | 6 +- .../gemm/collective/sm70_mma_twostage.hpp | 46 +- .../collective/sm80_mma_array_multistage.hpp | 412 ++++ .../gemm/collective/sm80_mma_multistage.hpp | 4 +- ...ma_gmma_rs_warpspecialized_mixed_input.hpp | 9 +- ..._mma_array_tma_gmma_ss_warpspecialized.hpp | 2 +- ..._array_tma_gmma_ss_warpspecialized_fp8.hpp | 12 +- ..._warpspecialized_fp8_blockwise_scaling.hpp | 398 ++-- ...mma_multistage_gmma_rs_warpspecialized.hpp | 2 +- ...mma_multistage_gmma_ss_warpspecialized.hpp | 2 +- .../sm90_mma_tma_gmma_rs_warpspecialized.hpp | 2 +- ...ma_gmma_rs_warpspecialized_mixed_input.hpp | 8 +- .../gemm/collective/sm90_mma_tma_gmma_ss.hpp | 2 +- .../sm90_mma_tma_gmma_ss_warpspecialized.hpp | 2 +- ...90_mma_tma_gmma_ss_warpspecialized_fp8.hpp | 12 +- ..._warpspecialized_fp8_blockwise_scaling.hpp | 405 ++-- ...sparse_mma_tma_gmma_ss_warpspecialized.hpp | 2 +- ...se_mma_tma_gmma_ss_warpspecialized_fp8.hpp | 12 +- .../cutlass/gemm/device/base_grouped.h | 2 +- .../gemm/device/default_gemm_configuration.h | 2 +- .../include/cutlass/gemm/device/ell_gemm.h | 6 +- .../include/cutlass/gemm/device/gemm.h | 4 +- .../include/cutlass/gemm/device/gemm_array.h | 4 +- .../cutlass/gemm/device/gemm_batched.h | 4 +- .../cutlass/gemm/device/gemm_blockwise.h | 761 ++++++++ .../cutlass/gemm/device/gemm_complex.h | 4 +- .../cutlass/gemm/device/gemm_grouped.h | 2 +- .../device/gemm_layernorm_mainloop_fusion.h | 2 +- .../include/cutlass/gemm/device/gemm_sparse.h | 2 +- .../gemm/device/gemm_sparse_universal.h | 2 +- .../gemm_sparse_universal_with_absmax.h | 2 +- .../gemm/device/gemm_sparse_with_absmax.h | 2 +- .../gemm/device/gemm_sparse_with_visitor.h | 2 +- .../gemm/device/gemm_splitk_parallel.h | 4 +- .../cutlass/gemm/device/gemm_universal.h | 4 +- .../gemm/device/gemm_universal_adapter.h | 22 +- .../cutlass/gemm/device/gemm_universal_base.h | 9 +- .../gemm_universal_streamk_with_broadcast.h | 4 +- .../gemm/device/gemm_universal_with_absmax.h | 4 +- .../device/gemm_universal_with_broadcast.h | 4 +- .../gemm/device/gemm_with_k_reduction.h | 4 +- .../include/cutlass/gemm/device/gemv.h | 2 +- .../cutlass/gemm/device/gemv_blockscaled.h | 183 ++ .../include/cutlass/gemm/device/rank_2k.h | 4 +- .../cutlass/gemm/device/rank_2k_grouped.h | 2 +- .../include/cutlass/gemm/device/rank_k.h | 4 +- .../include/cutlass/gemm/device/symm.h | 4 +- .../include/cutlass/gemm/device/trmm.h | 8 +- .../include/cutlass/gemm/dispatch_policy.hpp | 408 +++- 3rd/cutlass/include/cutlass/gemm/gemm.h | 2 +- .../cutlass/gemm/gemm_enumerated_types.h | 2 +- .../gemm/group_array_problem_shape.hpp | 64 +- .../cutlass/gemm/kernel/default_ell_gemm.h | 2 +- .../cutlass/gemm/kernel/default_gemm.h | 2 +- .../gemm/kernel/default_gemm_complex.h | 2 +- .../gemm/kernel/default_gemm_grouped.h | 2 +- .../default_gemm_grouped_per_group_scale.h | 2 +- ...ult_gemm_grouped_softmax_mainloop_fusion.h | 2 +- .../default_gemm_layernorm_mainloop_fusion.h | 2 +- .../default_gemm_planar_complex_universal.h | 2 +- .../cutlass/gemm/kernel/default_gemm_sparse.h | 2 +- .../kernel/default_gemm_sparse_universal.h | 2 +- ...efault_gemm_sparse_universal_with_absmax.h | 2 +- .../kernel/default_gemm_sparse_with_absmax.h | 2 +- .../kernel/default_gemm_sparse_with_visitor.h | 2 +- .../kernel/default_gemm_splitk_parallel.h | 2 +- .../default_gemm_streamk_with_broadcast.h | 2 +- .../gemm/kernel/default_gemm_universal.h | 2 +- .../default_gemm_universal_with_visitor.h | 2 +- .../gemm/kernel/default_gemm_with_absmax.h | 2 +- .../gemm/kernel/default_gemm_with_broadcast.h | 2 +- .../kernel/default_gemm_with_k_reduction.h | 2 +- .../gemm/kernel/default_gemm_with_reduction.h | 2 +- .../cutlass/gemm/kernel/default_gemv.h | 2 +- .../cutlass/gemm/kernel/default_rank_2k.h | 2 +- .../gemm/kernel/default_rank_2k_complex.h | 2 +- .../gemm/kernel/default_rank_2k_grouped.h | 2 +- .../gemm/kernel/default_rank_2k_universal.h | 2 +- .../cutlass/gemm/kernel/default_rank_k.h | 2 +- .../gemm/kernel/default_rank_k_complex.h | 2 +- .../gemm/kernel/default_rank_k_universal.h | 2 +- .../cutlass/gemm/kernel/default_symm.h | 2 +- .../gemm/kernel/default_symm_complex.h | 6 +- .../gemm/kernel/default_symm_universal.h | 2 +- .../cutlass/gemm/kernel/default_trmm.h | 2 +- .../gemm/kernel/default_trmm_complex.h | 2 +- .../gemm/kernel/default_trmm_universal.h | 2 +- .../include/cutlass/gemm/kernel/ell_gemm.h | 2 +- .../include/cutlass/gemm/kernel/gemm.h | 2 +- .../include/cutlass/gemm/kernel/gemm_array.h | 2 +- .../cutlass/gemm/kernel/gemm_batched.h | 2 +- .../cutlass/gemm/kernel/gemm_blockwise.h | 223 +++ .../cutlass/gemm/kernel/gemm_grouped.h | 2 +- .../kernel/gemm_grouped_per_group_scale.h | 2 +- .../kernel/gemm_grouped_problem_visitor.h | 2 +- .../gemm_grouped_softmax_mainloop_fusion.h | 2 +- .../kernel/gemm_layernorm_mainloop_fusion.h | 2 +- .../include/cutlass/gemm/kernel/gemm_params.h | 2 +- .../cutlass/gemm/kernel/gemm_pipelined.h | 2 +- .../cutlass/gemm/kernel/gemm_planar_complex.h | 2 +- .../gemm/kernel/gemm_planar_complex_array.h | 2 +- .../gemm/kernel/gemm_sparse_universal.h | 2 +- .../gemm_sparse_universal_with_absmax.h | 2 +- .../gemm/kernel/gemm_splitk_parallel.h | 2 +- .../kernel/gemm_streamk_with_fused_epilogue.h | 2 +- .../gemm/kernel/gemm_transpose_operands.h | 2 +- .../cutlass/gemm/kernel/gemm_universal.h | 2 +- .../cutlass/gemm/kernel/gemm_universal.hpp | 9 +- .../gemm/kernel/gemm_universal_blockwise.h | 359 ++++ .../cutlass/gemm/kernel/gemm_universal_decl.h | 2 +- .../gemm/kernel/gemm_universal_streamk.h | 4 +- .../gemm/kernel/gemm_universal_with_visitor.h | 2 +- .../gemm_universal_with_visitor_streamk.h | 2 +- .../cutlass/gemm/kernel/gemm_with_absmax.h | 2 +- .../gemm/kernel/gemm_with_fused_epilogue.h | 2 +- .../gemm/kernel/gemm_with_k_reduction.h | 2 +- .../include/cutlass/gemm/kernel/gemv.h | 6 +- .../gemm/kernel/gemv_batched_strided.h | 4 +- .../cutlass/gemm/kernel/gemv_blockscaled.h | 885 +++++++++ .../gemm/kernel/grouped_problem_visitor.h | 2 +- .../cutlass/gemm/kernel/params_sparse_base.h | 2 +- .../gemm/kernel/params_universal_base.h | 2 +- .../cutlass/gemm/kernel/rank_2k_grouped.h | 2 +- .../kernel/rank_2k_grouped_problem_visitor.h | 2 +- .../gemm/kernel/rank_2k_transpose_operands.h | 2 +- .../cutlass/gemm/kernel/rank_2k_universal.h | 2 +- .../cutlass/gemm/kernel/rank_k_universal.h | 2 +- .../sm100_gemm_array_tma_warpspecialized.hpp | 606 +++++- ...ay_tma_warpspecialized_input_transform.hpp | 57 +- ...rray_tma_warpspecialized_mma_transform.hpp | 18 +- .../sm100_gemm_cpasync_warpspecialized.hpp | 794 ++++++++ ...gemm_mixed_tma_cpasync_warpspecialized.hpp | 1000 ++++++++++ .../kernel/sm100_gemm_tma_warpspecialized.hpp | 7 +- ...mm_tma_warpspecialized_input_transform.hpp | 120 +- ..._warpspecialized_mixed_input_transform.hpp | 1090 +++++++++++ ...gemm_tma_warpspecialized_mma_transform.hpp | 7 +- .../sm100_sparse_gemm_tma_warpspecialized.hpp | 4 +- .../kernel/sm100_static_tile_scheduler.hpp | 2 +- .../gemm/kernel/sm100_tile_scheduler.hpp | 4 +- .../kernel/sm100_tile_scheduler_group.hpp | 58 +- .../kernel/sm100_tile_scheduler_stream_k.hpp | 46 +- ...kscaled_gemm_array_tma_warpspecialized.hpp | 1330 +++++++++++++ ...3_blockscaled_gemm_tma_warpspecialized.hpp | 1121 +++++++++++ ...specialized_cooperative_asymmetric_dma.hpp | 4 +- .../include/cutlass/gemm/kernel/sm70_gemm.hpp | 2 +- .../cutlass/gemm/kernel/sm70_gemm_array.hpp | 279 +++ ..._array_tma_warpspecialized_cooperative.hpp | 49 +- ...emm_array_tma_warpspecialized_pingpong.hpp | 49 +- .../cutlass/gemm/kernel/sm90_gemm_tma.hpp | 9 +- .../kernel/sm90_gemm_tma_warpspecialized.hpp | 12 +- ...0_gemm_tma_warpspecialized_cooperative.hpp | 43 +- ...sm90_gemm_tma_warpspecialized_pingpong.hpp | 49 +- .../gemm/kernel/sm90_gemm_warpspecialized.hpp | 9 +- .../sm90_gemm_warpspecialized_cooperative.hpp | 8 +- .../sm90_gemm_warpspecialized_pingpong.hpp | 10 +- .../gemm/kernel/sm90_tile_scheduler.hpp | 2 +- .../gemm/kernel/sm90_tile_scheduler_group.hpp | 95 +- .../kernel/sm90_tile_scheduler_stream_k.hpp | 13 +- .../include/cutlass/gemm/kernel/sparse_gemm.h | 2 +- .../gemm/kernel/sparse_gemm_with_absmax.h | 2 +- .../gemm/kernel/sparse_gemm_with_visitor.h | 2 +- .../gemm/kernel/static_tile_scheduler.hpp | 2 +- .../cutlass/gemm/kernel/symm_universal.h | 2 +- .../cutlass/gemm/kernel/tile_scheduler.hpp | 62 +- .../gemm/kernel/tile_scheduler_detail.hpp | 2 +- .../gemm/kernel/tile_scheduler_params.h | 54 +- .../cutlass/gemm/kernel/trmm_universal.h | 2 +- 3rd/cutlass/include/cutlass/gemm/thread/mma.h | 2 +- .../include/cutlass/gemm/thread/mma_sm50.h | 2 +- .../include/cutlass/gemm/thread/mma_sm60.h | 2 +- .../include/cutlass/gemm/thread/mma_sm61.h | 2 +- .../gemm/threadblock/default_ell_mma.h | 8 +- .../gemm/threadblock/default_gemv_core.h | 2 +- .../cutlass/gemm/threadblock/default_mma.h | 8 +- .../gemm/threadblock/default_mma_core.h | 2 +- .../gemm/threadblock/default_mma_core_simt.h | 2 +- .../gemm/threadblock/default_mma_core_sm70.h | 2 +- .../gemm/threadblock/default_mma_core_sm75.h | 2 +- .../gemm/threadblock/default_mma_core_sm80.h | 5 +- .../default_mma_core_sparse_sm80.h | 2 +- .../default_mma_core_with_access_size.h | 2 +- .../default_mma_core_with_reduction.h | 2 +- .../gemm/threadblock/default_mma_core_wmma.h | 2 +- .../default_mma_layernorm_mainloop_fusion.h | 4 +- .../default_mma_multistage_blockwise.h | 212 +++ .../default_mma_planar_complex_multistage.h | 2 +- .../default_mma_planar_complex_pipelined.h | 2 +- .../default_mma_softmax_mainloop_fusion.h | 4 +- .../threadblock/default_mma_with_reduction.h | 4 +- .../default_multistage_mma_complex.h | 2 +- .../default_multistage_mma_complex_core.h | 2 +- ...default_multistage_mma_complex_core_sm80.h | 2 +- .../default_multistage_trmm_complex.h | 2 +- .../gemm/threadblock/default_sparse_mma.h | 6 +- .../cutlass/gemm/threadblock/default_trmm.h | 12 +- .../gemm/threadblock/ell_mma_multistage.h | 2 +- .../gemm/threadblock/ell_mma_pipelined.h | 2 +- .../include/cutlass/gemm/threadblock/gemv.h | 4 +- .../cutlass/gemm/threadblock/index_remat.h | 2 +- .../cutlass/gemm/threadblock/mma_base.h | 2 +- .../gemm/threadblock/mma_blas3_multistage.h | 2 +- ...mma_layernorm_mainloop_fusion_multistage.h | 2 +- .../cutlass/gemm/threadblock/mma_multistage.h | 5 +- .../threadblock/mma_multistage_blockwise.h | 449 +++++ .../cutlass/gemm/threadblock/mma_pipelined.h | 2 +- .../threadblock/mma_planar_complex_base.h | 2 +- .../mma_planar_complex_multistage.h | 4 +- .../mma_planar_complex_pipelined.h | 2 +- .../gemm/threadblock/mma_singlestage.h | 4 +- .../mma_softmax_mainloop_fusion_multistage.h | 2 +- .../gemm/threadblock/mma_sparse_base.h | 2 +- .../gemm/threadblock/mma_sparse_multistage.h | 2 +- .../mma_with_reduction_multistage.h | 2 +- .../gemm/threadblock/threadblock_swizzle.h | 2 +- .../threadblock/threadblock_swizzle_streamk.h | 4 +- .../gemm/warp/default_mma_complex_tensor_op.h | 2 +- .../gemm/warp/default_mma_sparse_tensor_op.h | 2 +- .../cutlass/gemm/warp/default_mma_tensor_op.h | 2 +- .../gemm/warp/default_mma_tensor_op_sm80.h | 2 +- .../default_mma_with_reduction_tensor_op.h | 2 +- .../gemm/warp/default_mma_wmma_tensor_op.h | 2 +- .../warp/layernorm_scale_bias_transform.h | 2 +- 3rd/cutlass/include/cutlass/gemm/warp/mma.h | 2 +- .../cutlass/gemm/warp/mma_complex_tensor_op.h | 4 +- .../warp/mma_complex_tensor_op_fast_f32.h | 4 +- ...mma_complex_tensor_op_tile_iterator_sm80.h | 2 +- .../warp/mma_gaussian_complex_tensor_op.h | 2 +- ...ian_complex_tensor_op_tile_iterator_sm80.h | 2 +- .../gemm/warp/mma_mixed_input_tensor_op.h | 8 +- .../cutlass/gemm/warp/mma_planar_complex.h | 2 +- .../include/cutlass/gemm/warp/mma_simt.h | 2 +- .../cutlass/gemm/warp/mma_simt_policy.h | 2 +- .../gemm/warp/mma_simt_tile_iterator.h | 2 +- .../cutlass/gemm/warp/mma_sparse_tensor_op.h | 4 +- .../include/cutlass/gemm/warp/mma_tensor_op.h | 2 +- .../gemm/warp/mma_tensor_op_fast_f32.h | 2 +- .../warp/mma_tensor_op_fragment_iterator.h | 4 +- .../cutlass/gemm/warp/mma_tensor_op_policy.h | 2 +- .../cutlass/gemm/warp/mma_tensor_op_sm70.h | 2 +- .../warp/mma_tensor_op_tile_access_iterator.h | 4 +- .../gemm/warp/mma_tensor_op_tile_iterator.h | 2 +- .../warp/mma_tensor_op_tile_iterator_sm70.h | 2 +- .../warp/mma_tensor_op_tile_iterator_sm80.h | 2 +- .../warp/mma_tensor_op_tile_iterator_sparse.h | 2 +- .../warp/mma_tensor_op_tile_iterator_wmma.h | 2 +- .../cutlass/gemm/warp/mma_tensor_op_wmma.h | 2 +- .../gemm/warp/mma_with_reduction_tensor_op.h | 2 +- .../gemm/warp/scale_bias_tile_iterator.h | 2 +- .../gemm/warp/softmax_scale_bias_transform.h | 2 +- .../gemm/warp/tile_iterator_planar_complex.h | 2 +- 3rd/cutlass/include/cutlass/gemm_coord.h | 2 +- 3rd/cutlass/include/cutlass/gemm_coord.hpp | 2 +- 3rd/cutlass/include/cutlass/half.h | 8 +- 3rd/cutlass/include/cutlass/integer_subbyte.h | 29 +- .../include/cutlass/kernel_hardware_info.h | 4 +- .../include/cutlass/kernel_hardware_info.hpp | 2 +- 3rd/cutlass/include/cutlass/kernel_launch.h | 2 +- 3rd/cutlass/include/cutlass/layout/layout.h | 2 +- 3rd/cutlass/include/cutlass/layout/matrix.h | 2 +- 3rd/cutlass/include/cutlass/layout/permute.h | 7 +- .../include/cutlass/layout/pitch_linear.h | 2 +- 3rd/cutlass/include/cutlass/layout/tensor.h | 5 +- .../layout/tensor_op_multiplicand_sm70.h | 2 +- .../layout/tensor_op_multiplicand_sm75.h | 2 +- .../layout/tensor_op_multiplicand_sm80.h | 2 +- 3rd/cutlass/include/cutlass/layout/vector.h | 2 +- 3rd/cutlass/include/cutlass/matrix.h | 282 +-- 3rd/cutlass/include/cutlass/matrix_coord.h | 2 +- 3rd/cutlass/include/cutlass/matrix_shape.h | 2 +- .../include/cutlass/numeric_conversion.h | 1295 +++++++++++-- 3rd/cutlass/include/cutlass/numeric_size.h | 10 +- 3rd/cutlass/include/cutlass/numeric_types.h | 3 +- .../include/cutlass/pipeline/pipeline.hpp | 2 +- .../cutlass/pipeline/sm100_pipeline.hpp | 79 +- .../cutlass/pipeline/sm90_pipeline.hpp | 13 +- .../include/cutlass/pitch_linear_coord.h | 2 +- .../include/cutlass/platform/platform.h | 90 +- .../include/cutlass/predicate_vector.h | 11 +- 3rd/cutlass/include/cutlass/quaternion.h | 2 +- 3rd/cutlass/include/cutlass/real.h | 2 +- .../cutlass/reduction/device/reduce_split_k.h | 2 +- .../cutlass/reduction/device/tensor_reduce.h | 2 +- .../device/tensor_reduce_affine_contiguous.h | 2 +- .../device/tensor_reduce_affine_strided.h | 2 +- .../reduction/kernel/reduce_softmax_final.h | 2 +- .../cutlass/reduction/kernel/reduce_split_k.h | 2 +- .../kernel/tensor_reduce_affine_contiguous.h | 2 +- .../kernel/tensor_reduce_affine_strided.h | 2 +- .../include/cutlass/reduction/thread/reduce.h | 2 +- .../reduction/thread/reduction_operators.h | 2 +- .../cutlass/reduction/threadblock_swizzle.h | 2 +- .../include/cutlass/relatively_equal.h | 3 +- 3rd/cutlass/include/cutlass/semaphore.h | 2 +- .../include/cutlass/subbyte_reference.h | 2 +- 3rd/cutlass/include/cutlass/tensor_coord.h | 2 +- 3rd/cutlass/include/cutlass/tensor_ref.h | 2 +- .../cutlass/tensor_ref_planar_complex.h | 2 +- 3rd/cutlass/include/cutlass/tensor_view.h | 2 +- .../cutlass/tensor_view_planar_complex.h | 2 +- 3rd/cutlass/include/cutlass/tfloat32.h | 6 +- 3rd/cutlass/include/cutlass/thread/matrix.h | 2 +- 3rd/cutlass/include/cutlass/trace.h | 2 +- .../collective/sm90_wgmma_transpose.hpp | 6 +- .../device/transform_universal_adapter.hpp | 2 +- .../kernel/filter_format_transformer.hpp | 2 +- .../kernel/sm90_sparse_gemm_compressor.hpp | 76 +- .../kernel/sparse_gemm_compressor.hpp | 2 +- .../transform/pitch_linear_thread_map.h | 14 +- .../cutlass/transform/thread/transpose.h | 2 +- .../cutlass/transform/thread/unary_op.h | 2 +- .../transform/threadblock/ell_iterator.h | 2 +- .../ell_predicated_tile_access_iterator.h | 2 +- .../ell_predicated_tile_iterator.h | 4 +- ...icated_scale_bias_vector_access_iterator.h | 2 +- .../predicated_scale_bias_vector_iterator.h | 2 +- .../predicated_tile_access_iterator.h | 4 +- ...icated_tile_access_iterator_2dthreadtile.h | 2 +- .../predicated_tile_access_iterator_params.h | 2 +- ...d_tile_access_iterator_triangular_matrix.h | 2 +- .../threadblock/predicated_tile_iterator.h | 4 +- .../predicated_tile_iterator_2dthreadtile.h | 8 +- ...edicated_tile_iterator_triangular_matrix.h | 6 +- .../predicated_vector_access_iterator.h | 2 +- ...egular_scale_bias_vector_access_iterator.h | 2 +- .../regular_tile_access_iterator.h | 2 +- ...egular_tile_access_iterator_pitch_linear.h | 2 +- ...access_iterator_pitch_linear_direct_conv.h | 2 +- .../regular_tile_access_iterator_tensor_op.h | 2 +- ...ular_tile_access_iterator_tensor_op_sm80.h | 2 +- .../threadblock/regular_tile_iterator.h | 2 +- .../regular_tile_iterator_pitch_linear.h | 2 +- ..._tile_iterator_pitch_linear_2dthreadtile.h | 2 +- .../regular_tile_iterator_tensor_op.h | 2 +- .../regular_tile_iterator_tensor_op_sm70.h | 2 +- .../transform/threadblock/vector_iterator.h | 2 +- .../transform/warp/vector_fragment_iterator.h | 4 +- 3rd/cutlass/include/cutlass/uint128.h | 96 +- 3rd/cutlass/include/cutlass/uint256.h | 93 + 3rd/cutlass/include/cutlass/version.h | 6 +- 3rd/cutlass/include/cutlass/wmma_array.h | 2 +- 3rd/cutlass/include/cutlass/workspace.h | 5 +- 3rd/update-cutlass.sh | 2 +- 823 files changed, 40478 insertions(+), 4319 deletions(-) create mode 100644 3rd/cutlass/include/cute/util/print_latex.hpp create mode 100644 3rd/cutlass/include/cute/util/print_svg.hpp create mode 100644 3rd/cutlass/include/cute/util/print_tensor.hpp create mode 100644 3rd/cutlass/include/cutlass/arch/mma_sm100.h create mode 100644 3rd/cutlass/include/cutlass/detail/collective/moe_stride_utils.hpp create mode 100644 3rd/cutlass/include/cutlass/detail/collective/sm103_kernel_type.hpp create mode 100644 3rd/cutlass/include/cutlass/detail/sm100_mixed_dtype_blockwise_layout.hpp create mode 100644 3rd/cutlass/include/cutlass/detail/sm103_blockscaled_layout.hpp create mode 100644 3rd/cutlass/include/cutlass/epilogue/collective/builders/sm103_builder.inl create mode 100644 3rd/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_array_planar_complex_nosmem.hpp create mode 100644 3rd/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_array_planar_complex_tma_warpspecialized.hpp create mode 100644 3rd/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_planar_complex_tma_warpspecialized.hpp create mode 100644 3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_scaling_factor.h create mode 100644 3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_9xBF16_interleaved_complex_umma_builder.inl create mode 100644 3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_blockscaled_mixed_tma_cpasync_umma_builder.inl create mode 100644 3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_cpasync_umma_builder.inl create mode 100644 3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_interleaved_complex_umma_builder.inl create mode 100644 3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_mixed_input_umma_builder.inl create mode 100644 3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_mixed_tma_cpasync_umma_builder.inl create mode 100644 3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_planar_complex_umma_builder.inl create mode 100644 3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_simt_builder.inl create mode 100644 3rd/cutlass/include/cutlass/gemm/collective/builders/sm103_blockscaled_umma_builder.inl create mode 100644 3rd/cutlass/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized_rcggemm.hpp create mode 100644 3rd/cutlass/include/cutlass/gemm/collective/sm100_blockscaled_mma_mixed_tma_cpasync_warpspecialized.hpp create mode 100644 3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_interleaved_complex_emulated.hpp create mode 100644 3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_interleaved_complex_tf32.hpp create mode 100644 3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_planar_complex.hpp create mode 100644 3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_rcggemm.hpp create mode 100644 3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp create mode 100644 3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_mixed_tma_cpasync_warpspecialized.hpp create mode 100644 3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_interleaved_complex_emulated.hpp create mode 100644 3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_interleaved_complex_tf32.hpp create mode 100644 3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp create mode 100644 3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_planar_complex.hpp create mode 100644 3rd/cutlass/include/cutlass/gemm/collective/sm103_blockscaled_mma_array_warpspecialized.hpp create mode 100644 3rd/cutlass/include/cutlass/gemm/collective/sm103_blockscaled_mma_warpspecialized.hpp create mode 100644 3rd/cutlass/include/cutlass/gemm/collective/sm80_mma_array_multistage.hpp create mode 100644 3rd/cutlass/include/cutlass/gemm/device/gemm_blockwise.h create mode 100644 3rd/cutlass/include/cutlass/gemm/device/gemv_blockscaled.h create mode 100644 3rd/cutlass/include/cutlass/gemm/kernel/gemm_blockwise.h create mode 100644 3rd/cutlass/include/cutlass/gemm/kernel/gemm_universal_blockwise.h create mode 100644 3rd/cutlass/include/cutlass/gemm/kernel/gemv_blockscaled.h create mode 100644 3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_cpasync_warpspecialized.hpp create mode 100644 3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_mixed_tma_cpasync_warpspecialized.hpp create mode 100644 3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mixed_input_transform.hpp create mode 100644 3rd/cutlass/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_array_tma_warpspecialized.hpp create mode 100644 3rd/cutlass/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_tma_warpspecialized.hpp create mode 100644 3rd/cutlass/include/cutlass/gemm/kernel/sm70_gemm_array.hpp create mode 100644 3rd/cutlass/include/cutlass/gemm/threadblock/default_mma_multistage_blockwise.h create mode 100644 3rd/cutlass/include/cutlass/gemm/threadblock/mma_multistage_blockwise.h create mode 100644 3rd/cutlass/include/cutlass/uint256.h diff --git a/3rd/cutlass/include/cute/algorithm/axpby.hpp b/3rd/cutlass/include/cute/algorithm/axpby.hpp index 60d5b46..7bf2dd6 100644 --- a/3rd/cutlass/include/cute/algorithm/axpby.hpp +++ b/3rd/cutlass/include/cute/algorithm/axpby.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/algorithm/clear.hpp b/3rd/cutlass/include/cute/algorithm/clear.hpp index 225c46e..adffd9d 100644 --- a/3rd/cutlass/include/cute/algorithm/clear.hpp +++ b/3rd/cutlass/include/cute/algorithm/clear.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/algorithm/cooperative_copy.hpp b/3rd/cutlass/include/cute/algorithm/cooperative_copy.hpp index 2653916..a06c4cf 100644 --- a/3rd/cutlass/include/cute/algorithm/cooperative_copy.hpp +++ b/3rd/cutlass/include/cute/algorithm/cooperative_copy.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** -* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +* Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/algorithm/cooperative_gemm.hpp b/3rd/cutlass/include/cute/algorithm/cooperative_gemm.hpp index f0a9935..4bfee04 100644 --- a/3rd/cutlass/include/cute/algorithm/cooperative_gemm.hpp +++ b/3rd/cutlass/include/cute/algorithm/cooperative_gemm.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/algorithm/copy.hpp b/3rd/cutlass/include/cute/algorithm/copy.hpp index d05b170..1859138 100644 --- a/3rd/cutlass/include/cute/algorithm/copy.hpp +++ b/3rd/cutlass/include/cute/algorithm/copy.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -254,19 +254,16 @@ copy(AutoVectorizingCopyWithAssumedAlignment const&, if constexpr (common_elem > 1) { constexpr int align_bits = CUTE_STATIC_V(gcd(max_alignment(src), max_alignment(dst), Int{})); - constexpr int vec_bits = gcd(common_elem * sizeof_bits_v, align_bits); + constexpr int vec_bits = gcd(common_elem * sizeof_bits_v, align_bits); - if constexpr ((vec_bits % 8) == 0) + if constexpr ((vec_bits % 8) == 0 && sizeof_bits_v < Int{}) { - // If more than one element vectorizes to 8bits or more, then recast and copy + // If more than one element vectorizes to a multiple of 8bits that is larger than the value_type, then recast and copy using VecType = uint_bit_t; - // Preserve volatility - using SrcVecType = conditional_t, VecType const volatile, VecType const>; - using DstVecType = conditional_t, VecType volatile, VecType >; // Recast - Tensor src_v = recast(src); - Tensor dst_v = recast(dst); + Tensor src_v = recast(src); + Tensor dst_v = recast(dst); return copy_if(constant_fn{}, src_v, dst_v); } else { return copy_if(constant_fn{}, src, dst); diff --git a/3rd/cutlass/include/cute/algorithm/fill.hpp b/3rd/cutlass/include/cute/algorithm/fill.hpp index 37b97f1..5b744a0 100644 --- a/3rd/cutlass/include/cute/algorithm/fill.hpp +++ b/3rd/cutlass/include/cute/algorithm/fill.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/algorithm/functional.hpp b/3rd/cutlass/include/cute/algorithm/functional.hpp index 5c56eb5..051c33b 100644 --- a/3rd/cutlass/include/cute/algorithm/functional.hpp +++ b/3rd/cutlass/include/cute/algorithm/functional.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/algorithm/gemm.hpp b/3rd/cutlass/include/cute/algorithm/gemm.hpp index 97839c0..aabbc76 100644 --- a/3rd/cutlass/include/cute/algorithm/gemm.hpp +++ b/3rd/cutlass/include/cute/algorithm/gemm.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/algorithm/prefer.hpp b/3rd/cutlass/include/cute/algorithm/prefer.hpp index 0a1c53e..a837fcc 100644 --- a/3rd/cutlass/include/cute/algorithm/prefer.hpp +++ b/3rd/cutlass/include/cute/algorithm/prefer.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/algorithm/prefetch.hpp b/3rd/cutlass/include/cute/algorithm/prefetch.hpp index 265da12..f2c340f 100644 --- a/3rd/cutlass/include/cute/algorithm/prefetch.hpp +++ b/3rd/cutlass/include/cute/algorithm/prefetch.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/algorithm/tensor_algorithms.hpp b/3rd/cutlass/include/cute/algorithm/tensor_algorithms.hpp index f47becb..9d8abbd 100644 --- a/3rd/cutlass/include/cute/algorithm/tensor_algorithms.hpp +++ b/3rd/cutlass/include/cute/algorithm/tensor_algorithms.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/algorithm/tensor_reduce.hpp b/3rd/cutlass/include/cute/algorithm/tensor_reduce.hpp index a6f1373..2ceb4a4 100644 --- a/3rd/cutlass/include/cute/algorithm/tensor_reduce.hpp +++ b/3rd/cutlass/include/cute/algorithm/tensor_reduce.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/algorithm/tuple_algorithms.hpp b/3rd/cutlass/include/cute/algorithm/tuple_algorithms.hpp index 311e105..966ae8f 100644 --- a/3rd/cutlass/include/cute/algorithm/tuple_algorithms.hpp +++ b/3rd/cutlass/include/cute/algorithm/tuple_algorithms.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/arch/cluster_sm100.hpp b/3rd/cutlass/include/cute/arch/cluster_sm100.hpp index 0bcf19b..3b8e4e3 100755 --- a/3rd/cutlass/include/cute/arch/cluster_sm100.hpp +++ b/3rd/cutlass/include/cute/arch/cluster_sm100.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/arch/cluster_sm90.hpp b/3rd/cutlass/include/cute/arch/cluster_sm90.hpp index 524a47e..8d05c5c 100644 --- a/3rd/cutlass/include/cute/arch/cluster_sm90.hpp +++ b/3rd/cutlass/include/cute/arch/cluster_sm90.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/arch/config.hpp b/3rd/cutlass/include/cute/arch/config.hpp index 8ec8ffb..d9cecf9 100644 --- a/3rd/cutlass/include/cute/arch/config.hpp +++ b/3rd/cutlass/include/cute/arch/config.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -54,13 +54,15 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// #if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ - defined(CUTLASS_ARCH_MMA_SM120A_ENABLED)) + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM121A_ENABLED)) # define CUTE_ARCH_TMA_SM90_ENABLED # define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED # define CUTE_ARCH_STSM_SM90_ENABLED #endif -#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED)) +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED)) # define CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED # define CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED # define CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED @@ -68,11 +70,12 @@ # define CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED #endif -#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) +#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) # define CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED #endif -#if (defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED)) +#if (defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103F_ENABLED)) # define CUTE_ARCH_TMA_SM90_ENABLED # define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED # define CUTE_ARCH_STSM_SM90_ENABLED @@ -83,32 +86,59 @@ # define CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED #endif -#if defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) +#if defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM103F_ENABLED) # define CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED #endif -#if (defined(CUTLASS_ARCH_MMA_SM120F_ENABLED)) +#if (defined(CUTLASS_ARCH_MMA_SM120F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121F_ENABLED)) +# define CUTE_ARCH_TMA_SM90_ENABLED +# define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED +# define CUTE_ARCH_STSM_SM90_ENABLED +#endif + + +// SM110 specific configs +#if (defined(CUTLASS_ARCH_MMA_SM110A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM110F_ENABLED)) # define CUTE_ARCH_TMA_SM90_ENABLED # define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED # define CUTE_ARCH_STSM_SM90_ENABLED +# define CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_S8_MMA_ENABLED +# define CUTE_ARCH_LDSM_SM100A_ENABLED +# define CUTE_ARCH_STSM_SM100A_ENABLED +# define CUTE_ARCH_TCGEN05_TMEM_ENABLED +# define CUTE_ARCH_TMA_SM100_ENABLED +# define CUTE_ARCH_LOAD256_SM100A_ENABLED +# define CUTE_ARCH_STORE256_SM100A_ENABLED +# define CUTE_ARCH_FLOAT2_MATH_ENABLED #endif +#if (defined(CUTLASS_ARCH_MMA_SM110A_ENABLED)) +# define CUTE_ARCH_TCGEN05_S8_MMA_ENABLED +#endif #if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED)) # define CUTE_ARCH_TCGEN05_S8_MMA_ENABLED #endif #if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ - defined(CUTLASS_ARCH_MMA_SM120A_ENABLED)) + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121A_ENABLED)) # define CUTE_ARCH_LDSM_SM100A_ENABLED # define CUTE_ARCH_STSM_SM100A_ENABLED #endif -#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED)) +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED)) # define CUTE_ARCH_TCGEN05_TMEM_ENABLED #endif -#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED)) +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED)) # define CUTE_ARCH_TMA_SM100_ENABLED #endif @@ -120,21 +150,26 @@ # define CUTE_ARCH_FLOAT2_MATH_ENABLED #endif -#if defined(CUTLASS_ARCH_MMA_SM120_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) +#if (defined(CUTLASS_ARCH_MMA_SM120_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM121_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121A_ENABLED)) # define CUTE_ARCH_MMA_SM120_ENABLED # define CUTE_ARCH_TMA_SM120_ENABLED #endif -#if defined(CUTLASS_ARCH_MMA_SM120_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) +#if (defined(CUTLASS_ARCH_MMA_SM120_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM121_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121A_ENABLED)) # if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)) # define CUTE_ARCH_F8F6F4_MMA_ENABLED # define CUTE_ARCH_MXF8F6F4_MMA_ENABLED # define CUTE_ARCH_MXF4NVF4_2X_UE8M0_MMA_ENABLED # define CUTE_ARCH_MXF4NVF4_4X_UE4M3_MMA_ENABLED # endif +# if (__CUDACC_VER_MAJOR__ == 13 && __CUDACC_VER_MINOR__ >= 1) +# define CUTE_ARCH_MXF4NVF4_4X_UE8M0_MMA_ENABLED +# endif #endif -#if defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) +#if defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM103F_ENABLED) # define CUTE_ARCH_LDSM_SM100A_ENABLED # define CUTE_ARCH_STSM_SM100A_ENABLED # define CUTE_ARCH_TCGEN05_TMEM_ENABLED @@ -149,14 +184,16 @@ # define CUTE_ARCH_TMA_SM100_ENABLED #endif -#if defined(CUTLASS_ARCH_MMA_SM120F_ENABLED) +#if (defined(CUTLASS_ARCH_MMA_SM120F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121F_ENABLED)) # define CUTE_ARCH_LDSM_SM100A_ENABLED # define CUTE_ARCH_STSM_SM100A_ENABLED #endif #if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) ||\ defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) ||\ - defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120F_ENABLED)) + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM103F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM121A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121F_ENABLED)) # if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)) # define CUTE_ARCH_LOAD256_SM100A_ENABLED # define CUTE_ARCH_STORE256_SM100A_ENABLED @@ -168,3 +205,7 @@ #define CUTE_ARCH_FLOAT2_MATH_ENABLED #endif +#if defined(CUTLASS_ARCH_MMA_SM103_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) +# define CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ULTRA_ENABLED +#endif + diff --git a/3rd/cutlass/include/cute/arch/copy.hpp b/3rd/cutlass/include/cute/arch/copy.hpp index 8b62fa9..87e1030 100644 --- a/3rd/cutlass/include/cute/arch/copy.hpp +++ b/3rd/cutlass/include/cute/arch/copy.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/arch/copy_sm100.hpp b/3rd/cutlass/include/cute/arch/copy_sm100.hpp index aa969af..c8109b3 100644 --- a/3rd/cutlass/include/cute/arch/copy_sm100.hpp +++ b/3rd/cutlass/include/cute/arch/copy_sm100.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -41,6 +41,51 @@ namespace cute { //////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Global Memory Load and Store PTX definitions +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_LOAD_256bit_CACHE_NOALLOCATION +{ + using SRegisters = uint256_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint256_t const& gmem_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { + #if defined(CUTE_ARCH_LOAD256_SM100A_ENABLED) + asm volatile("ld.global.L1::no_allocate.v8.f32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "l"(&gmem_addr) ); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use LOAD.256 without CUTE_ARCH_LOAD256_SM100A_ENABLED."); + #endif + } +}; + +struct SM100_STORE_256bit_CACHE_NOALLOCATION +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint256_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint256_t& gmem_addr) + { + #if defined(CUTE_ARCH_STORE256_SM100A_ENABLED) + asm volatile("st.global.L1::no_allocate.v8.f32 [%0], {%1, %2, %3, %4, %5, %6, %7, %8};\n" + :: "l"(&gmem_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), "r"(src4), "r"(src5), "r"(src6), "r"(src7)); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use stg.256 without CUTE_ARCH_STORE256_SM100A_ENABLED."); + #endif + } +}; + //////////////////////////////////////////////////////////////////////////////////////////////////// // // LDSM PTX definitions diff --git a/3rd/cutlass/include/cute/arch/copy_sm100_tma.hpp b/3rd/cutlass/include/cute/arch/copy_sm100_tma.hpp index f69cbff..178995d 100644 --- a/3rd/cutlass/include/cute/arch/copy_sm100_tma.hpp +++ b/3rd/cutlass/include/cute/arch/copy_sm100_tma.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2020 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2020 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -37,6 +37,8 @@ #include #include +#include "cutlass/arch/synclog.hpp" + namespace cute { diff --git a/3rd/cutlass/include/cute/arch/copy_sm50.hpp b/3rd/cutlass/include/cute/arch/copy_sm50.hpp index 12518fc..1ccadec 100644 --- a/3rd/cutlass/include/cute/arch/copy_sm50.hpp +++ b/3rd/cutlass/include/cute/arch/copy_sm50.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/arch/copy_sm75.hpp b/3rd/cutlass/include/cute/arch/copy_sm75.hpp index 0e4821b..0d22b40 100644 --- a/3rd/cutlass/include/cute/arch/copy_sm75.hpp +++ b/3rd/cutlass/include/cute/arch/copy_sm75.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -41,11 +41,13 @@ // * https://reviews.llvm.org/D121666 // * https://reviews.llvm.org/D126846 #define CUTE_ARCH_CLANG_SUPPORTS_LDSM_SM75 (__clang_major__ >= 15) + #define CUTE_ARCH_CLANG_SUPPORTS_MOVM_SM75 (__clang_major__ >= 15) #endif #if defined(__NVCC__) || defined(__CUDACC_RTC__) // ldmatrix PTX instruction added in CUDA 10.2+ #define CUTE_ARCH_NVCC_SUPPORTS_LDSM_SM75 ((__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) || __CUDACC_VER_MAJOR__ >= 11) + #define CUTE_ARCH_NVCC_SUPPORTS_MOVM_SM75 ((__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) || __CUDACC_VER_MAJOR__ >= 11) #endif #if ! defined(CUTE_ARCH_LDSM_SM75_SUPPORTED) @@ -60,6 +62,19 @@ #define CUTE_ARCH_LDSM_SM75_ACTIVATED 1 #endif +#if ! defined(CUTE_ARCH_MOVM_SM75_SUPPORTED) + #define CUTE_ARCH_MOVM_SM75_SUPPORTED (CUTE_ARCH_NVCC_SUPPORTS_MOVM_SM75 || CUTE_ARCH_CLANG_SUPPORTS_MOVM_SM75) +#endif + +#if ! defined(CUTE_ARCH_MOVM_SM75_ENABLED) + #define CUTE_ARCH_MOVM_SM75_ENABLED (CUTE_ARCH_MOVM_SM75_SUPPORTED) +#endif + +#if (CUTE_ARCH_MOVM_SM75_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + #define CUTE_ARCH_MOVM_SM75_ACTIVATED 1 +#endif + + namespace cute { @@ -183,6 +198,24 @@ struct SM75_U16x8_LDSM_T } }; +struct SM75_U32x1_MOVM_T +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t src, + uint32_t &dst) + { +#if CUTE_ARCH_MOVM_SM75_ACTIVATED + asm volatile("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;\n" + : "=r"(dst) + : "r"(src)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use movmatrix without CUTE_ARCH_MOVM_SM75_ACTIVATED."); +#endif + } +}; // // Legacy LDSM interfaces that aren't very useful // diff --git a/3rd/cutlass/include/cute/arch/copy_sm80.hpp b/3rd/cutlass/include/cute/arch/copy_sm80.hpp index 71a7b3a..3871085 100644 --- a/3rd/cutlass/include/cute/arch/copy_sm80.hpp +++ b/3rd/cutlass/include/cute/arch/copy_sm80.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/arch/copy_sm90.hpp b/3rd/cutlass/include/cute/arch/copy_sm90.hpp index 5c0745d..b7f39cb 100644 --- a/3rd/cutlass/include/cute/arch/copy_sm90.hpp +++ b/3rd/cutlass/include/cute/arch/copy_sm90.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/arch/copy_sm90_desc.hpp b/3rd/cutlass/include/cute/arch/copy_sm90_desc.hpp index 095cde5..2a8dc0b 100644 --- a/3rd/cutlass/include/cute/arch/copy_sm90_desc.hpp +++ b/3rd/cutlass/include/cute/arch/copy_sm90_desc.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -239,7 +239,7 @@ to_CUtensorMapDataType() { inline CUtensorMapSwizzle to_CUtensorMapSwizzle(SmemSwizzleBits const& t, SmemSwizzleBase const& b) { switch (t) { - default: assert(false && "Unsupported pair of SmemSwizzleBits and SmemSwizzleBase!"); + default: throw std::runtime_error("Unsupported pair of SmemSwizzleBits and SmemSwizzleBase!"); case SmemSwizzleBits::DISABLE: assert((b == SmemSwizzleBase::SWIZZLE_BASE_16B) && "Expected 16B swizzle base for 0B swizzle bits."); return CU_TENSOR_MAP_SWIZZLE_NONE; @@ -251,7 +251,7 @@ to_CUtensorMapSwizzle(SmemSwizzleBits const& t, SmemSwizzleBase const& b) { return CU_TENSOR_MAP_SWIZZLE_64B; case SmemSwizzleBits::B128: switch (b) { - default: assert(false && "Unsupported pair of SmemSwizzleBits and SmemSwizzleBase!"); + default: throw std::runtime_error("Unsupported pair of SmemSwizzleBits and SmemSwizzleBase!"); case SmemSwizzleBase::SWIZZLE_BASE_16B: return CU_TENSOR_MAP_SWIZZLE_128B; #if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ > 6))) @@ -265,7 +265,7 @@ to_CUtensorMapSwizzle(SmemSwizzleBits const& t, SmemSwizzleBase const& b) { inline CUtensorMapFloatOOBfill to_CUtensorMapFloatOOBfill(OOBFill const& t) { switch(t) { - default: assert(false && "Unknown OOBFill!"); + default: throw std::runtime_error("Unknown OOBFill!"); case OOBFill::ZERO: return CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; case OOBFill::CONSTANT: return CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA; } @@ -274,7 +274,7 @@ to_CUtensorMapFloatOOBfill(OOBFill const& t) { inline CUtensorMapL2promotion to_CUtensorMapL2promotion(L2Promotion const& t) { switch(t) { - default: assert(false && "Unknown L2Promotion!"); + default: throw std::runtime_error("Unknown L2Promotion!"); case L2Promotion::DISABLE: return CU_TENSOR_MAP_L2_PROMOTION_NONE; case L2Promotion::B64: return CU_TENSOR_MAP_L2_PROMOTION_L2_64B; case L2Promotion::B128: return CU_TENSOR_MAP_L2_PROMOTION_L2_128B; diff --git a/3rd/cutlass/include/cute/arch/copy_sm90_tma.hpp b/3rd/cutlass/include/cute/arch/copy_sm90_tma.hpp index ec15644..a0b5faa 100644 --- a/3rd/cutlass/include/cute/arch/copy_sm90_tma.hpp +++ b/3rd/cutlass/include/cute/arch/copy_sm90_tma.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/arch/mma.hpp b/3rd/cutlass/include/cute/arch/mma.hpp index 8b97f50..5a63563 100644 --- a/3rd/cutlass/include/cute/arch/mma.hpp +++ b/3rd/cutlass/include/cute/arch/mma.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/arch/mma_sm100.hpp b/3rd/cutlass/include/cute/arch/mma_sm100.hpp index 749da81..0dd3fac 100644 --- a/3rd/cutlass/include/cute/arch/mma_sm100.hpp +++ b/3rd/cutlass/include/cute/arch/mma_sm100.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/arch/mma_sm100_desc.hpp b/3rd/cutlass/include/cute/arch/mma_sm100_desc.hpp index f15108a..887b589 100644 --- a/3rd/cutlass/include/cute/arch/mma_sm100_desc.hpp +++ b/3rd/cutlass/include/cute/arch/mma_sm100_desc.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -449,8 +449,8 @@ union InstrDescriptorBlockScaled : 1, // b_sf_id_ : 2, // bit [ 4, 6) : Matrix B Scale Factor ID : 1, // - a_format_ : 3, // bit [ 7, 9) : MXF8F6F4Format:0 = E4M3, 1 = E5M2, 3 = E2M3, 4 = E3M2, 5 = E2M1. F32F16Format: 0 = F16, 1 = BF16, 2 = TF32. S8: 0 unsigned 8 bit, 1 signed 8 bit. BMMA: 0 Boolean - b_format_ : 3, // bit [10,12) : MXF8F6F4Format:0 = E4M3, 1 = E5M2, 3 = E2M3, 4 = E3M2, 5 = E2M1. F32F16Format: 0 = F16, 1 = BF16, 2 = TF32. S8: 0 unsigned 8 bit, 1 signed 8 bit. BMMA: 0 Boolean + a_format_ : 3, // bit [ 7, 10): MXF8F6F4Format:0 = E4M3, 1 = E5M2, 3 = E2M3, 4 = E3M2, 5 = E2M1. F32F16Format: 0 = F16, 1 = BF16, 2 = TF32. S8: 0 unsigned 8 bit, 1 signed 8 bit. BMMA: 0 Boolean + b_format_ : 3, // bit [10,13) : MXF8F6F4Format:0 = E4M3, 1 = E5M2, 3 = E2M3, 4 = E3M2, 5 = E2M1. F32F16Format: 0 = F16, 1 = BF16, 2 = TF32. S8: 0 unsigned 8 bit, 1 signed 8 bit. BMMA: 0 Boolean a_negate_ : 1, // bit [13,14) : 0 = no negate. 1 = negate. 1 value valid only for F32F16Format and MXF8F6F4Format b_negate_ : 1, // bit [14,15) : 0 = no negate. 1 = negate. 1 value valid only for F32F16Format and MXF8F6F4Format a_major_ : 1; // bit [15,16) : 0 = K-major. 1 = MN-major. Major value of 1 is only valid for E4M3, E5M2, INT8 (signed and unsigned), F16, BF16 and TF32 source formats @@ -459,7 +459,7 @@ union InstrDescriptorBlockScaled scale_format_ : 1, // bit [23,24) : 0=E4M3, 1=E8M0 m_dim_ : 5, // bit [24,29) : 4 LSBs not included. Valid values are: 4 (M=64), 8 (M=128), 16 (M=256) a_sf_id_ : 2, // bit [29,31) : Matrix A Scale Factor ID - : 1; // + k_size_ : 1; // bit [31,32) : MMA-K Dim. MXF8F6F4Format: 0=[dense: K32, sparse: K64]. S8Format: 0=[dense: K32, sparse: invalid]. MXF4Format: 0=[dense: K64, sparse: K128], 1=[dense: K96, sparse: invalid]. }; // Decay to a uint32_t diff --git a/3rd/cutlass/include/cute/arch/mma_sm100_umma.hpp b/3rd/cutlass/include/cute/arch/mma_sm100_umma.hpp index f754e26..d7dfb71 100644 --- a/3rd/cutlass/include/cute/arch/mma_sm100_umma.hpp +++ b/3rd/cutlass/include/cute/arch/mma_sm100_umma.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -46,10 +46,8 @@ template +struct SM100_MMA_TF32_TS_INTERLEAVED_CF32CTF32CTF32CF32_TN +{ + static_assert(M == 64 || M == 128, "SM100_MMA_TF32_TS_INTERLEAVED_CF32CTF32CTF32CF32_TN M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) || + (M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)), + "SM100_MMA_TF32_TS_INTERLEAVED_CF32CTF32CTF32CF32_TN N-mode size should be a multiple of 8 between 8 and 256 for M=64,\ + or a multiple of 16 between 16 and 256 for M=128."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_TF32_TS_INTERLEAVED_CF32CTF32CTF32CF32_TN A from TMEM can't be transposed"); + static_assert(b_major == UMMA::Major::K, "SM100_MMA_TF32_TS_INTERLEAVED_CF32CTF32CTF32CF32_TN B from SMEM requires non-transpose"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::tf32 [%0], [%1], %2, %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_TF32_TS_INTERLEAVED_CF32CTF32CTF32CF32_TN without CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_TF32_SS_SCALED +{ + static_assert(M == 64 || M == 128, "SM100_MMA_TF32_SS_SCALED M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), + "SM100_MMA_TF32_SS_SCALED N-mode size should be a multiple of 8 between 8 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& accumulate, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_UTFMMA_SCALED_ENABLED) + if (cute::elect_one_sync()) { + // ScaleC input should be a literal or compile time constant + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::tf32 [%0], %1, %2, %3, {%5, %6, %7, %8}, p, %9; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(accumulate), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), "n"(ScaleC)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_TF32_SS_SCALED without CUTE_ARCH_TCGEN05_UTFMMA_SCALED_ENABLED"); +#endif + } +}; + template struct SM100_MMA_F16BF16_SS_SCALED { static_assert(M == 64 || M == 128, "SM100_MMA_F16BF16_SS_SCALED M-mode size should be 64 or 128 for 1 CTA cluster MMA."); - static_assert((M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) || - (M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)), - "SM100_MMA_F16BF16_SS_SCALED N-mode size should be a multiple of 8 between 8 and 256 for M=64,\ - or a multiple of 16 between 16 and 256 for M=128."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), + "SM100_MMA_F16BF16_SS_SCALED N-mode size should be a multiple of 8 between 8 and 256."); using DRegisters = void; using ARegisters = uint64_t[1]; @@ -255,6 +334,51 @@ struct SM100_MMA_F16BF16_SS_SCALED } }; +template +struct SM100_MMA_TF32_TS_SCALED +{ + static_assert(M == 64 || M == 128, "SM100_MMA_TF32_TS_SCALED M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) || + (M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)), + "SM100_MMA_TF32_TS_SCALED N-mode size should be a multiple of 8 between 8 and 256 for M=64,\ + or a multiple of 16 between 16 and 256 for M=128."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_TF32_TS_SCALED A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& accumulate, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_UTFMMA_SCALED_ENABLED) + if (cute::elect_one_sync()) { + // ScaleC input should be a literal or compile time constant + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::tf32 [%0], [%1], %2, %3, {%5, %6, %7, %8}, p, %9; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(accumulate), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), "n"(ScaleC)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_TF32_TS_SCALED without CUTE_ARCH_TCGEN05_UTFMMA_SCALED_ENABLED"); +#endif + } +}; + template +struct SM100_MMA_TF32_2x1SM_TS_INTERLEAVED_CF32CTF32CTF32CF32_TN +{ + static_assert(M == 128 || M == 256, "SM100_MMA_TF32_2x1SM_TS_INTERLEAVED_CF32CTF32CTF32CF32_TN M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_TF32_2x1SM_TS_INTERLEAVED_CF32CTF32CTF32CF32_TN N-mode size should be a multiple of 32 between 32 and 256."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_TF32_2x1SM_TS_INTERLEAVED_CF32CTF32CTF32CF32_TN A from TMEM can't be transposed"); + static_assert(b_major == UMMA::Major::K, "SM100_MMA_TF32_2x1SM_TS_INTERLEAVED_CF32CTF32CTF32CF32_TN B from SMEM requires non-transpose"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::tf32 [%0], [%1], %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_TF32_2x1SM_TS_INTERLEAVED_CF32CTF32CTF32CF32_TN without CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_TF32_2x1SM_SS_SCALED +{ + static_assert(M == 128 || M == 256, "SM100_MMA_TF32_2x1SM_SS_SCALED M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_TF32_2x1SM_SS_SCALED N-mode size should be a multiple of 16 between 16 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& accumulate, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_UTFMMA_SCALED_ENABLED) + if (cute::elect_one_sync()) { + // ScaleC input should be a literal or compile time constant + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::tf32 [%0], %1, %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p, %13; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(accumulate), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7]), "n"(ScaleC)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_TF32_2x1SM_SS_SCALED without CUTE_ARCH_TCGEN05_UTFMMA_SCALED_ENABLED"); +#endif + } +}; + template struct SM100_MMA_F16BF16_2x1SM_SS_SCALED { static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_SS_SCALED M-mode size should be 128 or 256 for 2 CTA cluster MMA."); - static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_SS_SCALED N-mode size should be a multiple of 32 between 32 and 256."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_SS_SCALED N-mode size should be a multiple of 16 between 16 and 256."); using DRegisters = void; using ARegisters = uint64_t[1]; @@ -591,6 +794,49 @@ struct SM100_MMA_F16BF16_2x1SM_SS_SCALED } }; +template +struct SM100_MMA_TF32_2x1SM_TS_SCALED +{ + static_assert(M == 128 || M == 256, "SM100_MMA_TF32_2x1SM_TS_SCALED M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_TF32_2x1SM_TS_SCALED N-mode size should be a multiple of 32 between 32 and 256."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_TF32_2x1SM_TS_SCALED A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& accumulate, + uint64_t idescE) + { +#if defined(CUTE_ARCH_TCGEN05_UTFMMA_SCALED_ENABLED) + if (cute::elect_one_sync()) { + // ScaleC input should be a literal or compile time constant + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::tf32 [%0], [%1], %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p, %13; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(accumulate), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7]), "n"(ScaleC)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_TF32_2x1SM_TS_SCALED without CUTE_ARCH_TCGEN05_UTFMMA_SCALED_ENABLED"); +#endif + } +}; + template +struct SM103_MXF4_ULTRA_SS_VS +{ + static_assert(M == 128, "MMA M-mode size should be 128 for 1 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "MMA N-mode size should be a multiple of 16 between 16 and 256."); + static_assert(((VS == 32) & (is_same_v && is_same_v)) || (VS == 16), + "Vector size can only be 4x mode (VS=16) or 2x mode (VS=32) for MMA. 2x mode only supports float_e2m1_t for a/b types and ue8m0_t for sf type"); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + using SFARegisters = uint32_t[1]; + using SFBRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tsfa_addr, + uint32_t const& tsfb_addr) + { +#if defined(CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ULTRA_ENABLED) + if constexpr (VS == 16) { + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.block16 [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr)); + } + } + else if constexpr (VS == 32) { + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.block32 [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::2X [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr)); + } + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM103_MXF4_ULTRA_SS_VS without CUTE_ARCH_MMA_SM103A_ENABLED"); +#endif + } + +}; + + +template +struct SM103_MXF4_ULTRA_2x1SM_SS_VS +{ + static_assert(M == 128 || M == 256, "MMA M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "MMA N-mode size should be a multiple of 16 between 16 and 256."); + static_assert(((VS == 32) & (is_same_v && is_same_v)) || (VS == 16), + "Vector size can only be 4x mode (VS=16) or 2x mode (VS=32) for MMA. 2x mode only supports float_e2m1_t for a/b types and ue8m0_t for sf type"); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + using SFARegisters = uint32_t[1]; + using SFBRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tsfa_addr, + uint32_t const& tsfb_addr) + { +#if defined(CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ULTRA_ENABLED) + if constexpr (VS == 16) { + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.block16 [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr)); + } + } + else if constexpr (VS == 32) { + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.block32 [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.scale_vec::2X [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr)); + } + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM103_MXF4_ULTRA_2x1SM_SS_VS without CUTE_ARCH_MMA_SM103A_ENABLED"); +#endif + } + +}; +} // namespace SM103 + } // end namespace cute diff --git a/3rd/cutlass/include/cute/arch/mma_sm120.hpp b/3rd/cutlass/include/cute/arch/mma_sm120.hpp index 1433a2c..1468c83 100644 --- a/3rd/cutlass/include/cute/arch/mma_sm120.hpp +++ b/3rd/cutlass/include/cute/arch/mma_sm120.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -3131,6 +3131,29 @@ struct SM120_16x8x64_TN_VS static constexpr uint16_t bidB = 0; CUTE_STATIC_ASSERT(VS == 16 || VS == 32, "Scaling factor vector size has to be 16 or 32 for MXF4NVF4 MMA."); + if constexpr ( VS == 16 ) { +#if defined(CUTE_ARCH_MXF4NVF4_4X_UE8M0_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale.scale_vec::4X.f32.e2m1.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x64_TN_VS without CUTE_ARCH_MXF4NVF4_4X_UE8M0_MMA_ENABLED"); +#endif + } else if constexpr ( VS == 32 ) { + #if defined(CUTE_ARCH_MXF4NVF4_2X_UE8M0_MMA_ENABLED) asm volatile( "mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 " @@ -3151,6 +3174,7 @@ struct SM120_16x8x64_TN_VS #else CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x64_TN_VS without CUTE_ARCH_MXF4NVF4_2X_UE8M0_MMA_ENABLED"); #endif + } } }; diff --git a/3rd/cutlass/include/cute/arch/mma_sm120_sparse.hpp b/3rd/cutlass/include/cute/arch/mma_sm120_sparse.hpp index a695003..a95af10 100644 --- a/3rd/cutlass/include/cute/arch/mma_sm120_sparse.hpp +++ b/3rd/cutlass/include/cute/arch/mma_sm120_sparse.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -3324,6 +3324,28 @@ struct SM120_SPARSE_16x8x128_TN_VS // CUTE_HOST_DEVICE #include // GMMA::Major, etc. +#include "cutlass/arch/synclog.hpp" + namespace cute { namespace SM90::GMMA::SPARSE { diff --git a/3rd/cutlass/include/cute/arch/mma_sm90_gmma_sparse_ext.hpp b/3rd/cutlass/include/cute/arch/mma_sm90_gmma_sparse_ext.hpp index 8945551..be8f9aa 100644 --- a/3rd/cutlass/include/cute/arch/mma_sm90_gmma_sparse_ext.hpp +++ b/3rd/cutlass/include/cute/arch/mma_sm90_gmma_sparse_ext.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/arch/simd_sm100.hpp b/3rd/cutlass/include/cute/arch/simd_sm100.hpp index 1c07a31..7c36611 100644 --- a/3rd/cutlass/include/cute/arch/simd_sm100.hpp +++ b/3rd/cutlass/include/cute/arch/simd_sm100.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/arch/tmem_allocator_sm100.hpp b/3rd/cutlass/include/cute/arch/tmem_allocator_sm100.hpp index 680e237..3afdf80 100644 --- a/3rd/cutlass/include/cute/arch/tmem_allocator_sm100.hpp +++ b/3rd/cutlass/include/cute/arch/tmem_allocator_sm100.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/arch/util.hpp b/3rd/cutlass/include/cute/arch/util.hpp index 6a3883e..058813b 100644 --- a/3rd/cutlass/include/cute/arch/util.hpp +++ b/3rd/cutlass/include/cute/arch/util.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/atom/copy_atom.hpp b/3rd/cutlass/include/cute/atom/copy_atom.hpp index 5c455cc..e24cd78 100644 --- a/3rd/cutlass/include/cute/atom/copy_atom.hpp +++ b/3rd/cutlass/include/cute/atom/copy_atom.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -208,11 +208,11 @@ struct TiledCopy : Copy_Atom // Tile a tensor or a layout from shape // (M,N,...) // to shape - // ((ThrV,ThrX),FrgV,(RestM,RestN,...)) + // (Thr,(FrgV,FrgX),(RestM,RestN,...)) // where - // ThrV: The threads local to a COPY_ATOM Src. - // ThrX: The threads tiled across COPY_ATOMs Src. + // Thr: The logical threads within the tiled copy. // FrgV: The values local to a COPY_ATOM Src. + // FrgX: The values tiled across COPY_ATOMs Src. // RestM: The values tiled in M. // RestN: The values tiled in N. template @@ -229,11 +229,11 @@ struct TiledCopy : Copy_Atom // Tile a tensor or a layout from shape // (M,N,...) // to shape - // ((ThrV,ThrX),FrgV,(RestM,RestN,...)) + // (Thr,(FrgV,FrgX),(RestM,RestN,...)) // where - // ThrV: The threads local to a COPY_ATOM Dst. - // ThrX: The threads tiled across COPY_ATOMs Dst. + // Thr: The logical threads within the tiled copy. // FrgV: The values local to a COPY_ATOM Dst. + // FrgX: The values tiled across COPY_ATOMs Dst. // RestM: The values tiled in M. // RestN: The values tiled in N. template @@ -250,7 +250,7 @@ struct TiledCopy : Copy_Atom // Tile a tensor or a layout from shape // ((TileM,TileN,...), (RestM,RestN,...)) // to shape - // ((ThrV,ThrX),FrgV,(RestM,RestN,...)) + // (Thr,(FrgV,FrgX),(RestM,RestN,...)) template CUTE_HOST_DEVICE constexpr static auto @@ -325,21 +325,6 @@ struct TiledCopy : Copy_Atom return tile2thrfrg(ref_S, right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{}))(_,_,Int<0>{}); } - CUTE_HOST_DEVICE constexpr static - auto - get_layoutS_MN() - { - // (thr_idx,val_idx) -> (M,N) - auto layoutS_TV = get_layoutS_TV(); - // (M,K) -> (thr_idx,val_idx) - auto layoutS_MK = right_inverse(layoutS_TV).with_shape(shape(Tiler_MN{})); - - // athrid = (v,m,k) -> thr_idx - auto thrID_S = make_layout(size<0>(TiledLayout_TV{})); - - return cute::make_tuple(layoutS_MK, thrID_S); - } - CUTE_HOST_DEVICE constexpr static auto get_layoutD_TV() @@ -350,21 +335,6 @@ struct TiledCopy : Copy_Atom return tile2thrfrg(ref_D, right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{}))(_,_,Int<0>{}); } - CUTE_HOST_DEVICE constexpr static - auto - get_layoutD_MN() - { - // (thr_idx,val_idx) -> (M,N) - auto layoutD_TV = get_layoutD_TV(); - // (M,K) -> (thr_idx,val_idx) - auto layoutD_MK = right_inverse(layoutD_TV).with_shape(shape(Tiler_MN{})); - - // athrid = (v,m,k) -> thr_idx - auto thrID_D = make_layout(size<0>(TiledLayout_TV{})); - - return cute::make_tuple(layoutD_MK, thrID_D); - } - template ::value)> CUTE_HOST_DEVICE static @@ -436,7 +406,7 @@ template CUTE_HOST_DEVICE -auto +auto constexpr make_tiled_copy_impl(Copy_Atom const& atom, LayoutCopy_TV const&, Tiler const&) @@ -450,7 +420,7 @@ make_tiled_copy_impl(Copy_Atom const& atom, template CUTE_HOST_DEVICE -auto +auto constexpr make_tiled_copy_A(Copy_Atom const& copy_atom, TiledMMA const& mma) { @@ -459,7 +429,7 @@ make_tiled_copy_A(Copy_Atom const& copy_atom, template CUTE_HOST_DEVICE -auto +auto constexpr make_tiled_copy_B(Copy_Atom const& copy_atom, TiledMMA const& mma) { @@ -521,7 +491,7 @@ template > CUTE_HOST_DEVICE -auto +auto constexpr make_tiled_copy(Copy_Atom const& copy_atom, ThrLayout const& thr_layout = {}, // (m,n) -> thr_idx ValLayout const& val_layout = {}) // (m,n) -> val_idx @@ -569,7 +539,8 @@ make_cotiled_copy(Copy_Atom const& copy_atom, auto layout_tv_data = composition(inv_data_layout, atom_tv_layout); // Check validity - CUTE_STATIC_ASSERT_V(coalesce(composition(data_layout, layout<1>(layout_tv_data))) == coalesce(layout<1>(atom_tv_layout)), + // Append 1:0 to data_layout so that OOB coordinates get the stride-0 + CUTE_STATIC_ASSERT_V(coalesce(composition(make_layout(data_layout, Layout<_1,_0>{}), layout<1>(layout_tv_data))) == coalesce(layout<1>(atom_tv_layout)), "The memory pointed to by AtomTVLayout does not exist in the DataLayout."); // // Tiler -- Find the active elements in the DATA tensor and generate a tiler to extract them @@ -680,101 +651,6 @@ print(ThrCopy const& thr_copy) print(TiledCopy{}); } -// TiledCopy to LaTeX TikZ -template -CUTE_HOST_DEVICE -auto -print_latex(TiledCopy const& copy, - TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string -{ - auto [layoutS_MN, thrID_S] = copy.get_layoutS_MN(); - auto [layoutD_MN, thrID_D] = copy.get_layoutD_MN(); - - print_latex_copy(layoutS_MN, thrID_S, - layoutD_MN, thrID_D); -} - -// MNK Copy Layout to LaTeX TikZ -template -CUTE_HOST_DEVICE -void -print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and tid -> thr_idx - LayoutD const& D, ThrIDD const& TD, // (m,n) -> (tid,vid) and tid -> thr_idx - TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string -{ - CUTE_STATIC_ASSERT_V(rank(S) == Int<2>{}); - CUTE_STATIC_ASSERT_V(rank(D) == Int<2>{}); - - assert(size<0>(S) == size<0>(D)); - assert(size<1>(S) == size<1>(D)); - - // Commented prints - printf("%% LayoutS: "); print(S); printf("\n"); - printf("%% ThrIDS : "); print(TS); printf("\n"); - printf("%% LayoutD: "); print(D); printf("\n"); - printf("%% ThrIDD : "); print(TD); printf("\n\n"); - - // Header - printf("\\documentclass[convert]{standalone}\n" - "\\usepackage{tikz}\n\n" - "\\begin{document}\n" - "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n"); - - // S starting at 0,0 - for (int i = 0; i < size<0>(S); ++i) { - for (int j = 0; j < size<1>(S); ++j) { - int thrid = S(i,j) % size(TS); - int val_idx = S(i,j) / size(TS); - int thr_idx = TS(thrid); - - printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color(thr_idx, val_idx), - i, j, - thr_idx, val_idx); - } - } - // Grid - printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", - 0, 0, int(size<0>(S)), int(size<1>(S))); - // S Labels - for (int i = 0, j = -1; i < size<0>(S); ++i) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i); - } - for (int i = -1, j = 0; j < size<1>(S); ++j) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j); - } - - // D starting at 0,size<1>(S)+3 - for (int i = 0; i < size<0>(D); ++i) { - for (int j = 0; j < size<1>(D); ++j) { - int thrid = D(i,j) % size(TD); - int val_idx = D(i,j) / size(TD); - int thr_idx = TD(thrid); - - printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color(thr_idx, val_idx), - i, j + size<1>(S) + 3, - thr_idx, val_idx); - } - } - // Grid - printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", - 0, int(size<1>(S)+3), int(size<0>(D)), int(size<1>(D)+size<1>(S)+3)); - // D Labels - for (int i = 0, j = size<1>(D); i < size<0>(D); ++i) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j + size<1>(S) + 3, i); - } - for (int i = -1, j = 0; j < size<1>(D); ++j) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j + size<1>(S) + 3, j); - } - - // Footer - printf("\\end{tikzpicture}\n" - "\\end{document}\n"); -} - } // end namespace cute //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cute/atom/copy_traits.hpp b/3rd/cutlass/include/cute/atom/copy_traits.hpp index 9117a1f..3f0b23d 100644 --- a/3rd/cutlass/include/cute/atom/copy_traits.hpp +++ b/3rd/cutlass/include/cute/atom/copy_traits.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/atom/copy_traits_sm100.hpp b/3rd/cutlass/include/cute/atom/copy_traits_sm100.hpp index 594149d..7996fa3 100644 --- a/3rd/cutlass/include/cute/atom/copy_traits_sm100.hpp +++ b/3rd/cutlass/include/cute/atom/copy_traits_sm100.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -43,6 +43,36 @@ namespace cute { +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + template <> struct Copy_Traits { diff --git a/3rd/cutlass/include/cute/atom/copy_traits_sm100_im2col.hpp b/3rd/cutlass/include/cute/atom/copy_traits_sm100_im2col.hpp index cd3bf98..aa21402 100644 --- a/3rd/cutlass/include/cute/atom/copy_traits_sm100_im2col.hpp +++ b/3rd/cutlass/include/cute/atom/copy_traits_sm100_im2col.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/atom/copy_traits_sm100_tma.hpp b/3rd/cutlass/include/cute/atom/copy_traits_sm100_tma.hpp index 0212db1..f62971f 100644 --- a/3rd/cutlass/include/cute/atom/copy_traits_sm100_tma.hpp +++ b/3rd/cutlass/include/cute/atom/copy_traits_sm100_tma.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2021 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -135,6 +135,13 @@ struct Copy_Traits uint64_t*, // smem mbarrier uint64_t // cache hint > const opargs_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return get<0>(opargs_); + } }; ////////////////////////////////////////////////////////////////////////////// @@ -223,6 +230,13 @@ struct Copy_Traits uint16_t, // multicast mask uint64_t // cache hint > const opargs_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return get<0>(opargs_); + } }; //////////////////////////////////// diff --git a/3rd/cutlass/include/cute/atom/copy_traits_sm50.hpp b/3rd/cutlass/include/cute/atom/copy_traits_sm50.hpp index 5299894..79ccfc4 100644 --- a/3rd/cutlass/include/cute/atom/copy_traits_sm50.hpp +++ b/3rd/cutlass/include/cute/atom/copy_traits_sm50.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/atom/copy_traits_sm75.hpp b/3rd/cutlass/include/cute/atom/copy_traits_sm75.hpp index 416938b..2f44422 100644 --- a/3rd/cutlass/include/cute/atom/copy_traits_sm75.hpp +++ b/3rd/cutlass/include/cute/atom/copy_traits_sm75.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -140,4 +140,20 @@ struct Copy_Traits using RefLayout = DstLayout; }; +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_32, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_32, _1>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; } // end namespace cute diff --git a/3rd/cutlass/include/cute/atom/copy_traits_sm80.hpp b/3rd/cutlass/include/cute/atom/copy_traits_sm80.hpp index ab8d128..8d614bb 100644 --- a/3rd/cutlass/include/cute/atom/copy_traits_sm80.hpp +++ b/3rd/cutlass/include/cute/atom/copy_traits_sm80.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/atom/copy_traits_sm90.hpp b/3rd/cutlass/include/cute/atom/copy_traits_sm90.hpp index ad479df..1a7fab9 100644 --- a/3rd/cutlass/include/cute/atom/copy_traits_sm90.hpp +++ b/3rd/cutlass/include/cute/atom/copy_traits_sm90.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/atom/copy_traits_sm90_im2col.hpp b/3rd/cutlass/include/cute/atom/copy_traits_sm90_im2col.hpp index e4d1e3f..f720b87 100644 --- a/3rd/cutlass/include/cute/atom/copy_traits_sm90_im2col.hpp +++ b/3rd/cutlass/include/cute/atom/copy_traits_sm90_im2col.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -471,6 +471,17 @@ make_im2col_tma_copy_desc( tma_l2Promotion, tma_oob_fill); + int driver_version = 0; + cudaError_t driver_version_err = cudaDriverGetVersion(&driver_version); + assert(driver_version_err == cudaSuccess); + if (driver_version <= 13010) { + if (cute::bits_to_bytes( + cute::cosize(tensor_cwhdn.layout()) * + cute::sizeof_bits::value) < 131072) { + reinterpret_cast(&tma_desc)[1] &= ~(1llu << 21); + } + } + // The extra asserts help indicate the error's cause. assert(encode_result != CUDA_ERROR_DEINITIALIZED); assert(encode_result != CUDA_ERROR_NOT_INITIALIZED); diff --git a/3rd/cutlass/include/cute/atom/copy_traits_sm90_tma.hpp b/3rd/cutlass/include/cute/atom/copy_traits_sm90_tma.hpp index 209a844..78ac598 100644 --- a/3rd/cutlass/include/cute/atom/copy_traits_sm90_tma.hpp +++ b/3rd/cutlass/include/cute/atom/copy_traits_sm90_tma.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -51,8 +51,12 @@ template struct AuxTmaParams { using GmemStrides = GmemTmaBasisStrides_; // Strides for Gmem mode -> Tma coord mode, may be dynamic GmemStrides g_stride_; - using TmaGmemBasis = TmaGmemBasis_; // Layout for Tma box shape -> Gmem mode(s), always static - static_assert(is_static::value); + using TmaGmemBasis = TmaGmemBasis_; // Layout for Tma box shape -> Gmem mode(s) + // By default, TmaGmemBasis produced by construct_tma_gbasis is fully static. + // The user may construct a dynamic gbasis manually (e.g. to represent smem box with dynamic shape). + // In that case they will need to pass it around via other means. + // We avoid passing it as a data member to avoid ABI impact. + // static_assert(is_static::value); using TmaSwizzle = TmaSwizzle_; // Tma swizzle, always Swizzle static_assert(is_static::value); }; @@ -70,7 +74,7 @@ struct TMA_LOAD_Unpack { static_assert(is_smem::value, "SM90_TMA_LOAD requires the destination be shared memory."); - auto src_coord = src.data().coord_; + auto src_coord = src(Int<0>{}); void* dst_ptr = cute::raw_pointer_cast(dst.data()); #if 0 auto [c0,c1,c2,c3,c4] = append<5>(src_coord, 0); @@ -156,6 +160,13 @@ struct Copy_Traits copy_unpack(Copy_Traits const& traits, Tensor const& src, Tensor & dst) = delete; + + // Construct with updated TMA descriptor only (no barrier change) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(TmaDescriptor const* new_tma_desc) const { + return {*new_tma_desc, aux_params_}; + } }; // The executable SM90_TMA_LOAD with tma_desc and tma_mbar @@ -181,6 +192,13 @@ struct Copy_Traits CUTE_HOST_DEVICE Copy_Traits(TmaDescriptor const* desc, uint64_t* mbar, uint64_t cache) : opargs_(desc, mbar, cache) {} + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return get<0>(opargs_); + } }; // The prefetch for SM90_TMA_LOAD with tma_desc @@ -199,10 +217,22 @@ struct Copy_Traits tuple const opargs_; // Construct with any other Traits' TMA Desc - template + template + CUTE_HOST_DEVICE + Copy_Traits(OtherTraits const& traits) + : opargs_({traits.get_tma_descriptor()}) {} + + // Construct directly with a TMA descriptor pointer CUTE_HOST_DEVICE - Copy_Traits(Copy_Traits const& traits) - : opargs_({&traits.tma_desc_}) {} + Copy_Traits(TmaDescriptor const* desc) + : opargs_({desc}) {} + + // Build a new Prefetch traits with a different TMA descriptor pointer + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(TmaDescriptor const* new_tma_desc) const { + return {new_tma_desc}; + } template @@ -211,7 +241,7 @@ struct Copy_Traits Tensor const& src, Tensor & dst) { - auto src_coord = src.data().coord_; + auto src_coord = src(Int<0>{}); return detail::explode_tuple(detail::CallCOPY{}, traits.opargs_, tuple_seq{}, src_coord, tuple_seq{}); @@ -312,6 +342,13 @@ struct Copy_Traits CUTE_HOST_DEVICE Copy_Traits(TmaDescriptor const* desc, uint64_t* mbar, uint16_t mask, uint64_t hint) : opargs_(desc, mbar, mask, hint) {} + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return get<0>(opargs_); + } }; ////////////////////////////////////////////////////////////////////////////// @@ -372,7 +409,7 @@ struct Copy_Traits void const* const desc_ptr = &(traits.tma_desc_); void const* const src_ptr = cute::raw_pointer_cast(src.data()); - auto dst_coord = dst.data().coord_; + auto dst_coord = dst(Int<0>{}); #if 0 auto [c0,c1,c2,c3,c4] = append<5>(dst_coord, 0); printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", @@ -413,7 +450,7 @@ struct Copy_Traits void const* const desc_ptr = traits.tma_desc_; void const* const src_ptr = cute::raw_pointer_cast(src.data()); - auto dst_coord = dst.data().coord_; + auto dst_coord = dst(Int<0>{}); #if 0 auto [c0,c1,c2,c3,c4] = append<5>(dst_coord, 0); printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", @@ -498,7 +535,8 @@ struct Copy_Traits static_assert(is_smem::value, "Expected smem src for SM90_TMA_REDUCE_ADD"); //static_assert(is_gmem::value, "Expected gmem dst for SM90_TMA_REDUCE_ADD"); // TMA spoofed src tensor - traits.copy_unpack_(cute::raw_pointer_cast(src.data()), dst.data().coord_, tuple_seq{}); + auto dst_coord = dst(Int<0>{}); + traits.copy_unpack_(cute::raw_pointer_cast(src.data()), dst_coord, tuple_seq{}); } }; @@ -1018,6 +1056,17 @@ make_tma_copy_desc(Tensor const& gtensor, // The origin smem_swizzle, tma_l2Promotion, tma_oobFill); + + int driver_version = 0; + cudaError_t driver_version_err = cudaDriverGetVersion(&driver_version); + assert(driver_version_err == cudaSuccess); + if (driver_version <= 13010) { + if (cute::bits_to_bytes( + cute::cosize(gtensor.layout()) * + cute::sizeof_bits::value) < 131072) { + reinterpret_cast(&tma_desc)[1] &= ~(1llu << 21); + } + } if (result != CUDA_SUCCESS) { std::cerr << "TMA Desc Addr: " << &tma_desc diff --git a/3rd/cutlass/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp b/3rd/cutlass/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp index 9a44789..d938393 100644 --- a/3rd/cutlass/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp +++ b/3rd/cutlass/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -48,13 +48,11 @@ TMA::SmemSwizzleBits get_tma_swizzle_bits(Swizzle) { if constexpr (M == 4) { - switch (B) { - default: static_assert(0 <= B && B <= 3, "Expected B = 0,1,2, or 3 when M == 4. Unsupported layout swizzle."); - case 3: return TMA::SmemSwizzleBits::B128; - case 2: return TMA::SmemSwizzleBits::B64; - case 1: return TMA::SmemSwizzleBits::B32; - case 0: return TMA::SmemSwizzleBits::DISABLE; - } + static_assert(0 <= B && B <= 3, "Expected B = 0,1,2, or 3 when M == 4. Unsupported layout swizzle."); + if constexpr (B == 3) { return TMA::SmemSwizzleBits::B128; } + if constexpr (B == 2) { return TMA::SmemSwizzleBits::B64; } + if constexpr (B == 1) { return TMA::SmemSwizzleBits::B32; } + if constexpr (B == 0) { return TMA::SmemSwizzleBits::DISABLE; } } else if constexpr (M == 5 || M == 6) { diff --git a/3rd/cutlass/include/cute/atom/mma_atom.hpp b/3rd/cutlass/include/cute/atom/mma_atom.hpp index e2f9bdf..a66173d 100644 --- a/3rd/cutlass/include/cute/atom/mma_atom.hpp +++ b/3rd/cutlass/include/cute/atom/mma_atom.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -180,10 +180,10 @@ struct MMA_Atom> if constexpr (has_dereference::value) { // If the intended FrgTypeB is a view (of the current tensor), forward the whole static_assert(is_same::value_type>::value - + || (sizeof_bits_v::value_type> == 8 && (sizeof_bits_v == 8 || sizeof_bits_v == 6 || sizeof_bits_v == 4)) - + , "Expecting ValTypeB type"); return make_tensor(static_cast(btensor)); } else { @@ -262,7 +262,7 @@ struct TiledMMA : MMA_Atom make_layout(size<1>(AtomShape_MNK{}))); auto c_tensor = zipped_divide(t_tensor, c_tile); // ((AtomM,AtomN),(RestM,RestN)) - // Transform the Atom mode from (M,K) to (Thr,Val) + // Transform the Atom mode from (M,N) to (Thr,Val) auto tv_tensor = c_tensor.compose(AtomLayoutC_TV{},_); // ((ThrV,FrgV),(RestM,RestN)) // Tile the tensor for the C-threads @@ -340,7 +340,7 @@ struct TiledMMA : MMA_Atom make_layout(size<2>(AtomShape_MNK{}))); auto b_tensor = zipped_divide(t_tensor, b_tile); // ((AtomN,AtomK),(RestN,RestK)) - // Transform the Atom mode from (M,K) to (Thr,Val) + // Transform the Atom mode from (N,K) to (Thr,Val) auto tv_tensor = b_tensor.compose(AtomLayoutB_TV{},_); // ((ThrV,FrgV),(RestN,RestK)) // Tile the tensor for the Thread @@ -394,55 +394,22 @@ struct TiledMMA : MMA_Atom return size(permutation_mnk()); } - CUTE_HOST_DEVICE constexpr - auto - get_layoutC_MN() const - { - // (M,N) -> (M,N) - auto ref_C = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<1>())); - // (cthrid,val) -> (M,N) - auto layoutC_TV = thrfrg_C(ref_C); - // (M,N) -> (cthrid,frg) - auto layoutC_MN = right_inverse(layoutC_TV).with_shape(shape(ref_C)); - - // cthrid = (v,m,n) -> thr_idx - auto thrID_C = thr_layout_vmnk_(_,_,_,Int<0>{}); - - return cute::make_tuple(layoutC_MN, thrID_C); - } - CUTE_HOST_DEVICE constexpr auto get_layoutC_TV() const { // (M,N) -> (M,N) auto ref_C = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<1>())); - // (cthrid,val) -> (M,N) - auto layoutC_TV = thrfrg_C(ref_C); // thr_idx -> (ThrV,ThrM,ThrN,ThrK) - auto thridx_2_thrid = right_inverse(thr_layout_vmnk_); + auto thridx_2_thrid = composition(make_layout(make_shape (size(thr_layout_vmnk_), Int<1>{}), + make_stride(Int<1>{}, Int<0>{})), + right_inverse(make_layout(thr_layout_vmnk_, complement(thr_layout_vmnk_)))); // (thr_idx,val) -> (M,N) - return layoutC_TV.compose(thridx_2_thrid, _); + return thrfrg_C(ref_C).compose(thridx_2_thrid, _); } - CUTE_HOST_DEVICE constexpr - auto - get_layoutA_MK() const - { - // (M,K) -> (M,K) - auto ref_A = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<2>())); - // (athrid,val) -> (M,K) - auto layoutA_TV = thrfrg_A(ref_A); - // (M,K) -> (athrid,frg) - auto layoutA_MK = right_inverse(layoutA_TV).with_shape(shape(ref_A)); - - // athrid = (v,m,k) -> thr_idx - auto thrID_A = thr_layout_vmnk_(_,_,Int<0>{},_); - - return cute::make_tuple(layoutA_MK, thrID_A); - } CUTE_HOST_DEVICE constexpr auto @@ -458,29 +425,14 @@ struct TiledMMA : MMA_Atom _)); // thr_idx -> (ThrV,ThrM,ThrN,ThrK) - auto thridx_2_thrid = right_inverse(thr_layout_vmnk_); + auto thridx_2_thrid = composition(make_layout(make_shape (size(thr_layout_vmnk_), Int<1>{}), + make_stride(Int<1>{}, Int<0>{})), + right_inverse(make_layout(thr_layout_vmnk_, complement(thr_layout_vmnk_)))); // (thr_idx,val) -> (M,K) return thrfrg_A(ref_A).compose(atile, _).compose(thridx_2_thrid, _); } - CUTE_HOST_DEVICE constexpr - auto - get_layoutB_NK() const - { - // (N,K) -> (N,K) - auto ref_B = make_layout(make_shape(tile_size_mnk<1>(), tile_size_mnk<2>())); - // (bthrid,val) -> (N,K) - auto layoutB_TV = thrfrg_B(ref_B); - // (N,K) -> (bthrid,frg) - auto layoutB_NK = right_inverse(layoutB_TV).with_shape(shape(ref_B)); - - // bthrid = (v,n,k) -> thr_idx - auto thrID_B = thr_layout_vmnk_(_,Int<0>{},_,_); - - return cute::make_tuple(layoutB_NK, thrID_B); - } - CUTE_HOST_DEVICE constexpr auto get_layoutB_TV() const @@ -495,7 +447,9 @@ struct TiledMMA : MMA_Atom _)); // thr_idx -> (ThrV,ThrM,ThrN,ThrK) - auto thridx_2_thrid = right_inverse(thr_layout_vmnk_); + auto thridx_2_thrid = composition(make_layout(make_shape (size(thr_layout_vmnk_), Int<1>{}), + make_stride(Int<1>{}, Int<0>{})), + right_inverse(make_layout(thr_layout_vmnk_, complement(thr_layout_vmnk_)))); // (thr_idx,val) -> (N,K) return thrfrg_B(ref_B).compose(btile, _).compose(thridx_2_thrid, _); @@ -733,376 +687,6 @@ print(ThrMMA const& thr_mma) print(static_cast(thr_mma)); } -// MMA Atom to LaTeX TikZ -template -CUTE_HOST_DEVICE -void -print_latex(MMA_Atom const& mma_atom, - TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string -{ - print_latex(make_tiled_mma(mma_atom)); -} - -// TiledMMA to LaTeX TikZ -template -CUTE_HOST_DEVICE -void -print_latex(TiledMMA const& mma, - TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string -{ - auto layout_and_thrid_C = mma.get_layoutC_MN(); - auto layoutC_MN = get<0>(layout_and_thrid_C); - auto thrID_C = get<1>(layout_and_thrid_C); - - auto layout_and_thrid_A = mma.get_layoutA_MK(); - auto layoutA_MK = get<0>(layout_and_thrid_A); - auto thrID_A = get<1>(layout_and_thrid_A); - - auto layout_and_thrid_B = mma.get_layoutB_NK(); - auto layoutB_NK = get<0>(layout_and_thrid_B); - auto thrID_B = get<1>(layout_and_thrid_B); - - print_latex_mma(layoutC_MN, thrID_C, - layoutA_MK, thrID_A, - layoutB_NK, thrID_B); -} - -// MNK MMA Layout to LaTeX TikZ -template -CUTE_HOST_DEVICE -void -print_latex_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx - LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx - LayoutB const& B, ThrIDB const& TB, // (n,k) -> (tid,vid) and tid -> thr_idx - TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string -{ - CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{}); - CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{}); - CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{}); - - assert(size<0>(A) == size<0>(C)); - assert(size<0>(B) == size<1>(C)); - assert(size<1>(A) == size<1>(B)); - - // Commented prints - printf("%% LayoutC: "); print(C); printf("\n"); - printf("%% ThrIDC : "); print(TC); printf("\n"); - printf("%% LayoutA: "); print(A); printf("\n"); - printf("%% ThrIDA : "); print(TA); printf("\n"); - printf("%% LayoutB: "); print(B); printf("\n"); - printf("%% ThrIDB : "); print(TB); printf("\n\n"); - // Header - printf("\\documentclass[convert]{standalone}\n" - "\\usepackage{tikz}\n\n" - "\\begin{document}\n" - "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n"); - - // C starting at 0,0 - for (int m = 0; m < size<0>(C); ++m) { - for (int n = 0; n < size<1>(C); ++n) { - int thrid = C(m,n) % size(TC); - int val_idx = C(m,n) / size(TC); - int thr_idx = TC(thrid); - - printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color(thr_idx, val_idx), - m, n, - thr_idx, val_idx); - } - } - // Grid - printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", - 0, 0, int(size<0>(C)), int(size<1>(C))); - - // A starting at 0,-size<1>(A)-1 - for (int m = 0; m < size<0>(A); ++m) { - for (int k = 0; k < size<1>(A); ++k) { - int thrid = A(m,k) % size(TA); - int val_idx = A(m,k) / size(TA); - int thr_idx = TA(thrid); - - printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color(thr_idx, val_idx), - m, k-1-size<1>(A), - thr_idx, val_idx); - } - } - // Grid - printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", - 0, int(-size<1>(A)-1), int(size<0>(A)), -1); - // A labels - for (int m = 0, k = -1; m < size<0>(A); ++m) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k-1-size<1>(A), m); - } - for (int m = -1, k = 0; k < size<1>(A); ++k) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k-1-size<1>(A), k); - } - - // B starting at -size<1>(B)-1,0 - for (int n = 0; n < size<0>(B); ++n) { - for (int k = 0; k < size<1>(B); ++k) { - int thrid = B(n,k) % size(TB); - int val_idx = B(n,k) / size(TB); - int thr_idx = TB(thrid); - - printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color(thr_idx, val_idx), - k-1-size<1>(B), n, - thr_idx, val_idx); - } - } - // Grid - printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", - int(-size<1>(B)-1), 0, -1, int(size<0>(B))); - // B labels - for (int n = 0, k = -1; n < size<0>(B); ++n) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k-1-size<1>(B), n, n); - } - for (int n = -1, k = 0; k < size<1>(B); ++k) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k-1-size<1>(B), n, k); - } - - // Footer - printf("\\end{tikzpicture}\n" - "\\end{document}\n"); -} - -// MNK MMA Layout to console printer -template -CUTE_HOST_DEVICE -void -print_layout_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx - LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx - LayoutB const& B, ThrIDB const& TB) // (n,k) -> (tid,vid) and tid -> thr_idx -{ - CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{}); - CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{}); - CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{}); - - assert(size<0>(A) == size<0>(C)); - assert(size<0>(B) == size<1>(C)); - assert(size<1>(A) == size<1>(B)); - - int a_width = size<1>(A) * 6 + 4; - - // Print out B (white-shifted) k-by-n - for (int k = 0; k < size<1>(B); ++k) { - // Header - printf("%*s", a_width, ""); - for (int n = 0; n < size<0>(B); ++n) printf("+-----"); - printf("+\n"); - // Values - printf("%*s", a_width, ""); - for (int n = 0; n < size<0>(B); ++n) printf("|T%02dV%1d", int(TB(B(n,k) % size(TB))), int(B(n,k) / size(TB))); - printf("|\n"); - } - // Footer - printf("%*s", a_width, ""); - for (int n = 0; n < size<0>(B); ++n) printf("+-----"); - printf("+\n\n"); - - // Print out A m-by-k and C m-by-n - for (int m = 0; m < size<0>(A); ++m) { - // Header - for (int k = 0; k < size<1>(A); ++k) printf("+-----"); - printf("+ "); - for (int n = 0; n < size<1>(C); ++n) printf("+-----"); - printf("+\n"); - // Values - for (int k = 0; k < size<1>(A); ++k) printf("|T%02dV%1d", int(TA(A(m,k) % size(TA))), int(A(m,k) / size(TA))); - printf("| "); - for (int n = 0; n < size<1>(C); ++n) printf("|T%02dV%1d", int(TC(C(m,n) % size(TC))), int(C(m,n) / size(TC))); - printf("|\n"); - } - // Footer - for (int k = 0; k < size<1>(A); ++k) printf("+-----"); - printf("+ "); - for (int n = 0; n < size<1>(C); ++n) printf("+-----"); - printf("+\n"); -} - -// MNK MMA Layout to SVG -- 8-value color coded by thread -template -CUTE_HOST_DEVICE -void -print_svg_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx - LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx - LayoutB const& B, ThrIDB const& TB) // (n,k) -> (tid,vid) and tid -> thr_idx -{ - char const *color_map[8] = {"175,175,255", "175,255,175", "255,255,175", - "255,175,175", "210,210,255", "210,255,210", - "255,255,210", "255,210,210"}; - - const int cell_width = 20; - const int cell_height = 20; - - const int page_width = (size<1>(A) + size<0>(B) + 2) * cell_width; - const int page_height = (size<1>(B) + size<0>(A) + 2) * cell_height; - - // header - printf("\n", - page_width, page_height); - - // C - int c_base_x = (size<1>(A) + 2) * cell_width; - int c_base_y = (size<1>(B) + 2) * cell_height; - for (int m = 0; m < cute::size<0>(C); ++m) { - for (int n = 0; n < cute::size<1>(C); ++n) { - - int thrid = C(m, n) % size(TC); - int val_idx = C(m, n) / size(TC); - int thr_idx = TC(thrid); - - int x = n * cell_width + c_base_x; - int y = m * cell_height + c_base_y; - - int thr_x = x + cell_width / 2; - int thr_y = y + cell_height / 4; - int val_x = x + cell_width / 2; - int val_y = y + cell_height * 3 / 4; - - printf("\n", - x, y, cell_width, cell_height, color_map[thr_idx % 8]); - - printf("T%d\n", - thr_x, thr_y, thr_idx); - printf("V%d\n", - val_x, val_y, val_idx); - } - } - - // A - int a_base_x = cell_width; - int a_base_y = (size<1>(B) + 2) * cell_height; - for (int m = 0; m < size<0>(A); ++m) { - for (int k = 0; k < size<1>(A); ++k) { - int thrid = A(m, k) % size(TA); - int val_idx = A(m, k) / size(TA); - int thr_idx = TA(thrid); - - int x = k * cell_width + a_base_x; - int y = m * cell_height + a_base_y; - - int thr_x = x + cell_width / 2; - int thr_y = y + cell_height / 4; - int val_x = x + cell_width / 2; - int val_y = y + cell_height * 3 / 4; - - printf("\n", - x, y, cell_width, cell_height, color_map[thr_idx % 8]); - printf("T%d\n", - thr_x, thr_y, thr_idx); - printf("V%d\n", - val_x, val_y, val_idx); - } - } - - // B - int b_base_x = (size<1>(A) + 2) * cell_width; - int b_base_y = cell_height; - for (int n = 0; n < size<0>(B); ++n) { - for (int k = 0; k < size<1>(B); ++k) { - int thrid = B(n, k) % size(TB); - int val_idx = B(n, k) / size(TB); - int thr_idx = TB(thrid); - - int x = n * cell_width + b_base_x; - int y = k * cell_height + b_base_y; - - int thr_x = x + cell_width / 2; - int thr_y = y + cell_height / 4; - int val_x = x + cell_width / 2; - int val_y = y + cell_height * 3 / 4; - - printf("\n", - x, y, cell_width, cell_height, color_map[thr_idx % 8]); - printf("T%d\n", - thr_x, thr_y, thr_idx); - printf("V%d\n", - val_x, val_y, val_idx); - } - } - - // A labels - for (int m = 0; m < size<0>(A); ++m) { - int x = cell_width / 2; - int y = m * cell_height + cell_height / 2 + a_base_y; - printf("%d\n", - x, y, m); - } - for (int k = 0; k < size<1>(A); ++k) { - int x = cell_width + k * cell_width + cell_width / 2; - int y = -cell_height / 2 + a_base_y; - printf("%d\n", - x, y, k); - } - - // B labels - for (int n = 0; n < size<0>(B); ++n) { - int x = b_base_x + cell_width * n + cell_width / 2; - int y = cell_height / 2; - printf("%d\n", - x, y, n); - } - for (int k = 0; k < size<1>(B); ++k) { - int x = b_base_x - cell_width / 2; - int y = cell_height * (k + 1) + cell_height / 2; - printf("%d\n", - x, y, k); - } - - // footer - printf("\n"); -} - -template -CUTE_HOST_DEVICE -void -print_svg(MMA_Atom const &mma_atom) { - print_svg(make_tiled_mma(mma_atom)); -} - -template -CUTE_HOST_DEVICE -void -print_svg(TiledMMA const &mma) { - auto layout_and_thrid_C = mma.get_layoutC_MN(); - auto layoutC_MN = get<0>(layout_and_thrid_C); - auto thrID_C = get<1>(layout_and_thrid_C); - - auto layout_and_thrid_A = mma.get_layoutA_MK(); - auto layoutA_MK = get<0>(layout_and_thrid_A); - auto thrID_A = get<1>(layout_and_thrid_A); - - auto layout_and_thrid_B = mma.get_layoutB_NK(); - auto layoutB_NK = get<0>(layout_and_thrid_B); - auto thrID_B = get<1>(layout_and_thrid_B); - - print_svg_mma(layoutC_MN, thrID_C, layoutA_MK, thrID_A, layoutB_NK, thrID_B); -} - } // namespace cute //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1114,7 +698,7 @@ print_svg(TiledMMA const &mma) { #include #include #include -#include +#include #include #include diff --git a/3rd/cutlass/include/cute/atom/mma_traits.hpp b/3rd/cutlass/include/cute/atom/mma_traits.hpp index de24b64..291d914 100644 --- a/3rd/cutlass/include/cute/atom/mma_traits.hpp +++ b/3rd/cutlass/include/cute/atom/mma_traits.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -41,10 +41,10 @@ namespace cute /** * concept MMA_Traits * { - * using ValTypeD = // Logical A-value type - * using ValTypeA = // Logical B-value type - * using ValTypeB = // Logical C-value type - * using ValTypeC = // Logical D-value type (NOTE: Not used? Assumed == ValTypeD) + * using ValTypeD = // Logical D-value type + * using ValTypeA = // Logical A-value type + * using ValTypeB = // Logical B-value type + * using ValTypeC = // Logical C-value type (NOTE: Not used? Assumed == ValTypeD) * * using FrgTypeA = // A-type consumed by MMA (if ommitted, same as ValTypeA) * using FrgTypeB = // B_type consumed by MMA (if ommitted, same as ValTypeB) diff --git a/3rd/cutlass/include/cute/atom/mma_traits_sm100.hpp b/3rd/cutlass/include/cute/atom/mma_traits_sm100.hpp index 820dc10..c9b42a8 100644 --- a/3rd/cutlass/include/cute/atom/mma_traits_sm100.hpp +++ b/3rd/cutlass/include/cute/atom/mma_traits_sm100.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2022 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -1353,6 +1353,186 @@ struct MMA_Traits scaleC) const { return {accumulate, idesc_}; } + + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(cute::integral_constant) const { + return {accumulate_, UMMA::make_instr_desc()}; + } +}; + +// Special instantiation for interleaved complex (emulated) +template +struct MMA_Traits, cutlass::complex, float, + M, N, a_major, b_major, + ScaleC, a_neg, b_neg>> +{ + static_assert(cute::sizeof_bits_v == 16, "Only supports 16bit base types"); + using a_type = complex; + using b_type = complex; + using c_type = float; + + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // Logical shape-K is always 256bits, transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + static constexpr uint32_t ScalingFactor = ScaleC; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + ab_vtype, ab_vtype, float, M, N, a_major, b_major, a_neg, b_neg>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_SS_SCALED::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + + } + + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(UMMA::ScaleOut accumulate, cute::integral_constant scaleC) const { + return {accumulate, idesc_}; + } + + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(cute::integral_constant) const { + return {accumulate_, UMMA::make_instr_desc()}; + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 32, "SM100_MMA_TF32_TS_SCALED supports 32bit types"); + + using FrgTypeA = UMMA::tmem_frg_1sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // Logical shape-K is always 256 bits; transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + static constexpr uint32_t ScalingFactor = ScaleC; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint32_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_TF32_TS_SCALED::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } + + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(UMMA::ScaleOut accumulate, cute::integral_constant scaleC) const { + return {accumulate, idesc_}; + } + + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(cute::integral_constant) const { + return {accumulate_, UMMA::make_instr_desc()}; + } }; template scaleC) const { return {accumulate, idesc_}; } + + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(cute::integral_constant) const { + return {accumulate_, UMMA::make_instr_desc()}; + } }; +// Special instantiation for interleaved complex (emulated) +template +struct MMA_Traits, cutlass::complex, float, + M, N, a_major, b_major, + ScaleC, a_neg, b_neg, c_sat>> +{ + static_assert(cute::sizeof_bits_v == 16, "Only supports 16bit base types"); + using a_type = complex; + using b_type = complex; + using c_type = float; + + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + + using FrgTypeA = UMMA::tmem_frg_1sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // Logical shape-K is always 256 bits; transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + static constexpr uint32_t ScalingFactor = ScaleC; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + ab_vtype, ab_vtype, float, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint32_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_TS_SCALED::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } + + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(UMMA::ScaleOut accumulate, cute::integral_constant scaleC) const { + return {accumulate, idesc_}; + } + + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(cute::integral_constant) const { + return {accumulate_, UMMA::make_instr_desc()}; + } +}; + + template +struct MMA_Traits> +{ + using a_type = complex; + using b_type = complex; + using c_type = float; + + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + + using FrgTypeA = UMMA::tmem_frg_1sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // Logical shape-K is always 256 bits; transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + // MMA based interleaved complex GEMM calculates realAcc and imagAcc separately. + // This MMA_traits is used to calculate 1 of the GEMMs below : + // 1. realAcc = realA * realB + (-imagA) * imagB + // 2. imagAcc = imagA * realB + realA * imagB + // So it requires complex type operand A&B and float type operand Acc. + static_assert(a_major == UMMA::Major::K, "SM100_MMA_TF32_TS_INTERLEAVED_CF32CTF32CTF32CF32_TN A from TMEM can't be transposed"); + static_assert(b_major == UMMA::Major::K, "SM100_MMA_TF32_TS_INTERLEAVED_CF32CTF32CTF32CF32_TN B from SMEM requires non-transpose"); + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + tfloat32_t, tfloat32_t, float, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint32_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_TF32_TS_INTERLEAVED_CF32CTF32CTF32CF32_TN< + M, N, + a_major, b_major, + a_neg, b_neg, c_sat>::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + template == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_TS supports 16bit types"); + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_TS supports 16bit types"); + + using FrgTypeA = UMMA::tmem_frg_2sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions' K extent is always 256 bits; convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_2x1SM_TS::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_SS_SCALED supports 16bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions's K extent is always 256bits, convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + constexpr static uint32_t ScalingFactor = ScaleC; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_2x1SM_SS_SCALED::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } + + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(UMMA::ScaleOut accumulate, cute::integral_constant scaleC) const { + return {accumulate, idesc_}; + } + + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(cute::integral_constant) const { + return {accumulate_, UMMA::make_instr_desc()}; + } +}; + + +// Special instantiation for interleaved complex (emulated) +template +struct MMA_Traits, cutlass::complex, float, + M, N, a_major, b_major, + ScaleC, a_neg, b_neg>> +{ + static_assert(cute::sizeof_bits_v == 16, "Only supports 16bit base types"); + using a_type = complex; + using b_type = complex; + using c_type = float; + + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions's K extent is always 256bits, convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + constexpr static uint32_t ScalingFactor = ScaleC; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + ab_vtype, ab_vtype, float, M, N, a_major, b_major, a_neg, b_neg>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_2x1SM_SS_SCALED::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } + + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(UMMA::ScaleOut accumulate, cute::integral_constant scaleC) const { + return {accumulate, idesc_}; + } + + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(cute::integral_constant) const { + return {accumulate_, UMMA::make_instr_desc()}; + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 32, "SM100_MMA_TF32_2x1SM_TS_SCALED supports 32bit types"); using FrgTypeA = UMMA::tmem_frg_2sm; using FrgTypeB = UMMA::smem_desc; @@ -1822,6 +2408,7 @@ struct MMA_Traits::value; + constexpr static uint32_t ScalingFactor = ScaleC; using Shape_MNK = Shape,Int,Int>; using ThrID = Layout<_2>; @@ -1860,31 +2447,49 @@ struct MMA_Traits(traits.idesc_); - SM100_MMA_F16BF16_2x1SM_TS::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + SM100_MMA_TF32_2x1SM_TS_SCALED::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } + + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(UMMA::ScaleOut accumulate, cute::integral_constant scaleC) const { + return {accumulate, idesc_}; + } + + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(cute::integral_constant) const { + return {accumulate_, UMMA::make_instr_desc()}; } }; + template -struct MMA_Traits +struct MMA_Traits> + ScaleC, a_neg, b_neg, c_sat>> { using ValTypeD = c_type; using ValTypeA = a_type; using ValTypeB = b_type; using ValTypeC = c_type; - static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_SS_SCALED supports 16bit types"); + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_TS_SCALED supports 16bit types"); - using FrgTypeA = UMMA::smem_desc; + using FrgTypeA = UMMA::tmem_frg_2sm; using FrgTypeB = UMMA::smem_desc; using FrgTypeC = UMMA::tmem_frg_2sm; - // Size of instructions's K extent is always 256bits, convert to units of element + // Size of instructions' K extent is always 256 bits; convert to units of element constexpr static int K = 256 / cute::sizeof_bits::value; constexpr static uint32_t ScalingFactor = ScaleC; @@ -1901,7 +2506,7 @@ struct MMA_Traits(); + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); template const& C) { static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); - static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected desc registers in MMA_Atom::call"); static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); - uint64_t desc_a = A[0]; + uint64_t tmem_a = raw_pointer_cast(A.data()); uint64_t desc_b = B[0]; uint32_t tmem_c = raw_pointer_cast(D.data()); uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); - SM100_MMA_F16BF16_2x1SM_SS_SCALED::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + ScaleC, a_neg, b_neg, c_sat>::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); } template CUTE_HOST_DEVICE constexpr - MMA_Traits> + MMA_Traits> with(UMMA::ScaleOut accumulate, cute::integral_constant scaleC) const { return {accumulate, idesc_}; } + + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(cute::integral_constant) const { + return {accumulate_, UMMA::make_instr_desc()}; + } }; -template -struct MMA_Traits, cutlass::complex, float, M, N, a_major, b_major, ScaleC, a_neg, b_neg, c_sat>> { + static_assert(cute::sizeof_bits_v == 16, "Only supports 16bit base types"); + using a_type = complex; + using b_type = complex; + using c_type = float; + using ValTypeD = c_type; using ValTypeA = a_type; using ValTypeB = b_type; using ValTypeC = c_type; - static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_TS_SCALED supports 16bit types"); using FrgTypeA = UMMA::tmem_frg_2sm; using FrgTypeB = UMMA::smem_desc; @@ -1974,7 +2593,7 @@ struct MMA_Traits(); + ab_vtype, ab_vtype, float, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); template scaleC) const { return {accumulate, idesc_}; } + + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(cute::integral_constant) const { + return {accumulate_, UMMA::make_instr_desc()}; + } }; + template +struct MMA_Traits> +{ + using a_type = complex; + using b_type = complex; + using c_type = float; + + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + + using FrgTypeA = UMMA::tmem_frg_2sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions' K extent is always 256 bits; convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + // Interleaved complex GEMM calculates realAcc and imagAcc separately. + // This MMA_traits is used to calculate 1 of the GEMMs below : + // 1. realAcc = realA * realB + (-imagA) * imagB + // 2. imagAcc = imagA * realB + realA * imagB + // So it requires complex type operand A&B and float type operand Acc. + static_assert(a_major == UMMA::Major::K, "SM100_MMA_TF32_2x1SM_TS_INTERLEAVED_CF32CTF32CTF32CF32_TN A from TMEM can't be transposed"); + static_assert(b_major == UMMA::Major::K, "SM100_MMA_TF32_2x1SM_TS_INTERLEAVED_CF32CTF32CTF32CF32_TN B from SMEM requires non-transpose"); + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + tfloat32_t, tfloat32_t, float, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_TF32_2x1SM_TS_INTERLEAVED_CF32CTF32CTF32CF32_TN< + M, N, + a_major, b_major, + a_neg, b_neg, c_sat>::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + template @@ -2639,10 +3344,10 @@ struct MMA_Traits <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4_SS supports types with leq 8bit types"); static_assert(M == 64 || M == 128, "SM100_MMA_F8F6F4_SS M-mode size should be 64 or 128 for 1 CTA cluster MMA."); - static_assert((M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) || - (M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)), - "SM100_MMA_F8F6F4_SS N-mode size should be a multiple of 8 between 8 and 256 for M=64,\ - or a multiple of 16 between 16 and 256 for M=128."); + static_assert(((b_major == UMMA::Major::K) && ((N % 8 == 0) && (8 <= N) && (N <= 256))) || + ((b_major == UMMA::Major::MN) && ((N % 16 == 0) && (16 <= N) && (N <= 256))), + "SM100_MMA_F8F6F4_SS N-mode size should be a multiple of 8 between 8 and 256 when B is K major. \ + SM100_MMA_F8F6F4_SS N-mode size should be a multiple of 16 between 16 and 256 when B is MN major."); using FrgTypeA = UMMA::smem_desc; using FrgTypeB = UMMA::smem_desc; using FrgTypeC = UMMA::tmem_frg_1sm; @@ -3051,14 +3756,16 @@ struct MMA_Traits, cute::integral_constant> { - using ValTypeD = c_type; using ValTypeA = a_type; using ValTypeB = b_type; using ValTypeC = c_type; static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4_2x1SM_SS supports types with leq 8bit types"); static_assert(M == 128 || M == 256, "SM100_MMA_F8F6F4_2x1SM_SS M-mode size should be 64 or 128 for 1 CTA cluster MMA."); - static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F8F6F4_2x1SM_SS N-mode size should be a multiple of 32 between 32 and 256."); + static_assert(((b_major == UMMA::Major::K) && ((N % 16 == 0) && (16 <= N) && (N <= 256))) || + ((b_major == UMMA::Major::MN) && ((N % 32 == 0) && (32 <= N) && (N <= 256))), + "SM100_MMA_F8F6F4_2x1SM_SS N-mode size should be a multiple of 16 between 16 and 256 when B is K major. \ + SM100_MMA_F8F6F4_2x1SM_SS N-mode size should be a multiple of 32 between 32 and 256 when B is MN major."); using FrgTypeA = UMMA::smem_desc; using FrgTypeB = UMMA::smem_desc; @@ -3844,4 +4551,187 @@ struct MMA_Traits +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = float; + using ValTypeB = float; + using ValTypeC = float; + + using Shape_MNK = Shape<_2,_1,_1>; + using ThrID = Layout<_1>; + + using ALayout = Layout>; + using BLayout = Layout>; + using CLayout = Layout>; +}; + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = float; + using ValTypeB = float; + using ValTypeC = float; + + using Shape_MNK = Shape<_1,_2,_1>; + using ThrID = Layout<_1>; + + using ALayout = Layout>; + using BLayout = Layout>; + using CLayout = Layout>; +}; + +namespace SM103 { + // Common mma_unpack for all MMA_Ops in cute::SM103 +template +CUTE_HOST_DEVICE constexpr +void +mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& zA, + Tensor const& zB, + Tensor const& C) + { + auto [A, next_A, SFA] = unzip_tensor(zA); + auto [B, next_B, SFB] = unzip_tensor(zB); + + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_next_a = next_A[0]; + uint64_t desc_b = B[0]; + uint64_t desc_next_b = next_B[0]; + + auto desc_a_temp = reinterpret_cast(desc_a); + auto desc_next_a_temp = reinterpret_cast(desc_next_a); + desc_a_temp.lbo_mode_ = 1; + desc_a_temp.leading_byte_offset_ = desc_next_a_temp.start_address_; + + auto desc_b_temp = reinterpret_cast(desc_b); + auto desc_next_b_temp = reinterpret_cast(desc_next_b); + desc_b_temp.lbo_mode_ = 1; + desc_b_temp.leading_byte_offset_ = desc_next_b_temp.start_address_; + + uint32_t tmem_c = raw_pointer_cast(D.data()); + UMMA::InstrDescriptorBlockScaled instr_desc = traits.idesc_; + instr_desc.k_size_ = 1; + auto tsfa_addr = raw_pointer_cast(SFA.data()); + auto tsfb_addr = raw_pointer_cast(SFB.data()); + + uint64_t idesc = UMMA::make_runtime_instr_desc_block_scaled<>(instr_desc, tsfa_addr, tsfb_addr); + // print("a: "); print(A); print("\n"); + // print("b: "); print(B); print("\n"); + + MMA_Op::fma(reinterpret_cast(desc_a_temp), reinterpret_cast(desc_b_temp), tmem_c, uint32_t(traits.accumulate_), idesc, tsfa_addr, tsfb_addr); + } +} // end namespace SM103 + + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + using ValTypeSFA = sf_type; + using ValTypeSFB = sf_type; + + // Logical shape-K is always 256bits, transform to units of elements + constexpr static int K = 96; + constexpr static int SFVecSize = VS; + + static_assert(a_major == UMMA::Major::K && b_major == UMMA::Major::K, "This MMA does not support transpose"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + using FrgTypeSFA = UMMA::tmem_sf_frg; + using FrgTypeSFB = UMMA::tmem_sf_frg; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + using MMA_ScaleFactor = SM100_MMA_MXF4_SS; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptorBlockScaled idesc_ = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, a_major, b_major, a_neg, b_neg>(); +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + using ValTypeSFA = sf_type; + using ValTypeSFB = sf_type; + + // Logical shape-K is always 256bits, transform to units of elements + constexpr static int K = 96; + constexpr static int SFVecSize = VS; + + static_assert(a_major == UMMA::Major::K && b_major == UMMA::Major::K, "This MMA does not support transpose"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + constexpr static UMMA::TmemAllocMode TmemAlloc = M == 128 ? + UMMA::TmemAllocMode::ScaleFactorDuplicated2by2 : UMMA::TmemAllocMode::ScaleFactorDuplicated4by1; + using FrgTypeSFA = UMMA::tmem_sf_frg; + using FrgTypeSFB = UMMA::tmem_sf_frg; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + using MMA_ScaleFactor = SM100_MMA_MXF4_SS 64 ? M/2 : M), (round_up(N, 128)), VS, a_major, b_major, + a_neg, b_neg>; + + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptorBlockScaled idesc_ = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, a_major, b_major, a_neg, b_neg>(); +}; + } // end namespace cute diff --git a/3rd/cutlass/include/cute/atom/mma_traits_sm120.hpp b/3rd/cutlass/include/cute/atom/mma_traits_sm120.hpp index e339980..3365f18 100644 --- a/3rd/cutlass/include/cute/atom/mma_traits_sm120.hpp +++ b/3rd/cutlass/include/cute/atom/mma_traits_sm120.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/atom/mma_traits_sm120_sparse.hpp b/3rd/cutlass/include/cute/atom/mma_traits_sm120_sparse.hpp index b62c576..7e81c09 100644 --- a/3rd/cutlass/include/cute/atom/mma_traits_sm120_sparse.hpp +++ b/3rd/cutlass/include/cute/atom/mma_traits_sm120_sparse.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/atom/mma_traits_sm61.hpp b/3rd/cutlass/include/cute/atom/mma_traits_sm61.hpp index 6b12903..7ea2119 100644 --- a/3rd/cutlass/include/cute/atom/mma_traits_sm61.hpp +++ b/3rd/cutlass/include/cute/atom/mma_traits_sm61.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/atom/mma_traits_sm70.hpp b/3rd/cutlass/include/cute/atom/mma_traits_sm70.hpp index 0b5b530..91f2058 100644 --- a/3rd/cutlass/include/cute/atom/mma_traits_sm70.hpp +++ b/3rd/cutlass/include/cute/atom/mma_traits_sm70.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/atom/mma_traits_sm75.hpp b/3rd/cutlass/include/cute/atom/mma_traits_sm75.hpp index d60c65f..a4e0deb 100644 --- a/3rd/cutlass/include/cute/atom/mma_traits_sm75.hpp +++ b/3rd/cutlass/include/cute/atom/mma_traits_sm75.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/atom/mma_traits_sm80.hpp b/3rd/cutlass/include/cute/atom/mma_traits_sm80.hpp index f7d5d2f..357921e 100644 --- a/3rd/cutlass/include/cute/atom/mma_traits_sm80.hpp +++ b/3rd/cutlass/include/cute/atom/mma_traits_sm80.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/atom/mma_traits_sm89.hpp b/3rd/cutlass/include/cute/atom/mma_traits_sm89.hpp index 35ad436..50e6e33 100644 --- a/3rd/cutlass/include/cute/atom/mma_traits_sm89.hpp +++ b/3rd/cutlass/include/cute/atom/mma_traits_sm89.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -67,8 +67,8 @@ struct MMA_Traits { }; template <> -struct MMA_Traits : -MMA_Traits { +struct MMA_Traits + : MMA_Traits { using ValTypeD = float; using ValTypeA = float_e4m3_t; using ValTypeB = float_e5m2_t; @@ -76,8 +76,8 @@ MMA_Traits { }; template <> -struct MMA_Traits : -MMA_Traits { +struct MMA_Traits + : MMA_Traits { using ValTypeD = float; using ValTypeA = float_e5m2_t; using ValTypeB = float_e5m2_t; @@ -85,12 +85,48 @@ MMA_Traits { }; template <> -struct MMA_Traits : -MMA_Traits { +struct MMA_Traits + : MMA_Traits { using ValTypeD = float; using ValTypeA = float_e5m2_t; using ValTypeB = float_e4m3_t; using ValTypeC = float; }; +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = cutlass::half_t; + using ValTypeA = cutlass::float_e4m3_t; + using ValTypeB = cutlass::float_e4m3_t; + using ValTypeC = cutlass::half_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = cutlass::half_t; + using ValTypeA = cutlass::float_e4m3_t; + using ValTypeB = cutlass::float_e5m2_t; + using ValTypeC = cutlass::half_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = cutlass::half_t; + using ValTypeA = cutlass::float_e5m2_t; + using ValTypeB = cutlass::float_e5m2_t; + using ValTypeC = cutlass::half_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = cutlass::half_t; + using ValTypeA = cutlass::float_e5m2_t; + using ValTypeB = cutlass::float_e4m3_t; + using ValTypeC = cutlass::half_t; +}; + } // end namespace cute diff --git a/3rd/cutlass/include/cute/atom/mma_traits_sm90.hpp b/3rd/cutlass/include/cute/atom/mma_traits_sm90.hpp index 0467dec..e2667ff 100644 --- a/3rd/cutlass/include/cute/atom/mma_traits_sm90.hpp +++ b/3rd/cutlass/include/cute/atom/mma_traits_sm90.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/atom/mma_traits_sm90_gmma.hpp b/3rd/cutlass/include/cute/atom/mma_traits_sm90_gmma.hpp index e1c3bb4..4943214 100644 --- a/3rd/cutlass/include/cute/atom/mma_traits_sm90_gmma.hpp +++ b/3rd/cutlass/include/cute/atom/mma_traits_sm90_gmma.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/atom/mma_traits_sm90_gmma_ext.hpp b/3rd/cutlass/include/cute/atom/mma_traits_sm90_gmma_ext.hpp index 3cab34d..bb4aad0 100644 --- a/3rd/cutlass/include/cute/atom/mma_traits_sm90_gmma_ext.hpp +++ b/3rd/cutlass/include/cute/atom/mma_traits_sm90_gmma_ext.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp b/3rd/cutlass/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp index 13ff07c..ba6981b 100644 --- a/3rd/cutlass/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp +++ b/3rd/cutlass/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/atom/mma_traits_sm90_gmma_sparse_ext.hpp b/3rd/cutlass/include/cute/atom/mma_traits_sm90_gmma_sparse_ext.hpp index fc28b8a..265e5cf 100644 --- a/3rd/cutlass/include/cute/atom/mma_traits_sm90_gmma_sparse_ext.hpp +++ b/3rd/cutlass/include/cute/atom/mma_traits_sm90_gmma_sparse_ext.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/atom/partitioner.hpp b/3rd/cutlass/include/cute/atom/partitioner.hpp index 75a55cc..8aedb6a 100644 --- a/3rd/cutlass/include/cute/atom/partitioner.hpp +++ b/3rd/cutlass/include/cute/atom/partitioner.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -31,8 +31,9 @@ #pragma once +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(type_traits) #else #include #endif diff --git a/3rd/cutlass/include/cute/config.hpp b/3rd/cutlass/include/cute/config.hpp index 538472c..fc5b116 100644 --- a/3rd/cutlass/include/cute/config.hpp +++ b/3rd/cutlass/include/cute/config.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/container/alignment.hpp b/3rd/cutlass/include/cute/container/alignment.hpp index f285004..4497185 100644 --- a/3rd/cutlass/include/cute/container/alignment.hpp +++ b/3rd/cutlass/include/cute/container/alignment.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/container/array.hpp b/3rd/cutlass/include/cute/container/array.hpp index a431fc4..0e57859 100644 --- a/3rd/cutlass/include/cute/container/array.hpp +++ b/3rd/cutlass/include/cute/container/array.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -391,9 +391,9 @@ cute::array reverse(cute::array const& t) // // Specialize tuple-related functionality for cute::array // - +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(tuple) #else #include #endif @@ -447,12 +447,18 @@ struct tuple_element> namespace std { +#if (__CUDACC_VER_MAJOR__ >= 13) + +#include + +#else #if defined(__CUDACC_RTC__) -template -struct tuple_size; + template + struct tuple_size; -template -struct tuple_element; + template + struct tuple_element; +#endif #endif template diff --git a/3rd/cutlass/include/cute/container/array_aligned.hpp b/3rd/cutlass/include/cute/container/array_aligned.hpp index 6491f72..031b960 100644 --- a/3rd/cutlass/include/cute/container/array_aligned.hpp +++ b/3rd/cutlass/include/cute/container/array_aligned.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/container/array_subbyte.hpp b/3rd/cutlass/include/cute/container/array_subbyte.hpp index 38da7ac..2e63471 100644 --- a/3rd/cutlass/include/cute/container/array_subbyte.hpp +++ b/3rd/cutlass/include/cute/container/array_subbyte.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -555,9 +555,9 @@ void fill(array_subbyte& a, T const& value) // // Specialize tuple-related functionality for cute::array_subbyte // - +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(tuple) #else #include #endif @@ -617,12 +617,18 @@ struct tuple_element> namespace std { +#if (__CUDACC_VER_MAJOR__ >= 13) + +#include + +#else #if defined(__CUDACC_RTC__) -template -struct tuple_size; + template + struct tuple_size; -template -struct tuple_element; + template + struct tuple_element; +#endif #endif template diff --git a/3rd/cutlass/include/cute/container/bit_field.hpp b/3rd/cutlass/include/cute/container/bit_field.hpp index cecdaee..5dd1c03 100644 --- a/3rd/cutlass/include/cute/container/bit_field.hpp +++ b/3rd/cutlass/include/cute/container/bit_field.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/container/cuda_types.hpp b/3rd/cutlass/include/cute/container/cuda_types.hpp index 5615fde..41170ec 100644 --- a/3rd/cutlass/include/cute/container/cuda_types.hpp +++ b/3rd/cutlass/include/cute/container/cuda_types.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/container/tuple.hpp b/3rd/cutlass/include/cute/container/tuple.hpp index e3dd6d2..f7cfca9 100644 --- a/3rd/cutlass/include/cute/container/tuple.hpp +++ b/3rd/cutlass/include/cute/container/tuple.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -701,12 +701,18 @@ struct tuple_element> namespace std { +#if (__CUDACC_VER_MAJOR__ >= 13) + +#include + +#else #if defined(__CUDACC_RTC__) -template -struct tuple_size; + template + struct tuple_size; -template -struct tuple_element; + template + struct tuple_element; +#endif #endif template diff --git a/3rd/cutlass/include/cute/container/type_list.hpp b/3rd/cutlass/include/cute/container/type_list.hpp index dfffbe2..62ebf9d 100644 --- a/3rd/cutlass/include/cute/container/type_list.hpp +++ b/3rd/cutlass/include/cute/container/type_list.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -77,9 +77,9 @@ find(type_list const&) noexcept { // // Specialize tuple-related functionality for cute::type_list // - +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(tuple) #else #include #endif @@ -103,12 +103,18 @@ struct tuple_element> namespace std { +#if (__CUDACC_VER_MAJOR__ >= 13) + +#include + +#else #if defined(__CUDACC_RTC__) -template -struct tuple_size; + template + struct tuple_size; -template -struct tuple_element; + template + struct tuple_element; +#endif #endif template diff --git a/3rd/cutlass/include/cute/int_tuple.hpp b/3rd/cutlass/include/cute/int_tuple.hpp index f9d7000..02d1fd8 100644 --- a/3rd/cutlass/include/cute/int_tuple.hpp +++ b/3rd/cutlass/include/cute/int_tuple.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/layout.hpp b/3rd/cutlass/include/cute/layout.hpp index 3f02a41..6b8b102 100644 --- a/3rd/cutlass/include/cute/layout.hpp +++ b/3rd/cutlass/include/cute/layout.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -834,7 +834,7 @@ coalesce_x(Layout const& layout) } else { return detail::bw_coalesce(flat_shape, flat_stride, get(flat_shape), get(flat_stride)); } - + CUTE_GCC_UNREACHABLE; } @@ -1474,49 +1474,33 @@ domain_distribute(ShapeA const& a, ShapeB const& b) // Kernel (Nullspace) of a Layout // -namespace detail { - -template -CUTE_HOST_DEVICE constexpr -auto -nullspace_seq(Stride const& stride, seq) -{ - if constexpr (NextI == rank_v) { - return seq{}; - } else - if constexpr (is_constant<0, decltype(get(stride))>::value) { - return detail::nullspace_seq(stride, seq{}); - } else { - return detail::nullspace_seq(stride, seq{}); - } - - CUTE_GCC_UNREACHABLE; -} - -} // end namespace detail - -// -// Build the nullspace of a layout -// @result A layout @a result such that -// size(@a result) == size(@a layout) / size(filter(@a layout)) -// @a layout(@a result(i)) == 0 for all i < size(@a result) -// - +/** Return a layout that represents the nullspace of @a layout + * @post @a layout(@a result(i)) == 0 for all i < size(@a result) + * @post nullspace(@a result) == Layout<_1,_0>{} + * @post size(@a result) == size(@a layout) / size(filter(@a layout)) + */ template CUTE_HOST_DEVICE constexpr auto nullspace(Layout const& layout) { - auto flat_layout = flatten(layout); + [[maybe_unused]] auto flat_stride = flatten(layout.stride()); - [[maybe_unused]] auto iseq = detail::nullspace_seq<0>(flat_layout.stride(), seq<>{}); + // Select all indices corresponding to stride-0s + auto iseq = cute::fold(make_seq>{}, cute::tuple<>{}, + [&](auto init, auto i){ + if constexpr (is_constant_v<0, decltype(get(flat_stride))>) { return append(init, i); } + else { return init; } + CUTE_GCC_UNREACHABLE; + }); - if constexpr (iseq.size() == 0) { + if constexpr (tuple_size::value == 0) { return Layout<_1,_0>{}; // Empty case, nothing found } else { // Generate the corresponding new strides and construct - auto rstride = compact_major(flat_layout.shape()); - return make_layout(unwrap(transform(iseq, [&](auto i) { return shape(flat_layout); })), + auto flat_shape = flatten(layout.shape()); + auto rstride = compact_major(flat_shape); + return make_layout(unwrap(transform(iseq, [&](auto i) { return get(flat_shape); })), unwrap(transform(iseq, [&](auto i) { return get(rstride); }))); } @@ -1944,185 +1928,4 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, Layout const& } #endif -// Generic 2D Layout to console table -template -CUTE_HOST_DEVICE -void -print_layout(Layout const& layout) // (m,n) -> idx -{ - CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{}); - - int idx_width = num_digits(cosize(layout)) + 2; - const char* delim = "+-----------------------"; - - print(layout); print("\n"); - - // Column indices - print(" "); - for (int n = 0; n < size<1>(layout); ++n) { printf(" %*d ", idx_width-2, n); } - printf("\n"); - - // Print out A m-by-n - for (int m = 0; m < size<0>(layout); ++m) { - // Header - print(" "); - for (int n = 0; n < size<1>(layout); ++n) { printf("%.*s", idx_width+1, delim); } - printf("+\n"); - // Values - printf("%2d ", m); // Row indices - for (int n = 0; n < size<1>(layout); ++n) { printf("| %*d ", idx_width-2, int(layout(m,n))); } - printf("|\n"); - } - // Footer - print(" "); - for (int n = 0; n < size<1>(layout); ++n) { printf("%.*s", idx_width+1, delim); } - printf("+\n"); -} - -// Generic ThrVal 2D Layout to console table -template -CUTE_HOST_DEVICE -void -print_layout(Layout const& layout, ThrID const& thrid) // (m,n) -> (tid,vid) and tid -> thr_idx -{ - CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{}); - - print(layout); print("\n"); - print(thrid); print("\n"); - - // Print out m-by-n - for (int m = 0; m < size<0>(layout); ++m) { - // Header - for (int n = 0; n < size<1>(layout); ++n) printf("+------"); - printf("+\n"); - // Values - for (int n = 0; n < size<1>(layout); ++n) printf("|%03d-%02d", int(thrid(layout(m,n) % size(thrid))), int(layout(m,n) / size(thrid))); - printf("|\n"); - } - // Footer - for (int n = 0; n < size<1>(layout); ++n) printf("+------"); - printf("+\n"); -} - -struct TikzColor_White { - CUTE_HOST_DEVICE char const* - operator()(int idx) const { - return "white"; - } -}; - -struct TikzColor_BWx8 { - CUTE_HOST_DEVICE char const* - operator()(int idx) const { - static char const* color_map[8] = {"black!00", "black!40", "black!20", "black!60", - "black!10", "black!50", "black!30", "black!70"}; - return color_map[idx % 8]; - } -}; - -struct TikzColor_TV { - CUTE_HOST_DEVICE char const* - operator()(int tid, int vid) const { - static char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", - "{rgb,255:red,175;green,255;blue,175}", - "{rgb,255:red,255;green,255;blue,175}", - "{rgb,255:red,255;green,175;blue,175}", - "{rgb,255:red,210;green,210;blue,255}", - "{rgb,255:red,210;green,255;blue,210}", - "{rgb,255:red,255;green,255;blue,210}", - "{rgb,255:red,255;green,210;blue,210}"}; - return color_map[tid % 8]; - } -}; - -// Generic 2D Layout to LaTeX printer -template -CUTE_HOST_DEVICE -void -print_latex(LayoutA const& layout_a, // (m,n) -> idx - TikzColorFn color = {}) // lambda(idx) -> tikz color string -{ - CUTE_STATIC_ASSERT_V(rank(layout_a) <= Int<2>{}); - auto layout = append<2>(layout_a, Layout<_1,_0>{}); - - // Commented print(layout) - printf("%% Layout: "); print(layout); printf("\n"); - // Header - printf("\\documentclass[convert]{standalone}\n" - "\\usepackage{tikz}\n\n" - "\\begin{document}\n" - "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n"); - - // Layout - for (int i = 0; i < size<0>(layout); ++i) { - for (int j = 0; j < size<1>(layout); ++j) { - int idx = layout(i,j); - printf("\\node[fill=%s] at (%d,%d) {%d};\n", - color(idx), i, j, idx); - } - } - // Grid - printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n", - int(size<0>(layout)), int(size<1>(layout))); - // Labels - for (int i = 0, j = -1; i < size<0>(layout); ++i) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i); - } - for (int i = -1, j = 0; j < size<1>(layout); ++j) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j); - } - - // Footer - printf("\\end{tikzpicture}\n" - "\\end{document}\n"); -} - -// Generic ThrVal 2D Layout to LaTeX TikZ -template -CUTE_HOST_DEVICE -void -print_latex(Layout const& layout, // (m,n) -> (tid,vid) - ThrID const& thr, // tid -> thr_idx - TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string -{ - CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{}); - - // Commented prints - printf("%% Layout: "); print(layout); printf("\n"); - printf("%% ThrID : "); print(thr); printf("\n"); - // Header - printf("\\documentclass[convert]{standalone}\n" - "\\usepackage{tikz}\n\n" - "\\begin{document}\n" - "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n"); - - // Layout - for (int i = 0; i < size<0>(layout); ++i) { - for (int j = 0; j < size<1>(layout); ++j) { - int thrid = layout(i,j) % size(thr); - int val_idx = layout(i,j) / size(thr); - int thr_idx = thr(thrid); - - printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color(thr_idx, val_idx), - i, j, - thr_idx, val_idx); - } - } - // Grid - printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n", - int(size<0>(layout)), int(size<1>(layout))); - // Labels - for (int i = 0, j = -1; i < size<0>(layout); ++i) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i); - } - for (int j = 0, i = -1; j < size<1>(layout); ++j) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j); - } - - // Footer - printf("\\end{tikzpicture}\n" - "\\end{document}\n"); -} - } // end namespace cute diff --git a/3rd/cutlass/include/cute/layout_composed.hpp b/3rd/cutlass/include/cute/layout_composed.hpp index 6a96778..3a34c1e 100644 --- a/3rd/cutlass/include/cute/layout_composed.hpp +++ b/3rd/cutlass/include/cute/layout_composed.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/numeric/arithmetic_tuple.hpp b/3rd/cutlass/include/cute/numeric/arithmetic_tuple.hpp index 60a4ff4..01dde5f 100644 --- a/3rd/cutlass/include/cute/numeric/arithmetic_tuple.hpp +++ b/3rd/cutlass/include/cute/numeric/arithmetic_tuple.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -42,22 +42,15 @@ namespace cute { template -struct ArithmeticTuple : tuple -{ - template +struct ArithmeticTuple : public tuple { CUTE_HOST_DEVICE constexpr - ArithmeticTuple(ArithmeticTuple const& u) - : tuple(static_cast const&>(u)) {} + ArithmeticTuple() : tuple() {} - template CUTE_HOST_DEVICE constexpr - ArithmeticTuple(tuple const& u) - : tuple(u) {} + ArithmeticTuple(tuple const& t) : tuple(t) {} - template CUTE_HOST_DEVICE constexpr - ArithmeticTuple(U const&... u) - : tuple(u...) {} + ArithmeticTuple(T const&... t) : tuple(t...) {} }; template @@ -147,12 +140,12 @@ operator-(ArithmeticTuple const& t) { } // -// Special cases +// Special cases for C<0> // template CUTE_HOST_DEVICE constexpr -ArithmeticTuple const& +ArithmeticTuple operator+(C, ArithmeticTuple const& u) { static_assert(t == 0, "Arithmetic tuple op+ error!"); return u; @@ -160,7 +153,7 @@ operator+(C, ArithmeticTuple const& u) { template CUTE_HOST_DEVICE constexpr -ArithmeticTuple const& +ArithmeticTuple operator+(ArithmeticTuple const& t, C) { static_assert(u == 0, "Arithmetic tuple op+ error!"); return t; @@ -168,7 +161,7 @@ operator+(ArithmeticTuple const& t, C) { template CUTE_HOST_DEVICE constexpr -ArithmeticTuple const& +ArithmeticTuple operator-(C, ArithmeticTuple const& u) { static_assert(t == 0, "Arithmetic tuple op- error!"); return -u; @@ -176,7 +169,7 @@ operator-(C, ArithmeticTuple const& u) { template CUTE_HOST_DEVICE constexpr -ArithmeticTuple const& +ArithmeticTuple operator-(ArithmeticTuple const& t, C) { static_assert(u == 0, "Arithmetic tuple op- error!"); return t; @@ -212,27 +205,20 @@ struct ArithmeticTupleIterator } }; -template -CUTE_HOST_DEVICE constexpr -auto -make_inttuple_iter(Tuple const& t) { - return ArithmeticTupleIterator(as_arithmetic_tuple(t)); -} - -template +template CUTE_HOST_DEVICE constexpr auto -make_inttuple_iter(T0 const& t0, T1 const& t1, Ts const&... ts) { - return make_inttuple_iter(cute::make_tuple(t0, t1, ts...)); +make_inttuple_iter(Ts const&... ts) { + return ArithmeticTupleIterator(as_arithmetic_tuple(ts...)); } // // ArithmeticTuple "basis" elements -// A ScaledBasis is a (at least) rank-N+1 ArithmeticTuple: +// A ScaledBasis is a (at least) rank-N+1 ArithmeticTuple: // (_0,_0,...,T,_0,...) // with value T in the Nth mode -template +template struct ScaledBasis : private tuple { CUTE_HOST_DEVICE constexpr @@ -243,40 +229,61 @@ struct ScaledBasis : private tuple CUTE_HOST_DEVICE constexpr decltype(auto) value() const { return get<0>(static_cast const&>(*this)); } + // Deprecated: Get the first hierarchical mode in this basis. CUTE_HOST_DEVICE static constexpr - auto mode() { return Int{}; } + auto mode() { return get<0>(int_sequence{}); } }; +// Ensure flat representation +template +struct ScaledBasis, Ns...> : ScaledBasis {}; + template struct is_scaled_basis : false_type {}; -template -struct is_scaled_basis> : true_type {}; +template +struct is_scaled_basis> : true_type {}; -template -struct is_integral> : true_type {}; +template +struct is_integral> : true_type {}; -// Get the scalar T out of a ScaledBasis -template -CUTE_HOST_DEVICE constexpr auto -basis_value(SB const& e) +// Shortcuts +// E<> := _1 +// E<0> := (_1,_0,_0,...) +// E<1> := (_0,_1,_0,...) +// E<0,0> := ((_1,_0,_0,...),_0,_0,...) +// E<0,1> := ((_0,_1,_0,...),_0,_0,...) +// E<1,0> := (_0,(_1,_0,_0,...),_0,...) +// E<1,1> := (_0,(_0,_1,_0,...),_0,...) +template +using E = ScaledBasis,Ns...>; + +// Apply the Ns... pack to another Tuple +template +CUTE_HOST_DEVICE decltype(auto) +basis_get(T const&, Tuple&& t) { - if constexpr (is_scaled_basis::value) { - return basis_value(e.value()); + return static_cast(t); +} + +template +CUTE_HOST_DEVICE decltype(auto) +basis_get(ScaledBasis const&, Tuple&& t) +{ + if constexpr (sizeof...(Ns) == 0) { + return static_cast(t); } else { - return e; + return get(static_cast(t)); } CUTE_GCC_UNREACHABLE; } -// Apply the N... pack to another Tuple -template +template CUTE_HOST_DEVICE decltype(auto) -basis_get(SB const& e, Tuple&& t) -{ - if constexpr (is_scaled_basis::value) { - return basis_get(e.value(), get(static_cast(t))); +basis_value(T const& e) { + if constexpr (is_scaled_basis::value) { + return e.value(); } else { - return static_cast(t); + return e; } CUTE_GCC_UNREACHABLE; } @@ -294,65 +301,34 @@ to_atuple_i(T const& t, seq) { // Turn a ScaledBases into a rank-N+1 ArithmeticTuple // with N prefix 0s: (_0,_0,...N...,_0,T) -template +template CUTE_HOST_DEVICE constexpr auto -as_arithmetic_tuple(ScaledBasis const& t) { - return detail::to_atuple_i(as_arithmetic_tuple(t.value()), make_seq{}); +as_arithmetic_tuple(ScaledBasis const& t) { + return t.value(); } -namespace detail { - -template -struct Basis; - -template <> -struct Basis<> { - using type = Int<1>; -}; - -template -struct Basis { - using type = ScaledBasis::type, N>; -}; - -} // end namespace detail - -// Shortcut for writing ScaledBasis, N0>, N1>, ...> -// E<> := _1 -// E<0> := (_1,_0,_0,...) -// E<1> := (_0,_1,_0,...) -// E<0,0> := ((_1,_0,_0,...),_0,_0,...) -// E<0,1> := ((_0,_1,_0,...),_0,_0,...) -// E<1,0> := (_0,(_1,_0,_0,...),_0,...) -// E<1,1> := (_0,(_0,_1,_0,...),_0,...) -template -using E = typename detail::Basis::type; +template +CUTE_HOST_DEVICE constexpr +auto +as_arithmetic_tuple(ScaledBasis const& t) { + return detail::to_atuple_i(as_arithmetic_tuple(ScaledBasis{t.value()}), make_seq{}); +} -template +template CUTE_HOST_DEVICE constexpr auto make_basis_like(Shape const& shape) { - if constexpr (is_integral::value) { - return Int<1>{}; - } else { - // Generate bases for each rank of shape + if constexpr (is_tuple::value) { + // Generate bases for each mode of shape return transform(tuple_seq{}, shape, [](auto I, auto si) { - // Generate bases for each rank of si and add an i on front - using I_type = decltype(I); - return transform_leaf(make_basis_like(si), [](auto e) { - // MSVC has trouble capturing variables as constexpr, - // so that they can be used as template arguments. - // This is exactly what the code needs to do with i, unfortunately. - // The work-around is to define i inside the inner lambda, - // by using just the type from the enclosing scope. - constexpr int i = I_type::value; - return ScaledBasis{}; - }); + // Generate bases for each si and add an i on end + return make_basis_like(si); }); + } else { + return E{}; } - CUTE_GCC_UNREACHABLE; } @@ -360,109 +336,124 @@ make_basis_like(Shape const& shape) // Arithmetic // -template +template CUTE_HOST_DEVICE constexpr auto -safe_div(ScaledBasis const& b, U const& u) +safe_div(ScaledBasis const& b, U const& u) { auto t = safe_div(b.value(), u); - return ScaledBasis{t}; + return ScaledBasis{t}; } -template +template CUTE_HOST_DEVICE constexpr auto -ceil_div(ScaledBasis const& b, U const& u) +ceil_div(ScaledBasis const& b, U const& u) { auto t = ceil_div(b.value(), u); - return ScaledBasis{t}; + return ScaledBasis{t}; } -template +template CUTE_HOST_DEVICE constexpr auto -abs(ScaledBasis const& e) +abs(ScaledBasis const& e) { auto t = abs(e.value()); - return ScaledBasis{t}; + return ScaledBasis{t}; } // Equality -template +template CUTE_HOST_DEVICE constexpr auto -operator==(ScaledBasis const& t, ScaledBasis const& u) { - return bool_constant{} && t.value() == u.value(); +operator==(ScaledBasis const& t, ScaledBasis const& u) { + if constexpr (sizeof...(Ns) == sizeof...(Ms)) { + return bool_constant<((Ns == Ms) && ...)>{} && t.value() == u.value(); + } else { + return false_type{}; + } + CUTE_GCC_UNREACHABLE; } // Not equal to anything else -template +template CUTE_HOST_DEVICE constexpr false_type -operator==(ScaledBasis const&, U const&) { +operator==(ScaledBasis const&, U const&) { return {}; } -template +template CUTE_HOST_DEVICE constexpr false_type -operator==(T const&, ScaledBasis const&) { +operator==(T const&, ScaledBasis const&) { return {}; } // Multiplication -template +template CUTE_HOST_DEVICE constexpr auto -operator*(A const& a, ScaledBasis const& e) { +operator*(A const& a, ScaledBasis const& e) { auto r = a * e.value(); - return ScaledBasis{r}; + return ScaledBasis{r}; } -template +template CUTE_HOST_DEVICE constexpr auto -operator*(ScaledBasis const& e, B const& b) { +operator*(ScaledBasis const& e, B const& b) { auto r = e.value() * b; - return ScaledBasis{r}; + return ScaledBasis{r}; } // Addition -template +template CUTE_HOST_DEVICE constexpr auto -operator+(ScaledBasis const& t, ScaledBasis const& u) { +operator+(ScaledBasis const& t, ScaledBasis const& u) { return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); } -template +template CUTE_HOST_DEVICE constexpr auto -operator+(ScaledBasis const& t, ArithmeticTuple const& u) { +operator+(ScaledBasis const& t, ArithmeticTuple const& u) { return as_arithmetic_tuple(t) + u; } -template +template CUTE_HOST_DEVICE constexpr auto -operator+(ArithmeticTuple const& t, ScaledBasis const& u) { +operator+(ArithmeticTuple const& t, ScaledBasis const& u) { return t + as_arithmetic_tuple(u); } -template +template CUTE_HOST_DEVICE constexpr auto -operator+(C, ScaledBasis const& u) { - static_assert(t == 0, "ScaledBasis op+ error!"); - return u; +operator+(C, ScaledBasis const& u) { + if constexpr (sizeof...(Ms) == 0) { + return C{} + u.value(); + } else { + static_assert(t == 0, "ScaledBasis op+ error!"); + return u; + } + CUTE_GCC_UNREACHABLE; } -template +template CUTE_HOST_DEVICE constexpr auto -operator+(ScaledBasis const& t, C) { - static_assert(u == 0, "ScaledBasis op+ error!"); - return t; +operator+(ScaledBasis const& t, C) { + if constexpr (sizeof...(Ns) == 0) { + return t.value() + C{}; + } else { + static_assert(u == 0, "ScaledBasis op+ error!"); + return t; + } + CUTE_GCC_UNREACHABLE; } // @@ -475,10 +466,12 @@ CUTE_HOST_DEVICE void print(ArithmeticTupleIterator const& iter) printf("ArithTuple"); print(iter.coord_); } -template -CUTE_HOST_DEVICE void print(ScaledBasis const& e) +template +CUTE_HOST_DEVICE void print(ScaledBasis const& e) { - print(e.value()); printf("@%d", N); + print(e.value()); + // Param pack trick to print in reverse + [[maybe_unused]] int dummy; (dummy = ... = (void(printf("@%d", Ns)), 0)); } #if !defined(__CUDACC_RTC__) @@ -488,10 +481,13 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, ArithmeticTupleIterator -CUTE_HOST std::ostream& operator<<(std::ostream& os, ScaledBasis const& e) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, ScaledBasis const& e) { - return os << e.value() << "@" << N; + os << e.value(); + // Param pack trick to print in reverse + [[maybe_unused]] int dummy; (dummy = ... = (void(os << "@" << Ns),0)); + return os; } #endif @@ -517,12 +513,18 @@ struct tuple_element> namespace std { +#if (__CUDACC_VER_MAJOR__ >= 13) + +#include + +#else #if defined(__CUDACC_RTC__) -template -struct tuple_size; + template + struct tuple_size; -template -struct tuple_element; + template + struct tuple_element; +#endif #endif template diff --git a/3rd/cutlass/include/cute/numeric/complex.hpp b/3rd/cutlass/include/cute/numeric/complex.hpp index 3115d61..f9ed592 100644 --- a/3rd/cutlass/include/cute/numeric/complex.hpp +++ b/3rd/cutlass/include/cute/numeric/complex.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/numeric/int.hpp b/3rd/cutlass/include/cute/numeric/int.hpp index 485c07d..484e799 100644 --- a/3rd/cutlass/include/cute/numeric/int.hpp +++ b/3rd/cutlass/include/cute/numeric/int.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -29,9 +29,9 @@ * **************************************************************************************************/ #pragma once - +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(cstdint) #else #include #endif @@ -47,8 +47,9 @@ namespace cute // Signed integers // -using int2_t = cutlass::int2b_t; -using int4_t = cutlass::int4b_t; +using int2_t = cutlass::int2b_t; +using int4_t = cutlass::int4b_t; +using int6_t = cutlass::int6b_t; using CUTE_STL_NAMESPACE::int8_t; using CUTE_STL_NAMESPACE::int16_t; using CUTE_STL_NAMESPACE::int32_t; @@ -75,25 +76,29 @@ using int_byte_t = typename int_byte::type; // Unsigned integers // -using uint1_t = cutlass::uint1b_t; -using uint2_t = cutlass::uint2b_t; -using uint4_t = cutlass::uint4b_t; -using uint6_t = cutlass::uint6b_t; +using uint1_t = cutlass::uint1b_t; +using uint2_t = cutlass::uint2b_t; +using uint4_t = cutlass::uint4b_t; +using uint6_t = cutlass::uint6b_t; using CUTE_STL_NAMESPACE::uint8_t; using CUTE_STL_NAMESPACE::uint16_t; using CUTE_STL_NAMESPACE::uint32_t; using CUTE_STL_NAMESPACE::uint64_t; using cutlass::uint128_t; +using cutlass::uint256_t; + template struct uint_bit; template <> struct uint_bit< 1> { using type = uint1_t; }; template <> struct uint_bit< 2> { using type = uint2_t; }; template <> struct uint_bit< 4> { using type = uint4_t; }; -template <> struct uint_bit< 6> { using type = uint6_t; }; +template <> struct uint_bit< 6> { using type = uint6_t; }; template <> struct uint_bit< 8> { using type = uint8_t; }; template <> struct uint_bit< 16> { using type = uint16_t; }; template <> struct uint_bit< 32> { using type = uint32_t; }; template <> struct uint_bit< 64> { using type = uint64_t; }; template <> struct uint_bit<128> { using type = cutlass::uint128_t; }; +template <> struct uint_bit<256> { using type = cutlass::uint256_t; }; + template using uint_bit_t = typename uint_bit::type; diff --git a/3rd/cutlass/include/cute/numeric/integer_sequence.hpp b/3rd/cutlass/include/cute/numeric/integer_sequence.hpp index 799e189..e08e693 100644 --- a/3rd/cutlass/include/cute/numeric/integer_sequence.hpp +++ b/3rd/cutlass/include/cute/numeric/integer_sequence.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/numeric/integral_constant.hpp b/3rd/cutlass/include/cute/numeric/integral_constant.hpp index 1d33361..56639ac 100644 --- a/3rd/cutlass/include/cute/numeric/integral_constant.hpp +++ b/3rd/cutlass/include/cute/numeric/integral_constant.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -516,7 +516,7 @@ constexpr uint64_t parse_int_digits(uint64_t result, int digit, Ts... digits) // var has type cute::constant. // template -constexpr cute::constant operator "" _c() +constexpr cute::constant operator""_c() { static_assert((('0' <= digits && digits <= '9') && ...), "Expected 0 <= digit <= 9 for each digit of the integer."); diff --git a/3rd/cutlass/include/cute/numeric/integral_ratio.hpp b/3rd/cutlass/include/cute/numeric/integral_ratio.hpp index 0104c31..505b3d7 100644 --- a/3rd/cutlass/include/cute/numeric/integral_ratio.hpp +++ b/3rd/cutlass/include/cute/numeric/integral_ratio.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -225,6 +225,27 @@ operator==(C, R) { return {}; } +template +CUTE_HOST_DEVICE constexpr +bool_constant::num * R::den < R::num * R::den> +operator<(R, R) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +bool_constant::num < c * R::den> +operator<(R, C) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +bool_constant::den < R::num> +operator<(C, R) { + return {}; +} + /////////////////////// // Special functions // /////////////////////// diff --git a/3rd/cutlass/include/cute/numeric/math.hpp b/3rd/cutlass/include/cute/numeric/math.hpp index 147458b..676db2a 100644 --- a/3rd/cutlass/include/cute/numeric/math.hpp +++ b/3rd/cutlass/include/cute/numeric/math.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/numeric/numeric_types.hpp b/3rd/cutlass/include/cute/numeric/numeric_types.hpp index 892ec70..c475d3c 100644 --- a/3rd/cutlass/include/cute/numeric/numeric_types.hpp +++ b/3rd/cutlass/include/cute/numeric/numeric_types.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -38,10 +38,19 @@ namespace cute { -template -struct sizeof_bits : public cutlass::sizeof_bits {}; +template +struct sizeof_bits : cutlass::sizeof_bits {}; + +template +struct sizeof_bits : sizeof_bits {}; + +template +struct sizeof_bits : sizeof_bits {}; + +template +struct sizeof_bits : sizeof_bits {}; -// DO NOT change auto to int, sizeof_bits use integral_ratio instead of int +// DO NOT change auto to int, sizeof_bits use integral_ratio instead of int template static constexpr auto sizeof_bits_v = sizeof_bits::value; @@ -53,6 +62,23 @@ using cutlass::is_subbyte; template static constexpr auto is_subbyte_v = is_subbyte::value; +// +// Integral +// + +using cutlass::bin1_t; +using cutlass::uint1b_t; +using cutlass::int2b_t; +using cutlass::uint2b_t; +using cutlass::int4b_t; +using cutlass::uint4b_t; +using cutlass::int6b_t; +using cutlass::uint6b_t; + +// +// Floating Point +// + using cutlass::half_t; using cutlass::bfloat16_t; @@ -65,18 +91,12 @@ using cutlass::type_erased_dynamic_float8_t; using cutlass::float_e4m3_t; using cutlass::float_e5m2_t; -using cutlass::uint1b_t; -using cutlass::int2b_t; -using cutlass::uint2b_t; -using cutlass::int4b_t; -using cutlass::uint4b_t; -using cutlass::bin1_t; + using cutlass::float_ue4m3_t; using cutlass::float_ue8m0_t; -using cutlass::uint6b_t; using cutlass::float_e2m1_t; using cutlass::float_e2m3_t; using cutlass::float_e3m2_t; @@ -94,8 +114,6 @@ using cutlass::detail::type_erased_dynamic_float4_unpacksmem_t; using cutlass::detail::type_erased_dynamic_float6_unpacksmem_t; }; - - // // Print utility // @@ -112,7 +130,6 @@ print(bfloat16_t a) { printf("%f", static_cast(a)); } - CUTE_HOST_DEVICE void print(tfloat32_t a) { @@ -131,6 +148,15 @@ print(float_e5m2_t a) { printf("%f", static_cast(a)); } +template +CUTE_HOST_DEVICE +void +print(cutlass::float_exmy_base a) { + printf("%f", static_cast(a)); +} + +// Pretty Print utility + CUTE_HOST_DEVICE void pretty_print(bfloat16_t v) { printf("%*.2f", 8, float(v)); @@ -156,26 +182,11 @@ pretty_print(float_e5m2_t t) { printf("%*.2f", 8, static_cast(t)); } - -template < - cutlass::detail::FpEncoding Encoding, - class Derived -> -CUTE_HOST_DEVICE -void -print(cutlass::float_exmy_base a) { - printf("%f", static_cast(a)); -} - -template < - cutlass::detail::FpEncoding Encoding, - class Derived -> +template CUTE_HOST_DEVICE void pretty_print_float_exmy_base(cutlass::float_exmy_base t) { printf("%*.2f", 8, static_cast(t)); } - } // namespace cute diff --git a/3rd/cutlass/include/cute/numeric/real.hpp b/3rd/cutlass/include/cute/numeric/real.hpp index 0bc9555..e572787 100644 --- a/3rd/cutlass/include/cute/numeric/real.hpp +++ b/3rd/cutlass/include/cute/numeric/real.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/pointer.hpp b/3rd/cutlass/include/cute/pointer.hpp index 3c42fd2..8a4bce3 100644 --- a/3rd/cutlass/include/cute/pointer.hpp +++ b/3rd/cutlass/include/cute/pointer.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -33,9 +33,9 @@ #include // CUTE_HOST_DEVICE #include // cute::iter_adaptor #include -#include // cute::subbyte_iterator #include // cute::true_type, cute::false_type #include // sizeof_bits +#include // cute::subbyte_iterator namespace cute { @@ -51,11 +51,13 @@ namespace cute // Requires construction of a sparse_ptr that emulates access to the S logical elements. // -template +template CUTE_HOST_DEVICE constexpr auto -recast_ptr(void* ptr) +recast_ptr(T* ptr) { + using NewT = copy_cv_t; + if constexpr (is_sparse::value) { constexpr int sparsity = NewT::sparsity; NewT* p = reinterpret_cast(ptr); @@ -69,24 +71,6 @@ recast_ptr(void* ptr) CUTE_GCC_UNREACHABLE; } -template -CUTE_HOST_DEVICE constexpr -auto -recast_ptr(void const* ptr) -{ - if constexpr (is_sparse::value) { - constexpr int sparsity = NewT::sparsity; - NewT const* p = reinterpret_cast(ptr); - return make_sparse_ptr(p); - } else - if constexpr (cute::is_subbyte_v) { - return subbyte_iterator(ptr); - } else { - return reinterpret_cast(ptr); - } - CUTE_GCC_UNREACHABLE; -} - // Disambiguate nullptr template CUTE_HOST_DEVICE constexpr diff --git a/3rd/cutlass/include/cute/pointer_base.hpp b/3rd/cutlass/include/cute/pointer_base.hpp index 740cc1b..3451d32 100644 --- a/3rd/cutlass/include/cute/pointer_base.hpp +++ b/3rd/cutlass/include/cute/pointer_base.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/pointer_flagged.hpp b/3rd/cutlass/include/cute/pointer_flagged.hpp index 7f20534..095391f 100644 --- a/3rd/cutlass/include/cute/pointer_flagged.hpp +++ b/3rd/cutlass/include/cute/pointer_flagged.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -167,23 +167,6 @@ downcast(ComposedLayout,Layout> const& // Display utilities // -// Capture and cast smem_ptr_flag Layouts to offset-0 layouts -template -CUTE_HOST_DEVICE -void -print_layout(ComposedLayout,Layout> const& layout) -{ - print_layout(as_position_independent_swizzle_layout(layout)); -} - -template -CUTE_HOST_DEVICE -void -print_latex(ComposedLayout,Layout> const& layout) -{ - print_latex(as_position_independent_swizzle_layout(layout)); -} - template CUTE_HOST_DEVICE void print(smem_ptr_flag_bits ptr) { diff --git a/3rd/cutlass/include/cute/pointer_sparse.hpp b/3rd/cutlass/include/cute/pointer_sparse.hpp index 56c4c29..362b7f0 100644 --- a/3rd/cutlass/include/cute/pointer_sparse.hpp +++ b/3rd/cutlass/include/cute/pointer_sparse.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/pointer_swizzle.hpp b/3rd/cutlass/include/cute/pointer_swizzle.hpp index 5706cd7..bef2f3d 100644 --- a/3rd/cutlass/include/cute/pointer_swizzle.hpp +++ b/3rd/cutlass/include/cute/pointer_swizzle.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/stride.hpp b/3rd/cutlass/include/cute/stride.hpp index 629cdfd..1e4a871 100644 --- a/3rd/cutlass/include/cute/stride.hpp +++ b/3rd/cutlass/include/cute/stride.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/swizzle.hpp b/3rd/cutlass/include/cute/swizzle.hpp index 2ae7d09..52a1071 100644 --- a/3rd/cutlass/include/cute/swizzle.hpp +++ b/3rd/cutlass/include/cute/swizzle.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/swizzle_layout.hpp b/3rd/cutlass/include/cute/swizzle_layout.hpp index ef1ca18..7300d3f 100644 --- a/3rd/cutlass/include/cute/swizzle_layout.hpp +++ b/3rd/cutlass/include/cute/swizzle_layout.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -165,6 +165,15 @@ get_nonswizzle_portion(Layout const& slayout) return slayout; } +// Return the codomain size of a Swizzled ComposedLayout +template +CUTE_HOST_DEVICE constexpr +auto +cosize(ComposedLayout,Offset,LayoutB> const& layout) +{ + return cosize(layout.layout_b()); +} + // // Slice a Swizzled ComposedLayout // diff --git a/3rd/cutlass/include/cute/tensor.hpp b/3rd/cutlass/include/cute/tensor.hpp index 1ab62fd..e7a7060 100644 --- a/3rd/cutlass/include/cute/tensor.hpp +++ b/3rd/cutlass/include/cute/tensor.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -56,3 +56,9 @@ #include #include +// +// Utilities +// + +#include +#include diff --git a/3rd/cutlass/include/cute/tensor_impl.hpp b/3rd/cutlass/include/cute/tensor_impl.hpp index e65ad41..6a26fa1 100644 --- a/3rd/cutlass/include/cute/tensor_impl.hpp +++ b/3rd/cutlass/include/cute/tensor_impl.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -753,24 +753,30 @@ domain_offset(Coord const& coord, Tensor&& tensor) // -- doesn't check dynamic integer divisibility // -- doesn't check alignment -template +template CUTE_HOST_DEVICE constexpr auto recast(Tensor&& tensor) { - using OldType = typename remove_cvref_t::value_type; - auto old_layout = tensor.layout(); - auto new_layout = recast_layout(old_layout); + using OldType = typename remove_cvref_t::element_type; + using NewType = copy_cv_t; - // If this is an upcast of a normal Layout with static negative strides, then offset as well - if constexpr (sizeof(OldType) < sizeof(NewType) && not is_composed_layout::value) { - auto shape_diff = transform(flatten(old_layout.shape()), flatten(new_layout.shape()), minus{}); - auto extent_diff = transform(shape_diff, flatten(old_layout.stride()), multiplies{}); - auto offset = fold(extent_diff, Int<0>{}, [](auto const& i, auto const& a) { return i + cute::min(a,Int<0>{}); }); - - return make_tensor(recast_ptr(static_cast(tensor).data() + offset), new_layout); + if constexpr (is_same::value) { + return make_tensor(static_cast(tensor).data(), tensor.layout()); } else { - return make_tensor(recast_ptr(static_cast(tensor).data() ), new_layout); + auto old_layout = tensor.layout(); + auto new_layout = recast_layout(old_layout); + + // If this is an upcast of a normal Layout with static negative strides, then offset as well + if constexpr (sizeof(OldType) < sizeof(NewType) && not is_composed_layout::value) { + auto shape_diff = transform(flatten(old_layout.shape()), flatten(new_layout.shape()), minus{}); + auto extent_diff = transform(shape_diff, flatten(old_layout.stride()), multiplies{}); + auto offset = fold(extent_diff, Int<0>{}, [](auto const& i, auto const& a) { return i + cute::min(a,Int<0>{}); }); + + return make_tensor(recast_ptr(static_cast(tensor).data() + offset), new_layout); + } else { + return make_tensor(recast_ptr(static_cast(tensor).data() ), new_layout); + } } CUTE_GCC_UNREACHABLE; @@ -1114,95 +1120,5 @@ CUTE_HOST_DEVICE void print(Tensor const& tensor) print(tensor.data()); print(" o "); print(tensor.layout()); } -template -CUTE_HOST_DEVICE void print_tensor(Tensor const& tensor, bool print_type = true) -{ - if (print_type) { - print(tensor); print(":\n"); - } - - if constexpr (Layout::rank == 1) - { - for (int m = 0; m < size(tensor); ++m) { - pretty_print(tensor(m)); - printf("\n"); - } - } else - if constexpr (Layout::rank == 2) - { - for (int m = 0; m < size<0>(tensor); ++m) { - for (int n = 0; n < size<1>(tensor); ++n) { - pretty_print(tensor(m,n)); - } - printf("\n"); - } - } else - if constexpr (Layout::rank == 3) - { - print_tensor(tensor(_,_,0), false); - for (int k = 1; k < size<2>(tensor); ++k) { - for (int i = 0; i < 5*size<1>(tensor); ++i) { print("-"); } print("\n"); - print_tensor(tensor(_,_,k), false); - } - } else - if constexpr (Layout::rank == 4) - { - print_tensor(tensor(_,_,_,0), false); - for (int p = 1; p < size<3>(tensor); ++p) { - for (int i = 0; i < 5*size<1>(tensor); ++i) { print("="); } print("\n"); - print_tensor(tensor(_,_,_,p), false); - } - } -} - -#if !defined(__CUDACC_RTC__) -template -CUTE_HOST std::ostream& print_tensor_os(std::ostream& os, Tensor const& tensor) -{ - int digits = 9; - - if constexpr (Layout::rank == 1) - { - for (int m = 0; m < size(tensor); ++m) { - os << std::setw(digits) << tensor(m) << std::endl; - } - } else - if constexpr (Layout::rank == 2) - { - for (int m = 0; m < size<0>(tensor); ++m) { - for (int n = 0; n < size<1>(tensor); ++n) { - os << std::setw(digits) << tensor(m,n); - } - os << std::endl; - } - } else - if constexpr (Layout::rank == 3) - { - print_tensor_os(os, tensor(_,_,0)); - for (int k = 1; k < size<2>(tensor); ++k) { - for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "-"; } os << std::endl; - print_tensor_os(os, tensor(_,_,k)); - } - } else - if constexpr (Layout::rank == 4) - { - print_tensor_os(os, tensor(_,_,_,0)); - for (int p = 1; p < size<3>(tensor); ++p) { - for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "="; } os << std::endl; - print_tensor_os(os, tensor(_,_,_,p)); - } - } - - return os; -} - -template -CUTE_HOST std::ostream& operator<<(std::ostream& os, Tensor const& tensor) -{ - os << tensor.layout() << std::endl; - return print_tensor_os(os, tensor); -} -#endif // !defined(__CUDACC_RTC__) - } // end namespace cute diff --git a/3rd/cutlass/include/cute/tensor_zip.hpp b/3rd/cutlass/include/cute/tensor_zip.hpp index 279c405..b23a7c6 100644 --- a/3rd/cutlass/include/cute/tensor_zip.hpp +++ b/3rd/cutlass/include/cute/tensor_zip.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/underscore.hpp b/3rd/cutlass/include/cute/underscore.hpp index 8a83b86..6edd04a 100644 --- a/3rd/cutlass/include/cute/underscore.hpp +++ b/3rd/cutlass/include/cute/underscore.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/util/debug.hpp b/3rd/cutlass/include/cute/util/debug.hpp index 5e704b2..1095d9c 100644 --- a/3rd/cutlass/include/cute/util/debug.hpp +++ b/3rd/cutlass/include/cute/util/debug.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cute/util/print.hpp b/3rd/cutlass/include/cute/util/print.hpp index e6cc887..12d2763 100644 --- a/3rd/cutlass/include/cute/util/print.hpp +++ b/3rd/cutlass/include/cute/util/print.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -127,6 +127,18 @@ print(uint4b_t a) { printf("%d", int(a)); } +CUTE_HOST_DEVICE +void +print(int6b_t a) { + printf("%d", int(a)); +} + +CUTE_HOST_DEVICE +void +print(uint6b_t a) { + printf("%d", int(a)); +} + CUTE_HOST_DEVICE void print(bin1_t a) { @@ -217,6 +229,16 @@ pretty_print(uint4b_t a) { printf("%*d", 5, int(a)); } +CUTE_HOST_DEVICE void +pretty_print(int6b_t a) { + printf("%*d", 5, int(a)); +} + +CUTE_HOST_DEVICE void +pretty_print(uint6b_t a) { + printf("%*d", 5, int(a)); +} + CUTE_HOST_DEVICE void pretty_print(bool v) { printf("%*d", 3, int(v)); diff --git a/3rd/cutlass/include/cute/util/print_latex.hpp b/3rd/cutlass/include/cute/util/print_latex.hpp new file mode 100644 index 0000000..96fa8ca --- /dev/null +++ b/3rd/cutlass/include/cute/util/print_latex.hpp @@ -0,0 +1,438 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE + +#include +#include + +#include +#include + +namespace cute +{ + +/////////////////////////////////////// +// Common LaTeX TikZ Color utilities // +/////////////////////////////////////// + +struct TikzColor_White { + CUTE_HOST_DEVICE char const* + operator()(int idx) const { + return "white"; + } +}; + +struct TikzColor_BWx8 { + CUTE_HOST_DEVICE char const* + operator()(int idx) const { + static char const* color_map[8] = {"black!00", "black!40", "black!20", "black!60", + "black!10", "black!50", "black!30", "black!70"}; + return color_map[idx % 8]; + } +}; + +struct TikzColor_TV { + CUTE_HOST_DEVICE char const* + operator()(int tid, int vid) const { + static char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", + "{rgb,255:red,175;green,255;blue,175}", + "{rgb,255:red,255;green,255;blue,175}", + "{rgb,255:red,255;green,175;blue,175}", + "{rgb,255:red,210;green,210;blue,255}", + "{rgb,255:red,210;green,255;blue,210}", + "{rgb,255:red,255;green,255;blue,210}", + "{rgb,255:red,255;green,210;blue,210}"}; + return color_map[tid % 8]; + } +}; + +///////////////////////////// +// Layout 2D to LaTeX TikZ // +///////////////////////////// + +template +CUTE_HOST_DEVICE +void +print_latex(LayoutA const& layout_a, // (m,n) -> idx + TikzColorFn color = {}) // lambda(idx) -> tikz color string +{ + CUTE_STATIC_ASSERT_V(rank(layout_a) <= Int<2>{}); + auto layout = append<2>(layout_a, Layout<_1,_0>{}); + + // Commented print(layout) + printf("%% Layout: "); print(layout); printf("\n"); + // Header + printf("\\documentclass[convert]{standalone}\n" + "\\usepackage{tikz}\n\n" + "\\begin{document}\n" + "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n"); + + auto [M, N] = product_each(shape(layout)); + + // Layout + for (int m = 0; m < M; ++m) { + for (int n = 0; n < N; ++n) { + int idx = layout(m,n); + printf("\\node[fill=%s] at (%d,%d) {%d};\n", + color(idx), m, n, idx); + } + } + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n", + int(M), int(N)); + // Labels + for (int m = 0, n = -1; m < M; ++m) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, n, m); + } + for (int m = -1, n = 0; n < N; ++n) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, n, n); + } + + // Footer + printf("\\end{tikzpicture}\n" + "\\end{document}\n"); +} + +template +CUTE_HOST_DEVICE +void +print_latex(ComposedLayout,Layout> const& layout, + TikzColorFn color = {}) // lambda(idx) -> tikz color string) +{ + print_latex(as_position_independent_swizzle_layout(layout), color); +} + +/////////////////////////////// +// LayoutTV 2D to LaTeX TikZ // +/////////////////////////////// + +template +CUTE_HOST_DEVICE +void +print_latex_tv(LayoutTV const& layout_tv, // (t,v) -> m,n coord + Tile_MN const& tile_mn, // (M,N) + TikzColorFn color = {}) // (t,v) -> color +{ + CUTE_STATIC_ASSERT_V(rank(layout_tv) == Int<2>{}); + + // Commented prints + printf("%% Layout TV: "); print(layout_tv); printf("\n"); + // Header + printf("\\documentclass[convert]{standalone}\n" + "\\usepackage{tikz}\n\n" + "\\begin{document}\n" + "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n"); + + auto [M, N] = product_each(shape(tile_mn)); + Tensor filled = make_tensor(make_shape(M, N)); + clear(filled); + + // Layout + for (int tid = 0; tid < size<0>(layout_tv); ++tid) { + for (int vid = 0; vid < size<1>(layout_tv); ++vid) { + auto [m, n] = layout_tv(tid, vid); + if (not filled(m, n)) { + filled(m, n) = true; + printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color(tid, vid), + int(m), int(n), + tid, vid); + } + } + } + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n", int(M), int(N)); + // Labels + for (int m = 0, n = -1; m < M; ++m) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, n, m); + } + for (int n = 0, m = -1; n < N; ++n) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, n, n); + } + // Footer + printf("\\end{tikzpicture}\n" + "\\end{document}\n"); +} + +//////////////////////////// +// MMA Atom to LaTeX TikZ // +//////////////////////////// + +namespace detail { + +template +CUTE_HOST_DEVICE +void +print_latex_mma(LayoutC const& C, // (tid,vid) -> (m,n) coord + LayoutA const& A, // (tid,vid) -> (m,k) coord + LayoutB const& B, // (tid,vid) -> (n,k) coord + Tile_MNK const& tile_mnk, // (M,N,K) + TikzColorFn color = {}) // lambda(tid,vid) -> tikz color string +{ + CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{}); + + // Commented prints + printf("%% LayoutC: "); print(C); printf("\n"); + printf("%% LayoutA: "); print(A); printf("\n"); + printf("%% LayoutB: "); print(B); printf("\n"); + // Header + printf("\\documentclass[convert]{standalone}\n" + "\\usepackage{tikz}\n\n" + "\\begin{document}\n" + "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n"); + + auto [M, N, K] = product_each(shape(tile_mnk)); + Tensor filled = make_tensor(make_shape(M, N, K)); + clear(filled); + + // C starting at 0,0 + for (int tid = 0; tid < size<0>(C); ++tid) { + for (int vid = 0; vid < size<1>(C); ++vid) { + auto [m, n] = C(tid, vid); + if (not filled(m, n, 0)) { + filled(m, n, 0) = true; + printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color(tid, vid), + int(m), int(n), + tid, vid); + } + } + } + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", + 0, 0, int(M), int(N)); + + clear(filled); + + // A starting at 0,-K-1 + for (int tid = 0; tid < size<0>(A); ++tid) { + for (int vid = 0; vid < size<1>(A); ++vid) { + auto [m, k] = A(tid, vid); + if (not filled(m, 0, k)) { + filled(m, 0, k) = true; + printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color(tid, vid), + int(m), int(k-K-1), + tid, vid); + } + } + } + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", + 0, -int(K)-1, int(M), -1); + // A labels + for (int m = 0, k = -1; m < M; ++m) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, int(k-K-1), m); + } + for (int m = -1, k = 0; k < K; ++k) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, int(k-K-1), k); + } + + clear(filled); + + // B starting at -K-1,0 + for (int tid = 0; tid < size<0>(B); ++tid) { + for (int vid = 0; vid < size<1>(B); ++vid) { + auto [n, k] = B(tid, vid); + if (not filled(0, n, k)) { + filled(0, n, k) = true; + printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color(tid, vid), + int(k)-int(K)-1, int(n), + tid, vid); + } + } + } + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", + -int(K)-1, 0, -1, int(N)); + // B labels + for (int n = 0, k = -1; n < N; ++n) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", int(k-K-1), n, n); + } + for (int n = -1, k = 0; k < K; ++k) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", int(k-K-1), n, k); + } + + // Footer + printf("\\end{tikzpicture}\n" + "\\end{document}\n"); +} + +} // end namespace detail + +// MMA Atom to LaTeX TikZ +template +CUTE_HOST_DEVICE +void +print_latex(MMA_Atom const& mma_atom, + TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string +{ + print_latex(make_tiled_mma(mma_atom)); +} + +// TiledMMA to LaTeX TikZ +template +CUTE_HOST_DEVICE +void +print_latex(TiledMMA const& mma, + TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string +{ + auto tile_mnk = tile_shape(mma); + + Tensor refC = make_identity_tensor(select<0,1>(tile_mnk)); + Tensor tensorC_TV = composition(refC, mma.get_layoutC_TV()); + + Tensor refA = make_identity_tensor(select<0,2>(tile_mnk)); + Tensor tensorA_TV = composition(refA, mma.get_layoutA_TV()); + + Tensor refB = make_identity_tensor(select<1,2>(tile_mnk)); + Tensor tensorB_TV = composition(refB, mma.get_layoutB_TV()); + + detail::print_latex_mma(tensorC_TV, tensorA_TV, tensorB_TV, tile_mnk, color); +} + +//////////////////////////// +// CopyAtom to LaTeX TikZ // +//////////////////////////// + +namespace detail { + +// Generic TV Layout to LaTeX TikZ +template +CUTE_HOST_DEVICE +void +print_latex_copy(LayoutS_TV const& S, // (t,v) -> m,n coord + LayoutD_TV const& D, // (t,v) -> m,n coord + Tile_MN const& tile_mn, // (M,N) + TikzColorFn color = {}) // (t,v) -> color +{ + CUTE_STATIC_ASSERT_V(rank(S) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<2>{}); + + // Commented prints + printf("%% Layout S TV: "); print(S); printf("\n"); + printf("%% Layout D TV: "); print(D); printf("\n"); + + // Header + printf("\\documentclass[convert]{standalone}\n" + "\\usepackage{tikz}\n\n" + "\\begin{document}\n" + "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n"); + + auto [M, N] = product_each(shape(tile_mn)); + Tensor filled = make_tensor(make_shape(M, N)); + clear(filled); + + // S starting at 0,0 + for (int tid = 0; tid < size<0>(S); ++tid) { + for (int vid = 0; vid < size<1>(S); ++vid) { + auto [m, n] = S(tid, vid); + if (not filled(m, n)) { + filled(m, n) = true; + printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color(tid, vid), + int(m), int(n), + tid, vid); + } + } + } + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", + 0, 0, int(M), int(N)); + // S Labels + for (int m = 0, n = -1; m < M; ++m) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, n, m); + } + for (int m = -1, n = 0; n < N; ++n) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, n, n); + } + + clear(filled); + + // D starting at 0,N+3 + for (int tid = 0; tid < size<0>(D); ++tid) { + for (int vid = 0; vid < size<1>(D); ++vid) { + auto [m, n] = D(tid, vid); + if (not filled(m, n)) { + filled(m, n) = true; + printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color(tid, vid), + int(m), int(n) + int(N) + 3, + tid, vid); + } + } + } + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", + 0, int(N) + 3, int(M), int(N) + int(N) + 3); + // D Labels + for (int m = 0, n = N; m < M; ++m) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, int(n+N+3), m); + } + for (int m = -1, n = 0; n < N; ++n) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, int(n+N+3), n); + } + + // Footer + printf("\\end{tikzpicture}\n" + "\\end{document}\n"); +} + +} // end namespace detail + +// TiledCopy to LaTeX TikZ +template +CUTE_HOST_DEVICE +void +print_latex(TiledCopy const& copy, + TikzColorFn color = {}) // lambda(tid,vid) -> tikz color string +{ + auto tiler_mn = typename TiledCopy::Tiler_MN{}; + auto tile_mn = product_each(shape(logical_divide(make_layout(Shape<_1,_1>{}), tiler_mn))); // tile_shape + + Tensor refS = make_identity_tensor(tile_mn); + Tensor layoutS_TV = copy.tidfrg_S(refS)(_,_,Int<0>{}); + + Tensor refD = make_identity_tensor(tile_mn); + Tensor layoutD_TV = copy.tidfrg_D(refD)(_,_,Int<0>{}); + + detail::print_latex_copy(layoutS_TV, layoutD_TV, tile_mn, color); +} + +} // end namespace cute diff --git a/3rd/cutlass/include/cute/util/print_svg.hpp b/3rd/cutlass/include/cute/util/print_svg.hpp new file mode 100644 index 0000000..1bfedfc --- /dev/null +++ b/3rd/cutlass/include/cute/util/print_svg.hpp @@ -0,0 +1,257 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE + +#include +#include + +#include +#include + +namespace cute +{ + +//////////////////////////////// +// Common SVG Color utilities // +//////////////////////////////// + +struct TSVGColor_White { + CUTE_HOST_DEVICE char const* + operator()(int idx) const { + return "255,255,255"; + } +}; + +struct TSVGColor_BWx8 { + CUTE_HOST_DEVICE char const* + operator()(int idx) const { + static char const* color_map[8] = {"255,255,255", "230,230,230", "205,205,205", "180,180,180", + "155,155,155", "130,130,130", "105,105,105", "080,080,080"}; + return color_map[idx % 8]; + } +}; + +struct SVGColor_TV { + CUTE_HOST_DEVICE char const* + operator()(int tid, int vid) const { + static char const* color_map[8] = {"175,175,255", "175,255,175", "255,255,175", "255,175,175", + "210,210,255", "210,255,210", "255,255,210", "255,210,210"}; + return color_map[tid % 8]; + } +}; + +///////////////////// +// MMA Atom to SVG // +///////////////////// + +namespace detail { + +template +CUTE_HOST_DEVICE +void +print_svg_mma(LayoutC const& C, + LayoutA const& A, + LayoutB const& B, + Tile_MNK const& tile_mnk, + SVGColorFn color = {}) // lambda(tid,vid) -> SVG color string +{ + CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{}); + + auto [M, N, K] = product_each(shape(tile_mnk)); + + int cell_size = 20; + + int page_width = (K + N + 2) * cell_size; + int page_height = (K + M + 2) * cell_size; + + // Commented print + printf("\n"); + printf("\n"); + printf("\n"); + printf("\n"); + + // SVG Header + printf("\n", + page_width, page_height); + + Tensor filled = make_tensor(make_shape(M, N, K)); + clear(filled); + + // --- Draw C --- + for (int tid = 0; tid < size<0>(C); ++tid) { + for (int vid = 0; vid < size<1>(C); ++vid) { + auto [m, n] = C(tid, vid); + if (!filled(m, n, 0)) { + filled(m, n, 0) = true; + + int x = (n + K + 2) * cell_size; + int y = (m + K + 2) * cell_size; + + printf("\n", + x, y, cell_size, cell_size, color(tid,vid)); + printf("T%d\n", + x + cell_size/2, y + 1*cell_size/4, tid); + printf("V%d\n", + x + cell_size/2, y + 3*cell_size/4, vid); + } + } + } + + clear(filled); + + // --- Draw A --- + for (int tid = 0; tid < size<0>(A); ++tid) { + for (int vid = 0; vid < size<1>(A); ++vid) { + auto [m, k] = A(tid, vid); + if (!filled(m, 0, k)) { + filled(m, 0, k) = true; + + int x = (k + 1) * cell_size; + int y = (m + K + 2) * cell_size; + + printf("\n", + x, y, cell_size, cell_size, color(tid,vid)); + printf("T%d\n", + x + cell_size/2, y + 1*cell_size/4, tid); + printf("V%d\n", + x + cell_size/2, y + 3*cell_size/4, vid); + } + } + } + + // A labels + for (int m = 0, k = -1; m < M; ++m) { + int x = (k + 1) * cell_size; + int y = (m + K + 2) * cell_size; + printf("%d\n", + x + cell_size/2, y + cell_size/2, m); + } + for (int m = -1, k = 0; k < K; ++k) { + int x = (k + 1) * cell_size; + int y = (m + K + 2) * cell_size; + printf("%d\n", + x + cell_size/2, y + cell_size/2, k); + } + + clear(filled); + + // --- Draw B --- + for (int tid = 0; tid < size<0>(B); ++tid) { + for (int vid = 0; vid < size<1>(B); ++vid) { + auto [n, k] = B(tid, vid); + if (!filled(0, n, k)) { + filled(0, n, k) = true; + + int x = (n + K + 2) * cell_size; + int y = (k + 1) * cell_size; + + printf("\n", + x, y, cell_size, cell_size, color(tid,vid)); + printf("T%d\n", + x + cell_size/2, y + 1*cell_size/4, tid); + printf("V%d\n", + x + cell_size/2, y + 3*cell_size/4, vid); + } + } + } + + // B labels + for (int n = 0, k = -1; n < N; ++n) { + int x = (n + K + 2) * cell_size; + int y = (k + 1) * cell_size; + printf("%d\n", + x + cell_size/2, y + cell_size/2, n); + } + for (int n = -1, k = 0; k < K; ++k) { + int x = (n + K + 2) * cell_size; + int y = (k + 1) * cell_size; + printf("%d\n", + x + cell_size/2, y + cell_size/2, k); + } + + // SVG footer + printf("\n"); +} + +} // end namespace detail + +// MMA Atom to SVG +template +CUTE_HOST_DEVICE +void +print_svg(MMA_Atom const& mma_atom, + SVGColorFn color = {}) // lambda(thr_idx,val_idx) -> svg color string +{ + print_svg(make_tiled_mma(mma_atom)); +} + +// TiledMMA to SVG +template +CUTE_HOST_DEVICE +void +print_svg(TiledMMA const& mma, + SVGColorFn color = {}) // lambda(thr_idx,val_idx) -> svg color string +{ + auto tile_mnk = tile_shape(mma); + + Tensor refC = make_identity_tensor(select<0,1>(tile_mnk)); + Tensor tensorC_TV = composition(refC, mma.get_layoutC_TV()); + + Tensor refA = make_identity_tensor(select<0,2>(tile_mnk)); + Tensor tensorA_TV = composition(refA, mma.get_layoutA_TV()); + + Tensor refB = make_identity_tensor(select<1,2>(tile_mnk)); + Tensor tensorB_TV = composition(refB, mma.get_layoutB_TV()); + + detail::print_svg_mma(tensorC_TV, tensorA_TV, tensorB_TV, tile_mnk, color); +} + +} // end namespace cute diff --git a/3rd/cutlass/include/cute/util/print_tensor.hpp b/3rd/cutlass/include/cute/util/print_tensor.hpp new file mode 100644 index 0000000..aabd0d8 --- /dev/null +++ b/3rd/cutlass/include/cute/util/print_tensor.hpp @@ -0,0 +1,197 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE + +#include +#include + +namespace cute +{ + +//////////////////////////////// +// Layout 2D to Console table // +//////////////////////////////// + +template +CUTE_HOST_DEVICE +void +print_layout(Layout const& layout) // (m,n) -> idx +{ + CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{}); + + int idx_width = num_digits(cosize(layout)) + 2; + + print(layout); print("\n"); + + // Column indices + print(" "); + for (int n = 0; n < size<1>(layout); ++n) { printf(" %*d ", idx_width-2, n); } + printf("\n"); + + // Print out A m-by-n + for (int m = 0; m < size<0>(layout); ++m) { + // Header + print(" "); + for (int n = 0; n < size<1>(layout); ++n) { + printf("+"); + for (int i = 0; i < idx_width; ++i) { + printf("-"); + } + } + printf("+\n"); + // Values + printf("%2d ", m); // Row indices + for (int n = 0; n < size<1>(layout); ++n) { printf("| %*d ", idx_width-2, int(layout(m,n))); } + printf("|\n"); + } + // Footer + print(" "); + for (int n = 0; n < size<1>(layout); ++n) { + printf("+"); + for (int i = 0; i < idx_width; ++i) { + printf("-"); + } + } + printf("+\n"); +} + +// Capture and cast smem_ptr_flag Layouts to offset-0 layouts +template +CUTE_HOST_DEVICE +void +print_layout(ComposedLayout,Layout> const& layout) +{ + print_layout(as_position_independent_swizzle_layout(layout)); +} + +//////////////////////////////// +// Tensor 1D,2D,3D,4D Console // +//////////////////////////////// + +template +CUTE_HOST_DEVICE +void +print_tensor(Tensor const& tensor, bool print_type = true) +{ + if (print_type) { + print(tensor); print(":\n"); + } + + if constexpr (Layout::rank == 1) + { + for (int m = 0; m < size(tensor); ++m) { + pretty_print(tensor(m)); + printf("\n"); + } + } else + if constexpr (Layout::rank == 2) + { + for (int m = 0; m < size<0>(tensor); ++m) { + for (int n = 0; n < size<1>(tensor); ++n) { + pretty_print(tensor(m,n)); + } + printf("\n"); + } + } else + if constexpr (Layout::rank == 3) + { + print_tensor(tensor(_,_,0), false); + for (int k = 1; k < size<2>(tensor); ++k) { + for (int i = 0; i < 5*size<1>(tensor); ++i) { print("-"); } print("\n"); + print_tensor(tensor(_,_,k), false); + } + } else + if constexpr (Layout::rank == 4) + { + print_tensor(tensor(_,_,_,0), false); + for (int p = 1; p < size<3>(tensor); ++p) { + for (int i = 0; i < 5*size<1>(tensor); ++i) { print("="); } print("\n"); + print_tensor(tensor(_,_,_,p), false); + } + } +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST +std::ostream& +print_tensor_os(std::ostream& os, Tensor const& tensor) +{ + int digits = 9; + + if constexpr (Layout::rank == 1) + { + for (int m = 0; m < size(tensor); ++m) { + os << std::setw(digits) << tensor(m) << std::endl; + } + } else + if constexpr (Layout::rank == 2) + { + for (int m = 0; m < size<0>(tensor); ++m) { + for (int n = 0; n < size<1>(tensor); ++n) { + os << std::setw(digits) << tensor(m,n); + } + os << std::endl; + } + } else + if constexpr (Layout::rank == 3) + { + print_tensor_os(os, tensor(_,_,0)); + for (int k = 1; k < size<2>(tensor); ++k) { + for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "-"; } os << std::endl; + print_tensor_os(os, tensor(_,_,k)); + } + } else + if constexpr (Layout::rank == 4) + { + print_tensor_os(os, tensor(_,_,_,0)); + for (int p = 1; p < size<3>(tensor); ++p) { + for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "="; } os << std::endl; + print_tensor_os(os, tensor(_,_,_,p)); + } + } + + return os; +} + +template +CUTE_HOST +std::ostream& +operator<<(std::ostream& os, Tensor const& tensor) +{ + os << tensor.layout() << std::endl; + return print_tensor_os(os, tensor); +} +#endif // !defined(__CUDACC_RTC__) + +} // end namespace cute diff --git a/3rd/cutlass/include/cute/util/type_traits.hpp b/3rd/cutlass/include/cute/util/type_traits.hpp index ee361c7..5ac263c 100644 --- a/3rd/cutlass/include/cute/util/type_traits.hpp +++ b/3rd/cutlass/include/cute/util/type_traits.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -29,13 +29,13 @@ * **************************************************************************************************/ #pragma once - +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include -#include -#include -#include -#include +#include CUDA_STD_HEADER(type_traits) +#include CUDA_STD_HEADER(utility) +#include CUDA_STD_HEADER(cstddef) +#include CUDA_STD_HEADER(cstdint) +#include CUDA_STD_HEADER(limits) #else #include #include // tuple_size, tuple_element @@ -92,6 +92,29 @@ using CUTE_STL_NAMESPACE::remove_const_t; using CUTE_STL_NAMESPACE::remove_cv_t; using CUTE_STL_NAMESPACE::remove_reference_t; +template +struct copy_cv { + using type = Dst; +}; + +template +struct copy_cv { + using type = Dst const; +}; + +template +struct copy_cv { + using type = Dst volatile; +}; + +template +struct copy_cv { + using type = Dst const volatile; +}; + +template +using copy_cv_t = typename copy_cv::type; + using CUTE_STL_NAMESPACE::extent; using CUTE_STL_NAMESPACE::remove_extent; diff --git a/3rd/cutlass/include/cutlass/aligned_buffer.h b/3rd/cutlass/include/cutlass/aligned_buffer.h index 8468f54..7b2c767 100644 --- a/3rd/cutlass/include/cutlass/aligned_buffer.h +++ b/3rd/cutlass/include/cutlass/aligned_buffer.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/arch/arch.h b/3rd/cutlass/include/cutlass/arch/arch.h index c9c636a..faaab62 100644 --- a/3rd/cutlass/include/cutlass/arch/arch.h +++ b/3rd/cutlass/include/cutlass/arch/arch.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -109,6 +109,10 @@ struct Sm120 { static int const kMinComputeCapability = 120; }; +struct Sm103 { + static int const kMinComputeCapability = 103; +}; + /// Triggers a breakpoint on the device CUTLASS_DEVICE void device_breakpoint() { diff --git a/3rd/cutlass/include/cutlass/arch/barrier.h b/3rd/cutlass/include/cutlass/arch/barrier.h index 3d5ec10..015de3e 100644 --- a/3rd/cutlass/include/cutlass/arch/barrier.h +++ b/3rd/cutlass/include/cutlass/arch/barrier.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -33,6 +33,7 @@ */ #pragma once +#include "cutlass/cutlass.h" #include #include @@ -46,11 +47,13 @@ #endif -#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED)) +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM110A_ENABLED)) #define CUTLASS_ARCH_TCGEN_ENABLED 1 #endif -#if (defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED)) +#if (defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM110F_ENABLED)) #define CUTLASS_ARCH_TCGEN_ENABLED 1 #endif @@ -281,20 +284,20 @@ class NamedBarrier { CUTLASS_DEVICE static void arrive_and_wait_internal(uint32_t num_threads, uint32_t barrier_id) { #if CUDA_BARRIER_ENABLED - asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); + asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads) : "memory"); cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); +#else + CUTLASS_NOT_IMPLEMENTED(); #endif } CUTLASS_DEVICE static void arrive_and_wait_internal_unaligned(uint32_t num_threads, uint32_t barrier_id) { #if CUDA_BARRIER_ENABLED - asm volatile("barrier.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); + asm volatile("barrier.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads) : "memory"); cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); +#else + CUTLASS_NOT_IMPLEMENTED(); #endif } @@ -302,9 +305,9 @@ class NamedBarrier { static void arrive_internal(uint32_t num_threads, uint32_t barrier_id) { #if CUDA_BARRIER_ENABLED cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id); - asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads)); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); + asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads) : "memory"); +#else + CUTLASS_NOT_IMPLEMENTED(); #endif } @@ -312,9 +315,9 @@ class NamedBarrier { static void arrive_internal_unaligned(uint32_t num_threads, uint32_t barrier_id) { #if CUDA_BARRIER_ENABLED cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id); - asm volatile("barrier.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads)); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); + asm volatile("barrier.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads) : "memory"); +#else + CUTLASS_NOT_IMPLEMENTED(); #endif } @@ -386,6 +389,7 @@ struct ClusterBarrier { // CUTLASS_HOST_DEVICE static void init(ValueType const* smem_ptr, uint32_t arrive_count) { + CUTLASS_ASSERT(arrive_count != 0 && "Arrive count must be non-zero"); #if CUDA_BARRIER_ENABLED uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); asm volatile( @@ -393,10 +397,11 @@ struct ClusterBarrier { "mbarrier.init.shared::cta.b64 [%1], %0; \n" "}" : - : "r"(arrive_count), "r"(smem_addr)); + : "r"(arrive_count), "r"(smem_addr) + : "memory"); cutlass::arch::synclog_emit_cluster_barrier_init(__LINE__, smem_addr, arrive_count); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); +#else + CUTLASS_NOT_IMPLEMENTED(); #endif } @@ -418,10 +423,11 @@ struct ClusterBarrier { "DONE: \n\t" "}" : - : "r"(smem_addr), "r"(phase), "r"(ticks)); + : "r"(smem_addr), "r"(phase), "r"(ticks) + : "memory"); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); +#else + CUTLASS_NOT_IMPLEMENTED(); #endif } @@ -441,11 +447,12 @@ struct ClusterBarrier { "selp.b32 %0, 1, 0, P1; \n\t" "}" : "=r"(waitComplete) - : "r"(smem_addr), "r"(phase), "r"(pred)); + : "r"(smem_addr), "r"(phase), "r"(pred) + : "memory"); return static_cast(waitComplete); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); +#else + CUTLASS_NOT_IMPLEMENTED(); #endif return 0; } @@ -464,11 +471,12 @@ struct ClusterBarrier { "selp.b32 %0, 1, 0, P1; \n\t" "}" : "=r"(waitComplete) - : "r"(smem_addr), "r"(phase)); + : "r"(smem_addr), "r"(phase) + : "memory"); return static_cast(waitComplete); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); +#else + CUTLASS_NOT_IMPLEMENTED(); #endif return 0; } @@ -486,12 +494,13 @@ struct ClusterBarrier { "mbarrier.arrive.shared::cluster.b64 _, [remAddr32];\n\t" "}" : - : "r"(smem_addr), "r"(cta_id)); + : "r"(smem_addr), "r"(cta_id) + : "memory"); } cutlass::arch::synclog_emit_cluster_barrier_arrive_cluster(__LINE__, smem_addr, cta_id, pred); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); +#else + CUTLASS_NOT_IMPLEMENTED(); #endif } @@ -505,10 +514,11 @@ struct ClusterBarrier { "mbarrier.arrive.shared::cta.b64 _, [%0];\n\t" "}" : - : "r"(smem_addr)); + : "r"(smem_addr) + : "memory"); cutlass::arch::synclog_emit_cluster_barrier_arrive(__LINE__, smem_addr); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); +#else + CUTLASS_NOT_IMPLEMENTED(); #endif } @@ -521,9 +531,10 @@ struct ClusterBarrier { "mbarrier.inval.shared::cta.b64 [%0]; \n\t" "}" : - : "r"(smem_addr)); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); + : "r"(smem_addr) + : "memory"); +#else + CUTLASS_NOT_IMPLEMENTED(); #endif } }; @@ -582,10 +593,11 @@ struct ClusterTransactionBarrier : public ClusterBarrier { "mbarrier.arrive.expect_tx.shared::cta.b64 _, [%1], %0; \n\t" "}" : - : "r"(transaction_bytes), "r"(smem_addr)); + : "r"(transaction_bytes), "r"(smem_addr) + : "memory"); cutlass::arch::synclog_emit_cluster_transaction_barrier_arrive_and_expect_tx(__LINE__, smem_addr, transaction_bytes); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); +#else + CUTLASS_NOT_IMPLEMENTED(); #endif } @@ -604,9 +616,10 @@ struct ClusterTransactionBarrier : public ClusterBarrier { "@p mbarrier.arrive.expect_tx.shared::cluster.b64 _, [remAddr32], %3;\n\t" "}" : - : "r"(smem_addr), "r"(cta_id), "r"(pred), "r"(transaction_bytes)); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); + : "r"(smem_addr), "r"(cta_id), "r"(pred), "r"(transaction_bytes) + : "memory"); +#else + CUTLASS_NOT_IMPLEMENTED(); #endif } @@ -620,10 +633,11 @@ struct ClusterTransactionBarrier : public ClusterBarrier { "mbarrier.expect_tx.shared::cta.b64 [%1], %0; \n\t" "}" : - : "r"(transaction_bytes), "r"(smem_addr)); + : "r"(transaction_bytes), "r"(smem_addr) + : "memory"); cutlass::arch::synclog_emit_cluster_transaction_barrier_expect_transaction(__LINE__, smem_addr, transaction_bytes); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); +#else + CUTLASS_NOT_IMPLEMENTED(); #endif } @@ -641,10 +655,11 @@ struct ClusterTransactionBarrier : public ClusterBarrier { "@p mbarrier.complete_tx.shared::cluster.relaxed.cluster.b64 [%1], %0;" "}" : - : "r"(transaction_bytes), "r"(smem_addr), "r"(pred)); + : "r"(transaction_bytes), "r"(smem_addr), "r"(pred) + : "memory"); cutlass::arch::synclog_emit_cluster_transaction_barrier_complete_transaction(__LINE__, smem_addr, dst_cta_id, transaction_bytes, pred); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); +#else + CUTLASS_NOT_IMPLEMENTED(); #endif } @@ -701,9 +716,10 @@ void fence_barrier_init() { "{\n\t" "fence.mbarrier_init.release.cluster; \n" "}" - ::); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); + :: + : "memory"); +#else + CUTLASS_NOT_IMPLEMENTED(); #endif } @@ -716,9 +732,25 @@ void fence_view_async_shared() { "{\n\t" "fence.proxy.async.shared::cta; \n" "}" - ::); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); + :: + : "memory"); +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif +} + +CUTLASS_DEVICE +void fence_view_shared() { +#if CUDA_BARRIER_ENABLED + cutlass::arch::synclog_emit_fence_view_shared(__LINE__); + asm volatile ( + "{\n\t" + "fence.release.sync_restrict::shared::cta.cluster; \n" + "}" + :: + : "memory"); +#else + CUTLASS_NOT_IMPLEMENTED(); #endif } @@ -732,10 +764,11 @@ void cpasync_barrier_arrive(uint64_t const* smem_ptr) { "cp.async.mbarrier.arrive.shared::cta.b64 [%0];\n\t" "}" : - : "r"(smem_addr)); + : "r"(smem_addr) + : "memory"); cutlass::arch::synclog_emit_cpasync_barrier_arrive(__LINE__, smem_addr); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); +#else + CUTLASS_NOT_IMPLEMENTED(); #endif } @@ -749,10 +782,11 @@ void cpasync_barrier_arrive_noinc(uint64_t const* smem_ptr) { "cp.async.mbarrier.arrive.noinc.shared::cta.b64 [%0];\n\t" "}" : - : "r"(smem_addr)); + : "r"(smem_addr) + : "memory"); cutlass::arch::synclog_emit_cpasync_barrier_arrive(__LINE__, smem_addr); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); +#else + CUTLASS_NOT_IMPLEMENTED(); #endif } @@ -766,10 +800,11 @@ void umma_arrive(uint64_t const* smem_ptr) { if (cute::elect_one_sync()) { asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];" : - :"r"(bar_intptr)); + :"r"(bar_intptr) + : "memory"); } -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); +#else + CUTLASS_NOT_IMPLEMENTED(); #endif } @@ -781,10 +816,11 @@ void umma_arrive_2x1SM(uint64_t const* smem_ptr) { if (cute::elect_one_sync()) { asm volatile("tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.b64 [%0];" : - :"r"(bar_intptr)); + :"r"(bar_intptr) + : "memory"); } -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); +#else + CUTLASS_NOT_IMPLEMENTED(); #endif } @@ -799,10 +835,11 @@ void umma_arrive_multicast(uint64_t const* smem_ptr, uint16_t cta_mask) { "tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1; \n\t" "}" : - :"r"(bar_intptr), "h"(cta_mask)); + :"r"(bar_intptr), "h"(cta_mask) + : "memory"); } -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); +#else + CUTLASS_NOT_IMPLEMENTED(); #endif } @@ -817,10 +854,11 @@ void umma_arrive_multicast_2x1SM(uint64_t const* smem_ptr, uint16_t cta_mask) { "tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1; \n\t" "}" : - :"r"(bar_intptr), "h"(cta_mask)); + :"r"(bar_intptr), "h"(cta_mask) + : "memory"); } -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); +#else + CUTLASS_NOT_IMPLEMENTED(); #endif } @@ -837,9 +875,10 @@ void umma_arrive_multicast_no_elect(uint64_t const* smem_ptr, uint16_t cta_mask) "tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], lo; \n\t" "}" : - :"r"(bar_intptr), "r"(uint32_t(cta_mask))); -#elif defined(__CUDA_ARCH__) - CUTLASS_NOT_IMPLEMENTED(); + :"r"(bar_intptr), "r"(uint32_t(cta_mask)) + : "memory"); +#else + CUTLASS_NOT_IMPLEMENTED(); #endif } @@ -856,7 +895,8 @@ void umma_arrive_multicast_2x1SM_no_elect(uint64_t const* smem_ptr, uint16_t cta "tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], lo; \n\t" "}" : - :"r"(bar_intptr), "r"(uint32_t(cta_mask))); + :"r"(bar_intptr), "r"(uint32_t(cta_mask)) + : "memory"); #else CUTLASS_NOT_IMPLEMENTED(); #endif @@ -872,10 +912,11 @@ void umma_arrive_2x1SM_sm0(uint64_t const* smem_ptr) { "mbarrier.arrive.shared::cluster.b64 _, [%0];\n\t" "}" : - : "r"(bar_intptr)); + : "r"(bar_intptr) + : "memory"); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); +#else + CUTLASS_NOT_IMPLEMENTED(); #endif } @@ -885,9 +926,10 @@ CUTE_DEVICE static void fence_view_async_tmem_load() { "{\n\t" "tcgen05.wait::ld.sync.aligned; \n" "}" - ::); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); + :: + : "memory"); +#else + CUTLASS_NOT_IMPLEMENTED(); #endif } @@ -897,9 +939,10 @@ CUTE_DEVICE static void fence_view_async_tmem_store() { "{\n\t" "tcgen05.wait::st.sync.aligned; \n" "}" - ::); -#elif defined(__CUDA_ARCH__) - asm volatile ("brkpt;\n" ::); + :: + : "memory"); +#else + CUTLASS_NOT_IMPLEMENTED(); #endif } diff --git a/3rd/cutlass/include/cutlass/arch/cache_operation.h b/3rd/cutlass/include/cutlass/arch/cache_operation.h index 5128ee0..0109c03 100644 --- a/3rd/cutlass/include/cutlass/arch/cache_operation.h +++ b/3rd/cutlass/include/cutlass/arch/cache_operation.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/arch/config.h b/3rd/cutlass/include/cutlass/arch/config.h index 60be8d7..995d82e 100644 --- a/3rd/cutlass/include/cutlass/arch/config.h +++ b/3rd/cutlass/include/cutlass/arch/config.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -128,6 +128,27 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// +// SM110 and SM110a only on 13.0 and above +#if !CUTLASS_CLANG_CUDA && (__CUDACC_VER_MAJOR__ > 13 || (__CUDACC_VER_MAJOR__ == 13 && __CUDACC_VER_MINOR__ >= 0)) + #define CUTLASS_ARCH_MMA_SM110_SUPPORTED 1 + #if (!defined(CUTLASS_ARCH_MMA_SM110_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1100) + #define CUTLASS_ARCH_MMA_SM110_ENABLED 1 + + #if (!defined(CUTLASS_ARCH_MMA_SM110A_ENABLED) && defined(__CUDA_ARCH_FEAT_SM110_ALL)) + #define CUTLASS_ARCH_MMA_SM110A_ENABLED 1 + #endif + + // SM110f + #if (__CUDACC_VER_MAJOR__ > 13 || (__CUDACC_VER_MAJOR__ == 13 && __CUDACC_VER_MINOR__ >= 0)) + #define CUTLASS_ARCH_MMA_SM110F_SUPPORTED 1 + #endif + + #if (!defined(CUTLASS_ARCH_MMA_SM110F_ENABLED) && CUDA_ARCH_FAMILY(1100)) + #define CUTLASS_ARCH_MMA_SM110F_ENABLED CUTLASS_ARCH_MMA_SM110F_SUPPORTED + #endif + #endif +#endif + ///////////////////////////////////////////////////////////////////////////////////////////////// // SM120 and SM120a @@ -151,10 +172,56 @@ #endif #endif +// SM103 and SM103a +#if !CUTLASS_CLANG_CUDA && (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)) + #define CUTLASS_ARCH_MMA_SM103_SUPPORTED 1 + #if (!defined(CUTLASS_ARCH_MMA_SM103_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1030) + #define CUTLASS_ARCH_MMA_SM103_ENABLED 1 + + #if (!defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) && defined(__CUDA_ARCH_FEAT_SM103_ALL)) + #define CUTLASS_ARCH_MMA_SM103A_ENABLED 1 + #endif + + // SM103f + #if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)) + #define CUTLASS_ARCH_MMA_SM103F_SUPPORTED 1 + #endif + + #if (!defined(CUTLASS_ARCH_MMA_SM103F_ENABLED) && CUDA_ARCH_FAMILY(1030)) + #define CUTLASS_ARCH_MMA_SM103F_ENABLED CUTLASS_ARCH_MMA_SM103F_SUPPORTED + #endif + #endif +#endif + +// SM121 and SM121a +#if !CUTLASS_CLANG_CUDA && (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)) + #define CUTLASS_ARCH_MMA_SM121_SUPPORTED 1 + #if (!defined(CUTLASS_ARCH_MMA_SM121_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1210) + #define CUTLASS_ARCH_MMA_SM121_ENABLED 1 + + #if (!defined(CUTLASS_ARCH_MMA_SM121A_ENABLED) &&\ + (defined(__CUDA_ARCH_FEAT_SM121_ALL) || CUDA_ARCH_CONDITIONAL(1210))) + #define CUTLASS_ARCH_MMA_SM121A_ENABLED 1 + #endif + + // SM121f + #if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)) + #define CUTLASS_ARCH_MMA_SM121F_SUPPORTED 1 + #endif + + #if (!defined(CUTLASS_ARCH_MMA_SM121F_ENABLED) && CUDA_ARCH_FAMILY(1210)) + #define CUTLASS_ARCH_MMA_SM121F_ENABLED CUTLASS_ARCH_MMA_SM121F_SUPPORTED + #endif + #endif +#endif + #if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) ||\ defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) ||\ - defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120F_ENABLED)) + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM103F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM110A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM110F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM121A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121F_ENABLED)) # define CUTLASS_ARCH_CLC_ENABLED #endif diff --git a/3rd/cutlass/include/cutlass/arch/grid_dependency_control.h b/3rd/cutlass/include/cutlass/arch/grid_dependency_control.h index f1e0200..63a7714 100644 --- a/3rd/cutlass/include/cutlass/arch/grid_dependency_control.h +++ b/3rd/cutlass/include/cutlass/arch/grid_dependency_control.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -62,8 +62,14 @@ (defined(__CUDA_ARCH_FEAT_SM100_ALL) || CUDA_ARCH_FAMILY(1000))) || \ (__CUDA_ARCH__ == 1010 &&\ (defined(__CUDA_ARCH_FEAT_SM101_ALL) || CUDA_ARCH_FAMILY(1010))) || \ + (__CUDA_ARCH__ == 1100 &&\ + (defined(__CUDA_ARCH_FEAT_SM110_ALL) || CUDA_ARCH_FAMILY(1100))) || \ + (__CUDA_ARCH__ == 1030 &&\ + (defined(__CUDA_ARCH_FEAT_SM103_ALL) || CUDA_ARCH_FAMILY(1030))) || \ (__CUDA_ARCH__ == 1200 &&\ - (defined(__CUDA_ARCH_FEAT_SM120_ALL) || CUDA_ARCH_FAMILY(1200))))) + (defined(__CUDA_ARCH_FEAT_SM120_ALL) || CUDA_ARCH_FAMILY(1200))) || \ + (__CUDA_ARCH__ == 1210 &&\ + (defined(__CUDA_ARCH_FEAT_SM121_ALL) || CUDA_ARCH_CONDITIONAL_OR_FAMILY(1210))))) #define CUTLASS_GDC_ENABLED #endif #endif diff --git a/3rd/cutlass/include/cutlass/arch/memory.h b/3rd/cutlass/include/cutlass/arch/memory.h index 0fb47b1..07cb3e7 100644 --- a/3rd/cutlass/include/cutlass/arch/memory.h +++ b/3rd/cutlass/include/cutlass/arch/memory.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/arch/memory_sm75.h b/3rd/cutlass/include/cutlass/arch/memory_sm75.h index 040f707..7b62069 100644 --- a/3rd/cutlass/include/cutlass/arch/memory_sm75.h +++ b/3rd/cutlass/include/cutlass/arch/memory_sm75.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/arch/memory_sm80.h b/3rd/cutlass/include/cutlass/arch/memory_sm80.h index 4e81293..c3fbdd7 100644 --- a/3rd/cutlass/include/cutlass/arch/memory_sm80.h +++ b/3rd/cutlass/include/cutlass/arch/memory_sm80.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -40,6 +40,7 @@ #include "cutlass/arch/memory.h" #include "cutlass/arch/memory_sm75.h" #include "cutlass/arch/cache_operation.h" +#include "cutlass/arch/synclog.hpp" #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) #define CUDA_CP_ASYNC_ACTIVATED 1 diff --git a/3rd/cutlass/include/cutlass/arch/mma.h b/3rd/cutlass/include/cutlass/arch/mma.h index 40c8200..e70e1b7 100644 --- a/3rd/cutlass/include/cutlass/arch/mma.h +++ b/3rd/cutlass/include/cutlass/arch/mma.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/arch/mma_sm100.h b/3rd/cutlass/include/cutlass/arch/mma_sm100.h new file mode 100644 index 0000000..92af907 --- /dev/null +++ b/3rd/cutlass/include/cutlass/arch/mma_sm100.h @@ -0,0 +1,120 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Matrix multiply +*/ + +#pragma once +#include "cutlass/cutlass.h" +#ifndef __QNX__ +#include CUDA_STD_HEADER(cassert) +#endif + +#include "cutlass/arch/mma.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/config.h" +#include "cute/arch/simd_sm100.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass{ +namespace arch { + + +/// Matrix multiply-add operation +template < + /// Data type of A elements + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC +> +struct Mma, 1, ElementA, LayoutA, ElementB, LayoutB, ElementC_, LayoutC, OpMultiplyAdd> { + + using Shape = gemm::GemmShape<2, 1, 1>; + using Operator = OpMultiplyAdd; + using ElementC = ElementC_; + + CUTLASS_DEVICE + void operator()( + Array &d, + Array const &a, + Array const &b, + Array const &c + ) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + d[i] = a[i] * b[0] + c[i]; + } + } +}; + +/// Matrix multiply-add operation +template < + /// Layout of A matrix + typename LayoutA, + /// Layout of B matrix + typename LayoutB, + /// Layout of C matrix + typename LayoutC +> +struct Mma, 1, float, LayoutA, float, LayoutB, float, LayoutC, OpMultiplyAdd> { + + using Shape = gemm::GemmShape<2, 1, 1>; + using Operator = OpMultiplyAdd; + using ElementC = float; + + CUTLASS_DEVICE + void operator()( + Array &d, + Array const &a, + Array const &b, + Array const &c + ) { + float2 result; + cute::fma(result, make_float2(a[0], a[1]), make_float2(b[0], b[0]), make_float2(c[0], c[1])); + d[0] = result.x; + d[1] = result.y; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass diff --git a/3rd/cutlass/include/cutlass/arch/mma_sm50.h b/3rd/cutlass/include/cutlass/arch/mma_sm50.h index 1701158..908ae3b 100644 --- a/3rd/cutlass/include/cutlass/arch/mma_sm50.h +++ b/3rd/cutlass/include/cutlass/arch/mma_sm50.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/arch/mma_sm60.h b/3rd/cutlass/include/cutlass/arch/mma_sm60.h index 31ef2b6..a6c2d19 100644 --- a/3rd/cutlass/include/cutlass/arch/mma_sm60.h +++ b/3rd/cutlass/include/cutlass/arch/mma_sm60.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/arch/mma_sm61.h b/3rd/cutlass/include/cutlass/arch/mma_sm61.h index b780335..8dde6e3 100644 --- a/3rd/cutlass/include/cutlass/arch/mma_sm61.h +++ b/3rd/cutlass/include/cutlass/arch/mma_sm61.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/arch/mma_sm70.h b/3rd/cutlass/include/cutlass/arch/mma_sm70.h index e4889a2..1680649 100644 --- a/3rd/cutlass/include/cutlass/arch/mma_sm70.h +++ b/3rd/cutlass/include/cutlass/arch/mma_sm70.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -32,8 +32,10 @@ \brief Matrix multiply */ #pragma once - -#include +#include "cutlass/cutlass.h" +#ifndef __QNX__ +#include CUDA_STD_HEADER(cassert) +#endif #include "mma.h" #include "cutlass/layout/matrix.h" diff --git a/3rd/cutlass/include/cutlass/arch/mma_sm75.h b/3rd/cutlass/include/cutlass/arch/mma_sm75.h index 120b116..3b5e51d 100644 --- a/3rd/cutlass/include/cutlass/arch/mma_sm75.h +++ b/3rd/cutlass/include/cutlass/arch/mma_sm75.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -33,8 +33,10 @@ */ #pragma once - -#include +#include "cutlass/cutlass.h" +#ifndef __QNX__ +#include CUDA_STD_HEADER(cassert) +#endif #include "cutlass/arch/wmma.h" diff --git a/3rd/cutlass/include/cutlass/arch/mma_sm80.h b/3rd/cutlass/include/cutlass/arch/mma_sm80.h index d89974f..48536b9 100644 --- a/3rd/cutlass/include/cutlass/arch/mma_sm80.h +++ b/3rd/cutlass/include/cutlass/arch/mma_sm80.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -33,10 +33,11 @@ */ #pragma once - -#include - #include "cutlass/cutlass.h" +#ifndef __QNX__ +#include CUDA_STD_HEADER(cassert) +#endif + #include "mma.h" #include "cutlass/layout/matrix.h" #include "cutlass/numeric_types.h" diff --git a/3rd/cutlass/include/cutlass/arch/mma_sm89.h b/3rd/cutlass/include/cutlass/arch/mma_sm89.h index a4a8b1c..493442a 100644 --- a/3rd/cutlass/include/cutlass/arch/mma_sm89.h +++ b/3rd/cutlass/include/cutlass/arch/mma_sm89.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -34,10 +34,11 @@ */ #pragma once - -#include - #include "cutlass/cutlass.h" +#ifndef __QNX__ +#include CUDA_STD_HEADER(cassert) +#endif + #include "mma.h" #include "cutlass/layout/matrix.h" #include "cutlass/numeric_types.h" diff --git a/3rd/cutlass/include/cutlass/arch/mma_sm90.h b/3rd/cutlass/include/cutlass/arch/mma_sm90.h index b1314a5..71aa4b9 100644 --- a/3rd/cutlass/include/cutlass/arch/mma_sm90.h +++ b/3rd/cutlass/include/cutlass/arch/mma_sm90.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -33,8 +33,10 @@ */ #pragma once - -#include +#include "cutlass/cutlass.h" +#ifndef __QNX__ +#include CUDA_STD_HEADER(cassert) +#endif #include "mma.h" #include "cutlass/layout/matrix.h" @@ -222,7 +224,7 @@ struct Mma< asm volatile("mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64 {%0, %1, %2, %3}, {%4, %5, %6, %7, %8, %9, %10, %11}, {%12, %13, %14, %15}, {%16, %17, %18, %19};\n" : "=d"(D[0]), "=d"(D[1]), "=d"(D[2]), "=d"(D[3]) - : "d"(A[0]), "d"(A[2]), "d"(A[2]), "d"(A[3]), "d"(A[4]), "d"(A[5]), "d"(A[6]), "d"(A[7]), + : "d"(A[0]), "d"(A[1]), "d"(A[2]), "d"(A[3]), "d"(A[4]), "d"(A[5]), "d"(A[6]), "d"(A[7]), "d"(B[0]), "d"(B[1]), "d"(B[2]), "d"(B[3]), "d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3])); diff --git a/3rd/cutlass/include/cutlass/arch/mma_sparse_sm80.h b/3rd/cutlass/include/cutlass/arch/mma_sparse_sm80.h index 187ccc1..c39a569 100644 --- a/3rd/cutlass/include/cutlass/arch/mma_sparse_sm80.h +++ b/3rd/cutlass/include/cutlass/arch/mma_sparse_sm80.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -34,8 +34,10 @@ */ #pragma once - -#include +#include "cutlass/cutlass.h" +#ifndef __QNX__ +#include CUDA_STD_HEADER(cassert) +#endif #include "mma.h" #include "cutlass/layout/matrix.h" diff --git a/3rd/cutlass/include/cutlass/arch/mma_sparse_sm89.h b/3rd/cutlass/include/cutlass/arch/mma_sparse_sm89.h index 27c40dc..00a0a34 100644 --- a/3rd/cutlass/include/cutlass/arch/mma_sparse_sm89.h +++ b/3rd/cutlass/include/cutlass/arch/mma_sparse_sm89.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -34,8 +34,10 @@ */ #pragma once - -#include +#include "cutlass/cutlass.h" +#ifndef __QNX__ +#include CUDA_STD_HEADER(cassert) +#endif #include "mma.h" #include "cutlass/layout/matrix.h" diff --git a/3rd/cutlass/include/cutlass/arch/reg_reconfig.h b/3rd/cutlass/include/cutlass/arch/reg_reconfig.h index a65ee32..7e1461b 100644 --- a/3rd/cutlass/include/cutlass/arch/reg_reconfig.h +++ b/3rd/cutlass/include/cutlass/arch/reg_reconfig.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -36,13 +36,20 @@ #pragma once #include "cutlass/cutlass.h" +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif #ifndef CUDA_CTA_RECONFIG_ACTIVATED #if defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 12 && ( \ (__CUDA_ARCH__ == 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)) \ || (__CUDA_ARCH__ == 1000 && defined(__CUDA_ARCH_FEAT_SM100_ALL)) \ || (__CUDA_ARCH__ == 1010 && defined(__CUDA_ARCH_FEAT_SM101_ALL)) \ + || (__CUDA_ARCH__ == 1030 && defined(__CUDA_ARCH_FEAT_SM103_ALL)) \ || (__CUDA_ARCH__ == 1200 && defined(__CUDA_ARCH_FEAT_SM120_ALL)) \ + || (__CUDA_ARCH__ == 1210 && defined(__CUDA_ARCH_FEAT_SM121_ALL)) \ ) #define CUDA_CTA_RECONFIG_ACTIVATED 1 #endif @@ -50,7 +57,9 @@ #if defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 12 && ( \ (__CUDA_ARCH__ == 1000 && CUDA_ARCH_FAMILY(1000)) \ || (__CUDA_ARCH__ == 1010 && CUDA_ARCH_FAMILY(1010)) \ + || (__CUDA_ARCH__ == 1030 && CUDA_ARCH_FAMILY(1030)) \ || (__CUDA_ARCH__ == 1200 && CUDA_ARCH_FAMILY(1200)) \ + || (__CUDA_ARCH__ == 1210 && CUDA_ARCH_CONDITIONAL_OR_FAMILY(1210)) \ ) #define CUDA_CTA_RECONFIG_ACTIVATED 1 #endif diff --git a/3rd/cutlass/include/cutlass/arch/simd.h b/3rd/cutlass/include/cutlass/arch/simd.h index a1dc7df..e78095c 100644 --- a/3rd/cutlass/include/cutlass/arch/simd.h +++ b/3rd/cutlass/include/cutlass/arch/simd.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/arch/simd_sm60.h b/3rd/cutlass/include/cutlass/arch/simd_sm60.h index 59f38d6..d454036 100644 --- a/3rd/cutlass/include/cutlass/arch/simd_sm60.h +++ b/3rd/cutlass/include/cutlass/arch/simd_sm60.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/arch/simd_sm61.h b/3rd/cutlass/include/cutlass/arch/simd_sm61.h index 46c2266..048c567 100644 --- a/3rd/cutlass/include/cutlass/arch/simd_sm61.h +++ b/3rd/cutlass/include/cutlass/arch/simd_sm61.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/arch/synclog.hpp b/3rd/cutlass/include/cutlass/arch/synclog.hpp index b981983..ffa7bae 100644 --- a/3rd/cutlass/include/cutlass/arch/synclog.hpp +++ b/3rd/cutlass/include/cutlass/arch/synclog.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -35,9 +35,9 @@ #pragma once #include "cutlass/detail/helper_macros.hpp" - +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(cstdint) #else #include #endif @@ -120,44 +120,34 @@ constexpr uint32_t synclog_length_cluster_barrier_init = synclog_length_prefix + constexpr bool synclog_enable_cluster_barrier_wait = true; constexpr uint32_t synclog_header_cluster_barrier_wait = 6; -constexpr uint32_t synclog_length_cluster_barrier_wait = synclog_length_prefix + 4; - +constexpr uint32_t synclog_length_cluster_barrier_wait = synclog_length_prefix + 2; constexpr bool synclog_enable_cluster_barrier_test_wait = true; constexpr uint32_t synclog_header_cluster_barrier_test_wait = 7; -constexpr uint32_t synclog_length_cluster_barrier_test_wait = synclog_length_prefix + 5; - +constexpr uint32_t synclog_length_cluster_barrier_test_wait = synclog_length_prefix + 3; constexpr bool synclog_enable_cluster_barrier_try_wait = true; constexpr uint32_t synclog_header_cluster_barrier_try_wait = 8; -constexpr uint32_t synclog_length_cluster_barrier_try_wait = synclog_length_prefix + 4; - +constexpr uint32_t synclog_length_cluster_barrier_try_wait = synclog_length_prefix + 2; constexpr bool synclog_enable_cluster_barrier_arrive_cluster = true; constexpr uint32_t synclog_header_cluster_barrier_arrive_cluster = 9; -constexpr uint32_t synclog_length_cluster_barrier_arrive_cluster = synclog_length_prefix + 5; - +constexpr uint32_t synclog_length_cluster_barrier_arrive_cluster = synclog_length_prefix + 3; constexpr bool synclog_enable_cluster_barrier_arrive = true; constexpr uint32_t synclog_header_cluster_barrier_arrive = 10; -constexpr uint32_t synclog_length_cluster_barrier_arrive = synclog_length_prefix + 3; - +constexpr uint32_t synclog_length_cluster_barrier_arrive = synclog_length_prefix + 1; constexpr bool synclog_enable_cluster_barrier_invalidate = true; constexpr uint32_t synclog_header_cluster_barrier_invalidate = 11; -constexpr uint32_t synclog_length_cluster_barrier_invalidate = synclog_length_prefix + 3; - +constexpr uint32_t synclog_length_cluster_barrier_invalidate = synclog_length_prefix + 1; constexpr bool synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx = true; constexpr uint32_t synclog_header_cluster_transaction_barrier_arrive_and_expect_tx = 12; -constexpr uint32_t synclog_length_cluster_transaction_barrier_arrive_and_expect_tx = synclog_length_prefix + 4; - +constexpr uint32_t synclog_length_cluster_transaction_barrier_arrive_and_expect_tx = synclog_length_prefix + 2; constexpr bool synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx_cluster = true; constexpr uint32_t synclog_header_cluster_transaction_barrier_arrive_and_expect_tx_cluster = 13; -constexpr uint32_t synclog_length_cluster_transaction_barrier_arrive_and_expect_tx_cluster = synclog_length_prefix + 6; - +constexpr uint32_t synclog_length_cluster_transaction_barrier_arrive_and_expect_tx_cluster = synclog_length_prefix + 4; constexpr bool synclog_enable_cluster_transaction_barrier_expect_transaction = true; constexpr uint32_t synclog_header_cluster_transaction_barrier_expect_transaction = 14; -constexpr uint32_t synclog_length_cluster_transaction_barrier_expect_transaction = synclog_length_prefix + 4; - +constexpr uint32_t synclog_length_cluster_transaction_barrier_expect_transaction = synclog_length_prefix + 2; constexpr bool synclog_enable_cluster_transaction_barrier_complete_transaction = true; constexpr uint32_t synclog_header_cluster_transaction_barrier_complete_transaction = 15; -constexpr uint32_t synclog_length_cluster_transaction_barrier_complete_transaction = synclog_length_prefix + 6; - +constexpr uint32_t synclog_length_cluster_transaction_barrier_complete_transaction = synclog_length_prefix + 4; constexpr bool synclog_enable_fence_barrier_init = true; constexpr uint32_t synclog_header_fence_barrier_init = 16; constexpr uint32_t synclog_length_fence_barrier_init = synclog_length_prefix + 0; @@ -166,6 +156,10 @@ constexpr bool synclog_enable_fence_view_async_shared = true; constexpr uint32_t synclog_header_fence_view_async_shared = 17; constexpr uint32_t synclog_length_fence_view_async_shared = synclog_length_prefix + 0; +constexpr bool synclog_enable_fence_view_shared = true; +constexpr uint32_t synclog_header_fence_view_shared = 39; +constexpr uint32_t synclog_length_fence_view_shared = synclog_length_prefix + 0; + constexpr bool synclog_enable_cp_async_wait = true; constexpr uint32_t synclog_header_cp_async_wait = 18; constexpr uint32_t synclog_length_cp_async_wait = synclog_length_prefix + 1; @@ -228,12 +222,11 @@ constexpr uint32_t synclog_length_wgmma_smem_smem = synclog_length_prefix + 4; constexpr bool synclog_enable_cpasync_barrier_arrive = true; constexpr uint32_t synclog_header_cpasync_barrier_arrive = 33; -constexpr uint32_t synclog_length_cpasync_barrier_arrive = synclog_length_prefix + 3; - +constexpr uint32_t synclog_length_cpasync_barrier_arrive = synclog_length_prefix + 1; CUTLASS_DEVICE bool synclog_condition_emit() { #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) - return threadIdx.x%NumThreadsPerWarp == 0 && threadIdx.y == 0 && threadIdx.z == 0 && + return threadIdx.x % NumThreadsPerWarp == 0 && threadIdx.y == 0 && threadIdx.z == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0; #else return 0; @@ -272,17 +265,6 @@ void synclog_print_prefix(char const* header, uint32_t at) { #endif } -CUTLASS_DEVICE -uint64_t synclog_mbarrier_bits(uint32_t smem_addr) { - uint64_t bits = 0; - asm volatile ( - "mbarrier.inval.shared::cta.b64 [%1];\n" - "ld.shared::cta.b64 %0, [%1];\n" - : "=l"(bits) : "r"(smem_addr) - ); - return bits; -} - CUTLASS_DEVICE void synclog_print_wgmma_desc(char const* str, uint32_t lo, uint32_t hi, char const* sep) { CUTLASS_UNUSED(hi); @@ -429,14 +411,11 @@ void synclog_emit_cluster_barrier_wait( #if defined(CUTLASS_ENABLE_SYNCLOG) if constexpr (!synclog_enable_cluster_barrier_wait) return; if (!synclog_condition_emit()) return; - uint64_t bits = synclog_mbarrier_bits(smem_addr); uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_wait); if (to == nullptr) return; synclog_emit_prefix(to, synclog_header_cluster_barrier_wait, line); to[synclog_length_prefix + 0] = smem_addr; to[synclog_length_prefix + 1] = phase; - to[synclog_length_prefix + 2] = bits; - to[synclog_length_prefix + 3] = bits >> 32; #else CUTLASS_UNUSED(line); CUTLASS_UNUSED(smem_addr); @@ -453,15 +432,12 @@ void synclog_emit_cluster_barrier_test_wait( #if defined(CUTLASS_ENABLE_SYNCLOG) if constexpr (!synclog_enable_cluster_barrier_test_wait) return; if (!synclog_condition_emit()) return; - uint64_t bits = synclog_mbarrier_bits(smem_addr); uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_test_wait); if (to == nullptr) return; synclog_emit_prefix(to, synclog_header_cluster_barrier_test_wait, line); to[synclog_length_prefix + 0] = smem_addr; to[synclog_length_prefix + 1] = phase; to[synclog_length_prefix + 2] = pred; - to[synclog_length_prefix + 3] = bits; - to[synclog_length_prefix + 4] = bits >> 32; #else CUTLASS_UNUSED(line); CUTLASS_UNUSED(smem_addr); @@ -478,14 +454,11 @@ void synclog_emit_cluster_barrier_try_wait( #if defined(CUTLASS_ENABLE_SYNCLOG) if constexpr (!synclog_enable_cluster_barrier_try_wait) return; if (!synclog_condition_emit()) return; - uint64_t bits = synclog_mbarrier_bits(smem_addr); uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_try_wait); if (to == nullptr) return; synclog_emit_prefix(to, synclog_header_cluster_barrier_try_wait, line); to[synclog_length_prefix + 0] = smem_addr; to[synclog_length_prefix + 1] = phase; - to[synclog_length_prefix + 2] = bits; - to[synclog_length_prefix + 3] = bits >> 32; #else CUTLASS_UNUSED(line); CUTLASS_UNUSED(smem_addr); @@ -502,15 +475,12 @@ void synclog_emit_cluster_barrier_arrive_cluster( #if defined(CUTLASS_ENABLE_SYNCLOG) if constexpr (!synclog_enable_cluster_barrier_arrive_cluster) return; if (!synclog_condition_emit()) return; - uint64_t bits = synclog_mbarrier_bits(smem_addr); uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_arrive_cluster); if (to == nullptr) return; synclog_emit_prefix(to, synclog_header_cluster_barrier_arrive_cluster, line); to[synclog_length_prefix + 0] = smem_addr; to[synclog_length_prefix + 1] = cta_id; to[synclog_length_prefix + 2] = pred; - to[synclog_length_prefix + 3] = bits; - to[synclog_length_prefix + 4] = bits >> 32; #else CUTLASS_UNUSED(line); CUTLASS_UNUSED(smem_addr); @@ -526,13 +496,10 @@ void synclog_emit_cluster_barrier_arrive( #if defined(CUTLASS_ENABLE_SYNCLOG) if constexpr (!synclog_enable_cluster_barrier_arrive) return; if (!synclog_condition_emit()) return; - uint64_t bits = synclog_mbarrier_bits(smem_addr); uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_arrive); if (to == nullptr) return; synclog_emit_prefix(to, synclog_header_cluster_barrier_arrive, line); to[synclog_length_prefix + 0] = smem_addr; - to[synclog_length_prefix + 1] = bits; - to[synclog_length_prefix + 2] = bits >> 32; #else CUTLASS_UNUSED(line); CUTLASS_UNUSED(smem_addr); @@ -546,13 +513,10 @@ void synclog_emit_cluster_barrier_invalidate( #if defined(CUTLASS_ENABLE_SYNCLOG) if constexpr (!synclog_enable_cluster_barrier_invalidate) return; if (!synclog_condition_emit()) return; - uint64_t bits = synclog_mbarrier_bits(smem_addr); uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_invalidate); if (to == nullptr) return; synclog_emit_prefix(to, synclog_header_cluster_barrier_invalidate, line); to[synclog_length_prefix + 0] = smem_addr; - to[synclog_length_prefix + 1] = bits; - to[synclog_length_prefix + 2] = bits >> 32; #else CUTLASS_UNUSED(line); CUTLASS_UNUSED(smem_addr); @@ -567,14 +531,11 @@ void synclog_emit_cluster_transaction_barrier_arrive_and_expect_tx( #if defined(CUTLASS_ENABLE_SYNCLOG) if constexpr (!synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx) return; if (!synclog_condition_emit()) return; - uint64_t bits = synclog_mbarrier_bits(smem_addr); uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_arrive_and_expect_tx); if (to == nullptr) return; synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_arrive_and_expect_tx, line); to[synclog_length_prefix + 0] = smem_addr; to[synclog_length_prefix + 1] = transaction_bytes; - to[synclog_length_prefix + 2] = bits; - to[synclog_length_prefix + 3] = bits >> 32; #else CUTLASS_UNUSED(line); CUTLASS_UNUSED(smem_addr); @@ -592,7 +553,6 @@ void synclog_emit_cluster_transaction_barrier_arrive_and_expect_tx_cluster( #if defined(CUTLASS_ENABLE_SYNCLOG) if constexpr (!synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx_cluster) return; if (!synclog_condition_emit()) return; - uint64_t bits = synclog_mbarrier_bits(smem_addr); uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_arrive_and_expect_tx_cluster); if (to == nullptr) return; synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_arrive_and_expect_tx_cluster, line); @@ -600,8 +560,6 @@ void synclog_emit_cluster_transaction_barrier_arrive_and_expect_tx_cluster( to[synclog_length_prefix + 1] = transaction_bytes; to[synclog_length_prefix + 2] = cta_id; to[synclog_length_prefix + 3] = pred; - to[synclog_length_prefix + 4] = bits; - to[synclog_length_prefix + 5] = bits >> 32; #else CUTLASS_UNUSED(line); CUTLASS_UNUSED(smem_addr); @@ -619,14 +577,11 @@ void synclog_emit_cluster_transaction_barrier_expect_transaction( #if defined(CUTLASS_ENABLE_SYNCLOG) if constexpr (!synclog_enable_cluster_transaction_barrier_expect_transaction) return; if (!synclog_condition_emit()) return; - uint64_t bits = synclog_mbarrier_bits(smem_addr); uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_expect_transaction); if (to == nullptr) return; synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_expect_transaction, line); to[synclog_length_prefix + 0] = smem_addr; to[synclog_length_prefix + 1] = transaction_bytes; - to[synclog_length_prefix + 2] = bits; - to[synclog_length_prefix + 2] = bits >> 32; #else CUTLASS_UNUSED(line); CUTLASS_UNUSED(smem_addr); @@ -644,7 +599,6 @@ void synclog_emit_cluster_transaction_barrier_complete_transaction( #if defined(CUTLASS_ENABLE_SYNCLOG) if constexpr (!synclog_enable_cluster_transaction_barrier_complete_transaction) return; if (!synclog_condition_emit()) return; - uint64_t bits = synclog_mbarrier_bits(smem_addr); uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_complete_transaction); if (to == nullptr) return; synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_complete_transaction, line); @@ -652,8 +606,6 @@ void synclog_emit_cluster_transaction_barrier_complete_transaction( to[synclog_length_prefix + 1] = dst_cta_id; to[synclog_length_prefix + 2] = transaction_bytes; to[synclog_length_prefix + 3] = pred; - to[synclog_length_prefix + 4] = bits; - to[synclog_length_prefix + 5] = bits >> 32; #else CUTLASS_UNUSED(line); CUTLASS_UNUSED(smem_addr); @@ -689,6 +641,19 @@ void synclog_emit_fence_view_async_shared(uint32_t line) { #endif // defined(CUTLASS_ENABLE_SYNCLOG) } +CUTLASS_DEVICE +void synclog_emit_fence_view_shared(uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_fence_view_shared) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_fence_view_shared); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_fence_view_shared, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + CUTLASS_DEVICE void synclog_emit_cp_async_wait( uint32_t line, @@ -977,13 +942,10 @@ void synclog_emit_cpasync_barrier_arrive( #if defined(CUTLASS_ENABLE_SYNCLOG) if constexpr (!synclog_enable_cpasync_barrier_arrive) return; if (!synclog_condition_emit()) return; - uint64_t bits = synclog_mbarrier_bits(smem_addr); uint32_t* to = synclog_alloc(synclog_length_cpasync_barrier_arrive); if (to == nullptr) return; synclog_emit_prefix(to, synclog_header_cpasync_barrier_arrive, line); to[synclog_length_prefix + 0] = smem_addr; - to[synclog_length_prefix + 1] = bits; - to[synclog_length_prefix + 2] = bits >> 32; #else CUTLASS_UNUSED(line); CUTLASS_UNUSED(smem_addr); @@ -1054,7 +1016,7 @@ void synclog_print() { if (header == synclog_header_cluster_barrier_wait) { synclog_print_prefix("cluster_barrier_wait", at); at += synclog_length_cluster_barrier_wait; - printf("smem_addr=%u phase=%u", synclog_buf[at-4], synclog_buf[at-3]); + printf("smem_addr=%u phase=%u\n", synclog_buf[at-2], synclog_buf[at-1]); continue; } } @@ -1062,7 +1024,7 @@ void synclog_print() { if (header == synclog_header_cluster_barrier_test_wait) { synclog_print_prefix("cluster_barrier_test_wait", at); at += synclog_length_cluster_barrier_test_wait; - printf("smem_addr=%u phase=%u pred=%u", synclog_buf[at-5], synclog_buf[at-4], synclog_buf[at-3]); + printf("smem_addr=%u phase=%u pred=%u\n", synclog_buf[at-3], synclog_buf[at-2], synclog_buf[at-1]); continue; } } @@ -1070,7 +1032,7 @@ void synclog_print() { if (header == synclog_header_cluster_barrier_try_wait) { synclog_print_prefix("cluster_barrier_try_wait", at); at += synclog_length_cluster_barrier_try_wait; - printf("smem_addr=%u phase=%u", synclog_buf[at-4], synclog_buf[at-3]); + printf("smem_addr=%u phase=%u\n", synclog_buf[at-2], synclog_buf[at-1]); continue; } } @@ -1078,7 +1040,7 @@ void synclog_print() { if (header == synclog_header_cluster_barrier_arrive_cluster) { synclog_print_prefix("cluster_barrier_arrive_cluster", at); at += synclog_length_cluster_barrier_arrive_cluster; - printf("smem_addr=%u cta_id=%u pred=%u", synclog_buf[at-5], synclog_buf[at-4], synclog_buf[at-3]); + printf("smem_addr=%u cta_id=%u pred=%u\n", synclog_buf[at-3], synclog_buf[at-2], synclog_buf[at-1]); continue; } } @@ -1086,7 +1048,7 @@ void synclog_print() { if (header == synclog_header_cluster_barrier_arrive) { synclog_print_prefix("cluster_barrier_arrive", at); at += synclog_length_cluster_barrier_arrive; - printf("smem_addr=%u", synclog_buf[at-3]); + printf("smem_addr=%u\n", synclog_buf[at-1]); continue; } } @@ -1094,7 +1056,7 @@ void synclog_print() { if (header == synclog_header_cluster_barrier_invalidate) { synclog_print_prefix("cluster_barrier_invalidate", at); at += synclog_length_cluster_barrier_invalidate; - printf("smem_addr=%u", synclog_buf[at-3]); + printf("smem_addr=%u\n", synclog_buf[at-1]); continue; } } @@ -1102,7 +1064,7 @@ void synclog_print() { if (header == synclog_header_cluster_transaction_barrier_arrive_and_expect_tx) { synclog_print_prefix("cluster_transaction_barrier_arrive_and_expect_tx", at); at += synclog_length_cluster_transaction_barrier_arrive_and_expect_tx; - printf("smem_addr=%u transaction_bytes=%u", synclog_buf[at-4], synclog_buf[at-3]); + printf("smem_addr=%u transaction_bytes=%u\n", synclog_buf[at-2], synclog_buf[at-1]); continue; } } @@ -1110,7 +1072,7 @@ void synclog_print() { if (header == synclog_header_cluster_transaction_barrier_arrive_and_expect_tx_cluster) { synclog_print_prefix("cluster_transaction_barrier_arrive_and_expect_tx_cluster", at); at += synclog_length_cluster_transaction_barrier_arrive_and_expect_tx_cluster; - printf("smem_addr=%u transaction_bytes=%u cta_id=%u pred=%u", synclog_buf[at-6], synclog_buf[at-5], synclog_buf[at-4], synclog_buf[at-3]); + printf("smem_addr=%u transaction_bytes=%u cta_id=%u pred=%u\n", synclog_buf[at-4], synclog_buf[at-3], synclog_buf[at-2], synclog_buf[at-1]); continue; } } @@ -1118,7 +1080,7 @@ void synclog_print() { if (header == synclog_header_cluster_transaction_barrier_expect_transaction) { synclog_print_prefix("cluster_transaction_barrier_expect_transaction", at); at += synclog_length_cluster_transaction_barrier_expect_transaction; - printf("smem_addr=%u transaction_bytes=%u", synclog_buf[at-4], synclog_buf[at-3]); + printf("smem_addr=%u transaction_bytes=%u\n", synclog_buf[at-2], synclog_buf[at-1]); continue; } } @@ -1126,7 +1088,7 @@ void synclog_print() { if (header == synclog_header_cluster_transaction_barrier_complete_transaction) { synclog_print_prefix("cluster_transaction_barrier_complete_transaction", at); at += synclog_length_cluster_transaction_barrier_complete_transaction; - printf("smem_addr=%u dst_cta_id=%u transaction_bytes=%u pred=%u", synclog_buf[at-6], synclog_buf[at-5], synclog_buf[at-4], synclog_buf[at-3]); + printf("smem_addr=%u dst_cta_id=%u transaction_bytes=%u pred=%u\n", synclog_buf[at-4], synclog_buf[at-3], synclog_buf[at-2], synclog_buf[at-1]); continue; } } @@ -1146,6 +1108,14 @@ void synclog_print() { continue; } } + if constexpr (synclog_enable_fence_view_shared) { + if (header == synclog_header_fence_view_shared) { + synclog_print_prefix("fence_view_shared", at); + at += synclog_length_fence_view_shared; + printf("\n"); + continue; + } + } if constexpr (synclog_enable_cp_async_wait) { if (header == synclog_header_cp_async_wait) { synclog_print_prefix("cp_async_wait", at); @@ -1283,7 +1253,7 @@ void synclog_print() { if (header == synclog_header_cpasync_barrier_arrive) { synclog_print_prefix("cpasync_barrier_arrive", at); at += synclog_length_cpasync_barrier_arrive; - printf("smem_addr=%u", synclog_buf[at-3]); + printf("smem_addr=%u\n", synclog_buf[at-1]); continue; } } @@ -1302,6 +1272,7 @@ void synclog_print() { //////////////////////////////////////////////////////////////////////////////////////////////////// + #if defined(CUTLASS_ENABLE_SYNCLOG) #undef __syncthreads #define __syncthreads() do {\ @@ -1318,6 +1289,7 @@ void synclog_print() { } while (0) #endif // defined(CUTLASS_ENABLE_SYNCLOG) + //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace arch diff --git a/3rd/cutlass/include/cutlass/arch/wmma.h b/3rd/cutlass/include/cutlass/arch/wmma.h index 9cb9c04..69bba57 100644 --- a/3rd/cutlass/include/cutlass/arch/wmma.h +++ b/3rd/cutlass/include/cutlass/arch/wmma.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -177,7 +177,7 @@ struct WmmaToCutlassDataType<__nv_bfloat16> { ///////////////////////////////////////////////////////////////////////////////////////////////// // WMMA template structure defines nvcuda::wmma::fragments and static assertion chaeks -// for a specific template paramterized data type (Element[A|B|C]), layout (Layout[A|B|C]), +// for a specific template parameterized data type (Element[A|B|C]), layout (Layout[A|B|C]), // and native wmma size (Shape) ///////////////////////////////////////////////////////////////////////////////////////////////// template < diff --git a/3rd/cutlass/include/cutlass/arch/wmma_sm70.h b/3rd/cutlass/include/cutlass/arch/wmma_sm70.h index 99d8148..05efaf6 100644 --- a/3rd/cutlass/include/cutlass/arch/wmma_sm70.h +++ b/3rd/cutlass/include/cutlass/arch/wmma_sm70.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -33,8 +33,10 @@ */ #pragma once - -#include +#include "cutlass/cutlass.h" +#ifndef __QNX__ +#include CUDA_STD_HEADER(cassert) +#endif #include "cutlass/layout/matrix.h" //////////////////////////////////////////////////////////////////////////////// @@ -123,7 +125,7 @@ struct Wmma< nvcuda::wmma::mma_sync(D, A, B, C); } #else - static_assert(false, "wmma.mma.sync for floating point multiplicands is avialable only for SM70 and beyond"); + static_assert(false, "wmma.mma.sync for floating point multiplicands is available only for SM70 and beyond"); #endif }; diff --git a/3rd/cutlass/include/cutlass/arch/wmma_sm72.h b/3rd/cutlass/include/cutlass/arch/wmma_sm72.h index 3c488c7..995f825 100644 --- a/3rd/cutlass/include/cutlass/arch/wmma_sm72.h +++ b/3rd/cutlass/include/cutlass/arch/wmma_sm72.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -33,8 +33,10 @@ */ #pragma once - -#include +#include "cutlass/cutlass.h" +#ifndef __QNX__ +#include CUDA_STD_HEADER(cassert) +#endif #include "cutlass/layout/matrix.h" //////////////////////////////////////////////////////////////////////////////// @@ -117,7 +119,7 @@ struct Wmma< } #else - static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM72 and beyond"); + static_assert(false, "wmma.mma.sync integer type multiplicands is available only for SM72 and beyond"); #endif }; @@ -197,7 +199,7 @@ struct Wmma< } #else - static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM72 and beyond"); + static_assert(false, "wmma.mma.sync integer type multiplicands is available only for SM72 and beyond"); #endif }; diff --git a/3rd/cutlass/include/cutlass/arch/wmma_sm75.h b/3rd/cutlass/include/cutlass/arch/wmma_sm75.h index d49e8ca..f67a519 100644 --- a/3rd/cutlass/include/cutlass/arch/wmma_sm75.h +++ b/3rd/cutlass/include/cutlass/arch/wmma_sm75.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -33,8 +33,10 @@ */ #pragma once - -#include +#include "cutlass/cutlass.h" +#ifndef __QNX__ +#include CUDA_STD_HEADER(cassert) +#endif #include "cutlass/layout/matrix.h" //////////////////////////////////////////////////////////////////////////////// @@ -115,7 +117,7 @@ struct Wmma< } #else - static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM75 and beyond"); + static_assert(false, "wmma.mma.sync integer type multiplicands is available only for SM75 and beyond"); #endif }; @@ -194,7 +196,7 @@ struct Wmma< } #else - static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM75 and beyond"); + static_assert(false, "wmma.mma.sync integer type multiplicands is available only for SM75 and beyond"); #endif }; diff --git a/3rd/cutlass/include/cutlass/array.h b/3rd/cutlass/include/cutlass/array.h index ce33110..0ee2591 100644 --- a/3rd/cutlass/include/cutlass/array.h +++ b/3rd/cutlass/include/cutlass/array.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -718,6 +718,24 @@ struct maximum_absolute_value_reduction, PropogateNaN> { } }; +template +struct maximum_absolute_value_zero_mantissa_reduction, PropagateNaN> { + + CUTLASS_HOST_DEVICE + T operator() (T const& scalar, cutlass::Array const& rhs) const { + + T result = scalar; + maximum_absolute_value_zero_mantissa_reduction scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result = scalar_op(result, rhs[i]); + } + + return result; + } +}; + template struct scale> { T const scaling_factor_; diff --git a/3rd/cutlass/include/cutlass/array_planar_complex.h b/3rd/cutlass/include/cutlass/array_planar_complex.h index 0bd9d0d..6e5ff13 100644 --- a/3rd/cutlass/include/cutlass/array_planar_complex.h +++ b/3rd/cutlass/include/cutlass/array_planar_complex.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/array_subbyte.h b/3rd/cutlass/include/cutlass/array_subbyte.h index 6a61379..0026810 100644 --- a/3rd/cutlass/include/cutlass/array_subbyte.h +++ b/3rd/cutlass/include/cutlass/array_subbyte.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -118,7 +118,7 @@ struct Array { // result[0] = xxx; // ``` // - // Will leads to compiler warning on use of unintialized member variable. Although we know + // Will leads to compiler warning on use of uninitialized member variable. Although we know // this read of uninitialized member variable is harmeless. #if defined(__clang__) diff --git a/3rd/cutlass/include/cutlass/barrier.h b/3rd/cutlass/include/cutlass/barrier.h index 8919e99..b7db168 100644 --- a/3rd/cutlass/include/cutlass/barrier.h +++ b/3rd/cutlass/include/cutlass/barrier.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/bfloat16.h b/3rd/cutlass/include/cutlass/bfloat16.h index 5e2f40b..8411935 100644 --- a/3rd/cutlass/include/cutlass/bfloat16.h +++ b/3rd/cutlass/include/cutlass/bfloat16.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -362,7 +362,7 @@ struct numeric_limits { /// Returns smallest finite value CUTLASS_HOST_DEVICE - static cutlass::bfloat16_t epsilon() { return cutlass::bfloat16_t::bitcast(0x1000); } + static cutlass::bfloat16_t epsilon() { return cutlass::bfloat16_t::bitcast(0x3c00); } /// Returns smallest finite value CUTLASS_HOST_DEVICE @@ -431,7 +431,7 @@ struct numeric_limits { /// Returns smallest finite value CUTLASS_HOST_DEVICE - static cutlass::bfloat16_t epsilon() { return cutlass::bfloat16_t::bitcast(0x1000); } + static cutlass::bfloat16_t epsilon() { return cutlass::bfloat16_t::bitcast(0x3c00); } /// Returns smallest finite value CUTLASS_HOST_DEVICE @@ -667,12 +667,12 @@ bfloat16_t operator--(bfloat16_t & lhs, int) { // CUTLASS_HOST_DEVICE -cutlass::bfloat16_t operator "" _bf16(long double x) { +cutlass::bfloat16_t operator""_bf16(long double x) { return cutlass::bfloat16_t(float(x)); } CUTLASS_HOST_DEVICE -cutlass::bfloat16_t operator "" _bf16(unsigned long long int x) { +cutlass::bfloat16_t operator""_bf16(unsigned long long int x) { return cutlass::bfloat16_t(int(x)); } diff --git a/3rd/cutlass/include/cutlass/blas3.h b/3rd/cutlass/include/cutlass/blas3.h index 8788f18..8ec5e88 100644 --- a/3rd/cutlass/include/cutlass/blas3.h +++ b/3rd/cutlass/include/cutlass/blas3.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/blas3_types.h b/3rd/cutlass/include/cutlass/blas3_types.h index e47002b..95b9a79 100644 --- a/3rd/cutlass/include/cutlass/blas3_types.h +++ b/3rd/cutlass/include/cutlass/blas3_types.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/block_striped.h b/3rd/cutlass/include/cutlass/block_striped.h index 93665c6..d8b9a64 100644 --- a/3rd/cutlass/include/cutlass/block_striped.h +++ b/3rd/cutlass/include/cutlass/block_striped.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/cluster_launch.hpp b/3rd/cutlass/include/cutlass/cluster_launch.hpp index 3a77711..5f26e16 100644 --- a/3rd/cutlass/include/cutlass/cluster_launch.hpp +++ b/3rd/cutlass/include/cutlass/cluster_launch.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -39,9 +39,10 @@ #include "cutlass/cutlass.h" #include "cutlass/trace.h" #include +#include "cutlass/arch/synclog.hpp" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(type_traits) #else #include #include diff --git a/3rd/cutlass/include/cutlass/complex.h b/3rd/cutlass/include/cutlass/complex.h index 723f1e3..adfe103 100644 --- a/3rd/cutlass/include/cutlass/complex.h +++ b/3rd/cutlass/include/cutlass/complex.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -34,14 +34,12 @@ #include #include - +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(cstdint) #else #include #endif - -#include "cutlass/cutlass.h" #include "cutlass/functional.h" #include "cutlass/platform/platform.h" #include "cutlass/real.h" @@ -815,6 +813,37 @@ struct atomic_add> { } }; +// Maximal exponent reduction for zero-mantissa scaling factors: complex number uses its largest cartesian norm not abs +template +struct maximum_cartesian_norm_zero_mantissa_reduction { + using T = typename TC::value_type; + + CUTLASS_HOST_DEVICE + T operator()(T const &lhs, cutlass::complex const &rhs) const { + maximum_absolute_value_zero_mantissa_reduction red_op; + + return red_op(red_op(lhs, rhs.real()), rhs.imag()); + } +}; + +template +struct maximum_cartesian_norm_zero_mantissa_reduction, PropagateNaN> { + using T = typename TC::value_type; + + CUTLASS_HOST_DEVICE + T operator() (T const& scalar, cutlass::Array const& rhs) const { + + T result = scalar; + maximum_cartesian_norm_zero_mantissa_reduction scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result = scalar_op(result, rhs[i]); + } + + return result; + } +}; ////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/constants.h b/3rd/cutlass/include/cutlass/constants.h index f5df017..d3efb24 100644 --- a/3rd/cutlass/include/cutlass/constants.h +++ b/3rd/cutlass/include/cutlass/constants.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/collective/builders/sm100_common.inl b/3rd/cutlass/include/cutlass/conv/collective/builders/sm100_common.inl index b502466..4256898 100644 --- a/3rd/cutlass/include/cutlass/conv/collective/builders/sm100_common.inl +++ b/3rd/cutlass/include/cutlass/conv/collective/builders/sm100_common.inl @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/collective/builders/sm100_umma_builder.inl b/3rd/cutlass/include/cutlass/conv/collective/builders/sm100_umma_builder.inl index 9a9d4cb..475f1be 100644 --- a/3rd/cutlass/include/cutlass/conv/collective/builders/sm100_umma_builder.inl +++ b/3rd/cutlass/include/cutlass/conv/collective/builders/sm100_umma_builder.inl @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/collective/builders/sm90_common.inl b/3rd/cutlass/include/cutlass/conv/collective/builders/sm90_common.inl index ddab1f7..e3eb7c9 100644 --- a/3rd/cutlass/include/cutlass/conv/collective/builders/sm90_common.inl +++ b/3rd/cutlass/include/cutlass/conv/collective/builders/sm90_common.inl @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/collective/builders/sm90_gmma_builder.inl b/3rd/cutlass/include/cutlass/conv/collective/builders/sm90_gmma_builder.inl index c298ffb..a5a1541 100644 --- a/3rd/cutlass/include/cutlass/conv/collective/builders/sm90_gmma_builder.inl +++ b/3rd/cutlass/include/cutlass/conv/collective/builders/sm90_gmma_builder.inl @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/collective/collective_builder.hpp b/3rd/cutlass/include/cutlass/conv/collective/collective_builder.hpp index e032f95..fe2394c 100644 --- a/3rd/cutlass/include/cutlass/conv/collective/collective_builder.hpp +++ b/3rd/cutlass/include/cutlass/conv/collective/collective_builder.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/collective/collective_conv.hpp b/3rd/cutlass/include/cutlass/conv/collective/collective_conv.hpp index f0bb596..281a1d0 100644 --- a/3rd/cutlass/include/cutlass/conv/collective/collective_conv.hpp +++ b/3rd/cutlass/include/cutlass/conv/collective/collective_conv.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/collective/detail.hpp b/3rd/cutlass/include/cutlass/conv/collective/detail.hpp index af541a9..d643b05 100644 --- a/3rd/cutlass/include/cutlass/conv/collective/detail.hpp +++ b/3rd/cutlass/include/cutlass/conv/collective/detail.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp b/3rd/cutlass/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp index d3c541c..49dad67 100644 --- a/3rd/cutlass/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp b/3rd/cutlass/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp index 11eefed..db6b6ab 100644 --- a/3rd/cutlass/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/conv2d_problem_size.h b/3rd/cutlass/include/cutlass/conv/conv2d_problem_size.h index fbef858..870d90a 100644 --- a/3rd/cutlass/include/cutlass/conv/conv2d_problem_size.h +++ b/3rd/cutlass/include/cutlass/conv/conv2d_problem_size.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/conv3d_problem_size.h b/3rd/cutlass/include/cutlass/conv/conv3d_problem_size.h index 48bf056..ae2b956 100644 --- a/3rd/cutlass/include/cutlass/conv/conv3d_problem_size.h +++ b/3rd/cutlass/include/cutlass/conv/conv3d_problem_size.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/convnd_problem_shape.hpp b/3rd/cutlass/include/cutlass/conv/convnd_problem_shape.hpp index 3c31c21..93e2c89 100644 --- a/3rd/cutlass/include/cutlass/conv/convnd_problem_shape.hpp +++ b/3rd/cutlass/include/cutlass/conv/convnd_problem_shape.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/convolution.h b/3rd/cutlass/include/cutlass/conv/convolution.h index a3cc98b..37548c4 100644 --- a/3rd/cutlass/include/cutlass/conv/convolution.h +++ b/3rd/cutlass/include/cutlass/conv/convolution.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/detail.hpp b/3rd/cutlass/include/cutlass/conv/detail.hpp index 0802921..04b32a1 100644 --- a/3rd/cutlass/include/cutlass/conv/detail.hpp +++ b/3rd/cutlass/include/cutlass/conv/detail.hpp @@ -1,6 +1,6 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/device/conv_universal_adapter.hpp b/3rd/cutlass/include/cutlass/conv/device/conv_universal_adapter.hpp index d60469f..90016d0 100644 --- a/3rd/cutlass/include/cutlass/conv/device/conv_universal_adapter.hpp +++ b/3rd/cutlass/include/cutlass/conv/device/conv_universal_adapter.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/device/direct_convolution.h b/3rd/cutlass/include/cutlass/conv/device/direct_convolution.h index 387574b..6eaf65e 100644 --- a/3rd/cutlass/include/cutlass/conv/device/direct_convolution.h +++ b/3rd/cutlass/include/cutlass/conv/device/direct_convolution.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/device/implicit_gemm_convolution.h b/3rd/cutlass/include/cutlass/conv/device/implicit_gemm_convolution.h index a9aae87..7a1982d 100644 --- a/3rd/cutlass/include/cutlass/conv/device/implicit_gemm_convolution.h +++ b/3rd/cutlass/include/cutlass/conv/device/implicit_gemm_convolution.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h b/3rd/cutlass/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h index efd3dcb..1d42aeb 100644 --- a/3rd/cutlass/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h +++ b/3rd/cutlass/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/dispatch_policy.hpp b/3rd/cutlass/include/cutlass/conv/dispatch_policy.hpp index d569cb1..4eb1baf 100644 --- a/3rd/cutlass/include/cutlass/conv/dispatch_policy.hpp +++ b/3rd/cutlass/include/cutlass/conv/dispatch_policy.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/kernel/conv_universal.hpp b/3rd/cutlass/include/cutlass/conv/kernel/conv_universal.hpp index af804df..94ab5b9 100644 --- a/3rd/cutlass/include/cutlass/conv/kernel/conv_universal.hpp +++ b/3rd/cutlass/include/cutlass/conv/kernel/conv_universal.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d.h b/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d.h index f9647a5..dd883d3 100644 --- a/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d.h +++ b/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_dgrad.h b/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_dgrad.h index 27a96a5..353a6ab 100644 --- a/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_dgrad.h +++ b/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_dgrad.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop.h b/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop.h index 77e4c5d..24c00f3 100644 --- a/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop.h +++ b/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -1056,7 +1056,7 @@ struct DefaultConv2dFprop < ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and /// multistage pipeline. template < typename ElementA, @@ -1184,7 +1184,7 @@ struct DefaultConv2dFprop < ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and // multistage pipeline with interleaved layout. template < typename ElementA, diff --git a/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h b/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h index 107a1be..61f1118 100644 --- a/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h +++ b/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -215,7 +215,7 @@ struct DefaultConv2dFpropFusion < ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and /// multistage pipeline. template < typename ElementA, diff --git a/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_absmax.h b/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_absmax.h index ccc7515..c887804 100644 --- a/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_absmax.h +++ b/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_absmax.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h b/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h index b7fca98..601fb8e 100644 --- a/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h +++ b/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h b/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h index 5c2c7ff..50d71cf 100644 --- a/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h +++ b/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_group_fprop.h b/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_group_fprop.h index 99e353d..b690bc8 100644 --- a/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_group_fprop.h +++ b/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_group_fprop.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_wgrad.h b/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_wgrad.h index d55d453..1af6ba8 100644 --- a/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_wgrad.h +++ b/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_wgrad.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_wgrad_fusion.h b/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_wgrad_fusion.h index 83b680e..1123aae 100644 --- a/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_wgrad_fusion.h +++ b/3rd/cutlass/include/cutlass/conv/kernel/default_conv2d_wgrad_fusion.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/kernel/default_conv3d_dgrad.h b/3rd/cutlass/include/cutlass/conv/kernel/default_conv3d_dgrad.h index 309924c..838bbd9 100644 --- a/3rd/cutlass/include/cutlass/conv/kernel/default_conv3d_dgrad.h +++ b/3rd/cutlass/include/cutlass/conv/kernel/default_conv3d_dgrad.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop.h b/3rd/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop.h index 4b6709f..37dd5b3 100644 --- a/3rd/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop.h +++ b/3rd/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h b/3rd/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h index 024fb82..27eedf4 100644 --- a/3rd/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h +++ b/3rd/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -217,7 +217,7 @@ struct DefaultConv3dFpropFusion < ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Conv3dFprop specialzation for Optimzed IteratorAlgorithm and +/// Defines a kernel for Conv3dFprop specialzation for Optimized IteratorAlgorithm and /// multistage pipeline. template < typename ElementA, diff --git a/3rd/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop_with_broadcast.h b/3rd/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop_with_broadcast.h index 2fb12c2..2a0790f 100644 --- a/3rd/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop_with_broadcast.h +++ b/3rd/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop_with_broadcast.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/kernel/default_conv3d_wgrad.h b/3rd/cutlass/include/cutlass/conv/kernel/default_conv3d_wgrad.h index 6b50d20..5288acd 100644 --- a/3rd/cutlass/include/cutlass/conv/kernel/default_conv3d_wgrad.h +++ b/3rd/cutlass/include/cutlass/conv/kernel/default_conv3d_wgrad.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/kernel/default_deconv2d.h b/3rd/cutlass/include/cutlass/conv/kernel/default_deconv2d.h index a58046f..7d5a0cf 100644 --- a/3rd/cutlass/include/cutlass/conv/kernel/default_deconv2d.h +++ b/3rd/cutlass/include/cutlass/conv/kernel/default_deconv2d.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/kernel/default_deconv2d_with_broadcast.h b/3rd/cutlass/include/cutlass/conv/kernel/default_deconv2d_with_broadcast.h index e62187e..50cb000 100644 --- a/3rd/cutlass/include/cutlass/conv/kernel/default_deconv2d_with_broadcast.h +++ b/3rd/cutlass/include/cutlass/conv/kernel/default_deconv2d_with_broadcast.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/kernel/default_deconv3d.h b/3rd/cutlass/include/cutlass/conv/kernel/default_deconv3d.h index cb7ca07..5edafcb 100644 --- a/3rd/cutlass/include/cutlass/conv/kernel/default_deconv3d.h +++ b/3rd/cutlass/include/cutlass/conv/kernel/default_deconv3d.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/kernel/default_deconv3d_with_broadcast.h b/3rd/cutlass/include/cutlass/conv/kernel/default_deconv3d_with_broadcast.h index e25c8b2..505064a 100644 --- a/3rd/cutlass/include/cutlass/conv/kernel/default_deconv3d_with_broadcast.h +++ b/3rd/cutlass/include/cutlass/conv/kernel/default_deconv3d_with_broadcast.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/kernel/default_depthwise_fprop.h b/3rd/cutlass/include/cutlass/conv/kernel/default_depthwise_fprop.h index ba70813..2e69261 100644 --- a/3rd/cutlass/include/cutlass/conv/kernel/default_depthwise_fprop.h +++ b/3rd/cutlass/include/cutlass/conv/kernel/default_depthwise_fprop.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/kernel/direct_convolution.h b/3rd/cutlass/include/cutlass/conv/kernel/direct_convolution.h index 8c04988..99428a9 100644 --- a/3rd/cutlass/include/cutlass/conv/kernel/direct_convolution.h +++ b/3rd/cutlass/include/cutlass/conv/kernel/direct_convolution.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h b/3rd/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h index d3fa0e9..d3c8c73 100644 --- a/3rd/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h +++ b/3rd/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h b/3rd/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h index 5451c17..ac8f1ca 100644 --- a/3rd/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h +++ b/3rd/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h b/3rd/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h index 071854c..ebc30e6 100644 --- a/3rd/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h +++ b/3rd/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_with_absmax.h b/3rd/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_with_absmax.h index 0113473..44c7559 100644 --- a/3rd/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_with_absmax.h +++ b/3rd/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_with_absmax.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h b/3rd/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h index 1e810e3..3b05609 100644 --- a/3rd/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h +++ b/3rd/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp b/3rd/cutlass/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp index 0874d8f..494ffe7 100644 --- a/3rd/cutlass/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -276,13 +276,26 @@ class ConvUniversal< static constexpr int MaxClusterSize = 16; implementable &= size(args.hw_info.cluster_shape) <= MaxClusterSize; implementable &= size(args.hw_info.cluster_shape_fallback) <= MaxClusterSize; - implementable &= cutlass::detail::preferred_cluster_can_implement(args.hw_info.cluster_shape, args.hw_info.cluster_shape_fallback); + // Early return if cluster shape validation failed to avoid division by zero below + if (not cutlass::detail::preferred_cluster_can_implement(args.hw_info.cluster_shape, args.hw_info.cluster_shape_fallback)) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Invalid dynamic cluster shape\n"); + return false; + } } - if constexpr (is_grouped_wgrad) { - auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, args.hw_info.cluster_shape); - auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, args.hw_info.cluster_shape_fallback); + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, args.hw_info.cluster_shape); + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, args.hw_info.cluster_shape_fallback); + // implicit gemm B tile can be small for conv, ensure multicast smem offsets are 128B aligned + int multicast_b_bits = (size<1>(TileShape{}) * size<2>(TileShape{}) / size<0>(cluster_shape)) * sizeof_bits_v; + int multicast_b_fallback_bits = (size<1>(TileShape{}) * size<2>(TileShape{}) / size<0>(cluster_shape_fallback)) * sizeof_bits_v; + implementable &= multicast_b_bits % (128*8) == 0 && multicast_b_fallback_bits % (128*8) == 0; + if (not implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: multicast size too large for B tile\n"); + return false; + } + + if constexpr (is_grouped_wgrad) { implementable &= size<0>(cluster_shape) == 1 && size<0>(cluster_shape_fallback) == 1; if (!implementable) { diff --git a/3rd/cutlass/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp b/3rd/cutlass/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp index 2c02a45..b620978 100644 --- a/3rd/cutlass/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/thread/depthwise_mma.h b/3rd/cutlass/include/cutlass/conv/thread/depthwise_mma.h index 41eaba2..c48fc00 100644 --- a/3rd/cutlass/include/cutlass/conv/thread/depthwise_mma.h +++ b/3rd/cutlass/include/cutlass/conv/thread/depthwise_mma.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h index 2da2b73..7d7578a 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h index 8a5e60b..67eae7a 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h index b33645c..f4ea891 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h index 638c660..8d460c1 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h index e4eb011..9336d96 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h index c608ce5..6adc39b 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h index ed0e38c..cd343d6 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h index 1a5c33e..aeccdde 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h index ed200ed..c24d4d3 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h index f208c9a..2fd3568 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h index 2dc2151..9a29aa8 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h index 9b12fbe..87e3aac 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_params.h b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_params.h index 8a3828f..7d1e114 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_params.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_params.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_tile_iterator.h b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_tile_iterator.h index 13bd29b..c6ae119 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_tile_iterator.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_tile_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h index b5a2407..2c4cd71 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h index 5619727..af8b8ba 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h index ea48bc6..011192d 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h index 8e5048f..6d14b5f 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h b/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h index d996003..0eaea99 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h b/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h index a269b18..f4680ce 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h b/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h index 700c3d1..5ace6db 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h b/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h index 69915ba..c342795 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h b/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h index 5a888e0..c56b6f2 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h b/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h index 057023c..3431abb 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h b/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h index 4a40d37..397cf1d 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h b/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h index b4e7db3..763c30f 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_params.h b/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_params.h index 941f4e1..b64706b 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_params.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_params.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h b/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h index 97cad0a..a2fc835 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h b/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h index 7e5475f..51e4b1d 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h b/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h index cbe4998..639f38f 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h b/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h index 6c2f2e5..5aa4960 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/depthwise_direct_conv_params.h b/3rd/cutlass/include/cutlass/conv/threadblock/depthwise_direct_conv_params.h index f5cd2a7..3d4af80 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/depthwise_direct_conv_params.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/depthwise_direct_conv_params.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h b/3rd/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h index 012e306..cde6d42 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h b/3rd/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h index b8ae9b9..9dc7dda 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h b/3rd/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h index 846f1f3..435adab 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h b/3rd/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h index 1035fda..20ec802 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_pipelined.h b/3rd/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_pipelined.h index 30d13e9..191ba89 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_pipelined.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_pipelined.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/depthwise_mma_base.h b/3rd/cutlass/include/cutlass/conv/threadblock/depthwise_mma_base.h index 44dafcb..4753785 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/depthwise_mma_base.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/depthwise_mma_base.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h b/3rd/cutlass/include/cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h index 9e3cc41..f444d62 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h b/3rd/cutlass/include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h index 482a52f..506da3b 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/implicit_gemm_multistage.h b/3rd/cutlass/include/cutlass/conv/threadblock/implicit_gemm_multistage.h index 6c9c479..2912e20 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/implicit_gemm_multistage.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/implicit_gemm_multistage.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/implicit_gemm_pipelined.h b/3rd/cutlass/include/cutlass/conv/threadblock/implicit_gemm_pipelined.h index 45e2794..12b3af6 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/implicit_gemm_pipelined.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/implicit_gemm_pipelined.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h b/3rd/cutlass/include/cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h index 3be08c1..1dc5f3e 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h b/3rd/cutlass/include/cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h index dac6423..3eb0202 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/predicated_scale_bias_vector_iterator.h b/3rd/cutlass/include/cutlass/conv/threadblock/predicated_scale_bias_vector_iterator.h index e9844be..92f4034 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/predicated_scale_bias_vector_iterator.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/predicated_scale_bias_vector_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/threadblock/threadblock_swizzle.h b/3rd/cutlass/include/cutlass/conv/threadblock/threadblock_swizzle.h index 0c5aed6..516ffb1 100644 --- a/3rd/cutlass/include/cutlass/conv/threadblock/threadblock_swizzle.h +++ b/3rd/cutlass/include/cutlass/conv/threadblock/threadblock_swizzle.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/warp/mma_depthwise_simt.h b/3rd/cutlass/include/cutlass/conv/warp/mma_depthwise_simt.h index b7af2e3..b52ff08 100644 --- a/3rd/cutlass/include/cutlass/conv/warp/mma_depthwise_simt.h +++ b/3rd/cutlass/include/cutlass/conv/warp/mma_depthwise_simt.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h b/3rd/cutlass/include/cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h index 47fd1e0..664a1c4 100644 --- a/3rd/cutlass/include/cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h +++ b/3rd/cutlass/include/cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/conv/warp/scale_bias_relu_transform.h b/3rd/cutlass/include/cutlass/conv/warp/scale_bias_relu_transform.h index 6cb3935..c0bd791 100644 --- a/3rd/cutlass/include/cutlass/conv/warp/scale_bias_relu_transform.h +++ b/3rd/cutlass/include/cutlass/conv/warp/scale_bias_relu_transform.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/coord.h b/3rd/cutlass/include/cutlass/coord.h index c0199e1..044541d 100644 --- a/3rd/cutlass/include/cutlass/coord.h +++ b/3rd/cutlass/include/cutlass/coord.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -33,15 +33,13 @@ */ #pragma once - +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(cstdint) #else #include #endif -#include "cutlass/cutlass.h" - namespace cutlass { //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -104,6 +102,7 @@ struct Coord { template CUTLASS_HOST_DEVICE Coord(Coord other) { + static_assert(kRank == R); for (int i = 0; i < kRank; ++i) { idx[i] = other[i]; } @@ -418,7 +417,7 @@ Coord operator/(Coord coord, Index s) { // //////////////////////////////////////////////////////////////////////////////////////////////////// -/// Helper to make a 2-element coordinate +/// Helper to make a 1-element coordinate template CUTLASS_HOST_DEVICE Coord<1, T> make_Coord(T _0) { diff --git a/3rd/cutlass/include/cutlass/core_io.h b/3rd/cutlass/include/cutlass/core_io.h index 046b306..1dd61fd 100644 --- a/3rd/cutlass/include/cutlass/core_io.h +++ b/3rd/cutlass/include/cutlass/core_io.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/cuda_host_adapter.hpp b/3rd/cutlass/include/cutlass/cuda_host_adapter.hpp index 98e7789..8cf41c1 100644 --- a/3rd/cutlass/include/cutlass/cuda_host_adapter.hpp +++ b/3rd/cutlass/include/cutlass/cuda_host_adapter.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -30,7 +30,7 @@ **************************************************************************************************/ /*! \file - \brief Interface betweeen a CUTLASS device-wide operator and CUDA. + \brief Interface between a CUTLASS device-wide operator and CUDA. */ #pragma once @@ -392,7 +392,7 @@ struct CudaHostAdapter { /** * Fills a buffer in Global Memory with a byte sequence copied from host memory. - * This function can be overriden to dispatch to the appropriate cuMemsetD*Async API + * This function can be overridden to dispatch to the appropriate cuMemsetD*Async API */ virtual Status memsetDeviceImpl( void* destination, ///< Device memory pointer to be filled diff --git a/3rd/cutlass/include/cutlass/cutlass.h b/3rd/cutlass/include/cutlass/cutlass.h index ed81aec..0c7cb20 100644 --- a/3rd/cutlass/include/cutlass/cutlass.h +++ b/3rd/cutlass/include/cutlass/cutlass.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -35,9 +35,10 @@ #pragma once -#include "cutlass/arch/synclog.hpp" #include "cutlass/detail/helper_macros.hpp" +#define CUDA_STD_HEADER(header) + //////////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { diff --git a/3rd/cutlass/include/cutlass/detail/blockwise_scale_layout.hpp b/3rd/cutlass/include/cutlass/detail/blockwise_scale_layout.hpp index c05498c..047637c 100644 --- a/3rd/cutlass/include/cutlass/detail/blockwise_scale_layout.hpp +++ b/3rd/cutlass/include/cutlass/detail/blockwise_scale_layout.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -32,7 +32,7 @@ /*! \file - \brief Block Wise Scale configs specific for SM100 Blockwise/Groupwise MMA + \brief Blockwise Scale configs specific for Blockwise/Groupwise MMA */ #pragma once @@ -41,6 +41,7 @@ #include "cute/int_tuple.hpp" #include "cute/atom/mma_traits_sm100.hpp" +#include "cute/arch/mma_sm90.hpp" namespace cutlass::detail{ @@ -270,8 +271,13 @@ struct RuntimeBlockwiseScaleConfig { }; // Sm90 only supports MN major for SFA and SFB for now -template -using Sm90BlockwiseScaleConfig = Sm1xxBlockwiseScaleConfig; +template +using Sm90BlockwiseScaleConfig = Sm1xxBlockwiseScaleConfig< + SFVecSizeM, + SFVecSizeN, + SFVecSizeK, + majorSFA == cute::GMMA::Major::MN ? UMMA::Major::MN : UMMA::Major::K, + majorSFB == cute::GMMA::Major::MN ? UMMA::Major::MN : UMMA::Major::K>; template using Sm100BlockwiseScaleConfig = Sm1xxBlockwiseScaleConfig; diff --git a/3rd/cutlass/include/cutlass/detail/cluster.hpp b/3rd/cutlass/include/cutlass/detail/cluster.hpp index d35765a..f9c9f3d 100644 --- a/3rd/cutlass/include/cutlass/detail/cluster.hpp +++ b/3rd/cutlass/include/cutlass/detail/cluster.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/detail/collective.hpp b/3rd/cutlass/include/cutlass/detail/collective.hpp index 01085c5..53f25e8 100644 --- a/3rd/cutlass/include/cutlass/detail/collective.hpp +++ b/3rd/cutlass/include/cutlass/detail/collective.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/detail/collective/mixed_input_utils.hpp b/3rd/cutlass/include/cutlass/detail/collective/mixed_input_utils.hpp index aed30be..7c5ee0b 100644 --- a/3rd/cutlass/include/cutlass/detail/collective/mixed_input_utils.hpp +++ b/3rd/cutlass/include/cutlass/detail/collective/mixed_input_utils.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -88,7 +88,7 @@ struct LayoutAwareConvertImpl< static void convert( cute::Tensor, cute::Stride<_4,_1>> - > const& src, + > const& src, cute::Tensor >& dst) { @@ -136,7 +136,7 @@ struct LayoutAwareConvertImpl< static void convert( cute::Tensor, cute::Stride<_4,_1>> - > const& src, + > const& src, cute::Tensor >& dst) { @@ -184,7 +184,7 @@ struct LayoutAwareConvertImpl< static void convert( cute::Tensor, cute::Stride<_4,_1>> - > const& src, + > const& src, cute::Tensor >& dst) { @@ -250,7 +250,7 @@ struct LayoutAwareConvertImpl< static void convert( cute::Tensor, cute::Stride<_4,_1>> - > const& src, + > const& src, cute::Tensor >& dst) { @@ -347,7 +347,7 @@ struct LayoutAwareConvertImpl< // Specialization for INT8 -> BF16 with [3120] value order template <> struct LayoutAwareConvertImpl< - cutlass::int8_t, + int8_t, cutlass::bfloat16_t, cute::Layout, cute::Stride<_2,_1>>, cute::Layout<_4> @@ -362,9 +362,9 @@ struct LayoutAwareConvertImpl< cute::Layout<_4> >& dst) { - static_assert(cute::is_same_v && + static_assert(cute::is_same_v && cute::is_same_v); - using SrcArray = cutlass::Array; + using SrcArray = cutlass::Array; using DstArray = cutlass::Array; using RegArray = cutlass::AlignedArray; @@ -402,7 +402,7 @@ struct LayoutAwareConvertImpl< // Specialization for INT8 -> FP16 with [3120] value order template <> struct LayoutAwareConvertImpl< - cutlass::int8_t, + int8_t, cutlass::half_t, cute::Layout, cute::Stride<_2,_1>>, cute::Layout<_4> @@ -417,9 +417,9 @@ struct LayoutAwareConvertImpl< cute::Layout<_4> >& dst) { - static_assert(cute::is_same_v && + static_assert(cute::is_same_v && cute::is_same_v); - using SrcArray = cutlass::Array; + using SrcArray = cutlass::Array; using DstArray = cutlass::Array; using RegArray = cutlass::AlignedArray; @@ -477,28 +477,36 @@ void LayoutAwareConvert( Tensor dst_vm = coalesce(dst); Layout src_layout = src_vm.layout(); Layout dst_layout = dst_vm.layout(); - LayoutAwareConvertImpl::convert(src_vm, dst_vm); } + } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// +namespace cutlass { + namespace detail { + enum class ConversionMode { + DirectConvert, // A * B + ConvertAndScale, // (scale * A) * B + ConvertAndScaleWithZero // (scale * A + zeros) * B + }; + } // namespace detail +} //namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + namespace cutlass::gemm::collective::detail { template static constexpr CUTLASS_HOST_DEVICE auto get_logical_ptr(PointerType const* ptr) { - if constexpr (cute::sizeof_bits_v < 8) { - return subbyte_iterator(ptr); - } - else { - return ptr; - } + return cute::recast_ptr(ptr); } template static constexpr @@ -530,8 +538,8 @@ auto get_gmem_layout(Shape const& shape, Stride const& stride) { template struct MixedInputUtils { private: + using ConversionMode = cutlass::detail::ConversionMode; using KernelSchedule = typename Collective::KernelSchedule; - using ConversionMode = typename Collective::ConversionMode; using SmemLayoutA = typename Collective::SmemLayoutA; using SmemLayoutB = typename Collective::SmemLayoutB; using SmemLayoutScale = typename Collective::SmemLayoutScale; @@ -551,10 +559,10 @@ struct MixedInputUtils { elements_per_smem_scale() { if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { return 0; - } + } else if constexpr (ModeHasScales) { return cute::cosize_v; - } + } else { static_assert(cutlass::detail::dependent_false, "Type not handled in scale smem allocation."); } @@ -565,10 +573,10 @@ struct MixedInputUtils { if constexpr (KernelConversionMode == ConversionMode::DirectConvert || KernelConversionMode == ConversionMode::ConvertAndScale ) { return 0; - } + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { return cute::cosize_v; - } + } else { static_assert(cutlass::detail::dependent_false, "Type not handled in scale smem allocation."); } @@ -611,6 +619,32 @@ struct MixedInputUtils { } } + static constexpr uint32_t + compute_tma_transaction_bytes_extra_transform() { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return 0; + } + else if constexpr (ModeHasScales) { + constexpr uint32_t scale_tx_bytes = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutScale{})) * static_cast(cute::sizeof_bits_v)); + static_assert(scale_tx_bytes % 128 == 0, "Each scale stage must be 128B aligned."); // required by TMA + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return scale_tx_bytes; + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + // Scale and zero share smem layout + constexpr uint32_t zero_tx_bytes = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutScale{})) * static_cast(cute::sizeof_bits_v)); + static_assert(zero_tx_bytes % 128 == 0, "Each zero stage must be 128B aligned."); // required by TMA + return scale_tx_bytes + zero_tx_bytes; + } + else { + static_assert(cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); + } + } + /// Utilities to copy A and extra inputs from smem to RF template (tiled_copy_and_views); auto tCrS_copy_view = cute::get<1>(tiled_copy_and_views); @@ -649,13 +683,60 @@ struct MixedInputUtils { } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); } - } + } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); } } } + /// (Designed for separate transform pipeline in Blackwell) + /// Utilities to copy extra inputs from smem to RF + template + CUTLASS_DEVICE + static void copy_scale_zeros_for_transform( + cute::tuple & partitioned_transform_extra_info, + int load2transform_consumer_index) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + } + else if constexpr (ModeHasScales) { + auto smem_tiled_copy_S = cute::get<0>(partitioned_transform_extra_info); + auto&& scales = cute::get<1>(partitioned_transform_extra_info); + using ScaleType = decltype(scales); + auto tSrS = make_tensor(scales.data(), scales.layout()); + auto tSsS = cute::get<2>(partitioned_transform_extra_info); + copy(smem_tiled_copy_S, tSsS(_,_,_,_,load2transform_consumer_index), tSrS); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Nothing extra to do + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto&& zeros = cute::get<3>(partitioned_transform_extra_info); + using ZeroType = decltype(zeros); + auto tZrZ = make_tensor(zeros.data(), zeros.layout()); + auto tZsZ = cute::get<4>(partitioned_transform_extra_info); + copy(smem_tiled_copy_S, tZsZ(_,_,_,_,load2transform_consumer_index), tZrZ); + + } else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + + // Helper functions to select packing for conversion + template + struct select_packing { // Naive packing policy + static constexpr auto value() { + return Int, sizeof_bits_v))>{}; + } + }; + // The core converter uses a lookup table to converts i4 -> 8 bit value. template && dst, Tensor const& scales_neg, Tensor const& scales_pos) { - + lookup_table_convert(src, dst, scales_neg, scales_pos); } template ; using DstArray = cutlass::Array; @@ -699,7 +780,7 @@ struct MixedInputUtils { // Determines if to get from the signed or unsigned candidates static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; - uint32_t sign; // ((reg & 0x88888888) | 0x64206420) >> 1 + uint32_t sign; // ((reg & 0x88888888) | 0x64206420) >> 1 asm volatile( "{\n" " lop3.b32 %0, %1, %2, %3, %4;\n" \ @@ -743,13 +824,13 @@ struct MixedInputUtils { static_check_scale(flatten(Layout{})); } template CUTLASS_DEVICE static void dequantize_A_kblock( - Tensor const& tCrA_load, + Tensor const& tCrA_load, Tensor& tCrA_mma, cute::tuple& partitioned_extra_info, int const k_block) { @@ -764,7 +845,7 @@ struct MixedInputUtils { Tensor src = tCrA_load(_, _, k_block); Tensor dst = tCrA_mma(_, _, k_block); - + CUTE_STATIC_ASSERT_V(size(src(_, 0)) == cosize(src(_, 0).layout()), "The first mode of tensor src must be contiguous in memory"); // try to make the size of the first mode equal to 32bit @@ -778,10 +859,11 @@ struct MixedInputUtils { for (int i = 0; i < size<1>(dst_vm); ++i) { LayoutAwareConvert(src_vm(_, i), dst_vm(_, i)); } - } + } else if constexpr (UseScaleLookupTable) { constexpr int num_elements = decltype(size(src))::value; - static_assert(is_same_v, "Lookup table only supports int4 being the quant type now."); + static_assert(is_same_v || is_same_v, + "Lookup table supports int4b_t (Two's Complement) and float_e2m1_t (E2M1/FP4) quant types."); static_assert(sizeof_bits_v == 64, "Lookup table only supports 8 8bit scale values now."); static_assert(num_elements % 4 == 0 && num_elements >= 4, "Lookup table requires a vector size of 4x when converting."); @@ -804,15 +886,31 @@ struct MixedInputUtils { { auto&& scale_neg_ = reinterpret_cast const&>(scales_neg_vm_(i)); auto&& scale_pos_ = reinterpret_cast &>(scales_pos_vm_(i)); - constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; - asm volatile( - "{\n" - " lop3 .b32 %0, %2, %4, %5, %6;\n" \ - " xor .b32 %1, %3, %5; \n" \ - "}\n" - : "=r"(scale_pos_[0]), "=r"(scale_pos_[1]) - : "r"(scale_neg_[0]), "r"(scale_neg_[1]), "n"(0xFFFFFF00), "n"(0x80808080), "n"(immLut) - ); + + // Accept CUTLASS pseudo-FP as well + if constexpr (cutlass::platform::is_floating_point::value || + cute::is_same_v) { + // E2M1 (FP4): Sign-magnitude encoding - simple sign flip with two XORs + asm volatile( + "{\n" + " xor .b32 %0, %2, %4;\n" \ + " xor .b32 %1, %3, %4;\n" \ + "}\n" + : "=r"(scale_pos_[0]), "=r"(scale_pos_[1]) + : "r"(scale_neg_[0]), "r"(scale_neg_[1]), "n"(0x80808080) + ); + } else { + // INT4: Two's complement encoding - reorder and sign flip with lop3 + constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + asm volatile( + "{\n" + " lop3 .b32 %0, %2, %4, %5, %6;\n" \ + " xor .b32 %1, %3, %5; \n" \ + "}\n" + : "=r"(scale_pos_[0]), "=r"(scale_pos_[1]) + : "r"(scale_neg_[0]), "r"(scale_neg_[1]), "n"(0xFFFFFF00), "n"(0x80808080), "n"(immLut) + ); + } } } CUTLASS_PRAGMA_UNROLL @@ -856,7 +954,7 @@ struct MixedInputUtils { CUTE_STATIC_ASSERT_V(size(src) == size(zeros)); Tensor scales_vm = cute::group_modes<1,-1>(cute::zipped_divide(scales, Int{})); Tensor zeros_vm = cute::group_modes<1,-1>(cute::zipped_divide(zeros, Int{})); - + if constexpr (is_same_v) { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size<1>(dst_vm); ++i) { @@ -885,6 +983,114 @@ struct MixedInputUtils { } } + /// (Designed for separate transform pipeline in Blackwell) + /// Utilities to dequantize A. + template + CUTLASS_DEVICE + static void dequantize_A_kblock_for_transform( + Tensor const& tArA, + Tensor& tArACompute, + cute::tuple const& partitioned_extra_info, + int const k_block) { + + static_assert(is_rmem::value, "Input tensor for A conversion must come from registers"); + static_assert(is_rmem::value, "Output tensor for A conversion must come from registers"); + static_assert(cosize_v == cosize_v); + static_assert(size_v == cosize_v); + static_assert(size_v == cosize_v); + using SrcType = typename EngineIn::value_type; + using DstType = typename EngineOut::value_type; + + auto src = tArA(_, _, _, k_block); + auto dst = tArACompute(_, _, _, k_block); + constexpr int num_elements = decltype(size(src))::value; + + constexpr int pack = decltype(select_packing::value())::value; + using Converter = cutlass::NumericArrayConverter; + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + constexpr int DstElementsPerReg = 32 / sizeof_bits_v; + using RegArray = cutlass::AlignedArray; + + auto src_arr = recast(src); + auto dst_arr = recast(dst); + + Tensor dst_vm = cute::group_modes<1,-1>(cute::zipped_divide(dst, pack)); + + cute::transform(src_arr, dst_arr, Converter::convert); + + if constexpr (ModeHasScales) { + + auto const& scales = cute::get<1>(partitioned_extra_info)(_,_,_,k_block); + + CUTE_STATIC_ASSERT_V(size(src) == size(scales)); + + if constexpr (is_same_v) { + + using ScaleArray = cutlass::Array; + auto scale_arr = recast(filter_zeros(scales)); + + if constexpr (is_same_v){ + Tensor scales_vm = cute::group_modes<1,-1>(cute::zipped_divide(scales, pack)); + + for (int i = 0; i < size<1>(dst_vm); ++i){ + auto&& r = cute::recast(dst_vm(_,i))(0); + auto&& scale_reg = cute::recast(scales_vm(_,i))(0); + CUTLASS_PRAGMA_UNROLL + for (size_t ii = 0; ii < RegArray::kElements; ++ii) { + __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + bf16x2_val = __hmul2(bf16x2_val, + reinterpret_cast(scale_reg[ii])); + } + } + } + else{ + cute::transform(dst_arr, scale_arr, dst_arr, cute::multiplies{}); + } + } + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Do Nothing + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + static_assert(is_same_v, "ElementScale and ElementZero must be the same."); + + auto const& zeros = cute::get<3>(partitioned_extra_info)(_,_,_,k_block); + CUTE_STATIC_ASSERT_V(size(src) == size(zeros)); + + if constexpr (is_same_v) { + using ZeroArray = cutlass::Array; + auto zero_arr = recast(filter_zeros(zeros)); + + if constexpr (is_same_v) { + Tensor zeros_vm = cute::group_modes<1,-1>(cute::zipped_divide(zeros, pack)); + + for (int i = 0; i < size<1>(dst_vm); ++i){ + auto&& r = cute::recast(dst_vm(_,i))(0); + auto&& zero_reg = cute::recast(zeros_vm(_,i))(0); + CUTLASS_PRAGMA_UNROLL + for (size_t ii = 0; ii < RegArray::kElements; ++ii) { + __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + bf16x2_val = __hadd2(bf16x2_val, + reinterpret_cast(zero_reg[ii])); + } + } + } + else{ + cute::transform(dst_arr, zero_arr, dst_arr, cute::plus{}); + } + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); + } + } +} + + /// Utilities for any additional inputs inside of the TMA load template < class Params, @@ -897,39 +1103,39 @@ struct MixedInputUtils { cute::tuple const& load_inputs, TensorStorage& shared_tensors, uint2 const& cluster_local_block_id, - int const m_coord, + int const m_coord, int const l_coord) { if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { return cute::make_tuple(); - } + } else if constexpr (ModeHasScales) { Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) Tensor gS_mkl = get<2>(load_inputs); auto block_tma_s = mainloop_params.tma_load_scale.get_slice(cluster_local_block_id.y); Tensor gS = gS_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor tSgS = block_tma_s.partition_S(gS); // (TMA,TMA_M,TMA_K,k) + Tensor tSgS = block_tma_s.partition_S(gS); Tensor tSsS = block_tma_s.partition_D(sS); // (TMA,TMA_M,TMA_K,PIPE) if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { return cute::make_tuple(tSgS, tSsS); - } + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) Tensor gZ_mkl = get<3>(load_inputs); auto block_tma_z = mainloop_params.tma_load_zero.get_slice(cluster_local_block_id.y); Tensor gZ = gZ_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor tZgZ = block_tma_z.partition_S(gZ); // (TMA,TMA_M,TMA_K,k) + Tensor tZgZ = block_tma_z.partition_S(gZ); Tensor tZsZ = block_tma_z.partition_D(sZ); // (TMA,TMA_M,TMA_K,PIPE) - return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ); + return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ); } else { - static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); } } else { - static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); } } @@ -938,7 +1144,7 @@ struct MixedInputUtils { class ThreadMma, class TensorStorage > - CUTLASS_DEVICE + CUTLASS_DEVICE static auto partition_extra_mma_info( ThreadMma const& mma_thread_slice, TensorStorage& shared_tensors) { @@ -950,8 +1156,8 @@ struct MixedInputUtils { else if constexpr (UseScaleLookupTable) { Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) Tensor tCsS = mma_thread_slice.partition_A(sS); - Tensor tCrS_neg = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout()); - Tensor tCrS_pos = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout()); + Tensor tCrS_neg = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout()); + Tensor tCrS_pos = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout()); if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { return cute::make_tuple(tCsS, tCrS_neg, tCrS_pos); @@ -960,7 +1166,7 @@ struct MixedInputUtils { else if constexpr (ModeHasScales) { Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) Tensor tCsS = mma_thread_slice.partition_A(sS); - Tensor tCrS = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout()); + Tensor tCrS = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout()); if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { return cute::make_tuple(tCsS, tCrS); @@ -968,13 +1174,54 @@ struct MixedInputUtils { else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) Tensor tCsZ = mma_thread_slice.partition_A(sZ); - Tensor tCrZ = make_tensor(mma_thread_slice.partition_fragment_A(sZ(_,_,Int<0>{})).layout()); + Tensor tCrZ = make_tensor(mma_thread_slice.partition_fragment_A(sZ(_,_,Int<0>{})).layout()); return cute::make_tuple(tCsS, tCrS, tCsZ, tCrZ); } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); } - } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + + template < + class TiledMma, + class TiledCopy, + class TensorStorage + > + CUTLASS_DEVICE + static auto partition_extra_transform_info( + TiledMma const& tiled_mma, + TiledCopy const& smem_tiled_copy_S, + TensorStorage& shared_storage) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + return cute::make_tuple(); + } + else if constexpr (ModeHasScales) { + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + auto smem_thr_copy_S = smem_tiled_copy_S.get_slice(threadIdx.x % 128); + + Tensor sS = make_tensor(make_smem_ptr(shared_storage.input.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE) + Tensor tSsS = smem_thr_copy_S.partition_S(sS); + Tensor tSrS = make_tensor(tSsS(_,_,_,_,0).shape()); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(smem_tiled_copy_S, tSrS, tSsS); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor sZ = make_tensor(make_smem_ptr(shared_storage.input.smem_zero.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) + Tensor tZsZ = smem_thr_copy_S.partition_S(sZ); + Tensor tZrZ = make_tensor(tZsZ(_,_,_,_,0).shape()); + return cute::make_tuple(smem_tiled_copy_S, tSrS, tSsS, tZrZ, tZsZ); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); } @@ -996,18 +1243,18 @@ struct MixedInputUtils { auto smem_tiled_copy_S = make_tiled_copy_A(SmemCopyAtomScale{}, tiled_mma); auto smem_thr_copy_S = smem_tiled_copy_S.get_thread_slice(warp_group_thread_idx); Tensor tCrS_copy_view = smem_thr_copy_S.retile_D(cute::get<1>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) - + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view); - } + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { Tensor tCrZ_copy_view = smem_thr_copy_S.retile_D(cute::get<3>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view, tCrZ_copy_view); - } + } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); } - } + } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); } diff --git a/3rd/cutlass/include/cutlass/detail/collective/moe_stride_utils.hpp b/3rd/cutlass/include/cutlass/detail/collective/moe_stride_utils.hpp new file mode 100644 index 0000000..e0a55b5 --- /dev/null +++ b/3rd/cutlass/include/cutlass/detail/collective/moe_stride_utils.hpp @@ -0,0 +1,99 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// +namespace cutlass { +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_HOST_DEVICE +cute::Stride, int64_t> +make_internal_packed_stride(cute::Stride, int64_t> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); + int batch_count = cute::get<2>(shape_MKL); + if (batch_count > 1) { + cute::get<2>(s_copy) = static_cast(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL)); + } + else { + cute::get<2>(s_copy) = static_cast(0); + } + return s_copy; +} + +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, int64_t> +make_internal_packed_stride(cute::Stride, IntT, int64_t> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); + int batch_count = cute::get<2>(shape_MKL); + if (batch_count > 1) { + cute::get<2>(s_copy) = static_cast(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL)); + } + else { + cute::get<2>(s_copy) = static_cast(0); + } + return s_copy; +} + +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<0>> +make_internal_packed_stride(cute::Stride, cute::Int<0>> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); + return s_copy; +} + +template +CUTLASS_HOST_DEVICE +cute::Stride, StrideIntT, cute::Int<0>> +make_internal_packed_stride(cute::Stride, StrideIntT, cute::Int<0>> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); + return s_copy; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +} +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/detail/collective/sm103_kernel_type.hpp b/3rd/cutlass/include/cutlass/detail/collective/sm103_kernel_type.hpp new file mode 100644 index 0000000..19a4d4a --- /dev/null +++ b/3rd/cutlass/include/cutlass/detail/collective/sm103_kernel_type.hpp @@ -0,0 +1,45 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Kernel type definitions specific for SM103 BlockScaled MMA +*/ + +#pragma once + +namespace cutlass::sm103::detail { + +enum class KernelPrefetchType { + TmaPrefetch, // TMA Prefetch (is the default version) + Disable // Disable Prefetch +}; + +} // namespace cutlass::sm103::detail diff --git a/3rd/cutlass/include/cutlass/detail/dependent_false.hpp b/3rd/cutlass/include/cutlass/detail/dependent_false.hpp index d2dd6a1..9b6f06a 100644 --- a/3rd/cutlass/include/cutlass/detail/dependent_false.hpp +++ b/3rd/cutlass/include/cutlass/detail/dependent_false.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/detail/helper_macros.hpp b/3rd/cutlass/include/cutlass/detail/helper_macros.hpp index cf9b803..65c235c 100644 --- a/3rd/cutlass/include/cutlass/detail/helper_macros.hpp +++ b/3rd/cutlass/include/cutlass/detail/helper_macros.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/detail/layout.hpp b/3rd/cutlass/include/cutlass/detail/layout.hpp index e1c1bd6..3d07220 100644 --- a/3rd/cutlass/include/cutlass/detail/layout.hpp +++ b/3rd/cutlass/include/cutlass/detail/layout.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/detail/mainloop_fusion_helper_scale_factor.hpp b/3rd/cutlass/include/cutlass/detail/mainloop_fusion_helper_scale_factor.hpp index 84de1c7..d512ca7 100644 --- a/3rd/cutlass/include/cutlass/detail/mainloop_fusion_helper_scale_factor.hpp +++ b/3rd/cutlass/include/cutlass/detail/mainloop_fusion_helper_scale_factor.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/detail/mma.hpp b/3rd/cutlass/include/cutlass/detail/mma.hpp index b4cbd38..3be9a8e 100644 --- a/3rd/cutlass/include/cutlass/detail/mma.hpp +++ b/3rd/cutlass/include/cutlass/detail/mma.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/detail/sm100_blockscaled_layout.hpp b/3rd/cutlass/include/cutlass/detail/sm100_blockscaled_layout.hpp index e4f20cb..e44dd88 100644 --- a/3rd/cutlass/include/cutlass/detail/sm100_blockscaled_layout.hpp +++ b/3rd/cutlass/include/cutlass/detail/sm100_blockscaled_layout.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -88,9 +88,14 @@ struct Sm1xxBlockScaledConfig { CUTE_HOST_DEVICE static constexpr auto tile_atom_to_shape_SFA(ProblemShape problem_shape, LayoutSFA layout_sfa = LayoutSFA{}) { - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_MNKL; - return tile_to_shape(SfAtom{}, make_shape(M,K,L), Step<_2,_1,_3>{}); + if constexpr (rank(ProblemShape{}) == 3) { + auto [M, N, K] = problem_shape; + return tile_to_shape(SfAtom{}, make_shape(M,K), Step<_2,_1>{}); + } + else { + auto [M, N, K, L] = problem_shape; + return tile_to_shape(SfAtom{}, make_shape(M,K,L), Step<_2,_1,_3>{}); + } } // The following function is provided for user fill dynamic problem size to the layout_SFB. @@ -98,9 +103,14 @@ struct Sm1xxBlockScaledConfig { CUTE_HOST_DEVICE static constexpr auto tile_atom_to_shape_SFB(ProblemShape problem_shape, LayoutSFB layout_sfb = LayoutSFB{}) { - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_MNKL; - return tile_to_shape(SfAtom{}, make_shape(N,K,L), Step<_2,_1,_3>{}); + if constexpr (rank(ProblemShape{}) == 3) { + auto [M, N, K] = problem_shape; + return tile_to_shape(SfAtom{}, make_shape(N,K), Step<_2,_1>{}); + } + else { + auto [M, N, K, L] = problem_shape; + return tile_to_shape(SfAtom{}, make_shape(N,K,L), Step<_2,_1,_3>{}); + } } template diff --git a/3rd/cutlass/include/cutlass/detail/sm100_mixed_dtype_blockwise_layout.hpp b/3rd/cutlass/include/cutlass/detail/sm100_mixed_dtype_blockwise_layout.hpp new file mode 100644 index 0000000..6c88ece --- /dev/null +++ b/3rd/cutlass/include/cutlass/detail/sm100_mixed_dtype_blockwise_layout.hpp @@ -0,0 +1,182 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Block Wise Scale configs specific for SM100 Blockwise/Groupwise MMA +*/ + +#pragma once + +#include "cutlass/layout/matrix.h" + +#include "cute/int_tuple.hpp" +#include "cute/atom/mma_traits_sm100.hpp" + +namespace cutlass::detail{ + +///////////////////////////////////////////////////////////////////////////////////////////////// +using namespace cute; + +template +struct Sm100MixedInputBlockwiseScaleConfig { + + using ShapeScale = Shape, int32_t>, Shape, int32_t>, int32_t>; + + using StrideScale = conditional_t,Stride<_0,int32_t>, int32_t>, + Stride,Stride<_0,_1>, int32_t>>; + + using LayoutScale = Layout; + + CUTE_HOST_DEVICE + static constexpr auto + deduce_layout_scale() { + return LayoutScale{}; + } + + template + CUTE_HOST_DEVICE + static constexpr auto + smem_atom_layout_scale(CtaShape_MN_K cta_shape_mn_k) { + static_assert(cute::is_static_v, "Expect static CTA shape"); + + int constexpr size_MN = cute::get<0>(CtaShape_MN_K{}); + int constexpr size_K = cute::get<1>(CtaShape_MN_K{}); + + int constexpr SmemSizeMN = (SFVecSizeMN < size_MN) + ? SFVecSizeMN + : size_MN; + + int constexpr SmemSizeK = (SFVecSizeK < size_K) + ? SFVecSizeK + : size_K; + + int constexpr div_MN = cute::ceil_div(size_MN, SmemSizeMN); + int constexpr div_K = cute::ceil_div(size_K, SmemSizeK); + + auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + if constexpr (majorSFA == UMMA::Major::MN) { + return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, Int{})); + } + else { + return make_stride(make_stride(_0{}, Int{}), make_stride(_0{}, _1{})); + } + }(); + + return make_layout( + make_shape(make_shape(Int{}, Int{}), + make_shape(Int{}, Int{})), + strides + ); + } + + + + // The following function is provided for user fill dynamic problem size to the layout_SFA. + template + CUTE_HOST_DEVICE + static constexpr auto + tile_atom_to_shape_scale(ScaledInputDim scale_input_dims) { + const auto scale_input_dims_MNKL = append<3>(scale_input_dims, 1); + + auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + auto [MN, K, L] = scale_input_dims_MNKL; + if constexpr (majorSFA == UMMA::Major::MN) { + return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, cute::ceil_div(MN, SFVecSizeMN))); + } + else { + return make_stride(make_stride(_0{}, cute::ceil_div(K, SFVecSizeK)), make_stride(_0{}, _1{})); + } + }(); + + auto [MN, K, L] = scale_input_dims_MNKL; + auto mk_layout = make_layout( + make_shape(make_shape(Int{}, cute::ceil_div(MN, SFVecSizeMN)), + make_shape(Int{}, cute::ceil_div(K, SFVecSizeK))), + strides + ); + + return make_layout(append(shape(mk_layout), L), append(stride(mk_layout), size(filter_zeros(mk_layout)))); + } + +}; + +template +struct RuntimeMixedInputBlockwiseScaleConfig { + + using ShapeScale = Shape, Shape, int32_t>; + + using StrideScale = conditional_t,Stride<_0,int32_t>, int32_t>, + Stride,Stride<_0,_1>, int32_t>>; + + using LayoutScale = Layout; + + CUTE_HOST_DEVICE + static constexpr auto + deduce_layout_scale() { + return LayoutScale{}; + } + + // The following function is provided for user fill dynamic problem size to the layout_S. + template + CUTE_HOST_DEVICE + static constexpr auto + tile_atom_to_shape_scale(ProblemShape problem_shape, SFVecShape sf_vec_shape) { + auto problem_shape_MNKL = append<3>(problem_shape, 1); + + auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + auto [MN, K, L] = problem_shape_MNKL; + auto [sfmn, sfk] = sf_vec_shape; + if constexpr (majorScale == UMMA::Major::MN) { + return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, cute::ceil_div(MN, sfmn))); + } + else { + return make_stride(make_stride(_0{}, cute::ceil_div(K, sfk)), make_stride(_0{}, _1{})); + } + }(); + + auto [MN, K, L] = problem_shape_MNKL; + auto [sfmn, sfk] = sf_vec_shape; + auto mk_layout = make_layout( + make_shape(make_shape(sfmn, cute::ceil_div(MN, sfmn)), + make_shape(sfk, cute::ceil_div(K, sfk))), + strides + ); + + return make_layout(append(shape(mk_layout), L), append(stride(mk_layout), size(filter_zeros(mk_layout)))); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::detail diff --git a/3rd/cutlass/include/cutlass/detail/sm100_tmem_helper.hpp b/3rd/cutlass/include/cutlass/detail/sm100_tmem_helper.hpp index f12bac1..1192210 100644 --- a/3rd/cutlass/include/cutlass/detail/sm100_tmem_helper.hpp +++ b/3rd/cutlass/include/cutlass/detail/sm100_tmem_helper.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/detail/sm103_blockscaled_layout.hpp b/3rd/cutlass/include/cutlass/detail/sm103_blockscaled_layout.hpp new file mode 100644 index 0000000..800d019 --- /dev/null +++ b/3rd/cutlass/include/cutlass/detail/sm103_blockscaled_layout.hpp @@ -0,0 +1,117 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Blocked Scale configs specific for SM103 BlockScaled MMA +*/ + +#pragma once + +#include "cutlass/layout/matrix.h" + +#include "cute/int_tuple.hpp" +#include "cute/atom/mma_traits_sm100.hpp" + +namespace cutlass::detail{ + +///////////////////////////////////////////////////////////////////////////////////////////////// +using namespace cute; + +template +struct Sm103BlockScaledBasicChunk { + + using Blk_MN = _128; + using Blk_SF = _4; + + using SfKMajorAtom = Layout< Shape< Shape< _8, _4, _4>, Shape, _4>>, + Stride, Stride< _0, _1>>>; + using SfMNMajorAtom = Layout< Shape< Shape, _4>, Shape<_8, _4, _4>>, + Stride, Stride<_16,_128, _4>>>; + using SfAtom = cute::conditional_t; +}; + +template +struct Sm103BlockScaledConfig { + // We are creating the SFA and SFB tensors' layouts in the collective since they always have the same layout. + // k-major order + static constexpr int SFVecSize = SFVecSize_; + using Sm103BlkScaledChunk = Sm103BlockScaledBasicChunk; + using Blk_MN = typename Sm103BlkScaledChunk::Blk_MN; + using Blk_SF = typename Sm103BlkScaledChunk::Blk_SF; + using SfAtom = typename Sm103BlkScaledChunk::SfAtom; + + using LayoutSF = decltype(tile_to_shape(SfAtom{}, make_shape(int(0),int(0),int(0)),Step<_2,_1,_3>{})); + + CUTE_HOST_DEVICE + static constexpr auto + deduce_layoutSFA() { + return LayoutSF{}; + } + + CUTE_HOST_DEVICE + static constexpr auto + deduce_layoutSFB() { + return LayoutSF{}; + } + + // The following function is provided for user fill dynamic problem size to the layout_SFA. + template < class ProblemShape, class LayoutSFA = LayoutSF> + CUTE_HOST_DEVICE + static constexpr auto + tile_atom_to_shape_SFA(ProblemShape problem_shape, LayoutSFA layout_sfa = LayoutSFA{}) { + if constexpr (rank(ProblemShape{}) == 3) { + auto [M, N, K] = problem_shape; + return tile_to_shape(SfAtom{}, make_shape(M,K), Step<_2,_1>{}); + } + else { + auto [M, N, K, L] = problem_shape; + return tile_to_shape(SfAtom{}, make_shape(M,K,L), Step<_2,_1,_3>{}); + } + } + + // The following function is provided for user fill dynamic problem size to the layout_SFB. + template + CUTE_HOST_DEVICE + static constexpr auto + tile_atom_to_shape_SFB(ProblemShape problem_shape, LayoutSFB layout_sfb = LayoutSFB{}) { + if constexpr (rank(ProblemShape{}) == 3) { + auto [M, N, K] = problem_shape; + return tile_to_shape(SfAtom{}, make_shape(N,K), Step<_2,_1>{}); + } + else { + auto [M, N, K, L] = problem_shape; + return tile_to_shape(SfAtom{}, make_shape(N,K,L), Step<_2,_1,_3>{}); + } + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::detail diff --git a/3rd/cutlass/include/cutlass/device_kernel.h b/3rd/cutlass/include/cutlass/device_kernel.h index 40e19a3..890e3dc 100644 --- a/3rd/cutlass/include/cutlass/device_kernel.h +++ b/3rd/cutlass/include/cutlass/device_kernel.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -35,6 +35,7 @@ #pragma once #include // CUTLASS_HOST_DEVICE +#include // cutlass::arch::synclog_* #include // uint64_t // __grid_constant__ was introduced in CUDA 11.7. diff --git a/3rd/cutlass/include/cutlass/epilogue/collective/builders/sm100_builder.inl b/3rd/cutlass/include/cutlass/epilogue/collective/builders/sm100_builder.inl index b710216..e5fafd5 100644 --- a/3rd/cutlass/include/cutlass/epilogue/collective/builders/sm100_builder.inl +++ b/3rd/cutlass/include/cutlass/epilogue/collective/builders/sm100_builder.inl @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -50,9 +50,10 @@ #include "cutlass/epilogue/fusion/callbacks.hpp" #include "cutlass/epilogue/fusion/operations.hpp" // detail::is_sfd_epilogue_v #include "cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp" +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(type_traits) #else #include #endif @@ -581,7 +582,13 @@ sm100_sparse_get_tma_dispatch_policy() { * Selected op also maximizes the TMEM_LOAD shape in order to minimize TMEM_LOADs issued, * subject to the constraint of the provided per-warp tmem subpartition shape **/ -template +template< + class GmemStrideTypeD, + class ElementAccumulator, + class ElementD, + class TmemShape_MN, + bool IsBlockScaleSupported +> constexpr auto sm100_get_tmem_load_op() { using namespace cute; @@ -957,6 +964,190 @@ struct CallbacksBuilder< >; }; +// Attempts to compute a reasonably performant epilogue tile or allows the user to provide one. +template< + class OpClass, + class CtaTileShape_MNK, + class EpilogueTileType, + class TmemWarpShape_MN, + class ElementC_, + class GmemStrideTypeC, + class ElementD, + class GmemStrideTypeD, + bool IsPerColScaleSupported +> +static constexpr auto +sm100_dense_compute_tile_shape_or_override() { + using namespace cute; + static_assert(!cute::is_same_v && !cute::is_same_v); + + constexpr bool DisableSource = cute::is_void_v; + using ElementC = cute::conditional_t; + + if constexpr (is_same_v && + is_same_v && + size<1>(CtaTileShape_MNK{}) == 256) { + constexpr int CtaM = size<0>(CtaTileShape_MNK{}); + constexpr int WarpM = size<0>(TmemWarpShape_MN{}); + constexpr int DpFull = 32; + constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load + // Note: + // Set Epi_Tile_N to 128 support OverlappingAccum for the largest tile. + // This is a general workable epi_tile_N which does not promise best perf. + return make_tile(Int{}, Int<128>{}); + } + else if constexpr (is_same_v) { + constexpr int CtaM = size<0>(CtaTileShape_MNK{}); + constexpr int CtaN = size<1>(CtaTileShape_MNK{}); + constexpr int WarpM = size<0>(TmemWarpShape_MN{}); + constexpr int WarpN = size<1>(TmemWarpShape_MN{}); + constexpr int MaxBits = cute::max(sizeof_bits_v, sizeof_bits_v); + + constexpr int DpFull = 32; // tmem datapaths in 1 subpartition + constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load + constexpr int N_perf = [&]() constexpr { // Known subtile sizes tested for perf + // Epilogues w/o residual load are less sensitive to smem allocation + // Target a fixed amount of compute per epilogue iteration + if (DisableSource) { + if (MaxBits == 4) { + // Make epilogue tile larger to reduce the epilogue iterations. + // 64 is the experimental value. It will minimize epilogue iterations but keep the number of A/B buffers the same. + constexpr int ComputeElts = 8192; + return ComputeElts / M; + } + constexpr int ComputeElts = 4096; + return ComputeElts / M; + } + // Epilogues w/ residual load are more sensitive to smem allocation + // Target optimal smem distribution between epilogue+mainloop based on datatype+tilesize + else { + if (MaxBits == 32) { + return (CtaM > 64 && CtaN <= 128) ? 16 : 32; + } + // Per-column scaling is high register pressure, reduce tile to prevent spills + else if (IsPerColScaleSupported) { + return 32; + } + else if (MaxBits == 16) { + return (CtaN <= 128) ? 32 : 64; + } + else { + return 64; + } + } + }(); + constexpr int N_min_C = (DisableSource || detail::is_m_major()) ? 8 * WarpN + : (sizeof_bits_v == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type + : 128 / sizeof_bits_v * WarpN; + constexpr int N_min_D = (detail::is_m_major()) ? 8 * WarpN + : (sizeof_bits_v == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type + : 128 / sizeof_bits_v * WarpN; + constexpr int N_tmp = cute::min(CtaN, cute::max(N_perf, N_min_C, N_min_D)); + constexpr int N = CtaN % N_tmp == 0 ? N_tmp : CtaN; + static_assert(CtaN >= N_min_C && CtaN >= N_min_D, "CTA tile too small"); + + // stride by tmem warp layout and return a by-mode tiler + auto tile_m = Layout>{}; + auto tile_n = Layout,Int< WarpN>>, + Stride,Int>>{}; + + return make_tile(tile_m, coalesce(tile_n)); + } + else { + static_assert(cute::is_tuple::value && not is_layout::value, + "EpilogueTile must be a cute::Tile or cute::Shape"); + + EpilogueTileType epi_tile; + constexpr int M = size<0>(shape(epi_tile)); + constexpr int N = size<1>(shape(epi_tile)); + static_assert(N % 8 == 0, "Unsupported tile shape"); + + return epi_tile; + } +} + +template< + bool Is2SmMma, + class MmaTileShape_MNK +> +static constexpr auto +sm100_tmem_warps() { + if constexpr (Is2SmMma && size<0>(MmaTileShape_MNK{}) == 128) { + return Shape<_2,_2>{}; + } + else { + return Shape<_4,_1>{}; + } +} + +template< + bool Is2SmMma, + class MmaTileShape_MNK +> +static constexpr auto +sm100_cta_tile_shape() { + if constexpr (Is2SmMma) { // 2x1 threadblock shape + auto [mma_tile_m, mma_tile_n, mma_tile_k] = MmaTileShape_MNK{}; + auto cta_tile_m = reverse(shape_div(reverse(mma_tile_m), _2{})); // first MmaTile_M/2 elements, preserve multimode + return make_shape(cta_tile_m, mma_tile_n, mma_tile_k); + } + else { // 1x1 threadblock shape + return MmaTileShape_MNK{}; + } +} + +template< + class EpilogueScheduleType, + class ElementC_, + class ElementD, + int EpiTiles, + int FragmentSize +> +static constexpr auto +sm100_dense_dispatch_policy() { + // 8b residuals load fast and consume little smem, so the perf cost of waiting on stores to finish outweighs the cost of extra allocation + constexpr bool ReuseSmem = sizeof_bits_v > 8; + // TMA store delay performs worse with residual loads + constexpr bool DelayTmaStore = is_void_v; + + constexpr int StagesD = cute::min(EpiTiles, 2); + constexpr int StagesC = ReuseSmem ? cute::max(cute::min(EpiTiles, 4), StagesD+1) + : cute::min(EpiTiles, 4); + + if constexpr (is_base_of_v || + is_base_of_v) { + return Sm100PtrArrayNoSmemWarpSpecialized{}; + } + else if constexpr (is_base_of_v || is_base_of_v) { + return Sm100NoSmemWarpSpecialized{}; + } + else if constexpr (is_same_v || + is_same_v) { + return Sm100PtrArrayPlanarComplexNoSmemWarpSpecialized{}; + } + else if constexpr (is_same_v || + is_same_v) { + constexpr bool ReuseSmem_ = (sizeof_bits_v == sizeof_bits_v); // limited smem reuse support for planar complex for now + constexpr int StagesC_ = ReuseSmem_ ? cute::max(cute::min(EpiTiles, 4), StagesD+1) : cute::min(EpiTiles, 4); + constexpr bool DelayTmaStore_ = false; // TMA store delay complicates tensormap updates for Ptr-Array GEMMs + return Sm100PtrArrayPlanarComplexTmaWarpSpecialized{}; + } + else if constexpr (is_same_v || + is_same_v) { + constexpr bool ReuseSmem_ = (sizeof_bits_v == sizeof_bits_v); // limited smem reuse support for planar complex for now + constexpr int StagesC_ = ReuseSmem_ ? cute::max(cute::min(EpiTiles, 4), StagesD+1) : cute::min(EpiTiles, 4); + return Sm100PlanarComplexTmaWarpSpecialized{}; + } + else if constexpr (is_same_v || + is_same_v) { + constexpr bool DelayTmaStore_ = false; // TMA store delay complicates tensormap updates for Ptr-Array GEMMs + return Sm100PtrArrayTmaWarpSpecialized{}; + } + else { + return Sm100TmaWarpSpecialized{}; + } +} + // Helper for building TMA warp-specialized collective epilogues, specialized by // the fusion operation performed and the dispatch policy to use. template < @@ -1016,17 +1207,7 @@ private: } } using CtaTileShape_MNK = decltype(cta_tile_shape()); - - static constexpr auto - tmem_warps() { - if constexpr (Is2SmMma && size<0>(MmaTileShape_MNK{}) == 128) { - return Shape<_2,_2>{}; - } - else { - return Shape<_4,_1>{}; - } - } - using TmemWarpShape_MN = decltype(tmem_warps()); + using TmemWarpShape_MN = decltype(detail::sm100_tmem_warps()); // Attempts to compute a reasonably performant epilogue tile or allows the user to provide one. static constexpr auto @@ -1040,84 +1221,10 @@ private: ElementC_, GmemStrideTypeC, ElementD, GmemStrideTypeD, Schedule, FusionOp>(); } - else if constexpr (is_same_v && - is_same_v && - size<1>(CtaTileShape_MNK{}) == 256) { - constexpr int CtaM = size<0>(CtaTileShape_MNK{}); - constexpr int WarpM = size<0>(TmemWarpShape_MN{}); - constexpr int DpFull = 32; - constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load - // Note: - // Set Epi_Tile_N to 128 support OverlappingAccum for the largest tile. - // This is a general workable epi_tile_N which does not promise best perf. - return make_tile(Int{}, Int<128>{}); - } - else if constexpr (is_same_v) { - constexpr int CtaM = size<0>(CtaTileShape_MNK{}); - constexpr int CtaN = size<1>(CtaTileShape_MNK{}); - constexpr int WarpM = size<0>(TmemWarpShape_MN{}); - constexpr int WarpN = size<1>(TmemWarpShape_MN{}); - constexpr int MaxBits = cute::max(sizeof_bits_v, sizeof_bits_v); - - constexpr int DpFull = 32; // tmem datapaths in 1 subpartition - constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load - constexpr int N_perf = [&]() constexpr { // Known subtile sizes tested for perf - // Epilogues w/o residual load are less sensitive to smem allocation - // Target a fixed amount of compute per epilogue iteration - if (DisableSource) { - if (MaxBits == 4) { - // Make epilogue tile larger to reduce the epilogue iterations. - // 64 is the experimental value. It will minimize epilogue iterations but keep the number of A/B buffers the same. - constexpr int ComputeElts = 8192; - return ComputeElts / M; - } - constexpr int ComputeElts = 4096; - return ComputeElts / M; - } - // Epilogues w/ residual load are more sensitive to smem allocation - // Target optimal smem distribution between epilogue+mainloop based on datatype+tilesize - else { - if (MaxBits == 32) { - return (CtaM > 64 && CtaN <= 128) ? 16 : 32; - } - // Per-column scaling is high register pressure, reduce tile to prevent spills - else if (FusionOp::IsPerColScaleSupported) { - return 32; - } - else if (MaxBits == 16) { - return (CtaN <= 128) ? 32 : 64; - } - else { - return 64; - } - } - }(); - constexpr int N_min_C = (DisableSource || detail::is_m_major()) ? 8 * WarpN - : (sizeof_bits_v == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type - : 128 / sizeof_bits_v * WarpN; - constexpr int N_min_D = (detail::is_m_major()) ? 8 * WarpN - : (sizeof_bits_v == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type - : 128 / sizeof_bits_v * WarpN; - constexpr int N = cute::min(CtaN, cute::max(N_perf, N_min_C, N_min_D)); - static_assert(CtaN >= N_min_C && CtaN >= N_min_D, "CTA tile too small"); - - // stride by tmem warp layout and return a by-mode tiler - auto tile_m = Layout>{}; - auto tile_n = Layout,Int< WarpN>>, - Stride,Int>>{}; - - return make_tile(tile_m, coalesce(tile_n)); - } else { - static_assert(cute::is_tuple::value && not is_layout::value, - "EpilogueTile must be a cute::Tile or cute::Shape"); - - EpilogueTileType epi_tile; - constexpr int M = size<0>(shape(epi_tile)); - constexpr int N = size<1>(shape(epi_tile)); - static_assert(N % 8 == 0, "Unsupported tile shape"); - - return epi_tile; + return sm100_dense_compute_tile_shape_or_override< + OpClass, CtaTileShape_MNK, EpilogueTileType, TmemWarpShape_MN, + ElementC_, GmemStrideTypeC, ElementD, GmemStrideTypeD, FusionOp::IsPerColScaleSupported>(); } } using EpilogueTile_MN = decltype(epilogue_tile()); @@ -1128,35 +1235,33 @@ private: using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTileShape_MN{}, TmemWarpShape_MN{})); using AccLoadOp = decltype(detail::sm100_get_tmem_load_op< - GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, FusionOp>()); + GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, + FusionOp::IsBlockScaleSupported + >()); static constexpr auto dispatch_policy() { - // 8b residuals load fast and consume little smem, so the perf cost of waiting on stores to finish outweighs the cost of extra allocation - constexpr bool ReuseSmem = sizeof_bits_v > 8; - // TMA store delay performs worse with residual loads - constexpr bool DelayTmaStore = is_void_v; - - constexpr int StagesD = cute::min(EpiTiles, 2); - constexpr int StagesC = ReuseSmem ? cute::max(cute::min(EpiTiles, 4), StagesD+1) - : cute::min(EpiTiles, 4); - if constexpr (is_same_v || is_same_v) { return detail::sparse::sm100_sparse_get_tma_dispatch_policy(); } - else if constexpr (is_same_v || - is_same_v) { - constexpr bool DelayTmaStore_ = false; // TMA store delay complicates tensormap updates for Ptr-Array GEMMs - return Sm100PtrArrayTmaWarpSpecialized{}; - } else { - return Sm100TmaWarpSpecialized{}; + return detail::sm100_dense_dispatch_policy(); } } static constexpr auto fusion_callbacks() { + if constexpr (is_same_v || + is_same_v || + is_same_v || + is_same_v) { + static_assert(IsDefaultFusionOp::value, "unsupported schedule + fusion"); + constexpr thread::ScaleType::Kind ScaleType = DisableSource ? thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default; + return thread::LinearCombinationPlanarComplex< + ElementD, FragmentSize, ElementAccumulator, ElementCompute, FusionOp::RoundStyle, ScaleType>({}); + } + else { return typename CallbacksBuilder< decltype(dispatch_policy()), @@ -1227,6 +1332,87 @@ public: >; }; +template< + class OpClass, + class MmaTileShape_MNK, + class EpilogueTileType, + class ElementAccumulator_, + class ElementC, + class ElementD, + class Schedule, + class GmemStrideTypeC, + class GmemStrideTypeD, + bool IsPerColScaleSupported, + bool IsBlockScaleSupported +> +struct Sm100EpilogueDescriptor { + using ElementAccumulator = ElementAccumulator_; + + static constexpr bool Is2SmMma = is_base_of_v || is_base_of_v; + using CtaTileShape_MNK = decltype(sm100_cta_tile_shape()); + using TileShape = CtaTileShape_MNK; + + using TmemWarpShape_MN = decltype(detail::sm100_tmem_warps()); + + using EpilogueTile = decltype( + sm100_dense_compute_tile_shape_or_override() + ); + + using EpilogueTileShape_MN = decltype(product_each(shape(EpilogueTile{}))); + static constexpr int EpiTiles = size(shape_div(take<0,2>(CtaTileShape_MNK{}), EpilogueTileShape_MN{})); + static constexpr int FragmentSize = size(EpilogueTileShape_MN{}) / NumThreadsPerWarpGroup; + + using DispatchPolicy = decltype(sm100_dense_dispatch_policy()); + + constexpr static int StagesC = DispatchPolicy::StagesC; + constexpr static int StagesD = DispatchPolicy::StagesD; + + using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTileShape_MN{}, TmemWarpShape_MN{})); + using AccLoadOp = decltype(detail::sm100_get_tmem_load_op< + GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, + IsBlockScaleSupported + >()); +}; + +// Get Stride, SmemLayout, and CopyOpS2R for AuxLoad node +template< + typename EpilogueDescriptor, + typename StrideOrLayoutTag, + typename ElementAux +> +struct Sm100AuxLoadDescriptor { + constexpr static int Stages = EpilogueDescriptor::StagesC; + using EpilogueTile = typename EpilogueDescriptor::EpilogueTile; + using Element = ElementAux; + using Stride = cutlass::detail::TagToStrideC_t; + + using SmemLayoutAtom = decltype(detail::sm100_get_epilogue_smem_swizzle_layout_atom< + Stride, ElementAux, EpilogueTile>()); + + using CopyOpS2R = decltype(detail::sm100_get_smem_load_op< + Stride, ElementAux, typename EpilogueDescriptor::ElementAccumulator, typename EpilogueDescriptor::AccLoadOp>()); +}; + +// Get Stride, SmemLayout, and CopyOpS2R for AuxStore node +template< + typename EpilogueDescriptor, + typename StrideOrLayoutTag, + typename ElementAux +> +struct Sm100AuxStoreDescriptor { + constexpr static int Stages = EpilogueDescriptor::StagesD; + using EpilogueTile = typename EpilogueDescriptor::EpilogueTile; + using Element = ElementAux; + using Stride = cutlass::detail::TagToStrideC_t; + + using SmemLayoutAtom = decltype(detail::sm100_get_epilogue_smem_swizzle_layout_atom< + Stride, ElementAux, EpilogueTile>()); + + using CopyOpR2S = decltype(detail::sm100_get_smem_store_op< + Stride, ElementAux, typename EpilogueDescriptor::ElementAccumulator, typename EpilogueDescriptor::AccLoadOp>()); +}; + } // namespace detail /////////////////////////////////////////////////////////////////////////////// @@ -1272,6 +1458,17 @@ private: static constexpr bool Is1SmMma = is_base_of_v; static constexpr bool Is2SmMma = is_base_of_v; + static constexpr bool IsInterleavedComplex = is_complex::value; + static constexpr bool IsFastF32Schedule = is_same_v || + is_same_v || + is_same_v || + is_same_v; + static constexpr bool IsBlockwiseSchedule = is_same_v || + is_same_v || + is_same_v || + is_same_v; + // Transform kernels - when dispatching to sm100 nosmem epilogue, go through the default path without EVT support. + static constexpr bool IsTransformSchedule = IsInterleavedComplex || IsFastF32Schedule || IsBlockwiseSchedule; static_assert(Is1SmMma ^ Is2SmMma, "unsupported schedule"); static_assert(not (Is2SmMma && size<0>(ClusterShape_MNK{}) % 2 == 1), "schedule + cluster mismatch"); @@ -1296,17 +1493,7 @@ private: } } using CtaTileShape_MNK = decltype(cta_tile_shape()); - - static constexpr auto - tmem_warps() { - if constexpr (Is2SmMma && size<0>(MmaTileShape_MNK{}) == 128) { - return Shape<_2,_2>{}; - } - else { - return Shape<_4,_1>{}; - } - } - using TmemWarpShape_MN = decltype(tmem_warps()); + using TmemWarpShape_MN = decltype(detail::sm100_tmem_warps()); static constexpr auto epilogue_tile() { @@ -1315,7 +1502,9 @@ private: static_assert(is_tuple_v, "Shape or Tile"); return EpilogueTileType{}; } - else if constexpr (is_same_v) { // perf specialized case + else if constexpr (is_same_v || not IsTransformSchedule) { + // Save register usage for sm103 blockscaled kernels and sm100 cpasync kernels + // to avoid register spilling. constexpr int EpiM = size<0>(CtaTileShape_MNK{}); constexpr int EpiN = cute::min(_64{}, size<1>(CtaTileShape_MNK{})); return Shape, Int>{}; @@ -1328,30 +1517,36 @@ private: using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTile{}, TmemWarpShape_MN{})); using AccLoadOp = decltype(detail::sm100_get_tmem_load_op< - GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, FusionOp>()); + GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, + FusionOp::IsBlockScaleSupported + >()); static constexpr int FragmentSize = size(EpilogueTile{}) / NumThreadsPerWarpGroup; - static constexpr auto - dispatch_policy() { - if constexpr (is_same_v || - is_same_v) { - return Sm100PtrArrayNoSmemWarpSpecialized{}; - } - else { - return Sm100NoSmemWarpSpecialized{}; - } - } - using DispatchPolicy = decltype(dispatch_policy()); + using EpilogueTileShape_MN = decltype(product_each(shape(EpilogueTile{}))); + static constexpr int EpiTiles = size(shape_div(take<0,2>(CtaTileShape_MNK{}), EpilogueTileShape_MN{})); + + using DispatchPolicy = decltype(detail::sm100_dense_dispatch_policy()); static constexpr auto fusion_callbacks() { constexpr thread::ScaleType::Kind ScaleType = DisableSource ? thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default; - if constexpr (IsDefaultFusionOp::value && not is_same_v) { + if constexpr (IsDefaultFusionOp::value &&\ + not is_same_v && \ + (IsTransformSchedule || \ + is_same_v || \ + is_same_v) + ) { // Legacy codepath using thread::LinearCombination, do not expect this to be stable return thread::LinearCombination< ElementD, 1, ElementAccumulator, ElementCompute, ScaleType, FusionOp::RoundStyle, ElementC>({}); } + else if constexpr (is_same_v || + is_same_v) { + static_assert(IsDefaultFusionOp::value, "unsupported schedule + fusion"); + return thread::LinearCombinationPlanarComplex< + ElementD, FragmentSize, ElementAccumulator, ElementCompute, FusionOp::RoundStyle, ScaleType>({}); + } else { return typename detail::CallbacksBuilder< DispatchPolicy, @@ -1519,6 +1714,106 @@ public: >::CollectiveOp; }; +template < + class MmaTileShape_MNK, + class ClusterShape_MNK, + class ElementAccumulator, + class ElementCompute, + class ElementC_, + class GmemLayoutTagC_, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class EpilogueScheduleType, + class FusionOp +> +struct CollectiveBuilder< + arch::Sm100, + arch::OpClassSimt, + MmaTileShape_MNK, + ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementC_, + GmemLayoutTagC_, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + EpilogueScheduleType, + FusionOp, + cute::enable_if_t< + cute::is_same_v || + cute::is_same_v || + cute::is_same_v >> { + using CtaTileShape_MNK = MmaTileShape_MNK; // cluster MMA not supported + + // Passing void C disables source load + using ElementC = cute::conditional_t, + ElementD, ElementC_>; // prevents void ref breakages + using GmemLayoutTagC = cute::conditional_t, + GmemLayoutTagD, GmemLayoutTagC_>; + static constexpr thread::ScaleType::Kind ScaleType = cute::is_void_v ? + thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default; + + using GmemStrideTypeC = cutlass::detail::TagToStrideC_t; + using GmemStrideTypeD = cutlass::detail::TagToStrideC_t; + + using ThreadOp = cute::conditional_t< + IsDefaultFusionOp::value, + thread::LinearCombination< + ElementD, AlignmentD, ElementAccumulator, ElementCompute, + ScaleType, FloatRoundStyle::round_to_nearest, ElementC> + , + thread::LinearCombinationBiasElementwise< + ElementC, ElementAccumulator, ElementCompute, ElementD, ElementD, AlignmentD, + typename FusionOp::ActivationFn, cutlass::plus, + false, typename FusionOp::ElementBias> + >; + static_assert(not (cute::is_same_v && not IsDefaultFusionOp::value), "unsupported schedule + fusion"); + + using WarpShape_MNK = decltype(cutlass::gemm::collective::detail::sm100_simt_f32_warp_shape_mnk_selector()); + static constexpr int ThreadCount = cute::size(WarpShape_MNK{}) * NumThreadsPerWarp; + static constexpr int WarpShape_M = cute::size<0>(WarpShape_MNK{}); + static constexpr int WarpShape_N = cute::size<1>(WarpShape_MNK{}); + + // For 32 threads in 1 warp, we use [8 x 4] thread layouts and each thread will hold [4 x 4] accumulator value layouts. + // Then totally each warp will hold [32 x 16] accumulator value layouts. + // We separate the whole epilogue calculation to multi steps, + // each step will calculate 1x [32 x 16] for each warp to reduce register pressure (mainly for C register allocation for beta 1!= 0 case). + // So EpiTileM = WarpShape_M * 32 and EpiTileN = WarpShape_N * 16. + using EpiTileM = Int; + using EpiTileN = Int; + + using SmemLayout = cute::conditional_t(GmemStrideTypeD{}), + cute::Layout, cute::Stride<_1, EpiTileM>>, + cute::Layout, cute::Stride>>; + + using CopyAtomR2S = Copy_Atom, ElementAccumulator>; + + using CopyAtomS2R = Copy_Atom>, ElementAccumulator>; + + using TiledCopyS2R = decltype( + cutlass::gemm::collective::detail::make_simt_gmem_tiled_copy< + CopyAtomS2R, ThreadCount, AlignmentD, GmemStrideTypeD, EpiTileM, EpiTileN>()); + + using Schedule = cute::conditional_t, + EpilogueSimtVectorized, + EpilogueScheduleType>; + using CopyAtomR2G = Copy_Atom>, ElementD>; + using CollectiveOp = cutlass::epilogue::collective::Epilogue< + GmemStrideTypeC, + GmemStrideTypeD, + ThreadOp, + SmemLayout, + CopyAtomR2S, + TiledCopyS2R, + CopyAtomR2G, + Schedule>; +}; + /////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::epilogue::collective diff --git a/3rd/cutlass/include/cutlass/epilogue/collective/builders/sm103_builder.inl b/3rd/cutlass/include/cutlass/epilogue/collective/builders/sm103_builder.inl new file mode 100644 index 0000000..00a01aa --- /dev/null +++ b/3rd/cutlass/include/cutlass/epilogue/collective/builders/sm103_builder.inl @@ -0,0 +1,108 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cute/layout.hpp" // cute::Shape +#include "cute/numeric/numeric_types.hpp" // cute::sizeof_bits_v +#include "cutlass/arch/mma.h" // cutlass::arch::OpClassTensorOp, cutlass::OpClassSparseTensorOp +#include "cute/atom/copy_traits_sm100.hpp" +#include "cute/atom/mma_traits_sm100.hpp" +#include "cute/util/type_traits.hpp" // cute::is_same_v + +#include "cutlass/detail/dependent_false.hpp" // cutlass::detail::dependent_false +#include "cutlass/detail/layout.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/collective/builders/sm100_common.inl" +#include "cutlass/epilogue/collective/builders/sm90_common.inl" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp" + +namespace cutlass::epilogue::collective { +// Alias to sm100 builder +template < + class OpClass, + class MmaTileShape_MNK, // Static MMA tile shape + class ClusterShape_MNK, // Static cluster shape or dynamic (int, int, _1) + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC, + class GmemLayoutTagC, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class EpilogueScheduleType, + class FusionOp +> +struct CollectiveBuilder< + arch::Sm103, + OpClass, + MmaTileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + EpilogueScheduleType, + FusionOp +> +{ + using CollectiveOp = typename CollectiveBuilder< + arch::Sm100, + OpClass, + MmaTileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + EpilogueScheduleType, + FusionOp + >::CollectiveOp; +}; + +} // namespace cutlass::epilogue::collective diff --git a/3rd/cutlass/include/cutlass/epilogue/collective/builders/sm120_builder.inl b/3rd/cutlass/include/cutlass/epilogue/collective/builders/sm120_builder.inl index 80e84e9..200d79f 100644 --- a/3rd/cutlass/include/cutlass/epilogue/collective/builders/sm120_builder.inl +++ b/3rd/cutlass/include/cutlass/epilogue/collective/builders/sm120_builder.inl @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -36,10 +36,10 @@ #include "cutlass/epilogue/collective/detail.hpp" #include "cutlass/epilogue/collective/builders/sm90_common.inl" #include "cutlass/epilogue/collective/builders/sm120_common.inl" - +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(type_traits) #else #include #endif diff --git a/3rd/cutlass/include/cutlass/epilogue/collective/builders/sm120_common.inl b/3rd/cutlass/include/cutlass/epilogue/collective/builders/sm120_common.inl index 5b8779d..8bd1f75 100644 --- a/3rd/cutlass/include/cutlass/epilogue/collective/builders/sm120_common.inl +++ b/3rd/cutlass/include/cutlass/epilogue/collective/builders/sm120_common.inl @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/collective/builders/sm90_builder.inl b/3rd/cutlass/include/cutlass/epilogue/collective/builders/sm90_builder.inl index a94d945..4a3bf90 100644 --- a/3rd/cutlass/include/cutlass/epilogue/collective/builders/sm90_builder.inl +++ b/3rd/cutlass/include/cutlass/epilogue/collective/builders/sm90_builder.inl @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -45,9 +45,9 @@ #include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" #include "cutlass/epilogue/fusion/callbacks.hpp" #include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" - +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(type_traits) #else #include #endif diff --git a/3rd/cutlass/include/cutlass/epilogue/collective/builders/sm90_common.inl b/3rd/cutlass/include/cutlass/epilogue/collective/builders/sm90_common.inl index 4c259aa..2fdd775 100644 --- a/3rd/cutlass/include/cutlass/epilogue/collective/builders/sm90_common.inl +++ b/3rd/cutlass/include/cutlass/epilogue/collective/builders/sm90_common.inl @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/collective/collective_builder.hpp b/3rd/cutlass/include/cutlass/epilogue/collective/collective_builder.hpp index bb55c96..7ba13d9 100644 --- a/3rd/cutlass/include/cutlass/epilogue/collective/collective_builder.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/collective/collective_builder.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -120,6 +120,7 @@ struct CallbacksBuilder< #include "builders/sm90_builder.inl" #include "builders/sm100_builder.inl" +#include "builders/sm103_builder.inl" #include "builders/sm120_builder.inl" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/epilogue/collective/collective_epilogue.hpp b/3rd/cutlass/include/cutlass/epilogue/collective/collective_epilogue.hpp index 918017e..9b745a4 100644 --- a/3rd/cutlass/include/cutlass/epilogue/collective/collective_epilogue.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/collective/collective_epilogue.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/collective/default_epilogue.hpp b/3rd/cutlass/include/cutlass/epilogue/collective/default_epilogue.hpp index ed34bc1..884ba9a 100644 --- a/3rd/cutlass/include/cutlass/epilogue/collective/default_epilogue.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/collective/default_epilogue.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/collective/default_epilogue_array.hpp b/3rd/cutlass/include/cutlass/epilogue/collective/default_epilogue_array.hpp index 3cab46d..855fbf6 100644 --- a/3rd/cutlass/include/cutlass/epilogue/collective/default_epilogue_array.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/collective/default_epilogue_array.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/collective/detail.hpp b/3rd/cutlass/include/cutlass/epilogue/collective/detail.hpp index 94e43ba..05078bb 100644 --- a/3rd/cutlass/include/cutlass/epilogue/collective/detail.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/collective/detail.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -40,6 +40,7 @@ #include "cute/tensor.hpp" #include "cute/numeric/numeric_types.hpp" #include "cute/util/type_traits.hpp" +#include "cute/arch/copy_sm90_desc.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -205,6 +206,16 @@ struct IsThreadEpilogueOpWithPerChannelScaling +struct IsThreadEpilogueOpWithResidualAdd { + static constexpr bool value = false; +}; + +template +struct IsThreadEpilogueOpWithResidualAdd > { + static constexpr bool value = ThreadEpilogueOp::IsResidualSupported; +}; + template struct IsThreadEpilogueOpWithActivation { static constexpr bool value = false; @@ -305,13 +316,37 @@ class Sm90TmaWarpSpecializedAdapter : public EpilogueOp { return false; } + template CUTLASS_DEVICE auto load_init( - [[maybe_unused]] typename EpilogueOp::Params const& params, - [[maybe_unused]] TensorMapStorage& shared_tensormaps, - [[maybe_unused]] int32_t sm_count, - [[maybe_unused]] int32_t sm_idx) { - return cute::make_tuple(nullptr); + typename EpilogueOp::Params const& params, + TensorMapStorage& shared_tensormaps, + int32_t sm_count, + int32_t sm_idx) { + return cute::make_tuple( + tensormaps_init(params, shared_tensormaps, sm_count, sm_idx, 0) + ); + } + + template + CUTLASS_DEVICE auto + tensormaps_init( + [[maybe_unused]] typename EpilogueOp::Params const& params, + [[maybe_unused]] TensorMapStorage& shared_tensormaps, + [[maybe_unused]] int32_t sm_count, + [[maybe_unused]] int32_t sm_idx, + [[maybe_unused]] int32_t warp_group_idx = 0) { + // In the async tensormap update kernels, we will use operator[] to index the return value to locate the correct tensormap. + // In other kernels, we will use return value as tensormap pointer directly. + struct { + CUTLASS_DEVICE operator cute::TmaDescriptor *() const { + return reinterpret_cast(0); + } + CUTLASS_DEVICE auto operator [] (int) const { + return reinterpret_cast(0); + } + } ret; + return ret; } template< @@ -367,14 +402,17 @@ class Sm90TmaWarpSpecializedAdapter : public EpilogueOp { return load_pipe_producer_state; } + template CUTLASS_DEVICE auto store_init( - [[maybe_unused]] typename EpilogueOp::Params const& params, - [[maybe_unused]] TensorMapStorage& shared_tensormaps, - [[maybe_unused]] int32_t sm_count, - [[maybe_unused]] int32_t sm_idx, - [[maybe_unused]] int32_t warp_group_idx) { - return cute::make_tuple(nullptr); + typename EpilogueOp::Params const& params, + TensorMapStorage& shared_tensormaps, + int32_t sm_count, + int32_t sm_idx, + int32_t warp_group_idx = 0) { + return cute::make_tuple( + tensormaps_init(params, shared_tensormaps, sm_count, sm_idx, 0) + ); } template< @@ -485,6 +523,7 @@ class Sm90TmaWarpSpecializedAdapter : public EpilogueOp { // Dummy methods to perform different parts of TMA/Tensormap modifications template CUTLASS_DEVICE void @@ -494,15 +533,17 @@ class Sm90TmaWarpSpecializedAdapter : public EpilogueOp { [[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] ProblemShapeMNKL problem_shape, [[maybe_unused]] int32_t next_batch, - [[maybe_unused]] int32_t warp_group_idx) { } + [[maybe_unused]] int32_t warp_group_idx = 0 + ) { } - template + template CUTLASS_DEVICE void tensormaps_cp_fence_release( [[maybe_unused]] TensorMapStorage& shared_tensormaps, [[maybe_unused]] cute::TmaDescriptor const* tensormap, - [[maybe_unused]] int32_t warp_group_idx) { } + [[maybe_unused]] int32_t warp_group_idx = 0 + ) { } template CUTLASS_DEVICE @@ -527,6 +568,10 @@ class Sm100TmaWarpSpecializedAdapter : public EpilogueOp { static constexpr int NumAccumulatorMtxs = Sm100EpilogueOpNumAccumulatorMtxs::value; + // Epilog assumes a max scheduler pipe count to calculate the number of asynchronous tma update buffer they need. + // In these epilogues, we don't need to update tensormaps at all. Setting this to INT_MAX. + constexpr static uint32_t NumMaxSchedulerPipelineStageCount = INT_MAX; + template CUTLASS_HOST_DEVICE static constexpr int @@ -554,13 +599,37 @@ class Sm100TmaWarpSpecializedAdapter : public EpilogueOp { // ctor inheritance using EpilogueOp::EpilogueOp; + template CUTLASS_DEVICE auto - load_init( + tensormaps_init( [[maybe_unused]] typename EpilogueOp::Params const& params, - [[maybe_unused]] TensorMapStorage& shared_tensormap, - [[maybe_unused]] int32_t const sm_count, - [[maybe_unused]] int32_t const sm_idx) const { - return cute::make_tuple(nullptr); + [[maybe_unused]] TensorMapStorage& shared_tensormaps, + [[maybe_unused]] int32_t sm_count, + [[maybe_unused]] int32_t sm_idx, + [[maybe_unused]] int32_t warp_group_idx = 0) const { + // In the async tensormap update kernels, we will use operator[] to index the return value to locate the correct tensormap. + // In other kernels, we will use return value as tensormap pointer directly. + struct { + CUTLASS_DEVICE operator cute::TmaDescriptor *() const { + return reinterpret_cast(0); + } + CUTLASS_DEVICE auto operator [] (int) const { + return reinterpret_cast(0); + } + } ret; + return ret; + } + + template + CUTLASS_DEVICE auto + load_init( + typename EpilogueOp::Params const& params, + TensorMapStorage& shared_tensormap, + int32_t const sm_count, + int32_t const sm_idx) const { + return cute::make_tuple( + tensormaps_init(params, shared_tensormap, sm_count, sm_idx, 0) + ); } template< @@ -623,13 +692,16 @@ class Sm100TmaWarpSpecializedAdapter : public EpilogueOp { { } + template CUTLASS_DEVICE auto store_init( - [[maybe_unused]] typename EpilogueOp::Params const& params, - [[maybe_unused]] TensorMapStorage& shared_tensormap, - [[maybe_unused]] int32_t const sm_count, - [[maybe_unused]] int32_t const sm_idx) const { - return cute::make_tuple(nullptr); + typename EpilogueOp::Params const& params, + TensorMapStorage& shared_tensormap, + int32_t const sm_count, + int32_t const sm_idx) const { + return cute::make_tuple( + tensormaps_init(params, shared_tensormap, sm_count, sm_idx, 0) + ); } template< @@ -664,7 +736,9 @@ class Sm100TmaWarpSpecializedAdapter : public EpilogueOp { // Wait for mma warp to fill tmem buffer with accumulator results acc_pipeline.consumer_wait(acc_pipe_consumer_state); - auto [acc_state_next] = (*this).template operator()( + auto [acc_state_next, load_state_next] = (*this).template operator()( + load_pipeline, + load_pipe_consumer_state, acc_pipeline, acc_pipe_consumer_state, problem_shape_mnkl, @@ -674,10 +748,9 @@ class Sm100TmaWarpSpecializedAdapter : public EpilogueOp { shared_tensors); // Let mma warp know tmem buffer is consumed and empty - ++load_pipe_consumer_state; ++store_pipe_producer_state; - return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state, acc_state_next); + return cute::make_tuple(load_state_next, store_pipe_producer_state, acc_state_next); } // FastF32 API @@ -693,18 +766,18 @@ class Sm100TmaWarpSpecializedAdapter : public EpilogueOp { > CUTLASS_DEVICE auto store( - LoadPipeline load_pipeline, - LoadPipelineState load_pipe_consumer_state, - StorePipeline store_pipeline, - StorePipelineState store_pipe_producer_state, - ProblemShapeMNKL problem_shape_mnkl, - CtaTileMNK cta_tile_mnk, - CtaCoordMNKL cta_coord_mnkl, - MmaTileMNK mma_tile_mnk, - TiledMma tiled_mma, - cute::Tensor& tTR_rAcc, - TensorStorage& shared_tensors, - TiledCopyT2R tiled_t2r) + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + MmaTileMNK mma_tile_mnk, + TiledMma tiled_mma, + cute::Tensor& tTR_rAcc, + TensorStorage& shared_tensors, + TiledCopyT2R tiled_t2r) { (*this)( problem_shape_mnkl, @@ -730,19 +803,19 @@ class Sm100TmaWarpSpecializedAdapter : public EpilogueOp { > CUTLASS_DEVICE auto store( - LoadPipeline load_pipeline, - LoadPipelineState load_pipe_consumer_state, - StorePipeline store_pipeline, - StorePipelineState store_pipe_producer_state, - ProblemShapeMNKL problem_shape_mnkl, - CtaTileMNK cta_tile_mnk, - CtaCoordMNKL cta_coord_mnkl, - MmaTileMNK mma_tile_mnk, - TiledMma tiled_mma, - cute::Tensor& tTR_rAcc, - TensorStorage& shared_tensors, - TensorMap tensormap, - TiledCopyT2R tiled_t2r) { + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + MmaTileMNK mma_tile_mnk, + TiledMma tiled_mma, + cute::Tensor& tTR_rAcc, + TensorStorage& shared_tensors, + TensorMap tensormap, + TiledCopyT2R tiled_t2r) { (*this)( problem_shape_mnkl, cta_tile_mnk, @@ -755,6 +828,7 @@ class Sm100TmaWarpSpecializedAdapter : public EpilogueOp { template< bool ReuseTmem = false, + bool WaitForInflightTmaRequests = true, class AccumulatorPipeline, class AccumulatorPipelineState, class ProblemShapeMNKL, @@ -783,11 +857,10 @@ class Sm100TmaWarpSpecializedAdapter : public EpilogueOp { TensorStorage& shared_tensors, TensorMap tensormap ) - { - // Wait for mma warp to fill tmem buffer with accumulator results - acc_pipeline.consumer_wait(acc_pipe_consumer_state); - - auto [acc_state_next] = (*this).template operator()( + { + auto [acc_state_next, load_state_next] = (*this).template operator()( + load_pipeline, + load_pipe_consumer_state, acc_pipeline, acc_pipe_consumer_state, problem_shape_mnkl, @@ -797,10 +870,9 @@ class Sm100TmaWarpSpecializedAdapter : public EpilogueOp { shared_tensors); // Let mma warp know tmem buffer is consumed and empty - ++load_pipe_consumer_state; ++store_pipe_producer_state; - return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state, acc_state_next); + return cute::make_tuple(load_state_next, store_pipe_producer_state, acc_state_next); } template @@ -815,8 +887,7 @@ class Sm100TmaWarpSpecializedAdapter : public EpilogueOp { } // Dummy methods to perform different parts of TMA/Tensormap modifications - - template + template CUTLASS_DEVICE void tensormaps_perform_update( @@ -824,14 +895,16 @@ class Sm100TmaWarpSpecializedAdapter : public EpilogueOp { [[maybe_unused]] typename EpilogueOp::Params const& params, [[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] ProblemShape problem_shape, - [[maybe_unused]] int32_t next_batch) { } + [[maybe_unused]] int32_t next_batch + ) { } - template + template CUTLASS_DEVICE void tensormaps_cp_fence_release( [[maybe_unused]] TensorMapStorage& shared_tensormap, - [[maybe_unused]] cute::TmaDescriptor const* tensormap) { } + [[maybe_unused]] cute::TmaDescriptor const* tensormap + ) { } template CUTLASS_DEVICE diff --git a/3rd/cutlass/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp b/3rd/cutlass/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp index d32dd6a..f661cbd 100644 --- a/3rd/cutlass/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_array_nosmem.hpp b/3rd/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_array_nosmem.hpp index 80eea5e..176e7ad 100644 --- a/3rd/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_array_nosmem.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_array_nosmem.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -60,7 +60,9 @@ template < class ElementD_, class StrideD_, class ThreadEpilogueOp_, - class CopyOpT2R_ + class CopyOpT2R_, + class AlignmentC_, + class AlignmentD_ > class CollectiveEpilogue< Sm100PtrArrayNoSmem, @@ -70,7 +72,10 @@ class CollectiveEpilogue< ElementD_, StrideD_, ThreadEpilogueOp_, - CopyOpT2R_ + CopyOpT2R_, + AlignmentC_, + AlignmentD_, + cute::enable_if_t::value> > { public: // @@ -92,6 +97,10 @@ class CollectiveEpilogue< using StrideD = StrideD_; using InternalStrideD = cute::remove_pointer_t; using CopyOpT2R = CopyOpT2R_; + using AlignmentC = AlignmentC_; + using AlignmentD = AlignmentD_; + + using GmemElementC = cute::conditional_t,ElementD,ElementC>; // prevents void ref breakages using GmemTiledCopyC = void; using GmemTiledCopyD = void; @@ -136,7 +145,7 @@ class CollectiveEpilogue< template static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count = 0) { + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int /*sm_count*/ = 0) { return 0; } @@ -160,6 +169,8 @@ class CollectiveEpilogue< template< bool ReuseTmem = false, + class LoadPipeline, + class LoadPipelineState, class AccumulatorPipeline, class AccumulatorPipelineState, class ProblemShapeMNKL, @@ -169,6 +180,8 @@ class CollectiveEpilogue< > CUTLASS_DEVICE auto operator()( + [[maybe_unused]]LoadPipeline load_pipeline, + [[maybe_unused]]LoadPipelineState load_pipe_consumer_state, AccumulatorPipeline acc_pipeline, AccumulatorPipelineState acc_pipe_consumer_state, ProblemShapeMNKL problem_shape_mnkl, @@ -185,11 +198,15 @@ class CollectiveEpilogue< static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); // Separate out problem shape for convenience - auto M = get<0>(problem_shape_mnkl); - auto N = get<1>(problem_shape_mnkl); - auto L = get<3>(problem_shape_mnkl); + auto [M, N, K, L] = problem_shape_mnkl; // Slice to get the tile this CTA is responsible for auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; + bool is_accumulator_needed = K > 0; + + if (is_accumulator_needed) { + // Wait for mma warp to fill tmem buffer with accumulator results + acc_pipeline.consumer_wait(acc_pipe_consumer_state); + } // Batches are managed by using appropriate pointers to C and D matrices auto problem_shape_mnl = append<3>(make_shape(M,N),Int<1>{}); @@ -232,24 +249,56 @@ class CollectiveEpilogue< } // Represent the full output tensor, slice to get the tile this CTA is responsible for - Tensor mC = make_tensor(make_gmem_ptr(ptr_C_l), problem_shape_mnl, stride_c); // (M,N,L) - Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), problem_shape_mnl, stride_d); // (M,N,L) + Tensor mC = make_tensor(make_gmem_ptr(ptr_C_l), problem_shape_mnl, stride_c); // (M,N,L) + Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), problem_shape_mnl, stride_d); // (M,N,L) Tensor gC = local_tile(mC, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) Tensor gD = local_tile(mD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) // Partition source and destination tiles according to tmem copy T2R partitioning (tTR_) auto tiled_t2r = make_tmem_copy(CopyOpT2R{}, tensor<0>(accumulators)); - auto thread_t2r = tiled_t2r.get_slice(threadIdx.x % size(tiled_t2r)); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); Tensor tTR_gC = thread_t2r.partition_D(gC); // (T2R,T2R_M,T2R_N) Tensor tTR_gD = thread_t2r.partition_D(gD); // (T2R,T2R_M,T2R_N) Tensor tTR_rAcc = make_tensor(shape(tTR_gD)); // (T2R,T2R_M,T2R_N) - Tensor coordD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) - Tensor cD = local_tile(coordD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) - Tensor tTR_cD = thread_t2r.partition_D(cD); // (T2R,T2R_M,T2R_N) -> (m,n,l) + Tensor tTR_rC = make_tensor(shape(tTR_gC)); // (T2R,T2R_M,T2R_N) + + Tensor coordCD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) + Tensor cCD = local_tile(coordCD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) + Tensor tTR_cCD = thread_t2r.partition_D(cCD); // (T2R,T2R_M,T2R_N) -> (m,n,l) + + constexpr auto mclD = decltype(max_common_layout(tTR_rAcc.layout(), tTR_gD.layout())){}; + constexpr int VD = cute::min(AlignmentD{}, size(mclD)); + Tensor tTR_rD_frag = make_tensor(shape(tTR_rAcc)); + Tensor tTR_rD_src = recast>(coalesce(tTR_rD_frag)); + Tensor tR2G_rD_dst = recast>(coalesce(tTR_gD)); + + Tensor tTR_cD_mn_frg = tensor<1>(zipped_divide(coalesce(tTR_cCD), mclD.compose(Int{}))); + Tensor tDpD = make_tensor(shape(tR2G_rD_dst)); + + CUTLASS_PRAGMA_UNROLL + for (int t = 0; t < size(tDpD); t++) { + tDpD(t) = elem_less(tTR_cD_mn_frg(t), problem_shape_mnl); + } + + constexpr auto mclC = decltype(max_common_layout(tTR_rAcc.layout(), tTR_gC.layout())){}; + constexpr int VC = cute::min(AlignmentC{}, size(mclC)); + + Tensor tTR_cC_mn_frg = tensor<1>(zipped_divide(coalesce(tTR_cCD), mclC.compose(Int{}))); + Tensor tG2R_rC_dst = recast>(coalesce(tTR_gC)); + Tensor tCpC = make_tensor(shape(tG2R_rC_dst)); + + CUTLASS_PRAGMA_UNROLL + for (int t = 0; t < size(tCpC); t++) { + tCpC(t) = elem_less(tTR_cC_mn_frg(t), problem_shape_mnl); + } + Tensor tTR_rC_src = recast>(coalesce(tTR_gC)); + Tensor tTR_rC_dst = recast>(coalesce(tTR_rC)); // Detect interleaved complex fp32 kernels - Tensor accs = accumulators; + [[maybe_unused]] Tensor accs = accumulators; using ElementTmem = typename decltype(accs)::value_type; constexpr bool is_interleaved_complex_f32 = is_complex::value && cute::is_same_v; @@ -279,33 +328,40 @@ class CollectiveEpilogue< else { Tensor tAcc = accumulators(make_coord(_,_),_0{},_0{}); // (CTA_M,CTA_N) Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); // (T2R,T2R_M,T2R_N) - - copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + if (is_accumulator_needed) { + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + } else { + fill(tTR_rAcc, 0); + } + } + if (is_accumulator_needed) { + cutlass::arch::fence_view_async_tmem_load(); + acc_pipeline.consumer_release(acc_pipe_consumer_state); + ++acc_pipe_consumer_state; } // 2. Apply element-wise operation and store to gmem // source is needed if (epilogue_op.is_source_needed()) { + copy_if(tCpC, tTR_rC_src, tTR_rC_dst); CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tTR_rAcc); ++i) { - if (elem_less(tTR_cD(i), problem_shape_mnl)) { - tTR_gD(i) = epilogue_op(tTR_rAcc(i), tTR_gC(i)); - } + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rD_frag(i) = epilogue_op(tTR_rAcc(i), tTR_rC(i)); } + + copy_if(tDpD, tTR_rD_src, tR2G_rD_dst); } // source is not needed, avoid load else { CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tTR_rAcc); ++i) { - if (elem_less(tTR_cD(i), problem_shape_mnl)) { - tTR_gD(i) = epilogue_op(tTR_rAcc(i)); - } + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rD_frag(i) = epilogue_op(tTR_rAcc(i)); } + + copy_if(tDpD, tTR_rD_src, tR2G_rD_dst); } - cutlass::arch::fence_view_async_tmem_load(); - acc_pipeline.consumer_release(acc_pipe_consumer_state); - ++acc_pipe_consumer_state; - return cute::make_tuple(acc_pipe_consumer_state); + + return cute::make_tuple(acc_pipe_consumer_state, load_pipe_consumer_state); } // API with Global Accumulator in registers for FastFP32 (emulated MMA) kernels. @@ -354,10 +410,33 @@ class CollectiveEpilogue< if (epilogue_op.is_source_needed()) { ptr_C_l = params.ptr_C[l_coord]; } + auto [stride_c, stride_d] = [&, l = l_coord]() { + if constexpr (!cute::is_same_v) { + // If grouped gemm + if (epilogue_op.is_source_needed()) { + return make_tuple( + detail::get_epilogue_stride(params.dC[l]), + detail::get_epilogue_stride(params.dD[l]) + ); + } + else { + return make_tuple( + InternalStrideC{}, + detail::get_epilogue_stride(params.dD[l]) + ); + } + } + else { + return make_tuple( + detail::get_epilogue_stride(params.dC), + detail::get_epilogue_stride(params.dD) + ); + } + }(); // Represent the full output tensor, slice to get the tile this CTA is responsible for - Tensor mC = make_tensor(make_gmem_ptr(ptr_C_l), problem_shape_mnl, append<3>(params.dC,_0{})); // (M,N,L) - Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), problem_shape_mnl, append<3>(params.dD,_0{})); // (M,N,L) + Tensor mC = make_tensor(make_gmem_ptr(ptr_C_l), problem_shape_mnl, stride_c); // (M,N,L) + Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), problem_shape_mnl, stride_d); // (M,N,L) Tensor gC = local_tile(mC, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) Tensor gD = local_tile(mD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) @@ -397,6 +476,447 @@ class CollectiveEpilogue< Params const& params; }; +template < + class EpilogueTile_, // (EPI_TILE_M, EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class FusionCallbacks_, + class CopyOpT2R_, + class AlignmentC_, + class AlignmentD_ +> +class CollectiveEpilogue< + Sm100PtrArrayNoSmem, + EpilogueTile_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + FusionCallbacks_, + CopyOpT2R_, + AlignmentC_, + AlignmentD_, + cute::enable_if_t::value> +> { +public: + // + // Type Aliases + // + // Required by the gemm::kernel + using DispatchPolicy = Sm100PtrArrayNoSmem; + using ElementC = ElementC_; + using ElementD = ElementD_; + using GmemElementC = cute::conditional_t,ElementD,ElementC>; // prevents void ref breakages + using StrideC = StrideC_; + using StrideD = StrideD_; + using InternalStrideC = cute::remove_pointer_t; + using InternalStrideD = cute::remove_pointer_t; + using EpilogueTile = EpilogueTile_; + using CopyOpT2R = CopyOpT2R_; + using FusionCallbacks = FusionCallbacks_; + using ThreadEpilogueOp = typename epilogue::fusion::FusionCallbacksTraits::Operation; + + using GmemTiledCopyC = void; + using GmemTiledCopyD = void; + +private: + constexpr static bool IsReductionBufferNeeded = ThreadEpilogueOp::IsDePerRowBiasSupported + || is_same_v; // alloc reduction buffer for custom EVTs + constexpr static size_t ImplicitSharedStorageSize = IsReductionBufferNeeded ? size(EpilogueTile{}) : 0; + + // Not unroll epi subtile loop when the activation op is heavy to reduce instruction size and register pressure. + constexpr static bool UnrollEpiLoop = + not cutlass::epilogue::thread::kIsHeavy_member_or_false::value; + +public: + constexpr static int ThreadCount = 128; + constexpr static uint32_t TmaTransactionBytes = 0; + + struct SharedStorage { + using FusionStorage = typename FusionCallbacks::SharedStorage; + FusionStorage thread; + array_aligned buffer; + }; + + // Host side epilogue arguments + struct Arguments { + typename FusionCallbacks::Arguments thread{}; + ElementC const** ptr_C = nullptr; + StrideC dC = {}; + ElementD** ptr_D = nullptr; + StrideD dD = {}; + }; + + // Device side epilogue params + struct Params { + typename FusionCallbacks::Params thread{}; + ElementC const** ptr_C = nullptr; + StrideC dC = {}; + ElementD** ptr_D = nullptr; + StrideD dD = {}; + }; + + // + // Constructor and Data Members + // + CUTLASS_DEVICE + CollectiveEpilogue(Params const& params_, SharedStorage& shared_tensors) + : fusion_callbacks(params_.thread, shared_tensors.thread) + , smem_buffer_ptr(shared_tensors.buffer.data()) + , params(params_) {}; + +protected: + FusionCallbacks fusion_callbacks; + uint8_t* smem_buffer_ptr; + Params const& params; + +public: + + template + static constexpr Params + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return { + FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), + args.ptr_C, + args.dC, + args.ptr_D, + args.dD + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int /*sm_count*/ = 0) { + return FusionCallbacks::get_workspace_size(problem_shape, args.thread); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return FusionCallbacks::initialize_workspace(problem_shape, args.thread, workspace, stream, cuda_adapter); + } + + template + static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + + template< + bool ReuseTmem = false, + class LoadPipeline, + class LoadPipelineState, + class AccumulatorPipeline, + class AccumulatorPipelineState, + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class AccEngine, class AccLayout + > + CUTLASS_DEVICE auto + operator()( + [[maybe_unused]]LoadPipeline load_pipeline, + [[maybe_unused]]LoadPipelineState load_pipe_consumer_state, + AccumulatorPipeline acc_pipeline, + AccumulatorPipelineState acc_pipe_consumer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + cute::Tensor accumulators, + [[maybe_unused]] SharedStorage& + ) { + using ElementAccumulator = typename AccEngine::value_type; + using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; + using ElementCompute = cute::conditional_t,ElementAccumulator,ElementCompute_>; + + // Wait for mma warp to fill tmem buffer with accumulator results + static_assert(is_tmem::value, "Accumulator must be TMEM resident."); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(rank(CtaCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); + static_assert(cute::sizeof_bits_v != 6, "Output element requires smem"); + + // Separate out problem shape for convenience + auto [M, N, K, L] = problem_shape_mnkl; + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; + bool is_accumulator_needed = K > 0; + + if (is_accumulator_needed) { + // Wait for mma warp to fill tmem buffer with accumulator results + acc_pipeline.consumer_wait(acc_pipe_consumer_state); + } + + // Batches are managed by using appropriate pointers to C and D matrices + auto problem_shape_mnl = append<3>(make_shape(M,N),Int<1>{}); + auto cta_coord_mnl = append<3>(make_shape(m_coord, n_coord),Int<0>{}); + auto cta_tiler = take<0,2>(cta_tile_mnk); + auto cta_coord_mnk = cute::make_coord(m_coord, n_coord, k_coord, cute::Int<0>{}); + + bool is_C_load_needed = fusion_callbacks.is_C_load_needed(); + + auto [stride_c, stride_d] = [&, l = l_coord]() { + if constexpr (!cute::is_same_v) { + // If grouped gemm + if (is_C_load_needed) { + return make_tuple( + detail::get_epilogue_stride(params.dC[l]), + detail::get_epilogue_stride(params.dD[l]) + ); + } + else { + return make_tuple( + InternalStrideC{}, + detail::get_epilogue_stride(params.dD[l]) + ); + } + } + else { + return make_tuple( + detail::get_epilogue_stride(params.dC), + detail::get_epilogue_stride(params.dD) + ); + } + }(); + + // Get the residual tensor for the current batch + ElementC const* ptr_C_l = nullptr; + if (is_C_load_needed) { + ptr_C_l = params.ptr_C[l_coord]; + } + + int thread_idx = threadIdx.x % ThreadCount; + + Tensor tAcc = accumulators(make_coord(_,_),_0{},_0{}); // (CTA_M,CTA_N) + Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + TiledCopy tiled_t2r = make_tmem_copy(CopyOpT2R{}, tAcc_epi(_,_,_0{},_0{})); + ThrCopy thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc_epi); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + + constexpr int FragmentSize = size(EpilogueTile{}) / ThreadCount; + + Tensor coordD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) + Tensor cD = local_tile(coordD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) + Tensor cD_epi = flat_divide(cD, EpilogueTile{}); + Tensor tTR_cD = thread_t2r.partition_D(cD_epi); // (T2R,T2R_M,T2R_N) -> (m,n,l) + + Tensor tTR_rAcc = make_tensor(shape(tTR_cD(_,_,_,_0{},_0{}))); + + // Construct the EVT consumer callbacks + auto residue_cD = make_coord(M,N) - cD(_0{}); + auto residue_tTR_cD = make_coord(M,N) - tTR_cD(_0{}); + Tensor cD_ = make_coord_tensor(cD.layout()); + Tensor tTR_cD_ = make_coord_tensor(tTR_cD.layout()); + constexpr bool RefSrc = false; + + Tensor mC = make_tensor(make_gmem_ptr(ptr_C_l), problem_shape_mnl, stride_c); + + Tensor tTR_gC = cutlass::epilogue::fusion::sm90_partition_for_epilogue( + mC, cta_tile_mnk, cta_coord_mnk, EpilogueTile{}, tiled_t2r, thread_idx); + + Tensor mD = make_tensor(make_gmem_ptr(recast_ptr(params.ptr_D[l_coord])), problem_shape_mnl, stride_d); + + Tensor tTR_gD = cutlass::epilogue::fusion::sm90_partition_for_epilogue( + mD, cta_tile_mnk, cta_coord_mnk, EpilogueTile{}, tiled_t2r, thread_idx); + + // Register Tensor + Tensor tTR_rD = make_tensor(take<0,3>(shape(tTR_gD))); + + Tensor coord_cCD = make_identity_tensor(problem_shape_mnl); + Tensor tTR_cCD = cutlass::epilogue::fusion::sm90_partition_for_epilogue( + coord_cCD, cta_tile_mnk, cta_coord_mnk, EpilogueTile{}, tiled_t2r, thread_idx); + constexpr auto mclD = decltype(max_common_layout(tTR_gD(_,_,_,_0{},_0{}), tTR_rD)){}; + constexpr int VD = cute::min(AlignmentD_{}, size(mclD)); + + auto tCrC = make_tensor(take<0,3>(shape(tTR_gC))); + constexpr auto mclC = decltype(max_common_layout(tTR_gC(_,_,_,_0{},_0{}), tCrC)){}; + constexpr int VC = cute::min(AlignmentC_{}, size(mclC)); + + Tensor tTR_rD_frg = recast>(coalesce(tTR_rD)); + + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ + problem_shape_mnkl, + cta_tile_mnk, + cta_coord_mnkl, + int(0), + EpilogueTile{}, + tiled_t2r, + cD_, + residue_cD, + tTR_cD_, + residue_tTR_cD, + tCrC, + thread_idx + }; + + auto synchronize = [] () CUTLASS_LAMBDA_FUNC_INLINE { cutlass::arch::NamedBarrier::sync(ThreadCount, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + + // The Epilogue Loop + auto epi_loop_fn = [&] (auto& cst_callbacks, bool is_accumulator_needed) CUTLASS_LAMBDA_FUNC_INLINE { + // Ensure there are no threads from the previous wave writing to shared memory being utilized for the current wave. + synchronize(); + cst_callbacks.begin(); + if (cst_callbacks.begin_sync_needed()) { + synchronize(); + } + + // If tmem doesn't have enough capacity to support double buffering, a portion of tmem (a column of epilogue tiles) + // is overlapped between 2 pseudo-buffers. The shared tmem portion corresponds to the last epilogue tile column of + // tmem accumulator buffer 0, and the first epilogue tile column of tmem accumulator 1. + // Thus, whenever we are processing tmem accumulator buffer 0, we process the epilogue tiles with reversed column order. + // Once the last epilogue tile column is loaded from tmem, the acc_pipeline is released. + // Then, the next accumulation stage for buffer 1 can start. + [[maybe_unused]] bool reverse_epi_n = ReuseTmem && acc_pipe_consumer_state.phase() == 0; + static_assert(not (ReuseTmem && AccumulatorPipeline::Stages != 1), "Tmem reuse requires 1 accumulator stage"); + + // For each epilogue subtile within the CTA tile + constexpr int NumEpiSubtilesN = CUTE_STATIC_V(size<4>(tTR_tAcc)); + constexpr int NumEpiSubtilesM = CUTE_STATIC_V(size<3>(tTR_tAcc)); + + // Lambda to process a single epilogue tile + auto process_tile = [&](int epi_m, int epi_n, int iter_m, int iter_n) CUTLASS_LAMBDA_FUNC_INLINE { + bool is_last_iteration = iter_m == NumEpiSubtilesM-1 && iter_n == NumEpiSubtilesN-1; + bool do_acc_release = is_last_iteration; + + // Adjust release condition for tmem reuse + if constexpr (ReuseTmem) { + do_acc_release = iter_m == NumEpiSubtilesM-1 && iter_n == 0; // Release on first N iteration + } + + Tensor tTR_cCD_mn = tTR_cCD(_,_,_,epi_m,epi_n); + Tensor tTR_pCD_mn = cute::lazy::transform(tTR_cCD_mn, [&] (auto const& c) CUTLASS_LAMBDA_FUNC_INLINE { return elem_less(c, problem_shape_mnl); }); + cst_callbacks.begin_loop(epi_m, epi_n); + + if constexpr (not cute::is_void_v) { + if (is_C_load_needed) { + using CVecType = uint_bit_t>; + + if constexpr (!is_same_v) { + Tensor tTR_gC_frg = recast(coalesce(tTR_gC(_,_,_,epi_m,epi_n))); + Tensor tTR_rC_frg = recast(coalesce(tCrC)); + Tensor tTR_pC_frg = tensor<1>(zipped_divide(coalesce(tTR_pCD_mn), mclC.compose(Int{}))); + copy_if(tTR_pC_frg, tTR_gC_frg, tTR_rC_frg); + } + else { + auto tiled_g2r = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + auto thr_g2r = tiled_g2r.get_slice(threadIdx.x); + Tensor c_src = thr_g2r.retile_S(tTR_gC(_,_,_,epi_m,epi_n)); + Tensor c_dst = thr_g2r.retile_D(tCrC); + Tensor c_prd = thr_g2r.retile_D(tTR_pCD_mn); + copy_if(tiled_g2r, c_prd, c_src, c_dst); + } + } + } + + // Copy accumulator tile from tmem to register + // The current tile in tmem + Tensor tTR_tAcc_mn = tTR_tAcc(_,_,_,epi_m,epi_n); + + Tensor tTR_rAcc_frg = recast>(coalesce(tTR_rAcc)); + + if (is_accumulator_needed) { + copy(tiled_t2r, tTR_tAcc_mn, tTR_rAcc); + } + else { + fill(tTR_rAcc, 0); + } + + // After the last tmem load, signal that tmem buffer is consumed and empty + if (do_acc_release && is_accumulator_needed) { + cutlass::arch::fence_view_async_tmem_load(); + acc_pipeline.consumer_release(acc_pipe_consumer_state); + ++acc_pipe_consumer_state; + } + + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size(tTR_rAcc_frg); ++epi_v) { + tTR_rD_frg(epi_v) = cst_callbacks.visit(tTR_rAcc_frg(epi_v), epi_v, epi_m, epi_n); + } + + Tensor reduction_buffer = make_tensor( + raw_pointer_cast(make_smem_ptr(smem_buffer_ptr)), make_layout(Shape>{})); + + cst_callbacks.reduce(reduction_buffer, synchronize, epi_m, epi_n, is_last_iteration, tTR_rAcc /*not used*/); + + cst_callbacks.end_loop(epi_m, epi_n); + + using VecType = uint_bit_t>; + if constexpr (!is_same_v) { + Tensor tTR_gD_frg = recast(coalesce(tTR_gD(_,_,_,epi_m,epi_n))); + Tensor tTR_rD_frg = recast(coalesce(tTR_rD)); + Tensor tTR_pD_frg = tensor<1>(zipped_divide(coalesce(tTR_pCD_mn), mclD.compose(Int{}))); + copy_if(tTR_pD_frg, tTR_rD_frg, tTR_gD_frg); + } + else { + auto tiled_r2g = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + auto thr_r2g = tiled_r2g.get_slice(threadIdx.x); + Tensor src = thr_r2g.retile_S(tTR_rD); + Tensor dst = thr_r2g.retile_D(tTR_gD(_,_,_,epi_m,epi_n)); + Tensor prd = thr_r2g.retile_D(tTR_pCD_mn); + copy_if(tiled_r2g, prd, src, dst); + } + }; + + // Use static iteration with appropriate ordering + // When ReuseTmem is true and reverse_epi_n is true, we need reverse N iteration + auto n_seq = cute::make_int_sequence{}; + auto m_seq = cute::make_int_sequence{}; + + if constexpr (UnrollEpiLoop) { + // Fully unrolled static iteration + cute::for_each(n_seq, [&](auto I_N) CUTLASS_LAMBDA_FUNC_INLINE { + constexpr int iter_n = I_N; + int epi_n = iter_n; + if constexpr (ReuseTmem) { + if (reverse_epi_n) { + epi_n = NumEpiSubtilesN - 1 - iter_n; // Reverse N iteration + } + } + + cute::for_each(m_seq, [&](auto I_M) CUTLASS_LAMBDA_FUNC_INLINE { + constexpr int iter_m = I_M; + process_tile(iter_m, epi_n, iter_m, iter_n); + }); + }); + } else { + // Runtime loop with pragma unroll(1) + #pragma unroll 1 + for (int iter_n = 0; iter_n < NumEpiSubtilesN; ++iter_n) { + int epi_n = iter_n; + if constexpr (ReuseTmem) { + if (reverse_epi_n) { + epi_n = NumEpiSubtilesN - 1 - iter_n; // Reverse N iteration + } + } + + #pragma unroll 1 + for (int iter_m = 0; iter_m < NumEpiSubtilesM; ++iter_m) { + process_tile(iter_m, epi_n, iter_m, iter_n); + } + } + } + + cst_callbacks.end(); + }; + + // + // BEGIN EPILOGUE + // + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + epi_loop_fn(cst_callbacks, is_accumulator_needed); + return cute::make_tuple(acc_pipe_consumer_state, load_pipe_consumer_state); + } + +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// // For sm100 kernels requiring warp specialized epilogues @@ -430,7 +950,10 @@ class CollectiveEpilogue< ElementD_, StrideD_, ThreadEpilogueOp_, - CopyOpT2R_>> + CopyOpT2R_, + AlignmentC, + AlignmentD, + void>> { public: // ctor inheritance @@ -442,7 +965,10 @@ class CollectiveEpilogue< ElementD_, StrideD_, ThreadEpilogueOp_, - CopyOpT2R_>>::Sm100TmaWarpSpecializedAdapter; + CopyOpT2R_, + AlignmentC, + AlignmentD, + void>>::Sm100TmaWarpSpecializedAdapter; }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_array_planar_complex_nosmem.hpp b/3rd/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_array_planar_complex_nosmem.hpp new file mode 100644 index 0000000..d3cb715 --- /dev/null +++ b/3rd/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_array_planar_complex_nosmem.hpp @@ -0,0 +1,345 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by Ptr-Array Planar Complex Gemm epilogue. +*/ + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" + +#include "cute/tensor.hpp" +#include "cute/numeric/numeric_types.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +/// Applies an element wise operation to all elements within the fragment +/// and writes it out to destination storage. +template < + class EpilogueTile_, // (EPI_TILE_M, EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class ThreadEpilogueOp_, + class CopyOpT2R_ +> +class CollectiveEpilogue< + Sm100PtrArrayPlanarComplexNoSmem, + EpilogueTile_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + ThreadEpilogueOp_, + CopyOpT2R_ +> { +public: + // + // Type Aliases + // + using DispatchPolicy = Sm100PtrArrayPlanarComplexNoSmem; + using EpilogueTile = EpilogueTile_; + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = typename ThreadEpilogueOp::ElementScalar; + using ElementC = ElementC_; + using StrideC = StrideC_; + using InternalStrideC = cute::remove_pointer_t; + using ElementD = ElementD_; + using StrideD = StrideD_; + using InternalStrideD = cute::remove_pointer_t; + using CopyOpT2R = CopyOpT2R_; + + using GmemTiledCopyC = void; + using GmemTiledCopyD = void; + + constexpr static int ThreadCount = 128; + constexpr static uint32_t TmaTransactionBytes = 0; + constexpr static int FragmentSize = ThreadEpilogueOp::kCount; + + struct SharedStorage { + struct TensorStorage { }; + struct TensorMapStorage { }; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + + // Planar complex kernels have two accumulator copies for the real and imaginary tensors. + constexpr static int NumAccumulatorMtxs = 2; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread{}; + ElementC const** ptr_C_real = nullptr; + StrideC dC_real{}; + ElementC const** ptr_C_imag = nullptr; + StrideC dC_imag{}; + ElementD** ptr_D_real = nullptr; + StrideD dD_real{}; + ElementD** ptr_D_imag = nullptr; + StrideD dD_imag{}; + }; + + // Device side epilogue params + using Params = Arguments; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + CUTLASS_HOST_DEVICE + CollectiveEpilogue(Params const& params, SharedStorage&) : params(params) { }; + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class AccEngine, class AccLayout + > + CUTLASS_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK cta_tile_shape_mnk, + TileCoordMNKL cta_coord_mnkl, + cute::Tensor const& accumulators, // (MMA,MMA_M,MMA_N) + [[maybe_unused]] SharedStorage&) { + + using namespace cute; + using X = Underscore; + + static_assert(is_tmem::value, "Accumulator must be TMEM resident."); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + // Batches are managed by using appropriate pointers to C and D matrices + const int32_t mock_L = 1; + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; + const int32_t mock_l_coord = 0; + + auto problem_shape_mnl = make_shape(M,N,mock_L); + auto cta_coord_mnl = make_shape(m_coord, n_coord, mock_l_coord); + auto cta_tiler = take<0,2>(cta_tile_shape_mnk); + + // Represent the full output tensor, slice to get the tile this CTA is responsible for + Tensor mC_real = make_tensor(make_gmem_ptr(params.ptr_C_real[l_coord]), problem_shape_mnl, append<3>(params.dC_real,_0{})); // (M,N,L) + Tensor mC_imag = make_tensor(make_gmem_ptr(params.ptr_C_imag[l_coord]), problem_shape_mnl, append<3>(params.dC_imag,_0{})); // (M,N,L) + + Tensor mD_real = make_tensor(make_gmem_ptr(params.ptr_D_real[l_coord]), problem_shape_mnl, append<3>(params.dD_real,_0{})); // (M,N,L) + Tensor mD_imag = make_tensor(make_gmem_ptr(params.ptr_D_imag[l_coord]), problem_shape_mnl, append<3>(params.dD_imag,_0{})); // (M,N,L) + + Tensor gC_real = local_tile(mC_real, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) + Tensor gC_imag = local_tile(mC_imag, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) + + Tensor gD_real = local_tile(mD_real, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) + Tensor gD_imag = local_tile(mD_imag, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) + + // Partition source and destination tiles according to tmem copy T2R partitioning (tTR_) + auto tiled_t2r = make_tmem_copy(CopyOpT2R{}, tensor<0>(accumulators)); + auto thread_t2r = tiled_t2r.get_slice(threadIdx.x % size(tiled_t2r)); + Tensor tTR_gC_real = thread_t2r.partition_D(gC_real); // (T2R,T2R_M,T2R_N) + Tensor tTR_gC_imag = thread_t2r.partition_D(gC_imag); // (T2R,T2R_M,T2R_N) + + Tensor tTR_gD_real = thread_t2r.partition_D(gD_real); // (T2R,T2R_M,T2R_N) + Tensor tTR_gD_imag = thread_t2r.partition_D(gD_imag); // (T2R,T2R_M,T2R_N) + + Tensor tTR_rAcc = make_tensor(append(shape(tTR_gD_real), Int{})); // (T2R,T2R_M,T2R_N,2) + Tensor tTR_rD = make_tensor(append(shape(tTR_gD_real), Int{})); // (T2R,T2R_M,T2R_N,2) + + Tensor tTR_rAcc_frg = recast>(coalesce(tTR_rAcc)); // (EPI_V) + Tensor tTR_rD_frg = recast>(coalesce(tTR_rD)); // (EPI_V) + + Tensor coordD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) + Tensor cD = local_tile(coordD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) + Tensor tTR_cD = thread_t2r.partition_D(cD); // (T2R,T2R_M,T2R_N) -> (m,n,l) + + // 1. Load accumulators into register from tmem + auto accumulators_real = accumulators(_,_,_,0); + auto accumulators_imag = accumulators(_,_,_,1); + Tensor tAcc_real = accumulators_real(make_coord(_,_),_0{},_0{}); // (CTA_M,CTA_N) + Tensor tAcc_imag = accumulators_imag(make_coord(_,_),_0{},_0{}); // (CTA_M,CTA_N) + Tensor tTR_tAcc_real = thread_t2r.partition_S(tAcc_real); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + Tensor tTR_tAcc_imag = thread_t2r.partition_S(tAcc_imag); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + + // tmem -> rmem + copy(tiled_t2r, tTR_tAcc_real, tTR_rAcc(_,_,_,0)); + copy(tiled_t2r, tTR_tAcc_imag, tTR_rAcc(_,_,_,1)); + + // 2. Apply element-wise operation and store to gmem + ThreadEpilogueOp epilogue_op{params.thread}; + // source is needed + if (epilogue_op.is_source_needed()) { + Tensor tTR_rC = make_tensor(append(shape(tTR_gC_real), Int{})); // (T2R,T2R_M,T2R_N,2) + Tensor tTR_rC_frg = recast>(coalesce(tTR_rC)); // (EPI_V) + + auto tTR_rC_real = tTR_rC(_,_,_,0); + auto tTR_rC_imag = tTR_rC(_,_,_,1); + + for( int i = 0; i < size(tTR_gC_real); ++i) { + if (elem_less(tTR_cD(i), problem_shape_mnl)) { + tTR_rC_real(i) = tTR_gC_real(i); + tTR_rC_imag(i) = tTR_gC_imag(i); + } + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc_frg); ++i) { + tTR_rD_frg(i) = epilogue_op(tTR_rAcc_frg(i), tTR_rC_frg(i)); + } + } + // source is not needed, avoid load + else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc_frg); ++i) { + tTR_rD_frg(i) = epilogue_op(tTR_rAcc_frg(i)); + } + } + + auto tTR_rD_real = tTR_rD(_,_,_,0); + auto tTR_rD_imag = tTR_rD(_,_,_,1); + + for( int i = 0; i < size(tTR_gD_real); ++i) { + if (elem_less(tTR_cD(i), problem_shape_mnl)) { + tTR_gD_real(i) = tTR_rD_real(i); + tTR_gD_imag(i) = tTR_rD_imag(i); + } + } + } + +protected: + Params const& params; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// For sm100 kernels requiring warp specialized epilogues +template < + class EpilogueTile_, // (EPI_TILE_M, EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class ThreadEpilogueOp_, + class CopyOpT2R_, + class AlignmentC, + class AlignmentD +> +class CollectiveEpilogue< + Sm100PtrArrayPlanarComplexNoSmemWarpSpecialized, + EpilogueTile_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + ThreadEpilogueOp_, + CopyOpT2R_, + AlignmentC, + AlignmentD +> : public detail::Sm100TmaWarpSpecializedAdapter> +{ +public: + // ctor inheritance + using detail::Sm100TmaWarpSpecializedAdapter>::Sm100TmaWarpSpecializedAdapter; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + + + + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_array_planar_complex_tma_warpspecialized.hpp b/3rd/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_array_planar_complex_tma_warpspecialized.hpp new file mode 100644 index 0000000..8c0626a --- /dev/null +++ b/3rd/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_array_planar_complex_tma_warpspecialized.hpp @@ -0,0 +1,1161 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Functor performing elementwise operations used by Ptr-Array Planar Complex Gemm epilogues. +*/ + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" +#include "cutlass/detail/layout.hpp" +#include "cutlass/trace.h" + +#include "cute/tensor.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_, + class CtaTileShape_, // (CTA_M,CTA_N,CTA_K, optional: Tile_L) + class EpilogueTile_, // (EPI_TILE_M, EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class ThreadEpilogueOp_, + class CopyOpT2R_, + class CopyOpG2S_, + class SmemLayoutAtomC_, + class CopyOpS2R_, + class CopyOpS2G_, + class SmemLayoutAtomD_, + class CopyOpR2S_, + class CopyOpR2R_ +> +class CollectiveEpilogue< + Sm100PtrArrayPlanarComplexTmaWarpSpecialized, + CtaTileShape_, + EpilogueTile_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + ThreadEpilogueOp_, + CopyOpT2R_, + CopyOpG2S_, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpS2G_, + SmemLayoutAtomD_, + CopyOpR2S_, + CopyOpR2R_ +> { +public: + using DispatchPolicy = Sm100PtrArrayPlanarComplexTmaWarpSpecialized; + using CtaTileShape = CtaTileShape_; + using EpilogueTile = EpilogueTile_; + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementC = ElementC_; + using StrideC = StrideC_; + using InternalStrideC = cute::remove_pointer_t; + using ElementD = ElementD_; + using StrideD = StrideD_; + using InternalStrideD = cute::remove_pointer_t; + using CopyOpT2R = CopyOpT2R_; + using CopyOpG2S = CopyOpG2S_; + using SmemLayoutAtomC = SmemLayoutAtomC_; + using CopyOpS2R = CopyOpS2R_; + using CopyOpS2G = CopyOpS2G_; + using SmemLayoutAtomD = SmemLayoutAtomD_; + using CopyOpR2S = CopyOpR2S_; + using CopyOpR2R = CopyOpR2R_; + + using GmemTiledCopyC = CopyOpG2S; + using GmemTiledCopyD = CopyOpS2G; + + constexpr static int ThreadCount = 128; + + static_assert(!is_layout::value && is_tuple::value, "EpilogueTile must be a cute::Tile or cute::Shape"); + static_assert(rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]"); + + // Epilog assumes a max scheduler pipe count to calculate the number of asynchronous tma update buffer they need. + constexpr static uint32_t NumMaxSchedulerPipelineStageCount = 8; + +private: + using SmemElementD = typename cutlass::detail::get_unpacked_element_type::type; + using SmemElementC = typename cutlass::detail::get_unpacked_element_type,ElementD,ElementC>>::type; // prevents void ref breakages + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static bool ReuseSmemC = ReuseSmemC_; + constexpr static bool DelayTmaStore = DelayTmaStore_; + constexpr static bool is_source_supported = ThreadEpilogueOp::kScale == cutlass::epilogue::thread::ScaleType::Default; + + constexpr static bool is_m_major_C = detail::is_m_major(); + constexpr static bool is_m_major_D = detail::is_m_major(); + + // Multiple buffer the TMA descriptors for each SM so that we can update them asynchronously. + // This should be larger than the total number of TMA requests inflight (from update to issued to returned). + // This can be calculated by SchedulerStages + max(TmaStages) + 2 (for consumer and producer in-flight accessies). + constexpr static uint32_t NumTmaDescriptorsPerSm = NumMaxSchedulerPipelineStageCount + std::max(StagesC, (ReuseSmemC ? StagesC : StagesD)) + 2; + + using SmemLayoutC = decltype(tile_to_shape( + SmemLayoutAtomC{}, + make_shape(size<0>(shape(EpilogueTile{})), size<1>(shape(EpilogueTile{})), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{} )); + using SmemLayoutD = decltype(tile_to_shape( + SmemLayoutAtomD{}, + make_shape(size<0>(shape(EpilogueTile{})), size<1>(shape(EpilogueTile{})), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{} )); + + constexpr static bool support_smem_reuse = is_source_supported && StagesD <= StagesC + && cosize(take<0,2>(SmemLayoutC{})) == cosize(take<0,2>(SmemLayoutD{})); + static_assert(not (ReuseSmemC && not support_smem_reuse), "Smem reuse requirements not met"); + + constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); + constexpr static size_t SmemAlignmentC = cutlass::detail::alignment_for_swizzle(SmemLayoutC{}); + +public : + struct TensorStorageWithC { + alignas(SmemAlignmentC) cute::ArrayEngine> smem_C_real; + alignas(SmemAlignmentC) cute::ArrayEngine> smem_C_imag; + + alignas(SmemAlignmentD) cute::ArrayEngine> smem_D_real; + alignas(SmemAlignmentD) cute::ArrayEngine> smem_D_imag; + }; + + struct TensorStorageWithoutC { + alignas(SmemAlignmentD) cute::ArrayEngine> smem_D_real; + alignas(SmemAlignmentD) cute::ArrayEngine> smem_D_imag; + }; + +public: + // TMA pipeline for loading C + using LoadPipeline = cutlass::PipelineTransactionAsync; + using LoadPipelineState = cutlass::PipelineState; + constexpr static uint32_t TmaTransactionBytes = + 2 * ((size(take<0,2>(SmemLayoutC{})) * static_cast(sizeof_bits::value)) / 8); + + // TMA pipeline for storing D + using StorePipeline = cute::conditional_t, + cutlass::PipelineTmaStore>; + using StorePipelineState = cutlass::PipelineState; + + struct SharedStorage { + using TensorStorage = + cute::conditional_t; + TensorStorage tensors; + + struct TensorMapStorage : cute::aligned_struct<128, _0> { + cute::TmaDescriptor smem_tensormap_C_real; + cute::TmaDescriptor smem_tensormap_C_imag; + + cute::TmaDescriptor smem_tensormap_D_real; + cute::TmaDescriptor smem_tensormap_D_imag; + } tensormaps; + + using PipelineStorage = typename LoadPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Planar complex kernels have two accumulator copies for the real and imaginary tensors. + constexpr static int NumAccumulatorMtxs = 2; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread{}; + ElementC const** ptr_C_real = nullptr; + StrideC dC_real{}; + ElementC const** ptr_C_imag = nullptr; + StrideC dC_imag{}; + ElementD** ptr_D_real = nullptr; + StrideD dD_real{}; + ElementD** ptr_D_imag = nullptr; + StrideD dD_imag{}; + }; + + // Device side epilogue params + struct Params { + using TensorShapeC = decltype(repeat_like(append<3>(StrideC{}, _1{}), int32_t(0))); + using TensorShapeD = decltype(repeat_like(append<3>(StrideD{}, _1{}), int32_t(0))); + using TMA_C = decltype(make_tma_copy( + CopyOpG2S{}, + make_tensor( + make_gmem_ptr(static_cast,ElementD,ElementC> const*>(nullptr)), + TensorShapeC{}, + append<3>(StrideC{}, _0{})), + take<0,2>(SmemLayoutC{}), + EpilogueTile{}, + _1{})); + using TMA_D = decltype(make_tma_copy( + CopyOpS2G{}, + make_tensor( + make_gmem_ptr(static_cast(nullptr)), + TensorShapeD{}, + append<3>(StrideD{}, _0{})), + take<0,2>(SmemLayoutC{}), + EpilogueTile{}, + _1{})); + + typename ThreadEpilogueOp::Params thread{}; + TMA_C tma_load_c_real; + TMA_C tma_load_c_imag; + TMA_D tma_store_d_real; + TMA_D tma_store_d_imag; + cute::TmaDescriptor* tensormaps; + ElementC const** ptr_C_real; + ElementC const** ptr_C_imag; + ElementD** ptr_D_real; + ElementD** ptr_D_imag; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + // Tensor shapes for Ptr-Array are initialized correctly here. + auto [M,N,K,mock_L] = problem_shape.get_host_problem_shape(0); + // Batches/Groups are managed by using appropriate pointers to input matrices + mock_L = 1; + + typename Params::TMA_C tma_load_c_real{}; + typename Params::TMA_C tma_load_c_imag{}; + if constexpr (not cute::is_void_v) { + // Tensor pointers will be fixed before the first access + ElementC const* ptr_C_real_first_batch = nullptr; + ElementC const* ptr_C_imag_first_batch = nullptr; + Tensor tensor_c_real = make_tensor(ptr_C_real_first_batch, make_layout(make_shape(M,N,mock_L), append<3>(args.dC_real, _0{}))); + Tensor tensor_c_imag = make_tensor(ptr_C_imag_first_batch, make_layout(make_shape(M,N,mock_L), append<3>(args.dC_imag, _0{}))); + + tma_load_c_real = make_tma_copy(CopyOpG2S{}, tensor_c_real, take<0,2>(SmemLayoutC{}), EpilogueTile{}, _1{}); + tma_load_c_imag = make_tma_copy(CopyOpG2S{}, tensor_c_imag, take<0,2>(SmemLayoutC{}), EpilogueTile{}, _1{}); + } + + // Tensor pointers will be fixed before the first access + ElementD* ptr_D_real_first_batch = nullptr; + ElementD* ptr_D_imag_first_batch = nullptr; + Tensor tensor_d_real = make_tensor(ptr_D_real_first_batch, make_layout(make_shape(M,N,mock_L), append<3>(args.dD_real, _0{}))); + Tensor tensor_d_imag = make_tensor(ptr_D_imag_first_batch, make_layout(make_shape(M,N,mock_L), append<3>(args.dD_imag, _0{}))); + + typename Params::TMA_D tma_store_d_real = + make_tma_copy(CopyOpS2G{}, tensor_d_real, take<0,2>(SmemLayoutD{}), EpilogueTile{}, _1{}); + typename Params::TMA_D tma_store_d_imag = + make_tma_copy(CopyOpS2G{}, tensor_d_imag, take<0,2>(SmemLayoutD{}), EpilogueTile{}, _1{}); + + return { + args.thread, + tma_load_c_real, + tma_load_c_imag, + tma_store_d_real, + tma_store_d_imag, + static_cast(workspace), + args.ptr_C_real, + args.ptr_C_imag, + args.ptr_D_real, + args.ptr_D_imag + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr uint32_t NumTensors = cute::is_void_v ? 2 : 4; + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + + return (NumTensors * SizeOfCuTensorMap * sm_count * NumTmaDescriptorsPerSm); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits_d = cutlass::detail::get_output_alignment_bits(); + auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(0), 1); + auto [M,N,K,L] = problem_shape_MNKL; + + constexpr int min_tma_aligned_elements_D = tma_alignment_bits_d / cutlass::sizeof_bits::value; + bool implementable = cutlass::detail::check_alignment(cute::make_shape(M,N,L), StrideD{}); + + if constexpr (not cute::is_void_v) { + constexpr int tma_alignment_bits_c = cutlass::detail::get_output_alignment_bits(); + constexpr int min_tma_aligned_elements_C = tma_alignment_bits_c / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), StrideC{}); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + + bool beta_implementable = true; + + if constexpr (cute::is_void_v) { + if constexpr (detail::has_beta::value) { + beta_implementable = args.thread.beta == 0.0; + } + if constexpr (detail::has_beta_ptr::value) { + beta_implementable = beta_implementable && args.thread.beta_ptr == nullptr; + } + } + + if (!beta_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Beta/beta pointer was set, but epilogue is sourceless (void-C).\n"); + } + + return implementable && beta_implementable; + } + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_load_pipe_increment(CtaTileMNK cta_tile_mnk) { + // Compute number of epilogue subtiles + constexpr int epi_m = size<0>(cta_tile_mnk) / size<0>(EpilogueTile{}); + constexpr int epi_n = size<1>(cta_tile_mnk) / size<1>(EpilogueTile{}); + + return epi_m * epi_n; + } + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_store_pipe_increment(CtaTileMNK cta_tile_mnk) { + return get_load_pipe_increment(cta_tile_mnk); + } + + CUTLASS_HOST_DEVICE + CollectiveEpilogue(Params const& params_, TensorStorage&) + : params(params_), epilogue_op(params_.thread) {} + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return epilogue_op.is_source_needed(); + } + + template + CUTLASS_DEVICE auto + load_init( + Params const& params, + TensorMapStorage& shared_tensormap, + int32_t const sm_count, + int32_t const sm_idx) const { + if constexpr (IsTmaAsyncUpdate) { + // Async update kernels will fetch the tensormap directly from tensormaps_init. + return cute::make_tuple(); + } else { + // Fetch a copy of tensormaps for the CTA from Params + constexpr bool IsEpiLoad = true; + auto load_tensormaps = tensormaps_init(params, shared_tensormap, sm_count, sm_idx); + return cute::make_tuple(load_tensormaps); + } + } + + template< + bool ReuseTmem = false, + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class MmaTileMNK, + class TiledMma, + class TensorMapC + > + CUTLASS_DEVICE auto + load( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + MmaTileMNK mma_tile_mnk, + TiledMma tiled_mma, + TensorStorage& shared_tensors, + cute::tuple, bool> load_tensormaps_info, + bool reverse_epi_n = false) { + using namespace cute; + + // Check to see if tensormaps have been replaced in gmem + if (get<1>(load_tensormaps_info) /* did_batch_change */) { + tensormaps_fence_acquire(get<0>(load_tensormaps_info)); + } + + int lane_idx = canonical_lane_idx(); + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; + + auto coord_shape = append<3>(make_shape(m_coord, n_coord),Int<0>{}); + // Tile residue + auto m_max_coord = unwrap(cute::transform(make_seq(cta_tile_mnk)>{}, [&](auto i) { + return get<0,i>(problem_shape_mnkl) - get<0,i>(cta_tile_mnk) * get<0,i>(cta_coord_mnkl); + })); + auto n_max_coord = unwrap(cute::transform(make_seq(cta_tile_mnk)>{}, [&](auto i) { + return get<1,i>(problem_shape_mnkl) - get<1,i>(cta_tile_mnk) * get<1,i>(cta_coord_mnkl); + })); + auto residue_mn = make_coord(m_max_coord, n_max_coord); + + // Represent the full source tensor, slice to get the tile this CTA is currently responsible for + Tensor mC_real_mn = params.tma_load_c_real.get_tma_tensor(append<3>(make_shape(M,N),Int<1>{})); // (M,N,L) + Tensor mC_imag_mn = params.tma_load_c_imag.get_tma_tensor(append<3>(make_shape(M,N),Int<1>{})); // (M,N,L) + + Tensor mC_real = coalesce(mC_real_mn, take<0,2>(cta_tile_mnk)); + Tensor mC_imag = coalesce(mC_imag_mn, take<0,2>(cta_tile_mnk)); + + Tensor gC_real = local_tile(mC_real, take<0,2>(cta_tile_mnk), coord_shape); // (CTA_M,CTA_N) + Tensor gC_imag = local_tile(mC_imag, take<0,2>(cta_tile_mnk), coord_shape); // (CTA_M,CTA_N) + + + // Apply epilogue subtile, get matching smem tensor + auto ptr_sC_real = [&]() { + if constexpr (not ReuseSmemC and is_source_supported) { + return shared_tensors.smem_C_real.begin(); + } + else { + return shared_tensors.smem_D_real.begin(); + } + }(); + auto ptr_sC_imag = [&]() { + if constexpr (not ReuseSmemC and is_source_supported) { + return shared_tensors.smem_C_imag.begin(); + } + else { + return shared_tensors.smem_D_imag.begin(); + } + }(); + + Tensor gC_real_epi = flat_divide(gC_real, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor gC_imag_epi = flat_divide(gC_imag, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + Tensor sC_real_epi = make_tensor(make_smem_ptr(ptr_sC_real), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + Tensor sC_imag_epi = make_tensor(make_smem_ptr(ptr_sC_imag), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + + // Prepare the thread(b)lock's (G)mem to (S)mem TMA tiled copy (bGS_) + ThrCopy thrblk_g2s_real = params.tma_load_c_real.get_slice(Int<0>{}); + ThrCopy thrblk_g2s_imag = params.tma_load_c_imag.get_slice(Int<0>{}); + + Tensor bGS_gC_real = thrblk_g2s_real.partition_S(gC_real_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) + Tensor bGS_gC_imag = thrblk_g2s_imag.partition_S(gC_imag_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) + + Tensor bGS_sC_real = thrblk_g2s_real.partition_D(sC_real_epi); // (TMA,TMA_M,TMA_N,PIPE_C) + Tensor bGS_sC_imag = thrblk_g2s_imag.partition_D(sC_imag_epi); // (TMA,TMA_M,TMA_N,PIPE_C) + + // Predication for TMA load (one thread issues TMA load) + bool issue_tma_load = cute::elect_one_sync(); + + // Acquire the lock for the first stage + uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); + load_pipeline.producer_acquire(load_pipe_producer_state); + + CUTLASS_PRAGMA_UNROLL + for (int iter_n = 0; iter_n < size<3>(gC_real_epi); ++iter_n) { + CUTLASS_PRAGMA_UNROLL + for (int iter_m = 0; iter_m < size<2>(gC_real_epi); ++iter_m) { + int epi_m = iter_m, epi_n = iter_n; + if constexpr (ReuseTmem) { + if (reverse_epi_n) { + epi_n = size<3>(gC_real_epi) - 1 - iter_n; + } + } + // Acquire the lock for this stage + constexpr uint16_t mcast_mask = 0; + uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); + load_pipeline.producer_acquire(load_pipe_producer_state); + + // Execute the TMA load for C + if (issue_tma_load) { + copy(params.tma_load_c_real.with(get<0>(get<0>(load_tensormaps_info)), *tma_barrier, mcast_mask), + bGS_gC_real(_,_,_,epi_m,epi_n), bGS_sC_real(_,_,_,load_pipe_producer_state.index())); + copy(params.tma_load_c_imag.with(get<1>(get<0>(load_tensormaps_info)), *tma_barrier, mcast_mask), + bGS_gC_imag(_,_,_,epi_m,epi_n), bGS_sC_imag(_,_,_,load_pipe_producer_state.index())); + load_pipeline.producer_expect_transaction(load_pipe_producer_state); + } + + // Commit TMA loads for this stage and release the lock + load_pipeline.producer_commit(load_pipe_producer_state); + ++load_pipe_producer_state; + } + } + + return load_pipe_producer_state; + } + + CUTLASS_DEVICE void + load_tail( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state, + [[maybe_unused]] StorePipeline store_pipeline, + [[maybe_unused]] StorePipelineState store_pipe_producer_state) { + load_pipeline.producer_tail(load_pipe_producer_state); + } + + template + CUTLASS_DEVICE auto + store_init( + Params const& params, + TensorMapStorage& shared_tensormap, + int32_t const sm_count, + int32_t const sm_idx) const { + if constexpr (IsTmaAsyncUpdate) { + return cute::make_tuple(); + } else { + // Fetch a copy of tensormaps for the CTA from Params + constexpr bool IsEpiLoad = false; + cute::tuple store_tensormaps = {nullptr, nullptr}; + int thread_idx = threadIdx.x % ThreadCount; + int warp_idx = thread_idx / NumThreadsPerWarp; + // Only the first epilogue warp needs to perform TMA related operations + if (warp_idx == 0) { + store_tensormaps = tensormaps_init(params, shared_tensormap, sm_count, sm_idx); + } + return cute::make_tuple(store_tensormaps); + } + } + + template< + bool ReuseTmem = false, + class AccumulatorPipeline, + class AccumulatorPipelineState, + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class MmaTileMNK, + class TiledMma, + class AccEngine, + class AccLayout, + class TensorMapD + > + CUTLASS_DEVICE auto + store( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + AccumulatorPipeline acc_pipeline, + AccumulatorPipelineState acc_pipe_consumer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + MmaTileMNK mma_tile_mnk, + TiledMma tiled_mma, + cute::Tensor accumulators, + TensorStorage& shared_tensors, + cute::tuple, bool> store_tensormap_info + ) { + using namespace cute; + using ElementAccumulator = typename AccEngine::value_type; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + + static_assert(is_tmem::value, "Accumulator must be TMEM resident."); + //static_assert(rank(accumulators) == 4, "Accumulators must be MMA-partitioned: [MMA, MMA_M, MMA_N]"); + static_assert(size<1>(accumulators) == 1 && size<2>(accumulators) == 1, "TiledMMA must match partitioned ShapeMN"); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(rank(CtaCoordMNKL{}) == 4, "CoordMNKL must be rank 4"); + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; + int thread_idx = threadIdx.x % ThreadCount; + int warp_idx = thread_idx / NumThreadsPerWarp; + [[maybe_unused]] int lane_idx = thread_idx % NumThreadsPerWarp; + + // Check to see if tensormaps have been replaced in gmem + // Only the first epilogue warp needs to perform TMA related operations + if (get<1>(store_tensormap_info) /* did_batch_change */ && warp_idx == 0) { + tensormaps_fence_acquire(get<0>(store_tensormap_info)); + } + + auto accumulators_real = accumulators(_,_,_,0); + auto accumulators_imag = accumulators(_,_,_,1); + + auto coord_shape = append<3>(make_shape(m_coord, n_coord),Int<0>{}); + + // Represent the full output tensor, slice to get the tile this CTA is responsible for + Tensor mD_real_mn = params.tma_store_d_real.get_tma_tensor(append<3>(make_shape(M,N),Int<1>{})); // (M,N,L) + Tensor mD_imag_mn = params.tma_store_d_imag.get_tma_tensor(append<3>(make_shape(M,N),Int<1>{})); // (M,N,L) + + Tensor mD_real = coalesce(mD_real_mn, take<0,2>(cta_tile_mnk)); + Tensor mD_imag = coalesce(mD_imag_mn, take<0,2>(cta_tile_mnk)); + + Tensor gD_real = local_tile(mD_real, take<0,2>(cta_tile_mnk), coord_shape); // (CTA_M,CTA_N) + Tensor gD_imag = local_tile(mD_imag, take<0,2>(cta_tile_mnk), coord_shape); // (CTA_M,CTA_N) + + Tensor tAcc_real = accumulators_real(make_coord(_,_),_0{},_0{}); // (CTA_M,CTA_N) + Tensor tAcc_imag = accumulators_imag(make_coord(_,_),_0{},_0{}); // (CTA_M,CTA_N) + + // Apply epilogue subtiling + Tensor tAcc_real_epi = flat_divide(tAcc_real, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor tAcc_imag_epi = flat_divide(tAcc_imag, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + Tensor gD_real_epi = flat_divide(gD_real, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor gD_imag_epi = flat_divide(gD_imag, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + // Construct the corresponding pipelined smem tensors + auto ptr_sC_real = [&]() { + if constexpr (not ReuseSmemC and is_source_supported) { + return shared_tensors.smem_C_real.begin(); + } + else { + return shared_tensors.smem_D_real.begin(); + } + }(); + auto ptr_sC_imag = [&]() { + if constexpr (not ReuseSmemC and is_source_supported) { + return shared_tensors.smem_C_imag.begin(); + } + else { + return shared_tensors.smem_D_imag.begin(); + } + }(); + + auto ptr_sD_real = shared_tensors.smem_D_real.begin(); + auto ptr_sD_imag = shared_tensors.smem_D_imag.begin(); + + Tensor sC_real_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sC_real), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + Tensor sC_imag_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sC_imag), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + + Tensor sD_real_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sD_real), SmemLayoutD{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_D) + Tensor sD_imag_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sD_imag), SmemLayoutD{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_D) + + // (t)hread-partition for (t)mem to (r)egister copy (tTR_) + TiledCopy tiled_t2r = make_tmem_copy(CopyOpT2R{}, tAcc_real_epi(_,_,_0{},_0{})); + ThrCopy thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_tAcc_real = thread_t2r.partition_S(tAcc_real_epi); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + Tensor tTR_sD_real = thread_t2r.partition_D(sD_real_epi(_,_,_0{})); // (T2R,T2R_M,T2R_N) + Tensor tTR_tAcc_imag = thread_t2r.partition_S(tAcc_imag_epi); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + Tensor tTR_sD_imag = thread_t2r.partition_D(sD_imag_epi(_,_,_0{})); // (T2R,T2R_M,T2R_N) + + // Allocate D and accumulator registers + Tensor tTR_rAcc = make_tensor(append(shape(tTR_sD_real), Int{})); // (T2R,T2R_M,T2R_N,2) + Tensor tTR_rD = make_tensor(append(shape(tTR_sD_real), Int{})); // (T2R,T2R_M,T2R_N,2) + + // Vectorized fragment view + constexpr int FragmentSize = DispatchPolicy::FragmentSize; + Tensor tTR_rAcc_frg = recast>(coalesce(tTR_rAcc)); // (EPI_V) + Tensor tTR_rD_frg = recast>(coalesce(tTR_rD)); // (EPI_V) + + CUTE_STATIC_ASSERT(size(tTR_rAcc) % DispatchPolicy::FragmentSize == 0, "Fragment size does not vectorize properly"); + + // (t)hread-partition for (s)mem to (r)egister copy (tSR_) + TiledCopy tiled_s2r = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx); + Tensor tSR_sC_real = thread_s2r.partition_S(sC_real_epi); // (S2R,S2R_M,S2R_N,PIPE_C) + Tensor tSR_sC_imag = thread_s2r.partition_S(sC_imag_epi); // (S2R,S2R_M,S2R_N,PIPE_C) + + Layout tSR_rC_layout = thread_s2r.retile_D(tTR_rD(_,_,_,_0{})).layout(); // (S2R,S2R_M,S2R_N) + + // Allocate C registers + // If C smem load is a non-vectorized dst(i) = src(i) then we can allocate C registers directly in the compute type + // to eliminate some redundant pack+unpack instruction sequences for sub-word types + constexpr bool IsDirectS2R = cute::is_same_v + && decltype(max_common_vector(tSR_rC_layout, tSR_sC_real.layout()))::value <= 1; + using RegisterElementC = cute::conditional_t; + Tensor tTR_rC = make_tensor(append(shape(tTR_sD_real), _2{})); // (T2R,T2R_M,T2R_N) + Tensor tSR_rC = thread_s2r.retile_D(tTR_rC); // (S2R,S2R_M,S2R_N) + Tensor tTR_rC_frg = recast>(tTR_rC); // (EPI_V) + + // (t)hread-partition for (r)egister to (s)mem copy (tRS_) + TiledCopy tiled_r2s = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); + Tensor tRS_rD = thread_r2s.retile_S(tTR_rD); // (R2S,R2S_M,R2S_N) + Tensor tRS_sD_real = thread_r2s.partition_D(sD_real_epi); // (R2S,R2S_M,R2S_N,PIPE_D) + Tensor tRS_sD_imag = thread_r2s.partition_D(sD_imag_epi); // (R2S,R2S_M,R2S_N,PIPE_D) + + // thread(b)lock-partition for (s)mem to (g)mem copy (bSG_) + ThrCopy thrblk_s2g = params.tma_store_d_real.get_slice(Int<0>{}); + Tensor bSG_sD_real = thrblk_s2g.partition_S(sD_real_epi); // (S2G,S2G_M,S2G_N,PIPE_D) + Tensor bSG_gD_real = thrblk_s2g.partition_D(gD_real_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) + Tensor bSG_sD_imag = thrblk_s2g.partition_S(sD_imag_epi); // (S2G,S2G_M,S2G_N,PIPE_D) + Tensor bSG_gD_imag = thrblk_s2g.partition_D(gD_imag_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) + + // Coordinate tensors and residue for tile quantization + auto m_max_coord = unwrap(cute::transform(make_seq(cta_tile_mnk)>{}, [&](auto i) { + auto c_m = get<0,i>(problem_shape_mnkl) - get<0,i>(cta_tile_mnk) * get<0,i>(cta_coord_mnkl); + return cute::max(0, c_m); + })); + auto n_max_coord = unwrap(cute::transform(make_seq(cta_tile_mnk)>{}, [&](auto i) { + auto c_n = get<1,i>(problem_shape_mnkl) - get<1,i>(cta_tile_mnk) * get<1,i>(cta_coord_mnkl); + return cute::max(0, c_n); + })); + auto residue_mn = make_coord(m_max_coord, n_max_coord); + Tensor cD = make_identity_tensor(take<0,2>(cta_tile_mnk)); + Tensor tTR_cD = thread_t2r.partition_D(flat_divide(cD, EpilogueTile{})); + + bool is_source_needed = epilogue_op.is_source_needed(); + // Thread synchronizer for previously issued waits or fences + // to ensure visibility of smem reads/writes to threads or TMA unit + auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(ThreadCount, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + + // Predication for sub-128 thread T2R tiled copy + Layout tmem_warp_layout = typename decltype(make_tmem_warp_partitioner(tAcc_real_epi(_,_,0,0)))::TiledLayout_TV{}; + constexpr bool predicate_tmem_load = size(tmem_warp_layout) != cosize(tmem_warp_layout); + bool issue_tmem_load = true; + + // If tmem doesn't have enough capacity to support double buffering, a portion of tmem (a column of epilogue tiles) + // is overlapped between 2 pseudo-buffers. The shared tmem portion corresponds to the last epilogue tile column of + // tmem accumulator buffer 0, and the first epilogue tile column of tmem accumulator 1. + // Thus, whenever we are processing tmem accumulator buffer 0, we process the epilogue tiles with reversed column order. + // Once the last epilogue tile column is loaded from tmem, the acc_pipeline is released. + // Then, the next accumulation stage for buffer 1 can start. + [[maybe_unused]] bool reverse_epi_n = ReuseTmem && acc_pipe_consumer_state.phase() == 0; + static_assert(not (ReuseTmem && AccumulatorPipeline::Stages != 1), "Tmem reuse requires 1 accumulator stage"); + + // Predication for TMA store (one warp issues TMA store) + bool issue_tma_store = warp_idx == 0; + + // In the reuse smem configuration we have StagesC smem buffers and at most StagesD committed TMA stores in flight. + // The TMA store pipeline producer acquire returns when at most StagesD-1 committed stores are in flight, so we can + // only guarantee store completion after StagesD iterations, then we can begin issuing releases on the smem buffer locks. + // store_pipe_producer_state tracks the acquire and load_pipe_consumer_state tracks the release, in circular buffer fashion. + // If TMA store supported async transaction mbarriers we would not need this synchronous release behavior. + LoadPipelineState load_wait_state = load_pipe_consumer_state; + if constexpr (ReuseSmemC) { + load_wait_state = store_pipe_producer_state; + load_wait_state.phase_ ^= 1; + } + + // We can delay issue of TMA store by one iteration to achieve better interleaving of non-TMA instructions + // Sync requirements of smem reuse may preclude this optimization + [[maybe_unused]] int epi_m_prev = 0; + [[maybe_unused]] int epi_n_prev = 0; + static_assert(not (DelayTmaStore and ReuseSmemC and StagesC <= StagesD), "This TMA epilogue configuration will deadlock"); + + // The TMA store sequence for one subtile iteration + auto tma_store_fn = [&] (int epi_m, int epi_n) { + // Write the tile from smem to gmem with TMA + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + synchronize(); // ensure all threads have issued their async fence + if (issue_tma_store) { + copy(params.tma_store_d_real.with(get<0>(get<0>(store_tensormap_info))), bSG_sD_real(_,_,_,store_pipe_producer_state.index()), bSG_gD_real(_,_,_,epi_m,epi_n)); + copy(params.tma_store_d_imag.with(get<1>(get<0>(store_tensormap_info))), bSG_sD_imag(_,_,_,store_pipe_producer_state.index()), bSG_gD_imag(_,_,_,epi_m,epi_n)); + } + + // Commit the TMA stores for this stage + if (issue_tma_store) { + store_pipeline.producer_commit(store_pipe_producer_state); + } + ++store_pipe_producer_state; + + // Wait for the next smem buffer to be available + if (issue_tma_store) { + store_pipeline.producer_acquire(store_pipe_producer_state); + } + synchronize(); + + if constexpr (ReuseSmemC) { + // producer_acquire returns when at most StagesD-1 committed stores are pending + bool store_finished = store_pipe_producer_state.count() > StorePipeline::UnacquiredStages; + // Let dma warp know earliest smem buffer is consumed and empty after StagesD producer commits + if (store_finished) { + if (is_source_needed) { + load_pipeline.consumer_release(load_pipe_consumer_state); + } + ++load_pipe_consumer_state; + } + } + }; + + // + // BEGIN EPILOGUE + // + + // Begin the wait for the producer load results + ConsumerToken load_wait_token{BarrierStatus::WaitDone}; + if (is_source_needed) { + load_wait_token = load_pipeline.consumer_try_wait(load_wait_state); + } + // Begin the wait for the accumulator results + ConsumerToken acc_wait_token = acc_pipeline.consumer_try_wait(acc_pipe_consumer_state); + + // For each epilogue subtile within the CTA tile + CUTLASS_PRAGMA_UNROLL + for (int iter_n = 0; iter_n < size<3>(gD_real_epi); ++iter_n) { + CUTLASS_PRAGMA_UNROLL + for (int iter_m = 0; iter_m < size<2>(gD_real_epi); ++iter_m) { + int epi_m = iter_m, epi_n = iter_n; + bool is_first_iteration = iter_m == 0 && iter_n == 0; + bool is_last_iteration = iter_m == size<2>(gD_real_epi)-1 && iter_n == size<3>(gD_real_epi)-1; + bool do_acc_release = is_last_iteration; + + // Reverse subtile order for tmem reuse if necessary + if constexpr (ReuseTmem) { + if (reverse_epi_n) { + epi_n = size<3>(gD_real_epi) - 1 - iter_n; + } + do_acc_release = iter_m == size<2>(gD_real_epi)-1 && iter_n == 0; + } + + if (is_source_needed) { + // Wait for the producer load to fill smem + load_pipeline.consumer_wait(load_wait_state, load_wait_token); + + // Copy source tile from smem to register // residual smem -> reg + copy(tiled_s2r, tSR_sC_real(_,_,_,load_wait_state.index()), tSR_rC(_,_,_,0)); + copy(tiled_s2r, tSR_sC_imag(_,_,_,load_wait_state.index()), tSR_rC(_,_,_,1)); + } + + if (is_source_needed) { + // Let producer load warp know smem buffers are consumed and empty + if constexpr (not ReuseSmemC) { + cutlass::arch::fence_view_async_shared(); + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } + ++load_wait_state; + } + + if (is_first_iteration) { + // Wait for mma warp to fill tmem buffer with accumulator results + acc_pipeline.consumer_wait(acc_pipe_consumer_state, acc_wait_token); + } + + // The current tile in tmem + Tensor tTR_tAcc_real_mn = tTR_tAcc_real(_,_,_,epi_m,epi_n); + Tensor tTR_tAcc_imag_mn = tTR_tAcc_imag(_,_,_,epi_m,epi_n); + + // Compute tmem load predication if necessary + if constexpr (predicate_tmem_load) { + // Issue tmem load if this tile's tmem subpartition is accessible by this warp + int subpart_idx = (tTR_tAcc_real_mn.data().dp_ / 32) % 4; + issue_tmem_load = warp_idx == subpart_idx; + } + + // Copy accumulator tile from tmem to register + if (issue_tmem_load) { // acc tmem -> reg + copy(tiled_t2r, tTR_tAcc_real_mn, tTR_rAcc(_,_,_,0)); + copy(tiled_t2r, tTR_tAcc_imag_mn, tTR_rAcc(_,_,_,1)); + } + + // After the last tmem load, signal that tmem buffer is consumed and empty + if (do_acc_release) { + cutlass::arch::fence_view_async_tmem_load(); + acc_pipeline.consumer_release(acc_pipe_consumer_state); + ++acc_pipe_consumer_state; + } + + // Vectorized fragment loop with visitor callback entry point + if (epilogue_op.is_source_needed()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rD_frg); ++i) { + tTR_rD_frg(i) = epilogue_op(tTR_rAcc_frg(i), tTR_rC_frg(i)); + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rD_frg); ++i) { + tTR_rD_frg(i) = epilogue_op(tTR_rAcc_frg(i)); + } + } + + if constexpr (DelayTmaStore) { + // Issue TMA stores for the previous subtile + if (not is_first_iteration) { + tma_store_fn(epi_m_prev, epi_n_prev); + } + epi_m_prev = epi_m; + epi_n_prev = epi_n; + } + + // Copy output tile from register to smem + bool issue_smem_store = issue_tmem_load; + if (issue_smem_store) { // after scale, reg -> smem + copy(tiled_r2s, tRS_rD(_,_,_,0), tRS_sD_real(_,_,_,store_pipe_producer_state.index())); + copy(tiled_r2s, tRS_rD(_,_,_,1), tRS_sD_imag(_,_,_,store_pipe_producer_state.index())); + } + + if constexpr (not DelayTmaStore) { + // Issue TMA stores for this subtile + tma_store_fn(epi_m, epi_n); + } + + if (is_source_needed) { + // Begin the wait for the next subtile producer load + load_wait_token = load_pipeline.consumer_try_wait(load_wait_state, is_last_iteration); + } + } // for epi_m + } // for epi_n + + if constexpr (DelayTmaStore) { + // Issue TMA stores for the last subtile + tma_store_fn(epi_m_prev, epi_n_prev); + } + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state, acc_pipe_consumer_state); + } + + template + CUTLASS_DEVICE void + store_tail( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + CtaTileMNK cta_tile_mnk) { + if constexpr (ReuseSmemC) { + if (epilogue_op.is_source_needed()) { + // wait for all TMA stores to complete + store_pipeline.producer_tail(store_pipe_producer_state); + + // Issue releases on up to StagesD-1 previously issued TMA stores + constexpr int release_stages = cute::min(StorePipeline::UnacquiredStages, get_load_pipe_increment(cta_tile_mnk)); + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < release_stages; ++stage) { + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } + } + } + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + + template + CUTLASS_DEVICE auto + tensormaps_init(Params const& params, + TensorMapStorage& shared_tensormap, + int32_t const sm_count, + int32_t const sm_idx, + bool const is_leader_warp = true) const { + // Define a local struct that provides simple array indexing for TMA descriptors + struct TensorMapArray { + cute::TmaDescriptor* tma_desc_real; + cute::TmaDescriptor* tma_desc_imag; + + TensorMapArray() = default; + + CUTLASS_DEVICE + TensorMapArray(cute::TmaDescriptor* desc_real, cute::TmaDescriptor* desc_imag) + : tma_desc_real(desc_real), tma_desc_imag(desc_imag) {} + + CUTLASS_DEVICE + cute::tuple + operator[](int32_t idx) const { + idx = idx % NumTmaDescriptorsPerSm; + return cute::make_tuple(tma_desc_real + idx, tma_desc_imag + idx); + } + }; + + cute::TmaDescriptor* tma_desc_real = nullptr; + cute::TmaDescriptor* tma_desc_imag = nullptr; + cute::TmaDescriptor* gmem_tensormap = params.tensormaps; + + if (!is_leader_warp) { + if constexpr (IsTmaAsyncUpdate) { + return TensorMapArray{tma_desc_real, tma_desc_imag}; + } else { + return cute::make_tuple(tma_desc_real, tma_desc_imag); + } + } + if constexpr (IsLoad) { + if constexpr (not cute::is_void_v) { + tma_desc_real = &gmem_tensormap[sm_idx * NumTmaDescriptorsPerSm]; + tma_desc_imag = &gmem_tensormap[(sm_idx + sm_count) * NumTmaDescriptorsPerSm]; + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + Tensor pC_real_tensormap = make_tensor(params.tma_load_c_real.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sC_real_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_C_real), Int<1>{}, Int<1>{}); + Tensor pC_imag_tensormap = make_tensor(params.tma_load_c_imag.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sC_imag_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_C_imag), Int<1>{}, Int<1>{}); + + copy(recast(pC_real_tensormap), recast(sC_real_tensormap)); + copy(recast(pC_imag_tensormap), recast(sC_imag_tensormap)); + } + __syncwarp(); + } + } else { + int const offset_Ddesc = cute::is_void_v ? 0 : (2 * sm_count); + tma_desc_real = &gmem_tensormap[(sm_idx + offset_Ddesc) * NumTmaDescriptorsPerSm]; + tma_desc_imag = &gmem_tensormap[(sm_idx + offset_Ddesc + sm_count) * NumTmaDescriptorsPerSm]; + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to gmem for modification later + Tensor pD_real_tensormap = make_tensor(params.tma_store_d_real.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sD_real_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_D_real), Int<1>{}, Int<1>{}); + Tensor pD_imag_tensormap = make_tensor(params.tma_store_d_imag.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sD_imag_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_D_imag), Int<1>{}, Int<1>{}); + + copy(recast(pD_real_tensormap), recast(sD_real_tensormap)); + copy(recast(pD_imag_tensormap), recast(sD_imag_tensormap)); + } + __syncwarp(); + } + + if constexpr (IsTmaAsyncUpdate) { + return TensorMapArray{tma_desc_real, tma_desc_imag}; + } else { + return cute::make_tuple(tma_desc_real, tma_desc_imag); + } + } + + // Replace address for the global tensor (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormap, + Params const& params, + int32_t next_batch) { + // Replacing global_address for the next batch + if constexpr (IsLoad) { + if (not cute::is_void_v) { + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_C_real, + params.ptr_C_real[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_C_imag, + params.ptr_C_imag[next_batch]); + } + } else { + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_D_real, + params.ptr_D_real[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_D_imag, + params.ptr_D_imag[next_batch]); + } + } + + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + TensorMapStorage& shared_tensormap, + Params const& params, + cute::tuple const& tensormaps, + [[maybe_unused]] ProblemShape problem_shape, + int32_t next_batch + ) { + if (cute::elect_one_sync()) { + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormap, params, next_batch); + } + // Ensure warp is converged before issuing tensormap fence release + __syncwarp(); + // Entire warp must do this (ie its aligned) + tensormaps_cp_fence_release( + shared_tensormap, + tensormaps + ); + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release( + TensorMapStorage& shared_tensormap, + cute::tuple const& tensormaps + ) { + // Commit and wait for all TMA load/store instructions before updating the tensormap in gmem if we're not using async update. + // This operation only happens when the group/batch changes between consecutive tiles. + // If there are no uncommitted instructions then tma_desc_commit_group results in an empty bulk async-group. + auto tma_desc_wait_all_fn = [] () CUTLASS_LAMBDA_FUNC_INLINE { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + }; + + // Entire warp must do this (ie its aligned) + if constexpr (IsLoad) { + if constexpr (not cute::is_void_v) { + if constexpr (WaitForInflightTmaRequests) { + tma_desc_wait_all_fn(); + } + tma_descriptor_cp_fence_release(get<0>(tensormaps), shared_tensormap.smem_tensormap_C_real); + tma_descriptor_cp_fence_release(get<1>(tensormaps), shared_tensormap.smem_tensormap_C_imag); + } + } else { + if constexpr (WaitForInflightTmaRequests) { + tma_desc_wait_all_fn(); + } + tma_descriptor_cp_fence_release(get<0>(tensormaps), shared_tensormap.smem_tensormap_D_real); + tma_descriptor_cp_fence_release(get<1>(tensormaps), shared_tensormap.smem_tensormap_D_imag); + } + } + + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::tuple const& tensormaps) { + if constexpr (IsLoad) { + if constexpr (not cute::is_void_v) { + cute::tma_descriptor_fence_acquire(get<0>(tensormaps)); + cute::tma_descriptor_fence_acquire(get<1>(tensormaps)); + } + } else { + cute::tma_descriptor_fence_acquire(get<0>(tensormaps)); + cute::tma_descriptor_fence_acquire(get<1>(tensormaps)); + } + } + +private: + Params const& params; + ThreadEpilogueOp epilogue_op; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp b/3rd/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp index 1f0a915..9ad99b1 100644 --- a/3rd/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -45,6 +45,7 @@ #include "cutlass/epilogue/fusion/callbacks.hpp" #include "cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp" #include "cutlass/detail/layout.hpp" +#include "cutlass/detail/collective/moe_stride_utils.hpp" #include "cutlass/trace.h" #include "cute/tensor.hpp" @@ -128,6 +129,9 @@ class CollectiveEpilogue< static_assert(!is_layout::value && is_tuple::value, "EpilogueTile must be a cute::Tile or cute::Shape"); static_assert(rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]"); + // Epilog assumes a max scheduler pipe count to calculate the number of asynchronous tma update buffer they need. + constexpr static uint32_t NumMaxSchedulerPipelineStageCount = 8; + private: constexpr static bool is_source_supported = not cute::is_void_v; @@ -177,6 +181,11 @@ class CollectiveEpilogue< // TMA store delay only benefits with loop unrolling constexpr static bool DelayTmaStore = DelayTmaStore_ and UnrollEpiLoop; + // Multiple buffer the TMA descriptors for each SM so that we can update them asynchronously. + // This should be larger than the total number of TMA requests inflight (from update to issued to returned). + // This can be calculated by SchedulerStages + max(TmaStages) + 2 (for consumer and producer in-flight accessies). + constexpr static uint32_t NumTmaDescriptorsPerSm = NumMaxSchedulerPipelineStageCount + std::max(StagesC, (ReuseSmemC ? StagesC : StagesD)) + 2; + struct CollectiveStorageWithC { alignas(SmemAlignmentC) ArrayEngine> smem_C; alignas(SmemAlignmentD) ArrayEngine> smem_D; @@ -346,7 +355,8 @@ class CollectiveEpilogue< constexpr uint32_t NumInputTensors = cute::is_void_v ? 1 : 2; constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies - return (NumInputTensors * SizeOfCuTensorMap * sm_count) + (round_nearest(FusionCallbacks::get_workspace_size(problem_shape, args.thread), MinTensorMapWorkspaceAlignment)); + return (NumInputTensors * SizeOfCuTensorMap * sm_count * NumTmaDescriptorsPerSm) + + (round_nearest(FusionCallbacks::get_workspace_size(problem_shape, args.thread), MinTensorMapWorkspaceAlignment)); } template @@ -456,16 +466,22 @@ class CollectiveEpilogue< return fusion_callbacks.is_producer_load_needed(); } + template CUTLASS_DEVICE auto load_init( Params const& params, TensorMapStorage& shared_tensormap, int32_t const sm_count, int32_t const sm_idx) const { - // Fetch a copy of tensormaps for the CTA from Params - constexpr bool IsEpiLoad = true; - auto load_tensormap = tensormaps_init(params, shared_tensormap, sm_count, sm_idx); - return cute::make_tuple(load_tensormap); + if constexpr (IsTmaAsyncUpdate) { + // Async update kernels will fetch the tensormap directly from tensormaps_init. + return cute::make_tuple(); + } else { + // Fetch a copy of tensormaps for the CTA from Params + constexpr bool IsEpiLoad = true; + auto load_tensormap = tensormaps_init(params, shared_tensormap, sm_count, sm_idx); + return cute::make_tuple(load_tensormap); + } } template< @@ -581,22 +597,27 @@ class CollectiveEpilogue< load_pipeline.producer_tail(load_pipe_producer_state); } + template CUTLASS_DEVICE auto store_init( Params const& params, TensorMapStorage& shared_tensormap, int32_t const sm_count, int32_t const sm_idx) const { - // Fetch a copy of tensormaps for the CTA from Params - constexpr bool IsEpiLoad = false; - cute::TmaDescriptor* store_tensormap = nullptr; - int thread_idx = threadIdx.x % ThreadCount; - int warp_idx = thread_idx / NumThreadsPerWarp; - // Only the first epilogue warp needs to perform TMA related operations - if (warp_idx == 0) { - store_tensormap = tensormaps_init(params, shared_tensormap, sm_count, sm_idx); + if constexpr (IsTmaAsyncUpdate) { + return cute::make_tuple(); + } else { + // Fetch a copy of tensormaps for the CTA from Params + constexpr bool IsEpiLoad = false; + cute::TmaDescriptor* store_tensormap = nullptr; + int thread_idx = threadIdx.x % ThreadCount; + int warp_idx = thread_idx / NumThreadsPerWarp; + // Only the first epilogue warp needs to perform TMA related operations + if (warp_idx == 0) { + store_tensormap = tensormaps_init(params, shared_tensormap, sm_count, sm_idx); + } + return cute::make_tuple(store_tensormap); } - return cute::make_tuple(store_tensormap); } template< @@ -646,6 +667,7 @@ class CollectiveEpilogue< int thread_idx = threadIdx.x % ThreadCount; int warp_idx = thread_idx / NumThreadsPerWarp; [[maybe_unused]] int lane_idx = thread_idx % NumThreadsPerWarp; + bool is_accumulator_needed = K > 0; // Check to see if tensormaps have been replaced in gmem // Only the first epilogue warp needs to perform TMA related operations @@ -804,7 +826,7 @@ class CollectiveEpilogue< static_assert(not (DelayTmaStore and ReuseSmemC and StagesC <= StagesD), "This TMA epilogue configuration will deadlock"); // The Epilogue Loop - auto epi_loop_fn = [&] (auto& cst_callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + auto epi_loop_fn = [&] (auto& cst_callbacks, bool is_accumulator_needed) CUTLASS_LAMBDA_FUNC_INLINE { bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); @@ -909,7 +931,7 @@ class CollectiveEpilogue< ++load_wait_state; } - if (is_first_iteration) { + if (is_first_iteration && is_accumulator_needed) { // Wait for mma warp to fill tmem buffer with accumulator results acc_pipeline.consumer_wait(acc_pipe_consumer_state, acc_wait_token); } @@ -926,11 +948,15 @@ class CollectiveEpilogue< // Copy accumulator tile from tmem to register if (issue_tmem_load) { - copy(tiled_t2r, tTR_tAcc_mn, tTR_rAcc); + if (is_accumulator_needed) { + copy(tiled_t2r, tTR_tAcc_mn, tTR_rAcc); + } else { + fill(tTR_rAcc, 0); + } } // After the last tmem load, signal that tmem buffer is consumed and empty - if (do_acc_release) { + if (do_acc_release && is_accumulator_needed) { cutlass::arch::fence_view_async_tmem_load(); acc_pipeline.consumer_release(acc_pipe_consumer_state); ++acc_pipe_consumer_state; @@ -1001,7 +1027,7 @@ class CollectiveEpilogue< // BEGIN EPILOGUE // auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); - epi_loop_fn(cst_callbacks); + epi_loop_fn(cst_callbacks, is_accumulator_needed); return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state, acc_pipe_consumer_state); } @@ -1343,17 +1369,44 @@ class CollectiveEpilogue< // Methods to perform different parts of TMA/Tensormap modifications // - template + template CUTLASS_DEVICE auto tensormaps_init(Params const& params, TensorMapStorage& shared_tensormap, int32_t const sm_count, - int32_t const sm_idx) const { + int32_t const sm_idx, + bool const is_leader_warp = true) const { + + // Define a local struct that provides simple array indexing for TMA descriptors + struct TensorMapArray { + cute::TmaDescriptor* tma_desc; + + TensorMapArray() = default; + + CUTLASS_DEVICE + TensorMapArray(cute::TmaDescriptor* desc) : tma_desc(desc) {} + + CUTLASS_DEVICE + cute::TmaDescriptor* + operator[](int32_t idx) const { + return tma_desc + (idx % NumTmaDescriptorsPerSm); + } + }; + cute::TmaDescriptor* tma_desc = nullptr; cute::TmaDescriptor* gmem_tensormap = params.tensormaps; + + if (!is_leader_warp) { + if constexpr (IsTmaAsyncUpdate) { + return TensorMapArray{tma_desc}; + } else { + return tma_desc; + } + } + if constexpr (IsLoad) { - if (is_source_supported) { - tma_desc = &gmem_tensormap[sm_idx]; + if constexpr (is_source_supported) { + tma_desc = &gmem_tensormap[sm_idx * NumTmaDescriptorsPerSm]; if (cute::elect_one_sync()) { // Bringing tensormaps from params to smem for modification later Tensor pC_tensormap = make_tensor(params.tma_load_c.get_tma_descriptor(), Int<1>{}, Int<1>{}); @@ -1364,7 +1417,7 @@ class CollectiveEpilogue< } } else if constexpr (is_destination_supported) { int const offset_Ddesc = cute::is_void_v ? 0 : sm_count; - tma_desc = &gmem_tensormap[sm_idx + offset_Ddesc]; + tma_desc = &gmem_tensormap[(sm_idx + offset_Ddesc) * NumTmaDescriptorsPerSm]; if (cute::elect_one_sync()) { // Bringing tensormaps from params to smem for modification later Tensor pD_tensormap = make_tensor(params.tma_store_d.get_tma_descriptor(), Int<1>{}, Int<1>{}); @@ -1374,7 +1427,11 @@ class CollectiveEpilogue< __syncwarp(); } - return tma_desc; + if constexpr (IsTmaAsyncUpdate) { + return TensorMapArray{tma_desc}; + } else { + return tma_desc; + } } // Replace address for the global tensor (to be done by single thread) @@ -1417,26 +1474,40 @@ class CollectiveEpilogue< if constexpr (IsLoad) { if constexpr (is_source_supported) { + ElementC const* ptr_C = nullptr; + Tensor tensor_c = make_tensor(ptr_C, make_layout(make_shape(M,N,Int<1>{}), InternalStrideC{})); if (params.dC != nullptr) { - ElementC const* ptr_C = nullptr; - Tensor tensor_c = make_tensor(ptr_C, make_layout(make_shape(M,N,Int<1>{}), params.dC[next_group])); - - cute::detail::fill_tma_gmem_shape_stride(params.tma_load_c, tensor_c, - prob_shape, prob_stride); - // Convert strides to byte strides - for (uint64_t& stride : prob_stride) { - stride = (stride * sizeof_bits_v) / 8; - } - cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_C, - prob_shape, - prob_stride); + tensor_c = make_tensor(ptr_C, make_layout(make_shape(M,N,Int<1>{}), params.dC[next_group])); + } + else { + auto internal_shape_c = make_shape(static_cast(M), static_cast(N), 1); + InternalStrideC stride_c = make_internal_packed_stride(InternalStrideC{}, internal_shape_c); + tensor_c = make_tensor(ptr_C, make_layout(make_shape(M,N,Int<1>{}), stride_c)); + } + + cute::detail::fill_tma_gmem_shape_stride(params.tma_load_c, tensor_c, + prob_shape, prob_stride); + // Convert strides to byte strides + for (uint64_t& stride : prob_stride) { + stride = (stride * sizeof_bits_v) / 8; } + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_C, + prob_shape, + prob_stride); + } } else if constexpr (is_destination_supported) { ElementD const* ptr_D = nullptr; - Tensor tensor_d = make_tensor(ptr_D, make_layout(make_shape(M,N,Int<1>{}), params.dD[next_group])); - + Tensor tensor_d = make_tensor(ptr_D, make_layout(make_shape(M,N,Int<1>{}), InternalStrideD{})); + if (params.dD != nullptr) { + tensor_d = make_tensor(ptr_D, make_layout(make_shape(M,N,Int<1>{}), params.dD[next_group])); + } + else { + auto internal_shape_d = make_shape(static_cast(M), static_cast(N), 1); + InternalStrideD stride_d = make_internal_packed_stride(InternalStrideD{}, internal_shape_d); + tensor_d = make_tensor(ptr_D, make_layout(make_shape(M,N,Int<1>{}), stride_d)); + } cute::detail::fill_tma_gmem_shape_stride(params.tma_store_d, tensor_d, prob_shape, prob_stride); // Convert strides to byte strides @@ -1445,21 +1516,23 @@ class CollectiveEpilogue< } cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_D, - prob_shape, - prob_stride); + prob_shape, + prob_stride); } } // The entire warp must call this function collectively (that is, the instructions are aligned) - template + template CUTLASS_DEVICE void tensormaps_perform_update( TensorMapStorage& shared_tensormap, Params const& params, - cute::TmaDescriptor const* tensormap, + cute::TmaDescriptor* tensormap, ProblemShape problem_shape, - int32_t next_batch) { + int32_t next_batch + ) { + __syncwarp(); if (cute::elect_one_sync()) { // Replacing global_address for the next batch tensormaps_replace_global_address(shared_tensormap, params, next_batch); @@ -1474,16 +1547,21 @@ class CollectiveEpilogue< // Ensure warp is converged before issuing tensormap fence release __syncwarp(); // Entire warp must do this (ie its aligned) - tensormaps_cp_fence_release(shared_tensormap, tensormap); + tensormaps_cp_fence_release( + shared_tensormap, + tensormap + ); } - template + template CUTLASS_DEVICE void tensormaps_cp_fence_release( TensorMapStorage& shared_tensormap, - cute::TmaDescriptor const* tensormap) { - // Commit and wait for all TMA load/store instructions before updating the tensormap in gmem. + cute::TmaDescriptor* tensormap + ) { + + // Commit and wait for all TMA load/store instructions before updating the tensormap in gmem if we're not using async update. // This operation only happens when the group/batch changes between consecutive tiles. // If there are no uncommitted instructions then tma_desc_commit_group results in an empty bulk async-group. auto tma_desc_wait_all_fn = [] () CUTLASS_LAMBDA_FUNC_INLINE { @@ -1492,14 +1570,19 @@ class CollectiveEpilogue< cute::tma_desc_wait_group(); } }; + // Entire warp must do this (ie its aligned) if constexpr (IsLoad) { - if (is_source_supported) { - tma_desc_wait_all_fn(); + if constexpr (is_source_supported) { + if constexpr (WaitForInflightTmaRequests) { + tma_desc_wait_all_fn(); + } tma_descriptor_cp_fence_release(tensormap, shared_tensormap.smem_tensormap_C); } } else if constexpr (is_destination_supported) { - tma_desc_wait_all_fn(); + if constexpr (WaitForInflightTmaRequests) { + tma_desc_wait_all_fn(); + } tma_descriptor_cp_fence_release(tensormap, shared_tensormap.smem_tensormap_D); } } @@ -1509,7 +1592,7 @@ class CollectiveEpilogue< void tensormaps_fence_acquire(cute::TmaDescriptor const* tensormap) { if constexpr (IsLoad) { - if (is_source_supported) { + if constexpr (is_source_supported) { cute::tma_descriptor_fence_acquire(tensormap); } } else if constexpr (is_destination_supported) { diff --git a/3rd/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp b/3rd/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp index f5e8fb5..c8944f1 100644 --- a/3rd/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -39,6 +39,8 @@ #include "cutlass/cutlass.h" #include "cutlass/epilogue/collective/detail.hpp" #include "cutlass/detail/helper_macros.hpp" +#include "cutlass/conv/convnd_problem_shape.hpp" +#include "cutlass/conv/detail.hpp" #include "cute/tensor.hpp" #include "cute/numeric/numeric_types.hpp" @@ -133,6 +135,7 @@ class CollectiveEpilogue< constexpr static int ThreadCount = 128; constexpr static int kOutputAlignment = ThreadEpilogueOp::kCount; constexpr static bool isEpilogueBiasSupported = detail::IsThreadEpilogueOpWithBias::value; + constexpr static bool isSourceNeeded = not cute::is_void_v; using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; constexpr static uint32_t TmaTransactionBytes = 0; @@ -173,12 +176,27 @@ class CollectiveEpilogue< return cutlass::Status::kSuccess; } + template + static bool + can_implement(cutlass::conv::ConvProblemShape const& problem_shape, Arguments const& args) { + return can_implement(cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape), args); + } + template static bool can_implement( [[maybe_unused]] ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) { - return true; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + auto shape = cute::make_shape(M,N,L); + + bool implementable = true; + implementable = implementable && cutlass::detail::check_alignment(shape, StrideD{}); + if constexpr (isSourceNeeded) { + implementable = implementable && cutlass::detail::check_alignment(shape, StrideC{}); + } + return implementable; } // @@ -196,6 +214,8 @@ class CollectiveEpilogue< public: template< bool ReuseTmem = false, + class LoadPipeline, + class LoadPipelineState, class AccumulatorPipeline, class AccumulatorPipelineState, class ProblemShapeMNKL, @@ -205,6 +225,8 @@ class CollectiveEpilogue< > CUTLASS_DEVICE auto operator()( + [[maybe_unused]]LoadPipeline load_pipeline, + [[maybe_unused]]LoadPipelineState load_pipe_consumer_state, AccumulatorPipeline acc_pipeline, AccumulatorPipelineState acc_pipe_consumer_state, ProblemShapeMNKL problem_shape_mnkl, @@ -325,8 +347,7 @@ class CollectiveEpilogue< copy_if(tDpD, tTR_rD_src, tR2G_rD_dst); } // source is not needed, avoid load - else - { + else { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(tTR_rAcc); i++) { tTR_rD_frag(i) = epilogue_op(tTR_rAcc(i)); @@ -335,7 +356,7 @@ class CollectiveEpilogue< copy_if(tDpD, tTR_rD_src, tR2G_rD_dst); } - return cute::make_tuple(acc_pipe_consumer_state); + return cute::make_tuple(acc_pipe_consumer_state, load_pipe_consumer_state); } @@ -554,6 +575,8 @@ class CollectiveEpilogue< template< bool ReuseTmem = false, + class LoadPipeline, + class LoadPipelineState, class AccumulatorPipeline, class AccumulatorPipelineState, class ProblemShapeMNKL, @@ -563,14 +586,16 @@ class CollectiveEpilogue< > CUTLASS_DEVICE auto operator()( + [[maybe_unused]]LoadPipeline load_pipeline, + [[maybe_unused]]LoadPipelineState load_pipe_consumer_state, AccumulatorPipeline acc_pipeline, AccumulatorPipelineState acc_pipe_consumer_state, ProblemShapeMNKL problem_shape_mnkl, CtaTileMNK cta_tile_mnk, CtaCoordMNKL cta_coord_mnkl, cute::Tensor accumulators, - [[maybe_unused]]SharedStorage& - ) { + [[maybe_unused]] SharedStorage& + ) { using ElementAccumulator = typename AccEngine::value_type; using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; using ElementCompute = cute::conditional_t,ElementAccumulator,ElementCompute_>; @@ -700,10 +725,20 @@ class CollectiveEpilogue< if (is_C_load_needed) { using CVecType = uint_bit_t>; + if constexpr (!is_same_v) { Tensor tTR_gC_frg = recast(coalesce(tTR_gC(_,_,_,epi_m,epi_n))); Tensor tTR_rC_frg = recast(coalesce(tCrC)); Tensor tTR_pC_frg = tensor<1>(zipped_divide(coalesce(tTR_pCD_mn), mclC.compose(Int{}))); copy_if(tTR_pC_frg, tTR_gC_frg, tTR_rC_frg); + } + else { + auto tiled_g2r = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + auto thr_g2r = tiled_g2r.get_slice(threadIdx.x); + Tensor c_src = thr_g2r.retile_S(tTR_gC(_,_,_,epi_m,epi_n)); + Tensor c_dst = thr_g2r.retile_D(tCrC); + Tensor c_prd = thr_g2r.retile_D(tTR_pCD_mn); + copy_if(tiled_g2r, c_prd, c_src, c_dst); + } } } @@ -735,10 +770,21 @@ class CollectiveEpilogue< cst_callbacks.end_loop(epi_m, epi_n); using VecType = uint_bit_t>; + if constexpr (!is_same_v) { Tensor tTR_gD_frg = recast(coalesce(tTR_gD(_,_,_,epi_m,epi_n))); Tensor tTR_rD_frg = recast(coalesce(tTR_rD)); Tensor tTR_pD_frg = tensor<1>(zipped_divide(coalesce(tTR_pCD_mn), mclD.compose(Int{}))); copy_if(tTR_pD_frg, tTR_rD_frg, tTR_gD_frg); + } + else { + auto tiled_r2g = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + auto thr_r2g = tiled_r2g.get_slice(threadIdx.x); + Tensor src = thr_r2g.retile_S(tTR_rD); + Tensor dst = thr_r2g.retile_D(tTR_gD(_,_,_,epi_m,epi_n)); + Tensor prd = thr_r2g.retile_D(tTR_pCD_mn); + copy_if(tiled_r2g, prd, src, dst); + } + } // for epi_m } // for epi_n @@ -750,7 +796,7 @@ class CollectiveEpilogue< // auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); epi_loop_fn(cst_callbacks); - return cute::make_tuple(acc_pipe_consumer_state); + return cute::make_tuple(acc_pipe_consumer_state, load_pipe_consumer_state); } }; diff --git a/3rd/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_planar_complex_tma_warpspecialized.hpp b/3rd/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_planar_complex_tma_warpspecialized.hpp new file mode 100644 index 0000000..e9b8d18 --- /dev/null +++ b/3rd/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_planar_complex_tma_warpspecialized.hpp @@ -0,0 +1,897 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Functor performing elementwise operations used by Planar Complex Gemm epilogues. +*/ + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" +#include "cutlass/detail/layout.hpp" +#include "cutlass/trace.h" + +#include "cute/tensor.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_, + class CtaTileShape_, // (CTA_M,CTA_N,CTA_K, optional: Tile_L) + class EpilogueTile_, // (EPI_TILE_M, EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class ThreadEpilogueOp_, + class CopyOpT2R_, + class CopyOpG2S_, + class SmemLayoutAtomC_, + class CopyOpS2R_, + class CopyOpS2G_, + class SmemLayoutAtomD_, + class CopyOpR2S_, + class CopyOpR2R_ +> +class CollectiveEpilogue< + Sm100PlanarComplexTmaWarpSpecialized, + CtaTileShape_, + EpilogueTile_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + ThreadEpilogueOp_, + CopyOpT2R_, + CopyOpG2S_, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpS2G_, + SmemLayoutAtomD_, + CopyOpR2S_, + CopyOpR2R_ +> { +public: + // + // Type Aliases + // + using DispatchPolicy = Sm100PlanarComplexTmaWarpSpecialized; + using CtaTileShape = CtaTileShape_; + using EpilogueTile = EpilogueTile_; + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementC = ElementC_; + using StrideC = StrideC_; + using ElementD = ElementD_; + using StrideD = StrideD_; + using CopyOpT2R = CopyOpT2R_; + using CopyOpG2S = CopyOpG2S_; + using SmemLayoutAtomC = SmemLayoutAtomC_; + using CopyOpS2R = CopyOpS2R_; + using CopyOpS2G = CopyOpS2G_; + using SmemLayoutAtomD = SmemLayoutAtomD_; + using CopyOpR2S = CopyOpR2S_; + using CopyOpR2R = CopyOpR2R_; + + using GmemTiledCopyC = CopyOpG2S; + using GmemTiledCopyD = CopyOpS2G; + + constexpr static int ThreadCount = 128; + + static_assert(!is_layout::value && is_tuple::value, "EpilogueTile must be a cute::Tile or cute::Shape"); + static_assert(rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]"); + +private: + using SmemElementD = typename cutlass::detail::get_unpacked_element_type::type; + using SmemElementC = typename cutlass::detail::get_unpacked_element_type,ElementD,ElementC>>::type; // prevents void ref breakages + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static bool ReuseSmemC = ReuseSmemC_; + constexpr static bool DelayTmaStore = DelayTmaStore_; + constexpr static bool is_source_supported = ThreadEpilogueOp::kScale == cutlass::epilogue::thread::ScaleType::Default; + + constexpr static bool is_m_major_C = detail::is_m_major(); + constexpr static bool is_m_major_D = detail::is_m_major(); + + using SmemLayoutC = decltype(tile_to_shape( + SmemLayoutAtomC{}, + make_shape(size<0>(shape(EpilogueTile{})), size<1>(shape(EpilogueTile{})), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{} )); + using SmemLayoutD = decltype(tile_to_shape( + SmemLayoutAtomD{}, + make_shape(size<0>(shape(EpilogueTile{})), size<1>(shape(EpilogueTile{})), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{} )); + + constexpr static bool support_smem_reuse = is_source_supported && StagesD <= StagesC + && cosize(take<0,2>(SmemLayoutC{})) == cosize(take<0,2>(SmemLayoutD{})); + static_assert(not (ReuseSmemC && not support_smem_reuse), "Smem reuse requirements not met"); + + constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); + constexpr static size_t SmemAlignmentC = cutlass::detail::alignment_for_swizzle(SmemLayoutC{}); + +public : + struct TensorStorageWithC { + alignas(SmemAlignmentC) cute::ArrayEngine> smem_C_real; + alignas(SmemAlignmentC) cute::ArrayEngine> smem_C_imag; + + alignas(SmemAlignmentD) cute::ArrayEngine> smem_D_real; + alignas(SmemAlignmentD) cute::ArrayEngine> smem_D_imag; + }; + + struct TensorStorageWithoutC { + alignas(SmemAlignmentD) cute::ArrayEngine> smem_D_real; + alignas(SmemAlignmentD) cute::ArrayEngine> smem_D_imag; + }; + +public: + // TMA pipeline for loading C + using LoadPipeline = cutlass::PipelineTransactionAsync; + using LoadPipelineState = cutlass::PipelineState; + constexpr static uint32_t TmaTransactionBytes = + 2 * ((size(take<0,2>(SmemLayoutC{})) * static_cast(sizeof_bits::value)) / 8); + + // TMA pipeline for storing D + using StorePipeline = cute::conditional_t, + cutlass::PipelineTmaStore>; + using StorePipelineState = cutlass::PipelineState; + + struct SharedStorage { + using TensorStorage = + cute::conditional_t; + TensorStorage tensors; + + using PipelineStorage = typename LoadPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Planar complex kernels have two accumulator copies for the real and imaginary tensors. + constexpr static int NumAccumulatorMtxs = 2; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread{}; + ElementC const* ptr_C_real = nullptr; + StrideC dC_real{}; + ElementC const* ptr_C_imag = nullptr; + StrideC dC_imag{}; + ElementD* ptr_D_real = nullptr; + StrideD dD_real{}; + ElementD* ptr_D_imag = nullptr; + StrideD dD_imag{}; + }; + + // Device side epilogue params + struct Params { + using TMA_C = decltype(make_tma_copy( + CopyOpG2S{}, + make_tensor( + make_gmem_ptr(static_cast,ElementD,ElementC> const*>(nullptr)), + repeat_like(append<3>(StrideC{}, _1{}), int32_t(0)), + append<3>(StrideC{}, _0{})), + take<0,2>(SmemLayoutC{}), + EpilogueTile{}, + _1{})); + using TMA_D = decltype(make_tma_copy( + CopyOpS2G{}, + make_tensor( + make_gmem_ptr(static_cast(nullptr)), + repeat_like(append<3>(StrideD{}, _1{}), int32_t(0)), + append<3>(StrideD{}, _0{})), + take<0,2>(SmemLayoutC{}), + EpilogueTile{}, + _1{})); + + typename ThreadEpilogueOp::Params thread{}; + TMA_C tma_load_c_real; + TMA_C tma_load_c_imag; + TMA_D tma_store_d_real; + TMA_D tma_store_d_imag; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; + + typename Params::TMA_C tma_load_c_real{}; + typename Params::TMA_C tma_load_c_imag{}; + if constexpr (not cute::is_void_v) { + Tensor tensor_c_real = make_tensor(make_gmem_ptr(args.ptr_C_real), make_layout(make_shape(M,N,L), append<3>(args.dC_real, _0{}))); + Tensor tensor_c_imag = make_tensor(make_gmem_ptr(args.ptr_C_imag), make_layout(make_shape(M,N,L), append<3>(args.dC_imag, _0{}))); + + tma_load_c_real = make_tma_copy(CopyOpG2S{}, tensor_c_real, take<0,2>(SmemLayoutC{}), EpilogueTile{}, _1{}); + tma_load_c_imag = make_tma_copy(CopyOpG2S{}, tensor_c_imag, take<0,2>(SmemLayoutC{}), EpilogueTile{}, _1{}); + } + + Tensor tensor_d_real = make_tensor(make_gmem_ptr(args.ptr_D_real), make_layout(make_shape(M,N,L), append<3>(args.dD_real, _0{}))); + Tensor tensor_d_imag = make_tensor(make_gmem_ptr(args.ptr_D_imag), make_layout(make_shape(M,N,L), append<3>(args.dD_imag, _0{}))); + + typename Params::TMA_D tma_store_d_real = + make_tma_copy(CopyOpS2G{}, tensor_d_real, take<0,2>(SmemLayoutD{}), EpilogueTile{}, _1{}); + typename Params::TMA_D tma_store_d_imag = + make_tma_copy(CopyOpS2G{}, tensor_d_imag, take<0,2>(SmemLayoutD{}), EpilogueTile{}, _1{}); + + return { + args.thread, + tma_load_c_real, + tma_load_c_imag, + tma_store_d_real, + tma_store_d_imag + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits_d = cutlass::detail::get_output_alignment_bits(); + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + constexpr int min_tma_aligned_elements_D = tma_alignment_bits_d / cutlass::sizeof_bits::value; + bool implementable = cutlass::detail::check_alignment(cute::make_shape(M,N,L), StrideD{}); + + if constexpr (not cute::is_void_v) { + constexpr int tma_alignment_bits_c = cutlass::detail::get_output_alignment_bits(); + constexpr int min_tma_aligned_elements_C = tma_alignment_bits_c / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), StrideC{}); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + + bool beta_implementable = true; + + if constexpr (cute::is_void_v) { + if constexpr (detail::has_beta::value) { + beta_implementable = args.thread.beta == 0.0; + } + if constexpr (detail::has_beta_ptr::value) { + beta_implementable = beta_implementable && args.thread.beta_ptr == nullptr; + } + } + + if (!beta_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Beta/beta pointer was set, but epilogue is sourceless (void-C).\n"); + } + + return implementable && beta_implementable; + } + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_load_pipe_increment(CtaTileMNK cta_tile_mnk) { + // Compute number of epilogue subtiles + constexpr int epi_m = size<0>(cta_tile_mnk) / size<0>(EpilogueTile{}); + constexpr int epi_n = size<1>(cta_tile_mnk) / size<1>(EpilogueTile{}); + + return epi_m * epi_n; + } + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_store_pipe_increment(CtaTileMNK cta_tile_mnk) { + return get_load_pipe_increment(cta_tile_mnk); + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE static void + prefetch_tma_descriptors(Params const& epilogue_params) { + cute::prefetch_tma_descriptor(epilogue_params.tma_load_c_real.get_tma_descriptor()); + cute::prefetch_tma_descriptor(epilogue_params.tma_load_c_imag.get_tma_descriptor()); + cute::prefetch_tma_descriptor(epilogue_params.tma_store_d_real.get_tma_descriptor()); + cute::prefetch_tma_descriptor(epilogue_params.tma_store_d_imag.get_tma_descriptor()); + } + + CUTLASS_HOST_DEVICE + CollectiveEpilogue(Params const& params_, TensorStorage&) + : params(params_), epilogue_op(params_.thread) {} + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return epilogue_op.is_source_needed(); + } + + template< + bool ReuseTmem = false, + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class MmaTileMNK, + class TiledMma + > + CUTLASS_DEVICE auto + load( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + MmaTileMNK mma_tile_mnk, + TiledMma tiled_mma, + TensorStorage& shared_tensors, + bool reverse_epi_n = false) { + using namespace cute; + + int lane_idx = canonical_lane_idx(); + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; + + auto coord_shape = make_coord(m_coord, n_coord, l_coord); + + // Tile residue + auto m_max_coord = unwrap(cute::transform(make_seq(cta_tile_mnk)>{}, [&](auto i) { + return get<0,i>(problem_shape_mnkl) - get<0,i>(cta_tile_mnk) * get<0,i>(cta_coord_mnkl); + })); + auto n_max_coord = unwrap(cute::transform(make_seq(cta_tile_mnk)>{}, [&](auto i) { + return get<1,i>(problem_shape_mnkl) - get<1,i>(cta_tile_mnk) * get<1,i>(cta_coord_mnkl); + })); + auto residue_mn = make_coord(m_max_coord, n_max_coord); + + // Represent the full source tensor, slice to get the tile this CTA is currently responsible for + Tensor mC_real_mn = params.tma_load_c_real.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor mC_imag_mn = params.tma_load_c_imag.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + + Tensor mC_real = coalesce(mC_real_mn, take<0,2>(cta_tile_mnk)); + Tensor mC_imag = coalesce(mC_imag_mn, take<0,2>(cta_tile_mnk)); + + Tensor gC_real = local_tile(mC_real, take<0,2>(cta_tile_mnk), coord_shape); // (CTA_M,CTA_N) + Tensor gC_imag = local_tile(mC_imag, take<0,2>(cta_tile_mnk), coord_shape); // (CTA_M,CTA_N) + + + // Apply epilogue subtile, get matching smem tensor + auto ptr_sC_real = [&]() { + if constexpr (not ReuseSmemC and is_source_supported) { + return shared_tensors.smem_C_real.begin(); + } + else { + return shared_tensors.smem_D_real.begin(); + } + }(); + auto ptr_sC_imag = [&]() { + if constexpr (not ReuseSmemC and is_source_supported) { + return shared_tensors.smem_C_imag.begin(); + } + else { + return shared_tensors.smem_D_imag.begin(); + } + }(); + + Tensor gC_real_epi = flat_divide(gC_real, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor gC_imag_epi = flat_divide(gC_imag, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + Tensor sC_real_epi = make_tensor(make_smem_ptr(ptr_sC_real), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + Tensor sC_imag_epi = make_tensor(make_smem_ptr(ptr_sC_imag), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + + // Prepare the thread(b)lock's (G)mem to (S)mem TMA tiled copy (bGS_) + ThrCopy thrblk_g2s_real = params.tma_load_c_real.get_slice(Int<0>{}); + ThrCopy thrblk_g2s_imag = params.tma_load_c_imag.get_slice(Int<0>{}); + + Tensor bGS_gC_real = thrblk_g2s_real.partition_S(gC_real_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) + Tensor bGS_gC_imag = thrblk_g2s_imag.partition_S(gC_imag_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) + + Tensor bGS_sC_real = thrblk_g2s_real.partition_D(sC_real_epi); // (TMA,TMA_M,TMA_N,PIPE_C) + Tensor bGS_sC_imag = thrblk_g2s_imag.partition_D(sC_imag_epi); // (TMA,TMA_M,TMA_N,PIPE_C) + + // Predication for TMA load (one thread issues TMA load) + bool issue_tma_load = cute::elect_one_sync(); + + // Acquire the lock for the first stage + uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); + load_pipeline.producer_acquire(load_pipe_producer_state); + + CUTLASS_PRAGMA_UNROLL + for (int iter_n = 0; iter_n < size<3>(gC_real_epi); ++iter_n) { + CUTLASS_PRAGMA_UNROLL + for (int iter_m = 0; iter_m < size<2>(gC_real_epi); ++iter_m) { + int epi_m = iter_m, epi_n = iter_n; + if constexpr (ReuseTmem) { + if (reverse_epi_n) { + epi_n = size<3>(gC_real_epi) - 1 - iter_n; + } + } + // Acquire the lock for this stage + constexpr uint16_t mcast_mask = 0; + uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); + load_pipeline.producer_acquire(load_pipe_producer_state); + + // Execute the TMA load for C + if (issue_tma_load) { + copy(params.tma_load_c_real.with(*tma_barrier, mcast_mask), + bGS_gC_real(_,_,_,epi_m,epi_n), bGS_sC_real(_,_,_,load_pipe_producer_state.index())); + copy(params.tma_load_c_imag.with(*tma_barrier, mcast_mask), + bGS_gC_imag(_,_,_,epi_m,epi_n), bGS_sC_imag(_,_,_,load_pipe_producer_state.index())); + load_pipeline.producer_expect_transaction(load_pipe_producer_state); + } + + // Commit TMA loads for this stage and release the lock + load_pipeline.producer_commit(load_pipe_producer_state); + ++load_pipe_producer_state; + } + } + + return load_pipe_producer_state; + } + + CUTLASS_DEVICE void + load_tail( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state, + [[maybe_unused]] StorePipeline store_pipeline, + [[maybe_unused]] StorePipelineState store_pipe_producer_state) { + load_pipeline.producer_tail(load_pipe_producer_state); + } + + template< + bool ReuseTmem = false, + class AccumulatorPipeline, + class AccumulatorPipelineState, + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class MmaTileMNK, + class TiledMma, + class AccEngine, + class AccLayout + > + CUTLASS_DEVICE auto + store( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + AccumulatorPipeline acc_pipeline, + AccumulatorPipelineState acc_pipe_consumer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + MmaTileMNK mma_tile_mnk, + TiledMma tiled_mma, + cute::Tensor accumulators, + TensorStorage& shared_tensors + ) { + using namespace cute; + using ElementAccumulator = typename AccEngine::value_type; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + + static_assert(is_tmem::value, "Accumulator must be TMEM resident."); + //static_assert(rank(accumulators) == 4, "Accumulators must be MMA-partitioned: [MMA, MMA_M, MMA_N]"); + static_assert(size<1>(accumulators) == 1 && size<2>(accumulators) == 1, "TiledMMA must match partitioned ShapeMN"); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(rank(CtaCoordMNKL{}) == 4, "CoordMNKL must be rank 4"); + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; + int thread_idx = threadIdx.x % ThreadCount; + int warp_idx = thread_idx / NumThreadsPerWarp; + [[maybe_unused]] int lane_idx = thread_idx % NumThreadsPerWarp; + + auto accumulators_real = accumulators(_,_,_,0); + auto accumulators_imag = accumulators(_,_,_,1); + + auto coord_shape = make_coord(m_coord, n_coord, l_coord); + + // Represent the full output tensor, slice to get the tile this CTA is responsible for + Tensor mD_real_mn = params.tma_store_d_real.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor mD_imag_mn = params.tma_store_d_imag.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + + Tensor mD_real = coalesce(mD_real_mn, take<0,2>(cta_tile_mnk)); + Tensor mD_imag = coalesce(mD_imag_mn, take<0,2>(cta_tile_mnk)); + + Tensor gD_real = local_tile(mD_real, take<0,2>(cta_tile_mnk), coord_shape); // (CTA_M,CTA_N) + Tensor gD_imag = local_tile(mD_imag, take<0,2>(cta_tile_mnk), coord_shape); // (CTA_M,CTA_N) + + Tensor tAcc_real = accumulators_real(make_coord(_,_),_0{},_0{}); // (CTA_M,CTA_N) + Tensor tAcc_imag = accumulators_imag(make_coord(_,_),_0{},_0{}); // (CTA_M,CTA_N) + + // Apply epilogue subtiling + Tensor tAcc_real_epi = flat_divide(tAcc_real, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor tAcc_imag_epi = flat_divide(tAcc_imag, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + Tensor gD_real_epi = flat_divide(gD_real, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor gD_imag_epi = flat_divide(gD_imag, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + // Construct the corresponding pipelined smem tensors + auto ptr_sC_real = [&]() { + if constexpr (not ReuseSmemC and is_source_supported) { + return shared_tensors.smem_C_real.begin(); + } + else { + return shared_tensors.smem_D_real.begin(); + } + }(); + auto ptr_sC_imag = [&]() { + if constexpr (not ReuseSmemC and is_source_supported) { + return shared_tensors.smem_C_imag.begin(); + } + else { + return shared_tensors.smem_D_imag.begin(); + } + }(); + + auto ptr_sD_real = shared_tensors.smem_D_real.begin(); + auto ptr_sD_imag = shared_tensors.smem_D_imag.begin(); + + Tensor sC_real_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sC_real), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + Tensor sC_imag_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sC_imag), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + + Tensor sD_real_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sD_real), SmemLayoutD{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_D) + Tensor sD_imag_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sD_imag), SmemLayoutD{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_D) + + // (t)hread-partition for (t)mem to (r)egister copy (tTR_) + TiledCopy tiled_t2r = make_tmem_copy(CopyOpT2R{}, tAcc_real_epi(_,_,_0{},_0{})); + ThrCopy thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_tAcc_real = thread_t2r.partition_S(tAcc_real_epi); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + Tensor tTR_sD_real = thread_t2r.partition_D(sD_real_epi(_,_,_0{})); // (T2R,T2R_M,T2R_N) + Tensor tTR_tAcc_imag = thread_t2r.partition_S(tAcc_imag_epi); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + Tensor tTR_sD_imag = thread_t2r.partition_D(sD_imag_epi(_,_,_0{})); // (T2R,T2R_M,T2R_N) + + // Allocate D and accumulator registers + Tensor tTR_rAcc = make_tensor(append(shape(tTR_sD_real), Int{})); // (T2R,T2R_M,T2R_N,2) + Tensor tTR_rD = make_tensor(append(shape(tTR_sD_real), Int{})); // (T2R,T2R_M,T2R_N,2) + + // Vectorized fragment view + constexpr int FragmentSize = DispatchPolicy::FragmentSize; + Tensor tTR_rAcc_frg = recast>(coalesce(tTR_rAcc)); // (EPI_V) + Tensor tTR_rD_frg = recast>(coalesce(tTR_rD)); // (EPI_V) + + CUTE_STATIC_ASSERT(size(tTR_rAcc) % DispatchPolicy::FragmentSize == 0, "Fragment size does not vectorize properly"); + + // (t)hread-partition for (s)mem to (r)egister copy (tSR_) + TiledCopy tiled_s2r = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx); + Tensor tSR_sC_real = thread_s2r.partition_S(sC_real_epi); // (S2R,S2R_M,S2R_N,PIPE_C) + Tensor tSR_sC_imag = thread_s2r.partition_S(sC_imag_epi); // (S2R,S2R_M,S2R_N,PIPE_C) + + Layout tSR_rC_layout = thread_s2r.retile_D(tTR_rD(_,_,_,_0{})).layout(); // (S2R,S2R_M,S2R_N) + + // Allocate C registers + // If C smem load is a non-vectorized dst(i) = src(i) then we can allocate C registers directly in the compute type + // to eliminate some redundant pack+unpack instruction sequences for sub-word types + constexpr bool IsDirectS2R = cute::is_same_v + && decltype(max_common_vector(tSR_rC_layout, tSR_sC_real.layout()))::value <= 1; + using RegisterElementC = cute::conditional_t; + Tensor tTR_rC = make_tensor(append(shape(tTR_sD_real), _2{})); // (T2R,T2R_M,T2R_N) + Tensor tSR_rC = thread_s2r.retile_D(tTR_rC); // (S2R,S2R_M,S2R_N) + Tensor tTR_rC_frg = recast>(tTR_rC); // (EPI_V) + + // (t)hread-partition for (r)egister to (s)mem copy (tRS_) + TiledCopy tiled_r2s = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); + Tensor tRS_rD = thread_r2s.retile_S(tTR_rD); // (R2S,R2S_M,R2S_N) + Tensor tRS_sD_real = thread_r2s.partition_D(sD_real_epi); // (R2S,R2S_M,R2S_N,PIPE_D) + Tensor tRS_sD_imag = thread_r2s.partition_D(sD_imag_epi); // (R2S,R2S_M,R2S_N,PIPE_D) + + // thread(b)lock-partition for (s)mem to (g)mem copy (bSG_) + ThrCopy thrblk_s2g = params.tma_store_d_real.get_slice(Int<0>{}); + Tensor bSG_sD_real = thrblk_s2g.partition_S(sD_real_epi); // (S2G,S2G_M,S2G_N,PIPE_D) + Tensor bSG_gD_real = thrblk_s2g.partition_D(gD_real_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) + Tensor bSG_sD_imag = thrblk_s2g.partition_S(sD_imag_epi); // (S2G,S2G_M,S2G_N,PIPE_D) + Tensor bSG_gD_imag = thrblk_s2g.partition_D(gD_imag_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) + + // Coordinate tensors and residue for tile quantization + auto m_max_coord = unwrap(cute::transform(make_seq(cta_tile_mnk)>{}, [&](auto i) { + auto c_m = get<0,i>(problem_shape_mnkl) - get<0,i>(cta_tile_mnk) * get<0,i>(cta_coord_mnkl); + return cute::max(0, c_m); + })); + auto n_max_coord = unwrap(cute::transform(make_seq(cta_tile_mnk)>{}, [&](auto i) { + auto c_n = get<1,i>(problem_shape_mnkl) - get<1,i>(cta_tile_mnk) * get<1,i>(cta_coord_mnkl); + return cute::max(0, c_n); + })); + auto residue_mn = make_coord(m_max_coord, n_max_coord); + Tensor cD = make_identity_tensor(take<0,2>(cta_tile_mnk)); + Tensor tTR_cD = thread_t2r.partition_D(flat_divide(cD, EpilogueTile{})); + + bool is_source_needed = epilogue_op.is_source_needed(); + // Thread synchronizer for previously issued waits or fences + // to ensure visibility of smem reads/writes to threads or TMA unit + auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(ThreadCount, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + + // Predication for sub-128 thread T2R tiled copy + Layout tmem_warp_layout = typename decltype(make_tmem_warp_partitioner(tAcc_real_epi(_,_,0,0)))::TiledLayout_TV{}; + constexpr bool predicate_tmem_load = size(tmem_warp_layout) != cosize(tmem_warp_layout); + bool issue_tmem_load = true; + + // If tmem doesn't have enough capacity to support double buffering, a portion of tmem (a column of epilogue tiles) + // is overlapped between 2 pseudo-buffers. The shared tmem portion corresponds to the last epilogue tile column of + // tmem accumulator buffer 0, and the first epilogue tile column of tmem accumulator 1. + // Thus, whenever we are processing tmem accumulator buffer 0, we process the epilogue tiles with reversed column order. + // Once the last epilogue tile column is loaded from tmem, the acc_pipeline is released. + // Then, the next accumulation stage for buffer 1 can start. + [[maybe_unused]] bool reverse_epi_n = ReuseTmem && acc_pipe_consumer_state.phase() == 0; + static_assert(not (ReuseTmem && AccumulatorPipeline::Stages != 1), "Tmem reuse requires 1 accumulator stage"); + + // Predication for TMA store (one warp issues TMA store) + bool issue_tma_store = warp_idx == 0; + + // In the reuse smem configuration we have StagesC smem buffers and at most StagesD committed TMA stores in flight. + // The TMA store pipeline producer acquire returns when at most StagesD-1 committed stores are in-flight, so we can + // only guarantee store completion after StagesD iterations, then we can begin issuing releases on the smem buffer locks. + // store_pipe_producer_state tracks the acquire and load_pipe_consumer_state tracks the release, in circular buffer fashion. + // If TMA store supported async transaction mbarriers we would not need this synchronous release behavior. + LoadPipelineState load_wait_state = load_pipe_consumer_state; + if constexpr (ReuseSmemC) { + load_wait_state = store_pipe_producer_state; + load_wait_state.phase_ ^= 1; + } + + // We can delay issue of TMA store by one iteration to achieve better interleaving of non-TMA instructions + // Sync requirements of smem reuse may preclude this optimization + [[maybe_unused]] int epi_m_prev = 0; + [[maybe_unused]] int epi_n_prev = 0; + static_assert(not (DelayTmaStore and ReuseSmemC and StagesC <= StagesD), "This TMA epilogue configuration will deadlock"); + + // The TMA store sequence for one subtile iteration + auto tma_store_fn = [&] (int epi_m, int epi_n) { + // Write the tile from smem to gmem with TMA + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + synchronize(); // ensure all threads have issued their async fence + if (issue_tma_store) { + copy(params.tma_store_d_real, bSG_sD_real(_,_,_,store_pipe_producer_state.index()), bSG_gD_real(_,_,_,epi_m,epi_n)); + copy(params.tma_store_d_imag, bSG_sD_imag(_,_,_,store_pipe_producer_state.index()), bSG_gD_imag(_,_,_,epi_m,epi_n)); + } + + // Commit the TMA stores for this stage + if (issue_tma_store) { + store_pipeline.producer_commit(store_pipe_producer_state); + } + ++store_pipe_producer_state; + + // Wait for the next smem buffer to be available + if (issue_tma_store) { + store_pipeline.producer_acquire(store_pipe_producer_state); + } + synchronize(); + + if constexpr (ReuseSmemC) { + // producer_acquire returns when at most StagesD-1 committed stores are pending + bool store_finished = store_pipe_producer_state.count() > StorePipeline::UnacquiredStages; + // Let dma warp know earliest smem buffer is consumed and empty after StagesD producer commits + if (store_finished) { + if (is_source_needed) { + load_pipeline.consumer_release(load_pipe_consumer_state); + } + ++load_pipe_consumer_state; + } + } + }; + + // + // BEGIN EPILOGUE + // + + // Begin the wait for the producer load results + ConsumerToken load_wait_token{BarrierStatus::WaitDone}; + if (is_source_needed) { + load_wait_token = load_pipeline.consumer_try_wait(load_wait_state); + } + // Begin the wait for the accumulator results + ConsumerToken acc_wait_token = acc_pipeline.consumer_try_wait(acc_pipe_consumer_state); + + // For each epilogue subtile within the CTA tile + CUTLASS_PRAGMA_UNROLL + for (int iter_n = 0; iter_n < size<3>(gD_real_epi); ++iter_n) { + CUTLASS_PRAGMA_UNROLL + for (int iter_m = 0; iter_m < size<2>(gD_real_epi); ++iter_m) { + int epi_m = iter_m, epi_n = iter_n; + bool is_first_iteration = iter_m == 0 && iter_n == 0; + bool is_last_iteration = iter_m == size<2>(gD_real_epi)-1 && iter_n == size<3>(gD_real_epi)-1; + bool do_acc_release = is_last_iteration; + + // Reverse subtile order for tmem reuse if necessary + if constexpr (ReuseTmem) { + if (reverse_epi_n) { + epi_n = size<3>(gD_real_epi) - 1 - iter_n; + } + do_acc_release = iter_m == size<2>(gD_real_epi)-1 && iter_n == 0; + } + + if (is_source_needed) { + // Wait for the producer load to fill smem + load_pipeline.consumer_wait(load_wait_state, load_wait_token); + + // Copy source tile from smem to register // residual smem -> reg + copy(tiled_s2r, tSR_sC_real(_,_,_,load_wait_state.index()), tSR_rC(_,_,_,0)); + copy(tiled_s2r, tSR_sC_imag(_,_,_,load_wait_state.index()), tSR_rC(_,_,_,1)); + } + + if (is_source_needed) { + // Let producer load warp know smem buffers are consumed and empty + if constexpr (not ReuseSmemC) { + cutlass::arch::fence_view_async_shared(); + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } + ++load_wait_state; + } + + if (is_first_iteration) { + // Wait for mma warp to fill tmem buffer with accumulator results + acc_pipeline.consumer_wait(acc_pipe_consumer_state, acc_wait_token); + } + + // The current tile in tmem + Tensor tTR_tAcc_real_mn = tTR_tAcc_real(_,_,_,epi_m,epi_n); + Tensor tTR_tAcc_imag_mn = tTR_tAcc_imag(_,_,_,epi_m,epi_n); + + // Compute tmem load predication if necessary + if constexpr (predicate_tmem_load) { + // Issue tmem load if this tile's tmem subpartition is accessible by this warp + int subpart_idx = (tTR_tAcc_real_mn.data().dp_ / 32) % 4; + issue_tmem_load = warp_idx == subpart_idx; + } + + // Copy accumulator tile from tmem to register + if (issue_tmem_load) { // acc tmem -> reg + copy(tiled_t2r, tTR_tAcc_real_mn, tTR_rAcc(_,_,_,0)); + copy(tiled_t2r, tTR_tAcc_imag_mn, tTR_rAcc(_,_,_,1)); + } + + // After the last tmem load, signal that tmem buffer is consumed and empty + if (do_acc_release) { + cutlass::arch::fence_view_async_tmem_load(); + acc_pipeline.consumer_release(acc_pipe_consumer_state); + ++acc_pipe_consumer_state; + } + + // Vectorized fragment loop with visitor callback entry point + if (epilogue_op.is_source_needed()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rD_frg); ++i) { + tTR_rD_frg(i) = epilogue_op(tTR_rAcc_frg(i), tTR_rC_frg(i)); + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rD_frg); ++i) { + tTR_rD_frg(i) = epilogue_op(tTR_rAcc_frg(i)); + } + } + + if constexpr (DelayTmaStore) { + // Issue TMA stores for the previous subtile + if (not is_first_iteration) { + tma_store_fn(epi_m_prev, epi_n_prev); + } + epi_m_prev = epi_m; + epi_n_prev = epi_n; + } + + // Copy output tile from register to smem + bool issue_smem_store = issue_tmem_load; + if (issue_smem_store) { // after scale, reg -> smem + copy(tiled_r2s, tRS_rD(_,_,_,0), tRS_sD_real(_,_,_,store_pipe_producer_state.index())); + copy(tiled_r2s, tRS_rD(_,_,_,1), tRS_sD_imag(_,_,_,store_pipe_producer_state.index())); + } + + if constexpr (not DelayTmaStore) { + // Issue TMA stores for this subtile + tma_store_fn(epi_m, epi_n); + } + + if (is_source_needed) { + // Begin the wait for the next subtile producer load + load_wait_token = load_pipeline.consumer_try_wait(load_wait_state, is_last_iteration); + } + } // for epi_m + } // for epi_n + + if constexpr (DelayTmaStore) { + // Issue TMA stores for the last subtile + tma_store_fn(epi_m_prev, epi_n_prev); + } + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state, acc_pipe_consumer_state); + } + + template + CUTLASS_DEVICE void + store_tail( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + CtaTileMNK cta_tile_mnk) { + if constexpr (ReuseSmemC) { + if (epilogue_op.is_source_needed()) { + // wait for all TMA stores to complete + store_pipeline.producer_tail(store_pipe_producer_state); + + // Issue releases on up to StagesD-1 previously issued TMA stores + constexpr int release_stages = cute::min(StorePipeline::UnacquiredStages, get_load_pipe_increment(cta_tile_mnk)); + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < release_stages; ++stage) { + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } + } + } + } + +private: + Params const& params; + ThreadEpilogueOp epilogue_op; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp b/3rd/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp index 412a4b7..4ba6032 100644 --- a/3rd/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp b/3rd/cutlass/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp index c2b8d84..322035a 100644 --- a/3rd/cutlass/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp b/3rd/cutlass/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp index 5030efd..c57d56d 100644 --- a/3rd/cutlass/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp b/3rd/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp index af53a1c..5601988 100644 --- a/3rd/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -507,8 +507,7 @@ class CollectiveEpilogue< int thread_idx, TensorStorage& shared_tensors, TensorMapC const& load_tensormap, - int subtile_idx=-1, - bool wait_until_load_finishes = false) { + int subtile_idx=-1) { using namespace cute; // Indexing variables @@ -595,12 +594,6 @@ class CollectiveEpilogue< // Post-loop fusion callback entry point pld_callbacks.end(); - if (wait_until_load_finishes && did_load) { - typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_tma_consumer_state = - {last_load_producer_state.index(), !last_load_producer_state.phase(), last_load_producer_state.count()}; - load_pipeline.consumer_wait(epi_load_pipe_tma_consumer_state); - } - return load_pipe_producer_state; } diff --git a/3rd/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp b/3rd/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp index 062b9a8..c15c472 100644 --- a/3rd/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -751,13 +751,13 @@ class CollectiveEpilogue< ++store_pipe_producer_state; ++issued_stores; - // Wait for the next smem buffer to be available - if (issue_tma_store) { - store_pipeline.producer_acquire(store_pipe_producer_state); - } - synchronize(); - if constexpr (ReuseSmemC) { + + // Wait for the next smem buffer to be available + if (issue_tma_store) { + store_pipeline.producer_acquire(store_pipe_producer_state); + } + synchronize(); // producer_acquire returns when at most StagesD-1 committed stores are pending bool store_finished = issued_stores > StorePipeline::UnacquiredStages; // Let dma warp know earliest smem buffer is consumed and empty after StagesD producer commits @@ -866,7 +866,13 @@ class CollectiveEpilogue< epi_m_prev = epi_m; epi_n_prev = epi_n; } - + if constexpr (not ReuseSmemC) { + // Wait for the smem buffer to be available + if (issue_tma_store) { + store_pipeline.producer_acquire(store_pipe_producer_state); + } + synchronize(); + } // Smem reduction callback entry point using current store buffer for workspace cst_callbacks.reduce(sD_epi(_,_,store_pipe_producer_state.index()), synchronize, epi_m, epi_n, is_last_iteration, tRS_rCompute_frg); @@ -885,7 +891,6 @@ class CollectiveEpilogue< for (int i = 0; i < size(tRS_rD_frg); ++i) { tRS_rD_frg(i) = cutlass::NumericArrayConverter{}(tRS_rCompute_frg(i)); } - // Copy tile from register to smem if constexpr (is_destination_supported) { copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); diff --git a/3rd/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp b/3rd/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp index 2d5fd85..efb5a4f 100644 --- a/3rd/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/dispatch_policy.hpp b/3rd/cutlass/include/cutlass/epilogue/dispatch_policy.hpp index 2e6213f..f1a53e2 100644 --- a/3rd/cutlass/include/cutlass/epilogue/dispatch_policy.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/dispatch_policy.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -61,13 +61,28 @@ struct PtrArrayTmaWarpSpecializedCooperative { static constexpr int NumEpilogueW // Blackwell direct store schedules struct NoSmemWarpSpecialized1Sm {}; struct NoSmemWarpSpecialized2Sm {}; +struct FastF32NoSmemWarpSpecialized1Sm : NoSmemWarpSpecialized1Sm {}; +struct FastF32NoSmemWarpSpecialized2Sm : NoSmemWarpSpecialized2Sm {}; +struct BlockwiseNoSmemWarpSpecialized1Sm : NoSmemWarpSpecialized1Sm {}; +struct BlockwiseNoSmemWarpSpecialized2Sm : NoSmemWarpSpecialized2Sm {}; struct PtrArrayNoSmemWarpSpecialized1Sm : NoSmemWarpSpecialized1Sm {}; struct PtrArrayNoSmemWarpSpecialized2Sm : NoSmemWarpSpecialized2Sm {}; +struct PtrArrayFastF32NoSmemWarpSpecialized1Sm : PtrArrayNoSmemWarpSpecialized1Sm {}; +struct PtrArrayFastF32NoSmemWarpSpecialized2Sm : PtrArrayNoSmemWarpSpecialized2Sm {}; +struct PtrArrayBlockwiseNoSmemWarpSpecialized1Sm : PtrArrayNoSmemWarpSpecialized1Sm {}; +struct PtrArrayBlockwiseNoSmemWarpSpecialized2Sm : PtrArrayNoSmemWarpSpecialized2Sm {}; +struct PtrArrayPlanarComplexNoSmemWarpSpecialized1Sm : PtrArrayNoSmemWarpSpecialized1Sm {}; +struct PtrArrayPlanarComplexNoSmemWarpSpecialized2Sm : PtrArrayNoSmemWarpSpecialized2Sm {}; // Blackwell TMA schedules struct TmaWarpSpecialized1Sm {}; struct TmaWarpSpecialized2Sm {}; struct PtrArrayTmaWarpSpecialized1Sm : TmaWarpSpecialized1Sm {}; struct PtrArrayTmaWarpSpecialized2Sm : TmaWarpSpecialized2Sm {}; + +struct PlanarComplexTmaWarpSpecialized1Sm : TmaWarpSpecialized1Sm {}; +struct PlanarComplexTmaWarpSpecialized2Sm : TmaWarpSpecialized2Sm {}; +struct PtrArrayPlanarComplexTmaWarpSpecialized1Sm : PlanarComplexTmaWarpSpecialized1Sm {}; +struct PtrArrayPlanarComplexTmaWarpSpecialized2Sm : PlanarComplexTmaWarpSpecialized2Sm {}; struct TmaWarpSpecialized1SmNvf4 final : TmaWarpSpecialized1Sm {}; struct TmaWarpSpecialized2SmNvf4 final : TmaWarpSpecialized2Sm {}; struct TmaWarpSpecialized1SmMxf4 final : TmaWarpSpecialized1Sm {}; @@ -234,11 +249,63 @@ struct Sm100PtrArrayTmaWarpSpecialized { static_assert(StagesD >= 1, "StagesD must be >= 1"); }; -// default elementwise operator epilogue without smem -struct Sm100NoSmem {}; -struct Sm100NoSmemWarpSpecialized {}; -struct Sm100PtrArrayNoSmem {}; -struct Sm100PtrArrayNoSmemWarpSpecialized {}; +struct Sm100NoSmem { + constexpr static int StagesC = 1; + constexpr static int StagesD = 1; + constexpr static int FragmentSize = 1; +}; + +struct Sm100NoSmemWarpSpecialized { + constexpr static int StagesC = 1; + constexpr static int StagesD = 1; + constexpr static int FragmentSize = 1; +}; +struct Sm100PtrArrayNoSmem { + constexpr static int StagesC = 1; + constexpr static int StagesD = 1; + constexpr static int FragmentSize = 1; +}; + +struct Sm100PtrArrayNoSmemWarpSpecialized { + constexpr static int StagesC = 1; + constexpr static int StagesD = 1; + constexpr static int FragmentSize = 1; +}; +struct Sm100PtrArrayPlanarComplexNoSmem {}; +struct Sm100PtrArrayPlanarComplexNoSmemWarpSpecialized {}; + +template< + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_ +> +struct Sm100PlanarComplexTmaWarpSpecialized + : public Sm100TmaWarpSpecialized +{ +}; + + +template< + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_ +> +struct Sm100PtrArrayPlanarComplexTmaWarpSpecialized + : public Sm100TmaWarpSpecialized +{ +}; template< int StagesC_, diff --git a/3rd/cutlass/include/cutlass/epilogue/fusion/callbacks.hpp b/3rd/cutlass/include/cutlass/epilogue/fusion/callbacks.hpp index f9febee..5837be0 100644 --- a/3rd/cutlass/include/cutlass/epilogue/fusion/callbacks.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/fusion/callbacks.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/fusion/operations.hpp b/3rd/cutlass/include/cutlass/epilogue/fusion/operations.hpp index 8cac28f..1ee6851 100644 --- a/3rd/cutlass/include/cutlass/epilogue/fusion/operations.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/fusion/operations.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -57,6 +57,7 @@ struct FusionOperation { using ElementSource = void; static constexpr bool IsSourceSupported = false; + static constexpr bool IsResidualSupported = false; // Source is added after activation using ElementScalar = void; static constexpr int AlignmentScalar = 0; @@ -317,6 +318,24 @@ struct PerColLinCombPerColBiasEltAct static constexpr bool IsPerColScaleSupported = true; }; +// D = activation(per-col alpha * acc + per-column bias) + per-col beta * C +template< + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, // per-row alpha/beta + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + int AlignmentScalar_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct PerColResAddPerColBiasEltAct + : PerColLinCombPerColBiasEltAct { + static constexpr bool IsResidualSupported = true; +}; + // Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias // if D is fp8 // D = scale_d * activation(Z) diff --git a/3rd/cutlass/include/cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp b/3rd/cutlass/include/cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp index d81e3b4..dcc1c95 100644 --- a/3rd/cutlass/include/cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -1280,6 +1280,39 @@ struct FusionCallbacks< }; +// -------------------------------------------------------------------- +// Sm100PtrArrayNoSmemWarpSpecialized (direct-store, grouped GEMM) +// -------------------------------------------------------------------- +template < + class Operation, + class CtaTile_MNK, + class EpilogueTile_MN, + class... Args +> +struct FusionCallbacks< + epilogue::Sm100PtrArrayNoSmemWarpSpecialized, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args...> + : FusionCallbacks< + // reuse the ptr-array *TMA* callbacks with 0 stages + epilogue::Sm100PtrArrayTmaWarpSpecialized<0,0,0,false,false>, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args...> { + + using Base = FusionCallbacks< + epilogue::Sm100PtrArrayTmaWarpSpecialized<0,0,0,false,false>, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args...>; + + // bring ctors into scope + using Base::Base; +}; } // namespace cutlass::epilogue::fusion diff --git a/3rd/cutlass/include/cutlass/epilogue/fusion/sm100_visitor_compute_tma_warpspecialized.hpp b/3rd/cutlass/include/cutlass/epilogue/fusion/sm100_visitor_compute_tma_warpspecialized.hpp index a205912..b5dd3e8 100644 --- a/3rd/cutlass/include/cutlass/epilogue/fusion/sm100_visitor_compute_tma_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/fusion/sm100_visitor_compute_tma_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp b/3rd/cutlass/include/cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp index d026b15..8a6b559 100644 --- a/3rd/cutlass/include/cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp b/3rd/cutlass/include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp index b769b1f..47feea7 100644 --- a/3rd/cutlass/include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp b/3rd/cutlass/include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp index e72e971..22f4800 100644 --- a/3rd/cutlass/include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp b/3rd/cutlass/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp index 87258c6..0a931dd 100644 --- a/3rd/cutlass/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -1306,6 +1306,114 @@ struct FusionCallbacks< ///////////////////////////////////////////////////////////////////////////////////////////////// +// D = activation(per-col alpha * acc + per-column bias) + per-col beta * C +template< + class CtaTileShapeMNK, + class EpilogueTile, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + int AlignmentScalar = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90PerColResAddPerColBiasEltAct = + Sm90EVT, // beta * C + activation(alpha * acc + bias) + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride<_0,bool,int64_t>, AlignmentScalar>, // beta, dynamic scalar/vector broadcast + Sm90SrcFetch, // C + Sm90EVT, // activation(alpha * acc + bias) + Sm90EVT, // alpha * acc + bias + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride<_0,bool,int64_t>, AlignmentScalar>, // alpha, dynamic scalar/vector broadcast + Sm90AccFetch, // acc + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_0,_1,int64_t>, AlignmentBias> // bias + > + > + >; + + template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + int AlignmentScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::PerColResAddPerColBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm90PerColResAddPerColBiasEltAct< + CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + > { + + using Impl = + Sm90PerColResAddPerColBiasEltAct< + CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >; + using Operation = + fusion::PerColResAddPerColBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,bool,int64_t>; + using StrideBeta = Stride<_0,bool,int64_t>; + StrideAlpha dAlpha = {_0{}, bool(1), 0}; + StrideBeta dBeta = {_0{}, bool(1), 0}; + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // ternary op : beta * C + activation(alpha * acc + bias) + {beta_ptr, beta, dBeta}, // leaf args : beta + {}, // leaf args : C + { // unary op : activation(alpha * acc + bias) + { // ternary op : alpha * acc + bias + {alpha_ptr, alpha, dAlpha}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }, // end unary op + {} // ternary args : multiply_add + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + namespace detail { template diff --git a/3rd/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp b/3rd/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp index 6aec0e8..877077c 100644 --- a/3rd/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -591,7 +591,7 @@ struct Sm90TreeVisitor< auto [M, N, K, L] = args.problem_shape_mnkl; auto [m, n, k, l] = args.tile_coord_mnkl; - gmem_ptr ptr_aux = make_gmem_ptr(subbyte_iterator(params_aux.ptr_aux)); + gmem_ptr ptr_aux = make_gmem_ptr(params_aux.ptr_aux); Tensor mAux = make_tensor(ptr_aux, make_layout(make_shape(M,N,L), params_aux.dAux)); // (M,N,L) Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) @@ -765,7 +765,7 @@ struct Sm90AuxLoad< auto [M, N, K, L] = args.problem_shape_mnkl; auto [m, n, k, l] = args.tile_coord_mnkl; - gmem_ptr ptr_aux = make_gmem_ptr(subbyte_iterator(params.ptr_aux)); + gmem_ptr ptr_aux = make_gmem_ptr(params.ptr_aux); Tensor mAux = make_tensor(ptr_aux, make_layout(make_shape(M,N,L), params.dAux)); // (M,N,L) Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) diff --git a/3rd/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp b/3rd/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp index 72afd1e..a584756 100644 --- a/3rd/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -871,11 +871,11 @@ struct Sm90ScalarBroadcastPtrArray { template CUTLASS_DEVICE auto get_producer_load_callbacks(ProducerLoadArgs const& args) { - // Get the scalar for batched broadcast - if (size<2>(params_ptr->dScalar[0]) != 0) { - auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; - update_scalar(l_coord); - } + // Always refresh scalar with the current group index so per-group + // alpha/beta values (provided through pointer arrays) are loaded + // correctly even when the L-stride is zero. + auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; + update_scalar(l_coord); return EmptyProducerLoadCallbacks{}; } @@ -904,12 +904,8 @@ struct Sm90ScalarBroadcastPtrArray { > CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - - // Get the scalar for batched broadcast - if (get<2>(params_ptr->dScalar[0]) != 0) { - auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; - update_scalar(l_coord); - } + auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; + update_scalar(l_coord); return ConsumerStoreCallbacks(scalar); } @@ -920,13 +916,15 @@ struct Sm90ScalarBroadcastPtrArray { int l_offset = l_coord * size<2>(params_ptr->dScalar[0]); if (params_ptr->scalar_ptr_arrays[0] != nullptr) { - scalar = *(params_ptr->scalar_ptr_arrays[0][l_offset]); + // Pointer-array variant: each entry already points to the scalar of a group. + scalar = *(params_ptr->scalar_ptr_arrays[0][l_coord]); } else if (params_ptr->scalar_ptrs[0] != nullptr) { + // Strided pointer variant. scalar = params_ptr->scalar_ptrs[0][l_offset]; } else { - // batch stride is ignored for nullptr fallback + // Literal fallback. scalar = params_ptr->scalars[0]; } @@ -936,15 +934,13 @@ struct Sm90ScalarBroadcastPtrArray { for (int i = 1; i < BroadcastCount; ++i) { if (params_ptr->scalar_ptr_arrays[i] != nullptr) { - int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]); - scalar = reduction_fn(scalar, *(params_ptr->scalar_ptr_arrays[i][rest_l_offset])); + scalar = reduction_fn(scalar, *(params_ptr->scalar_ptr_arrays[i][l_coord])); } - if (params_ptr->scalar_ptrs[i] != nullptr) { + else if (params_ptr->scalar_ptrs[i] != nullptr) { int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]); scalar = reduction_fn(scalar, params_ptr->scalar_ptrs[i][rest_l_offset]); } else { - // batch stride is ignored for nullptr fallback scalar = reduction_fn(scalar, params_ptr->scalars[i]); } } diff --git a/3rd/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp b/3rd/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp index 29b9d1d..ee78ca6 100644 --- a/3rd/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -1173,8 +1173,9 @@ struct Sm90RowReduction { CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { Layout ref_layout_MN = [&] () { - if constexpr (ReferenceSrc) { return get<0>(args.tiled_copy.get_layoutS_MN()); } - else { return get<0>(args.tiled_copy.get_layoutD_MN()); } + auto mn_shape = shape(typename decltype(args.tiled_copy)::Tiler_MN{}); + if constexpr (ReferenceSrc) { return right_inverse(args.tiled_copy.get_layoutS_TV()).with_shape(mn_shape); } + else { return right_inverse(args.tiled_copy.get_layoutD_TV()).with_shape(mn_shape); } }(); // tile_mn -> tv_idx // Get the MN layout + coord of lanes to determine shuffle reduction iterations @@ -1650,8 +1651,9 @@ struct Sm90ColReduction { CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { Layout ref_layout_MN = [&] () { - if constexpr (ReferenceSrc) { return get<0>(args.tiled_copy.get_layoutS_MN()); } - else { return get<0>(args.tiled_copy.get_layoutD_MN()); } + auto mn_shape = shape(typename decltype(args.tiled_copy)::Tiler_MN{}); + if constexpr (ReferenceSrc) { return right_inverse(args.tiled_copy.get_layoutS_TV()).with_shape(mn_shape); } + else { return right_inverse(args.tiled_copy.get_layoutD_TV()).with_shape(mn_shape); } }(); // tile_mn -> tv_idx // Get the MN layout + coord of lanes to determine shuffle reduction iterations diff --git a/3rd/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp b/3rd/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp index 93720f8..5d4e9de 100644 --- a/3rd/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp b/3rd/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp index 330e1fd..b28dba2 100644 --- a/3rd/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -93,7 +93,7 @@ Array top_2_reduce(Array a, Array b) { " setp.gtu.f32 p, %2, %4;\n" // a0 > b0 " selp.f32 %1, mx.x, mx.y, p;\n" // a0 > b0 ? max(a1, b0) : max(a0, b1) " selp.f32 %0, %2, %4, p;\n" // a0 > b0 ? a0 : b0 - "}\n" : "=f"(out[0]), "=f"(out[1]) : + "}\n" : "=f"(out[0]), "=f"(out[1]) : "f"(a[0]), "f"(a[1]), "f"(b[0]), "f"(b[1])); return out; } @@ -117,8 +117,8 @@ Array top_4_reduce_scalar(Array a, float scalar) { " selp.f32 %1, %5, %8, p1;\n" // a0 = a1 > b ? a1 : b " selp.f32 %1, %1, %4, p0;\n" // a0 > b ? max(a1, b) : a0 == a0 > b ? a0 : old_a0 " selp.f32 %0, %4, %8, p0;\n" // a0 = a0 > b ? a0 : b - "}\n" : - "=f"(out[0]), "=f"(out[1]), "=f"(out[2]), "=f"(out[3]) : + "}\n" : + "=f"(out[0]), "=f"(out[1]), "=f"(out[2]), "=f"(out[3]) : "f"(a[0]), "f"(a[1]), "f"(a[2]), "f"(a[3]), "f"(scalar)); return out; } @@ -187,8 +187,8 @@ Array top_4_reduce(Array a, Array b) { " selp.f32 %3, mxa2b1, %3, pa1b1;\n" // a3 = a1 > b1 ? max(a2, b1) ** second most likely case " selp.f32 %3, mxa3b0, %3, pa2b0;\n" // a0 > a1 > a2 > b0 " selp.f32 %3, mxa0b3, %3, pb2a0;\n" // b0 > b1 > b2 > a0 - "}\n" : - "=f"(out[0]), "=f"(out[1]), "=f"(out[2]), "=f"(out[3]) : + "}\n" : + "=f"(out[0]), "=f"(out[1]), "=f"(out[2]), "=f"(out[3]) : "f"(a[0]), "f"(a[1]), "f"(a[2]), "f"(a[3]), "f"(b[0]), "f"(b[1]), "f"(b[2]), "f"(b[3])); return out; @@ -351,7 +351,7 @@ struct Sm90TopKSoftmaxColReduction { // we can track logsumexp instead of tracking two variables (sum of exps and the max). // In addition, subtracting logsumexp from any element and taking its exp is equivalent to // computing its softmax. - // + // // The overlap between softmax and top-K is that we don't need to reduce logsumexp along the // way at all, because any element not in the top-K is going to be masked out and set to 0. // Therefore, we only reduce the top-K elements, and when done, compute their logsumexp and @@ -370,7 +370,7 @@ struct Sm90TopKSoftmaxColReduction { ReductionResult() { } CUTLASS_DEVICE - ReductionResult(ElementCompute min, ElementCompute logsumexp): + ReductionResult(ElementCompute min, ElementCompute logsumexp): logsumexp_(logsumexp), min_(min) { } // Warp shuffle broadcast @@ -541,7 +541,7 @@ struct Sm90TopKSoftmaxColReduction { visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, Array const& frg_input) { - auto& [tCrTopK, tCrSoftmax, tCcCol, cCol, + auto& [tCrTopK, tCrSoftmax, tCcCol, cCol, lane_layout_MN, lane_mn, residue_cCol, residue_tCcCol] = args_tuple; Tensor tCcCol_mn = tCcCol(_,_,_,epi_m,epi_n); @@ -566,7 +566,7 @@ struct Sm90TopKSoftmaxColReduction { CUTLASS_DEVICE void reduce(STensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { - auto& [tCrTopK, tCrSoftmax, tCcCol, cCol, + auto& [tCrTopK, tCrSoftmax, tCcCol, cCol, lane_layout_MN, lane_mn, residue_cCol, residue_tCcCol] = args_tuple; @@ -668,7 +668,7 @@ struct Sm90TopKSoftmaxColReduction { CUTLASS_DEVICE void end_loop(int epi_m, int epi_n) { - auto& [tCrTopK, tCrSoftmax, tCcCol, cCol, + auto& [tCrTopK, tCrSoftmax, tCcCol, cCol, lane_layout_MN, lane_mn, residue_cCol, residue_tCcCol] = args_tuple; @@ -690,8 +690,9 @@ struct Sm90TopKSoftmaxColReduction { CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { Layout ref_layout_MN = [&] () { - if constexpr (ReferenceSrc) { return get<0>(args.tiled_copy.get_layoutS_MN()); } - else { return get<0>(args.tiled_copy.get_layoutD_MN()); } + auto mn_shape = shape(typename decltype(args.tiled_copy)::Tiler_MN{}); + if constexpr (ReferenceSrc) { return right_inverse(args.tiled_copy.get_layoutS_TV()).with_shape(mn_shape); } + else { return right_inverse(args.tiled_copy.get_layoutD_TV()).with_shape(mn_shape); } }(); // tile_mn -> tv_idx // Get the MN layout + coord of lanes to determine shuffle reduction iterations @@ -739,7 +740,7 @@ struct Sm90TopKSoftmaxColReduction { Tensor tRS_rSoftmax = thread_r2s.retile_S(tCrColReduce); // ((R2S,R2S_V),MMA_M,MMA_N) auto tCrC_layout = args.tCrC.layout(); // (R2S,R2S_M,R2S_N) - // Compose the new accumulator R2S layout with the expected tCrC layout to get final + // Compose the new accumulator R2S layout with the expected tCrC layout to get final // reduction tensor layout. auto tCrSoftmax_layout = take<0, 3>(tRS_rSoftmax.layout()).compose(tCrC_layout); // (R2S,R2S_V) o (R2S,R2S_M,R2S_N) diff --git a/3rd/cutlass/include/cutlass/epilogue/thread/activation.h b/3rd/cutlass/include/cutlass/epilogue/thread/activation.h index 8412b50..e2a04d0 100644 --- a/3rd/cutlass/include/cutlass/epilogue/thread/activation.h +++ b/3rd/cutlass/include/cutlass/epilogue/thread/activation.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/thread/conversion_op.h b/3rd/cutlass/include/cutlass/epilogue/thread/conversion_op.h index 432906a..7da27da 100644 --- a/3rd/cutlass/include/cutlass/epilogue/thread/conversion_op.h +++ b/3rd/cutlass/include/cutlass/epilogue/thread/conversion_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -62,6 +62,7 @@ class Convert { using ElementOutput = ElementOutput_; using ElementAccumulator = ElementAccumulator_; using ElementCompute = ElementAccumulator_; + using ElementD = ElementOutput; // for use with cute::collective::DefaultEpilogue static int const kCount = Count; @@ -123,6 +124,21 @@ class Convert { return destination_converter(accumulator); } + + // + // Specializations for scalar (for use with cute::collective::DefaultEpilogue) + // + CUTLASS_HOST_DEVICE + ElementD operator()(ElementAccumulator const accumulator, ElementAccumulator const source) const { + NumericConverter destination_converter; + return destination_converter(source); + } + + CUTLASS_HOST_DEVICE + ElementD operator()(ElementAccumulator const accumulator) const { + NumericConverter destination_converter; + return destination_converter(accumulator); + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/epilogue/thread/detail.hpp b/3rd/cutlass/include/cutlass/epilogue/thread/detail.hpp index a132134..dabb4ea 100644 --- a/3rd/cutlass/include/cutlass/epilogue/thread/detail.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/thread/detail.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination.h b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination.h index 05a1f79..c7dbd14 100644 --- a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination.h +++ b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h index 0b6aa71..3f57f73 100644 --- a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h +++ b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_bias_relu.h b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_bias_relu.h index 76d80f2..89a460a 100644 --- a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_bias_relu.h +++ b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_bias_relu.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_clamp.h b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_clamp.h index 7abed26..0b36443 100644 --- a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_clamp.h +++ b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_clamp.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_dgelu.h b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_dgelu.h index 2aefe91..21a2972 100644 --- a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_dgelu.h +++ b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_dgelu.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_drelu.h b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_drelu.h index 9ecb015..5a86fe1 100644 --- a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_drelu.h +++ b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_drelu.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_gelu.h b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_gelu.h index 3e82d2c..cac77bb 100644 --- a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_gelu.h +++ b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_gelu.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_generic.h b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_generic.h index a2acd49..65f4f4b 100644 --- a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_generic.h +++ b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_generic.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_generic_with_scaling.h b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_generic_with_scaling.h index c8a8083..a928a4a 100644 --- a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_generic_with_scaling.h +++ b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_generic_with_scaling.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_hardswish.h b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_hardswish.h index 4315a9b..1dab5be 100644 --- a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_hardswish.h +++ b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_hardswish.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_leaky_relu.h b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_leaky_relu.h index 24b507e..fd85911 100644 --- a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_leaky_relu.h +++ b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_leaky_relu.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_params.h b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_params.h index 2a7136a..c9f5d86 100644 --- a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_params.h +++ b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_params.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_planar_complex.h b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_planar_complex.h index 212084a..ed40252 100644 --- a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_planar_complex.h +++ b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_planar_complex.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_relu.h b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_relu.h index 134ddde..4b28f8f 100644 --- a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_relu.h +++ b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_relu.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_relu0.h b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_relu0.h index bbfa4a3..3102f19 100644 --- a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_relu0.h +++ b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_relu0.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_residual_block.h b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_residual_block.h index 219ab25..6aa17af 100644 --- a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_residual_block.h +++ b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_residual_block.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_sigmoid.h b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_sigmoid.h index 481eb00..f8c1dda 100644 --- a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_sigmoid.h +++ b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_sigmoid.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_silu.h b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_silu.h index 438bfa6..a209fc0 100644 --- a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_silu.h +++ b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_silu.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_tensor_broadcast.hpp b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_tensor_broadcast.hpp index b36501b..7e22ca7 100644 --- a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_tensor_broadcast.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_tensor_broadcast.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_with_elementwise.h b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_with_elementwise.h index 7dd3b3e..9feee18 100644 --- a/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_with_elementwise.h +++ b/3rd/cutlass/include/cutlass/epilogue/thread/linear_combination_with_elementwise.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/thread/reduction_op.h b/3rd/cutlass/include/cutlass/epilogue/thread/reduction_op.h index c2474c0..1a515c4 100644 --- a/3rd/cutlass/include/cutlass/epilogue/thread/reduction_op.h +++ b/3rd/cutlass/include/cutlass/epilogue/thread/reduction_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/thread/scale_type.h b/3rd/cutlass/include/cutlass/epilogue/thread/scale_type.h index beed8bf..da46fa7 100644 --- a/3rd/cutlass/include/cutlass/epilogue/thread/scale_type.h +++ b/3rd/cutlass/include/cutlass/epilogue/thread/scale_type.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h index 2dd2265..bc034ae 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op_blas3.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op_blas3.h index effb49a..035b686 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op_blas3.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op_blas3.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_direct_store.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_direct_store.h index 45e3602..b3d01bb 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_direct_store.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_direct_store.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_planar_complex.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_planar_complex.h index ed87a9e..1b3d380 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_planar_complex.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_planar_complex.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_simt.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_simt.h index 10719f1..690d6e7 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_simt.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_simt.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h index fb01693..2b42054 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op_blas3.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op_blas3.h index 68a98f3..0b1eba1 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op_blas3.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op_blas3.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h index 2039fe1..707133d 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_absmax.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_absmax.h index f260a5b..b555882 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_absmax.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_absmax.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h index ef4fc03..5d63349 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_reduction.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_reduction.h index 0e023c6..d54403b 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_reduction.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_reduction.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h index dd7a071..db32e10 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_simt.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_simt.h index 030a9c1..31e9e40 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_simt.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_simt.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_tensor_op.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_tensor_op.h index 39297f1..3ee6f2a 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_tensor_op.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_volta_tensor_op.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_volta_tensor_op.h index 3c38116..298aa6a 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_volta_tensor_op.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_volta_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_wmma_tensor_op.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_wmma_tensor_op.h index 5f5cd47..fcd296d 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_wmma_tensor_op.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_wmma_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/direct_store_epilogue_iterator.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/direct_store_epilogue_iterator.h index 07115e6..784e47d 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/direct_store_epilogue_iterator.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/direct_store_epilogue_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue.h index 49143cf..423996a 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -38,10 +38,11 @@ */ #pragma once - -#include - #include "cutlass/cutlass.h" +#ifndef __QNX__ +#include CUDA_STD_HEADER(cassert) +#endif + #include "cutlass/numeric_types.h" #include "cutlass/array.h" #include "cutlass/layout/vector.h" diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_base.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_base.h index 57ba7aa..b6b6360 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_base.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_base.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -37,15 +37,15 @@ */ #pragma once - +#include "cutlass/cutlass.h" #if !defined(__CUDACC_RTC__) #include #include #endif +#ifndef __QNX__ +#include CUDA_STD_HEADER(cassert) +#endif -#include - -#include "cutlass/cutlass.h" #include "cutlass/matrix_shape.h" #include "cutlass/numeric_types.h" #include "cutlass/array.h" diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_base_streamk.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_base_streamk.h index 14aac16..47686d5 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_base_streamk.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_base_streamk.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_depthwise.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_depthwise.h index 7696741..2c8f579 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_depthwise.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_depthwise.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_direct_store.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_direct_store.h index 187d40c..d975f21 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_direct_store.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_direct_store.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h index e8d6fbc..4b308b2 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -37,10 +37,8 @@ */ #pragma once - -#include - #include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) #include "cutlass/numeric_types.h" #include "cutlass/array.h" #include "cutlass/layout/vector.h" diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_planar_complex.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_planar_complex.h index 7eb68f2..e254298 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_planar_complex.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_planar_complex.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_smem_accumulator.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_smem_accumulator.h index 7321355..6eb44ef 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_smem_accumulator.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_smem_accumulator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -37,10 +37,8 @@ */ #pragma once - -#include - #include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) #include "cutlass/numeric_types.h" #include "cutlass/array.h" #include "cutlass/layout/vector.h" diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_streamk_with_broadcast.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_streamk_with_broadcast.h index 6a50a50..050f500 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_streamk_with_broadcast.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_streamk_with_broadcast.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -38,16 +38,16 @@ */ #pragma once +#include "cutlass/cutlass.h" -#include +#include CUDA_STD_HEADER(cassert) #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(utility) #else #include #endif -#include "cutlass/cutlass.h" #include "cutlass/array.h" #include "cutlass/numeric_types.h" #include "cutlass/numeric_conversion.h" diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h index 8459a72..19f2a66 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_absmax.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_absmax.h index 751ce50..2ec42cd 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_absmax.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_absmax.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -49,16 +49,15 @@ */ #pragma once - -#include +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(utility) #else #include #endif -#include "cutlass/cutlass.h" #include "cutlass/array.h" #include "cutlass/numeric_types.h" #include "cutlass/numeric_conversion.h" diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h index 312d43c..ec83c16 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -38,16 +38,15 @@ */ #pragma once - -#include +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(utility) #else #include #endif -#include "cutlass/cutlass.h" #include "cutlass/array.h" #include "cutlass/numeric_types.h" #include "cutlass/numeric_conversion.h" diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h index 5699a23..c0c855d 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -38,10 +38,10 @@ */ #pragma once +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) -#include -#include "cutlass/cutlass.h" #include "cutlass/array.h" #include "cutlass/numeric_types.h" #include "cutlass/numeric_conversion.h" diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_scaling_factor.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_scaling_factor.h new file mode 100644 index 0000000..c3b616c --- /dev/null +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_scaling_factor.h @@ -0,0 +1,231 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Epilogue visitor for threadblock scoped GEMMs that process softmax computations in epilogue. + + The epilogue finds max values in each row of the row-major output matrix and stores them. + The max values are also used for a further round of threadblock scoped reduction operation, where + the partial reduction results are stored in a pre-allocated array and used for further full reduction. + +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/arch/memory.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" // cutlass::TensorRef + +namespace cutlass +{ +namespace epilogue +{ +namespace threadblock +{ + +template +class GemvEpilogueWithScalingFactor +{ + public: + using ThreadShape = ThreadShape_; + using ElementCompute = ElementCompute_; // f32 + using ElementAccumulator = ElementAccumulator_; // f32 + using ElementC = ElementC_; // e2m1 + using ElementD = ElementD_; // e2m1 + using ElementSFD = ElementSFD_; // e4m3 + using LayoutOutput = LayoutOutput_; // ColumnMajor + using LayoutSFD = LayoutSFD_; // ColumnMajor + using TensorRefD = TensorRef; + static constexpr int kVectorSize = kVectorSize_; + // number of threads row + static constexpr int kThreadsPerCol = ThreadShape::kM; // 16 + // number of threads col + static constexpr int kThreadsPerRow = ThreadShape::kN; // 8 + static constexpr int kThreadCount = kThreadsPerCol * kThreadsPerRow; // 128 + + static_assert(kVectorSize == kThreadsPerCol, "vector size and number of threads row should be equal"); + static_assert(std::is_same_v && + std::is_same_v, + "Only support Mx1 (ColumnMajor) output and ColumnMajor scaling factor"); + static_assert(std::is_same_v, "ElementCompute should be float type"); + static_assert(cutlass::sizeof_bits::value == 4, "Output should be FP4 type"); + static_assert(cutlass::sizeof_bits::value == 8, "ElementSFD should be FP8 type"); + static_assert(std::is_same_v, "only support same layout for D and SFD"); + + // Hardcode static_assert on threadshape 16x8 to avoid bug + static_assert(kThreadsPerCol == 16, "thread shape col false"); + static_assert(kThreadsPerRow == 8, "thread shape row false"); + static_assert(kThreadCount == 128, "thread count false"); + + struct Params + { + TensorRefD tensor_d; + ElementSFD *scale_factor_d_ptr{nullptr}; + ElementCompute alpha{0}; + ElementCompute beta{0}; + float st{0}; + int64_t batch_stride_sfd{0}; // Add batch stride for SFD + int64_t stride_d{0}; // Add stride for D tensor + }; + + /// Shared storage + struct SharedStorage + { + // fp32 + // Each thread store one fp32 +#if 1 + ElementAccumulator reduction_buffer[kThreadsPerCol]; +#else + ElementAccumulator reduction_buffer[kThreadCount]; +#endif + // Buffer for collecting 4-bit values for packed store + uint8_t packed_buffer[kThreadsPerCol]; + }; + + private: + Params const ¶ms_; + SharedStorage &shared_storage_; + float st_scale_down{0}; + + public: + CUTLASS_HOST_DEVICE GemvEpilogueWithScalingFactor(Params const ¶ms, SharedStorage &shared_storage) + : params_(params) + , shared_storage_(shared_storage) + { + const float fp_subtype_max = static_cast(cutlass::platform::numeric_limits::max()); + this->st_scale_down = this->params_.st / fp_subtype_max; + } + + CUTLASS_DEVICE void operator()(ElementAccumulator frag_acc, ElementC frag_c, int batch_idx) + { + const int block_idx = blockIdx.x; + const int thread_idx_col = threadIdx.x; + const int thread_idx_row = threadIdx.y; + + const float st_scale_down = this->st_scale_down; + const float st = this->params_.st; + + // Compute D offset using batch_idx and stride_d + const int output_d_base_offset = blockIdx.x * blockDim.y; + const int d_batch_offset = batch_idx * params_.stride_d; + ElementD* output_ptr = ¶ms_.tensor_d.at({output_d_base_offset + d_batch_offset, 0}); + uint8_t* byte_ptr = reinterpret_cast(output_ptr); + // For 8x16 thread layout, 1 thread per 128 threads write to sf d + // Every block write one SFD to gmem + const bool is_write_sfd_thread = (thread_idx_row == 0); + + // Calculate SFD offset using proper batch stride + const int output_sfd_offset = (block_idx / 4) * 512 + block_idx % 4 + batch_idx * params_.batch_stride_sfd; + + auto reduction_buffer = shared_storage_.reduction_buffer; + // fp32 + ElementAccumulator max_accum_row0 = ElementAccumulator(0); + ElementAccumulator max_accum_row1 = ElementAccumulator(0); + + // Thread in row contain duplicate frag_acc data + if ( thread_idx_col == 0 ) { + // 16 threads write to 16 contigious bank, no conflict + reduction_buffer[thread_idx_row] = frag_acc; + } + + __syncthreads(); + + if (threadIdx.y == 0) { + auto acc_0 = reduction_buffer[threadIdx.x * 2]; + auto acc_1 = reduction_buffer[threadIdx.x * 2 + 1]; + // compute the max for me using shuffling among 16 threads. + ElementAccumulator max_accum = fabsf(acc_0); + max_accum = cutlass::fast_max(max_accum, fabsf(acc_1)); + + // Butterfly reduction pattern for 16 threads + // Each iteration halves the number of active lanes + max_accum = cutlass::fast_max(max_accum, __shfl_down_sync(0xFF, max_accum, 4)); // 8->4 + max_accum = cutlass::fast_max(max_accum, __shfl_down_sync(0xFF, max_accum, 2)); // 4->2 + max_accum = cutlass::fast_max(max_accum, __shfl_down_sync(0xFF, max_accum, 1)); // 2->1 + + // Broadcast the final result to all 8 threads + max_accum = __shfl_sync(0xFF, max_accum, 0); + + float pvscale = max_accum * st_scale_down; + ElementSFD qpvscale = static_cast(pvscale); + float qpvscale_up = NumericConverter{}(qpvscale); + float qpvscale_up_rcp = __frcp_rn(qpvscale_up) * st; + uint8_t qval_u8_compare; + + #if defined(CUDA_PTX_FP4FP6_CVT_ENABLED) + uint32_t temp_result; + asm volatile ( + "{\n" + " .reg .f32 output_fp32_0, output_fp32_1;\n" + " .reg .b8 byte0, byte1, byte2, byte3;\n" + " mul.f32 output_fp32_0, %1, %3;\n" + " mul.f32 output_fp32_1, %2, %3;\n" + " cvt.rn.satfinite.e2m1x2.f32 byte0, output_fp32_1, output_fp32_0;\n" + " mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}\n" + : "=r"(temp_result) // Output to uint32_t + : "f"(acc_0), "f"(acc_1), "f"(qpvscale_up_rcp) + ); + qval_u8_compare = temp_result & 0xFF; + #else + ElementD output_fp4_0 = NumericConverter{}(acc_0 * qpvscale_up_rcp); + ElementD output_fp4_1 = NumericConverter{}(acc_1 * qpvscale_up_rcp); + uint8_t raw_fp4_0 = reinterpret_cast(output_fp4_0) & 0x0F; + uint8_t raw_fp4_1 = reinterpret_cast(output_fp4_1) & 0x0F; + qval_u8_compare = (raw_fp4_1 << 4) | raw_fp4_0; + #endif + byte_ptr[threadIdx.x] = qval_u8_compare; + + arch::global_store(qpvscale, + (void *)(params_.scale_factor_d_ptr + output_sfd_offset), + is_write_sfd_thread); + + } + + } // end of operator() +}; + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_visitor.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_visitor.h index e3e5abd..5663672 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_visitor.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_visitor.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h index 377524f..fa68a55 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_workspace.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_workspace.h index 65bf32a..502e377 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_workspace.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/epilogue_workspace.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/fusion/visitor_2x.hpp b/3rd/cutlass/include/cutlass/epilogue/threadblock/fusion/visitor_2x.hpp index a5b26e0..7d5b712 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/fusion/visitor_2x.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/fusion/visitor_2x.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/fusion/visitor_compute.hpp b/3rd/cutlass/include/cutlass/epilogue/threadblock/fusion/visitor_compute.hpp index 6275a2f..ce1ead5 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/fusion/visitor_compute.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/fusion/visitor_compute.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/fusion/visitor_load.hpp b/3rd/cutlass/include/cutlass/epilogue/threadblock/fusion/visitor_load.hpp index d894b11..bf3936f 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/fusion/visitor_load.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/fusion/visitor_load.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/fusion/visitor_store.hpp b/3rd/cutlass/include/cutlass/epilogue/threadblock/fusion/visitor_store.hpp index 7bc7f80..53beb6c 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/fusion/visitor_store.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/fusion/visitor_store.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/fusion/visitors.hpp b/3rd/cutlass/include/cutlass/epilogue/threadblock/fusion/visitors.hpp index f1936f2..1a30c2a 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/fusion/visitors.hpp +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/fusion/visitors.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/interleaved_epilogue.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/interleaved_epilogue.h index ec717fb..43a2a52 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/interleaved_epilogue.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/interleaved_epilogue.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/output_iterator_parameter.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/output_iterator_parameter.h index 6f6d101..2a0846a 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/output_iterator_parameter.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/output_iterator_parameter.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/output_tile_thread_map.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/output_tile_thread_map.h index 2c011c1..25fda0b 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/output_tile_thread_map.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/output_tile_thread_map.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h index 7c4692f..2502c2e 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h index 7068c39..80d9482 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine_layout_params.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine_layout_params.h index 9990dbd..a0af5b2 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine_layout_params.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine_layout_params.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_blas3.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_blas3.h index 518ad09..3ab7de8 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_blas3.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_blas3.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_conv.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_conv.h index 49ee22e..67dc94a 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_conv.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_conv.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_direct_conv.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_direct_conv.h index 0d1f171..571eb23 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_direct_conv.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_direct_conv.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h index 11ec3d7..13f2b4f 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_predicates.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_predicates.h index a4ed371..7e72663 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_predicates.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_predicates.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h index dfe9571..60bf427 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator.h index a321f1b..f98468c 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator_mixed.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator_mixed.h index 66cc17f..25c3a6f 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator_mixed.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator_mixed.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator_pitch_linear.h b/3rd/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator_pitch_linear.h index 74d040b..684394c 100644 --- a/3rd/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator_pitch_linear.h +++ b/3rd/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator_pitch_linear.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h b/3rd/cutlass/include/cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h index 58ccbfa..3de43a7 100644 --- a/3rd/cutlass/include/cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h +++ b/3rd/cutlass/include/cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/warp/fragment_iterator_gaussian_complex_tensor_op.h b/3rd/cutlass/include/cutlass/epilogue/warp/fragment_iterator_gaussian_complex_tensor_op.h index b03cab8..4b51f84 100644 --- a/3rd/cutlass/include/cutlass/epilogue/warp/fragment_iterator_gaussian_complex_tensor_op.h +++ b/3rd/cutlass/include/cutlass/epilogue/warp/fragment_iterator_gaussian_complex_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/warp/fragment_iterator_simt.h b/3rd/cutlass/include/cutlass/epilogue/warp/fragment_iterator_simt.h index 404be79..c533196 100644 --- a/3rd/cutlass/include/cutlass/epilogue/warp/fragment_iterator_simt.h +++ b/3rd/cutlass/include/cutlass/epilogue/warp/fragment_iterator_simt.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/warp/fragment_iterator_tensor_op.h b/3rd/cutlass/include/cutlass/epilogue/warp/fragment_iterator_tensor_op.h index 4c6f10b..f8bc10a 100644 --- a/3rd/cutlass/include/cutlass/epilogue/warp/fragment_iterator_tensor_op.h +++ b/3rd/cutlass/include/cutlass/epilogue/warp/fragment_iterator_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/warp/fragment_iterator_volta_tensor_op.h b/3rd/cutlass/include/cutlass/epilogue/warp/fragment_iterator_volta_tensor_op.h index fede558..954f7c3 100644 --- a/3rd/cutlass/include/cutlass/epilogue/warp/fragment_iterator_volta_tensor_op.h +++ b/3rd/cutlass/include/cutlass/epilogue/warp/fragment_iterator_volta_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/warp/fragment_iterator_wmma_tensor_op.h b/3rd/cutlass/include/cutlass/epilogue/warp/fragment_iterator_wmma_tensor_op.h index 245499b..c1d9e79 100644 --- a/3rd/cutlass/include/cutlass/epilogue/warp/fragment_iterator_wmma_tensor_op.h +++ b/3rd/cutlass/include/cutlass/epilogue/warp/fragment_iterator_wmma_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/warp/simt_policy.h b/3rd/cutlass/include/cutlass/epilogue/warp/simt_policy.h index a1fa65c..d53b21f 100644 --- a/3rd/cutlass/include/cutlass/epilogue/warp/simt_policy.h +++ b/3rd/cutlass/include/cutlass/epilogue/warp/simt_policy.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/warp/tensor_op_policy.h b/3rd/cutlass/include/cutlass/epilogue/warp/tensor_op_policy.h index 002d859..11f3c5d 100644 --- a/3rd/cutlass/include/cutlass/epilogue/warp/tensor_op_policy.h +++ b/3rd/cutlass/include/cutlass/epilogue/warp/tensor_op_policy.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/warp/tile_iterator_simt.h b/3rd/cutlass/include/cutlass/epilogue/warp/tile_iterator_simt.h index be7af13..d496988 100644 --- a/3rd/cutlass/include/cutlass/epilogue/warp/tile_iterator_simt.h +++ b/3rd/cutlass/include/cutlass/epilogue/warp/tile_iterator_simt.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/warp/tile_iterator_tensor_op.h b/3rd/cutlass/include/cutlass/epilogue/warp/tile_iterator_tensor_op.h index 7cfa072..91a51f6 100644 --- a/3rd/cutlass/include/cutlass/epilogue/warp/tile_iterator_tensor_op.h +++ b/3rd/cutlass/include/cutlass/epilogue/warp/tile_iterator_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h b/3rd/cutlass/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h index 134e668..327e733 100644 --- a/3rd/cutlass/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h +++ b/3rd/cutlass/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h b/3rd/cutlass/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h index a18a9ac..92e7f1a 100644 --- a/3rd/cutlass/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h +++ b/3rd/cutlass/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h b/3rd/cutlass/include/cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h index 8129dce..1d04604 100644 --- a/3rd/cutlass/include/cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h +++ b/3rd/cutlass/include/cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/warp/volta_tensor_op_policy.h b/3rd/cutlass/include/cutlass/epilogue/warp/volta_tensor_op_policy.h index c108fc9..ea668ea 100644 --- a/3rd/cutlass/include/cutlass/epilogue/warp/volta_tensor_op_policy.h +++ b/3rd/cutlass/include/cutlass/epilogue/warp/volta_tensor_op_policy.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/epilogue/warp/wmma_tensor_op_policy.h b/3rd/cutlass/include/cutlass/epilogue/warp/wmma_tensor_op_policy.h index 01b1e72..1b78e3d 100644 --- a/3rd/cutlass/include/cutlass/epilogue/warp/wmma_tensor_op_policy.h +++ b/3rd/cutlass/include/cutlass/epilogue/warp/wmma_tensor_op_policy.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/exmy_base.h b/3rd/cutlass/include/cutlass/exmy_base.h index be207a4..0ab6a87 100644 --- a/3rd/cutlass/include/cutlass/exmy_base.h +++ b/3rd/cutlass/include/cutlass/exmy_base.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -79,7 +79,7 @@ enum class FpEncoding E8M23, // float E5M2, // FP8 E4M3, // FP8 - UE4M3, // FP8 + UE4M3, // FP8 UE8M0, // FP8 E3M2, // FP6 E2M3, // FP6 @@ -869,7 +869,7 @@ CUTLASS_CONSTEXPR_IF_CXX17 auto fp_encoding_selector() { else if CUTLASS_CONSTEXPR_IF_CXX17 (FpExMyCode == FpEncoding::UE4M3) { // FP8 return cutlass::detail::FpBitRepresentation{}; } - + else if CUTLASS_CONSTEXPR_IF_CXX17 (FpExMyCode == FpEncoding::UE8M0) { // FP8 return cutlass::detail::FpBitRepresentation{}; } diff --git a/3rd/cutlass/include/cutlass/experimental/distributed/device/detail.hpp b/3rd/cutlass/include/cutlass/experimental/distributed/device/detail.hpp index 129f733..f358c18 100644 --- a/3rd/cutlass/include/cutlass/experimental/distributed/device/detail.hpp +++ b/3rd/cutlass/include/cutlass/experimental/distributed/device/detail.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/experimental/distributed/device/dist_gemm_universal_wrapper.hpp b/3rd/cutlass/include/cutlass/experimental/distributed/device/dist_gemm_universal_wrapper.hpp index 7968849..82810b8 100644 --- a/3rd/cutlass/include/cutlass/experimental/distributed/device/dist_gemm_universal_wrapper.hpp +++ b/3rd/cutlass/include/cutlass/experimental/distributed/device/dist_gemm_universal_wrapper.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -253,16 +253,59 @@ class DistributedGemmUniversalAdapter { return DistSchedule::get_tensor_D(tensor_D, tensor_buffer, device_idx, iteration); } + static + auto make_dummy_base_args(Arguments const* args, int device_idx, int iteration, void ** buffer_space) { + + // Set up GEMM arguments for the current stage/iteration + auto tensor_a_iter = get_tensor_A_for_iter(args, buffer_space, device_idx, iteration); + auto tensor_b_iter = get_tensor_B_for_iter(args, buffer_space, device_idx, iteration); + auto tensor_c_iter = get_tensor_C_for_iter(args, buffer_space, device_idx, iteration); + auto tensor_d_iter = get_tensor_D_for_iter(args, buffer_space, device_idx, iteration); + + Arguments base_args = args[device_idx]; + base_args.problem_shape = DistSchedule::get_local_gemm_shape(args[device_idx].problem_shape); + base_args.mainloop = { + reinterpret_cast(tensor_a_iter.data()), + tensor_a_iter.stride(), + reinterpret_cast(tensor_b_iter.data()), + tensor_b_iter.stride() + }; + base_args.epilogue = { + base_args.epilogue.thread, + reinterpret_cast(tensor_c_iter.data()), + tensor_c_iter.stride(), + reinterpret_cast(tensor_d_iter.data()), + tensor_d_iter.stride() + }; + + if constexpr (DistSchedule::RemoteC) { + if (iteration > 0) { + base_args.epilogue.thread.beta = 1.0; + } + else if (iteration == 0){ + base_args.epilogue.thread.beta = 0.0; + } + } + + return base_args; + } + static size_t - get_workspace_size(Arguments const& args) { + get_workspace_size(Arguments const* args, int device_idx) { size_t workspace_bytes = 0; - workspace_bytes = get_buffer_space_size(args); + workspace_bytes = get_buffer_space_size(args[device_idx]); + + void* dummy_buffer_space[TP_]; for (int iteration = 0; iteration < TP_; ++iteration) { + // Workspace sizes can vary if arguments change, therefore we must + // construct args for each iteration exactly as it will be run. + auto args_base = make_dummy_base_args(args, device_idx, iteration, dummy_buffer_space); + // NOTE: assumes underlying kernels align up to alignment requirements on their own, // and that the alignment requirements of the individual kernels match. - workspace_bytes += GemmKernel::get_workspace_size(args); + workspace_bytes += GemmKernel::get_workspace_size(args_base); } return workspace_bytes; diff --git a/3rd/cutlass/include/cutlass/experimental/distributed/device/full_barrier.hpp b/3rd/cutlass/include/cutlass/experimental/distributed/device/full_barrier.hpp index ab91cf8..67f4135 100644 --- a/3rd/cutlass/include/cutlass/experimental/distributed/device/full_barrier.hpp +++ b/3rd/cutlass/include/cutlass/experimental/distributed/device/full_barrier.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/experimental/distributed/kernel/detail.hpp b/3rd/cutlass/include/cutlass/experimental/distributed/kernel/detail.hpp index 0445567..0022049 100644 --- a/3rd/cutlass/include/cutlass/experimental/distributed/kernel/detail.hpp +++ b/3rd/cutlass/include/cutlass/experimental/distributed/kernel/detail.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/experimental/distributed/kernel/dist_gemm_kernel_wrapper.hpp b/3rd/cutlass/include/cutlass/experimental/distributed/kernel/dist_gemm_kernel_wrapper.hpp index b290031..c5e65c0 100644 --- a/3rd/cutlass/include/cutlass/experimental/distributed/kernel/dist_gemm_kernel_wrapper.hpp +++ b/3rd/cutlass/include/cutlass/experimental/distributed/kernel/dist_gemm_kernel_wrapper.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/experimental/distributed/kernel/full_barrier.hpp b/3rd/cutlass/include/cutlass/experimental/distributed/kernel/full_barrier.hpp index 0ec620a..2e0c0ad 100644 --- a/3rd/cutlass/include/cutlass/experimental/distributed/kernel/full_barrier.hpp +++ b/3rd/cutlass/include/cutlass/experimental/distributed/kernel/full_barrier.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/experimental/distributed/schedules/dist_gemm_1d_schedules.hpp b/3rd/cutlass/include/cutlass/experimental/distributed/schedules/dist_gemm_1d_schedules.hpp index 73d52ad..23419bb 100644 --- a/3rd/cutlass/include/cutlass/experimental/distributed/schedules/dist_gemm_1d_schedules.hpp +++ b/3rd/cutlass/include/cutlass/experimental/distributed/schedules/dist_gemm_1d_schedules.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/experimental/distributed/schedules/dist_gemm_base_schedule.hpp b/3rd/cutlass/include/cutlass/experimental/distributed/schedules/dist_gemm_base_schedule.hpp index 3a2d332..bf5f14b 100644 --- a/3rd/cutlass/include/cutlass/experimental/distributed/schedules/dist_gemm_base_schedule.hpp +++ b/3rd/cutlass/include/cutlass/experimental/distributed/schedules/dist_gemm_base_schedule.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/fast_math.h b/3rd/cutlass/include/cutlass/fast_math.h index 4a758b7..8fa30f9 100644 --- a/3rd/cutlass/include/cutlass/fast_math.h +++ b/3rd/cutlass/include/cutlass/fast_math.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -30,18 +30,17 @@ **************************************************************************************************/ #pragma once - +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(cstdint) #else #include #include #include #endif #if !defined(__QNX__) -#include +#include CUDA_STD_HEADER(utility) #endif -#include "cutlass/cutlass.h" #include "cutlass/array.h" #include "cutlass/uint128.h" #include "cutlass/coord.h" diff --git a/3rd/cutlass/include/cutlass/float8.h b/3rd/cutlass/include/cutlass/float8.h index 574202e..9aacd78 100644 --- a/3rd/cutlass/include/cutlass/float8.h +++ b/3rd/cutlass/include/cutlass/float8.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -59,12 +59,14 @@ #if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ - defined(CUTLASS_ARCH_MMA_SM120A_ENABLED)) + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM110A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121A_ENABLED)) # define CUDA_PTX_UE8M0_CVT_ENABLED 1 #endif #if (defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) ||\ - defined(CUTLASS_ARCH_MMA_SM120F_ENABLED)) + defined(CUTLASS_ARCH_MMA_SM103F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM110F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM120F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121F_ENABLED)) # define CUDA_PTX_UE8M0_CVT_ENABLED 1 #endif @@ -1635,46 +1637,46 @@ struct numeric_limits : // CUTLASS_HOST_DEVICE -cutlass::float_e4m3_t operator "" _fe4m3(long double x) { +cutlass::float_e4m3_t operator""_fe4m3(long double x) { return cutlass::float_e4m3_t(float(x)); } CUTLASS_HOST_DEVICE -cutlass::float_e4m3_t operator "" _fe4m3(unsigned long long int x) { +cutlass::float_e4m3_t operator""_fe4m3(unsigned long long int x) { return cutlass::float_e4m3_t(int(x)); } CUTLASS_HOST_DEVICE -cutlass::float_ue4m3_t operator "" _fue4m3(long double x) { +cutlass::float_ue4m3_t operator""_fue4m3(long double x) { return cutlass::float_ue4m3_t(float(x)); } CUTLASS_HOST_DEVICE -cutlass::float_ue4m3_t operator "" _fue4m3(unsigned long long int x) { +cutlass::float_ue4m3_t operator""_fue4m3(unsigned long long int x) { return cutlass::float_ue4m3_t(int(x)); } CUTLASS_HOST_DEVICE -cutlass::float_e5m2_t operator "" _fe5m2(long double x) { +cutlass::float_e5m2_t operator""_fe5m2(long double x) { return cutlass::float_e5m2_t(float(x)); } CUTLASS_HOST_DEVICE -cutlass::float_e5m2_t operator "" _fe5m2(unsigned long long int x) { +cutlass::float_e5m2_t operator""_fe5m2(unsigned long long int x) { return cutlass::float_e5m2_t(int(x)); } CUTLASS_HOST_DEVICE -cutlass::float_ue8m0_t operator "" _fue8m0(long double x) +cutlass::float_ue8m0_t operator""_fue8m0(long double x) { return cutlass::float_ue8m0_t(float(x)); } CUTLASS_HOST_DEVICE -cutlass::float_ue8m0_t operator "" _fue8m0(unsigned long long int x) +cutlass::float_ue8m0_t operator""_fue8m0(unsigned long long int x) { return cutlass::float_ue8m0_t(int(x)); } diff --git a/3rd/cutlass/include/cutlass/float_subbyte.h b/3rd/cutlass/include/cutlass/float_subbyte.h index 547714b..56d5129 100644 --- a/3rd/cutlass/include/cutlass/float_subbyte.h +++ b/3rd/cutlass/include/cutlass/float_subbyte.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -45,12 +45,14 @@ #endif #if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ - defined(CUTLASS_ARCH_MMA_SM120A_ENABLED)) + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM110A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121A_ENABLED)) # define CUDA_PTX_FP4FP6_CVT_ENABLED 1 #endif #if (defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) ||\ - defined(CUTLASS_ARCH_MMA_SM120F_ENABLED)) + defined(CUTLASS_ARCH_MMA_SM103F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM110F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM120F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121F_ENABLED)) # define CUDA_PTX_FP4FP6_CVT_ENABLED 1 #endif @@ -758,36 +760,36 @@ struct numeric_limits : public float_ // User-defined literals // CUTLASS_HOST_DEVICE -cutlass::float_e2m1_t operator"" _fe2m1(long double x) +cutlass::float_e2m1_t operator""_fe2m1(long double x) { return cutlass::float_e2m1_t(float(x)); } CUTLASS_HOST_DEVICE -cutlass::float_e2m1_t operator"" _fe2m1(unsigned long long int x) +cutlass::float_e2m1_t operator""_fe2m1(unsigned long long int x) { return cutlass::float_e2m1_t(int(x)); } CUTLASS_HOST_DEVICE -cutlass::float_e2m3_t operator"" _fe2m3(long double x) +cutlass::float_e2m3_t operator""_fe2m3(long double x) { return cutlass::float_e2m3_t(float(x)); } CUTLASS_HOST_DEVICE -cutlass::float_e2m3_t operator"" _fe2m3(unsigned long long int x) +cutlass::float_e2m3_t operator""_fe2m3(unsigned long long int x) { return cutlass::float_e2m3_t(int(x)); } CUTLASS_HOST_DEVICE -cutlass::float_e3m2_t operator"" _fe3m2(long double x) +cutlass::float_e3m2_t operator""_fe3m2(long double x) { return cutlass::float_e3m2_t(float(x)); } CUTLASS_HOST_DEVICE -cutlass::float_e3m2_t operator"" _fe3m2(unsigned long long int x) +cutlass::float_e3m2_t operator""_fe3m2(unsigned long long int x) { return cutlass::float_e3m2_t(int(x)); } diff --git a/3rd/cutlass/include/cutlass/floating_point_nvrtc.h b/3rd/cutlass/include/cutlass/floating_point_nvrtc.h index 6496fea..1ce3d2a 100644 --- a/3rd/cutlass/include/cutlass/floating_point_nvrtc.h +++ b/3rd/cutlass/include/cutlass/floating_point_nvrtc.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/functional.h b/3rd/cutlass/include/cutlass/functional.h index 628a807..2ac6e65 100644 --- a/3rd/cutlass/include/cutlass/functional.h +++ b/3rd/cutlass/include/cutlass/functional.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -38,6 +38,8 @@ #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" #include "cutlass/platform/platform.h" +#include "cutlass/detail/dependent_false.hpp" + #if defined(__CUDACC_RTC__) #include "cutlass/floating_point_nvrtc.h" #endif @@ -54,7 +56,8 @@ #include #endif // _MSC_VER -#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) +#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM103F_ENABLED) # define CUTLASS_ARCH_CREDUX_ENABLED #endif @@ -534,6 +537,49 @@ struct maximum_absolute_value_reduction { } }; +// Maximal exponent reduction for zero-mantissa scaling factors +template +struct maximum_absolute_value_zero_mantissa_reduction { + + // Discard mantissa and sign bits for the input. Needs to specify the number of mantissa / exponent bits + template + static CUTLASS_HOST_DEVICE T_ discard_sign_mantissa_impl(T_ x) { + static constexpr UI one = 1; + static constexpr UI n_mantissa = N_Mantissa; + static constexpr UI pos_sign = sizeof(T_) * 8 - 1; // Position of sign bit: bit width - 1. + static constexpr UI mask = ~((one << n_mantissa) - one) & ~(one << pos_sign); + static constexpr UI subnormal_cap = one << n_mantissa; + + UI out = *reinterpret_cast(&x) & mask; + // Subnormals + if (out == 0) { + out = subnormal_cap; + } + return *reinterpret_cast(&out); + } + + // Discard mantissa and sign bits s.t. multipling with this scaling factor only results in an exponent shift + template + static CUTLASS_HOST_DEVICE T_ discard_sign_mantissa(T_ x) { + if constexpr (cute::is_same_v) { + return discard_sign_mantissa_impl(x); + } + else if constexpr (cute::is_same_v) { + return discard_sign_mantissa_impl(x); + } + else { + static_assert(cutlass::detail::dependent_false, "Can't discard mantissa & sign bits for unknown data type"); + } + } + + CUTLASS_HOST_DEVICE + T operator()(T const &lhs, T const &rhs) const { + cutlass::maximum max_op; + + return max_op(lhs, discard_sign_mantissa(rhs)); + } +}; + /// Fused multiply-add template struct multiply_add { @@ -655,7 +701,7 @@ struct and_popc_add { } }; -/// Fused multiply-add +/// Fused and-add template struct and_add { CUTLASS_HOST_DEVICE @@ -677,7 +723,7 @@ struct xor_popc_add { } }; -/// Fused multiply-add +/// Fused xor-add template struct xor_add { CUTLASS_HOST_DEVICE @@ -699,7 +745,7 @@ struct or_popc_add { }; -/// Fused multiply-add +/// Fused or-add template struct or_add { CUTLASS_HOST_DEVICE diff --git a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_9xBF16_interleaved_complex_umma_builder.inl b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_9xBF16_interleaved_complex_umma_builder.inl new file mode 100644 index 0000000..5cdbdd7 --- /dev/null +++ b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_9xBF16_interleaved_complex_umma_builder.inl @@ -0,0 +1,298 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + + +#include "cutlass/gemm/collective/builders/sm100_common.inl" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// +// FastFP (9xBF16) TCGEN05 kernels builder +// Interleaved complex kernels that provides support for complex data types +template < + class ArchTag, + class GmemLayoutATag, + int AlignmentA, + class TransformA, + class GmemLayoutBTag, + int AlignmentB, + class TransformB, + class ElementAccumulator, + class TileShape_MNK, // The Cluster-level TileShape + class ClusterShape_MNK, + class StageCountType, + class BuilderScheduleTag +> +struct CollectiveBuilder< + ArchTag, + arch::OpClassTensorOp, + cute::tuple, TransformA>, // ElementA + ConjA + GmemLayoutATag, // LayoutA + AlignmentA, + cute::tuple, TransformB>, // ElementB + ConjB + GmemLayoutBTag, // LayoutB + AlignmentB, + ElementAccumulator, + TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + ClusterShape_MNK, // Static cluster shape or dynamic (int, int, int) + StageCountType, + BuilderScheduleTag, + cute::enable_if_t< + (cute::is_same_v + ) && + (not cute::is_tuple::value && not cute::is_tuple::value) && + (cute::is_base_of_v + ) && + ((sizeof(cutlass::complex) * AlignmentA) % detail::tma_alignment_bytes == 0) && + ((sizeof(cutlass::complex) * AlignmentB) % detail::tma_alignment_bytes == 0)>> +{ + static_assert(cute::is_static_v, "TileShape_MNK has to be static"); + static constexpr cute::UMMA::Major UmmaMajorA = cutlass::gemm::collective::detail::tag_to_umma_major_A(); + static constexpr cute::UMMA::Major UmmaMajorB = cutlass::gemm::collective::detail::tag_to_umma_major_B(); + static constexpr cute::UMMA::Major UmmaMajorACompute = cute::UMMA::Major::K; + static constexpr cute::UMMA::Major UmmaMajorBCompute = cute::UMMA::Major::K; + static constexpr bool BuilderTagIsSmem = + cute::is_base_of_v + ; + + using ElementA = complex; + using ElementB = complex; + using ElementAMma = complex< + cutlass::bfloat16_t + >; + using ElementBMma = complex< + cutlass::bfloat16_t + >; + static constexpr int ScalingFactor = + 8; + + using TiledMma = decltype(detail::sm100_make_trivial_fastFP32_tiled_mma()); + using AtomThrID = typename TiledMma::AtomThrID; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + using CtaTileShape_MNK = decltype(shape_div(TileShape_MNK{}, AtomThrShapeMNK{})); + + // ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K) + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + // ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K) + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + + using BlockTileA_M = decltype(cute::size<0,0>(MmaShapeA_MK{}) * cute::size<1>(MmaShapeA_MK{})); + using BlockTileA_K = decltype(cute::size<0,1>(MmaShapeA_MK{}) * cute::size<2>(MmaShapeA_MK{})); + + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); + // Take 3 compute buffers into account for swizzle selection + using SmemLayoutAtomACompute = decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); + + // Input transform kernel can not use TMA 2SM instructions. + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(cute::size<1>(ClusterShape_MNK{}))); + using SmemLayoutAtomPairA = cutlass::gemm::collective::detail::CollectiveMmaEmulatedLayoutAtomType< + SmemLayoutAtomA, SmemLayoutAtomACompute>; + + static constexpr int MMA_M = cute::size<0,0>(MmaShapeA_MK{}); + using CopyAtomPairA = cutlass::gemm::collective::detail::CollectiveMmaEmulatedCopyType< + Copy_Atom, ElementA>, + cute::conditional_t<(UmmaMajorACompute == cute::UMMA::Major::K && !BuilderTagIsSmem), + cute::conditional_t<(MMA_M == 64 && size(AtomThrID{}) == 1), SM100_TMEM_STORE_16dp256b1x, SM100_TMEM_STORE_32dp32b8x>, // TS Implementation + Copy_Atom, ElementAMma>> // SS Implementation + >; + + using BlockTileB_N = decltype(cute::size<0,0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{})); + using BlockTileB_K = decltype(cute::size<0,1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{})); + + // Input transform kernel can not use TMA 2SM instructions. + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(cute::size<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); + // Take 3 compute buffers into account for swizzle selection + using SmemLayoutAtomBCompute = decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); + + using SmemLayoutAtomPairB = cutlass::gemm::collective::detail::CollectiveMmaEmulatedLayoutAtomType< + SmemLayoutAtomB, SmemLayoutAtomBCompute>; + using CopyAtomPairB = cutlass::gemm::collective::detail::CollectiveMmaEmulatedCopyType< + Copy_Atom, ElementB>, + Copy_Atom, ElementBMma> + >; + + // SmemCarveout + static constexpr int NumComplexComponents = 2; + static constexpr int NumComputeMtxs = + 3; + static constexpr int NumBandsToCompute = + 5; + static constexpr int AccPromotionInterval = + 1; + static constexpr int SchedulerPipelineStageCount = 3; + static constexpr bool IsArrayOfPointersGemm = + (cute::is_base_of_v + ); + // CLCPipeline = PipelineCLCFetchAsync + static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); + // CLC (scheduler) response + static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize; + // CLC Throttle pipeline storage + static constexpr auto CLCThrottlePipelineStorage = sizeof(typename cutlass::PipelineAsync::SharedStorage); + // Tmem dealloc + static constexpr auto TmemDeallocStorage = sizeof(cutlass::arch::ClusterBarrier); + // Tmem ptr storage + static constexpr auto TmemBasePtrsStorage = sizeof(uint32_t); + // Tensormap Storage + static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0; + + // Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage + static constexpr auto KernelSmemCarveout = static_cast( CLCPipelineStorage + + CLCResponseStorage + + CLCThrottlePipelineStorage + + TmemDeallocStorage + + TmemBasePtrsStorage + + TensorMapStorage); + + // Reduce SMEM capacity available for buffers considering extra B smem and barrier smem allocations + + static constexpr int ReducedSmemCapacityBytes = + cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + static constexpr auto stage_info = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_fast_fp32< + ReducedSmemCapacityBytes, CtaTileShape_MNK, TiledMma, BuilderScheduleTag, UmmaMajorACompute, + NumComplexComponents, NumComputeMtxs + >(StageCountType{}); + + // Complex 9xBF16 allows TileShape_N = 64, while SmemLayoutAtomB contains Swizzle<3,4,3>. + static constexpr int Load2TransformPipelineStageCount = size<1>(TileShape_MNK{}) == 64 ? get<0>(stage_info) / 2 * 2 : get<0>(stage_info); + static constexpr int Transform2MmaPipelineStageCount = get<1>(stage_info); + static constexpr int AccumulatorPipelineStageCount = get<2>(stage_info); + + using AccumulatorCopyAtom = cute::SM100_TMEM_LOAD_32dp32b32x; + + using DispatchPolicy = cute::conditional_t, + cutlass::gemm::MainloopSm100TmaUmmaWarpSpecializedFastF32< + Load2TransformPipelineStageCount, + Transform2MmaPipelineStageCount, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + NumBandsToCompute, + ScalingFactor, + AccPromotionInterval, + ClusterShape_MNK, + AccumulatorCopyAtom> + >; + using CollectiveOp = cutlass::gemm::collective::CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + cutlass::gemm::TagToStrideA_t, + ElementB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomPairA, + CopyAtomPairA, + TransformA, + GmemTiledCopyB, + SmemLayoutAtomPairB, + CopyAtomPairB, + TransformB + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// FastFP (9xBF16) TCGEN05 kernels builder +// CUTLASS library compatibility builder without conjugate +template < + class ArchTag, + class GmemLayoutATag, + int AlignmentA, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, // The Cluster-level TileShape + class ClusterShape_MNK, + class StageCountType, + class BuilderScheduleTag +> +struct CollectiveBuilder< + ArchTag, + arch::OpClassTensorOp, + cutlass::complex, // ElementA + GmemLayoutATag, // LayoutA + AlignmentA, + cutlass::complex, // ElementB + GmemLayoutBTag, // LayoutB + AlignmentB, + ElementAccumulator, + TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + ClusterShape_MNK, // Static cluster shape or dynamic (int, int, int) + StageCountType, + BuilderScheduleTag, + cute::enable_if_t< + (cute::is_same_v + ) && + (not cute::is_tuple::value && not cute::is_tuple::value) && + (cute::is_base_of_v + ) && + ((sizeof(cutlass::complex) * AlignmentA) % detail::tma_alignment_bytes == 0) && + ((sizeof(cutlass::complex) * AlignmentB) % detail::tma_alignment_bytes == 0)>> +{ + using CollectiveOp = typename CollectiveBuilder< + ArchTag, + arch::OpClassTensorOp, + cute::tuple, cute::identity>, + GmemLayoutATag, + AlignmentA, + cute::tuple, cute::identity>, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + BuilderScheduleTag + >::CollectiveOp; +}; + +} // cutlass::gemm::collective diff --git a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_9xBF16_umma_builder.inl b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_9xBF16_umma_builder.inl index 3edd928..4b34545 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_9xBF16_umma_builder.inl +++ b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_9xBF16_umma_builder.inl @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -61,15 +61,19 @@ sm100_compute_stage_count_or_override_fast_fp32(StageCountAutoCarveout(CtaTileShape_MNK{}); using AtomThrID = typename TiledMma::AtomThrID; constexpr int TmemColumns = 512; + constexpr bool BuilderTagIsSmem = ( + cute::is_base_of_v + ); // Detect 2x2 TMEM layout constexpr int TmemAccWordsPerDP = (CtaM == 64 && size(AtomThrID{}) == 2) ? CtaN/2 : CtaN; constexpr int TmemAWordsPerDP = ComplexComponent * NumComputeMtxs * CtaK / 2; - constexpr bool IsAComputeinTmem = UmmaMajorA == cute::UMMA::Major::K && !cute::is_base_of_v; + constexpr bool IsAComputeinTmem = UmmaMajorA == cute::UMMA::Major::K && !BuilderTagIsSmem; constexpr bool IsAComputeinSmem = !IsAComputeinTmem; constexpr int AccumulatorStageCount = (IsAComputeinTmem) ? (((TmemAccWordsPerDP * ComplexComponent == 128) ? 2 : 3) * ComplexComponent) : (TmemColumns / TmemAccWordsPerDP); - - constexpr int SmemCapacityAfterMma2AccumCarveout = CapacityBytes - (carveout_bytes + AccumulatorStageCount * 32); + + constexpr int SmemCapacityAfterMma2AccumCarveout = CapacityBytes - (carveout_bytes + AccumulatorStageCount * (32 + )); constexpr int TmemInAStageCount_Potential = (IsAComputeinTmem) ? (TmemColumns - AccumulatorStageCount * TmemAccWordsPerDP) / TmemAWordsPerDP : 10000; @@ -87,7 +91,8 @@ sm100_compute_stage_count_or_override_fast_fp32(StageCountAutoCarveout(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{})) + // If ACompute is in TMEM, Acompute buffer has 0 bytes. cutlass::bits_to_bytes(NumComputeMtxs * b_compute_bits * size<1>(CtaTileShape_MNK{}) / size(AtomThrID{}) * size<2>(CtaTileShape_MNK{})) + - static_cast(transform2mma_pipeline_bytes); + static_cast(transform2mma_pipeline_bytes) + ; constexpr int ABComputeStageCount_Potential = SmemCapacityAfterMma2AccumCarveout / (ab_stage_bytes + ab_compute_stage_bytes); // The number of SMEM buffers for A, B. ACompute (if in SMEM), BCompute should be at least Transform2MmaStageCount @@ -106,6 +111,7 @@ sm100_compute_stage_count_or_override_fast_fp32(StageCountAutoCarveout struct CollectiveBuilder< - arch::Sm100, + ArchTag, arch::OpClassTensorOp, float, // ElementA GmemLayoutATag, // LayoutA @@ -131,21 +137,36 @@ struct CollectiveBuilder< StageCountType, BuilderScheduleTag, cute::enable_if_t< + (cute::is_same_v + ) && (not cute::is_tuple::value && not cute::is_tuple::value) && - (cute::is_base_of_v) && + (cute::is_base_of_v + ) && ((sizeof(float) * AlignmentA) % detail::tma_alignment_bytes == 0) && ((sizeof(float) * AlignmentB) % detail::tma_alignment_bytes == 0)>> { static constexpr cute::UMMA::Major UmmaMajorA = cutlass::gemm::collective::detail::tag_to_umma_major_A(); static constexpr cute::UMMA::Major UmmaMajorB = cutlass::gemm::collective::detail::tag_to_umma_major_B(); + static constexpr cute::UMMA::Major UmmaMajorACompute = + UmmaMajorA; + static constexpr cute::UMMA::Major UmmaMajorBCompute = + UmmaMajorB; + static constexpr bool BuilderTagIsSmem = ( + cute::is_base_of_v + ); using ElementA = float; using ElementB = float; - using ElementAMma = cutlass::bfloat16_t; - using ElementBMma = cutlass::bfloat16_t; - static constexpr int ScalingFactor = 8; + using ElementAMma = + cutlass::bfloat16_t + ; + using ElementBMma = + cutlass::bfloat16_t + ; + static constexpr int ScalingFactor = + 8; - using TiledMma = decltype(detail::sm100_make_trivial_fastFP32_tiled_mma()); + using TiledMma = decltype(detail::sm100_make_trivial_fastFP32_tiled_mma()); using AtomThrID = typename TiledMma::AtomThrID; using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; using CtaTileShape_MNK = decltype(shape_div(TileShape_MNK{}, AtomThrShapeMNK{})); @@ -163,7 +184,7 @@ struct CollectiveBuilder< using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); // Take 3 compute buffers into account for swizzle selection - using SmemLayoutAtomACompute = decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); // Input transform kernel can not use TMA 2SM instructions. @@ -174,7 +195,7 @@ struct CollectiveBuilder< static constexpr int MMA_M = cute::size<0,0>(MmaShapeA_MK{}); using CopyAtomPairA = cutlass::gemm::collective::detail::CollectiveMmaEmulatedCopyType< Copy_Atom, ElementA>, - cute::conditional_t<(UmmaMajorA == cute::UMMA::Major::K && !cute::is_base_of_v), + cute::conditional_t<(UmmaMajorACompute == cute::UMMA::Major::K && !BuilderTagIsSmem), cute::conditional_t<(MMA_M == 64 && size(AtomThrID{}) == 1), SM100_TMEM_STORE_16dp256b1x, SM100_TMEM_STORE_32dp32b8x>, // TS Implementation Copy_Atom, ElementA>> // SS Implementation >; @@ -188,7 +209,7 @@ struct CollectiveBuilder< using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); // Take 3 compute buffers into account for swizzle selection - using SmemLayoutAtomBCompute = decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); using SmemLayoutAtomPairB = cutlass::gemm::collective::detail::CollectiveMmaEmulatedLayoutAtomType< @@ -199,11 +220,16 @@ struct CollectiveBuilder< >; // SmemCarveout - static constexpr int NumBandsToCompute = 5; - static constexpr int AccPromotionInterval = 1; + static constexpr int NumComputeMtxs = + 3; + static constexpr int NumBandsToCompute = + 5; + static constexpr int AccPromotionInterval = + 1; static constexpr int SchedulerPipelineStageCount = 3; - static constexpr bool IsArrayOfPointersGemm = (cute::is_base_of_v); - + static constexpr bool IsArrayOfPointersGemm = + (cute::is_base_of_v + ); // CLCPipeline = PipelineCLCFetchAsync static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); // CLC (scheduler) response @@ -226,9 +252,13 @@ struct CollectiveBuilder< TensorMapStorage); // Reduce SMEM capacity available for buffers considering extra B smem and barrier smem allocations - static constexpr int Sm100ReducedSmemCapacityBytes = detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + + static constexpr int ReducedSmemCapacityBytes = + cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; static constexpr auto stage_info = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_fast_fp32< - Sm100ReducedSmemCapacityBytes, CtaTileShape_MNK, TiledMma, BuilderScheduleTag, UmmaMajorA>(StageCountType{}); + ReducedSmemCapacityBytes, CtaTileShape_MNK, TiledMma, BuilderScheduleTag, UmmaMajorACompute, + /*Cmplx=*/ 1, /*Mtxs=*/ NumComputeMtxs + >(StageCountType{}); static constexpr int Load2TransformPipelineStageCount = get<0>(stage_info); static constexpr int Transform2MmaPipelineStageCount = get<1>(stage_info); diff --git a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_blockscaled_mixed_tma_cpasync_umma_builder.inl b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_blockscaled_mixed_tma_cpasync_umma_builder.inl new file mode 100644 index 0000000..de9158d --- /dev/null +++ b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_blockscaled_mixed_tma_cpasync_umma_builder.inl @@ -0,0 +1,274 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/gemm/collective/builders/sm100_common.inl" + +#include "cutlass/gemm/collective/collective_builder_decl.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +namespace detail { + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template < + int CapacityBytes, + class ElementA, + class ElementB, + class TileShapeMNK, + class TileShapeSFA, + class TileShapeSFB, + int stages +> +constexpr int +sm100_compute_stage_count_or_override_blockscaled_mixed_tma_cpasync(StageCount stage_count) { + return stages; +} + +template < + int CapacityBytes, + class ElementA, + class ElementB, + class TileShapeMNK, + class TileShapeSFA, + class TileShapeSFB, + int carveout_bytes +> +constexpr int +sm100_compute_stage_count_or_override_blockscaled_mixed_tma_cpasync(StageCountAutoCarveout stage_count) { + // For MXF8F6F4 MMA, ElementA/B will be passed in as uint8_t + // Each stage include (CollectiveMma::SharedStorage) + // 1. smem for A and smem for B (CollectiveMma::SharedStorage::TensorStorage) + // 2. one MainloopPipeline = PipelineTmaUmmaAsync (CollectiveMma::SharedStorage::SharedStorage) + // 3. smem for SFB and smem for SFB (CollectiveMma::SharedStorage::TensorStorage, independent of input size b.c. sizeof(sf) is fixed) + constexpr auto mainloop_pipeline_bytes = + sizeof(typename cutlass::PipelineTmaUmmaAsync<1>::SharedStorage) + + sizeof(typename cutlass::PipelineUmmaConsumerAsync<1>::SharedStorage); + + constexpr auto a_bits = cute::sizeof_bits_v; + constexpr auto b_bits = cute::sizeof_bits_v; + constexpr auto stage_sfa_bytes = size(filter_zeros(TileShapeSFA{})); + constexpr auto stage_sfb_bytes = size(filter_zeros(TileShapeSFB{})); + + constexpr int stage_bytes = + cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + static_cast(mainloop_pipeline_bytes + stage_sfa_bytes + stage_sfb_bytes); + + return (CapacityBytes - carveout_bytes) / stage_bytes; +} + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class ElementPairA, + class GmemLayoutATag, + int AlignmentA, + class ElementPairB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class BuilderScheduleTag +> +struct CollectiveBuilder< + arch::Sm100, + arch::OpClassBlockScaledTensorOp, + ElementPairA, + GmemLayoutATag, + AlignmentA, + ElementPairB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + ClusterShape_MNK, // Static cluster shape (_1, _1, _1) + StageCountType, + BuilderScheduleTag, + cute::enable_if_t > +> +{ + using ElementSFA = typename detail::blockscaled::blockscaled_type::sf_type; + using ElementSFB = typename detail::blockscaled::blockscaled_type::sf_type; + using ElementA = typename detail::blockscaled::blockscaled_type::data_type; + using ElementB = typename detail::blockscaled::blockscaled_type::data_type; + using ElementSF = ElementSFA; + + static constexpr cute::UMMA::Major UmmaMajorA = cutlass::gemm::collective::detail::tag_to_umma_major_A(); + static constexpr cute::UMMA::Major UmmaMajorB = cutlass::gemm::collective::detail::tag_to_umma_major_B(); + + static_assert(cute::is_static_v, "TileShape has to be static"); + static_assert(detail::blockscaled::check_input_datatypes(), "Incorrect input types"); + + static constexpr bool is_2sm = false; // detail::blockscaled::is_2sm(); + static constexpr auto Instr = detail::blockscaled::select_instr(); + + using TiledMma = typename cutlass::gemm::collective::detail::TrivialBlockscaledMma::type; + + static constexpr bool UseMxf8f6f4 = Instr == detail::blockscaled::BlockScaledInstr::MXF4F6F8; + + static_assert(UseMxf8f6f4 || (cutlass::gemm::detail::is_k_major_A() && cutlass::gemm::detail::is_k_major_B()), "Only MMA.MXF8F6F4 supports non-K major inputs"); + + // Data type used by MMA instruction + using ElementAMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element()); + using ElementBMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element()); + + static_assert(detail::sm1xx_gemm_check_for_f8f6f4_mix8bit_requirement(), + "TileSize and MNK Major does not met with MMA Mix 8-bit TMA load requirement" ); + + static constexpr uint32_t SFVectorSize = TiledMma::SFVecSize; + + using ElementAMma_SmemAllocType = cute::conditional_t; + using ElementBMma_SmemAllocType = cute::conditional_t; + + // using TiledMma = decltype(detail::sm100_make_trivial_tiled_mma< + // ElementAMma, ElementBMma, ElementAccumulator, + // decltype(cute::product_each(TileShape_MNK{})), ClusterShape_MNK, + // UmmaMajorA, UmmaMajorB, BuilderScheduleTag>()); + + using AtomThrID = typename TiledMma::AtomThrID; + using Sm1xxBlkScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig; + + // ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K) + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + // ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K) + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + + // Assigning 4 warps for mainloop load of B + static constexpr int NumLoadThreadsCpAsync = 128; + + + using SmemShapeA_M = decltype(shape_div(shape<0>(TileShape_MNK{}), shape_div(shape<0>(TileShape_MNK{}), size<0>(TileShape_MNK{}) / size(AtomThrID{})))); + using SmemShapeA_K = decltype(cute::get<2>(TileShape_MNK{})); + + + using GmemTiledCopyA = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_A( + ClusterShape_MNK{}, AtomThrID{})); + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + UmmaMajorA, ElementAMma_SmemAllocType, SmemShapeA_M, SmemShapeA_K>()); + + using AlignmentTypeB = cute::uint_byte_t(cutlass::sizeof_bits::value) * AlignmentB / 8>; + using GmemCopyAtomB = cute::Copy_Atom, ElementB>; + using GmemTiledCopyB = decltype(detail::make_simt_gmem_tiled_copy< + GmemCopyAtomB, NumLoadThreadsCpAsync, AlignmentB, TagToStrideB_t, + decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + using BlockTileB_N = decltype(cute::size<0,0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{})); + using BlockTileB_K = decltype(cute::size<0,1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{})); + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + UmmaMajorB, ElementBMma_SmemAllocType, BlockTileB_N, BlockTileB_K>()); + + using GmemTiledCopySFA = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_A( + ClusterShape_MNK{}, AtomThrID{})); + using GmemTiledCopySFB = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_SFB( + ClusterShape_MNK{}, AtomThrID{})); + + using GmemTiledCopyPairA = decltype(cute::make_tuple(GmemTiledCopyA{}, GmemTiledCopySFA{})); + using GmemTiledCopyPairB = decltype(cute::make_tuple(GmemTiledCopyB{}, GmemTiledCopySFB{})); + + using SmemLayoutAtomSFA = decltype(Sm1xxBlkScaledConfig::deduce_smem_layoutSFA(TiledMma{}, TileShape_MNK{})); + using SmemLayoutAtomSFB = decltype(Sm1xxBlkScaledConfig::deduce_smem_layoutSFB(TiledMma{}, TileShape_MNK{})); + using SmemLayoutAtomsA = decltype(cute::make_tuple(SmemLayoutAtomA{}, SmemLayoutAtomSFA{})); + using SmemLayoutAtomsB = decltype(cute::make_tuple(SmemLayoutAtomB{}, SmemLayoutAtomSFB{})); + + + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideB = cutlass::gemm::TagToStrideB_t; + using InternalStrideA = cute::remove_pointer_t; + using InternalStrideB = cute::remove_pointer_t; + using InternalLayoutSFA = decltype(Sm1xxBlkScaledConfig::deduce_layoutSFA()); + using InternalLayoutSFB = decltype(Sm1xxBlkScaledConfig::deduce_layoutSFB()); + using LayoutSFA = cute::conditional_t, InternalLayoutSFA, InternalLayoutSFA *>; + using LayoutSFB = cute::conditional_t, InternalLayoutSFB, InternalLayoutSFB *>; + using StridePairA = decltype(cute::make_tuple(StrideA{}, LayoutSFA{})); + using StridePairB = decltype(cute::make_tuple(StrideB{}, LayoutSFB{})); + + static constexpr uint32_t AccumulatorPipelineStageCount = 2; + // Calculate scheduler pipeline stages. Having one more stage than the accumulator allows more latency hiding. + static constexpr uint32_t SchedulerPipelineStageCount = AccumulatorPipelineStageCount + 1; + + // AccumulatorPipeline = PipelineUmmaAsync + static constexpr auto AccumulatorPipelineStorage = sizeof(typename cutlass::PipelineUmmaAsync::SharedStorage); + // CLCPipeline = PipelineCLCFetchAsync + static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); + // CLC (scheduler) response + static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize; + // Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage + static constexpr auto KernelSmemCarveout = static_cast( AccumulatorPipelineStorage + + CLCPipelineStorage + + CLCResponseStorage); + // Reduce SMEM capacity available for buffers considering barrier allocations. + static constexpr int Sm100ReducedSmemCapacityBytes = cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + + using SmemTileShape = cute::Shape; + + static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_blockscaled_mixed_tma_cpasync< + Sm100ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, SmemLayoutAtomSFA, SmemLayoutAtomSFB>(StageCountType{}); + static_assert(PipelineStages > 0, "Smem usage is too high. Can't create any SMEM buffers for A, B, SFA, and SFB."); + + using CollectiveOp = cutlass::gemm::collective::CollectiveMma< + cutlass::gemm::MainloopSm100UmmaMixedTmaCpAsyncWarpSpecializedBlockScaled< + PipelineStages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape_MNK>, + TileShape_MNK, + cute::tuple, + StridePairA, + cute::tuple, + StridePairB, + TiledMma, + GmemTiledCopyPairA, + SmemLayoutAtomsA, + void, + cute::identity, + GmemTiledCopyPairB, + SmemLayoutAtomsB, + void, + cute::identity + >; +}; + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_blockscaled_sparse_umma_builder.inl b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_blockscaled_sparse_umma_builder.inl index b8824c2..3d19e76 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_blockscaled_sparse_umma_builder.inl +++ b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_blockscaled_sparse_umma_builder.inl @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -44,6 +44,7 @@ namespace detail { // Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. template < + int CapacityBytes, class ElementAMma, class ElementB, class ElementEMma, @@ -62,6 +63,7 @@ sm100_compute_stage_count_or_override_blockscaled_sparse(StageCount stag // Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. template < + int CapacityBytes, class ElementAMma, class ElementB, class ElementEMma, @@ -110,7 +112,7 @@ sm100_compute_stage_count_or_override_blockscaled_sparse(StageCountAutoCarveout< constexpr auto EpilogueSharedStorage = carveout_bytes; - constexpr auto Stages = (cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout - EpilogueSharedStorage) / + constexpr auto Stages = (CapacityBytes - KernelSmemCarveout - EpilogueSharedStorage) / (MainloopTensorStorage_per_Stage + MainloopPipelineStorage_per_Stage_aligned); return Stages; @@ -121,6 +123,7 @@ sm100_compute_stage_count_or_override_blockscaled_sparse(StageCountAutoCarveout< ///////////////////////////////////////////////////////////////////////////////////////////////// template < + class ArchTag, class ElementPairA, class GmemLayoutATag, int AlignmentA, @@ -134,7 +137,7 @@ template < class BuilderScheduleTag > struct CollectiveBuilder< - arch::Sm100, + ArchTag, arch::OpClassBlockScaledSparseTensorOp, ElementPairA, GmemLayoutATag, @@ -148,6 +151,8 @@ struct CollectiveBuilder< StageCountType, BuilderScheduleTag, cute::enable_if_t< + (cute::is_same_v + ) && // Blockscaled Sparse Gemm cute::is_base_of_v && @@ -272,7 +277,12 @@ struct CollectiveBuilder< using SmemTileShape = cute::Shape; + // Calculate SMEM capacity based on ArchTag + static constexpr int ReducedSmemCapacityBytes = + cutlass::gemm::collective::detail::sm100_smem_capacity_bytes; + static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_blockscaled_sparse< + ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, ElementEMma, diff --git a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl index 3556fad..923ce9c 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl +++ b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -73,7 +73,7 @@ sm100_compute_stage_count_or_override_blockscaled(StageCountAutoCarveout::SharedStorage); + constexpr auto mainloop_pipeline_bytes = cutlass::round_up(sizeof(typename cutlass::PipelineTmaUmmaAsync<1>::SharedStorage), 128); constexpr auto a_bits = cute::sizeof_bits_v; constexpr auto b_bits = cute::sizeof_bits_v; constexpr auto stage_sfa_bytes = size(filter_zeros(TileShapeSFA{})); @@ -92,6 +92,7 @@ sm100_compute_stage_count_or_override_blockscaled(StageCountAutoCarveout struct CollectiveBuilder< - arch::Sm100, + ArchTag, arch::OpClassBlockScaledTensorOp, ElementPairA, GmemLayoutATag, @@ -119,7 +120,10 @@ struct CollectiveBuilder< StageCountType, BuilderScheduleTag, cute::enable_if_t< + (cute::is_same_v + ) && // Blockscaled Gemm + (not cute::is_same_v) && (cute::is_base_of_v || cute::is_same_v) && @@ -237,8 +241,9 @@ struct CollectiveBuilder< static constexpr uint32_t AccumulatorPipelineStageCount = (MMA_N == 256) ? 1 : 2; static constexpr bool IsArrayOfPointersGemm = cute::is_base_of_v; // Grouped GEMM(where Stride type is Stride*) uses specific static tile scheduler. - static constexpr bool IsGroupGemm = !cute::is_same_v; - static constexpr uint32_t SchedulerPipelineStageCount = cute::conditional_return(8, 2); + static constexpr bool IsGroupGemm = !(cute::is_same_v) && !(cute::is_same_v); + static constexpr bool IsRCGroupGemm = (cute::is_same_v) && !(cute::is_same_v); + static constexpr uint32_t SchedulerPipelineStageCount = cute::conditional_return(8, 2); static constexpr uint32_t KernelSmemCarveout = detail::Sm100DenseGemmTmaUmmaCarveout< ClusterShape_MNK, @@ -249,23 +254,33 @@ struct CollectiveBuilder< 4 // 4 Tensor maps for A, SFA, B and SFB >::KernelSmemCarveout; // Reduce SMEM capacity available for buffers considering barrier allocations. - static constexpr int Sm100ReducedSmemCapacityBytes = cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + + static constexpr int ReducedSmemCapacityBytes = + cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; using SmemTileShape = cute::Shape; static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_blockscaled< - Sm100ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, SmemLayoutAtomSFA, SmemLayoutAtomSFB>(StageCountType{}); + ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, SmemLayoutAtomSFA, SmemLayoutAtomSFB>(StageCountType{}); static_assert(PipelineStages > 0, "Smem usage is too high. Can't create any SMEM buffers for A, B, SFA, and SFB."); using DispatchPolicy = cute::conditional_t, - cutlass::gemm::MainloopSm100TmaUmmaWarpSpecializedBlockScaled< + cute::conditional_t, + cutlass::gemm::MainloopSm100ArrayTmaUmmaWarpSpecializedBlockScaled< + PipelineStages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape_MNK + > + >, + cutlass::gemm::MainloopSm100TmaUmmaWarpSpecializedBlockScaled< PipelineStages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, diff --git a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_blockwise_umma_builder.inl b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_blockwise_umma_builder.inl index 8617e88..fc9911d 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_blockwise_umma_builder.inl +++ b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_blockwise_umma_builder.inl @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -102,8 +102,10 @@ sm100_compute_stage_count_or_override_blockwise(StageCountAutoCarveout(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + - cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) + ) + + cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) + ) + cutlass::bits_to_bytes(scale_bits * size<0>(ScaleShapeMNK{}) * size<2>(ScaleShapeMNK{})) + cutlass::bits_to_bytes(scale_bits * size<1>(ScaleShapeMNK{}) * size<2>(ScaleShapeMNK{})), 128) + @@ -132,7 +134,7 @@ auto sm100_make_simt_gmem_tiled_copy_SFA() { return make_tiled_copy( SmemScalingCopyAtomA{}, Layout>{}, // 32 threads - Layout, Int>>, Stride>>{}); + Layout>>{}); } else { using SmemScalingCopyAtomA = Copy_Atom, Element>; @@ -166,7 +168,7 @@ auto sm100_make_simt_gmem_tiled_copy_SFB() { return make_tiled_copy( SmemScalingCopyAtomB{}, Layout>{}, // 32 threads - Layout, Int>>, Stride>>{}); + Layout>>{}); } else { using SmemScalingCopyAtomB = Copy_Atom, Element>; @@ -237,6 +239,7 @@ sm100_make_trivial_tiled_mma_blockwise() { ///////////////////////////////////////////////////////////////////////////////////////////////// template < + class ArchTag, class ElementA, class GmemLayoutATagPair, int AlignmentA, @@ -250,7 +253,7 @@ template < class BuilderScheduleTag > struct CollectiveBuilder< - arch::Sm100, + ArchTag, arch::OpClassTensorOp, ElementA, GmemLayoutATagPair, @@ -264,6 +267,8 @@ struct CollectiveBuilder< StageCountType, BuilderScheduleTag, cute::enable_if_t< + (cute::is_same_v + ) && not cute::is_tuple_v && not cute::is_tuple_v && not cute::is_complex_v && not cute::is_complex_v && cute::is_tuple_v && cute::is_tuple_v && @@ -369,7 +374,9 @@ struct CollectiveBuilder< IsArrayOfPointersGemm >::KernelSmemCarveout; // Reduce SMEM capacity available for buffers considering barrier allocations. - static constexpr int Sm100ReducedSmemCapacityBytes = cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + + static constexpr int ReducedSmemCapacityBytes = + cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; using SmemTileShape = cute::Shape; using MainloopABPipelineStorage = typename cutlass::PipelineTmaUmmaAsync<1>::SharedStorage; @@ -399,7 +406,7 @@ struct CollectiveBuilder< using ScaleTileShape = cute::Shape; static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_blockwise< - Sm100ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, + ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, ElementAccumulator, ScaleTileShape, SmemTileShape, MainloopABPipelineStorage, MainloopSFPipelineStorage>(StageCountType{}); static_assert(PipelineStages > 0, "Smem usage is too high. Can't create any SMEM buffers for A, B, and scales."); @@ -436,7 +443,6 @@ struct CollectiveBuilder< >; }; - } // namespace cutlass::gemm::collective ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_common.inl b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_common.inl index 6230c61..284def1 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_common.inl +++ b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_common.inl @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -220,7 +220,7 @@ sm100_cluster_shape_to_tma_atom_A(ClusterShapeMNK cluster_shape_mnk, AtomThrId a } else { // In the case of dynamic cluster, multicast decision is not known at compile time. - // A multicast instruction is forced by passing a cute::Int<2>{} to this helper. + // A multicast instruction is forced by passing a cute::Int<2>{} to this helper. return detail::sm90_cluster_shape_to_tma_atom(cute::Int<2>{}); } } @@ -255,7 +255,7 @@ sm100_cluster_shape_to_tma_atom_B(ClusterShapeMNK cluster_shape_mnk, AtomThrId a } else { // In the case of dynamic cluster, multicast decision is not known at compile time. - // A multicast instruction is forced by passing a cute::Int<2>{} to this helper. + // A multicast instruction is forced by passing a cute::Int<2>{} to this helper. return detail::sm90_cluster_shape_to_tma_atom(cute::Int<2>{}); } } @@ -281,7 +281,7 @@ sm100_cluster_shape_to_tma_atom_SFB(ClusterShapeMNK cluster_shape_mnk, AtomThrId } else { // In the case of dynamic cluster, multicast decision is not known at compile time. - // A multicast instruction is forced by passing a cute::Int<2>{} to this helper. + // A multicast instruction is forced by passing a cute::Int<2>{} to this helper. return detail::sm90_cluster_shape_to_tma_atom(cute::Int<2>{}); } } @@ -328,24 +328,24 @@ sm100_make_1sm_trivial_tiled_mma() { return make_tiled_mma(cute::SM100_MMA_S8_SS{}); } - else if constexpr (cute::is_same_v - || cute::is_same_v - || cute::is_same_v + else if constexpr (cute::is_same_v + || cute::is_same_v + || cute::is_same_v || cute::is_same_v || cute::is_same_v || cute::is_same_v || cute::is_same_v || cute::is_same_v ) { - + return make_tiled_mma( cute::MMA_Traits< cute::SM100_MMA_F8F6F4_SS, ElementAMma, - ElementBMma, - ElementAMmaccumulator, - cute::C, - cute::C, + ElementBMma, + ElementAMmaccumulator, + cute::C, + cute::C, cute::integral_constant, cute::integral_constant, cute::integral_constant, @@ -396,9 +396,9 @@ sm100_make_2sm_trivial_tiled_mma() { return make_tiled_mma(cute::SM100_MMA_S8_2x1SM_SS{}); } - else if constexpr (cute::is_same_v - || cute::is_same_v - || cute::is_same_v + else if constexpr (cute::is_same_v + || cute::is_same_v + || cute::is_same_v || cute::is_same_v || cute::is_same_v || cute::is_same_v @@ -408,12 +408,12 @@ sm100_make_2sm_trivial_tiled_mma() { return make_tiled_mma( cute::MMA_Traits< - cute::SM100_MMA_F8F6F4_2x1SM_SS, + cute::SM100_MMA_F8F6F4_2x1SM_SS, ElementAMma, ElementBMma, - ElementAMmaccumulator, - cute::C, - cute::C, + ElementAMmaccumulator, + cute::C, + cute::C, cute::integral_constant, cute::integral_constant, cute::integral_constant, @@ -471,7 +471,7 @@ sm100_make_trivial_tiled_mma() { return sm100_make_1sm_trivial_tiled_mma(); } - // Dynamic cluster shape means we cannot assume we can use 2SM MMA + // Dynamic cluster shape means we cannot assume we can use 2SM MMA } else { return sm100_make_1sm_trivial_tiled_mma +constexpr auto +sm100_make_1sm_ts_trivial_tiled_mma() { + + constexpr int M = cute::size<0>(ClusterTileShape_MNK{}); + static_assert(M == 64 || M == 128, "Invalid TileShape_M."); + + // Do not allow a tiled MMA N mode > 1, as that is not reasonable. + constexpr int N = cute::size<1>(ClusterTileShape_MNK{}); + static_assert(N % 8 == 0 && N <= 256, "Invalid TileShape_N."); + + if constexpr (cute::is_same_v) { + static_assert(cute::is_same_v, "ElementA and ElementB must match."); + return make_tiled_mma(cute::SM100_MMA_TF32_TS{}); + } + else if constexpr (cute::is_same_v || + cute::is_same_v) { + static_assert(cute::is_same_v, "ElementA and ElementB must match."); + return make_tiled_mma(cute::SM100_MMA_F16BF16_TS{}); + } + else if constexpr (cute::is_same_v || + cute::is_same_v) { + return make_tiled_mma(cute::SM100_MMA_S8_TS{}); + } + else { + static_assert(cutlass::detail::dependent_false, + "Unsupported configuration for SM100 collective builder."); + } +} + +template< + class ElementAMma, + class ElementBMma, + class ElementAccumulator, + class ClusterTileShape_MNK, + class ClusterShape_MNK, + UMMA::Major UmmaMajorA, + UMMA::Major UmmaMajorB, + UMMA::ScaleIn ANeg = UMMA::ScaleIn::One, + UMMA::ScaleIn BNeg = UMMA::ScaleIn::One +> +constexpr auto +sm100_make_2sm_ts_trivial_tiled_mma() { + + constexpr int M = cute::size<0>(ClusterTileShape_MNK{}); + static_assert(M == 128 || M == 256, "Invalid TileShape_M."); + + // Do not allow a tiled MMA N mode > 1, as that is not reasonable. + constexpr int N = cute::size<1>(ClusterTileShape_MNK{}); + static_assert(N % 8 == 0 && N <= 256, "Invalid TileShape_N."); + + if constexpr (cute::is_same_v) { + static_assert(cute::is_same_v, "For SM100 TF32 MMA, ElementA and ElementB must match."); + return make_tiled_mma(cute::SM100_MMA_TF32_2x1SM_TS{}); + } + else if constexpr (cute::is_same_v || + cute::is_same_v) { + static_assert(cute::is_same_v, "For SM100 F16F32 MMA, ElementA and ElementB must match."); + return make_tiled_mma(cute::SM100_MMA_F16BF16_2x1SM_TS{}); + } + else if constexpr (cute::is_same_v || + cute::is_same_v) { + return make_tiled_mma(cute::SM100_MMA_S8_2x1SM_TS{}); + } + else { + static_assert(cutlass::detail::dependent_false, + "Unsupported configuration for SM100 collective builder."); + } +} + template< class ElementAMma, class ElementBMma, @@ -493,12 +579,16 @@ template< > constexpr auto sm100_make_trivial_fastFP32_tiled_mma() { + constexpr bool TagHasUmmaSs = ( + cute::is_base_of_v + ); + // MMA_2SM requested if constexpr (cute::is_base_of_v ) { using AtomLayout_MNK = decltype(make_layout(shape_div(ClusterShape_MNK{}, Shape<_2,_1,_1>{}))); constexpr int M = cute::size<0>(TileShape_MNK{}); constexpr int N = cute::size<1>(TileShape_MNK{}); - if constexpr (UmmaMajorA == cute::UMMA::Major::K && !cute::is_base_of_v) { + if constexpr (UmmaMajorA == cute::UMMA::Major::K && !TagHasUmmaSs) { return make_tiled_mma(cute::SM100_MMA_F16BF16_2x1SM_TS_SCALED{}); } @@ -512,7 +602,7 @@ sm100_make_trivial_fastFP32_tiled_mma() { // using AtomLayout_MNK = Layout; constexpr int M = cute::size<0>(TileShape_MNK{}); constexpr int N = cute::size<1>(TileShape_MNK{}); - if constexpr (UmmaMajorA == cute::UMMA::Major::K && !cute::is_base_of_v) { + if constexpr (UmmaMajorA == cute::UMMA::Major::K && !TagHasUmmaSs) { return make_tiled_mma(cute::SM100_MMA_F16BF16_TS_SCALED{}); } @@ -531,7 +621,7 @@ sm100_make_trivial_fastFP32_tiled_mma() { // and only M=128 and M=256 are supported, otherwise, fall back to MMA_1SM if constexpr (cute::get<0>(ClusterShape_MNK{}) % 2 == 0 && (cute::get<0>(TileShape_MNK{}) / cute::get<0>(ClusterShape_MNK{})) % 64 == 0) { - if constexpr (!cute::is_base_of_v) { + if constexpr (!TagHasUmmaSs) { return sm100_make_trivial_fastFP32_tiled_mma(); } @@ -541,7 +631,7 @@ sm100_make_trivial_fastFP32_tiled_mma() { } } else { - if constexpr (!cute::is_base_of_v) { + if constexpr (!TagHasUmmaSs) { return sm100_make_trivial_fastFP32_tiled_mma(); } @@ -551,9 +641,9 @@ sm100_make_trivial_fastFP32_tiled_mma() { } } } - // Dynamic cluster shape means we cannot assume we can use 2SM MMA + // Dynamic cluster shape means we cannot assume we can use 2SM MMA else { - if constexpr (!cute::is_base_of_v) { + if constexpr (!TagHasUmmaSs) { return sm100_make_trivial_fastFP32_tiled_mma(); } @@ -569,6 +659,160 @@ sm100_make_trivial_fastFP32_tiled_mma() { } } +template< + class TileShape_MNK, + class ClusterShape_MNK, + class BuilderScheduleTag +> +constexpr auto +sm100_make_trivial_interleaved_complex_tf32_tiled_mma() { + // MMA_2SM requested + if constexpr (cute::is_base_of_v ) { + constexpr int M = cute::size<0>(TileShape_MNK{}); + static_assert(M == 128 || M == 256, "Invalid TileShape_M."); + + // Do not allow a tiled MMA N mode > 1, as that is not reasonable. + constexpr int N = cute::size<1>(TileShape_MNK{}); + static_assert(N % 8 == 0 && N <= 256, "Invalid TileShape_N."); + return make_tiled_mma(cute::SM100_MMA_TF32_2x1SM_TS_INTERLEAVED_CF32CTF32CTF32CF32_TN{}); + } + // MMA_1SM requested + else if constexpr (cute::is_base_of_v ) { + constexpr int M = cute::size<0>(TileShape_MNK{}); + static_assert(M == 64 || M == 128, "Invalid TileShape_M."); + // Do not allow a tiled MMA N mode > 1, as that is not reasonable. + constexpr int N = cute::size<1>(TileShape_MNK{}); + static_assert(N % 8 == 0 && N <= 256, "Invalid TileShape_N."); + return make_tiled_mma(cute::SM100_MMA_TF32_TS_INTERLEAVED_CF32CTF32CTF32CF32_TN{}); + } + else if constexpr (cute::is_same_v) { + // Static cluster + if constexpr (cute::is_static_v) { + // For MMA_2SM we need a cluster shape that is multiple of 2x1 + // and only M=128 and M=256 are supported, otherwise, fall back to MMA_1SM + if constexpr (cute::get<0>(ClusterShape_MNK{}) % 2 == 0 && + cute::size<0>(TileShape_MNK{}) % 128 == 0) { + return sm100_make_trivial_interleaved_complex_tf32_tiled_mma(); + } + else { + return sm100_make_trivial_interleaved_complex_tf32_tiled_mma(); + } + } + // Dynamic cluster shape means we cannot assume we can use 2SM MMA + else { + return sm100_make_trivial_interleaved_complex_tf32_tiled_mma(); + } + } +} + +//Setting mma for Mixed input gemm. Here, ElementAMma should be TACompute +template< + class ElementAMma, + class ElementBMma, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + UMMA::Major UmmaMajorA, + UMMA::Major UmmaMajorB, + class KernelScheduleType +> +constexpr auto +sm100_make_trivial_mixed_input_tiled_mma() { + constexpr int M = cute::size<0>(TileShape_MNK{}); + constexpr int N = cute::size<1>(TileShape_MNK{}); + //MMA 1Sm requested + if constexpr (cute::is_base_of_v ) { + if constexpr (UmmaMajorA == cute::UMMA::Major::K && !cute::is_base_of_v) { + if constexpr (cute::is_same_v || cute::is_same_v) { + return make_tiled_mma(cute::SM100_MMA_F16BF16_TS{}); + } + if constexpr (cute::is_same_v) { + return make_tiled_mma(cute::SM100_MMA_F8F6F4_TS{}); + } + } + else { // If A needs to be transposed by MMA, fall back to SMEM from A MMA instructions + if constexpr (cute::is_same_v || cute::is_same_v) { + return make_tiled_mma(cute::SM100_MMA_F16BF16_SS{}); + } + if constexpr (cute::is_same_v) { + return make_tiled_mma( + cute::MMA_Traits< + cute::SM100_MMA_F8F6F4_SS, + ElementAMma, + ElementBMma, + ElementAccumulator, + cute::C, + cute::C, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant>{}); + } + } + } + //MMA 2Sm requested + else if constexpr (cute::is_base_of_v) { + if constexpr (UmmaMajorA == cute::UMMA::Major::K && !cute::is_base_of_v) { + if constexpr (cute::is_same_v || cute::is_same_v) { + return make_tiled_mma(cute::SM100_MMA_F16BF16_2x1SM_TS{}); + } + if constexpr (cute::is_same_v) { + return make_tiled_mma(cute::SM100_MMA_F8F6F4_2x1SM_TS{}); + } + } + } + else { + static_assert(cutlass::detail::dependent_false == 0, + "Unsupported policy for SM100 collective builder."); + } +} + +template< + class CtaShape_MNK +> +constexpr auto +sm100_simt_f32_warp_shape_mnk_selector() { + using namespace cute; + + constexpr int CtaShape_M = cute::size<0>(CtaShape_MNK{}); + constexpr int CtaShape_N = cute::size<1>(CtaShape_MNK{}); + constexpr int CtaShape_K = cute::size<2>(CtaShape_MNK{}); + + // CTA tile shape M and N are supposed to be divisible by 32. + static_assert(CtaShape_M % 32 == 0, "CtaShape_M needs to be divisible by 32."); + static_assert(CtaShape_N % 32 == 0, "CtaShape_N needs to be divisible by 32."); + + // WarpShape_MNK configuration + // We assume WarpShape_K is always 1 in our SM100 SIMT SGEMM implementation. + if constexpr (CtaShape_M >= CtaShape_N) { + if constexpr (CtaShape_M == 256 && CtaShape_N == 128) { + return cute::Shape<_4, _2, _1>{}; + } + else if constexpr ((CtaShape_M == 64 || CtaShape_M == 32) && CtaShape_N == 32) { + return cute::Shape<_1, _2, _1>{}; + } + else { + return cute::Shape<_2, _2, _1>{}; + } + } + else { + if constexpr (CtaShape_M == 128 && CtaShape_N == 256) { + return cute::Shape<_2, _4, _1>{}; + } + else if constexpr (CtaShape_M == 32 && CtaShape_N == 64) { + return cute::Shape<_1, _2, _1>{}; + } + else { + return cute::Shape<_1, _4, _1>{}; + } + } +} + template < class ElementPairA, diff --git a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_cpasync_umma_builder.inl b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_cpasync_umma_builder.inl new file mode 100644 index 0000000..34ab79d --- /dev/null +++ b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_cpasync_umma_builder.inl @@ -0,0 +1,179 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/gemm/collective/builders/sm100_common.inl" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class ArchTag, + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class BuilderScheduleTag +> +struct CollectiveBuilder< + ArchTag, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + ClusterShape_MNK, // Static cluster shape or dynamic (int, int, _1) + StageCountType, + BuilderScheduleTag, + cute::enable_if_t< + (cute::is_same_v + ) && + (cute::is_same_v || + (cute::is_same_v && + (((sizeof(ElementA) * AlignmentA) % cutlass::gemm::collective::detail::tma_alignment_bytes != 0) || + ((sizeof(ElementB) * AlignmentB) % cutlass::gemm::collective::detail::tma_alignment_bytes != 0))))> +> +{ + static_assert(cute::is_static_v, "TileShape has to be static"); + + static constexpr cute::UMMA::Major UmmaMajorA = cutlass::gemm::collective::detail::tag_to_umma_major_A(); + static constexpr cute::UMMA::Major UmmaMajorB = cutlass::gemm::collective::detail::tag_to_umma_major_B(); + + // Data type used by MMA instruction + using ElementAMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element()); + using ElementBMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element()); + + using ElementAMma_SmemAllocType = cute::conditional_t < 8, uint8_t, ElementAMma>; + using ElementBMma_SmemAllocType = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using TiledMma = decltype(detail::sm100_make_trivial_tiled_mma< + ElementAMma, ElementBMma, ElementAccumulator, + decltype(cute::product_each(TileShape_MNK{})), ClusterShape_MNK, + UmmaMajorA, UmmaMajorB, BuilderScheduleTag>()); + + using AtomThrID = typename TiledMma::AtomThrID; + + // ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K) + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + // ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K) + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + + // Assigning 4 warps for mainloop load + static constexpr int NumLoadThreads = 128; + + using AlignmentTypeA = cute::uint_byte_t(sizeof(ElementA)) * AlignmentA>; + using GmemCopyAtomA = cute::Copy_Atom, ElementA>; + using GmemTiledCopyA = decltype(detail::make_simt_gmem_tiled_copy< + GmemCopyAtomA, NumLoadThreads, AlignmentA, TagToStrideA_t, + decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + using BlockTileA_M = decltype(cute::size<0,0>(MmaShapeA_MK{}) * cute::size<1>(MmaShapeA_MK{})); + using BlockTileA_K = decltype(cute::size<0,1>(MmaShapeA_MK{}) * cute::size<2>(MmaShapeA_MK{})); + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + UmmaMajorA, ElementAMma_SmemAllocType, BlockTileA_M, BlockTileA_K>()); + + using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * AlignmentB>; + using GmemCopyAtomB = cute::Copy_Atom, ElementB>; + using GmemTiledCopyB = decltype(detail::make_simt_gmem_tiled_copy< + GmemCopyAtomB, NumLoadThreads, AlignmentB, TagToStrideB_t, + decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + using BlockTileB_N = decltype(cute::size<0,0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{})); + using BlockTileB_K = decltype(cute::size<0,1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{})); + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + UmmaMajorB, ElementBMma_SmemAllocType, BlockTileB_N, BlockTileB_K>()); + + static constexpr uint32_t AccumulatorPipelineStageCount = 2; + // Calculate scheduler pipeline stages. Having one more stage than the accumulator allows more latency hiding. + static constexpr uint32_t SchedulerPipelineStageCount = AccumulatorPipelineStageCount + 1; + + // AccumulatorPipeline = PipelineUmmaAsync + static constexpr auto AccumulatorPipelineStorage = sizeof(typename cutlass::PipelineUmmaAsync::SharedStorage); + // CLCPipeline = PipelineCLCFetchAsync + static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); + // CLC (scheduler) response + static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize; + // Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage + static constexpr auto KernelSmemCarveout = static_cast( AccumulatorPipelineStorage + + CLCPipelineStorage + + CLCResponseStorage); + // Reduce SMEM capacity available for buffers considering barrier allocations. + + static constexpr int ReducedSmemCapacityBytes = + cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + + using SmemTileShape = cute::Shape; + using MainloopPipelineStorage = typename cutlass::PipelineUmmaConsumerAsync<1>::SharedStorage; + + static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override< + ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, MainloopPipelineStorage>(StageCountType{}); + + using CollectiveOp = cutlass::gemm::collective::CollectiveMma< + cutlass::gemm::MainloopSm100UmmaCpAsyncWarpSpecialized< + PipelineStages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape_MNK>, + TileShape_MNK, + ElementA, + cutlass::gemm::TagToStrideA_t, + ElementB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + void, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + void, + cute::identity + >; +}; + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_interleaved_complex_umma_builder.inl b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_interleaved_complex_umma_builder.inl new file mode 100644 index 0000000..e680143 --- /dev/null +++ b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_interleaved_complex_umma_builder.inl @@ -0,0 +1,264 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + + +#include "cutlass/gemm/collective/builders/sm100_common.inl" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +namespace detail { + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template< + int CapacityBytes, + class TileShapeMNK, + int stages +> +constexpr int +sm100_compute_stage_count_or_override_interleaved_complex_tf32(StageCount stage_count) { + return stages; +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template< + int CapacityBytes, + class TileShapeMNK, + int carveout_bytes +> +constexpr int +sm100_compute_stage_count_or_override_interleaved_complex_tf32(StageCountAutoCarveout stage_count) { + // Each stage include (CollectiveMma::SharedStorage) + // 1. smem for A and smem for B (CollectiveMma::SharedStorage::TensorStorage) + // 2. one Load2TransformPipeline = PipelineTmaTransformAsync + constexpr auto load2transform_pipeline_bytes = sizeof(typename cutlass::PipelineTmaTransformAsync<1>::SharedStorage); + constexpr auto a_bits = cute::sizeof_bits_v>; + constexpr auto b_bits = cute::sizeof_bits_v>; + constexpr int stage_bytes = + cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + static_cast(load2transform_pipeline_bytes); + + return (CapacityBytes - carveout_bytes) / stage_bytes; +} + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Interleaved complex tf32 TCGEN05 kernels builder +template < + class ArchTag, + class GmemLayoutATag, + class TransformA, + class GmemLayoutBTag, + class TransformB, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class BuilderScheduleTag +> +struct CollectiveBuilder< + ArchTag, + arch::OpClassTensorOp, + cute::tuple, TransformA>, + GmemLayoutATag, + 2, + cute::tuple, TransformB>, + GmemLayoutBTag, + 2, + cutlass::complex, + TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + ClusterShape_MNK, // Static cluster shape or dynamic (int, int, _1) + StageCountType, + BuilderScheduleTag, + cute::enable_if_t< + (cute::is_same_v + ) && + (cute::is_base_of_v || + cute::is_same_v)>> +{ + static_assert(cute::is_static_v, "TileShape_MNK has to be static"); + // ElementA and ElementB are cutlass::complex, which are used as GMEM input and output data type. + // ElementAMma and ElementBMma are cutlass::complex, which are used as SMEM and RF data type. + using ElementA = complex; + using ElementB = complex; + using ElementAccumulator = cutlass::complex; + using ElementAMma = complex; + using ElementBMma = complex; + + static constexpr cute::UMMA::Major UmmaMajorA = cutlass::gemm::collective::detail::tag_to_umma_major_A(); + static constexpr cute::UMMA::Major UmmaMajorB = cutlass::gemm::collective::detail::tag_to_umma_major_B(); + + using TiledMma = decltype(detail::sm100_make_trivial_interleaved_complex_tf32_tiled_mma< + TileShape_MNK,ClusterShape_MNK,BuilderScheduleTag>()); + + using AtomThrID = typename TiledMma::AtomThrID; + + // Define A and B block shapes for reduced size TMA_LOADs + // ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K) + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + // ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K) + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + // Input transform kernel can not use TMA 2SM instructions. + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(cute::size<1>(ClusterShape_MNK{}))); + + using BlockTileA_M = decltype(cute::size<0,0>(MmaShapeA_MK{}) * cute::size<1>(MmaShapeA_MK{})); + using BlockTileA_K = decltype(cute::size<0,1>(MmaShapeA_MK{}) * cute::size<2>(MmaShapeA_MK{})); + + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); + using SmemLayoutAtomACompute = decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); + using SmemLayoutAtomPairA = cutlass::gemm::collective::detail::Sm100CollectiveMmaComplexLayoutAtomType; + + static constexpr int MMA_M = cute::size<0>(TileShape_MNK{}); + + using CopyAtomPairA = cutlass::gemm::collective::detail::Sm100CollectiveMmaComplexCopyType< + Copy_Atom, ElementAMma>, + conditional_t + >; + + // Input transform kernel can not use TMA 2SM instructions. + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(cute::size<0>(ClusterShape_MNK{}))); + + using BlockTileB_N = decltype(cute::size<0,0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{})); + using BlockTileB_K = decltype(cute::size<0,1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{})); + + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); + using SmemLayoutAtomBCompute = decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); + using SmemLayoutAtomPairB = cutlass::gemm::collective::detail::Sm100CollectiveMmaComplexLayoutAtomType; + + using CopyAtomPairB = cutlass::gemm::collective::detail::Sm100CollectiveMmaComplexCopyType< + Copy_Atom, ElementBMma>, + Copy_Atom, ElementBMma> + >; + + // Calculate SMEM matrix A and B buffers' pipeline stages + static constexpr int MMA_N = cute::size<1>(TileShape_MNK{}); + static constexpr uint32_t TransformationStageCount = 4; + static constexpr uint32_t AccumulatorPipelineStageCount = (MMA_N >= 128) ? 1 : 2; + static constexpr uint32_t SchedulerPipelineStageCount = 3; + static constexpr bool IsArrayOfPointersGemm = (cute::is_base_of_v); + + // SmemCarveout + // B needs extra smem for smem tranpose (CollectiveMma::TensorStorageTransformed) + static constexpr auto TensorStorageTransformedSmemBStorage = TransformationStageCount * + static_cast(sizeof(ElementBMma)) * size<0>(BlockTileB_N{}) * size<0>(BlockTileA_K{}); + // CLCPipeline = PipelineCLCFetchAsync + static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); + // Transform2MmaPipeline = PipelineUmmaConsumerAsync (CollectiveMma) + static constexpr auto Transform2MmaPipelineStorage = sizeof(typename cutlass::PipelineUmmaConsumerAsync::SharedStorage); + // Mma2AccumPipeline = PipelineUmmaAsync (CollectiveMma) + static constexpr auto Mma2AccumPipelineStorage = sizeof(typename cutlass::PipelineUmmaAsync::SharedStorage); + // CLC (scheduler) response + static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize; + // CLC Throttle pipeline storage + static constexpr auto CLCThrottlePipelineStorage = sizeof(typename cutlass::PipelineAsync::SharedStorage); + // Tmem dealloc + static constexpr auto TmemDeallocStorage = sizeof(cutlass::arch::ClusterBarrier); + // Tmem ptr storage + static constexpr auto TmemBasePtrsStorage = sizeof(uint32_t); + // Tensormap Storage + + static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0; + + // Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage + static constexpr auto KernelSmemCarveout = static_cast( CLCPipelineStorage + + CLCResponseStorage + + CLCThrottlePipelineStorage + + Transform2MmaPipelineStorage + + Mma2AccumPipelineStorage + + TensorStorageTransformedSmemBStorage + + TmemDeallocStorage + + TmemBasePtrsStorage + + TensorMapStorage); + // Reduce SMEM capacity available for buffers considering extra B smem and barrier smem allocations + + static constexpr int ReducedSmemCapacityBytes = + cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + + using SmemTileShape = cute::Shape; + + static constexpr int PipelineStages_ = detail::sm100_compute_stage_count_or_override_interleaved_complex_tf32< + ReducedSmemCapacityBytes, SmemTileShape>(StageCountType{}); + // Complex kernels allow TileShape_N = 64, while SmemLayoutAtomB contains Swizzle<3,4,3>. + static constexpr int PipelineStages = size<1>(TileShape_MNK{}) == 64 ? PipelineStages_ / 2 * 2 : PipelineStages_; + static_assert(PipelineStages >= 2, "Pipeline Stages has to be at least 2"); + + using AccumulatorCopyAtom = cute::SM100_TMEM_LOAD_16dp256b1x; + + using DispatchPolicy = cute::conditional_t, + cutlass::gemm::MainloopSm100TmaUmmaWarpSpecializedInterleavedComplexTF32< + PipelineStages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + TransformationStageCount, + ClusterShape_MNK, + AccumulatorCopyAtom + > + >; + + using CollectiveOp = cutlass::gemm::collective::CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + cutlass::gemm::TagToStrideA_t, + ElementB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomPairA, + CopyAtomPairA, + TransformA, + GmemTiledCopyB, + SmemLayoutAtomPairB, + CopyAtomPairB, + TransformB + >; +}; + +} // cutlass::gemm::collective diff --git a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_mixed_input_umma_builder.inl b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_mixed_input_umma_builder.inl new file mode 100644 index 0000000..70726b6 --- /dev/null +++ b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_mixed_input_umma_builder.inl @@ -0,0 +1,349 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/gemm/collective/builders/sm100_common.inl" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template< + int CapacityBytes, + class ElementA, + class ElementAMma, + class ElementScale, + class ElementZero, + class ElementB, + class CtaTileShape_MNK, + class TiledMma, + class KernelScheduleType, + UMMA::Major UmmaMajorA, + int ScaleGranularityK, + int stages +> +constexpr cute::tuple +sm100_compute_stage_count_or_override_mixed_input(StageCount stage_count) { + constexpr int Load2TransformStageCount = stages; + constexpr int Transform2MmaStageCount = stages; + constexpr int AccumulatorStageCount = stages; + return cute::make_tuple(Load2TransformStageCount, Transform2MmaStageCount, AccumulatorStageCount); +} + +template< + int CapacityBytes, + class ElementA, + class ElementAMma, + class ElementScale, + class ElementZero, + class ElementB, + class CtaTileShape_MNK, + class TiledMma, + class KernelScheduleType, + UMMA::Major UmmaMajorA, + int ScaleGranularityK, + int carveout_bytes +> +constexpr cute::tuple +sm100_compute_stage_count_or_override_mixed_input(StageCountAutoCarveout stage_count) { + + constexpr int CtaM = get<0>(CtaTileShape_MNK{}); + constexpr int CtaN = get<1>(CtaTileShape_MNK{}); + static_assert(CtaN <= 128, "Can't support CtaN>128 tiles"); + constexpr int CtaK = get<2>(CtaTileShape_MNK{}); + using AtomThrID = typename TiledMma::AtomThrID; + + constexpr int TmemColumns = 512; + + constexpr bool IsAComputeinTmem = UmmaMajorA == cute::UMMA::Major::K && !cute::is_base_of_v; + constexpr bool IsAComputeinSmem = !IsAComputeinTmem; + + // Detect 2x2 TMEM layout + constexpr int TmemAccWordsPerDP = (CtaM == 64 && size(AtomThrID{}) == 2) ? CtaN/2 : CtaN; + constexpr int TmemAWordsPerDP = CtaK / 2; + + constexpr int AccumulatorStageCount = (IsAComputeinTmem) ? ((TmemAccWordsPerDP == 128) ? 2 : 3) : (TmemColumns / TmemAccWordsPerDP); + + constexpr int SmemCapacityAfterMma2AccumCarveout = CapacityBytes - (carveout_bytes + AccumulatorStageCount * 32); + + constexpr int TmemInAStageCount_Potential = (IsAComputeinTmem) ? (TmemColumns - AccumulatorStageCount * TmemAccWordsPerDP) / TmemAWordsPerDP : 10000; + + // Mainload2Transform Pipeline + constexpr auto load2transform_pipeline_bytes = sizeof(typename cutlass::PipelineTmaTransformAsync<1>::SharedStorage); + constexpr auto a_bits = cute::sizeof_bits_v; // ElementA introduce here + constexpr auto s_bits = cute::is_void_v ? 0 : cute::sizeof_bits_v; + constexpr auto z_bits = cute::is_void_v ? 0 : cute::sizeof_bits_v; + + constexpr auto load2mma_pipeline_bytes = sizeof(typename cutlass::PipelineTmaUmmaAsync<1>::SharedStorage); + constexpr auto b_bits = cute::sizeof_bits_v; // ElementB introduce here + + constexpr int ab_stage_bytes = + cutlass::bits_to_bytes(a_bits * size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{})) + + cutlass::bits_to_bytes(s_bits * size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{}) / ScaleGranularityK) + + cutlass::bits_to_bytes(z_bits * size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{}) / ScaleGranularityK) + + cutlass::bits_to_bytes(b_bits * size<1>(CtaTileShape_MNK{}) / size(AtomThrID{}) * size<2>(CtaTileShape_MNK{})) + + static_cast(load2transform_pipeline_bytes) + static_cast(load2mma_pipeline_bytes); + + // Transform2Mma Pipeline + constexpr auto transform2mma_pipeline_bytes = sizeof(typename cutlass::PipelineUmmaConsumerAsync<1>::SharedStorage); + constexpr auto a_compute_bits = cute::sizeof_bits_v; + constexpr int ab_compute_stage_bytes = + cutlass::bits_to_bytes(a_compute_bits * int(IsAComputeinSmem) * size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{})) + // If ACompute is in TMEM, Acompute buffer has 0 bytes. + static_cast(transform2mma_pipeline_bytes); + + constexpr int ABComputeStageCount_Potential = SmemCapacityAfterMma2AccumCarveout / (ab_stage_bytes + ab_compute_stage_bytes); + + // The number of SMEM buffers for A, B. ACompute (if in SMEM), BCompute should be at least Transform2MmaStageCount + constexpr int Transform2MmaStageCount = std::min(TmemInAStageCount_Potential, ABComputeStageCount_Potential); + + constexpr int SmemCapacityAfterABComputeCarveout = SmemCapacityAfterMma2AccumCarveout - (Transform2MmaStageCount * ab_compute_stage_bytes); + + // Can we boost the number of buffers for A and B? + constexpr int Load2TransformStageCount = SmemCapacityAfterABComputeCarveout / ab_stage_bytes; + + static_assert(Load2TransformStageCount >= 2 && Transform2MmaStageCount >= 2 && AccumulatorStageCount >= 2, "Not enough SMEM or TMEM capacity for selected tile size"); + return cute::make_tuple(Load2TransformStageCount, Transform2MmaStageCount, AccumulatorStageCount); +} + +} // namespace detail + +template +constexpr int get_ScaleGranularityK() { + if constexpr (cute::is_void_v) { + return 1; + } else { + return size<1,0>(LayoutScale{}); + } +} + + +// Mixed Input MMA kernels builder +template < + class ArchTag, + class ElementAOptionalTuple, + class GmemLayoutATagTuple, + int AlignmentA, + class ElementBOptionalTuple, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, // The Cluster-level TileShape + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + ArchTag, + arch::OpClassTensorOp, + ElementAOptionalTuple, // ElementA + GmemLayoutATagTuple, // LayoutA + AlignmentA, + ElementBOptionalTuple, // ElementB + GmemLayoutBTag, // LayoutB + AlignmentB, + ElementAccumulator, + TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + ClusterShape_MNK, // Static cluster shape or dynamic (int, int, int) + StageCountType, + KernelScheduleType, + cute::enable_if_t< + (cute::is_same_v + ) && + (cute::is_base_of_v) && + ((sizeof(float) * AlignmentA) % detail::tma_alignment_bytes == 0) && + ((sizeof(float) * AlignmentB) % detail::tma_alignment_bytes == 0)>> +{ + using GmemLayoutATag = detail::deduce_mixed_width_dtype_t<0, GmemLayoutATagTuple>; + using GmemLayoutScaleTag = detail::deduce_mixed_width_dtype_t<1, GmemLayoutATagTuple>; + + static constexpr cute::UMMA::Major UmmaMajorA = cutlass::gemm::collective::detail::tag_to_umma_major_A(); + static constexpr cute::UMMA::Major UmmaMajorB = cutlass::gemm::collective::detail::tag_to_umma_major_B(); + + using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementAOptionalTuple>; + using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementBOptionalTuple>; + using ElementScale = detail::deduce_mixed_width_dtype_t<1, ElementAOptionalTuple>; + using ElementZero = detail::deduce_mixed_width_dtype_t<2, ElementAOptionalTuple>; + + static constexpr bool NeitherIsTuple = !cute::is_tuple::value && !cute::is_tuple::value; + static constexpr bool IsANarrow = cute::sizeof_bits_v < cute::sizeof_bits_v; + static constexpr bool IsMixedInput = cute::sizeof_bits_v != cute::sizeof_bits_v; + static_assert(IsMixedInput, "Mixed Input GEMM Kernel doesn't support regular gemm."); + + static_assert((cute::is_tuple::value ^ cute::is_tuple::value || + (NeitherIsTuple && (cute::sizeof_bits::value != cute::sizeof_bits::value))), + "Either A OR B must be a tuple or the widths of A and B must be different."); + using ElementPairA = cute::conditional_t, ElementAOptionalTuple>; + using ElementPairB = cute::conditional_t, ElementBOptionalTuple>; + static constexpr bool IsATransformed = cute::is_tuple::value; + static_assert(IsATransformed, "A matrix should be transformed."); + + // For fp32 types, map to tf32 MMA value type. + using ElementMma = cute::conditional_t, tfloat32_t, ElementB>; + + + using ElementAMma = ElementMma; + using ElementBMma = ElementMma; + + static constexpr int IsSubbyteA = cute::sizeof_bits_v < 8; + using TmaElementA = cute::conditional_t; + + static constexpr int ScalingFactor = 1; + + using TiledMma = decltype(detail::sm100_make_trivial_mixed_input_tiled_mma()); + using AtomThrID = typename TiledMma::AtomThrID; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + using CtaTileShape_MNK = decltype(shape_div(TileShape_MNK{}, AtomThrShapeMNK{})); + + // ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K) + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + // ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K) + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + + using BlockTileA_M = decltype(cute::size<0,0>(MmaShapeA_MK{}) * cute::size<1>(MmaShapeA_MK{})); + using BlockTileA_K = decltype(cute::size<0,1>(MmaShapeA_MK{}) * cute::size<2>(MmaShapeA_MK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(cute::size<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm100_cluster_shape_to_tma_atom_B(ClusterShape_MNK{}, AtomThrID{})); + + // Input transform kernel can not use TMA 2SM instructions. + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); + using SmemLayoutAtomACompute = decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); + using SmemLayoutAtomPairA = cutlass::gemm::collective::detail::CollectiveMmaEmulatedLayoutAtomType< + SmemLayoutAtomA, SmemLayoutAtomACompute>; + static constexpr int MMA_M = cute::size<0,0>(MmaShapeA_MK{}); + using CopyAtomPairA = cutlass::gemm::collective::detail::CollectiveMmaEmulatedCopyType< + Copy_Atom, ElementA>, + cute::conditional_t<(UmmaMajorA == cute::UMMA::Major::K && !cute::is_base_of_v), + cute::conditional_t<(MMA_M == 64 && size(AtomThrID{}) == 1), SM100_TMEM_STORE_16dp256b1x, SM100_TMEM_STORE_32dp32b8x>, // TS Implementation + Copy_Atom, ElementA>> // SS Implementation + >; + + using BlockTileB_N = decltype(cute::size<0,0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{})); + using BlockTileB_K = decltype(cute::size<0,1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{})); + + // Input transform kernel can not use TMA 2SM instructions. + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); + using SmemLayoutAtomBCompute = decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); + using SmemLayoutAtomPairB = cutlass::gemm::collective::detail::CollectiveMmaEmulatedLayoutAtomType< + SmemLayoutAtomB, SmemLayoutAtomBCompute>; + using CopyAtomPairB = cutlass::gemm::collective::detail::CollectiveMmaEmulatedCopyType< + Copy_Atom, ElementB>, + Copy_Atom, ElementMma> + >; + + //Creating the stride of Transformed Input + using StrideA = cutlass::gemm::TagToStrideA_t; + using LayoutScale = cutlass::gemm::TagToStrideA_t; + + using VoidShapeScale = Shape, _1>, Shape, _1>, _1>; //Dummy Value to create a dummy ScaleConfig + using VoidStrideScale = Stride,Stride<_0, _1>, _1>; + using VoidLayoutScale = Layout; + + using NonVoidLayoutScale = cute::conditional_t< + cute::is_void_v, VoidLayoutScale, LayoutScale>; + + using StridePairA = decltype(cute::make_tuple(StrideA{}, NonVoidLayoutScale{})); + + // SmemCarveout + static constexpr int SchedulerPipelineStageCount = 3; + static constexpr bool IsArrayOfPointersGemm = (cute::is_base_of_v); + + // CLCPipeline = PipelineCLCFetchAsync + static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); + // CLC (scheduler) response + static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize; + // CLC Throttle pipeline storage + static constexpr auto CLCThrottlePipelineStorage = sizeof(typename cutlass::PipelineAsync::SharedStorage); + // Tmem dealloc + static constexpr auto TmemDeallocStorage = sizeof(cutlass::arch::ClusterBarrier); + // Tmem ptr storage + static constexpr auto TmemBasePtrsStorage = sizeof(uint32_t); + // Tensormap Storage + static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0; + + // Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage + static constexpr auto KernelSmemCarveout = static_cast( CLCPipelineStorage + + CLCResponseStorage + + CLCThrottlePipelineStorage + + TmemDeallocStorage + + TmemBasePtrsStorage + + TensorMapStorage); + + // Reduce SMEM capacity available for buffers considering extra B smem and barrier smem allocations + static constexpr int ReducedSmemCapacityBytes = + cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + + static constexpr int ScaleGranularityK = get_ScaleGranularityK(); + static constexpr auto stage_info = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_mixed_input< + ReducedSmemCapacityBytes, TmaElementA, ElementAMma, ElementScale, ElementZero, ElementB, CtaTileShape_MNK, TiledMma, KernelScheduleType, UmmaMajorA, ScaleGranularityK>(StageCountType{}); + + static constexpr int Load2TransformPipelineStageCount = get<0>(stage_info); + static constexpr int Transform2MmaPipelineStageCount = get<1>(stage_info); + static constexpr int AccumulatorPipelineStageCount = get<2>(stage_info); + + static_assert(!IsArrayOfPointersGemm, "mixed input does not support grouped gemm on Blackwell"); + + using DispatchPolicy = cutlass::gemm::MainloopSm100TmaUmmaWarpSpecializedMixedInput< + Load2TransformPipelineStageCount, + Transform2MmaPipelineStageCount, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape_MNK + >; + using CollectiveOp = cutlass::gemm::collective::CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementPairA, + StridePairA, + ElementPairB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomPairA, + CopyAtomPairA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomPairB, + CopyAtomPairB, + cute::identity + >; +}; + +} // namespace cutlass::gemm::collective diff --git a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_mixed_tma_cpasync_umma_builder.inl b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_mixed_tma_cpasync_umma_builder.inl new file mode 100644 index 0000000..d9cb128 --- /dev/null +++ b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_mixed_tma_cpasync_umma_builder.inl @@ -0,0 +1,171 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/gemm/collective/builders/sm100_common.inl" + +#include "cutlass/gemm/collective/collective_builder_decl.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class BuilderScheduleTag +> +struct CollectiveBuilder< + arch::Sm100, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + ClusterShape_MNK, // Static cluster shape (_1, _1, _1) + StageCountType, + BuilderScheduleTag, + cute::enable_if_t > +> +{ + static_assert(cute::is_static_v, "TileShape has to be static"); + + static constexpr cute::UMMA::Major UmmaMajorA = cutlass::gemm::collective::detail::tag_to_umma_major_A(); + static constexpr cute::UMMA::Major UmmaMajorB = cutlass::gemm::collective::detail::tag_to_umma_major_B(); + + // Data type used by MMA instruction + using ElementAMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element()); + using ElementBMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element()); + + using ElementAMma_SmemAllocType = cute::conditional_t < 8, uint8_t, ElementAMma>; + using ElementBMma_SmemAllocType = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using TiledMma = decltype(detail::sm100_make_trivial_tiled_mma< + ElementAMma, ElementBMma, ElementAccumulator, + decltype(cute::product_each(TileShape_MNK{})), ClusterShape_MNK, + UmmaMajorA, UmmaMajorB, BuilderScheduleTag>()); + + using AtomThrID = typename TiledMma::AtomThrID; + + // ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K) + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + // ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K) + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + + // Assigning 4 warps for mainloop load of B + static constexpr int NumLoadThreadsCpAsync = 128; + + + using SmemShapeA_M = decltype(shape_div(shape<0>(TileShape_MNK{}), shape_div(shape<0>(TileShape_MNK{}), size<0>(TileShape_MNK{}) / size(AtomThrID{})))); + using SmemShapeA_K = decltype(cute::get<2>(TileShape_MNK{})); + + + using GmemTiledCopyA = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_A( + ClusterShape_MNK{}, AtomThrID{})); + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + UmmaMajorA, ElementAMma_SmemAllocType, SmemShapeA_M, SmemShapeA_K>()); + + using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * AlignmentB>; + using GmemCopyAtomB = cute::Copy_Atom, ElementB>; + using GmemTiledCopyB = decltype(detail::make_simt_gmem_tiled_copy< + GmemCopyAtomB, NumLoadThreadsCpAsync, AlignmentB, TagToStrideB_t, + decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + using BlockTileB_N = decltype(cute::size<0,0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{})); + using BlockTileB_K = decltype(cute::size<0,1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{})); + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + UmmaMajorB, ElementBMma_SmemAllocType, BlockTileB_N, BlockTileB_K>()); + + static constexpr uint32_t AccumulatorPipelineStageCount = 2; + // Calculate scheduler pipeline stages. Having one more stage than the accumulator allows more latency hiding. + static constexpr uint32_t SchedulerPipelineStageCount = AccumulatorPipelineStageCount + 1; + + // AccumulatorPipeline = PipelineUmmaAsync + static constexpr auto AccumulatorPipelineStorage = sizeof(typename cutlass::PipelineUmmaAsync::SharedStorage); + // CLCPipeline = PipelineCLCFetchAsync + static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); + // CLC (scheduler) response + static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize; + // Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage + static constexpr auto KernelSmemCarveout = static_cast( AccumulatorPipelineStorage + + CLCPipelineStorage + + CLCResponseStorage); + // Reduce SMEM capacity available for buffers considering barrier allocations. + static constexpr int Sm100ReducedSmemCapacityBytes = cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + + using SmemTileShape = cute::Shape; + using MainloopPipelineStorage = typename cutlass::PipelineUmmaConsumerAsync<1>::SharedStorage; + + static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override< + Sm100ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, MainloopPipelineStorage>(StageCountType{}); + + using CollectiveOp = cutlass::gemm::collective::CollectiveMma< + cutlass::gemm::MainloopSm100UmmaMixedTmaCpAsyncWarpSpecialized< + PipelineStages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape_MNK>, + TileShape_MNK, + ElementA, + cutlass::gemm::TagToStrideA_t, + ElementB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + void, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + void, + cute::identity + >; +}; + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_pipeline_carveout.inl b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_pipeline_carveout.inl index fc4aa4a..08f90a7 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_pipeline_carveout.inl +++ b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_pipeline_carveout.inl @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -46,11 +46,13 @@ struct Sm100DenseGemmTmaUmmaCarveout { // AccumulatorPipeline = PipelineUmmaAsync static constexpr auto AccumulatorPipelineStorage = sizeof(typename cutlass::PipelineUmmaAsync::SharedStorage); // CLCPipeline = PipelineCLCFetchAsync - static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); + // For pointer-array and grouped GEMM, we have two CLC responses, one for TMA updater, one for the TMA/MMA/Epilogue warps. + static constexpr int NumCLCResponses = (IsArrayOfPointersGemm ? 2 : 1); + static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage) * NumCLCResponses; // LoadOrderBarrier = OrderedSequenceBarrier<1,2> static constexpr auto LoadOrderBarrierStorage = sizeof(typename cutlass::OrderedSequenceBarrier<1,2>::SharedStorage); // CLC (scheduler) response - static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize; + static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize * NumCLCResponses; // CLC Throttle pipeline storage static constexpr auto CLCThrottlePipelineStorage = sizeof(typename cutlass::PipelineAsync::SharedStorage); // Tmem dealloc @@ -59,8 +61,14 @@ struct Sm100DenseGemmTmaUmmaCarveout { static constexpr auto TmemBasePtrsStorage = SchedulerPipelineStageCount * sizeof(uint32_t); // Tensormap Storage static constexpr auto TensorMapStorage = - IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * NumTensorMaps /* for A and B */ : + IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * NumTensorMaps * 5 /* We have five tensormaps smem */ : + 0; + + // TensorMapReady pipeline storage (specific to grouped/array kernels) + static constexpr auto TensorMapReadyPipelineStorage = + IsArrayOfPointersGemm ? sizeof(typename cutlass::PipelineAsync::SharedStorage) : 0; + // Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage static constexpr auto KernelSmemCarveout = static_cast( AccumulatorPipelineStorage + CLCPipelineStorage + @@ -69,7 +77,8 @@ struct Sm100DenseGemmTmaUmmaCarveout { CLCThrottlePipelineStorage + CLCResponseStorage + TmemBasePtrsStorage + - TensorMapStorage + TensorMapStorage + + TensorMapReadyPipelineStorage ); }; diff --git a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_planar_complex_umma_builder.inl b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_planar_complex_umma_builder.inl new file mode 100644 index 0000000..dee762f --- /dev/null +++ b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_planar_complex_umma_builder.inl @@ -0,0 +1,182 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + + +#include "cutlass/gemm/collective/builders/sm100_common.inl" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +///////////////////////////////////////////////////////////////////////////////////////////////// +// Planar Complex f16/bf16 TCGEN05 kernels builder +template < + class ArchTag, + class ElementA, + class GmemLayoutATag, + class TransformA, + class ElementB, + class GmemLayoutBTag, + class TransformB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class BuilderScheduleTag +> +struct CollectiveBuilder< + ArchTag, + arch::OpClassTensorOp, + cute::tuple, + GmemLayoutATag, + 8, + cute::tuple, + GmemLayoutBTag, + 8, + ElementAccumulator, + TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + ClusterShape_MNK, // Static cluster shape or dynamic (int, int, _1) + StageCountType, + BuilderScheduleTag, + cute::enable_if_t< + (cute::is_same_v + ) && + // Element Types AB should be set as real type in Planar Complex f16/bf16 TCGEN05 kernels builder. + (cute::is_same_v || cute::is_same_v) && + (cute::is_same_v || cute::is_same_v) && + // Planar Complex f16/bf16 kernels don't support auto-scheduling for mainloop builder. + cute::is_base_of_v>> +{ + static_assert(cute::is_static_v, "TileShape has to be static"); + + static constexpr cute::UMMA::Major UmmaMajorA = cutlass::gemm::collective::detail::tag_to_umma_major_A(); + static constexpr cute::UMMA::Major UmmaMajorB = cutlass::gemm::collective::detail::tag_to_umma_major_B(); + + using TiledMma = decltype(detail::sm100_make_trivial_tiled_mma< + ElementA, ElementB, ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + UmmaMajorA, UmmaMajorB, BuilderScheduleTag, UMMA::ScaleIn::One>()); + using TiledMmaANeg = decltype(detail::sm100_make_trivial_tiled_mma< + ElementA, ElementB, ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + UmmaMajorA, UmmaMajorB, BuilderScheduleTag, UMMA::ScaleIn::Neg>()); + using TiledMmaPair = cutlass::gemm::collective::detail::Sm100CollectiveMmaPlanarComplexTiledMmaType; + + using AtomThrID = typename TiledMma::AtomThrID; + + // Define A and B block shapes for reduced size TMA_LOADs + // ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K) + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + // ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K) + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + + using GmemTiledCopyA = decltype(detail::sm100_cluster_shape_to_tma_atom_A(ClusterShape_MNK{}, AtomThrID{})); + + using BlockTileA_M = decltype(cute::size<0,0>(MmaShapeA_MK{}) * cute::size<1>(MmaShapeA_MK{})); + using BlockTileA_K = decltype(cute::size<0,1>(MmaShapeA_MK{}) * cute::size<2>(MmaShapeA_MK{})); + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + UmmaMajorA, ElementA, BlockTileA_M, BlockTileA_K>()); + + using GmemTiledCopyB = decltype(detail::sm100_cluster_shape_to_tma_atom_B(ClusterShape_MNK{}, AtomThrID{})); + + using BlockTileB_N = decltype(cute::size<0,0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{})); + using BlockTileB_K = decltype(cute::size<0,1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{})); + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + UmmaMajorB, ElementB, BlockTileB_N, BlockTileB_K>()); + + // Calculate SMEM matrix A and B buffers' pipeline stages + static constexpr uint32_t AccumulatorPipelineStageCount = 2; + // Ptr-arry gemm requires extra TensorMap storage + static constexpr bool IsArrayOfPointersGemm = (cute::is_base_of_v); + // Calculate scheduler pipeline stages. Having one more stage than the accumulator allows more latency hiding. + static constexpr uint32_t SchedulerPipelineStageCount = IsArrayOfPointersGemm ? AccumulatorPipelineStageCount + 1: 1; + + static constexpr uint32_t KernelSmemCarveout = detail::Sm100DenseGemmTmaUmmaCarveout< + ClusterShape_MNK, + AccumulatorPipelineStageCount, + SchedulerPipelineStageCount, + detail::CLCResponseSize, + IsArrayOfPointersGemm, + 4 // 4 Tensor maps for A_{imag|real} and B_{imag|real} + >::KernelSmemCarveout; + + // Reduce SMEM capacity available for buffers considering barrier allocations. + + static constexpr int ReducedSmemCapacityBytes = + cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + + using SmemTileShape = cute::Shape; + + // Use complex type to calculate SMEM stage count + using ComplexElementA = cutlass::complex; + using ComplexElementB = cutlass::complex; + + using MainloopPipelineStorage = typename cutlass::PipelineTmaUmmaAsync<1>::SharedStorage; + static constexpr int PipelineStages = detail::sm100_compute_stage_count_or_override< + ReducedSmemCapacityBytes, ComplexElementA, ComplexElementB, SmemTileShape, MainloopPipelineStorage>(StageCountType{}); + + using DispatchPolicy = cute::conditional_t, + cutlass::gemm::MainloopSm100TmaUmmaWarpSpecializedPlanarComplex< + PipelineStages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape_MNK + > + >; + + using CollectiveOp = cutlass::gemm::collective::CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + cutlass::gemm::TagToStrideA_t, + ElementB, + cutlass::gemm::TagToStrideB_t, + TiledMmaPair, + GmemTiledCopyA, + SmemLayoutAtomA, + void, + TransformA, + GmemTiledCopyB, + SmemLayoutAtomB, + void, + TransformB + >; +}; + +} // cutlass::gemm::collective diff --git a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_simt_builder.inl b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_simt_builder.inl new file mode 100644 index 0000000..40dcfae --- /dev/null +++ b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_simt_builder.inl @@ -0,0 +1,219 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/gemm/collective/builders/sm100_common.inl" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template< + class LayoutA, + int AlignmentA, + class LayoutB, + int AlignmentB, + class CtaShape_MNK, + class WarpShape_MNK +> +constexpr auto +sm100_make_simt_f32_tiled_mma() { + using namespace cute; + + constexpr int CtaShape_M = cute::size<0>(CtaShape_MNK{}); + constexpr int CtaShape_N = cute::size<1>(CtaShape_MNK{}); + constexpr int CtaShape_K = cute::size<2>(CtaShape_MNK{}); + + constexpr int WarpShape_M = cute::size<0>(WarpShape_MNK{}); + constexpr int WarpShape_N = cute::size<1>(WarpShape_MNK{}); + constexpr int WarpShape_K = cute::size<2>(WarpShape_MNK{}); + + // Use Permutation to achieve a [4 x 4] value layout for each thread. + // Ideally, we want the tiled mma to be such that loads from shared memory are 128 bit wide. + // While as we are using CtaShape_K = 16, when A and B are K-major, we use tranpose + 8 byte padding to avoid smem bank conflict, + // so we could only use 64 bit smem load. + // When A and B are MN-major, we use 128 bit smem load. + using PermutationA = Layout, _2>, Stride< _1, _4, _2>>; + using PermutationB = Layout, _4>, Stride< _4, _1>>; + + // For 32 threads in 1 warp, we use [8 x 4] thread layouts and each thread will hold [4 x 4] value layouts. + // Then totally each warp will hold [32 x 16] value layouts. + // So WarpShape_M needs to be equal or smaller than CtaShape_M / 32 and WarpShape_N needs to be equal or smaller than CtaShape_N / 16. + static_assert(WarpShape_M <= CtaShape_M / 32, "WarpShape_M is too large, it needs to be equal or smaller than CtaShape_M / 32."); + static_assert(WarpShape_N <= CtaShape_N / 16, "WarpShape_N is too large, it needs to be equal or smaller than CtaShape_N / 16."); + + constexpr int WarpStride_M = (WarpShape_M != 1) * NumThreadsPerWarp; + constexpr int WarpStride_N = WarpShape_M * NumThreadsPerWarp; + + // We first introduce a [8 x 4] thread layouts in 1 warp. + // And inside this [8 x 4] thread layouts, each 4 threads will be arranged as [2 x 2]. + // Then we could set different WarpShape to finalize how many warps we use in our tiled mma. + // For example : + // With 128 threads in the tiled mma, we could set the WarpShapeMNK as [2 x 2 x 1], [1 x 4 x 1] and [4 x 1 x 1]. + // With 64 threads in the tiled mma, we could set the WarpShapeMNK as [1 x 2 x 1] and [2 x 1 x 1]. + return make_tiled_mma( + MMA_Atom{}, + Layout>, Shape <_2, _2, Int>, _1>, + Stride< Stride<_1, _8, Int>, Stride<_2, _4, Int>, _1>>{}, + Tile< + PermutationA, + PermutationB, + Underscore>{}); +} + +} // namespace detail + +template < + class ArchTag, + class GmemLayoutATag, + int AlignmentA, + class GmemLayoutBTag, + int AlignmentB, + class CtaShape_MNK, + class ClusterShape_MNK, + int stages, + class BuilderScheduleTag> +struct CollectiveBuilder< + ArchTag, + arch::OpClassSimt, + float, + GmemLayoutATag, + AlignmentA, + float, + GmemLayoutBTag, + AlignmentB, + float, + CtaShape_MNK, + ClusterShape_MNK, + StageCount, + BuilderScheduleTag, + cute::enable_if_t< + (cute::is_same_v + ) && + (cute::is_same_v || + cute::is_same_v || + cute::is_same_v) && + ((sizeof(float) * AlignmentA) % detail::cp_async_min_alignment_bytes == 0) && + ((sizeof(float) * AlignmentB) % detail::cp_async_min_alignment_bytes == 0) >> { + static_assert(cute::size<2>(CtaShape_MNK{}) == 16, "SM100 SIMT SGEMM Kernels only support TileShape_K = 16."); + + // This kernel is specialized for F32 data type. + using ElementA = float; + using ElementB = float; + + using M = decltype(cute::size<0>(CtaShape_MNK{})); + using N = decltype(cute::size<1>(CtaShape_MNK{})); + using K = decltype(cute::size<2>(CtaShape_MNK{})); + + using WarpShape_MNK = decltype(detail::sm100_simt_f32_warp_shape_mnk_selector()); + + static constexpr int ThreadCount = cute::size(WarpShape_MNK{}) * NumThreadsPerWarp; + + using TiledMma = decltype( + detail::sm100_make_simt_f32_tiled_mma< + GmemLayoutATag, + AlignmentA, + GmemLayoutBTag, + AlignmentB, + CtaShape_MNK, + WarpShape_MNK>()); + + // for K major layouts, add a smem alignment offset to avoid bank conflicts + static constexpr int SmemAlignmentOffsetA = cutlass::gemm::detail::is_mn_major_A() ? 0 : 2; + static constexpr int SmemAlignmentOffsetB = cutlass::gemm::detail::is_mn_major_B() ? 0 : 2; + static constexpr int CtaShape_M = cute::size<0>(CtaShape_MNK{}); + static constexpr int CtaShape_N = cute::size<1>(CtaShape_MNK{}); + + // Shared memory layout is [M x K] in M-major + using SmemLayoutAtomA = cute::Layout, + cute::Stride<_1, Int>>; + // A M-major use 128bit smem load. + // A K-major needs to do tranpose and 8 byte padding to make smem bank conflict free, then we can only use 64bit smem load. + using SmemCopyAtomA = std::conditional_t(), + cute::Copy_Atom, ElementA>, + cute::Copy_Atom, ElementA>>; + + using AlignmentTypeA = cute::uint_byte_t(sizeof(ElementA)) * AlignmentA>; + using GmemCopyAtomA = cute::Copy_Atom, ElementA>; + using GmemTiledCopyA = decltype( + detail::make_simt_gmem_tiled_copy< + GmemCopyAtomA, ThreadCount, AlignmentA, TagToStrideA_t, M, K>()); + + // Shared memory layout is [N x K] in N-major + using SmemLayoutAtomB = cute::Layout, + cute::Stride<_1, Int>>; + // B N-major use 128bit smem load. + // B K-major needs to do tranpose and 8 byte padding to make smem bank conflict free, then we can only use 64bit smem load. + using SmemCopyAtomB = std::conditional_t(), + cute::Copy_Atom, ElementB>, + cute::Copy_Atom, ElementB>>; + + using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * AlignmentB>; + using GmemCopyAtomB = cute::Copy_Atom, ElementB>; + using GmemTiledCopyB = decltype( + detail::make_simt_gmem_tiled_copy< + GmemCopyAtomB, ThreadCount, AlignmentB, TagToStrideB_t, N, K>()); + + static constexpr bool IsArrayOfPointersGemm = cute::is_same_v; + using DispatchPolicy = cute::conditional_t, + cutlass::gemm::MainloopSm80CpAsync + >; + + using CollectiveOp = cutlass::gemm::collective::CollectiveMma< + DispatchPolicy, + CtaShape_MNK, + ElementA, + TagToStrideA_t, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity + >; +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_sparse_umma_builder.inl b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_sparse_umma_builder.inl index c7d380a..40dfda2 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_sparse_umma_builder.inl +++ b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_sparse_umma_builder.inl @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -44,6 +44,7 @@ namespace detail { // Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. template < + int CapacityBytes, class ElementAMma, class ElementB, class ElementEMma, @@ -60,6 +61,7 @@ sm100_compute_stage_count_or_override_sparse(StageCount stage_count) { // Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. template < + int CapacityBytes, class ElementAMma, class ElementB, class ElementEMma, @@ -104,7 +106,7 @@ sm100_compute_stage_count_or_override_sparse(StageCountAutoCarveout struct CollectiveBuilder< - arch::Sm100, + ArchTag, arch::OpClassSparseTensorOp, ElementA, GmemLayoutATag, @@ -296,6 +299,8 @@ struct CollectiveBuilder< StageCountType, BuilderScheduleTag, cute::enable_if_t< + (cute::is_same_v + ) && (not cute::is_tuple_v && not cute::is_tuple_v && not cute::is_complex_v && not cute::is_complex_v && not cute::is_sparse_v) && @@ -375,7 +380,12 @@ struct CollectiveBuilder< using SmemTileShape = cute::Shape; + // Calculate SMEM capacity based on ArchTag + static constexpr int ReducedSmemCapacityBytes = + cutlass::gemm::collective::detail::sm100_smem_capacity_bytes; + static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_sparse< + ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, ElementEMma, diff --git a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl index e7f5235..3e4d830 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl +++ b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -153,6 +153,7 @@ check_input_datatypes() { ///////////////////////////////////////////////////////////////////////////////////////////////// template < + class ArchTag, class ElementA, class GmemLayoutATag, int AlignmentA, @@ -166,7 +167,7 @@ template < class BuilderScheduleTag > struct CollectiveBuilder< - arch::Sm100, + ArchTag, arch::OpClassTensorOp, ElementA, GmemLayoutATag, @@ -180,10 +181,14 @@ struct CollectiveBuilder< StageCountType, BuilderScheduleTag, cute::enable_if_t< + (cute::is_same_v + ) && not cute::is_tuple_v && not cute::is_tuple_v && not cute::is_complex_v && not cute::is_complex_v && // Dense Gemm / PtrArrayDenseGemm ( + (not cute::is_same_v) && + (not cute::is_same_v) && (cute::is_base_of_v || cute::is_same_v)) && // Alignment check @@ -263,11 +268,17 @@ struct CollectiveBuilder< // Calculate scheduler pipeline stages. Having one more stage than the accumulator allows more latency hiding. using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideB = cutlass::gemm::TagToStrideB_t; using InternalStrideA = cute::remove_pointer_t; + using InternalStrideB = cute::remove_pointer_t; static constexpr bool IsArrayOfPointersGemm = (cute::is_base_of_v); + // Grouped GEMM(where Stride type is Stride*) uses specific static tile scheduler. - static constexpr bool IsGroupGemm = !cute::is_same_v; - static constexpr uint32_t SchedulerPipelineStageCount = cute::conditional_return(8, 2); + // Perform checks for both StrideA and StrideB to filter out Ragged Continguous Group Gemm + static constexpr bool IsGroupGemm = !(cute::is_same_v) && !(cute::is_same_v); + static constexpr bool IsRCGroupGemm = (cute::is_same_v) && !(cute::is_same_v); + + static constexpr uint32_t SchedulerPipelineStageCount = cute::conditional_return(8, 2); static constexpr uint32_t KernelSmemCarveout = detail::Sm100DenseGemmTmaUmmaCarveout< ClusterShape_MNK, @@ -277,23 +288,34 @@ struct CollectiveBuilder< IsArrayOfPointersGemm >::KernelSmemCarveout; // Reduce SMEM capacity available for buffers considering barrier allocations. - static constexpr int Sm100ReducedSmemCapacityBytes = cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + + static constexpr int ReducedSmemCapacityBytes = + cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + using SmemTileShape = cute::Shape; using MainloopPipelineStorage = typename cutlass::PipelineTmaUmmaAsync<1>::SharedStorage; static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override< - Sm100ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, MainloopPipelineStorage>(StageCountType{}); + ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, MainloopPipelineStorage>(StageCountType{}); static_assert(PipelineStages > 0, "Smem usage is too high. Can't create any SMEM buffers for A, and B."); using DispatchPolicy = cute::conditional_t, + cute::conditional_t, + cutlass::gemm::MainloopSm100ArrayTmaUmmaWarpSpecialized< + PipelineStages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape_MNK + > + >, cutlass::gemm::MainloopSm100TmaUmmaWarpSpecialized< PipelineStages, SchedulerPipelineStageCount, diff --git a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm103_blockscaled_umma_builder.inl b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm103_blockscaled_umma_builder.inl new file mode 100644 index 0000000..5986439 --- /dev/null +++ b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm103_blockscaled_umma_builder.inl @@ -0,0 +1,550 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/gemm/collective/builders/sm100_common.inl" +#include "cutlass/detail/sm103_blockscaled_layout.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template < + int CapacityBytes, + class ElementA, + class ElementB, + class TileShapeMNK, + class TileShapeSFA, + class TileShapeSFB, + int stages +> +constexpr int +sm103_compute_stage_count_or_override_blockscaled(StageCount stage_count) { + return stages; +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template < + int CapacityBytes, + class ElementA, + class ElementB, + class TileShapeMNK, + class TileShapeSFA, + class TileShapeSFB, + int carveout_bytes +> +constexpr auto +sm103_compute_stage_count_or_override_blockscaled(StageCountAutoCarveout stage_count) { + // For F8F6F4 MMA sub-bytes, ElementA/B will be passed in as uint8_t + // Each stage include (CollectiveMma::SharedStorage) + // 1. smem for A and smem for B (CollectiveMma::SharedStorage::TensorStorage) + // 2. one MainloopPipeline = PipelineTmaUmmaAsync (CollectiveMma::SharedStorage::SharedStorage) + // 3. smem for SFB and smem for SFB (CollectiveMma::SharedStorage::TensorStorage, independent of input size b.c. sizeof(sf) is fixed) + constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaUmmaAsync<1>::SharedStorage); + constexpr auto a_bits = cute::sizeof_bits_v; + constexpr auto b_bits = cute::sizeof_bits_v; + constexpr auto stage_sfa_bytes = size(filter_zeros(TileShapeSFA{})); + constexpr auto stage_sfb_bytes = size(filter_zeros(TileShapeSFB{})); + + constexpr int stage_bytes = + cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + static_cast(mainloop_pipeline_bytes * 2 + stage_sfa_bytes + stage_sfb_bytes); + + constexpr int ab_buffer = (CapacityBytes - carveout_bytes) / stage_bytes; + constexpr int sb_buffer = ab_buffer + (CapacityBytes - carveout_bytes - ab_buffer * stage_bytes) / (mainloop_pipeline_bytes + stage_sfa_bytes + stage_sfb_bytes); + return make_tuple(ab_buffer, sb_buffer); +} + +template< + class ElementAMma, + class ElementBMma, + class ElementAccumulator, + class ElementSF, + class TileShape_MNK, + class ClusterShape_MNK, + UMMA::Major UmmaMajorA, + UMMA::Major UmmaMajorB, + int SFVectorSize +> +constexpr auto +sm103_make_blockscaled_1sm_tiled_mma() { + using AtomLayout_MNK = Layout; + constexpr int M = cute::size<0>(TileShape_MNK{}); + static_assert(M == 128 || M == 256, "Invalid TileShape_M."); + + constexpr int N = cute::size<1>(TileShape_MNK{}); + static_assert(N % 64 == 0 && N <= 256, "Invalid TileShape_N."); + + if constexpr (cute::is_same_v || + cute::is_same_v) { + return make_tiled_mma(cute::SM103::SM103_MXF4_ULTRA_SS_VS{}); + } + else { + static_assert(cutlass::detail::dependent_false, + "Unsupported configuration for SM103 collective builder."); + } +} + +template< + class ElementAMma, + class ElementBMma, + class ElementAccumulator, + class ElementSF, + class TileShape_MNK, + class ClusterShape_MNK, + UMMA::Major UmmaMajorA, + UMMA::Major UmmaMajorB, + int SFVectorSize +> +constexpr auto +sm103_make_blockscaled_2sm_tiled_mma() { + using AtomLayout_MNK = Layout{}))>; + + constexpr int M = cute::size<0>(TileShape_MNK{}); + static_assert(M == 128 || M == 256, "Invalid TileShape_M."); + + constexpr int N = cute::size<1>(TileShape_MNK{}); + static_assert(N % 64 == 0 && N <= 256, "Invalid TileShape_N."); + + if constexpr (cute::is_same_v || + cute::is_same_v) { + return make_tiled_mma(cute::SM103::SM103_MXF4_ULTRA_2x1SM_SS_VS{}); + } + else { + static_assert(cutlass::detail::dependent_false, + "Unsupported configuration for SM103 collective builder."); + } +} + + +template< + class ElementAMma, + class ElementBMma, + class ElementAccumulator, + class ElementSF, + class ClusterTileShape_MNK, + class ClusterShape_MNK, + UMMA::Major UmmaMajorA, + UMMA::Major UmmaMajorB, + class BuilderScheduleTag +> +constexpr auto +sm103_make_blockscaled_tiled_mma() { + constexpr uint32_t SFVectorSize = find_vector_size(); + + // MMA_2SM requested + if constexpr (cute::is_base_of_v) { + return sm103_make_blockscaled_2sm_tiled_mma(); + } + // MMA_1SM requested + else if constexpr (cute::is_base_of_v) { + return sm103_make_blockscaled_1sm_tiled_mma(); + } + // Auto scheduling requested + else if constexpr (cute::is_same_v) { + if constexpr (cute::get<0>(ClusterShape_MNK{}) % 2 == 0) { + return sm103_make_blockscaled_2sm_tiled_mma(); + } + else { + return sm103_make_blockscaled_1sm_tiled_mma(); + } + } + else { + static_assert(cutlass::detail::dependent_false, + "Unsupported policy for SM103 collective builder."); + } +} + +template< + class ElementAMma, + class ElementBMma, + class ElementAccumulator, + class ElementSF, + class TileShape_MNK, + class ClusterShape_MNK, + UMMA::Major UmmaMajorA, + UMMA::Major UmmaMajorB, + uint32_t SFVectorSize, + class BuilderScheduleTag, + bool Is2SM +> +struct Sm103TrivialBlockscaledMma {}; + +template< + class ElementAMma, + class ElementBMma, + class ElementAccumulator, + class ElementSF, + class TileShape_MNK, + class ClusterShape_MNK, + UMMA::Major UmmaMajorA, + UMMA::Major UmmaMajorB, + uint32_t SFVectorSize, + class BuilderScheduleTag +> +struct Sm103TrivialBlockscaledMma< ElementAMma, + ElementBMma, + ElementAccumulator, + ElementSF, + TileShape_MNK, + ClusterShape_MNK, + UmmaMajorA, + UmmaMajorB, + SFVectorSize, + BuilderScheduleTag, + true /*Is2SM*/> { + using type = decltype(sm103_make_blockscaled_2sm_tiled_mma()); + }; + +template< + class ElementAMma, + class ElementBMma, + class ElementAccumulator, + class ElementSF, + class TileShape_MNK, + class ClusterShape_MNK, + UMMA::Major UmmaMajorA, + UMMA::Major UmmaMajorB, + uint32_t SFVectorSize, + class BuilderScheduleTag +> +struct Sm103TrivialBlockscaledMma< ElementAMma, + ElementBMma, + ElementAccumulator, + ElementSF, + TileShape_MNK, + ClusterShape_MNK, + UmmaMajorA, + UmmaMajorB, + SFVectorSize, + BuilderScheduleTag, + false /*Is2SM*/> { + using type = decltype(sm103_make_blockscaled_1sm_tiled_mma()); +}; + +template +CUTLASS_HOST_DEVICE +static constexpr bool +is_sm103_block_scale_input() { + // Allowed input element datatype for block-scaling GEMM + return ( cute::is_same_v || + cute::is_same_v); +} + +template +constexpr +auto sm103_sfa_smem_atom_layout() { + constexpr int SF_BUFFERS_PER_TILE_K = BlockScaleConfig::SFVecSize == 16 ? 4 : 2; + auto mma_sfa_tiler = make_shape(get<0,0>(MmaShapeA_MK{})*get<1>(MmaShapeA_MK{}), get<0,1>(MmaShapeA_MK{}) * get<2>(MmaShapeA_MK{}) / Int{}); + return tiled_product(typename BlockScaleConfig::SfAtom{}, + make_layout(shape_div(mma_sfa_tiler, product_each(shape(typename BlockScaleConfig::SfAtom{}))))); +} + +template +constexpr +auto sm103_sfb_smem_atom_layout() { +auto sSFB = [&]() { + constexpr int MMA_N = get<0>(MmaShapeB_NK{}); + constexpr int NonPow2N = 192; + constexpr int NonPow2N_RoundUp = 256; + // If MMA_N is 192, we need to operate at MMA_N = 256 granularity for UTCCP to work for ScaleFactorB. + // Both TMA and UTCCP will transfer scale factor B as if we have 256 columns in B matrix. + constexpr int MMA_N_SFB = (MMA_N == NonPow2N) ? NonPow2N_RoundUp : MMA_N; + constexpr int SF_BUFFERS_PER_TILE_K = BlockScaleConfig::SFVecSize == 16 ? 4 : 2; + auto mma_sfb_tiler = make_shape(Int{}, get<1>(MmaShapeB_NK{}) / Int{}); + if constexpr(Int{} == Int<128>{}) { + return tiled_product(typename BlockScaleConfig::SfAtom{}, + make_layout(shape_div(mma_sfb_tiler,product_each(shape(typename BlockScaleConfig::SfAtom{}))))); + + } + else { + using SfKMajorAtom256 = Layout< Shape< Shape<_32,_4, _2>, Shape, _4>>, + Stride(mma_sfb_tiler)/SFVecSize/4*512>>, Stride< _0, _1>>>; + return tiled_product(SfKMajorAtom256{}, + make_layout(shape_div(mma_sfb_tiler,product_each(shape(SfKMajorAtom256{}))))); + } + }(); + return sSFB; +} + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class ArchTag, + class ElementPairA, + class GmemLayoutATag, + int AlignmentA, + class ElementPairB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ClusterShape_MNK, // Static cluster shape or dynamic (int, int, _1) + class StageCountType, + class BuilderScheduleTag +> +struct CollectiveBuilder< + ArchTag, + arch::OpClassBlockScaledTensorOp, + ElementPairA, + GmemLayoutATag, + AlignmentA, + ElementPairB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + BuilderScheduleTag, + cute::enable_if_t< + (cute::is_same_v + ) && + // Not paired input, Not Complex input + (cute::is_tuple_v && cute::is_tuple_v && + not cute::is_complex_v && not cute::is_complex_v) && + // Blockscaled Gemm + (cute::is_base_of_v || + cute::is_base_of_v || + cute::is_same_v) && + // Alignment check + detail::sm1xx_blockscaled_gemm_is_aligned(ElementPairA{}))>, + AlignmentA, + remove_cvref_t(ElementPairB{}))>, + AlignmentB, + BuilderScheduleTag>()>> +{ + using ElementA = remove_cvref_t(ElementPairA{}))>; + using ElementB = remove_cvref_t(ElementPairB{}))>; + using ElementSF = remove_cvref_t(ElementPairA{}))>; + + static_assert(cute::is_tuple::value, "Expecting ElementPairA to be a tuple."); + static_assert(cute::is_tuple::value, "Expecting ElementPairB to be a tuple."); + + static_assert(cute::is_static_v, "TileShape has to be static"); + static_assert(cute::size<2>(TileShape_MNK{}) == _768{}, "TileShape_K should 768 for MMA kernels"); + + static constexpr cute::UMMA::Major UmmaMajorA = cutlass::gemm::collective::detail::tag_to_umma_major_A(); + static constexpr cute::UMMA::Major UmmaMajorB = cutlass::gemm::collective::detail::tag_to_umma_major_B(); + static_assert(cutlass::gemm::detail::is_k_major_A() && cutlass::gemm::detail::is_k_major_B(), "Only K major inputs are supported"); + + static_assert(cutlass::gemm::collective::detail::is_sm103_block_scale_input(), "Incorrect type for A matrix"); + static_assert(cutlass::gemm::collective::detail::is_sm103_block_scale_input(), "Incorrect type for B matrix"); + + static_assert(cute::is_same_v || + cute::is_same_v, "Incorrect scale factor type"); + + // Data type used by MMA instruction + using ElementAMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element()); + using ElementBMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element()); + + static constexpr uint32_t SFVectorSize = detail::find_vector_size(); + + static constexpr bool is_2sm = cute::is_base_of_v || + (cute::is_same_v && + (cute::is_static_v && cute::get<0>(ClusterShape_MNK{}) % 2 == 0)); + + using TiledMma = typename cutlass::gemm::collective::detail::Sm103TrivialBlockscaledMma::type; + + using AtomThrID = typename TiledMma::AtomThrID; + using Sm1xxBlkScaledConfig = cutlass::detail::Sm103BlockScaledConfig; + + using ElementAMma_SmemAllocType = uint8_t; + // ElementAMma; + using ElementBMma_SmemAllocType = uint8_t; + // ElementBMma; + + // ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K) + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + // ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K) + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + + using GmemTiledCopyA = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_A( + ClusterShape_MNK{}, AtomThrID{})); + + using GmemTiledCopyB = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_B( + ClusterShape_MNK{}, AtomThrID{})); + + using GmemTiledCopySFA = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_A( + ClusterShape_MNK{}, AtomThrID{})); + + using GmemTiledCopySFB = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_SFB( + ClusterShape_MNK{}, AtomThrID{})); + + using GmemTiledCopyPairA = decltype(cute::make_tuple(GmemTiledCopyA{}, GmemTiledCopySFA{})); + using GmemTiledCopyPairB = decltype(cute::make_tuple(GmemTiledCopyB{}, GmemTiledCopySFB{})); + + // + // Construct SMEM layout (SmemLayoutAtom) for A and SFA + // + using SmemLayoutAtomA = UMMA::Layout_K_SW128_Atom; + // A single indivisible block will hold 4 scale factors of 128 rows/columns (A/B matrix). + // 4 is chosen to make consecutive 32bits of data to have scale factors for only a single row (col). 32bits corresponds to the TMEM word size + static constexpr int MMA_M = cute::size<0>(TileShape_MNK{}) / cute::size(AtomThrID{}); + using SmemLayoutAtomSFA = decltype(detail::sm103_sfa_smem_atom_layout()); + using SmemLayoutAtomsA = decltype(cute::make_tuple(SmemLayoutAtomA{}, SmemLayoutAtomSFA{})); + + // + // Construct SMEM layout(SmemLayoutAtom)for B and SFB + // + + using SmemLayoutAtomB = UMMA::Layout_K_SW128_Atom; + static constexpr int MMA_N = cute::size<1>(TileShape_MNK{}); + // If MMA_N is 192, we need to operate at MMA_N = 256 granularity for UTCCP to work for ScaleFactorB. + // Both TMA and UTCCP will transfer scale factor B as if we have 256 columns in B matrix. + using SmemLayoutAtomSFB = decltype(detail::sm103_sfb_smem_atom_layout(TileShape_MNK{})),SFVectorSize>()); + using SmemLayoutAtomsB = decltype(cute::make_tuple(SmemLayoutAtomB{}, SmemLayoutAtomSFB{})); + + // + // Construct Strides for A, SFA, B, and SFB + // + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideB = cutlass::gemm::TagToStrideB_t; + using InternalStrideA = cute::remove_pointer_t; + using InternalStrideB = cute::remove_pointer_t; + using InternalLayoutSFA = decltype(Sm1xxBlkScaledConfig::deduce_layoutSFA()); + using InternalLayoutSFB = decltype(Sm1xxBlkScaledConfig::deduce_layoutSFB()); + using LayoutSFA = cute::conditional_t, InternalLayoutSFA, InternalLayoutSFA *>; + using LayoutSFB = cute::conditional_t, InternalLayoutSFB, InternalLayoutSFB *>; + using StridePairA = decltype(cute::make_tuple(StrideA{}, LayoutSFA{})); + using StridePairB = decltype(cute::make_tuple(StrideB{}, LayoutSFB{})); + + // + // Others + // + + static constexpr cutlass::sm103::detail::KernelPrefetchType PrefetchType = cute::is_base_of_v + || cute::is_base_of_v + ? cutlass::sm103::detail::KernelPrefetchType::Disable : + cutlass::sm103::detail::KernelPrefetchType::TmaPrefetch; + + static constexpr uint32_t AccumulatorPipelineStageCount = (MMA_N == 256) ? 1 : 2; + static constexpr uint32_t SchedulerPipelineStageCount = 3; + + // AccumulatorPipeline = PipelineUmmaAsync + static constexpr auto AccumulatorPipelineStorage = sizeof(typename cutlass::PipelineUmmaAsync::SharedStorage); + // CLCPipeline = PipelineCLCFetchAsync + static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); + // LoadOrderBarrier = OrderedSequenceBarrier<1,2> + static constexpr auto LoadOrderBarrierStorage = sizeof(typename cutlass::OrderedSequenceBarrier<1,2>::SharedStorage); + // CLC (scheduler) response + static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize; + // CLC Throttle pipeline storage + static constexpr auto CLCThrottlePipelineStorage = sizeof(typename cutlass::PipelineAsync::SharedStorage); + // Tmem dealloc + static constexpr auto TmemDeallocStorage = sizeof(cutlass::arch::ClusterBarrier); + // Tmem ptr storage + static constexpr auto TmemBasePtrsStorage = AccumulatorPipelineStageCount * sizeof(uint32_t); + // Tensormap Storage + static constexpr bool IsArrayOfPointersGemm = cute::is_base_of_v; + static constexpr auto TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 4 /* for A, B, SFA and SFB */ : 0; + // TMA Load Prefetch Storage + static constexpr auto TmaPrefetchStorage = 0; + // Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage + static constexpr auto KernelSmemCarveout = static_cast( AccumulatorPipelineStorage + + CLCPipelineStorage + + LoadOrderBarrierStorage + + CLCResponseStorage + + CLCThrottlePipelineStorage + + TmemDeallocStorage + + TmemBasePtrsStorage + + TensorMapStorage + + TmaPrefetchStorage); + // Reduce SMEM capacity available for buffers considering barrier allocations. + static constexpr int ReducedSmemCapacityBytes = + cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + + using SmemTileShape = cute::Shape, Int, _128>; // SmemAllocTypes are uint8_t. We always allocate 128bytes + static constexpr auto PipelineStages = cutlass::gemm::collective::detail::sm103_compute_stage_count_or_override_blockscaled< + ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, SmemLayoutAtomSFA, SmemLayoutAtomSFB>(StageCountType{}); + + using DispatchPolicy = typename cute::conditional_t(PipelineStages), + get<1>(PipelineStages), + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape_MNK, + PrefetchType + >, + cutlass::gemm::MainloopSm103TmaUmmaWarpSpecializedBlockScaled< + get<0>(PipelineStages), + get<1>(PipelineStages), + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape_MNK, + PrefetchType + > + >; + + using CollectiveOp = cutlass::gemm::collective::CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementPairA, + StridePairA, + ElementPairB, + StridePairB, + TiledMma, + GmemTiledCopyPairA, + SmemLayoutAtomsA, + void, + cute::identity, + GmemTiledCopyPairB, + SmemLayoutAtomsB, + void, + cute::identity + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl index 862d430..99b1323 100755 --- a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl +++ b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -186,13 +186,13 @@ struct CollectiveBuilder< // Basic storage block for new Scaling Factor Layouts using mnBasicBlockShape = Shape<_32,_4>; using mnBasicBlockStride = Stride<_16,_4>; - using kBasicBlockShape = Shape, Int>; + using kBasicBlockShape = Shape, Int>; using kBasicBlockStride = Stride<_0, _1>; using sSFA_shapeM = decltype(prepend(size<0>(TileShape_MNK{}) / Blk_MN{}, mnBasicBlockShape{})); using sSF_strideMN = decltype(prepend( Blk_Elems{}, mnBasicBlockStride{})); using sSFA_strideM = sSF_strideMN; - using sSF_shapeK = decltype(prepend(make_shape( Blk_SF{}/Int{}, size<2>(TileShape_MNK{}) / Int{} / Blk_SF{}), kBasicBlockShape{})); + using sSF_shapeK = decltype(prepend(make_shape( Blk_SF{}/Int{}, size<2>(TileShape_MNK{}) / Int<(int)SFVectorSize>{} / Blk_SF{}), kBasicBlockShape{})); using sSFA_strideK = decltype(prepend(make_stride( Int{}, size<0>(TileShape_MNK{}) / Blk_MN{} * Blk_Elems{}), kBasicBlockStride{})); using sSFA_shape = decltype(make_shape( sSFA_shapeM{}, sSF_shapeK{})); @@ -209,11 +209,6 @@ struct CollectiveBuilder< using SmemLayoutAtomsA = decltype(cute::make_tuple(SmemLayoutAtomA{}, SmemLayoutAtomSFA{})); using SmemLayoutAtomsB = decltype(cute::make_tuple(SmemLayoutAtomB{}, SmemLayoutAtomSFB{})); - static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_blockscaled< - detail::sm120_smem_capacity_bytes, SmemAllocTypeA, SmemAllocTypeB, TileShape_MNK, SmemLayoutAtomSFA, SmemLayoutAtomSFB>(StageCountType{}); - - static constexpr uint32_t SchedulerPipelineStageCount = 3; - using StrideA = cutlass::gemm::TagToStrideA_t; using StrideB = cutlass::gemm::TagToStrideB_t; using InternalStrideA = cute::remove_pointer_t; @@ -232,6 +227,34 @@ struct CollectiveBuilder< cute::is_base_of_v, "Invalid builder schedule tag for grouped GEMM"); + + static constexpr uint32_t SchedulerPipelineStageCount = 3; + + static constexpr int CLCResponseSize = sizeof(typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100,1>::CLCResponse{}); + + static constexpr auto SchedulerPipelineStorage = IsGroupedGemmKernel ? sizeof(cutlass::PipelineDetail::PipelineAsyncSharedStorage<8>) + : sizeof(typename cutlass::PipelineCLCFetchAsync>::SharedStorage); + static constexpr auto CLCResponseStorage = IsGroupedGemmKernel ? 0 : (SchedulerPipelineStageCount * + CLCResponseSize); + static constexpr auto TensorMapStorage = + IsGroupedGemmKernel ? sizeof(cute::TmaDescriptor) * 2 /* We have two tensormaps smem */ : + 0; + + // TensorMapReady pipeline storage (specific to grouped/array kernels) + static constexpr auto TensorMapReadyPipelineStorage = + IsGroupedGemmKernel ? sizeof(typename cutlass::PipelineAsync::SharedStorage) : + 0; + + static constexpr int ReducedSmemCapacityBytes = detail::sm120_smem_capacity_bytes - + SchedulerPipelineStorage - + TensorMapStorage - + TensorMapReadyPipelineStorage - + CLCResponseStorage; + + static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_blockscaled< + ReducedSmemCapacityBytes, SmemAllocTypeA, SmemAllocTypeB, TileShape_MNK, SmemLayoutAtomSFA, SmemLayoutAtomSFB>(StageCountType{}); + + using KernelSchedule = cute::conditional_t(TileShape_MNK{}), Blk_MN{}) * Blk_MN{}, ceil_div(size<1>(TileShape_MNK{}), Blk_MN{}) * Blk_MN{}, shape<2>(TileShape_MNK{}))); @@ -279,13 +279,13 @@ struct CollectiveBuilder< // Basic storage block for new Scaling Factor Layouts using mnBasicBlockShape = Shape<_32,_4>; using mnBasicBlockStride = Stride<_16,_4>; - using kBasicBlockShape = Shape, Int>; + using kBasicBlockShape = Shape, Int>; using kBasicBlockStride = Stride<_0, _1>; using sSFA_shapeM = decltype(prepend(size<0>(TileShape_MNK{}) / Blk_MN{}, mnBasicBlockShape{})); using sSF_strideMN = decltype(prepend( Blk_Elems{}, mnBasicBlockStride{})); using sSFA_strideM = sSF_strideMN; - using sSF_shapeK = decltype(prepend(make_shape( Blk_SF{}/Int{}, size<2>(TileShape_MNK{}) / Int{} / Blk_SF{}), kBasicBlockShape{})); + using sSF_shapeK = decltype(prepend(make_shape( Blk_SF{}/Int{}, size<2>(TileShape_MNK{}) / Int<(int)SFVectorSize>{} / Blk_SF{}), kBasicBlockShape{})); using sSFA_strideK = decltype(prepend(make_stride( Int{}, size<0>(TileShape_MNK{}) / Blk_MN{} * Blk_Elems{}), kBasicBlockStride{})); using sSFA_shape = decltype(make_shape( sSFA_shapeM{}, sSF_shapeK{})); diff --git a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm120_blockwise_mma_builder.inl b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm120_blockwise_mma_builder.inl index 4b5858c..6dd884c 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm120_blockwise_mma_builder.inl +++ b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm120_blockwise_mma_builder.inl @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm120_common.inl b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm120_common.inl index 45e201b..a1ccdd1 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm120_common.inl +++ b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm120_common.inl @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm120_mma_builder.inl b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm120_mma_builder.inl index b75573a..490a8ad 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm120_mma_builder.inl +++ b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm120_mma_builder.inl @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm120_sparse_mma_builder.inl b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm120_sparse_mma_builder.inl index 36ed318..a7fe826 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm120_sparse_mma_builder.inl +++ b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm120_sparse_mma_builder.inl @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm1xx_common.inl b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm1xx_common.inl index a6444e0..2dd6fda 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm1xx_common.inl +++ b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm1xx_common.inl @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -141,6 +141,18 @@ constexpr uint32_t find_vector_size() { cute::is_same_v || cute::is_same_v || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v ) { return 16; } @@ -490,6 +502,7 @@ check_input_datatypes() { || (cute::is_same_v) || (cute::is_same_v) || (cute::is_same_v) + || (cute::is_same_v) // SM100 BS ptr_array || (cute::is_same_v) || (cute::is_same_v) @@ -566,6 +579,8 @@ check_input_datatypes() { ((SfVectorSizeA == 32 && cute::is_same_v) || (SfVectorSizeA == 32 && cute::is_same_v) || (SfVectorSizeA == 32 && cute::is_same_v) + || (SfVectorSizeA == 32 && cute::is_same_v) + || (SfVectorSizeA == 32 && cute::is_same_v) || (SfVectorSizeA == 32 && cute::is_base_of_v) || (SfVectorSizeA == 32 && cute::is_base_of_v) || (SfVectorSizeA == 64 && cute::is_base_of_v) @@ -745,6 +760,8 @@ select_instr() { (SfVectorSize == 32 && cute::is_same_v) || (SfVectorSize == 32 && cute::is_base_of_v) || (SfVectorSize == 32 && cute::is_base_of_v) + || (SfVectorSize == 32 && cute::is_base_of_v) + || (SfVectorSize == 32 && cute::is_base_of_v) || (SfVectorSize == 32 && cute::is_base_of_v) || (SfVectorSize == 32 && cute::is_base_of_v) || (SfVectorSize == 64 && cute::is_base_of_v diff --git a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm1xx_sparse_config.inl b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm1xx_sparse_config.inl index 2afe099..3833c6c 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm1xx_sparse_config.inl +++ b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm1xx_sparse_config.inl @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm90_common.inl b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm90_common.inl index b1f4f1f..ae08658 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm90_common.inl +++ b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm90_common.inl @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl index c75af3a..01a737f 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl +++ b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -52,21 +52,21 @@ namespace cutlass::gemm::collective { namespace detail { // Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. -template +template constexpr int compute_stage_count_or_override(StageCount stage_count) { return stages; } // Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. -template +template constexpr int compute_stage_count_or_override(cute::Int stage_count) { return stages; } // Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. -template +template constexpr int compute_stage_count_or_override(StageCountAutoCarveout stage_count) { constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage); @@ -85,7 +85,7 @@ compute_stage_count_or_override(StageCountAutoCarveout stage_co } // Returns the maximum number of smem tiles that can be used with a given smem capacity in gemm of blockwise/groupwise scale. -template +template constexpr int compute_stage_count_with_blockwise_scale(StageCountAutoCarveout stage_count) { constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage); @@ -107,7 +107,14 @@ compute_stage_count_with_blockwise_scale(StageCountAutoCarveout } // Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count. -template +template +constexpr int +compute_stage_count_or_override_single_affine_transformed_input(cute::Int stage_count) { + return stages; +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count. +template constexpr int compute_stage_count_or_override_single_affine_transformed_input(StageCount stage_count) { return stages; @@ -124,7 +131,7 @@ constexpr int get_bits_for_possibly_void_element() { } // Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count. -template +template constexpr int compute_stage_count_or_override_single_affine_transformed_input(StageCountAutoCarveout stage_count) { @@ -456,12 +463,12 @@ public: static constexpr int PipelineStages = IsMixedInput ? ( IsArrayOfPointersGemm ? detail::compute_stage_count_or_override_single_affine_transformed_input(StageCountType{}) : + RealElementA, RealElementB, ElementScale, ElementZero, TileShape_MNK, SmemAlignment>(StageCountType{}) : detail::compute_stage_count_or_override_single_affine_transformed_input(StageCountType{}) + RealElementA, RealElementB, ElementScale, ElementZero, TileShape_MNK, SmemAlignment>(StageCountType{}) ) : detail::compute_stage_count_or_override(StageCountType{}); + ElementAMma, ElementBMma, TileShape_MNK, SmemAlignment>(StageCountType{}); using DispatchPolicy = cute::conditional_t or - cute::is_same_v or - cute::is_same_v or - cute::is_same_v) and + (cute::is_same_v or + cute::is_same_v or + cute::is_same_v or + cute::is_same_v) and not detail::is_use_rmem_A() > > { @@ -1105,7 +1112,7 @@ struct CollectiveBuilder< cute::is_base_of_v); static constexpr bool IsFP8Input = detail::is_input_fp8(); - static_assert(IsFP8Input, "Warp Specialized gemm with FP8 BlockScaled Accumulator is only compatible with FP8 Blocked Scaled version right now."); + static_assert(IsFP8Input, "Warp Specialized gemm with FP8 Blockwise (Software) Scaling is only compatible with FP8 inputs version right now."); // For fp32 types, map to tf32 MMA value type using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>; @@ -1133,6 +1140,7 @@ struct CollectiveBuilder< GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0; + // Reserve 128B for 8 stages of tile scheduling static constexpr size_t SchedulerPipelineStorage = cute::is_pointer_v> ? sizeof(cutlass::PipelineDetail::PipelineAsyncSharedStorage<8>) : 0; @@ -1146,8 +1154,8 @@ struct CollectiveBuilder< static constexpr int PipelineStages = detail::compute_stage_count_with_blockwise_scale(StageCountType{}); using DispatchPolicy = cute::conditional_t, - MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8>; + MainloopSm90ArrayTmaGmmaWarpSpecializedBlockwise, + MainloopSm90TmaGmmaWarpSpecializedBlockwiseFP8>; using SmemCopyAtomA = void; using SmemCopyAtomB = void; diff --git a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm90_sparse_config.inl b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm90_sparse_config.inl index 09da42c..82a1499 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm90_sparse_config.inl +++ b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm90_sparse_config.inl @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl index 541b45e..fdfdc1d 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl +++ b/3rd/cutlass/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/collective/collective_builder.hpp b/3rd/cutlass/include/cutlass/gemm/collective/collective_builder.hpp index b03c79c..e5f51df 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/collective_builder.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/collective_builder.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -39,18 +39,27 @@ #include "cutlass/gemm/collective/collective_builder_decl.hpp" #include "cutlass/gemm/collective/builders/sm90_gmma_builder.inl" #include "cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl" -#if !defined(__CUDACC_RTC__) -#include "cutlass/gemm/collective/builders/sm100_umma_builder.inl" -#include "cutlass/gemm/collective/builders/sm100_9xBF16_umma_builder.inl" +#if !defined(__CUDACC_RTC__) +#include "cutlass/gemm/collective/builders/sm100_umma_builder.inl" +#include "cutlass/gemm/collective/builders/sm100_9xBF16_umma_builder.inl" #include "cutlass/gemm/collective/builders/sm100_sparse_umma_builder.inl" -#include "cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl" +#include "cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl" #include "cutlass/gemm/collective/builders/sm100_blockwise_umma_builder.inl" #include "cutlass/gemm/collective/builders/sm100_blockscaled_sparse_umma_builder.inl" +#include "cutlass/gemm/collective/builders/sm100_simt_builder.inl" +#include "cutlass/gemm/collective/builders/sm100_mixed_input_umma_builder.inl" +#include "cutlass/gemm/collective/builders/sm100_cpasync_umma_builder.inl" +#include "cutlass/gemm/collective/builders/sm100_mixed_tma_cpasync_umma_builder.inl" +#include "cutlass/gemm/collective/builders/sm100_blockscaled_mixed_tma_cpasync_umma_builder.inl" +#include "cutlass/gemm/collective/builders/sm103_blockscaled_umma_builder.inl" #include "cutlass/gemm/collective/builders/sm120_mma_builder.inl" #include "cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl" #include "cutlass/gemm/collective/builders/sm120_sparse_mma_builder.inl" #include "cutlass/gemm/collective/builders/sm120_blockscaled_sparse_mma_builder.inl" #include "cutlass/gemm/collective/builders/sm120_blockwise_mma_builder.inl" +#include "cutlass/gemm/collective/builders/sm100_interleaved_complex_umma_builder.inl" +#include "cutlass/gemm/collective/builders/sm100_9xBF16_interleaved_complex_umma_builder.inl" +#include "cutlass/gemm/collective/builders/sm100_planar_complex_umma_builder.inl" #endif ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/gemm/collective/collective_builder_decl.hpp b/3rd/cutlass/include/cutlass/gemm/collective/collective_builder_decl.hpp index aae7334..bf2a61f 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/collective_builder_decl.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/collective_builder_decl.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/collective/collective_mma.hpp b/3rd/cutlass/include/cutlass/gemm/collective/collective_mma.hpp index f65dd70..2cc5697 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/collective_mma.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/collective_mma.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -37,11 +37,12 @@ #include "cutlass/gemm/collective/sm70_mma_twostage.hpp" #include "cutlass/gemm/collective/sm80_mma_multistage.hpp" +#include "cutlass/gemm/collective/sm80_mma_array_multistage.hpp" #include "cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp" #include "cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp" -#include "cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp" +#include "cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp" #include "cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized.hpp" #include "cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized_fp8.hpp" @@ -54,22 +55,36 @@ #if !defined(__CUDACC_RTC__) #include "cutlass/gemm/collective/sm100_mma_warpspecialized.hpp" #include "cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm100_mma_array_warpspecialized_rcggemm.hpp" +#include "cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized_rcggemm.hpp" #include "cutlass/gemm/collective/sm100_mma_warpspecialized_emulated.hpp" #include "cutlass/gemm/collective/sm100_mma_array_warpspecialized_emulated.hpp" #include "cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp" #include "cutlass/gemm/collective/sm100_blockscaled_sparse_mma_warpspecialized.hpp" -#include "cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp" -#include "cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp" #include "cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp" #include "cutlass/gemm/collective/sm100_mma_array_warpspecialized_blockwise_scaling.hpp" +#include "cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp" +#include "cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm100_mma_mixed_tma_cpasync_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm100_blockscaled_mma_mixed_tma_cpasync_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm103_blockscaled_mma_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm103_blockscaled_mma_array_warpspecialized.hpp" #include "cutlass/gemm/collective/sm120_mma_tma.hpp" #include "cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp" #include "cutlass/gemm/collective/sm120_blockscaled_mma_array_tma.hpp" -#include "cutlass/gemm/collective/sm120_sparse_mma_tma.hpp" -#include "cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp" #include "cutlass/gemm/collective/sm120_mma_tma_blockwise_scaling.hpp" #include "cutlass/gemm/collective/sm120_mma_array_tma_blockwise_scaling.hpp" -#endif // !defined(__CUDACC_RTC__) +#include "cutlass/gemm/collective/sm100_mma_warpspecialized_interleaved_complex_emulated.hpp" +#include "cutlass/gemm/collective/sm100_mma_array_warpspecialized_interleaved_complex_emulated.hpp" +#include "cutlass/gemm/collective/sm100_mma_warpspecialized_interleaved_complex_tf32.hpp" +#include "cutlass/gemm/collective/sm100_mma_array_warpspecialized_interleaved_complex_tf32.hpp" +#include "cutlass/gemm/collective/sm100_mma_warpspecialized_planar_complex.hpp" +#include "cutlass/gemm/collective/sm100_mma_array_warpspecialized_planar_complex.hpp" +#include "cutlass/gemm/collective/sm120_sparse_mma_tma.hpp" +#include "cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp" +#endif // !defined(__CUDACC_RTC__) diff --git a/3rd/cutlass/include/cutlass/gemm/collective/collective_mma_decl.hpp b/3rd/cutlass/include/cutlass/gemm/collective/collective_mma_decl.hpp index a2faa1f..1a7fd0b 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/collective_mma_decl.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/collective_mma_decl.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/collective/fp8_accumulation.hpp b/3rd/cutlass/include/cutlass/gemm/collective/fp8_accumulation.hpp index 6ff3a94..f9fd989 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/fp8_accumulation.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/fp8_accumulation.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -56,7 +56,6 @@ struct GmmaFP8Accumulation { static_assert(is_rmem::value , "Accumulator tensor must be rmem resident."); private: - TensorAccum& accum_; TensorAccum accum_temp_; uint32_t accum_promotion_interval_; // defines the max num of executed MMAs after which accum should be promoted. @@ -65,8 +64,10 @@ struct GmmaFP8Accumulation { uint32_t reset_accum_flag_; // accum needs to be zeroed or not. // promote or `add` the partial accumulators to main accumulator (FADD). + template CUTLASS_DEVICE - void promote_core() { + void promote_core(TensorAccumOrig &accum_) { + CUTE_STATIC_ASSERT_V(size(accum_) == size(accum_temp_)); warpgroup_wait<0>(); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(accum_); ++i) { @@ -75,8 +76,10 @@ struct GmmaFP8Accumulation { } // `multiply` scale the partial accumulators and `add` to main accumulator (FFMA). + template CUTLASS_DEVICE - void scale_core(ElementAccumulator const &scale) { + void scale_core(TensorAccumOrig &accum_, ElementAccumulator const &scale) { + CUTE_STATIC_ASSERT_V(size(accum_) == size(accum_temp_)); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(accum_); ++i) { accum_(i) += accum_temp_(i) * scale; @@ -84,16 +87,17 @@ struct GmmaFP8Accumulation { } template < + class TensorAccumOrig, class EngineScale, class LayoutScale> CUTLASS_DEVICE - void scale_core(const cute::Tensor &scale) { + void scale_core(TensorAccumOrig &accum_, const cute::Tensor &scale) { using TensorScale = cute::Tensor; static_assert(is_static::value, "Scale Layout should be static"); static_assert(is_rmem::value , "Scale tensor must be rmem resident."); - static_assert(LayoutAccum{}.shape() == LayoutScale{}.shape(), "Accumulator and scale must have same shape."); + CUTE_STATIC_ASSERT_V(size(accum_) == size(accum_temp_)); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(accum_); ++i) { @@ -102,12 +106,13 @@ struct GmmaFP8Accumulation { } template < + class TensorAccumOrig, class EngineScaleA, class LayoutScaleA, class EngineScaleB, class LayoutScaleB> CUTLASS_DEVICE - void scale_core(const cute::Tensor &scaleA, const cute::Tensor &scaleB) { + void scale_core(TensorAccumOrig &accum_, const cute::Tensor &scaleA, const cute::Tensor &scaleB) { using TensorScaleA = cute::Tensor; using TensorScaleB = cute::Tensor; @@ -116,8 +121,10 @@ struct GmmaFP8Accumulation { static_assert(is_rmem::value, "ScaleA tensor must be rmem resident."); static_assert(is_rmem::value, "ScaleB tensor must be rmem resident."); - static_assert(LayoutAccum{}.shape() == LayoutScaleA{}.shape(), "Accumulator and scaleA must have same shape."); - static_assert(LayoutAccum{}.shape() == LayoutScaleB{}.shape(), "Accumulator and scaleB must have same shape."); + + CUTE_STATIC_ASSERT_V(size(accum_) == size(accum_temp_)); + CUTE_STATIC_ASSERT_V(size(accum_) == size(scaleA)); + CUTE_STATIC_ASSERT_V(size(accum_) == size(scaleB)); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(accum_); ++i) { @@ -128,16 +135,15 @@ struct GmmaFP8Accumulation { public: CUTLASS_DEVICE GmmaFP8Accumulation( - TensorAccum &accum, + TensorAccum &accum_temp, uint32_t accum_promotion_interval, uint32_t mma_count_per_mainloop_iteration) - : accum_(accum), + : accum_temp_(accum_temp), accum_promotion_interval_(accum_promotion_interval), mma_count_per_mainloop_iteration_(mma_count_per_mainloop_iteration), mma_count_(0), reset_accum_flag_(0) { - accum_temp_ = cute::make_fragment_like(accum); } // @@ -160,21 +166,23 @@ struct GmmaFP8Accumulation { // /// promote (add) the results from the MMA accumulators to main accumulator if needed. + template CUTLASS_DEVICE - void promote_if_needed() { + void promote_if_needed(TensorAccumOrig &accum_) { mma_count_ += mma_count_per_mainloop_iteration_; reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); if (reset_accum_flag_) { - promote_core(); + promote_core(accum_); mma_count_ = 0; } } /// promote (add) the residue results from the MMA accumulators to main accumulator if needed. + template CUTLASS_DEVICE - void promote_residue_if_needed() { + void promote_residue_if_needed(TensorAccumOrig &accum_) { if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) { - promote_core(); + promote_core(accum_); } } @@ -183,95 +191,104 @@ struct GmmaFP8Accumulation { // /// scale (multiply_add) the results from the MMA accumulators to main accumulator if needed. + template CUTLASS_DEVICE - void scale_if_needed(ElementAccumulator const &scale) { + void scale_if_needed(TensorAccumOrig &accum_, ElementAccumulator const &scale) { mma_count_ += mma_count_per_mainloop_iteration_; reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); if (reset_accum_flag_) { - scale_core(scale); + scale_core(accum_, scale); mma_count_ = 0; } } template < + class TensorAccumOrig, class EngineScale, class LayoutScale> CUTLASS_DEVICE - void scale_if_needed(const cute::Tensor &scale) { + void scale_if_needed(TensorAccumOrig &accum_, const cute::Tensor &scale) { mma_count_ += mma_count_per_mainloop_iteration_; reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); if (reset_accum_flag_) { - scale_core(scale); + scale_core(accum_, scale); mma_count_ = 0; } } template < + class TensorAccumOrig, class EngineScaleA, class LayoutScaleA, class EngineScaleB, class LayoutScaleB> CUTLASS_DEVICE - void scale_if_needed(const cute::Tensor &scaleA, const cute::Tensor &scaleB) { + void scale_if_needed(TensorAccumOrig &accum_, const cute::Tensor &scaleA, const cute::Tensor &scaleB) { mma_count_ += mma_count_per_mainloop_iteration_; reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); if (reset_accum_flag_) { - scale_core(scaleA, scaleB); + scale_core(accum_, scaleA, scaleB); mma_count_ = 0; } } /// scale (multiply_add) the results from the MMA accumulators to main accumulator without checking the counter. + template CUTLASS_DEVICE - void scale(ElementAccumulator const &scale) { - scale_core(scale); + void scale(TensorAccumOrig &accum_, ElementAccumulator const &scale) { + scale_core(accum_, scale); } template < + class TensorAccumOrig, class EngineScale, class LayoutScale> CUTLASS_DEVICE - void scale(const cute::Tensor &scale) { - scale_core(scale); + void scale(TensorAccumOrig &accum_, const cute::Tensor &scale) { + scale_core(accum_, scale); } template < + class TensorAccumOrig, class EngineScaleA, class LayoutScaleA, class EngineScaleB, class LayoutScaleB> CUTLASS_DEVICE - void scale(const cute::Tensor &scaleA, const cute::Tensor &scaleB) { - scale_core(scaleA, scaleB); + void scale(TensorAccumOrig &accum_, const cute::Tensor &scaleA, const cute::Tensor &scaleB) { + scale_core(accum_, scaleA, scaleB); } /// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed. + template CUTLASS_DEVICE - void scale_residue_if_needed(ElementAccumulator const &scale) { + void scale_residue_if_needed(TensorAccumOrig &accum_, ElementAccumulator const &scale) { if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) { - scale_core(scale); + scale_core(accum_, scale); } } template < + class TensorAccumOrig, class EngineScale, class LayoutScale> CUTLASS_DEVICE - void scale_residue_if_needed(const cute::Tensor &scale) { + void scale_residue_if_needed(TensorAccumOrig &accum_, const cute::Tensor &scale) { if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) { - scale_core(scale); + scale_core(accum_, scale); } } template < + class TensorAccumOrig, class EngineScaleA, class LayoutScaleA, class EngineScaleB, class LayoutScaleB> CUTLASS_DEVICE - void scale_residue_if_needed(const cute::Tensor &scaleA, const cute::Tensor &scaleB) { + void scale_residue_if_needed(TensorAccumOrig &accum_, const cute::Tensor &scaleA, const cute::Tensor &scaleB) { if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) { - scale_core(scaleA, scaleB); + scale_core(accum_, scaleA, scaleB); } } }; diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp index 2665ef1..edacf6a 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -143,6 +143,11 @@ struct CollectiveMma< using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + // Multiple buffer the TMA descriptors for each SM so that we can update them asynchronously. + // This should be larger than the total number of TMA requests inflight (from update to issued to returned). + // This can be calculated by SchedulerStages + max(TmaStages) + 2 (for consumer and producer in-flight accessies). + constexpr static uint32_t NumTmaDescriptorsPerSm = SchedulerPipelineStageCount + Stages + 2; + using ElementPairA = ElementPairA_; using ElementPairB = ElementPairB_; using ElementAMma = typename TiledMma::ValTypeA; @@ -571,13 +576,18 @@ struct CollectiveMma< }; } + struct TensorMaps : cute::aligned_struct<256, _0> { + cute::TmaDescriptor tma_desc_a; + cute::TmaDescriptor tma_desc_b; + cute::TmaDescriptor tma_desc_sfa; + cute::TmaDescriptor tma_desc_sfb; + }; + template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { - constexpr uint32_t NumInputTensors = 4; - constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); - // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies - return (NumInputTensors * SizeOfCuTensorMap * sm_count); + // Allocate gmem space for input tensormaps per each SM. + return (sm_count * sizeof(TensorMaps) * NumTmaDescriptorsPerSm); } template @@ -674,7 +684,7 @@ struct CollectiveMma< /// mcast_mask_b - tma multicast mask for B /// mcast_mask_sfa - tma multicast mask for SFA /// mcast_mask_sfb - tma multicast mask for SFB - template + template CUTLASS_DEVICE auto load_init( ProblemShape_MNKL const& problem_shape_MNKL, @@ -682,6 +692,7 @@ struct CollectiveMma< TensorStorage& shared_tensors, TensorMapStorage& shared_tensormaps, int32_t const sm_count, int32_t const sm_idx, + [[maybe_unused]] int32_t num_groups, int32_t init_group) const { using X = Underscore; @@ -788,15 +799,19 @@ struct CollectiveMma< uint16_t mcast_mask_sfa = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); uint16_t mcast_mask_sfb = create_tma_multicast_mask<1>(cta_layout_sfb_vmnk, cta_coord_sfb_vmnk); - // Fetch a copy of tensormaps for the CTA from Params - auto input_tensormaps = tensormaps_init(params, shared_tensormaps, sm_count, sm_idx); - - return cute::make_tuple( - gA_mkl, gB_nkl, // for scheduler - tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values - tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB, // for input scale factor tensor values - mcast_mask_a, mcast_mask_b, mcast_mask_sfa, mcast_mask_sfb, // multicast masks - input_tensormaps); // for tma descriptor modification (per-CTA tensormap copy) + auto ret = cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB, // for input scale factor tensor values + mcast_mask_a, mcast_mask_b, mcast_mask_sfa, mcast_mask_sfb); // multicast masks + + if constexpr (IsTensorMapUpdateAsync) { + return ret; + } else { + // Fetch a copy of tensormaps for the CTA from Params + auto input_tensormaps = tensormaps_init(params, shared_tensormaps, sm_count, sm_idx); + return cute::tuple_cat(ret, cute::make_tuple(input_tensormaps)); + } } /// Set up the data needed by this collective for mma compute. @@ -895,7 +910,8 @@ struct CollectiveMma< cute::tuple> const& load_inputs, TileCoordMNKL const& cta_coord_mnkl, KTileIterator k_tile_iter, int k_tile_count, - bool did_batch_change) { + bool did_batch_change, + [[maybe_unused]] int curr_batch) { auto [unused_gA, unused_gB, tAgA_mkl, tBgB_nkl, tAsA, tBsB, @@ -1066,8 +1082,10 @@ struct CollectiveMma< } } else { - // Wait for tmem accumulator buffer to become empty with a flipped phase - accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + if (k_tile_count > 0) { + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + } } CUTLASS_PRAGMA_NO_UNROLL @@ -1116,19 +1134,15 @@ struct CollectiveMma< // Methods to perform different parts of TMA/Tensormap modifications // + template CUTLASS_DEVICE auto tensormaps_init( Params const& mainloop_params, TensorMapStorage& shared_tensormaps, int32_t const sm_count, int32_t const sm_idx) const { - cute::TmaDescriptor* gmem_tensormap = mainloop_params.tensormaps; - - cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx]; - cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count]; - cute::TmaDescriptor* tma_desc_sfa = &gmem_tensormap[sm_idx + 2 * sm_count]; - cute::TmaDescriptor* tma_desc_sfb = &gmem_tensormap[sm_idx + 3 * sm_count]; + TensorMaps* gmem_tensormap = &(reinterpret_cast(mainloop_params.tensormaps)[sm_idx * NumTmaDescriptorsPerSm]); if (cute::elect_one_sync()) { // Bringing tensormaps from params to smem for modification later @@ -1148,9 +1162,30 @@ struct CollectiveMma< copy(recast(pSFA_tensormap), recast(sSFA_tensormap)); copy(recast(pSFB_tensormap), recast(sSFB_tensormap)); } + __syncwarp(); - return cute::make_tuple(tma_desc_a, tma_desc_b, tma_desc_sfa, tma_desc_sfb); + struct TensorMapArray { + + TensorMaps *tensor_maps; + + TensorMapArray() = default; + + CUTLASS_DEVICE + TensorMapArray(void* tensormaps) : tensor_maps(reinterpret_cast(tensormaps)) {} + + CUTLASS_DEVICE + cute::tuple + operator[](int32_t idx) { + idx = idx % NumTmaDescriptorsPerSm; + return cute::make_tuple(&tensor_maps[idx].tma_desc_a, &tensor_maps[idx].tma_desc_b, &tensor_maps[idx].tma_desc_sfa, &tensor_maps[idx].tma_desc_sfb); + } + }; + if constexpr (IsTensorMapUpdateAsync) { + return TensorMapArray(gmem_tensormap); + } else { + return cute::make_tuple(&gmem_tensormap->tma_desc_a, &gmem_tensormap->tma_desc_b, &gmem_tensormap->tma_desc_sfa, &gmem_tensormap->tma_desc_sfb); + } } // Replace address for the global tensor (to be done by single thread) @@ -1244,7 +1279,7 @@ struct CollectiveMma< } // The entire warp must call this function collectively (that is, the instructions are aligned) - template + template CUTLASS_DEVICE void tensormaps_perform_update( @@ -1252,10 +1287,9 @@ struct CollectiveMma< Params const& mainloop_params, cute::tuple const& input_tensormaps, ProblemShape problem_shape, - int32_t next_batch) { + int32_t next_batch + ) { if (cute::elect_one_sync()) { - // Replacing global_address for the next batch - tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); if constexpr (IsGroupedGemmKernel) { auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(next_batch), 1); @@ -1263,23 +1297,34 @@ struct CollectiveMma< tensormaps_replace_global_tensor_properties(shared_tensormaps, mainloop_params, next_batch, problem_shape_MNKL); } + + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); } // Ensure warp is converged before issuing tensormap fence release __syncwarp(); // Entire warp must do this (ie its aligned) - tensormaps_cp_fence_release(shared_tensormaps, input_tensormaps); + tensormaps_cp_fence_release( + shared_tensormaps, + input_tensormaps + ); } - template + template CUTLASS_DEVICE void tensormaps_cp_fence_release ( TensorMapStorage& shared_tensormaps, - cute::tuple const& input_tensormaps) { - if (cute::elect_one_sync()) { - cute::tma_desc_commit_group(); - cute::tma_desc_wait_group(); + cute::tuple const& input_tensormaps + ) { + + if constexpr (WaitForInflightTmaRequests) { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } } + // Entire warp must do this (i.e. it's aligned) tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A); tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_B); diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized_rcggemm.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized_rcggemm.hpp new file mode 100644 index 0000000..b3fc231 --- /dev/null +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized_rcggemm.hpp @@ -0,0 +1,1294 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/sm100_tmem_helper.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +#include "cutlass/detail/collective/moe_stride_utils.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int Stages, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementPairA_, + class StridePairA_, + class ElementPairB_, + class StridePairB_, + class TiledMma_, + class GmemTiledCopyPairA_, + class SmemLayoutAtomPairA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyPairB_, + class SmemLayoutAtomPairB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100RCGroupGemmTmaUmmaWarpSpecializedBlockScaled< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, + TileShape_, + ElementPairA_, + StridePairA_, + ElementPairB_, + StridePairB_, + TiledMma_, + GmemTiledCopyPairA_, + SmemLayoutAtomPairA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyPairB_, + SmemLayoutAtomPairB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + using DispatchPolicy = MainloopSm100RCGroupGemmTmaUmmaWarpSpecializedBlockScaled< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; + using TileShape = TileShape_; + // Due to an MSVC bug, we can't use decltype(make_tiled_mma()) interface. + using TiledMMA_SF = TiledMMA, + Layout>, + Tile>; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + static constexpr int SFVecSize = TiledMma::SFVecSize; + static constexpr bool IsOverlappingAccum = DispatchPolicy::IsOverlappingAccum; + + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TileShape should be evenly divided by TiledMma"); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + static_assert(shape<1>(CtaShape_MNK{}) == 192 or shape<1>(CtaShape_MNK{}) == 64 or + shape<1>(CtaShape_MNK{}) == 128 or shape<1>(CtaShape_MNK{}) == 256, + "Cta N should be one of 64/128/192/256"); + + using ClusterTileShape = decltype(make_shape(get<0>(TileShape{})*get<0>(ClusterShape{}),get<1>(TileShape{})*get<1>(ClusterShape{}),get<2>(TileShape{})*get<2>(ClusterShape{}))); + using Sm1xxBlkScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig; + using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; + static constexpr int IsCtaN192 = shape<1>(CtaShape_MNK{}) == 192; + static constexpr int IsCtaN64 = shape<1>(CtaShape_MNK{}) == 64; + static int constexpr CTA_N_SF = cutlass::ceil_div(size<1>(CtaShape_MNK{}), Blk_MN{}) * Blk_MN{}; + // Tile shape used for partitioning Scale Factor B. + // The M-dim does not affect the SFB, so just set it as the original TileShape; + using TileShape_SF = decltype(make_shape(get<0>(CtaShape_MNK{}), + Int{} * shape<2>(typename TiledMma::ThrLayoutVMNK()), + get<2>(TileShape{}))); + + // Define A and B block shapes for reduced size TMA_LOADs + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + + // Multiple buffer the TMA descriptors for each SM so that we can update them asynchronously. + // This should be larger than the total number of TMA requests inflight (from update to issued to returned). + // This can be calculated by SchedulerStages + max(TmaStages) + 2 (for consumer and producer in-flight accessies). + constexpr static uint32_t NumTmaDescriptorsPerSm = SchedulerPipelineStageCount + Stages + 2; + + using ElementPairA = ElementPairA_; + using ElementPairB = ElementPairB_; + using ElementAMma = typename TiledMma::ValTypeA; + using ElementBMma = typename TiledMma::ValTypeB; + using StridePairA = StridePairA_; + using StridePairB = StridePairB_; + using SmemLayoutAtomPairA = SmemLayoutAtomPairA_; + using SmemLayoutAtomPairB = SmemLayoutAtomPairB_; + static_assert(cute::is_same_v(ElementPairA{}))>, + remove_cvref_t(ElementPairB{}))>>, "SFA and SFB data types should be the same"); + + // A and B matrices + using ElementA = remove_cvref_t(ElementPairA{}))>; + using StrideA = remove_cvref_t(StridePairA{}))>; + using InternalStrideA = cute::remove_pointer_t; + + using ElementB = remove_cvref_t(ElementPairB{}))>; + using StrideB = remove_cvref_t(StridePairB{}))>; + using InternalStrideB = cute::remove_pointer_t; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + // SFA and SFB + using ElementSF = remove_cvref_t(ElementPairA{}))>; + using LayoutSFA = remove_cvref_t(StridePairA{}))>; + using InternalLayoutSFA = cute::remove_pointer_t; + using LayoutSFB = remove_cvref_t(StridePairB{}))>; + using InternalLayoutSFB = cute::remove_pointer_t; + + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyPairA = GmemTiledCopyPairA_; + using GmemTiledCopyPairB = GmemTiledCopyPairB_; + using GmemTiledCopyA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopySFA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopyB = remove_cvref_t(GmemTiledCopyPairB{}))>; + using GmemTiledCopySFB = remove_cvref_t(GmemTiledCopyPairB{}))>; + + using SmemLayoutAtomA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomSFA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + using SmemLayoutAtomSFB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::Stages, + ClusterShape, + AtomThrShapeMNK>; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtomA must evenly divide the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtomA must evenly divide the tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtomB must evenly divide the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtomB must evenly divide the tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + // (MMA_TILE_N,MMA_TILE_K),MMA_N,MMA_K,PIPE) + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + // SmemLayoutAtomSFA and SmemLayoutAtomSFB are for whole CTA tiles. We add the number of pipeline stages here. + // The number of pipeline stages is the same as the number of pipeline stages from AB Load <-> MainLoop + using SmemLayoutSFA = decltype(make_layout( + append(shape(SmemLayoutAtomSFA{}), Int{}), + append(stride(SmemLayoutAtomSFA{}), size(filter_zeros(SmemLayoutAtomSFA{}))) + )); + using SmemLayoutSFB = decltype(make_layout( + append(shape(SmemLayoutAtomSFB{}), Int{}), + append(stride(SmemLayoutAtomSFB{}), size(filter_zeros(SmemLayoutAtomSFB{}))) + )); + + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + + static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + using TmaInternalElementA = cute::conditional_t; + using TmaInternalElementB = cute::conditional_t; + + using SmemAllocTypeA = cute::conditional_t < 8, uint8_t, ElementAMma>; + using SmemAllocTypeB = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using BitTypeElementA = cute::uint_bit_t>; + using BitTypeElementB = cute::uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = typename detail::sm10x_block_scale_runtime_input_t::Type; + using RuntimeDataTypeB = typename detail::sm10x_block_scale_runtime_input_t::Type; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + cute::ArrayEngine> smem_SFA; + cute::ArrayEngine> smem_SFB; + } tensors; + + struct TensorMapStorage : cute::aligned_struct<128, _0> { + cute::TmaDescriptor smem_tensormap_B; + cute::TmaDescriptor smem_tensormap_SFB; + } tensormaps; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly + static constexpr uint32_t SFTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFB{})) * cute::sizeof_bits_v); + static constexpr uint32_t ABTmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v); + static constexpr uint32_t TmaTransactionBytes = ABTmaTransactionBytes + SFTransactionBytes; + + template + struct TmemStorage { + AccTensor accumulators; + SfaTensor tCtSFA; + SfbTensor tCtSFB; + }; + + // Host side kernel arguments + struct Arguments { + ArrayElementA const* ptr_A{nullptr}; + ArrayElementB const** ptr_B{nullptr}; + ElementSF const* ptr_SFA{nullptr}; + ElementSF const** ptr_SFB{nullptr}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = + decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), + ClusterShape{})), make_tile(typename TiledMma::AtomThrID{}))); + using ClusterLayoutSfb_VMNK = + decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), + ClusterShape{})), make_tile(typename TiledMMA_SF::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_B = decltype(make_tma_atom_B_sm100( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}), + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_SFA = decltype(make_tma_atom_A_sm100( + GmemTiledCopySFA{}, + make_tensor(static_cast(nullptr), InternalLayoutSFA{}), + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_SFB = decltype(make_tma_atom_B_sm100( + GmemTiledCopySFB{}, + make_tensor(static_cast(nullptr), InternalLayoutSFB{}), + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + TileShape_SF{}, + TiledMMA_SF{}, + ClusterLayoutSfb_VMNK{}) + ); + + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_SFA tma_load_sfa; + TMA_SFB tma_load_sfb; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + TMA_SFA tma_load_sfa_fallback; + TMA_SFB tma_load_sfb_fallback; + dim3 cluster_shape_fallback; + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + cute::TmaDescriptor* tensormaps; + ArrayElementA const* ptr_A; + ArrayElementB const** ptr_B; + ElementSF const* ptr_SFA; + ElementSF const** ptr_SFB; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) + , runtime_data_type_a_(params.runtime_data_type_a) + , runtime_data_type_b_(params.runtime_data_type_b) { + if constexpr (IsDynamicCluster) { + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + observed_tma_load_sfa_ = is_fallback_cluster ? ¶ms.tma_load_sfa_fallback : ¶ms.tma_load_sfa; + observed_tma_load_sfb_ = is_fallback_cluster ? ¶ms.tma_load_sfb_fallback : ¶ms.tma_load_sfb; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + observed_tma_load_sfa_ = ¶ms.tma_load_sfa; + observed_tma_load_sfb_ = ¶ms.tma_load_sfb; + } + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape problem_shapes, + Arguments const& args, + void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. + // These will be replaced with correct values before the initial tma load. + auto init_M = int32_t(size<0>(TileShape{})); + auto init_N = int32_t(size<1>(TileShape{})); + auto init_K_A = int32_t(size<2>(TileShape{})); + auto init_K_B = int32_t(size<2>(TileShape{})); + auto init_L = 1; + + // Tensor pointers will be fixed before the first access + auto ptr_A_first_batch = recast_ptr(args.ptr_A); + TmaInternalElementB const* ptr_B_first_batch = nullptr; + + auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0); + init_M = get<0>(problem_shape_MNK); + init_K_A = get<2>(problem_shape_MNK); + + auto shape_a = make_shape(init_M, init_K_A, problem_shapes.groups()); + InternalStrideA stride_a = cutlass::make_internal_packed_stride(InternalStrideA{}, shape_a); + InternalStrideB stride_b = InternalStrideB{}; + + InternalLayoutSFA layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(init_M, init_N, init_K_A, problem_shapes.groups())); + InternalLayoutSFB layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(init_M, init_N, init_K_B, 1)); + + // Batches/Groups are managed by using appropriate pointers to input matrices. + Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(init_M, init_K_A, problem_shapes.groups()), stride_a)); + Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(init_N, init_K_B, init_L), stride_b)); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + + // Tensor pointers will be fixed before the first access + ElementSF const* ptr_SFA_first_batch = recast_ptr(args.ptr_SFA); + ElementSF const* ptr_SFB_first_batch = nullptr; + + Tensor tensor_sfa = make_tensor(ptr_SFA_first_batch, layout_SFA); + Tensor tensor_sfb = make_tensor(ptr_SFB_first_batch, layout_SFB); + + // Cluster layout for TMA construction of SFB + auto cluster_layout_sfb_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMMA_SF::AtomThrID{})); + auto cluster_layout_sfb_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMMA_SF::AtomThrID{})); + + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_SFA tma_load_sfa = make_tma_atom_A_sm100( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_SFB tma_load_sfb = make_tma_atom_B_sm100( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + TileShape_SF{}, + TiledMMA_SF{}, + cluster_layout_sfb_vmnk); + + typename Params::TMA_SFA tma_load_sfa_fallback = make_tma_atom_A_sm100( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_SFB tma_load_sfb_fallback = make_tma_atom_B_sm100( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + TileShape_SF{}, + TiledMMA_SF{}, + cluster_layout_sfb_vmnk_fallback); + + return { + tma_load_a, + tma_load_b, + tma_load_sfa, + tma_load_sfb, + tma_load_a_fallback, + tma_load_b_fallback, + tma_load_sfa_fallback, + tma_load_sfb_fallback, + hw_info.cluster_shape_fallback, + args.runtime_data_type_a, + args.runtime_data_type_b, + reinterpret_cast(workspace), + args.ptr_A, + reinterpret_cast(args.ptr_B), + args.ptr_SFA, + reinterpret_cast(args.ptr_SFB), + }; + } + + struct TensorMaps : cute::aligned_struct<256, _0> { + cute::TmaDescriptor tma_desc_b; + cute::TmaDescriptor tma_desc_sfb; + }; + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + // Allocate gmem space for input tensormaps per each SM. + return (sm_count * sizeof(TensorMaps) * NumTmaDescriptorsPerSm); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + ProblemShape problem_shapes, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits::value; + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cute::sizeof_bits::value; + + bool implementable = true; + if (problem_shapes.is_host_problem_shape_available()) { + // Check alignment for all problem sizes + for (int i = 0; i < problem_shapes.groups(); i++) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{}); + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{}); + } + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE static + auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + template + CUTLASS_DEVICE static + auto + slice_accumulator(TmemStorage tmem_storage, int stage) { + return tmem_storage.accumulators(_,_,_,stage); + } + + template + CUTLASS_DEVICE static + auto + init_tmem_tensors(EpilogueTile epi_tile) { + TiledMma tiled_mma; + auto acc_shape = partition_accumulator_shape(); + // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. + Tensor accumulators = cutlass::detail::make_sm100_accumulator( + tiled_mma, acc_shape, EpilogueTile{}); + Tensor tCtSFA = make_tensor(shape(SmemLayoutAtomSFA{})); + Tensor tCtSFB = make_tensor(shape(SmemLayoutAtomSFB{})); + + TmemStorage tmem_storage; + tmem_storage.accumulators = accumulators; + tmem_storage.tCtSFA = tCtSFA; + tmem_storage.tCtSFB = tCtSFB; + + return tmem_storage; + } + + template + CUTLASS_DEVICE static + void + set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) { + tmem_storage.accumulators.data() = tmem_base_addr; + tmem_storage.tCtSFA.data() = tmem_storage.accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.accumulators); + tmem_storage.tCtSFB.data() = tmem_storage.tCtSFA.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.tCtSFA); + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tma tensor for input A + /// gB_nkl - The tiled tma tensor for input B + /// tAgA_mkl - partitioned gmem tensor for A + /// tBgB_nkl - partitioned gmem tensor for B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + /// tAgSFA_mkl - partitioned gmem tensor for SFA + /// tBgSFB_nkl - partitioned gmem tensor for SFB + /// tAsSFA - partitioned tmem tensor for SFA + /// tAsSFB - partitioned tmem tensor for SFB + /// mcast_mask_a - tma multicast mask for A + /// mcast_mask_b - tma multicast mask for B + /// mcast_mask_sfa - tma multicast mask for SFA + /// mcast_mask_sfb - tma multicast mask for SFB + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, int32_t const sm_idx, + int32_t num_groups, + int32_t init_group) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + // Problem Shape and therefore strides that we construct are [M,N,K,L], but since here for the TMA loads + // we are managing TMA descriptors to change batches, we need to neglect the L mode + const int32_t mock_L = 1; + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,num_groups)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K,mock_L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + // Represent the full tensor of Scale factors + InternalLayoutSFA layout_SFA{}; + InternalLayoutSFB layout_SFB{}; + + layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, num_groups)); + layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); + + Tensor mSFA_mkl = observed_tma_load_sfa_->get_tma_tensor(shape(layout_SFA)); + auto mSFB_nkl = [=](){ + if constexpr (IsCtaN192) { + Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB)); + auto x = stride<0,1>(mSFB_tmp); + auto y = ceil_div(shape<0,1>(mSFB_tmp), 4); + auto new_shape = make_shape (make_shape( shape<0,0>(mSFB_tmp), + make_shape( make_shape(_2{}, _2{}), y)), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); + auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), + make_stride(make_stride( x, x), x*3)), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); + return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); + } + else if constexpr (IsCtaN64) { + Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB)); + auto new_shape = make_shape(make_shape(shape<0,0>(mSFB_tmp), + make_shape(_2{} , shape<0,1>(mSFB_tmp))), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); + auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), + make_stride(_0{}, stride<0,1>(mSFB_tmp))), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); + return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); + } + else { + return observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB)); + } + }(); + + Tensor gSFA_mkl = local_tile(mSFA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (TILE_M,TILE_K,m,k,l) + Tensor gSFB_nkl = local_tile(mSFB_nkl, TileShape_SF{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (TILE_N,TILE_K,n,k,l) + + // Partition for this CTA + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + + ThrMMA cta_mma_sfb = TiledMMA_SF{}.get_slice(blockIdx.x % size(typename TiledMMA_SF::AtomThrID{})); + Tensor tCgSFA_mkl = cta_mma.partition_A(gSFA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgSFB_nkl = cta_mma_sfb.partition_B(gSFB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + // Define the CTA-in-Cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + Layout cta_layout_sfb_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA_SF::AtomThrID{})); + auto cta_coord_sfb_vmnk = cta_layout_sfb_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,3>(tCgB_nkl)); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgSFA_mkl, tAsSFA] = tma_partition(*observed_tma_load_sfa_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sSFA), group_modes<0,3>(tCgSFA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgSFB_nkl, tBsSFB] = tma_partition(*observed_tma_load_sfb_, + get<1>(cta_coord_sfb_vmnk), make_layout(size<1>(cta_layout_sfb_vmnk)), + group_modes<0,3>(sSFB), group_modes<0,3>(tCgSFB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_sfa = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_sfb = create_tma_multicast_mask<1>(cta_layout_sfb_vmnk, cta_coord_sfb_vmnk); + + auto ret = cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB, // for input scale factor tensor values + mcast_mask_a, mcast_mask_b, mcast_mask_sfa, mcast_mask_sfb); // multicast masks + + if constexpr (IsTensorMapUpdateAsync) { + return ret; + } + else { + // Fetch a copy of tensormaps for the CTA from Params + auto input_tensormaps = tensormaps_init(params, shared_tensormaps, sm_count, sm_idx); + return cute::tuple_cat(ret, cute::make_tuple(input_tensormaps)); + } + } + + /// Set up the data needed by this collective for mma compute. + template + CUTLASS_DEVICE auto + mma_init( + TmemStorage tmem_storage, + TensorStorage& shared_tensors) const { + + // Allocate "fragments/descriptors" for A and B matrices + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); // PIPE + + // + // Scale Factor + // + Tensor tCtSFA = tmem_storage.tCtSFA; + Tensor tCtSFB = tmem_storage.tCtSFB; + // Setup smem descriptors for UTCCP + Tensor tCsSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor tCsSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + // Make SMEM and TMEM tensors compact removing the zero strides to eliminate unnecessary copy instructions. + auto tCsSFA_compact = make_tensor(tCsSFA.data(), filter_zeros(tCsSFA.layout())); + auto tCtSFA_compact = make_tensor(tCtSFA.data(), filter_zeros(tCtSFA.layout())); + auto tCsSFB_compact = make_tensor(tCsSFB.data(), filter_zeros(tCsSFB.layout())); + auto tCtSFB_compact = make_tensor(tCtSFB.data(), filter_zeros(tCtSFB.layout())); + + // Create the SMEM to TMEM copy operations based on the MMA atom used (1CTA vs 2CTA) + using AtomThrID = typename TiledMma::AtomThrID; + using UtccpOp = cute::conditional_t<(decltype(cute::size(AtomThrID{}) == Int<2>{})::value), + SM100_UTCCP_4x32dp128bit_2cta, SM100_UTCCP_4x32dp128bit_1cta>; + auto tiled_copy_s2t_SFA = make_utccp_copy(UtccpOp{}, tCtSFA_compact); + auto tiled_copy_s2t_SFB = make_utccp_copy(UtccpOp{}, tCtSFB_compact); + + auto thr_copy_s2t_SFA = tiled_copy_s2t_SFA.get_slice(0); + auto thr_tCsSFA_compact_s2t_ = thr_copy_s2t_SFA.partition_S(tCsSFA_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFA_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFA_compact_s2t_); + auto thr_tCtSFA_compact_s2t = thr_copy_s2t_SFA.partition_D(tCtSFA_compact); + + auto thr_copy_s2t_SFB = tiled_copy_s2t_SFB.get_slice(0); + auto thr_tCsSFB_compact_s2t_ = thr_copy_s2t_SFB.partition_S(tCsSFB_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFB_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFB_compact_s2t_); + auto thr_tCtSFB_compact_s2t = thr_copy_s2t_SFB.partition_D(tCtSFB_compact); + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + // Update instruction descriptor according to runtime argument. + // Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe. + tiled_mma.idesc_.a_format_ = uint8_t(runtime_data_type_a_) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(runtime_data_type_b_) & 0b111; + } + + return cute::make_tuple( + tiled_mma, + tCrA, tCrB, tCtSFA, tCtSFB, + tiled_copy_s2t_SFA, thr_tCsSFA_compact_s2t, thr_tCtSFA_compact_s2t, + tiled_copy_s2t_SFB, thr_tCsSFB_compact_s2t, thr_tCtSFB_compact_s2t); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class GTensorPartitionedSFA, class GTensorPartitionedSFB, + class STensorSFA, class STensorSFB, + class TensorMapB, class TensorMapSFB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load( + Params const& params, + MainloopPipeline mainloop_pipeline, + MainloopPipelineState mainloop_pipe_producer_state, + cute::tuple> const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count, + bool did_batch_change, + int curr_batch) { + + auto [unused_gA, unused_gB, + tAgA_mkl, tBgB_nkl, tAsA, tBsB, + tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB, + mcast_mask_a, mcast_mask_b, mcast_mask_sfa, mcast_mask_sfb, + input_tensormaps] = load_inputs; + + // Check to see if tensormaps have been replaced in gmem + if (did_batch_change) { + tensormaps_fence_acquire(input_tensormaps); + } + + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, curr_batch); + Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + Tensor tAgSFA = tAgSFA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, curr_batch); + Tensor tBgSFB = tBgSFB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + // Note: We don't synchronize the sf_pipeline for "Buffer_Empty". We use mainloop pipeline + // to do the synchronization at once. + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage)); + copy(observed_tma_load_sfa_->with(*tma_barrier, mcast_mask_sfa), tAgSFA(_,*k_tile_iter), tAsSFA(_,write_stage)); + copy(observed_tma_load_sfb_->with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_sfb), tBgSFB(_,*k_tile_iter), tBsSFB(_,write_stage)); + } + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline mainloop_pipeline, MainloopPipelineState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class AccumulatorPipeline, + class FrgEngine, class FrgLayout, + class FragmentA, class FragmentB, + class FragmentSFA, class FragmentSFB, + class CtaTileCoord, + class SFATiledCopy, class SmemFrgSFA, class TmemFrgSFA, + class SFBTiledCopy, class SmemFrgSFB, class TmemFrgSFB + > + CUTLASS_DEVICE auto + mma(cute::tuple pipelines, + cute::tuple pipeline_states, + cute::Tensor& accumulators, + cute::tuple const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + + auto [tiled_mma, + tCrA, tCrB, tCtSFA, tCtSFB, + tiled_copy_s2t_SFA, thr_tCsSFA_s2t, + thr_tCtSFA_s2t, tiled_copy_s2t_SFB, + thr_tCsSFB_s2t, thr_tCtSFB_s2t] = mma_inputs; + + auto [mainloop_pipeline, accumulator_pipeline] = pipelines; + auto [mainloop_pipe_consumer_state, accumulator_pipe_producer_state] = pipeline_states; + + auto tCtSFB_mma = [tCtSFB = tCtSFB, cta_tile_coord]() { + if constexpr (IsCtaN192) { + // If this is an ODD tile, shift the TMEM start address for N=192 case by two words (ignores first 64 columns of SFB) + auto tCtSFB_tmp = tCtSFB; + if (size<1>(cta_tile_coord) % 2 == 1) { + tCtSFB_tmp.data() = tCtSFB_tmp.data().get() + 2; + } + return tCtSFB_tmp; + } + else if constexpr (IsCtaN64) { + // Move in increments of 64 columns of SFB + auto tCtSFB_tmp = tCtSFB; + tCtSFB_tmp.data() = tCtSFB_tmp.data().get() + (size<1>(cta_tile_coord) % 2) * 2; + return tCtSFB_tmp; + } + else { + return tCtSFB; + } + }(); + + uint32_t skip_wait = k_tile_count <= 0; + auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + // + // PIPELINED MAIN LOOP + // + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + + if constexpr (IsOverlappingAccum) { + // first iteration manual unroll for tmem overlap kernel + if (k_tile_count > 0) { + // WAIT on mainloop_pipe_consumer_state until its data are available + // (phase bit flips from mainloop_pipe_consumer_state.phase() value) + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + + // Compute on k_tile + int read_stage = mainloop_pipe_consumer_state.index(); + // Save current mainlop pipeline read state + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + + // Advance mainloop_pipe + ++mainloop_pipe_consumer_state; + --k_tile_count; + skip_wait = k_tile_count <= 0; + // Peek at next iteration + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + if (cute::elect_one_sync()) { + copy(tiled_copy_s2t_SFA, thr_tCsSFA_s2t(_,_,_,_,read_stage), thr_tCtSFA_s2t); + copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage), thr_tCtSFB_s2t); + } + + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + + // Unroll the K mode manually so we can set scale C to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma.with(tiled_mma.accumulate_, + tCtSFA(_,_,k_block), + tCtSFB_mma(_,_,k_block)), + tCrA(_,_,k_block,read_stage), + tCrB(_,_,k_block,read_stage), + accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + } + else { + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + } + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // WAIT on mainloop_pipe_consumer_state until its data are available + // (phase bit flips from mainloop_pipe_consumer_state.phase() value) + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + + // Compute on k_tile + int read_stage = mainloop_pipe_consumer_state.index(); + // Save current mainlop pipeline read state + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + + // Advance mainloop_pipe + ++mainloop_pipe_consumer_state; + --k_tile_count; + skip_wait = k_tile_count <= 0; + // Peek at next iteration + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + if (cute::elect_one_sync()) { + copy(tiled_copy_s2t_SFA, thr_tCsSFA_s2t(_,_,_,_,read_stage), thr_tCtSFA_s2t); + copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage), thr_tCtSFB_s2t); + } + + // Unroll the K mode manually so we can set scale C to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma.with(tiled_mma.accumulate_, + tCtSFA(_,_,k_block), + tCtSFB_mma(_,_,k_block)), + tCrA(_,_,k_block,read_stage), + tCrB(_,_,k_block,read_stage), + accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + + return mainloop_pipe_consumer_state; + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + + template + CUTLASS_DEVICE auto + tensormaps_init( + Params const& mainloop_params, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, + int32_t const sm_idx) const { + + TensorMaps* gmem_tensormap = &(reinterpret_cast(mainloop_params.tensormaps)[sm_idx * NumTmaDescriptorsPerSm]); + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + Tensor pB_tensormap = make_tensor(observed_tma_load_b_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{}); + + Tensor pSFB_tensormap = make_tensor(observed_tma_load_sfb_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sSFB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_SFB), Int<1>{}, Int<1>{}); + + copy(recast(pB_tensormap), recast(sB_tensormap)); + + copy(recast(pSFB_tensormap), recast(sSFB_tensormap)); + } + + __syncwarp(); + + struct TensorMapArray { + + TensorMaps *tensor_maps; + + TensorMapArray() = default; + + CUTLASS_DEVICE + TensorMapArray(void* tensormaps) : tensor_maps(reinterpret_cast(tensormaps)) {} + + CUTLASS_DEVICE + cute::tuple + operator[](int32_t idx) { + idx = idx % NumTmaDescriptorsPerSm; + return cute::make_tuple(&tensor_maps[idx].tma_desc_b, &tensor_maps[idx].tma_desc_sfb); + } + }; + if constexpr (IsTensorMapUpdateAsync) { + return TensorMapArray(gmem_tensormap); + } else { + return cute::make_tuple(&gmem_tensormap->tma_desc_b, &gmem_tensormap->tma_desc_sfb); + } + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE + void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B, + mainloop_params.ptr_B[next_batch]); + + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_SFB, + mainloop_params.ptr_SFB[next_batch]); + } + + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_tensor_properties( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl) { + const uint32_t M = get<0>(problem_shape_mnkl); + const uint32_t N = get<1>(problem_shape_mnkl); + const uint32_t K = get<2>(problem_shape_mnkl); + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape_B = {1,1,1,1,1}; + cute::array prob_stride_B = {0,0,0,0,0}; + cute::array prob_shape_SFB = {1,1,1,1,1}; + cute::array prob_stride_SFB = {0,0,0,0,0}; + + TmaInternalElementB const* ptr_B = nullptr; + auto internal_shape_b = make_shape(static_cast(N), static_cast(K), 1); + InternalStrideB stride_b = cutlass::make_internal_packed_stride(InternalStrideB{}, internal_shape_b); + Tensor tensor_b = make_tensor(ptr_B, make_shape(N,K,Int<1>{}), stride_b); + + ElementSF const* ptr_SF = nullptr; + auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); + Tensor tensor_sfb = make_tensor(ptr_SF, layout_SFB); + + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_b_, tensor_b, + prob_shape_B, prob_stride_B); + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_sfb_, tensor_sfb, + prob_shape_SFB, prob_stride_SFB); + + // Convert strides to byte strides + for (uint64_t& stride : prob_stride_B) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_SFB) { + stride = (stride * sizeof_bits_v) / 8; + } + + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B, + prob_shape_B, + prob_stride_B); + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_SFB, + prob_shape_SFB, + prob_stride_SFB); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + cute::tuple const& input_tensormaps, + ProblemShape problem_shape, + int32_t next_batch + ) { + if (cute::elect_one_sync()) { + + auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(next_batch), 1); + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties(shared_tensormaps, + mainloop_params, next_batch, problem_shape_MNKL); + + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); + } + // Ensure warp is converged before issuing tensormap fence release + __syncwarp(); + // Entire warp must do this (ie its aligned) + tensormaps_cp_fence_release( + shared_tensormaps, + input_tensormaps + ); + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release ( + TensorMapStorage& shared_tensormaps, + cute::tuple const& input_tensormaps + ) { + + if constexpr (WaitForInflightTmaRequests) { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + } + + // Entire warp must do this (i.e. it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_B); + + tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_SFB); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::tuple const& input_tensormaps) { + cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps)); + } + +protected: + + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + typename Params::TMA_B const* observed_tma_load_b_{nullptr}; + typename Params::TMA_SFA const* observed_tma_load_sfa_{nullptr}; + typename Params::TMA_SFB const* observed_tma_load_sfb_{nullptr}; + + LayoutSFA layout_SFA_{}; + LayoutSFB layout_SFB_{}; + RuntimeDataTypeA runtime_data_type_a_{}; + RuntimeDataTypeB runtime_data_type_b_{}; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm100_blockscaled_mma_mixed_tma_cpasync_warpspecialized.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm100_blockscaled_mma_mixed_tma_cpasync_warpspecialized.hpp new file mode 100644 index 0000000..157df9c --- /dev/null +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm100_blockscaled_mma_mixed_tma_cpasync_warpspecialized.hpp @@ -0,0 +1,1032 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/arch/memory.h" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +#include "cutlass/gemm/collective/collective_mma_decl.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int Stages, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementPairA_, + class StridePairA_, + class ElementPairB_, + class StridePairB_, + class TiledMma_, + class GmemTiledCopyPairA_, + class SmemLayoutAtomPairA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyPairB_, + class SmemLayoutAtomPairB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100UmmaMixedTmaCpAsyncWarpSpecializedBlockScaled< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, + TileShape_, + ElementPairA_, + StridePairA_, + ElementPairB_, + StridePairB_, + TiledMma_, + GmemTiledCopyPairA_, + SmemLayoutAtomPairA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyPairB_, + SmemLayoutAtomPairB_, + SmemCopyAtomB_, + TransformB_> { + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + // Statically asserting to ensure only 1x1x1 cluster shape & 1sm setup is received + static_assert(size(AtomThrShapeMNK{}) == 1, "Lower alignment SM100 GEMM only supports 1SM MMA"); + static_assert(size(ClusterShape{}) == 1, "CPASYNC does not support multicast so the cluster shape is restricted to 1, 1, 1"); + + static_assert(size(typename TiledMma::AtomThrID{}) == 1); + + using DispatchPolicy = MainloopSm100UmmaMixedTmaCpAsyncWarpSpecializedBlockScaled< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; + // TileShape refers to MmaTileShape to adapt for runtime cluster + using TileShape = TileShape_; + using TiledMma_SF = TiledMMA, + Layout>, + Tile>; + + static constexpr int SFVecSize = TiledMma::SFVecSize; + static constexpr bool IsOverlappingAccum = DispatchPolicy::IsOverlappingAccum; + static_assert(!IsOverlappingAccum, "TMA+CPASYNC kernel currently only supports non-overlapping accum."); + + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TileShape should be evenly divided by TiledMma"); + + // Define A and B block shapes + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + // using LoadShapeA_MK = decltype(select<0,2>(TileShape{})); + using LoadShapeB_NK = decltype(select<1,2>(TileShape{})); + + // CtaShape_MNK is queried from collective in all kernel layers + using CtaShape_MNK = TileShape; + static_assert(shape<1>(CtaShape_MNK{}) == 192 or shape<1>(CtaShape_MNK{}) == 64 or + shape<1>(CtaShape_MNK{}) == 128 or shape<1>(CtaShape_MNK{}) == 256, + "Cta N should be one of 64/128/192/256"); + + using ClusterTileShape = decltype(make_shape(get<0>(TileShape{})*get<0>(ClusterShape{}),get<1>(TileShape{})*get<1>(ClusterShape{}),get<2>(TileShape{})*get<2>(ClusterShape{}))); + using Sm1xxBlkScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig; + using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; + static constexpr int IsCtaN192 = shape<1>(CtaShape_MNK{}) == 192; + static constexpr int IsCtaN64 = shape<1>(CtaShape_MNK{}) == 64; + static int constexpr CTA_N_SF = cutlass::ceil_div(size<1>(CtaShape_MNK{}), Blk_MN{}) * Blk_MN{}; + // Tile shape used for partitioning Scale Factor B. + // The M-dim does not affect the SFB, so just set it as the original TileShape; + using TileShape_SF = decltype(make_shape(get<0>(CtaShape_MNK{}), + Int{} * shape<2>(typename TiledMma::ThrLayoutVMNK()), + get<2>(TileShape{}))); + + using ElementPairA = ElementPairA_; + using ElementPairB = ElementPairB_; + using StridePairA = StridePairA_; + using StridePairB = StridePairB_; + using SmemLayoutAtomPairA = SmemLayoutAtomPairA_; + using SmemLayoutAtomPairB = SmemLayoutAtomPairB_; + static_assert(cute::is_same_v(ElementPairA{}))>, + remove_cvref_t(ElementPairB{}))>>, "SFA and SFB data types should be the same"); + + + using ElementA = remove_cvref_t(ElementPairA{}))>; + using ElementAMma = typename TiledMma::ValTypeA; + using StrideA = remove_cvref_t(StridePairA{}))>; + using ElementB = remove_cvref_t(ElementPairB{}))>; + using ElementBMma = typename TiledMma::ValTypeB; + using StrideB = remove_cvref_t(StridePairB{}))>; + + static constexpr bool IsRuntimeDataTypeA = cute::is_same_v or cute::is_same_v; + static constexpr bool IsRuntimeDataTypeB = cute::is_same_v or cute::is_same_v; + + static_assert(IsRuntimeDataTypeA == IsRuntimeDataTypeB, + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + // SFA and SFB + using ElementSF = remove_cvref_t(ElementPairA{}))>; + using LayoutSFA = remove_cvref_t(StridePairA{}))>; + using LayoutSFB = remove_cvref_t(StridePairB{}))>; + + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyPairA = GmemTiledCopyPairA_; + using GmemTiledCopyPairB = GmemTiledCopyPairB_; + + using GmemTiledCopyA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopySFA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopyB = remove_cvref_t(GmemTiledCopyPairB{}))>; + using GmemTiledCopySFB = remove_cvref_t(GmemTiledCopyPairB{}))>; + + using SmemLayoutAtomA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomSFA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + using SmemLayoutAtomSFB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipelineTMA = cutlass::PipelineTmaUmmaAsync; + using MainloopPipelineTMAState = typename MainloopPipelineTMA::PipelineState; + + using MainloopPipelineCpAsync = cutlass::PipelineUmmaConsumerAsync; + using MainloopPipelineCpAsyncState = typename MainloopPipelineCpAsync::PipelineState; + + // static_assert(size(GmemTiledCopyA{}) == size(GmemTiledCopyB{}), "A and B GmemTiledCopy should share the same thread count"); + static constexpr int NumLoadThreadsCpAsync = size(GmemTiledCopyB{}); + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M,K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtomA must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtomA must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtomB must be rank 2 (N,K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtomB must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtomB must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + + using MmaSmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + using LoadSmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + append(LoadShapeB_NK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + // SmemLayoutAtomSFA and SmemLayoutAtomSFB are for whole CTA tiles. We add the number of pipeline stages here. + // The number of pipeline stages is the same as the number of pipeline stages from AB Load <-> MainLoop + using SmemLayoutSFA = decltype(make_layout( + append(shape(SmemLayoutAtomSFA{}), Int{}), + append(stride(SmemLayoutAtomSFA{}), size(filter_zeros(SmemLayoutAtomSFA{}))) + )); + using SmemLayoutSFB = decltype(make_layout( + append(shape(SmemLayoutAtomSFB{}), Int{}), + append(stride(SmemLayoutAtomSFB{}), size(filter_zeros(SmemLayoutAtomSFB{}))) + )); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + + static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + + using TmaInternalElementA = cute::conditional_t; + + using SmemAllocTypeA = cute::conditional_t < 8, uint8_t, ElementAMma>; + using SmemAllocTypeB = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using BitTypeElementA = cute::uint_bit_t>; + using BitTypeElementB = cute::uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = typename detail::sm10x_block_scale_runtime_input_t::Type; + using RuntimeDataTypeB = typename detail::sm10x_block_scale_runtime_input_t::Type; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + cute::ArrayEngine> smem_SFA; + cute::ArrayEngine> smem_SFB; + } tensors; + + using PipelineStorageTMA = typename MainloopPipelineTMA::SharedStorage; + using PipelineStorageCpAsync = typename MainloopPipelineCpAsync::SharedStorage; + + struct PipelineStorage : cute::aligned_struct<16, _0> { + alignas(16) PipelineStorageTMA tma; + alignas(16) PipelineStorageCpAsync cpasync; + } pipelines; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr uint32_t SFTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFB{})) * cute::sizeof_bits_v); + static constexpr uint32_t ATmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v); + static constexpr uint32_t TmaTransactionBytes = ATmaTransactionBytes + SFTransactionBytes; + + template + struct TmemStorage { + AccTensor accumulators; + SfaTensor tCtSFA; + SfbTensor tCtSFB; + }; + + // Host side kernel arguments + struct Arguments { + ArrayElementA const* ptr_A{nullptr}; + ArrayElementB const* ptr_B{nullptr}; + ElementSF const* ptr_SFA{nullptr}; + ElementSF const* ptr_SFB{nullptr}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(ClusterShape{}), + make_tile(typename TiledMma::AtomThrID{}))); + using ClusterLayoutSfb_VMNK = decltype(tiled_divide(make_layout(ClusterShape{}), + make_tile(typename TiledMma_SF::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + using TMA_SFA = decltype(make_tma_atom_A_sm100( + GmemTiledCopySFA{}, + make_tensor(static_cast(nullptr), LayoutSFA{}), + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + using TMA_SFB = decltype(make_tma_atom_B_sm100( + GmemTiledCopySFB{}, + make_tensor(static_cast(nullptr), LayoutSFB{}), + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + TileShape_SF{}, + TiledMma_SF{}, + ClusterLayoutSfb_VMNK{}) + ); + + TMA_A tma_load_a; + TMA_SFA tma_load_sfa; + TMA_SFB tma_load_sfb; + + ArrayElementB const* ptr_B{nullptr}; + + LayoutSFA layout_SFA; + LayoutSFB layout_SFB; + + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params) + : layout_SFA_(params.layout_SFA) + , layout_SFB_(params.layout_SFB) + , runtime_data_type_a_(params.runtime_data_type_a) + , runtime_data_type_b_(params.runtime_data_type_b) { + + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_sfa_ = ¶ms.tma_load_sfa; + observed_tma_load_sfb_ = ¶ms.tma_load_sfb; + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_B = recast_ptr(args.ptr_B); + + const auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(problem_shape_MNKL); + const auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(problem_shape_MNKL); + + auto shape_a = make_shape(M, K, L); + auto stride_a = cutlass::make_internal_packed_stride(StrideA{}, shape_a); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(shape_a, stride_a)); + Tensor tensor_sfa = make_tensor(args.ptr_SFA, layout_SFA); + Tensor tensor_sfb = make_tensor(args.ptr_SFB, layout_SFB); + + auto cluster_layout_vmnk = tiled_divide(make_layout(ClusterShape{}), make_tile(typename TiledMma::AtomThrID{})); + auto cluster_layout_sfb_vmnk = tiled_divide(make_layout(ClusterShape{}), make_tile(typename TiledMma_SF::AtomThrID{})); + + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + typename Params::TMA_SFA tma_load_sfa = make_tma_atom_A_sm100( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + typename Params::TMA_SFB tma_load_sfb = make_tma_atom_B_sm100( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + TileShape_SF{}, + TiledMma_SF{}, + cluster_layout_sfb_vmnk); + + return { + tma_load_a, + tma_load_sfa, + tma_load_sfb, + args.ptr_B, + layout_SFA, + layout_SFB, + args.runtime_data_type_a, + args.runtime_data_type_b + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + // static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits::value; + + bool implementable = true; + + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for CpAsync.\n"); + } + + return implementable; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE void + prefetch_tma_descriptors() { + cute::prefetch_tma_descriptor(observed_tma_load_a_->get_tma_descriptor()); + cute::prefetch_tma_descriptor(observed_tma_load_sfa_->get_tma_descriptor()); + cute::prefetch_tma_descriptor(observed_tma_load_sfb_->get_tma_descriptor()); + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE static + auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + template + CUTLASS_DEVICE static + auto + slice_accumulator(TmemStorage tmem_storage, int stage) { + return cute::make_tuple(tmem_storage.accumulators(_,_,_,stage)); + } + + template + CUTLASS_DEVICE static + auto + init_tmem_tensors(EpilogueTile epi_tile) { + TiledMma tiled_mma; + auto acc_shape = partition_accumulator_shape(); + // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. + Tensor accumulators = cutlass::detail::make_sm100_accumulator( + tiled_mma, acc_shape, EpilogueTile{}); + Tensor tCtSFA = make_tensor(shape(SmemLayoutAtomSFA{})); + Tensor tCtSFB = make_tensor(shape(SmemLayoutAtomSFB{})); + + TmemStorage tmem_storage; + tmem_storage.accumulators = accumulators; + tmem_storage.tCtSFA = tCtSFA; + tmem_storage.tCtSFB = tCtSFB; + + return tmem_storage; + } + + template + CUTLASS_DEVICE static + void + set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) { + tmem_storage.accumulators.data() = tmem_base_addr; + tmem_storage.tCtSFA.data() = tmem_storage.accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.accumulators); + tmem_storage.tCtSFB.data() = tmem_storage.tCtSFA.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.tCtSFA); + } + + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tensor for input A + /// gB_nkl - The tiled tensor for input B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + template + CUTLASS_DEVICE auto + load_init_tma( + ProblemShape_MNKL const& problem_shape_MNKL, + TensorStorage& shared_tensors) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,L)); + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + + // Represent the full tensor of Scale factors + Tensor mSFA_mkl = observed_tma_load_sfa_->get_tma_tensor(shape(layout_SFA_)); + auto mSFB_nkl = [=](){ + if constexpr (IsCtaN192) { + Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB_)); + auto x = stride<0,1>(mSFB_tmp); + auto y = ceil_div(shape<0,1>(mSFB_tmp), 4); + auto new_shape = make_shape (make_shape( shape<0,0>(mSFB_tmp), + make_shape( make_shape(_2{}, _2{}), y)), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); + auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), + make_stride(make_stride( x, x), x*3)), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); + return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); + } + else if constexpr (IsCtaN64) { + Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB_)); + auto new_shape = make_shape(make_shape(shape<0,0>(mSFB_tmp), + make_shape(_2{} , shape<0,1>(mSFB_tmp))), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); + auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), + make_stride(_0{}, stride<0,1>(mSFB_tmp))), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); + return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); + } + else { + return observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB_)); + } + }(); + + Tensor gSFA_mkl = local_tile(mSFA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (TILE_M,TILE_K,m,k,l) + Tensor gSFB_nkl = local_tile(mSFB_nkl, TileShape_SF{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (TILE_N,TILE_K,n,k,l) + + + ThrMMA cta_mma = TiledMma{}.get_slice(0); + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + + ThrMMA cta_mma_sfb = TiledMma_SF{}.get_slice(blockIdx.x % size(typename TiledMma_SF::AtomThrID{})); + Tensor tCgSFA_mkl = cta_mma.partition_A(gSFA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgSFB_nkl = cta_mma_sfb.partition_B(gSFB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(ClusterShape{}); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(0); + Layout cta_layout_sfb_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma_SF::AtomThrID{})); + auto cta_coord_sfb_vmnk = cta_layout_sfb_vmnk.get_flat_coord(0); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + // Project the cta_layout for tma_a along the n-modes + auto [tAgSFA_mkl, tAsSFA] = tma_partition(*observed_tma_load_sfa_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sSFA), group_modes<0,3>(tCgSFA_mkl)); + // Project the cta_layout for tma_b along the m-modes + auto [tBgSFB_nkl, tBsSFB] = tma_partition(*observed_tma_load_sfb_, + get<1>(cta_coord_sfb_vmnk), make_layout(size<1>(cta_layout_sfb_vmnk)), + group_modes<0,3>(sSFB), group_modes<0,3>(tCgSFB_nkl)); + + return cute::make_tuple( + shape<3>(gA_mkl), // for scheduler + tAgA_mkl, tAsA, // for input tensor values + tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB // for input scale factor tensor values + ); + } + + template + CUTLASS_DEVICE auto + load_init_cpasync( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors, + TileScheduler const& scheduler, + typename TileScheduler::WorkTileInfo const& work_tile_info) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // Setting the stride of B + auto shape_b = make_shape(N, K, L); + StrideB stride_b = cutlass::make_internal_packed_stride(StrideB{}, shape_b); + + // convert to subptr iterator if necessary + auto ptr_B = recast_ptr(params.ptr_B); + // Represent the full tensors + Tensor mB_nkl = make_tensor(make_gmem_ptr(ptr_B), shape_b, stride_b); //(n,k,l) + // Partition for cpasync + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + // Build the coordinate tensors with the same shape as input matrices + Tensor cB_nk = make_identity_tensor(make_shape(N,K)); + // Slice the coordinate tensors in the same way as A/B tensor partitioning + Tensor cgB_nk = local_tile(cB_nk, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k) + + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), LoadSmemLayoutB{}); + + GmemTiledCopyB gmem_to_smem_b_tiled_copy; + + int thread_idx = threadIdx.x % NumLoadThreadsCpAsync; + auto thr_copy_b = gmem_to_smem_b_tiled_copy.get_slice(thread_idx); + + return cute::make_tuple( + gB_nkl, cgB_nk, sB, + // problem_shape_MNKL, + gmem_to_smem_b_tiled_copy, thr_copy_b); + } + + /// Set up the data needed by this collective for mma compute. + template + CUTLASS_DEVICE auto + mma_init( + Params const& params, + TmemStorage tmem_storage, + // [[maybe_unused]] cute::tuple, cute::Tensor> const& accumulators_pair, + TensorStorage& shared_tensors) const { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), MmaSmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); + + // + // Scale Factor + // + Tensor tCtSFA = tmem_storage.tCtSFA; + Tensor tCtSFB = tmem_storage.tCtSFB; + // Setup smem descriptors for UTCCP + Tensor tCsSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor tCsSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + // Make SMEM and TMEM tensors compact removing the zero strides to eliminate unnecessary copy instructions. + auto tCsSFA_compact = make_tensor(tCsSFA.data(), filter_zeros(tCsSFA.layout())); + auto tCtSFA_compact = make_tensor(tCtSFA.data(), filter_zeros(tCtSFA.layout())); + auto tCsSFB_compact = make_tensor(tCsSFB.data(), filter_zeros(tCsSFB.layout())); + auto tCtSFB_compact = make_tensor(tCtSFB.data(), filter_zeros(tCtSFB.layout())); + + // Create the SMEM to TMEM copy operations based on the MMA atom used (1CTA vs 2CTA) + using AtomThrID = typename TiledMma::AtomThrID; + using UtccpOp = cute::conditional_t<(decltype(cute::size(AtomThrID{}) == Int<2>{})::value), + SM100_UTCCP_4x32dp128bit_2cta, SM100_UTCCP_4x32dp128bit_1cta>; + auto tiled_copy_s2t_SFA = make_utccp_copy(UtccpOp{}, tCtSFA_compact); + auto tiled_copy_s2t_SFB = make_utccp_copy(UtccpOp{}, tCtSFB_compact); + + auto thr_copy_s2t_SFA = tiled_copy_s2t_SFA.get_slice(0); + auto thr_tCsSFA_compact_s2t_ = thr_copy_s2t_SFA.partition_S(tCsSFA_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFA_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFA_compact_s2t_); + auto thr_tCtSFA_compact_s2t = thr_copy_s2t_SFA.partition_D(tCtSFA_compact); + + auto thr_copy_s2t_SFB = tiled_copy_s2t_SFB.get_slice(0); + auto thr_tCsSFB_compact_s2t_ = thr_copy_s2t_SFB.partition_S(tCsSFB_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFB_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFB_compact_s2t_); + auto thr_tCtSFB_compact_s2t = thr_copy_s2t_SFB.partition_D(tCtSFB_compact); + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + // Update instruction descriptor according to runtime argument. + // Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe. + tiled_mma.idesc_.a_format_ = uint8_t(params.runtime_data_type_a) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(params.runtime_data_type_b) & 0b111; + } + + return cute::make_tuple( + tiled_mma, + tCrA, tCrB, + tCtSFA, tCtSFB, + tiled_copy_s2t_SFA, thr_tCsSFA_compact_s2t, thr_tCtSFA_compact_s2t, + tiled_copy_s2t_SFB, thr_tCsSFB_compact_s2t, thr_tCtSFB_compact_s2t + + // debug + // , sA, sB, tCsSFA, tCsSFB + ); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + // class KTileCount, + // class GTensorPartitionedA, + // class STensorA, + class TileCoordMNKL, + class KTileIterator, + class... TLoadParams // see load_init_tma + > + CUTLASS_DEVICE auto + load_tma( + MainloopPipelineTMA mainloop_pipeline, + MainloopPipelineTMAState mainloop_pipe_producer_state, + cute::tuple const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + // Unpack from load_inputs + // KTileCount k_tiles = get<0>(load_inputs); + // GTensorPartitionedA tAgA_mkl = get<1>(load_inputs); + // STensorA tAsA = get<2>(load_inputs); + + auto [k_tiles, + tAgA_mkl, tAsA, + tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB] = load_inputs; + + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tAgSFA = tAgSFA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgSFB = tBgSFB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + + using BarrierType = typename MainloopPipelineTMA::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(*tma_barrier), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_sfa_->with(*tma_barrier), tAgSFA(_,*k_tile_iter), tAsSFA(_,write_stage)); + copy(observed_tma_load_sfb_->with(*tma_barrier), tBgSFB(_,*k_tile_iter), tBsSFB(_,write_stage)); + } + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + + template < + // class GTensorB, + // class CTensorB, + // class STensorB, + // class ProblemShape_MNKL, + // class TiledCopyB, + // class ThreadCopyB, + class TileCoordMNKL, + class KTileIterator, + class ProblemShape_MNKL, + class... TParams + > + CUTLASS_DEVICE auto + load_cpasync( + Params const& params, + MainloopPipelineCpAsync mainloop_pipeline, + MainloopPipelineCpAsyncState mainloop_pipe_producer_state, + cute::tuple const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count, + ProblemShape_MNKL effective_shape + ) { + + // Unpack from load_inputs + // GTensorB tBgB_nkl = get<0>(load_inputs); + // CTensorB cgB_nk = get<1>(load_inputs); + // STensorB sB = get<2>(load_inputs); + // ProblemShape_MNKL problem_shape_MNKL = get<3>(load_inputs); + // TiledCopyB gmem_to_smem_b_tiled_copy = get<4>(load_inputs); + // ThreadCopyB thr_copy_b = get<5>(load_inputs); + + // auto [M,N,K,L] = problem_shape_MNKL; + + auto [ + tBgB_nkl, cgB_nk, sB, + // problem_shape_MNKL, + gmem_to_smem_b_tiled_copy, thr_copy_b] = load_inputs; + + auto [M,N,K,L] = effective_shape; + + // Slice out the work coord from partitioned tensors + Tensor gB_in = tBgB_nkl(_, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + // Repeat slicing out coordinate tensor exactly the same as input tensor does + Tensor cgB_nk_in = cgB_nk(_, _, get<1>(cta_coord_mnkl), _); + + auto k_residue = K - size<1>(gB_in) * size<2>(gB_in); // K - BLK_K * k is negative + + Tensor gB = gB_in; + Tensor cB = cgB_nk_in; + + auto tBgB = thr_copy_b.partition_S(gB); + auto tBsB = thr_copy_b.partition_D(sB); + + // Allocate predicate tensors for n + Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); + Tensor tBcB_nk = thr_copy_b.partition_S(cgB_nk_in); + Tensor tBcB = thr_copy_b.partition_S(cB); + + // Copy gmem to smem for *k_tile_iter, predicating for k residue + Tensor tBgBk = tBgB(_,_,_,*k_tile_iter); + + // Repeating on predicators with the same operations on tBgB + Tensor tBcBk = tBcB(_,_,_,*k_tile_iter); + + // Set predicates for n bounds + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<0>(tBpB); ++n) { + tBpB(n,0) = elem_less(get<0>(tBcBk(0,n,0)), N); // blk_n coord < N + } + + // we will process the last tile after the mainloop + if (k_residue != 0) { + --k_tile_count; + } + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + + copy_if(gmem_to_smem_b_tiled_copy, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + + mainloop_pipeline.producer_commit(mainloop_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive); + --k_tile_count; + ++k_tile_iter; + ++mainloop_pipe_producer_state; + } + + // last tile with predication on k to account for residue + // For performance consideration, + // this predicated block for K-tail is only activated when there is k-residue + if (k_residue != 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tBsB); ++k) { + if (int(get<1>(tBcBk(0,0,k))) >= 0) { // blk_k coord < K + copy_if(gmem_to_smem_b_tiled_copy, tBpB(_,k), tBgB(_,_,k,*k_tile_iter), tBsB(_,_,k,write_stage)); + } + else { + clear(tBsB(_,_,k,write_stage)); + } + } + ++k_tile_iter; + --k_tile_count; + + // UNLOCK mainloop_pipe_producer_state + mainloop_pipeline.producer_commit(mainloop_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive); + + // Advance mainloop_pipe_producer_state + ++mainloop_pipe_producer_state; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_tail_tma(MainloopPipelineTMA mainloop_pipeline, MainloopPipelineTMAState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + CUTLASS_DEVICE void + load_tail_cpasync(MainloopPipelineCpAsync mainloop_pipeline, MainloopPipelineCpAsyncState mainloop_pipe_producer_state) { + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class AccumulatorPipeline, + class FrgEngine, class FrgLayout, + class CtaTileCoord, + class... TMmaParams + > + CUTLASS_DEVICE auto + mma(cute::tuple pipelines, + cute::tuple pipeline_states, + cute::tuple> const& accumulators_pair, + cute::tuple const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + auto accumulators = get<0>(accumulators_pair); + auto [tiled_mma, tCrA, tCrB, tCtSFA, tCtSFB, + tiled_copy_s2t_SFA, thr_tCsSFA_s2t, + thr_tCtSFA_s2t, tiled_copy_s2t_SFB, + thr_tCsSFB_s2t, thr_tCtSFB_s2t + + // debug + // , sA, sB, tCsSFA, tCsSFB + ] = mma_inputs; + + auto [mainloop_pipeline_tma, mainloop_pipeline_cpasync, accumulator_pipeline] = pipelines; + auto [mainloop_pipe_tma_consumer_state, mainloop_pipe_cpasync_consumer_state, accumulator_pipe_producer_state] = pipeline_states; + + auto tCtSFB_mma = [tCtSFB = tCtSFB, cta_tile_coord]() { + if constexpr (IsCtaN192) { + // If this is an ODD tile, shift the TMEM start address for N=192 case by two words (ignores first 64 columns of SFB) + auto tCtSFB_tmp = tCtSFB; + if (size<1>(cta_tile_coord) % 2 == 1) { + tCtSFB_tmp.data() = tCtSFB_tmp.data().get() + 2; + } + return tCtSFB_tmp; + } + else if constexpr (IsCtaN64) { + // Move in increments of 64 columns of SFB + auto tCtSFB_tmp = tCtSFB; + tCtSFB_tmp.data() = tCtSFB_tmp.data().get() + (size<1>(cta_tile_coord) % 2) * 2; + return tCtSFB_tmp; + } + else { + return tCtSFB; + } + }(); + + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + + // + // PIPELINED MAIN LOOP + // + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + mainloop_pipeline_tma.consumer_wait(mainloop_pipe_tma_consumer_state); + mainloop_pipeline_cpasync.consumer_wait(mainloop_pipe_cpasync_consumer_state); + + int read_stage_tma = mainloop_pipe_tma_consumer_state.index(); + int read_stage_cpasync = mainloop_pipe_cpasync_consumer_state.index(); + + if (cute::elect_one_sync()) { + copy(tiled_copy_s2t_SFA, thr_tCsSFA_s2t(_,_,_,_,read_stage_tma), thr_tCtSFA_s2t); + copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage_tma), thr_tCtSFB_s2t); + } + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma.with(tiled_mma.accumulate_, + tCtSFA(_,_,k_block), + tCtSFB_mma(_,_,k_block)), + tCrA(_,_,k_block,read_stage_tma), + tCrB(_,_,k_block,read_stage_cpasync), + accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + mainloop_pipeline_tma.consumer_release(mainloop_pipe_tma_consumer_state); + mainloop_pipeline_cpasync.consumer_release(mainloop_pipe_cpasync_consumer_state); + --k_tile_count; + ++mainloop_pipe_tma_consumer_state; + ++mainloop_pipe_cpasync_consumer_state; + } + + return cute::make_tuple(mainloop_pipe_tma_consumer_state, mainloop_pipe_cpasync_consumer_state); + } + +protected: + + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + typename Params::TMA_SFA const* observed_tma_load_sfa_{nullptr}; + typename Params::TMA_SFB const* observed_tma_load_sfb_{nullptr}; + + LayoutSFA layout_SFA_; + LayoutSFB layout_SFB_; + + RuntimeDataTypeA runtime_data_type_a_{}; + RuntimeDataTypeB runtime_data_type_b_{}; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp index 79a97be..e949f31 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -604,14 +604,15 @@ struct CollectiveMma< implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); // Check for SFA SFB layout requirement - const auto layout_sfa_ref = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(problem_shape_MNKL); - const auto layout_sfb_ref = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(problem_shape_MNKL); - implementable = implementable && (layout_sfa_ref == args.layout_SFA); + const auto layout_sfa_ref = take<0,2>(Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(problem_shape_MNKL)); + const auto layout_sfb_ref = take<0,2>(Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(problem_shape_MNKL)); + + implementable = implementable && (layout_sfa_ref == take<0,2>(args.layout_SFA)); if (!implementable) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: layout_SFA mismatch, layout_SFA needs to be K-major\n"); } - implementable = implementable && (layout_sfb_ref == args.layout_SFB); + implementable = implementable && (layout_sfb_ref == take<0,2>(args.layout_SFB)); if (!implementable) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: layout_SFB mismatch, layout_SFB needs to be K-major\n"); } diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm100_blockscaled_sparse_mma_warpspecialized.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm100_blockscaled_sparse_mma_warpspecialized.hpp index bcf8862..e20371d 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm100_blockscaled_sparse_mma_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm100_blockscaled_sparse_mma_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -772,14 +772,14 @@ struct CollectiveMma< } // Check for SFA SFB layout requirement - const auto layout_sfa_ref = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(problem_shape_MNKL); - const auto layout_sfb_ref = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(problem_shape_MNKL); - implementable = implementable && (layout_sfa_ref == args.layout_SFA); + const auto layout_sfa_ref = take<0,2>(Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(problem_shape_MNKL)); + const auto layout_sfb_ref = take<0,2>(Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(problem_shape_MNKL)); + implementable = implementable && (layout_sfa_ref == take<0,2>(args.layout_SFA)); if (!implementable) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: layout_SFA mismatch, layout_SFA needs to be K-major\n"); } - implementable = implementable && (layout_sfb_ref == args.layout_SFB); + implementable = implementable && (layout_sfb_ref == take<0,2>(args.layout_SFB)); if (!implementable) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: layout_SFB mismatch, layout_SFB needs to be K-major\n"); } diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp index d832a1f..2af0146 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -123,6 +123,11 @@ struct CollectiveMma< using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + // Multiple buffer the TMA descriptors for each SM so that we can update them asynchronously. + // This should be larger than the total number of TMA requests inflight (from update to issued to returned). + // This can be calculated by SchedulerStages + max(TmaStages) + 2 (for consumer and producer in-flight accessies). + constexpr static uint32_t NumTmaDescriptorsPerSm = SchedulerPipelineStageCount + Stages + 2; + using ElementA = ElementA_; using ElementAMma = typename TiledMma::ValTypeA; using StrideA = StrideA_; @@ -417,7 +422,7 @@ struct CollectiveMma< constexpr uint32_t NumInputTensors = 2; constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies - return (NumInputTensors * SizeOfCuTensorMap * sm_count); + return (NumInputTensors * SizeOfCuTensorMap * sm_count * NumTmaDescriptorsPerSm); } template @@ -497,7 +502,7 @@ struct CollectiveMma< /// tBsB - partitioned smem tensor for B /// mcast_mask_a - tma multicast mask for A /// mcast_mask_b - tma multicast mask for B - template + template CUTLASS_DEVICE auto load_init( ProblemShape_MNKL const& problem_shape_MNKL, @@ -505,6 +510,7 @@ struct CollectiveMma< TensorStorage& shared_tensors, TensorMapStorage& shared_tensormaps, int32_t const sm_count, int32_t const sm_idx, + [[maybe_unused]] int32_t num_groups, [[maybe_unused]] int32_t init_group) const { using X = Underscore; @@ -550,14 +556,20 @@ struct CollectiveMma< uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); - // Fetch a copy of tensormaps for the CTA from Params - auto input_tensormaps = tensormaps_init(params, shared_tensormaps, sm_count, sm_idx); - - return cute::make_tuple( + auto ret = cute::make_tuple( gA_mkl, gB_nkl, // for scheduler tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values - mcast_mask_a, mcast_mask_b, // multicast masks - input_tensormaps); // for tma descriptor modification (per-CTA tensormap copy) + mcast_mask_a, mcast_mask_b // multicast masks + ); + + if constexpr (IsTensorMapUpdateAsync) { + return ret; + } + else { + // Fetch a copy of tensormaps for the CTA from Params + auto input_tensormaps = tensormaps_init(params, shared_tensormaps, sm_count, sm_idx); + return cute::tuple_cat(ret, cute::make_tuple(input_tensormaps)); + } } /// Set up the data needed by this collective for mma compute. @@ -612,7 +624,8 @@ struct CollectiveMma< cute::tuple> const& load_inputs, TileCoordMNKL const& cta_coord_mnkl, KTileIterator k_tile_iter, int k_tile_count, - bool did_batch_change) { + bool did_batch_change, + [[maybe_unused]] int curr_batch) { auto [unused_gA, unused_gB, tAgA_mkl, tBgB_nkl, tAsA, tBsB, @@ -739,6 +752,7 @@ struct CollectiveMma< // Methods to perform different parts of TMA/Tensormap modifications // + template CUTLASS_DEVICE auto tensormaps_init( Params const& mainloop_params, @@ -747,8 +761,8 @@ struct CollectiveMma< int32_t const sm_idx) const { cute::TmaDescriptor* gmem_tensormap = mainloop_params.tensormaps; - cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx]; - cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count]; + cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx * NumTmaDescriptorsPerSm]; + cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[(sm_idx + sm_count) * NumTmaDescriptorsPerSm]; if (cute::elect_one_sync()) { // Bringing tensormaps from params to smem for modification later @@ -762,7 +776,29 @@ struct CollectiveMma< } __syncwarp(); - return cute::make_tuple(tma_desc_a, tma_desc_b); + struct TensorMapArray { + cute::TmaDescriptor* tma_desc_a; + cute::TmaDescriptor* tma_desc_b; + + TensorMapArray() = default; + + CUTLASS_DEVICE + TensorMapArray(cute::TmaDescriptor* tma_desc_a, cute::TmaDescriptor* tma_desc_b) : tma_desc_a(tma_desc_a), tma_desc_b(tma_desc_b) {} + + CUTLASS_DEVICE + cute::tuple + operator[](int32_t idx) { + idx = idx % NumTmaDescriptorsPerSm; + return cute::make_tuple(tma_desc_a + idx, tma_desc_b + idx); + } + }; + + if constexpr (IsTensorMapUpdateAsync) { + return TensorMapArray(tma_desc_a, tma_desc_b); + } + else { + return cute::make_tuple(tma_desc_a, tma_desc_b); + } } // Replace address for the global tensor (to be done by single thread) @@ -826,7 +862,7 @@ struct CollectiveMma< } // The entire warp must call this function collectively (that is, the instructions are aligned) - template + template CUTLASS_DEVICE void tensormaps_perform_update( @@ -834,7 +870,8 @@ struct CollectiveMma< Params const& mainloop_params, cute::tuple const& input_tensormaps, ProblemShape problem_shape, - int32_t next_batch) { + int32_t next_batch + ) { if (cute::elect_one_sync()) { // Replacing global_address for the next batch tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); @@ -849,18 +886,24 @@ struct CollectiveMma< // Ensure warp is converged before issuing tensormap fence release __syncwarp(); // Entire warp must do this (ie its aligned) - tensormaps_cp_fence_release(shared_tensormaps, input_tensormaps); + tensormaps_cp_fence_release( + shared_tensormaps, + input_tensormaps + ); } - template + template CUTLASS_DEVICE void tensormaps_cp_fence_release ( TensorMapStorage& shared_tensormaps, - cute::tuple const& input_tensormaps) { - if (cute::elect_one_sync()) { - cute::tma_desc_commit_group(); - cute::tma_desc_wait_group(); + cute::tuple const& input_tensormaps + ) { + if constexpr (WaitForInflightTmaRequests) { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } } // Entire warp must do this (i.e. it's aligned) tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A); diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_blockwise_scaling.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_blockwise_scaling.hpp index 812553a..291bb11 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_blockwise_scaling.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_blockwise_scaling.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_emulated.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_emulated.hpp index 0a90566..15dd91b 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_emulated.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_emulated.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -181,6 +181,11 @@ struct CollectiveMma< static constexpr uint32_t NumTransformationThreads = 128; static constexpr uint32_t NumAccumThreads = 128; + // Register reconfiguration + static constexpr uint32_t GenericRegisterRequirement = 64; + static constexpr uint32_t TransformRegisterRequirement = 184; + static constexpr uint32_t AccumRegisterRequirement = 256; + // Get the Algorithm parameters constexpr static int NumComputeMtxs = 3; constexpr static int NumBandsToCompute = DispatchPolicy::NumBandsToCompute; diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_interleaved_complex_emulated.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_interleaved_complex_emulated.hpp new file mode 100644 index 0000000..45b9cb4 --- /dev/null +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_interleaved_complex_emulated.hpp @@ -0,0 +1,1202 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + + +#pragma once +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/numeric_conversion.h" +#include "cutlass/detail/sm100_tmem_helper.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/arch/mma_sm100.hpp" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop for FastF32 Kernels +template < + int Load2TransformPipelineStageCount_, + int Transform2MmaPipelineStageCount_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + int NumBandsToCompute_, + int ScalingFactor_, + int AccPromotionInterval_, + class AccumulatorCopyAtom_, + class ClusterShape, + class TileShape_, + class StrideA_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomsA_, + class CopyAtomsA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomsB_, + class CopyAtomsB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100ArrayTmaUmmaWarpSpecializedFastF32< + Load2TransformPipelineStageCount_, + Transform2MmaPipelineStageCount_, + SchedulerPipelineStageCount_, + AccumulatorPipelineStageCount_, + NumBandsToCompute_, + ScalingFactor_, + AccPromotionInterval_, + ClusterShape, + AccumulatorCopyAtom_>, + TileShape_, + complex, + StrideA_, + complex, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomsA_, + CopyAtomsA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomsB_, + CopyAtomsB_, + TransformB_> +{ + // + // Type Aliases + // + using TileShape = TileShape_; + using TiledMma = TiledMma_; + + // ElementA and ElementB are cutlass::complex, which are used as GMEM input and output data type. + using ElementA = complex; + using StrideA = StrideA_; + using InternalStrideA = cute::remove_pointer_t; + using ElementB = complex; + using StrideB = StrideB_; + using InternalStrideB = cute::remove_pointer_t; + + // For a complex kernel, the MMA output type is real valued, but ElementAccumulator is a complex type for the GETT reference kernel + using ElementAccumulator = complex; + using ElementAccumulatorRaw = typename TiledMma::ValTypeC; + +private: + // ElementAMma and ElementBMma are cutlass::complex, which are used as SMEM and RF data type. + // ElementAMmaRaw and ElementBMmaRaw are cutlass::bfloat16_t, which is the real internal data type set in TMA descriptor and used in TCGEN05 calculation. + using ElementAMma = typename TiledMma::ValTypeA; // complex + using ElementAMmaRaw = typename ElementAMma::value_type; // bfloat16_t + using ElementBMma = typename TiledMma::ValTypeB; // complex + using ElementBMmaRaw = typename ElementBMma::value_type; // bfloat16_t + +public: + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomsA = SmemLayoutAtomsA_; + using SmemLayoutAtomsB = SmemLayoutAtomsB_; + using CopyAtomsA = CopyAtomsA_; + using CopyAtomsB = CopyAtomsB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + + static_assert(cute::is_same_v, "Underlying input type for A should be float"); + static_assert(cute::is_same_v, "Underlying input type for B should be float"); + static_assert(cute::is_same_v, "Underlying compute type for A should be bfloat16_t"); + static_assert(cute::is_same_v, "Underlying compute type for A should be bfloat16_t"); + + // Determine MMA type: MMA_1SM vs MMA_2SM + using AtomThrShapeMNK = Shape(typename TiledMma_::ThrLayoutVMNK{})), _1, _1>; + using DispatchPolicy = MainloopSm100ArrayTmaUmmaWarpSpecializedFastF32< + Load2TransformPipelineStageCount_, + Transform2MmaPipelineStageCount_, + SchedulerPipelineStageCount_, + AccumulatorPipelineStageCount_, + NumBandsToCompute_, + ScalingFactor_, + AccPromotionInterval_, + ClusterShape, + AccumulatorCopyAtom_>; + static constexpr bool IsDynamicCluster = not cute::is_static_v; + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + + // Define A and B block shapes for reduced size TMA_LOADs + using CtaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using CtaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + + using ArchTag = typename DispatchPolicy::ArchTag; + + using Load2TransformPipeline = cutlass::PipelineTmaTransformAsync< + DispatchPolicy::Load2TransformPipelineStageCount, + AtomThrShapeMNK>; + using Load2TransformPipelineState = typename Load2TransformPipeline::PipelineState; + + using Transform2MmaPipeline = cutlass::PipelineUmmaConsumerAsync< + DispatchPolicy::Transform2MmaPipelineStageCount, + AtomThrShapeMNK>; + using Transform2MmaPipelineState = typename Transform2MmaPipeline::PipelineState; + + using Mma2AccumPipeline = cutlass::PipelineUmmaAsync< + DispatchPolicy::Schedule::AccumulatorPipelineStageCount, + AtomThrShapeMNK>; + using Mma2AccumPipelineState = typename Mma2AccumPipeline::PipelineState; + + // Thread Counts + static constexpr uint32_t NumTransformationThreads = 128; + static constexpr uint32_t NumAccumThreads = 128; + + // Register reconfiguration + static constexpr uint32_t GenericRegisterRequirement = 64; + static constexpr uint32_t TransformRegisterRequirement = 184; + static constexpr uint32_t AccumRegisterRequirement = 256; + + // Get the Algorithm parameters + constexpr static int NumComputeMtxs = 3; + constexpr static int ConjSwapMode = 2; + constexpr static int NumBandsToCompute = DispatchPolicy::NumBandsToCompute; + constexpr static int ScalingFactor = DispatchPolicy::ScalingFactor; + constexpr static int AccPromotionInterval = DispatchPolicy::AccPromotionInterval; + constexpr static int AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; + constexpr static int StagesPerTile = size<2>(CtaShapeA_MK{}) / DispatchPolicy::AccPromotionInterval; + constexpr static int NumBandsMax = 5; + static_assert(NumBandsToCompute <= NumBandsMax && NumBandsToCompute >= 3, "NumBandsToCompute should be less than maximum number of bands"); + static_assert(StagesPerTile * AccPromotionInterval == size<2>(CtaShapeA_MK{}), "PromotionInterval*InstructionK doesn't evenly divide CTA shape"); + + // Copy atom for Accumulator + using AccumulatorCopyAtom = typename DispatchPolicy::AccumulatorCopyAtom; + + static_assert((NumBandsToCompute == 5 || NumBandsToCompute == 4 || NumBandsToCompute == 3), + "9xBF16 with 5/4/3 Bands are supported"); + + using SmemLayoutAtomA = typename SmemLayoutAtomsA::InputLayoutAtom; + using SmemLayoutAtomACompute = typename SmemLayoutAtomsA::ComputeLayoutAtom; + using SmemLayoutAtomB = typename SmemLayoutAtomsB::InputLayoutAtom; + using SmemLayoutAtomBCompute = typename SmemLayoutAtomsB::ComputeLayoutAtom; + + using InputCopyAtomA = typename CopyAtomsA::InputCopyAtom; + using ComputeCopyAtomA = typename CopyAtomsA::ComputeCopyAtom; + using InputCopyAtomB = typename CopyAtomsB::InputCopyAtom; + using ComputeCopyAtomB = typename CopyAtomsB::ComputeCopyAtom; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(((size<0,0>(CtaShapeA_MK{}) * size<1>(CtaShapeA_MK{})) % size<0>(SmemLayoutAtomACompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + static_assert(((size<0,1>(CtaShapeA_MK{}) * size<2>(CtaShapeA_MK{})) % size<1>(SmemLayoutAtomACompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(((size<0,0>(CtaShapeB_NK{}) * size<1>(CtaShapeB_NK{})) % size<0>(SmemLayoutAtomBCompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + static_assert(((size<0,1>(CtaShapeB_NK{}) * size<2>(CtaShapeB_NK{})) % size<1>(SmemLayoutAtomBCompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(CtaShapeA_MK{}, Int{}), + (cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}))); + + using SmemLayoutACompute = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomACompute{}, + tuple_cat(CtaShapeA_MK{}, tuple, Int>{}))); + + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(CtaShapeB_NK{}, Int{}), + (cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}))); + + using SmemLayoutBCompute = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomBCompute{}, + tuple_cat(CtaShapeB_NK{}, tuple, Int, Int>{}))); + + static_assert(DispatchPolicy::Load2TransformPipelineStageCount >= 2 && DispatchPolicy::Load2TransformPipelineStageCount >= 2, + "Specialization requires Stages set to value 2 or more."); + static_assert((cute::is_base_of::value || + cute::is_base_of::value ) && + cute::is_base_of::value, + "MMA atom must A operand from SMEM or TMEM and B operand from SMEM for this mainloop."); + static_assert((cute::is_same_v || cute::is_same_v), + "GmemTiledCopyA - invalid TMA copy atom specified."); + static_assert((cute::is_same_v || cute::is_same_v), + "GmemTiledCopyB - invalid TMA copy atom specified."); + + struct PipelineStorage { + using Load2TransformPipelineStorage = typename Load2TransformPipeline::SharedStorage; + alignas(16) Load2TransformPipelineStorage load2transform_pipeline; + using Transform2MmaPipelineStorage = typename Transform2MmaPipeline::SharedStorage; + alignas(16) Transform2MmaPipelineStorage transform2mma_pipeline; + using Mma2AccumPipelineStorage = typename Mma2AccumPipeline::SharedStorage; + alignas(16) Mma2AccumPipelineStorage mma2accum_pipeline; + }; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + struct TensorStorageUntransformed { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + }; + + struct TensorStorageTransformedAinSmem { + alignas(1024) cute::ArrayEngine> smem_ACompute; + alignas(1024) cute::ArrayEngine> smem_BCompute; + }; + + union TensorStorageTransformedAinTmem { + alignas(1024) cute::ArrayEngine smem_ACompute; // No smem_ACompute + alignas(1024) cute::ArrayEngine> smem_BCompute; + }; + + using TensorStorageTransformed = cute::conditional_t< + cute::is_base_of::value, + TensorStorageTransformedAinSmem, + TensorStorageTransformedAinTmem>; + + TensorStorageUntransformed input; + TensorStorageTransformed compute; + } tensors; + + struct TensorMapStorage : cute::aligned_struct<128, _0> { + cute::TmaDescriptor smem_tensormap_A; + cute::TmaDescriptor smem_tensormap_B; + } tensormaps; + + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + + // Different from other GEMM kernels, both CTAs should be aware of loads. Both CTAs will work on + // loaded input A and B matrices to convert the data type + static constexpr uint32_t TmaTransactionBytes = + cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * size<2>(SmemLayoutA{}) * static_cast(sizeof_bits::value))+ + cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * size<2>(SmemLayoutB{}) * static_cast(sizeof_bits::value)); + + // Host side kernel arguments + struct Arguments { + ElementA const** ptr_A{nullptr}; + StrideA dA{}; + ElementB const** ptr_B{nullptr}; + StrideB dB{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), + make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + using TMA_B = decltype(make_tma_atom_B_sm100( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + dim3 cluster_shape_fallback; + cute::TmaDescriptor* tensormaps; + ElementA const** ptr_A; + ElementB const** ptr_B; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) { + if constexpr (IsDynamicCluster) { + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + } + } + + template + static constexpr Params + to_underlying_arguments(ProblemShape problem_shape, Arguments const& args, void* workspace, cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + (void) workspace; + + // Tensor shapes for Ptr-Array are initialized correctly here. + auto [M,N,K,mock_L] = problem_shape.get_host_problem_shape(0); + // Batches/Groups are managed by using appropriate pointers to input matrices + mock_L = 1; + + // Tensor pointers will be fixed before the first access + ElementA const* ptr_A_first_batch = nullptr; + ElementB const* ptr_B_first_batch = nullptr; + + Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(M,K,mock_L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(N,K,mock_L), args.dB)); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + // Cluster layout for TMA construction + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + return { + tma_load_a, + tma_load_b, + tma_load_a_fallback, + tma_load_b_fallback, + hw_info.cluster_shape_fallback, + reinterpret_cast(workspace), + reinterpret_cast(args.ptr_A), + reinterpret_cast(args.ptr_B) + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr uint32_t NumInputTensors = 2; + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies + return (NumInputTensors * SizeOfCuTensorMap * sm_count); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + ProblemShape problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto [M,N,K,L] = problem_shape.get_host_problem_shape(0); + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + /// Produce the inputs to the transform threads by loading inputs from gmem -> smem + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class TensorMapA, class TensorMapB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load( + Params const& params, + Load2TransformPipeline pipeline, + Load2TransformPipelineState load2xform_pipeline_state, + cute::tuple> const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + auto [unused_gA, unused_gB, + tAgA_mkl, tBgB_nkl, tAsA, tBsB, + mcast_mask_a, mcast_mask_b, + input_tensormaps] = load_inputs; + + // slice out the work coord from tiled tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + uint32_t skip_wait = (k_tile_count <= 0); + auto pipeline_flag = pipeline.producer_try_acquire(load2xform_pipeline_state, skip_wait); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK mainloop_load2xform_pipeline_state for _writing_ + pipeline.producer_acquire(load2xform_pipeline_state, pipeline_flag); + int write_stage = load2xform_pipeline_state.index(); + + using BarrierType = typename Load2TransformPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(load2xform_pipeline_state); + + // Advance mainloop_pipe + ++load2xform_pipeline_state; + skip_wait = (k_tile_count <= 1); + pipeline_flag = pipeline.producer_try_acquire(load2xform_pipeline_state, skip_wait); + + copy(observed_tma_load_a_->with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage)); + ++k_tile_iter; + } + return cute::make_tuple(load2xform_pipeline_state, k_tile_iter); + } + + /// Set up the data needed by this collective for load. + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tiled tensor for input A + /// gB_nkl - The tiled tensor for input B + // Other inputs needed for load(): partitioned AB tensors for gmem and smem, and mcast masks + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_storage, + int32_t const sm_count, int32_t const sm_idx) const { + auto [gA_mkl, gB_nkl] = tile_input_tensors(params, problem_shape_MNKL); + + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_storage.input.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_storage.input.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,3>(tCgB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + // Fetch a copy of tensormaps for the CTA from Params + auto input_tensormaps = tensormaps_init(params, sm_count, sm_idx); + + return cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b, // multicast masks + input_tensormaps); // for tma descriptor modification (per-CTA tensormap copy) + } + + template< + class KTileIterator, class Accumulator, + class GTensorA, class DstCopyA, class SrcTensorA, class DstTensorA, + class GTensorB, class SrcTensorB, class DstTensorB + > + CUTLASS_DEVICE auto + transform( + Load2TransformPipeline load2transform_pipeline, + Load2TransformPipelineState load2transform_pipeline_consumer_state, + Transform2MmaPipeline transform2mma_pipeline, + Transform2MmaPipelineState transform2mma_pipeline_producer_state, + Accumulator accumulators, + cute::tuple input_operands, + KTileIterator k_tile_iter, int k_tile_count) { + + static_assert(cute::is_same_v, "ElementA and ElementB types should be the same."); + static_assert(cute::is_same_v, "ElementAMma and ElementBMma types should be the same."); + + cutlass::arch::NamedBarrier transform_bar(NumTransformationThreads, cutlass::arch::ReservedNamedBarriers::TransformBarrier); + + // tAsA : (Copy,#Copy),MMA_Rest,MMA_M_Rest,MMA_K_Rest, SmemStages (In SMEM) + // tAdA : (Copy,#Copy),MMA_Rest,MMA_M_Rest,MMA_K_Rest, NumComputeMtxs(Emulation), SmemStages (In SMEM or TMEM) + // tBsB : (Copy,#Copy),MMA_Rest,MMA_N_Rest,MMA_K_Rest, SmemStages (In SMEM) + // tBsB : (Copy,#Copy),MMA_Rest,MMA_N_Rest,MMA_K_Rest, NumComputeMtxs(Complex,Emulation), SmemStages (In SMEM) + auto [unused_tAgA, dst_copy_A, tAsA, tAdACompute, + unused_tBgB, tBsB, tBsBCompute] = input_operands; + + // Create the tensors in registers + auto tArA = make_tensor(tAsA(_,_,_,_,0).shape()); + auto tArA_temp = make_tensor(tAsA(_,_,_,_,0).shape()); + auto tArACompute = make_tensor(tAsA(_,_,_,_,0).shape()); + + auto tBrB = make_tensor(tBsB(_,_,_,_,0).shape()); + auto tBrB_temp = make_tensor(tBsB(_,_,_,_,0).shape()); + auto tBrBCompute = make_tensor(tBsB(_,_,_,_,0).shape()); + + // For compute, cast to 4 raw elements instead of 2 complex elements. + auto tArA_x4 = recast>(tArA); + auto tArA_temp_x4 = recast>(tArA_temp); + auto tArACompute_x4 = recast>(tArACompute); + + auto tBrB_x4 = recast>(tBrB); + auto tBrB_temp_x4 = recast>(tBrB_temp); + auto tBrBCompute_x4 = recast>(tBrBCompute); + + uint32_t skip_wait = (k_tile_count <= 0); + auto load2transform_flag = load2transform_pipeline.consumer_try_wait(load2transform_pipeline_consumer_state, skip_wait); + auto transform2mma_flag = transform2mma_pipeline.producer_try_acquire(transform2mma_pipeline_producer_state, skip_wait); + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + + load2transform_pipeline.consumer_wait(load2transform_pipeline_consumer_state, load2transform_flag); + transform2mma_pipeline.producer_acquire(transform2mma_pipeline_producer_state, transform2mma_flag); + + int load2transform_consumer_index = load2transform_pipeline_consumer_state.index(); + int transform2mma_producer_index = transform2mma_pipeline_producer_state.index(); + + auto curr_load2transform_pipeline_consumer_state = load2transform_pipeline_consumer_state; + auto curr_transform2mma_pipeline_producer_state = transform2mma_pipeline_producer_state; + + // Copy the input B matrix from SMEM + copy(AutoVectorizingCopy{}, tBsB(_,_,_,_,load2transform_consumer_index), tBrB); + // Copy the input A matrix from SMEM + copy(AutoVectorizingCopy{}, tAsA(_,_,_,_,load2transform_consumer_index), tArA); + + /// NOTE: sm100_mma_warpspecialized_interleaved_complex_tf32.hpp introduced about expanding: + /// re(a_complex * b_complex) -> (a_re, a_im) . (b_re,-b_im) = a . b_conj + /// im(a_complex * b_complex) -> (a_re, a_im) . (b_im, b_re) = a . b_swap + /// However, 16b types need to be packed for swapping and negation. + /// Hence, (re | im | re | im) is reordered into (re_x2 | im_x2). + cute::transform(tBrB_x4, tBrB_x4, [&] (auto& f4) -> Array {return {f4[0], f4[2], f4[1], f4[3]};}); + // Conversion b -> b_conj goes first, hence TransformB has a not preceding it. + if constexpr (not cute::is_same_v) { + cute::transform(tBrB_x4, tBrB_x4, [&] (auto& f4) { + auto f2_x2 = *reinterpret_cast,2>*>(&f4); + f2_x2[1] = cutlass::negate>{}(f2_x2[1]); + return *reinterpret_cast*>(&f2_x2); + }); + } + CUTE_UNROLL + for (int comp_mtx_index = 0; comp_mtx_index < NumComputeMtxs; ++comp_mtx_index) { + // Convert from fp32 -> bf16 + cute::transform(tBrB_x4, tBrBCompute_x4, + cutlass::NumericArrayConverter::convert); + // Store as B_conj (for producing C_re) + copy(AutoVectorizingCopy{}, tBrBCompute, tBsBCompute(_,_,_,_,0,comp_mtx_index,transform2mma_producer_index)); + // if it is not the last compute matrix, scale and substract + if (comp_mtx_index < NumComputeMtxs - 1) { + // Convert from bf16 -> fp32 to substract + cute::transform(tBrBCompute_x4, tBrB_temp_x4, + cutlass::NumericArrayConverter::convert); + cute::transform(tBrB_x4, tBrB_temp_x4, tBrB_x4, cutlass::minus>{}); + if constexpr (DispatchPolicy::ScalingFactor != 0) { + cute::transform(tBrB_x4, tBrB_x4, cutlass::scale>{(1 << DispatchPolicy::ScalingFactor)}); + } + } + + // Convert B_conj to B_swap + cute::transform(tBrBCompute_x4, tBrBCompute_x4, [&] (auto& h4) { + // Reinterpret as packed types + auto h2_x2_conj = *reinterpret_cast,2>*>(&h4); + cutlass::negate> neg; + Array,2> h2_x2_swap{ neg(h2_x2_conj[1]), h2_x2_conj[0] }; + return *reinterpret_cast*>(&h2_x2_swap); + }); + // Store as B_swap (for producing C_im) + copy(AutoVectorizingCopy{}, tBrBCompute, tBsBCompute(_,_,_,_,1,comp_mtx_index,transform2mma_producer_index)); + } + + // Loads from SMEM are done. Signal the mainloop load as early as possible + transform_bar.sync(); + load2transform_pipeline.consumer_release(curr_load2transform_pipeline_consumer_state); + + // ( re | im | re | im ) -> ( re_x2 | im_x2 ) + cute::transform(tArA_x4, tArA_x4, [&] (auto& f4) -> Array{return {f4[0], f4[2], f4[1], f4[3]};}); + if constexpr (cute::is_same_v) { + cute::transform(tArA_x4, tArA_x4, [&] (auto& f4) { + auto f2_x2 = *reinterpret_cast,2>*>(&f4); + f2_x2[1] = cutlass::negate>{}(f2_x2[1]); + return *reinterpret_cast*>(&f2_x2); + }); + } + CUTE_UNROLL + for (int comp_mtx_index = 0; comp_mtx_index < NumComputeMtxs; ++comp_mtx_index) { + // Convert from fp32 -> bf16 + cute::transform(tArA_x4, tArACompute_x4, + cutlass::NumericArrayConverter::convert); + copy(dst_copy_A, tArACompute, tAdACompute(_,_,_,_,comp_mtx_index,transform2mma_producer_index)); + // if it is not the last compute matrix, scale and substract + if (comp_mtx_index < NumComputeMtxs - 1) { + // Convert from bf16 -> fp32 to substract + cute::transform(tArACompute_x4, tArA_temp_x4, + cutlass::NumericArrayConverter::convert); + cute::transform(tArA_x4, tArA_temp_x4, tArA_x4, cutlass::minus>{}); + if constexpr (DispatchPolicy::ScalingFactor != 0) { + cute::transform(tArA_x4, tArA_x4, cutlass::scale>{(1 << DispatchPolicy::ScalingFactor)}); + } + } + } + + // fence for SMEM writes + cutlass::arch::fence_view_async_shared(); + if constexpr (is_tmem::value) { + // fence for TMEM writes if A operand is coming from TMEM + cutlass::arch::fence_view_async_tmem_store(); + } + + // Let the MMA know we are done transforming + transform2mma_pipeline.producer_commit(curr_transform2mma_pipeline_producer_state); + // Next pipeline stage + ++load2transform_pipeline_consumer_state; + ++transform2mma_pipeline_producer_state; + + skip_wait = (k_tile_count <= 1); + // Peek the next pipeline stage's barriers + load2transform_flag = load2transform_pipeline.consumer_try_wait(load2transform_pipeline_consumer_state, skip_wait); + transform2mma_flag = transform2mma_pipeline.producer_try_acquire(transform2mma_pipeline_producer_state, skip_wait); + } + return cute::make_tuple(load2transform_pipeline_consumer_state, transform2mma_pipeline_producer_state); + } + + template + CUTLASS_DEVICE auto + transform_init( + Params const& params, + ProblemShape_MNKL const& problem_shape_MNKL, + Accumulator accumulators, + TensorStorage& shared_storage) { + auto [gA_mkl, gB_nkl] = tile_input_tensors(params, problem_shape_MNKL); + + Tensor sA_orig = make_tensor(make_smem_ptr(shared_storage.input.smem_A.begin()), SmemLayoutA{}); + Tensor sA = as_position_independent_swizzle_tensor(sA_orig); + Tensor sACompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_ACompute.begin()), SmemLayoutACompute{}); + + Tensor sB_orig = make_tensor(make_smem_ptr(shared_storage.input.smem_B.begin()), SmemLayoutB{}); + Tensor sB = as_position_independent_swizzle_tensor(sB_orig); + Tensor sBCompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_BCompute.begin()), SmemLayoutBCompute{}); + + // Map input, compute, and fragment tensors to + // Copy strategies and partitioned tensors. These will become the input + // operands of the transform function. Depending on MMA atom type, the + // operands can reside in SMEM or TMEM + auto setup_copy_ops = [&] ( + auto tensor_input, + auto input_copy_atom, + auto tensor_compute, + auto make_fragment, + auto compute_copy_atom) constexpr { + auto fragment_compute = make_fragment(tensor_compute); + if constexpr (cute::is_tmem>::value) { + // For M=128 with 2CTA MMA atoms, the TMEM tensor for A has a duplicated allocation. + // Instead of allocation a 64x16 TMEM tensor, we have a 128x16 allocation + // See: TmemAllocMode::Duplicated. + Tensor tensor_input2x = [&] () constexpr { + if constexpr (decltype(size<0,0>(fragment_compute) == Int<128>{} && size<0,0>(tensor_input) == Int<64>{})::value) { + return make_tensor(tensor_input.data(), + logical_product(tensor_input.layout(), + make_tile(make_tile(Layout<_2,_0>{},_),_,_,_))); // ((128,16),m,k,PIPE) + } + else { + return tensor_input; + } + }(); + + fragment_compute.data() = accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + // If operand comes from TMEM, create the TMEM_STORE based copy + auto reg2tmem_tiled_copy = make_tmem_copy(compute_copy_atom, fragment_compute(_,_,_,0,0)); + auto thr_reg2tmem_tiled_copy = reg2tmem_tiled_copy.get_slice(threadIdx.x % NumTransformationThreads); + auto partitioned_tensor_input = thr_reg2tmem_tiled_copy.partition_S(tensor_input2x); + auto partitioned_tensor_compute = thr_reg2tmem_tiled_copy.partition_D(fragment_compute); + return cute::make_tuple(reg2tmem_tiled_copy, partitioned_tensor_input, partitioned_tensor_compute); + } + else { + // If the operand comes from SMEM, create SMEM copy. + auto tensor_compute_ind_sw = as_position_independent_swizzle_tensor(tensor_compute); + auto reg2smem_tiled_copy = make_cotiled_copy(compute_copy_atom, Layout, Stride< _8,_1>>{}, + take<0,3>(tensor_compute.layout())); + + // Source copy is based on the source operand of copy. + auto thr_reg2smem_tiled_copy = reg2smem_tiled_copy.get_slice(threadIdx.x % NumTransformationThreads); + auto partitioned_tensor_input = thr_reg2smem_tiled_copy.partition_S(tensor_input); + auto partitioned_tensor_compute = thr_reg2smem_tiled_copy.partition_D(tensor_compute_ind_sw); + + return cute::make_tuple(AutoVectorizingCopy{}, partitioned_tensor_input, partitioned_tensor_compute); + } + }; + + auto [dst_copy_A, tAsA, tAsACompute] = + setup_copy_ops(sA, InputCopyAtomA{}, sACompute, [&](auto &arg) {return TiledMma::make_fragment_A(arg);}, ComputeCopyAtomA{}); + + auto [dst_copy_B, tBsB, tBsBCompute] = + setup_copy_ops(sB, InputCopyAtomB{}, sBCompute, [&](auto &arg) {return TiledMma::make_fragment_B(arg);}, ComputeCopyAtomB{}); + + return cute::make_tuple(gA_mkl, dst_copy_A, tAsA, tAsACompute, + gB_nkl, tBsB, tBsBCompute); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgEngine, class FrgLayout, + class TensorA, class TensorB + > + CUTLASS_DEVICE auto + mma( + Transform2MmaPipeline transform2mma_pipeline, + Transform2MmaPipelineState transform2mma_pipeline_consumer_state, + Mma2AccumPipeline mma2accum_pipeline, + Mma2AccumPipelineState mma2accum_pipeline_producer_state, + cute::Tensor const& accumulators, + cute::tuple const& input_operands, + int k_tile_count + ) { + TiledMma tiled_mma; + + auto curr_transform2mma_pipeline_consumer_state = transform2mma_pipeline_consumer_state; + auto next_transform2mma_pipeline_consumer_state = transform2mma_pipeline_consumer_state; + uint32_t skip_wait = (k_tile_count <= 0); + auto transform2mma_flag = transform2mma_pipeline.consumer_try_wait(next_transform2mma_pipeline_consumer_state, skip_wait); + ++next_transform2mma_pipeline_consumer_state; + + // tCrA : (MMA), MMA_M, MMA_K, NumComputeMtxs, SmemStage (In SMEM or TMEM) + // We use SMEM stages to match #buffers in Load <-> Convert + // tCrB : (MMA), MMA_N, MMA_K, NumComputeMtxs, SmemStages (In SMEM) + auto const [tCrA, tCrB] = input_operands; + + using ZeroScaler = cute::integral_constant; + using Scaler = cute::integral_constant; + + int remaining_accum_promotions = k_tile_count * StagesPerTile; + uint32_t mma2accum_skip_wait = (remaining_accum_promotions <= 0); + auto mma2accum_flag = mma2accum_pipeline.producer_try_acquire(mma2accum_pipeline_producer_state, mma2accum_skip_wait); + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + + transform2mma_pipeline.consumer_wait(curr_transform2mma_pipeline_consumer_state, transform2mma_flag); + + int transform2mma_pipeline_consumer_state_index = curr_transform2mma_pipeline_consumer_state.index(); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); k_block += DispatchPolicy::AccPromotionInterval, --remaining_accum_promotions) { + // Accum stages are organized as (C_real | C_imag | C_real | C_imag | ...) + CUTLASS_PRAGMA_UNROLL + for (int re_im = 0; re_im < 2; ++re_im) { + mma2accum_pipeline.producer_acquire(mma2accum_pipeline_producer_state, mma2accum_flag); + + int mma2accum_pipeline_producer_state_index = mma2accum_pipeline_producer_state.index(); + auto tCtC = accumulators(_,_,_,mma2accum_pipeline_producer_state_index); + auto curr_mma2accum_pipeline_producer_state = mma2accum_pipeline_producer_state; + + ++mma2accum_pipeline_producer_state; + mma2accum_skip_wait = (remaining_accum_promotions <= 1) && (re_im == 1); + mma2accum_flag = mma2accum_pipeline.producer_try_acquire(mma2accum_pipeline_producer_state, mma2accum_skip_wait); + + auto tCrA0 = tCrA(_,_,_,0,transform2mma_pipeline_consumer_state_index); + auto tCrA1 = tCrA(_,_,_,1,transform2mma_pipeline_consumer_state_index); + auto tCrA2 = tCrA(_,_,_,2,transform2mma_pipeline_consumer_state_index); + + auto tCrB0 = tCrB(_,_,_,re_im,0,transform2mma_pipeline_consumer_state_index); + auto tCrB1 = tCrB(_,_,_,re_im,1,transform2mma_pipeline_consumer_state_index); + auto tCrB2 = tCrB(_,_,_,re_im,2,transform2mma_pipeline_consumer_state_index); + + // MMA instructions Emulation + auto accumulate = UMMA::ScaleOut::Zero; + + // First set of GEMMs that we need to perform for each band are unrolled to set compile-time constant + // scaling parameter. Scaled GEMM operations are only needed for the first MMA operation of each band. + + // Band 5 + if constexpr (NumBandsToCompute == 5) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block), tCrB2(_,_,k_block), tCtC); // A[2]*B[2] + accumulate = UMMA::ScaleOut::One; + CUTLASS_PRAGMA_UNROLL + for (int s = 1; s < DispatchPolicy::AccPromotionInterval; s++) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block+s), tCrB2(_,_,k_block+s), tCtC); // A[2]*B[2] + } + } + // Band 4 + if constexpr (NumBandsToCompute >= 4) { + cute::gemm(tiled_mma.with(accumulate, Scaler{}), tCrA1(_,_,k_block), tCrB2(_,_,k_block), tCtC); // A[1]*B[2] + accumulate = UMMA::ScaleOut::One; + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block), tCrB1(_,_,k_block), tCtC); // A[2]*B[1] + CUTLASS_PRAGMA_UNROLL + for (int s = 1; s < DispatchPolicy::AccPromotionInterval; s++) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA1(_,_,k_block+s), tCrB2(_,_,k_block+s), tCtC); // A[1]*B[2] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block+s), tCrB1(_,_,k_block+s), tCtC); // A[2]*B[1] + } + } + // Band 3 + cute::gemm(tiled_mma.with(accumulate, Scaler{}), tCrA0(_,_,k_block), tCrB2(_,_,k_block), tCtC); // A[2]*B[0] + accumulate = UMMA::ScaleOut::One; + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA1(_,_,k_block), tCrB1(_,_,k_block), tCtC); // A[1]*B[1] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block), tCrB0(_,_,k_block), tCtC); // A[0]*B[2] + CUTLASS_PRAGMA_UNROLL + for (int s = 1; s < DispatchPolicy::AccPromotionInterval; s++) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA0(_,_,k_block+s), tCrB2(_,_,k_block+s), tCtC); // A[2]*B[0] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA1(_,_,k_block+s), tCrB1(_,_,k_block+s), tCtC); // A[1]*B[1] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block+s), tCrB0(_,_,k_block+s), tCtC); // A[0]*B[2] + } + // Band 2 + cute::gemm(tiled_mma.with(accumulate, Scaler{}), tCrA0(_,_,k_block), tCrB1(_,_,k_block), tCtC); // A[0]*B[1] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA1(_,_,k_block), tCrB0(_,_,k_block), tCtC); // A[1]*B[0] + CUTLASS_PRAGMA_UNROLL + for (int s = 1; s < DispatchPolicy::AccPromotionInterval; s++) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA0(_,_,k_block+s), tCrB1(_,_,k_block+s), tCtC); // A[0]*B[1] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA1(_,_,k_block+s), tCrB0(_,_,k_block+s), tCtC); // A[1]*B[0] + } + // Band 1 + cute::gemm(tiled_mma.with(accumulate, Scaler{}), tCrA0(_,_,k_block), tCrB0(_,_,k_block), tCtC); // A[0]*B[0] + CUTLASS_PRAGMA_UNROLL + for (int s = 1; s < DispatchPolicy::AccPromotionInterval; s++) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA0(_,_,k_block+s), tCrB0(_,_,k_block+s), tCtC); // A[0]*B[0] + } + mma2accum_pipeline.producer_commit(curr_mma2accum_pipeline_producer_state); + } + } + + transform2mma_pipeline.consumer_release(curr_transform2mma_pipeline_consumer_state); + + skip_wait = (k_tile_count <= 1); + transform2mma_flag = transform2mma_pipeline.consumer_try_wait(next_transform2mma_pipeline_consumer_state, skip_wait); + + curr_transform2mma_pipeline_consumer_state = next_transform2mma_pipeline_consumer_state; + ++next_transform2mma_pipeline_consumer_state; + } + return cute::make_tuple(curr_transform2mma_pipeline_consumer_state, mma2accum_pipeline_producer_state); + } + + template + CUTLASS_DEVICE auto + mma_init(cute::Tensor const& accumulators, TensorStorage& shared_storage) const { + TiledMma tiled_mma; + + Tensor tCrA = [&] () constexpr { + if constexpr (cute::is_base_of::value) { + Tensor sACompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_ACompute.begin()), SmemLayoutACompute{}); + return tiled_mma.make_fragment_A(sACompute); + } + else { + auto tCrA = tiled_mma.make_fragment_A(shape(SmemLayoutACompute{})); + tCrA.data() = accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + return tCrA; + } + } (); + Tensor sBCompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_BCompute.begin()), SmemLayoutBCompute{}); + Tensor tCrB = tiled_mma.make_fragment_B(sBCompute); + return cute::make_tuple(tCrA, tCrB); + } + + template + CUTLASS_DEVICE auto + accum_init(cute::Tensor const& accumulators, TmemCopyAtom tmem_cp_atom, EpilogueTile epilogue_tile) { + // Obtain a single accumulator + Tensor tAcc = tensor<0>(accumulators(_,_,_,_0{})); + // Apply epilogue subtiling + Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + // Create the TMEM copy for single EpilogueTile. + // Note that EpilogueTile = CtaTile for NoSmem epilogue + auto tiled_t2r = make_tmem_copy(tmem_cp_atom, tAcc_epi(_,_,_0{},_0{})); + auto thread_t2r = tiled_t2r.get_slice(threadIdx.x % size(tiled_t2r)); + Tensor tTR_gC = thread_t2r.partition_D(tAcc_epi); + Tensor tTR_rAcc = make_tensor(shape(tTR_gC)); // (T2R,T2R_M,T2R_N) + Tensor tTR_rGlobAcc = make_tensor(shape(tTR_gC)); // (T2R,T2R_M,T2R_N) + + // Apply epilogue subtiling to bulk accumulator + // We need to tile the whole bulk_tmem allocation with EpilogueTile. + // The accumulation should be aware of the AccumulatorPipelineStages + Tensor tBulkAcc_epi = flat_divide(accumulators(make_coord(_,_),_0{},_0{},_), EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N,PIPE) + Tensor tTR_tBulkAcc = thread_t2r.partition_S(tBulkAcc_epi); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N,PIPE) + return cute::make_tuple(tiled_t2r, thread_t2r, tTR_tBulkAcc, tTR_rAcc, tTR_rGlobAcc); + } + + template + CUTLASS_DEVICE auto + accum(cute::tuple accum_inputs, + Mma2AccumPipeline mma2accum_pipeline, + Mma2AccumPipelineState mma2accum_pipeline_consumer_state, + int k_tile_count) { + auto [tiled_t2r, thread_t2r, tTR_tBulkAcc, + tTR_rAcc, tTR_rGlobAcc] = accum_inputs; + + Tensor tTR_rAcc_float2 = recast>(tTR_rAcc); // (T2R/2,T2R_M,T2R_N) + Tensor tTR_rGlobAcc_float2_x2 = recast,2>>(tTR_rGlobAcc);// (T2R/2,T2R_M,T2R_N) + + // Clear the global accumulator + CUTE_UNROLL + for (int i = 0; i 0; --k_tile_count) { + // The stage is limited to a CTA tile + CUTLASS_PRAGMA_NO_UNROLL + for (int k_block = 0; k_block cute::remove_cvref_t {return {cutlass::plus>{}(f2_x2[0], r2), f2_x2[1]};}); + } + else { + cute::transform(tTR_rGlobAcc_float2_x2, tTR_rAcc_float2, tTR_rGlobAcc_float2_x2, + [&] (auto& f2_x2, auto& i2) -> cute::remove_cvref_t {return {f2_x2[0], cutlass::plus>{}(f2_x2[1], i2)};}); + } + + cutlass::arch::fence_view_async_tmem_load(); // Need a fence bw TMEM_LOAD and arrive + mma2accum_pipeline.consumer_release(mma2accum_pipeline_consumer_state); + + ++mma2accum_pipeline_consumer_state; + skip_wait = ((k_tile_count <= 1) && (k_block >= (StagesPerTile-1))) && (re_im == 1); + mma2accum_flag = mma2accum_pipeline.consumer_try_wait(mma2accum_pipeline_consumer_state, skip_wait); + } + } + } + + // Interleave back (real_x2 | imag_x2) to (real | imag | real | imag) + cute::transform(tTR_rGlobAcc_float2_x2, tTR_rGlobAcc_float2_x2, [&] (auto& f2_x2) -> cute::remove_cvref_t { + Array c0{f2_x2[0][0], f2_x2[1][0]}; + Array c1{f2_x2[0][1], f2_x2[1][1]}; + return {c0, c1}; + }); + + return cute::make_tuple(mma2accum_pipeline_consumer_state, tTR_rGlobAcc); + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + + CUTLASS_DEVICE auto + tensormaps_init(Params const& mainloop_params, int32_t const sm_count, int32_t const sm_idx) const { + cute::TmaDescriptor* gmem_tensormap = mainloop_params.tensormaps; + + cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx]; + cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count]; + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to gmem for modification later + Tensor pA_tensormap = make_tensor(observed_tma_load_a_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor gA_tensormap = make_tensor(tma_desc_a, Int<1>{}, Int<1>{}); + Tensor pB_tensormap = make_tensor(observed_tma_load_b_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor gB_tensormap = make_tensor(tma_desc_b, Int<1>{}, Int<1>{}); + + copy(recast(pA_tensormap), recast(gA_tensormap)); + copy(recast(pB_tensormap), recast(gB_tensormap)); + } + + return cute::make_tuple(tma_desc_a, tma_desc_b); + } + + // Bringing tensormaps to smem (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_fetch_to_smem( + TensorMapStorage& shared_tensormap, + cute::tuple const& input_tensormaps) const { + Tensor gA_tensormap = make_tensor(make_gmem_ptr(get<0>(input_tensormaps)), Int<1>{}, Int<1>{}); + Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_A), Int<1>{}, Int<1>{}); + Tensor gB_tensormap = make_tensor(make_gmem_ptr(get<1>(input_tensormaps)), Int<1>{}, Int<1>{}); + Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_B), Int<1>{}, Int<1>{}); + + copy(recast(gA_tensormap), recast(sA_tensormap)); + copy(recast(gB_tensormap), recast(sB_tensormap)); + + cp_async_fence(); + cp_async_wait<0>(); + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE + void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormap, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_A, + mainloop_params.ptr_A[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_B, + mainloop_params.ptr_B[next_batch]); + } + + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + TensorMapStorage& shared_tensormap, + Params const& mainloop_params, + cute::tuple const& input_tensormaps, + int32_t next_batch, + uint32_t lane_predicate) { + if (lane_predicate) { + // Bringing tensormaps to smem + tensormaps_fetch_to_smem(shared_tensormap, input_tensormaps); + + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormap, mainloop_params, next_batch); + } + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release ( + TensorMapStorage& shared_tensormap, + cute::tuple const& input_tensormaps) { + if (cute::elect_one_sync()) { + // Perform using same thread as the one that issued TMA store, separate these out as far as possible to hide latency + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + // Entire warp must do this (i.e. it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormap.smem_tensormap_A); + tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormap.smem_tensormap_B); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::tuple const& input_tensormaps) { + cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps)); + } + +protected: + + template + CUTLASS_DEVICE + constexpr auto + tile_input_tensors(Params const& params, ProblemShape_MNKL const& problem_shape_MNKL) const { + using X = cute::Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,L)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K,L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); + + return cute::make_tuple(gA_mkl, gB_nkl); + } + + typename Params::TMA_A const* observed_tma_load_a_ = nullptr; + typename Params::TMA_B const* observed_tma_load_b_ = nullptr; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_interleaved_complex_tf32.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_interleaved_complex_tf32.hpp new file mode 100644 index 0000000..33507c9 --- /dev/null +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_interleaved_complex_tf32.hpp @@ -0,0 +1,992 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +#pragma once +#include + +#include "cutlass/cutlass.h" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/numeric_conversion.h" +#include "cutlass/detail/sm100_tmem_helper.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/arch/mma_sm100.hpp" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop for complex kernels +template < + int ComputationPipelineStageCount_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + int TransformationPipelineStageCount_, + class AccumulatorCopyAtom_, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class StrideA_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomsA_, + class CopyAtomsA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomsB_, + class CopyAtomsB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100ArrayTmaUmmaWarpSpecializedInterleavedComplexTF32< + ComputationPipelineStageCount_, + SchedulerPipelineStageCount_, + AccumulatorPipelineStageCount_, + TransformationPipelineStageCount_, + ClusterShape, + AccumulatorCopyAtom_>, + TileShape_, + complex, + StrideA_, + complex, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomsA_, + CopyAtomsA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomsB_, + CopyAtomsB_, + TransformB_> +{ + // + // Type Aliases + // + using TileShape = TileShape_; + using TiledMma = TiledMma_; + + // ElementA and ElementB are cutlass::complex, which are used as GMEM input and output data type. + using ElementA = complex; + using StrideA = StrideA_; + using InternalStrideA = cute::remove_pointer_t; + using ElementB = complex; + using StrideB = StrideB_; + using InternalStrideB = cute::remove_pointer_t; + +private: + // ElementAMma and ElementBMma are cutlass::complex, which are used as SMEM and RF data type. + // ElementAMmaRaw and ElementBMmaRaw are cutlass::tfloat32_t, which is the real internal data type set in TMA descriptor and used in TCGEN05 calculation. + using ElementAMma = typename TiledMma::ValTypeA; // complex + using ElementAMmaRaw = typename ElementAMma::value_type; // tfloat32_t + using ElementBMma = typename TiledMma::ValTypeB; // complex + using ElementBMmaRaw = typename ElementBMma::value_type; // tfloat32_t + +public: + // For a complex kernel, the MMA output type is real valued, but ElementAccumulator is a complex type for the GETT reference kernel + using ElementAccumulator = cutlass::complex; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomsA = SmemLayoutAtomsA_; + using SmemLayoutAtomsB = SmemLayoutAtomsB_; + using CopyAtomsA = CopyAtomsA_; + using CopyAtomsB = CopyAtomsB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + + // Determine MMA type: MMA_1SM vs MMA_2SM + using AtomThrShapeMNK = Shape(typename TiledMma_::ThrLayoutVMNK{})), _1, _1>; + using DispatchPolicy = MainloopSm100ArrayTmaUmmaWarpSpecializedInterleavedComplexTF32< + ComputationPipelineStageCount_, + SchedulerPipelineStageCount_, + AccumulatorPipelineStageCount_, + TransformationPipelineStageCount_, + ClusterShape, + AccumulatorCopyAtom_>; + static constexpr bool IsDynamicCluster = not cute::is_static_v; + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + + // Define A and B block shapes for reduced size TMA_LOADs + using CtaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using CtaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + + using ArchTag = typename DispatchPolicy::ArchTag; + + using Load2TransformPipeline = cutlass::PipelineTmaTransformAsync< + DispatchPolicy::ComputationPipelineStageCount, + AtomThrShapeMNK>; + using Load2TransformPipelineState = typename Load2TransformPipeline::PipelineState; + + using Transform2MmaPipeline = cutlass::PipelineUmmaConsumerAsync< + DispatchPolicy::TransformationPipelineStageCount, + AtomThrShapeMNK>; + using Transform2MmaPipelineState = typename Transform2MmaPipeline::PipelineState; + + using Mma2AccumPipeline = cutlass::PipelineUmmaAsync< + DispatchPolicy::Schedule::AccumulatorPipelineStageCount, + AtomThrShapeMNK>; + using Mma2AccumPipelineState = typename Mma2AccumPipeline::PipelineState; + + // Thread Counts + static constexpr uint32_t NumTransformationThreads = 128; + static constexpr uint32_t NumAccumThreads = 128; + + // Register reconfiguration + static constexpr uint32_t GenericRegisterRequirement = 152; + static constexpr uint32_t TransformRegisterRequirement = 200; + static constexpr uint32_t AccumRegisterRequirement = 152; + + // Get the Algorithm parameters + constexpr static int NumComputeMtxs = 2; + constexpr static int AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; + constexpr static int StagesPerTile = size<2>(CtaShapeA_MK{}); + + // Copy atom for Accumulator + using AccumulatorCopyAtom = typename DispatchPolicy::AccumulatorCopyAtom; + + using SmemLayoutAtomA = typename SmemLayoutAtomsA::InputLayoutAtom; + using SmemLayoutAtomACompute = typename SmemLayoutAtomsA::ComputeLayoutAtom; + using SmemLayoutAtomB = typename SmemLayoutAtomsB::InputLayoutAtom; + using SmemLayoutAtomBCompute = typename SmemLayoutAtomsB::ComputeLayoutAtom; + + using InputCopyAtomA = typename CopyAtomsA::InputCopyAtom; + using ComputeCopyAtomA = typename CopyAtomsA::ComputeCopyAtom; + using InputCopyAtomB = typename CopyAtomsB::InputCopyAtom; + using ComputeCopyAtomB = typename CopyAtomsB::ComputeCopyAtom; + + static_assert(((size<0,0>(CtaShapeA_MK{}) * size<1>(CtaShapeA_MK{})) % size<0>(SmemLayoutAtomACompute{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(CtaShapeA_MK{}) * size<2>(CtaShapeA_MK{})) % size<1>(SmemLayoutAtomACompute{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(((size<0,0>(CtaShapeB_NK{}) * size<1>(CtaShapeB_NK{})) % size<0>(SmemLayoutAtomBCompute{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(CtaShapeB_NK{}) * size<2>(CtaShapeB_NK{})) % size<1>(SmemLayoutAtomBCompute{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(CtaShapeA_MK{}, Int{}), + (cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}))); + + using SmemLayoutACompute = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomACompute{}, + append(append(CtaShapeA_MK{}, Int{}), Int{}))); + + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(CtaShapeB_NK{}, Int{}), + (cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}))); + + using SmemLayoutBCompute = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomBCompute{}, + append(CtaShapeB_NK{}, Int{}))); + + static_assert(DispatchPolicy::ComputationPipelineStageCount >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(DispatchPolicy::TransformationPipelineStageCount >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must have A operand from TMEM and B operand from SMEM for this mainloop."); + static_assert((cute::is_same_v || cute::is_same_v), + "GmemTiledCopyA - invalid TMA copy atom specified."); + static_assert((cute::is_same_v || cute::is_same_v), + "GmemTiledCopyB - invalid TMA copy atom specified."); + + struct PipelineStorage { + using Load2TransformPipelineStorage = typename Load2TransformPipeline::SharedStorage; + alignas(16) Load2TransformPipelineStorage load2transform_pipeline; + using Transform2MmaPipelineStorage = typename Transform2MmaPipeline::SharedStorage; + alignas(16) Transform2MmaPipelineStorage transform2mma_pipeline; + using Mma2AccumPipelineStorage = typename Mma2AccumPipeline::SharedStorage; + alignas(16) Mma2AccumPipelineStorage mma2accum_pipeline; + }; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + struct TensorStorageUntransformed { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + } input; + + union TensorStorageTransformed { + alignas(1024) cute::ArrayEngine smem_ACompute; // smem_ACompute is actually in tmem + alignas(1024) cute::ArrayEngine> smem_BCompute; + } compute; + } tensors; + + struct TensorMapStorage : cute::aligned_struct<128, _0> { + cute::TmaDescriptor smem_tensormap_A; + cute::TmaDescriptor smem_tensormap_B; + } tensormaps; + + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + + // Different from other GEMM kernels, both CTAs should be aware of loads. Both CTAs will work on + // loaded input A and B matrices to convert the data type + static constexpr uint32_t TmaTransactionBytes = + (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * size<2>(SmemLayoutA{}) * static_cast(sizeof(ElementAMma))) + + (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * size<2>(SmemLayoutB{}) * static_cast(sizeof(ElementBMma))); + + // Host side kernel arguments + struct Arguments { + ElementA const** ptr_A{nullptr}; + StrideA dA{}; + ElementB const** ptr_B{nullptr}; + StrideB dB{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), + make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + using TMA_B = decltype(make_tma_atom_B_sm100( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + dim3 cluster_shape_fallback; + cute::TmaDescriptor* tensormaps; + ElementA const** ptr_A; + ElementB const** ptr_B; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) { + if constexpr (IsDynamicCluster) { + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + } + } + + template + static constexpr Params + to_underlying_arguments(ProblemShape problem_shape, Arguments const& args, void* workspace, cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + (void) workspace; + + // Tensor shapes for Ptr-Array are initialized correctly here. + auto [M,N,K,mock_L] = problem_shape.get_host_problem_shape(0); + // Batches/Groups are managed by using appropriate pointers to input matrices + mock_L = 1; + + // Tensor pointers will be fixed before the first access + ElementA const* ptr_A_first_batch = nullptr; + ElementB const* ptr_B_first_batch = nullptr; + + Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(M,K,mock_L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(N,K,mock_L), args.dB)); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + // Cluster layout for TMA construction + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + return { + tma_load_a, + tma_load_b, + tma_load_a_fallback, + tma_load_b_fallback, + hw_info.cluster_shape_fallback, + reinterpret_cast(workspace), + reinterpret_cast(args.ptr_A), + reinterpret_cast(args.ptr_B) + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr uint32_t NumInputTensors = 2; + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies + return (NumInputTensors * SizeOfCuTensorMap * sm_count); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + ProblemShape problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto [M,N,K,L] = problem_shape.get_host_problem_shape(0); + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE auto + partition_accumulator_shape() { + return append( + partition_shape_C(TiledMma{}, take<0,2>(TileShape{})), + Int<2>{}); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,TMEM_PIPE,2) + } + + /// Produce the inputs to the transform threads by loading inputs from gmem -> smem + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class TensorMapA, class TensorMapB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE cute::tuple + load( + Params const& params, + Load2TransformPipeline pipeline, + Load2TransformPipelineState load2xform_pipeline_state, + cute::tuple> const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + auto [unused_gA, unused_gB, + tAgA_mkl, tBgB_nkl, tAsA, tBsB, + mcast_mask_a, mcast_mask_b, + input_tensormaps] = load_inputs; + + // slice out the work coord from tiled tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + uint32_t skip_wait = (k_tile_count <= 0); + auto pipeline_flag = pipeline.producer_try_acquire(load2xform_pipeline_state, skip_wait); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK mainloop_load2xform_pipeline_state for _writing_ + pipeline.producer_acquire(load2xform_pipeline_state, pipeline_flag); + int write_stage = load2xform_pipeline_state.index(); + + using BarrierType = typename Load2TransformPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(load2xform_pipeline_state); + + // Advance mainloop_pipe + ++load2xform_pipeline_state; + skip_wait = (k_tile_count <= 1); + pipeline_flag = pipeline.producer_try_acquire(load2xform_pipeline_state, skip_wait); + + copy(observed_tma_load_a_->with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage)); + ++k_tile_iter; + } + return cute::make_tuple(load2xform_pipeline_state, k_tile_iter); + } + + + /// Set up the data needed by this collective for load. + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tiled tensor for input A + /// gB_nkl - The tiled tensor for input B + // Other inputs needed for load(): partitioned AB tensors for gmem and smem, and mcast masks + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors, + int32_t const sm_count, int32_t const sm_idx) const { + auto [gA_mkl, gB_nkl] = tile_input_tensors(params, problem_shape_MNKL); + + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.input.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.input.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,3>(tCgB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + // Fetch a copy of tensormaps for the CTA from Params + auto input_tensormaps = tensormaps_init(params, sm_count, sm_idx); + + return cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b, // multicast masks + input_tensormaps); // for tma descriptor modification (per-CTA tensormap copy) + } + + template< + class KTileIterator, class Accumulator, + class GTensorA, class SrcCopyA, class DstCopyA, class SrcTensorA, class DstTensorA, + class GTensorB, class SrcCopyB, class DstCopyB, class SrcTensorB, class DstTensorB + > + CUTLASS_DEVICE auto + transform( + Load2TransformPipeline load2transform_pipeline, + Load2TransformPipelineState load2transform_pipeline_consumer_state, + Transform2MmaPipeline transform2mma_pipeline, + Transform2MmaPipelineState transform2mma_pipeline_producer_state, + Accumulator accumulators, + cute::tuple input_operands, + KTileIterator k_tile_iter, int k_tile_count) { + cutlass::arch::NamedBarrier transform_barrier(NumTransformationThreads, cutlass::arch::ReservedNamedBarriers::TransformBarrier); + + // tAsA : (Copy,#Copy),MMA_Rest,MMA_M_Rest,MMA_K_Rest, SmemStages (In SMEM) + // tAtACompute : (Copy,#Copy),MMA_Rest,MMA_M_Rest,MMA_K_Rest, NumComputeMtxs, SmemStages (In TMEM) + // tBsB : (Copy,#Copy),MMA_Rest,MMA_N_Rest,MMA_K_Rest, SmemStages (In SMEM) + // tBsBCompute : (Copy,#Copy),MMA_Rest,MMA_N_Rest,MMA_K_Rest, SmemStages (In SMEM) + auto [unused_tAgA, src_copy_A, dst_copy_A, tAsA, tAtACompute, + unused_tBgB, src_copy_B, dst_copy_B, tBsB, tBsBCompute] = input_operands; + + // Create the tensors in registers + auto tArA = make_tensor(tAsA(_,_,_,_,0).shape()); + auto tArA_conj = make_tensor(tAsA(_,_,_,_,0).shape()); + auto tArA_swap = make_tensor(tAsA(_,_,_,_,0).shape()); + auto tBrB = make_tensor(tBsB(_,_,_,_,0).shape()); + + uint32_t skip_wait = (k_tile_count <= 0); + auto load2transform_flag = load2transform_pipeline.consumer_try_wait(load2transform_pipeline_consumer_state, skip_wait); + auto transform2mma_flag = transform2mma_pipeline.producer_try_acquire(transform2mma_pipeline_producer_state, skip_wait); + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + + load2transform_pipeline.consumer_wait(load2transform_pipeline_consumer_state, load2transform_flag); + transform2mma_pipeline.producer_acquire(transform2mma_pipeline_producer_state, transform2mma_flag); + + int load2transform_consumer_index = load2transform_pipeline_consumer_state.index(); + int transform2mma_producer_index = transform2mma_pipeline_producer_state.index(); + + auto curr_load2transform_pipeline_consumer_state = load2transform_pipeline_consumer_state; + auto curr_transform2mma_pipeline_producer_state = transform2mma_pipeline_producer_state; + + // Copy the input A matrix from SMEM + copy(src_copy_A, tAsA(_,_,_,_,load2transform_consumer_index), tArA); + // Copy the input B matrix from SMEM + copy(src_copy_B, tBsB(_,_,_,_,load2transform_consumer_index), tBrB); + + // First MMA, A.real * B.real - A.imag * B.imag + // Compose [real, -imag] copy for A TMEM + // Reflect the conjugation of B through A + if constexpr (cute::is_same_v) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tArA); i++) { + tArA_conj(i) = {tArA(i).real(), -tArA(i).imag()}; + } + } + else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tArA); i++) { + tArA_conj(i) = tArA(i); + } + } + // Write to TMEM + copy(dst_copy_A, tArA_conj, tAtACompute(_,_,_,_,0,transform2mma_producer_index)); + + // Second MMA, A.imag * B.real + A.real * B.imag + // Compose [imag, real] copy for A TMEM + // Reflect the conjugation of B through A + auto transform_element = [] (ElementAMma const& tArA_i) -> ElementAMma { + if constexpr (cute::is_same_v && cute::is_same_v) { // CC + return {-tArA_i.imag(), -tArA_i.real()}; + } + else if constexpr (cute::is_same_v && not cute::is_same_v) { // CN/CT + return {-tArA_i.imag(), tArA_i.real()}; + } + else if constexpr (not cute::is_same_v && cute::is_same_v) { // NC/TC + return {tArA_i.imag(), -tArA_i.real()}; + } + else { // TN/NT/NN/TT + return {tArA_i.imag(), tArA_i.real()}; + } + }; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tArA); i++) { + tArA_swap(i) = transform_element(tArA(i)); + } + + // Write to TMEM + copy(dst_copy_A, tArA_swap, tAtACompute(_,_,_,_,1,transform2mma_producer_index)); + + // Write the B matrix to SMEM without any changes + copy(dst_copy_B, tBrB, tBsBCompute(_,_,_,_,transform2mma_producer_index)); + + // Loads from SMEM are done. Signal the mainloop load as early as possible + transform_barrier.sync(); + load2transform_pipeline.consumer_release(curr_load2transform_pipeline_consumer_state); + + // fence for SMEM writes + cutlass::arch::fence_view_async_shared(); + if constexpr (is_tmem::value) { + // fence for TMEM writes if A operand is coming from TMEM + cutlass::arch::fence_view_async_tmem_store(); + } + + // Let the MMA know we are done transforming + transform2mma_pipeline.producer_commit(curr_transform2mma_pipeline_producer_state); + // Next pipeline stage + ++load2transform_pipeline_consumer_state; + ++transform2mma_pipeline_producer_state; + + skip_wait = (k_tile_count <= 1); + // Peek the next pipeline stage's barriers + load2transform_flag = load2transform_pipeline.consumer_try_wait(load2transform_pipeline_consumer_state, skip_wait); + transform2mma_flag = transform2mma_pipeline.producer_try_acquire(transform2mma_pipeline_producer_state, skip_wait); + } + return cute::make_tuple(load2transform_pipeline_consumer_state, transform2mma_pipeline_producer_state); + } + + template + CUTLASS_DEVICE auto + transform_init( + Params const& params, + ProblemShape_MNKL const& problem_shape_MNKL, + Accumulator accumulators, + TensorStorage& shared_storage) { + auto [gA_mkl, gB_nkl] = tile_input_tensors(params, problem_shape_MNKL); + + Tensor sA_orig = make_tensor(make_smem_ptr(shared_storage.input.smem_A.begin()), SmemLayoutA{}); + Tensor sA = as_position_independent_swizzle_tensor(sA_orig); + Tensor sACompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_ACompute.begin()), SmemLayoutACompute{}); + + Tensor sB_orig = make_tensor(make_smem_ptr(shared_storage.input.smem_B.begin()), SmemLayoutB{}); + Tensor sB = as_position_independent_swizzle_tensor(sB_orig); + Tensor sBCompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_BCompute.begin()), SmemLayoutBCompute{}); + + // Map input, compute, and fragment tensors to + // Copy strategies and partitioned tensors. These will become the input + // operands of the transform function. Depending on MMA atom type, the + // operands can reside in SMEM or TMEM + auto setup_copy_ops = [&] (auto tensor_input, auto input_copy_atom, + auto tensor_compute, auto make_fragment, auto compute_copy_atom) constexpr { + auto fragment_compute = make_fragment(tensor_compute); + if constexpr (cute::is_tmem>::value) { + // For M=128 with 2CTA MMA atoms, the TMEM tensor for A has a duplicated allocation. + // Instead of allocation a 64x16 TMEM tensor, we have a 128x16 allocation + // See: TmemAllocMode::Duplicated. + Tensor tensor_input2x = [&] () constexpr { + if constexpr (decltype(size<0,0>(fragment_compute) == Int<128>{} && size<0,0>(tensor_input) == Int<64>{})::value) { + return make_tensor(tensor_input.data(), + logical_product(tensor_input.layout(), + make_tile(make_tile(Layout<_2,_0>{},_),_,_,_))); // ((128,16),m,k,PIPE) + } + else { + return tensor_input; + } + }(); + + fragment_compute.data() = accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + // If operand comes from TMEM, create the TMEM_STORE based copy + auto reg2tmem_tiled_copy = make_tmem_copy(compute_copy_atom, fragment_compute(_,_,_,0,0)); + auto thr_reg2tmem_tiled_copy = reg2tmem_tiled_copy.get_slice(threadIdx.x % NumTransformationThreads); + auto partitioned_tensor_input = thr_reg2tmem_tiled_copy.partition_S(tensor_input2x); + auto partitioned_tensor_compute = thr_reg2tmem_tiled_copy.partition_D(fragment_compute); + // Source copy is based on the source operand of TMEM_STORE copy. + auto smem2reg_tiled_copy = make_tiled_copy_S(Copy_Atom, ElementAMma>{}, reg2tmem_tiled_copy); + return cute::make_tuple(smem2reg_tiled_copy, reg2tmem_tiled_copy, partitioned_tensor_input, partitioned_tensor_compute); + } + else { + // If the operand comes from SMEM, create SMEM copy. + auto tensor_compute_ind_sw = as_position_independent_swizzle_tensor(tensor_compute); + auto reg2smem_tiled_copy = make_cotiled_copy(compute_copy_atom, Layout, Stride< _8,_1>>{}, + tensor_compute(_,_,_,0).layout()); + + // Source copy is based on the source operand of copy. + auto smem2reg_tiled_copy = make_tiled_copy_S(input_copy_atom, reg2smem_tiled_copy); + auto thr_smem2reg_tiled_copy = smem2reg_tiled_copy.get_slice(threadIdx.x % NumTransformationThreads); + auto thr_reg2smem_tiled_copy = reg2smem_tiled_copy.get_slice(threadIdx.x % NumTransformationThreads); + auto partitioned_tensor_input = thr_reg2smem_tiled_copy.partition_S(tensor_input); + auto partitioned_tensor_compute = thr_reg2smem_tiled_copy.partition_D(tensor_compute_ind_sw); + + return cute::make_tuple(smem2reg_tiled_copy, reg2smem_tiled_copy, partitioned_tensor_input, partitioned_tensor_compute); + } + }; + + auto [src_copy_A, dst_copy_A, tAsA, tAtACompute] = + setup_copy_ops(sA, InputCopyAtomA{}, sACompute, [&](auto &arg) {return TiledMma::make_fragment_A(arg);}, ComputeCopyAtomA{}); + + auto [src_copy_B, dst_copy_B, tBsB, tBsBCompute] = + setup_copy_ops(sB, InputCopyAtomB{}, sBCompute, [&](auto &arg) {return TiledMma::make_fragment_B(arg);}, ComputeCopyAtomB{}); + + return cute::make_tuple(gA_mkl, src_copy_A, dst_copy_A, tAsA, tAtACompute, + gB_nkl, src_copy_B, dst_copy_B, tBsB, tBsBCompute); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgEngine, class FrgLayout, + class TensorA, class TensorB + > + CUTLASS_DEVICE auto + mma( + Transform2MmaPipeline transform2mma_pipeline, + Transform2MmaPipelineState transform2mma_pipeline_consumer_state, + Mma2AccumPipeline mma2accum_pipeline, + Mma2AccumPipelineState mma2accum_pipeline_producer_state, + cute::Tensor const& accumulators, + cute::tuple const& input_operands, + int k_tile_count + ) { + TiledMma tiled_mma; + + // tCrA : (MMA), MMA_M, MMA_K, NumComputeMtxs, SmemStage (In TMEM) + // We use SMEM stages to match #buffers in Load <-> Convert + // tCrB : (MMA), MMA_N, MMA_K, SmemStages (In SMEM) + auto const [tCrA, tCrB] = input_operands; + + auto curr_transform2mma_pipeline_consumer_state = transform2mma_pipeline_consumer_state; + auto next_transform2mma_pipeline_consumer_state = transform2mma_pipeline_consumer_state; + uint32_t skip_wait = (k_tile_count <= 0); + auto transform2mma_flag = transform2mma_pipeline.consumer_try_wait(next_transform2mma_pipeline_consumer_state, skip_wait); + ++next_transform2mma_pipeline_consumer_state; + + mma2accum_pipeline.producer_acquire(mma2accum_pipeline_producer_state); + + constexpr int RealAccumIndex = 0; + constexpr int ImagAccumIndex = 1; + + int mma2accum_pipeline_producer_state_index = mma2accum_pipeline_producer_state.index(); + auto tCtC_real = accumulators(_,_,_,RealAccumIndex,mma2accum_pipeline_producer_state_index); + auto tCtC_imag = accumulators(_,_,_,ImagAccumIndex,mma2accum_pipeline_producer_state_index); + auto curr_mma2accum_pipeline_producer_state = mma2accum_pipeline_producer_state; + ++mma2accum_pipeline_producer_state; + + // + // PIPELINED MAIN LOOP + // + + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + + transform2mma_pipeline.consumer_wait(curr_transform2mma_pipeline_consumer_state, transform2mma_flag); + + int transform2mma_pipeline_consumer_state_index = curr_transform2mma_pipeline_consumer_state.index(); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < StagesPerTile; ++k_block) { + + auto tCrA_conj = tCrA(_,_,_,Int<0>{},transform2mma_pipeline_consumer_state_index); + auto tCrA_swap = tCrA(_,_,_,Int<1>{},transform2mma_pipeline_consumer_state_index); + + auto tCrB0 = tCrB(_,_,_,transform2mma_pipeline_consumer_state_index); + + // A conjugate * B + cute::gemm(tiled_mma, tCrA_conj(_,_,k_block), tCrB0(_,_,k_block), tCtC_real); // A[0]*B[0] + // A swapped * B + cute::gemm(tiled_mma, tCrA_swap(_,_,k_block), tCrB0(_,_,k_block), tCtC_imag); // A[0]*B[0] + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + transform2mma_pipeline.consumer_release(curr_transform2mma_pipeline_consumer_state); + + skip_wait = (k_tile_count <= 1); + transform2mma_flag = transform2mma_pipeline.consumer_try_wait(next_transform2mma_pipeline_consumer_state, skip_wait); + + curr_transform2mma_pipeline_consumer_state = next_transform2mma_pipeline_consumer_state; + ++next_transform2mma_pipeline_consumer_state; + } + + mma2accum_pipeline.producer_commit(curr_mma2accum_pipeline_producer_state); + + return cute::make_tuple(curr_transform2mma_pipeline_consumer_state, mma2accum_pipeline_producer_state); + } + + template + CUTLASS_DEVICE auto + mma_init(cute::Tensor const& accumulators, TensorStorage& shared_storage) const { + TiledMma tiled_mma; + + Tensor tCrA = [&] () constexpr { + if constexpr (cute::is_base_of::value) { + Tensor sACompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_ACompute.begin()), SmemLayoutACompute{}); + return tiled_mma.make_fragment_A(sACompute); + } + else { + auto tCrA = tiled_mma.make_fragment_A(shape(SmemLayoutACompute{})); + tCrA.data() = accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + return tCrA; + } + } (); + Tensor sBCompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_BCompute.begin()), SmemLayoutBCompute{}); + Tensor tCrB = tiled_mma.make_fragment_B(sBCompute); + return cute::make_tuple(tCrA, tCrB); + } + + template + CUTLASS_DEVICE auto + accum_init(cute::Tensor const& accumulators, TmemCopyAtom, EpilogueTile) { + return accumulators; + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + + CUTLASS_DEVICE auto + tensormaps_init(Params const& mainloop_params, int32_t const sm_count, int32_t const sm_idx) const { + cute::TmaDescriptor* gmem_tensormap = mainloop_params.tensormaps; + + cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx]; + cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count]; + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to gmem for modification later + Tensor pA_tensormap = make_tensor(observed_tma_load_a_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor gA_tensormap = make_tensor(tma_desc_a, Int<1>{}, Int<1>{}); + Tensor pB_tensormap = make_tensor(observed_tma_load_b_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor gB_tensormap = make_tensor(tma_desc_b, Int<1>{}, Int<1>{}); + + copy(recast(pA_tensormap), recast(gA_tensormap)); + copy(recast(pB_tensormap), recast(gB_tensormap)); + } + + return cute::make_tuple(tma_desc_a, tma_desc_b); + } + + // Bringing tensormaps to smem (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_fetch_to_smem( + TensorMapStorage& shared_tensormap, + cute::tuple const& input_tensormaps) const { + Tensor gA_tensormap = make_tensor(make_gmem_ptr(get<0>(input_tensormaps)), Int<1>{}, Int<1>{}); + Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_A), Int<1>{}, Int<1>{}); + Tensor gB_tensormap = make_tensor(make_gmem_ptr(get<1>(input_tensormaps)), Int<1>{}, Int<1>{}); + Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_B), Int<1>{}, Int<1>{}); + + copy(recast(gA_tensormap), recast(sA_tensormap)); + copy(recast(gB_tensormap), recast(sB_tensormap)); + + cp_async_fence(); + cp_async_wait<0>(); + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE + void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormap, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_A, + mainloop_params.ptr_A[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_B, + mainloop_params.ptr_B[next_batch]); + } + + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + TensorMapStorage& shared_tensormap, + Params const& mainloop_params, + cute::tuple const& input_tensormaps, + int32_t next_batch, + uint32_t lane_predicate) { + if (lane_predicate) { + // Bringing tensormaps to smem + tensormaps_fetch_to_smem(shared_tensormap, input_tensormaps); + + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormap, mainloop_params, next_batch); + } + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release ( + TensorMapStorage& shared_tensormap, + cute::tuple const& input_tensormaps) { + if (cute::elect_one_sync()) { + // Perform using same thread as the one that issued TMA store, separate these out as far as possible to hide latency + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + // Entire warp must do this (i.e. it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormap.smem_tensormap_A); + tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormap.smem_tensormap_B); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::tuple const& input_tensormaps) { + cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps)); + } + +protected: + + template + CUTLASS_DEVICE + constexpr auto + tile_input_tensors(Params const& params, ProblemShape_MNKL const& problem_shape_MNKL) const { + using X = cute::Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + // Problem Shape and therefore strides that we construct are [M,N,K,L], but since here for the TMA loads + // we are managing TMA descriptors to change batches, we need to neglect the L mode + const int32_t mock_L = 1; + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,mock_L)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K,mock_L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + return cute::make_tuple(gA_mkl, gB_nkl); + } + + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + typename Params::TMA_B const* observed_tma_load_b_{nullptr}; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_planar_complex.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_planar_complex.hpp new file mode 100644 index 0000000..44bdc8c --- /dev/null +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_planar_complex.hpp @@ -0,0 +1,963 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/cuda_host_adapter.hpp" +#include "cutlass/detail/sm100_tmem_helper.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int Stages, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, + class TileShape_, // Static cluster shape or dynamic (int, int, _1) + class ElementA_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMmaPair_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100ArrayTmaUmmaWarpSpecializedPlanarComplex< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMmaPair_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + + // Determine MMA type: MMA_1SM vs MMA_2SM + using TiledMmaPair = TiledMmaPair_; + using TiledMma = typename TiledMmaPair::TiledMmaAPosAtom; + using TiledMmaANeg = typename TiledMmaPair::TiledMmaANegAtom; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + + using DispatchPolicy = MainloopSm100ArrayTmaUmmaWarpSpecializedPlanarComplex< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; + using TileShape = TileShape_; + + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TileShape should be evenly divided by TiledMma"); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + + // Define A and B block shapes for reduced size TMA_LOADs + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + + // Multiple buffer the TMA descriptors for each SM so that we can update them asynchronously. + // This should be larger than the total number of TMA requests inflight (from update to issued to returned). + // This can be calculated by SchedulerStages + max(TmaStages) + 2 (for consumer and producer in-flight accessies). + constexpr static uint32_t NumTmaDescriptorsPerSm = SchedulerPipelineStageCount + Stages + 2; + + using ElementA = ElementA_; + using ElementAMma = typename TiledMma::ValTypeA; + using StrideA = StrideA_; + using InternalStrideA = cute::remove_pointer_t; + using ElementB = ElementB_; + using ElementBMma = typename TiledMma::ValTypeB; + using StrideB = StrideB_; + using InternalStrideB = cute::remove_pointer_t; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::Stages, + ClusterShape, + AtomThrShapeMNK>; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M, K)"); + static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtomB must be rank 2 (N, K)"); + static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::ArrayEngine> smem_A_real; + cute::ArrayEngine> smem_A_imag; + cute::ArrayEngine> smem_B_real; + cute::ArrayEngine> smem_B_imag; + } tensors; + + struct TensorMapStorage : cute::aligned_struct<128, _0> { + cute::TmaDescriptor smem_tensormap_A_real; + cute::TmaDescriptor smem_tensormap_A_imag; + cute::TmaDescriptor smem_tensormap_B_real; + cute::TmaDescriptor smem_tensormap_B_imag; + } tensormaps; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly + static constexpr uint32_t TmaTransactionBytes = 2 * ( + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * (cosize(take<0,3>(SmemLayoutA{})) * static_cast(cute::sizeof_bits::value))) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * (cosize(take<0,3>(SmemLayoutB{})) * static_cast(cute::sizeof_bits::value)))); + + template + struct TmemStorage { + AccTensor accumulators; + }; + + // Host side kernel arguments + struct Arguments { + ElementA const** ptr_A_real{nullptr}; + StrideA dA_real{}; + ElementA const** ptr_A_imag{nullptr}; + StrideA dA_imag{}; + ElementB const** ptr_B_real{nullptr}; + StrideB dB_real{}; + ElementB const** ptr_B_imag{nullptr}; + StrideB dB_imag{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_B = decltype(make_tma_atom_B_sm100( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + TMA_A tma_load_a_real; + TMA_A tma_load_a_imag; + TMA_B tma_load_b_real; + TMA_B tma_load_b_imag; + TMA_A tma_load_a_real_fallback; + TMA_A tma_load_a_imag_fallback; + TMA_B tma_load_b_real_fallback; + TMA_B tma_load_b_imag_fallback; + dim3 cluster_shape_fallback; + cute::TmaDescriptor* tensormaps; + ElementA const** ptr_A_real; + ElementA const** ptr_A_imag; + ElementB const** ptr_B_real; + ElementB const** ptr_B_imag; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) { + if constexpr (IsDynamicCluster) { + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_real_ = is_fallback_cluster ? ¶ms.tma_load_a_real_fallback : ¶ms.tma_load_a_real; + observed_tma_load_a_imag_ = is_fallback_cluster ? ¶ms.tma_load_a_imag_fallback : ¶ms.tma_load_a_imag; + observed_tma_load_b_real_ = is_fallback_cluster ? ¶ms.tma_load_b_real_fallback : ¶ms.tma_load_b_real; + observed_tma_load_b_imag_ = is_fallback_cluster ? ¶ms.tma_load_b_imag_fallback : ¶ms.tma_load_b_imag; + } + else { + observed_tma_load_a_real_ = ¶ms.tma_load_a_real; + observed_tma_load_a_imag_ = ¶ms.tma_load_a_imag; + observed_tma_load_b_real_ = ¶ms.tma_load_b_real; + observed_tma_load_b_imag_ = ¶ms.tma_load_b_imag; + } + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + // Tensor shapes for Ptr-Array are initialized correctly here. + auto [M,N,K,mock_L] = problem_shape.get_host_problem_shape(0); + + // Batches/Groups are managed by using appropriate pointers to input matrices + mock_L = 1; + + // Tensor pointers will be fixed before the first access + ElementA const* ptr_A_real_first_batch = nullptr; + ElementA const* ptr_A_imag_first_batch = nullptr; + + ElementB const* ptr_B_real_first_batch = nullptr; + ElementB const* ptr_B_imag_first_batch = nullptr; + + Tensor tensor_a_real = make_tensor(ptr_A_real_first_batch, make_layout(make_shape(M,K,mock_L), args.dA_real)); + Tensor tensor_a_imag = make_tensor(ptr_A_imag_first_batch, make_layout(make_shape(M,K,mock_L), args.dA_imag)); + + Tensor tensor_b_real = make_tensor(ptr_B_real_first_batch, make_layout(make_shape(N,K,mock_L), args.dB_real)); + Tensor tensor_b_imag = make_tensor(ptr_B_imag_first_batch, make_layout(make_shape(N,K,mock_L), args.dB_imag)); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + + auto cluster_shape_fallback = conditional_return(make_shape(hw_info.cluster_shape_fallback.x, hw_info.cluster_shape_fallback.y, Int<1>{}), ClusterShape{}); + // Cluster layout for TMA construction + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + + typename Params::TMA_A tma_load_a_real = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a_real, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_imag = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a_imag, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b_real = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b_real, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b_imag = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b_imag, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_real_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a_real, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_A tma_load_a_imag_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a_imag, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_real_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b_real, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_imag_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b_imag, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + return { + tma_load_a_real, + tma_load_a_imag, + tma_load_b_real, + tma_load_b_imag, + tma_load_a_real_fallback, + tma_load_a_imag_fallback, + tma_load_b_real_fallback, + tma_load_b_imag_fallback, + hw_info.cluster_shape_fallback, + reinterpret_cast(workspace), + reinterpret_cast(args.ptr_A_real), + reinterpret_cast(args.ptr_A_imag), + reinterpret_cast(args.ptr_B_real), + reinterpret_cast(args.ptr_B_imag) + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr uint32_t NumInputTensors = 4; + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies + return (NumInputTensors * SizeOfCuTensorMap * sm_count * NumTmaDescriptorsPerSm); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + ProblemShape problem_shape, + [[maybe_unused]] Arguments const& args) { + auto [M,N,K,L] = problem_shape.get_host_problem_shape(0); + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = 128 / cute::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = 128 / cute::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE static + auto + partition_accumulator_shape() { + return append(partition_shape_C(TiledMma{}, take<0,2>(TileShape{})), Int<2>{}); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,TMEM_PIPE,2) + } + + template + CUTLASS_DEVICE static + auto + slice_accumulator(TmemStorage tmem_storage, int stage) { + return tmem_storage.accumulators(_,_,_,_,stage); + } + + template + CUTLASS_DEVICE static + auto + init_tmem_tensors(EpilogueTile epi_tile) { + TiledMma tiled_mma; + auto acc_shape = partition_accumulator_shape(); + // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. + Tensor accumulators = cutlass::detail::make_sm100_accumulator( + tiled_mma, acc_shape, EpilogueTile{}); + TmemStorage tmem_storage; + tmem_storage.accumulators = accumulators; + return tmem_storage; + } + + template + CUTLASS_DEVICE static + void + set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) { + tmem_storage.accumulators.data() = tmem_base_addr; + } + + /// Set up the data needed by this collective for load. + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_(real/imag)_mkl - The tiled tma tensor for input A_(real/imag) + /// gB_(real/imag)_nkl - The tiled tma tensor for input B_(real/imag) + /// tAsA_(real/imag) - partitioned smem tensor for A_(real/imag) + /// tBsB_(real/imag) - partitioned smem tensor for B_(real/imag) + /// mcast_mask_a - tma multicast mask for A_(real/imag) + /// mcast_mask_b - tma multicast mask for B_(real/imag) + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, int32_t const sm_idx, + [[maybe_unused]] int32_t num_groups, + [[maybe_unused]] int32_t init_group) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + // Problem Shape and therefore strides that we construct are [M,N,K,L], but since here for the TMA loads + // we are managing TMA descriptors to change batches, we need to neglect the L mode + const int32_t mock_L = 1; + + // Represent the full tensors -- get these from TMA + Tensor mA_real_mkl = observed_tma_load_a_real_->get_tma_tensor(make_shape(M,K,mock_L)); + Tensor mA_imag_mkl = observed_tma_load_a_imag_->get_tma_tensor(make_shape(M,K,mock_L)); + Tensor mB_real_nkl = observed_tma_load_b_real_->get_tma_tensor(make_shape(N,K,mock_L)); + Tensor mB_imag_nkl = observed_tma_load_b_imag_->get_tma_tensor(make_shape(N,K,mock_L)); + + // Tile the tensors and defer the slice + Tensor gA_real_mkl = local_tile(mA_real_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor gA_imag_mkl = local_tile(mA_imag_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_N, BLK_K, m, k, l) + + Tensor gB_real_nkl = local_tile(mB_real_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + Tensor gB_imag_nkl = local_tile(mB_imag_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + // Partition for this CTA + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_real_mkl = cta_mma.partition_A(gA_real_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgA_imag_mkl = cta_mma.partition_A(gA_imag_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + + Tensor tCgB_real_nkl = cta_mma.partition_B(gB_real_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + Tensor tCgB_imag_nkl = cta_mma.partition_B(gB_imag_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sA_real = make_tensor(make_smem_ptr(shared_tensors.smem_A_real.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sA_imag = make_tensor(make_smem_ptr(shared_tensors.smem_A_imag.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + + Tensor sB_real = make_tensor(make_smem_ptr(shared_tensors.smem_B_real.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + Tensor sB_imag = make_tensor(make_smem_ptr(shared_tensors.smem_B_imag.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_real_mkl, tAsA_real] = tma_partition(*observed_tma_load_a_real_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA_real), group_modes<0,3>(tCgA_real_mkl)); + auto [tAgA_imag_mkl, tAsA_imag] = tma_partition(*observed_tma_load_a_imag_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA_imag), group_modes<0,3>(tCgA_imag_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_real_nkl, tBsB_real] = tma_partition(*observed_tma_load_b_real_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB_real), group_modes<0,3>(tCgB_real_nkl)); + auto [tBgB_imag_nkl, tBsB_imag] = tma_partition(*observed_tma_load_b_imag_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB_imag), group_modes<0,3>(tCgB_imag_nkl)); + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + auto ret = cute::make_tuple( + gA_real_mkl, gA_imag_mkl, gB_real_nkl, gB_imag_nkl, // for scheduler + tAgA_real_mkl, tAgA_imag_mkl, tBgB_real_nkl, tBgB_imag_nkl, // for input tensor values + tAsA_real, tAsA_imag, tBsB_real, tBsB_imag, // for input tensor values + mcast_mask_a, mcast_mask_b // multicast masks + ); + + if constexpr (IsTensorMapUpdateAsync) { + return ret; + } + else { + // Fetch a copy of tensormaps for the CTA from Params + auto input_tensormaps = tensormaps_init(params, shared_tensormaps, sm_count, sm_idx); + return cute::tuple_cat(ret, cute::make_tuple(input_tensormaps)); + } + } + + /// Set up the data needed by this collective for mma compute. + template + CUTLASS_DEVICE auto + mma_init( + [[maybe_unused]] TmemStorage tmem_storage, + TensorStorage& shared_tensors) const { + Tensor sA_real = make_tensor(make_smem_ptr(shared_tensors.smem_A_real.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sA_imag = make_tensor(make_smem_ptr(shared_tensors.smem_A_imag.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + + Tensor sB_real = make_tensor(make_smem_ptr(shared_tensors.smem_B_real.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sB_imag = make_tensor(make_smem_ptr(shared_tensors.smem_B_imag.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA_real = TiledMma::make_fragment_A(sA_real); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrA_imag = TiledMma::make_fragment_A(sA_imag); // (MMA,MMA_M,MMA_K,PIPE) + + Tensor tCrB_real = TiledMma::make_fragment_B(sB_real); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCrB_imag = TiledMma::make_fragment_B(sB_imag); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA_real)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB_real)); // PIPE + + TiledMma tiled_mma_a_pos; + TiledMmaANeg tiled_mma_a_neg; + + return cute::make_tuple(tiled_mma_a_pos, tiled_mma_a_neg, tCrA_real, tCrA_imag, tCrB_real, tCrB_imag); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class TensorMapA, class TensorMapB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load( + Params const& params, + MainloopPipeline mainloop_pipeline, + MainloopPipelineState mainloop_pipe_producer_state, + cute::tuple> const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count, + bool did_batch_change, + [[maybe_unused]] int curr_batch) { + + auto [unused_gA_real, unused_gA_imag, unused_gB_real, unused_gB_imag, + tAgA_real_mkl, tAgA_imag_mkl, tBgB_real_nkl, tBgB_imag_nkl, + tAsA_real, tAsA_imag, tBsB_real, tBsB_imag, + mcast_mask_a, mcast_mask_b, + input_tensormaps] = load_inputs; + + // Check to see if tensormaps have been replaced in gmem + if (did_batch_change) { + tensormaps_fence_acquire(input_tensormaps); + } + + // slice out the work coord from partitioned tensors + Tensor tAgA_real = tAgA_real_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tAgA_imag = tAgA_imag_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + + Tensor tBgB_real = tBgB_real_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + Tensor tBgB_imag = tBgB_imag_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_real_->with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a), tAgA_real(_,*k_tile_iter), tAsA_real(_,write_stage)); + copy(observed_tma_load_a_imag_->with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_a), tAgA_imag(_,*k_tile_iter), tAsA_imag(_,write_stage)); + + copy(observed_tma_load_b_real_->with(get<2>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB_real(_,*k_tile_iter), tBsB_real(_,write_stage)); + copy(observed_tma_load_b_imag_->with(get<3>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB_imag(_,*k_tile_iter), tBsB_imag(_,write_stage)); + } + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline mainloop_pipeline, MainloopPipelineState mainloop_pipe_producer_state) { + // Issue the epilogue waits + /* This helps avoid early exit of ctas in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class AccumulatorPipeline, + class FrgEngine, class FrgLayout, + class FragmentA, class FragmentB, + class CtaTileCoord + > + CUTLASS_DEVICE auto + mma(cute::tuple pipelines, + cute::tuple pipeline_states, + cute::Tensor const& accumulators, + cute::tuple const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 4 && size<3>(FrgLayout{}) == _2{}, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N, _2)"); + + auto [tiled_mma_a_pos, tiled_mma_a_neg, tCrA_real, tCrA_imag, tCrB_real, tCrB_imag] = mma_inputs; + + auto [mainloop_pipeline, accumulator_pipeline] = pipelines; + auto [mainloop_pipe_consumer_state, accumulator_pipe_producer_state] = pipeline_states; + + uint32_t skip_wait = k_tile_count <= 0; + auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + // + // PIPELINED MAIN LOOP + // + tiled_mma_a_pos.accumulate_ = UMMA::ScaleOut::Zero; + tiled_mma_a_neg.accumulate_ = UMMA::ScaleOut::Zero; + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + + auto accumulators_real = accumulators(_,_,_,0); + auto accumulators_imag = accumulators(_,_,_,1); + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // WAIT on mainloop_pipe_consumer_state until its data are available + // (phase bit flips from mainloop_pipe_consumer_state.phase() value) + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + + // Compute on k_tile + int read_stage = mainloop_pipe_consumer_state.index(); + // Save current mainlop pipeline read state + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + + // Advance mainloop_pipe + ++mainloop_pipe_consumer_state; + --k_tile_count; + skip_wait = k_tile_count <= 0; + // Peek at next iteration + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + // Unroll the K mode manually so we can set scale C to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA_real); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + + // Calculate real acc, 1st step + // realAcc += realA * realB + cute::gemm(tiled_mma_a_pos, tCrA_real(_,_,k_block,read_stage), tCrB_real(_,_,k_block,read_stage), accumulators_real); + + // Calculate imag acc, 1st step + if constexpr (cute::is_same_v) { + // imagAcc += realA * (-imagB) + cute::gemm(tiled_mma_a_neg, tCrA_real(_,_,k_block,read_stage), tCrB_imag(_,_,k_block,read_stage), accumulators_imag); + } + else { + // imagAcc += realA * imagB + cute::gemm(tiled_mma_a_pos, tCrA_real(_,_,k_block,read_stage), tCrB_imag(_,_,k_block,read_stage), accumulators_imag); + } + + tiled_mma_a_pos.accumulate_ = UMMA::ScaleOut::One; + tiled_mma_a_neg.accumulate_ = UMMA::ScaleOut::One; + + // Calculate real acc, 2nd step + if constexpr (cute::is_same_v) { + // realAcc -= imagA * imagB + cute::gemm(tiled_mma_a_neg, tCrA_imag(_,_,k_block,read_stage), tCrB_imag(_,_,k_block,read_stage), accumulators_real); + } + else { + // realAcc += imagA * imagB + cute::gemm(tiled_mma_a_pos, tCrA_imag(_,_,k_block,read_stage), tCrB_imag(_,_,k_block,read_stage), accumulators_real); + } + + // Calculate imag acc, 2nd step + if constexpr (cute::is_same_v) { + // imagAcc += (-imagA) * realB + cute::gemm(tiled_mma_a_neg, tCrA_imag(_,_,k_block,read_stage), tCrB_real(_,_,k_block,read_stage), accumulators_imag); + } + else { + // imagAcc += imagA * realB + cute::gemm(tiled_mma_a_pos, tCrA_imag(_,_,k_block,read_stage), tCrB_real(_,_,k_block,read_stage), accumulators_imag); + } + } + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + + return mainloop_pipe_consumer_state; + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + + template + CUTLASS_DEVICE auto + tensormaps_init( + Params const& mainloop_params, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, + int32_t const sm_idx) const { + cute::TmaDescriptor* gmem_tensormap = mainloop_params.tensormaps; + + cute::TmaDescriptor* tma_desc_a_real = &gmem_tensormap[sm_idx * NumTmaDescriptorsPerSm]; + cute::TmaDescriptor* tma_desc_a_imag = &gmem_tensormap[(sm_idx + sm_count) * NumTmaDescriptorsPerSm]; + + cute::TmaDescriptor* tma_desc_b_real = &gmem_tensormap[(sm_idx + 2 * sm_count) * NumTmaDescriptorsPerSm]; + cute::TmaDescriptor* tma_desc_b_imag = &gmem_tensormap[(sm_idx + 3 * sm_count) * NumTmaDescriptorsPerSm]; + + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + Tensor pA_real_tensormap = make_tensor(observed_tma_load_a_real_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sA_real_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A_real), Int<1>{}, Int<1>{}); + Tensor pA_imag_tensormap = make_tensor(observed_tma_load_a_imag_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sA_imag_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A_imag), Int<1>{}, Int<1>{}); + + Tensor pB_real_tensormap = make_tensor(observed_tma_load_b_real_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sB_real_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B_real), Int<1>{}, Int<1>{}); + Tensor pB_imag_tensormap = make_tensor(observed_tma_load_b_imag_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sB_imag_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B_imag), Int<1>{}, Int<1>{}); + + copy(recast(pA_real_tensormap), recast(sA_real_tensormap)); + copy(recast(pA_imag_tensormap), recast(sA_imag_tensormap)); + + copy(recast(pB_real_tensormap), recast(sB_real_tensormap)); + copy(recast(pB_imag_tensormap), recast(sB_imag_tensormap)); + } + __syncwarp(); + + struct TensorMapArray { + cute::TmaDescriptor* tma_desc_a_real; + cute::TmaDescriptor* tma_desc_a_imag; + cute::TmaDescriptor* tma_desc_b_real; + cute::TmaDescriptor* tma_desc_b_imag; + + TensorMapArray() = default; + + CUTLASS_DEVICE + TensorMapArray(cute::TmaDescriptor* tma_desc_a_real, cute::TmaDescriptor* tma_desc_a_imag, + cute::TmaDescriptor* tma_desc_b_real, cute::TmaDescriptor* tma_desc_b_imag) + : tma_desc_a_real(tma_desc_a_real), tma_desc_a_imag(tma_desc_a_imag), + tma_desc_b_real(tma_desc_b_real), tma_desc_b_imag(tma_desc_b_imag) {} + + CUTLASS_DEVICE + cute::tuple + operator[](int32_t idx) { + idx = idx % NumTmaDescriptorsPerSm; + return cute::make_tuple(tma_desc_a_real + idx, tma_desc_a_imag + idx, + tma_desc_b_real + idx, tma_desc_b_imag + idx); + } + }; + + if constexpr (IsTensorMapUpdateAsync) { + return TensorMapArray(tma_desc_a_real, tma_desc_a_imag, tma_desc_b_real, tma_desc_b_imag); + } + else { + return cute::make_tuple(tma_desc_a_real, tma_desc_a_imag, tma_desc_b_real, tma_desc_b_imag); + } + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE + void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormap, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_A_real, + mainloop_params.ptr_A_real[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_A_imag, + mainloop_params.ptr_A_imag[next_batch]); + + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_B_real, + mainloop_params.ptr_B_real[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_B_imag, + mainloop_params.ptr_B_imag[next_batch]); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + cute::tuple const& input_tensormaps, + [[maybe_unused]]ProblemShape problem_shape, + int32_t next_batch + ) { + if (cute::elect_one_sync()) { + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); + } + // Ensure warp is converged before issuing tensormap fence release + __syncwarp(); + // Entire warp must do this (ie its aligned) + tensormaps_cp_fence_release( + shared_tensormaps, + input_tensormaps + ); + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release ( + TensorMapStorage& shared_tensormap, + cute::tuple const& input_tensormaps + ) { + // Entire warp must do this (i.e., it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormap.smem_tensormap_A_real); + tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormap.smem_tensormap_A_imag); + + tma_descriptor_cp_fence_release(get<2>(input_tensormaps), shared_tensormap.smem_tensormap_B_real); + tma_descriptor_cp_fence_release(get<3>(input_tensormaps), shared_tensormap.smem_tensormap_B_imag); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::tuple const& input_tensormaps) { + cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<2>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<3>(input_tensormaps)); + } + +protected: + + typename Params::TMA_A const* observed_tma_load_a_real_ = nullptr; + typename Params::TMA_A const* observed_tma_load_a_imag_ = nullptr; + typename Params::TMA_B const* observed_tma_load_b_real_ = nullptr; + typename Params::TMA_B const* observed_tma_load_b_imag_ = nullptr; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_rcggemm.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_rcggemm.hpp new file mode 100644 index 0000000..9bd85ae --- /dev/null +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_rcggemm.hpp @@ -0,0 +1,900 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/cuda_host_adapter.hpp" +#include "cutlass/detail/sm100_tmem_helper.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +#include "cutlass/detail/collective/moe_stride_utils.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int Stages, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100RCGroupGemmTmaUmmaWarpSpecialized< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + using DispatchPolicy = MainloopSm100RCGroupGemmTmaUmmaWarpSpecialized< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; + using TileShape = TileShape_; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TileShape should be evenly divided by TiledMma"); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + + // Define A and B block shapes for reduced size TMA_LOADs + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + + // Multiple buffer the TMA descriptors for each SM so that we can update them asynchronously. + // This should be larger than the total number of TMA requests inflight (from update to issued to returned). + // This can be calculated by SchedulerStages + max(TmaStages) + 2 (for consumer and producer in-flight accessies). + constexpr static uint32_t NumTmaDescriptorsPerSm = SchedulerPipelineStageCount + Stages + 2; + + using ElementA = ElementA_; + using ElementAMma = typename TiledMma::ValTypeA; + using StrideA = StrideA_; + using InternalStrideA = cute::remove_pointer_t; + using ElementB = ElementB_; + using ElementBMma = typename TiledMma::ValTypeB; + using StrideB = StrideB_; + using InternalStrideB = cute::remove_pointer_t; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::Stages, + ClusterShape, + AtomThrShapeMNK>; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M,K)"); + static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtomB must be rank 2 (N,K)"); + static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + // (MMA_TILE_N,MMA_TILE_K),MMA_N,MMA_K,PIPE) + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + + using TmaInternalElementA = cute::conditional_t, cutlass::tfloat32_t, ElementAMma>; + using TmaInternalElementB = cute::conditional_t, cutlass::tfloat32_t, ElementBMma>; + + using SmemAllocTypeA = cute::conditional_t < 8, uint8_t, ElementAMma>; + using SmemAllocTypeB = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using BitTypeElementA = cute::uint_bit_t>; + using BitTypeElementB = cute::uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = cute::conditional_t; + using RuntimeDataTypeB = cute::conditional_t; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + } tensors; + + struct TensorMapStorage : cute::aligned_struct<128, _0> { + cute::TmaDescriptor smem_tensormap_B; + } tensormaps; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly + static constexpr uint32_t TmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v); + + template + struct TmemStorage { + AccTensor accumulators; + }; + + // Host side kernel arguments + struct Arguments { + ArrayElementA const* ptr_A{nullptr}; + ArrayElementB const** ptr_B{nullptr}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), + make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_B = decltype(make_tma_atom_B_sm100( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}), + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + dim3 cluster_shape_fallback; + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + cute::TmaDescriptor* tensormaps; + ArrayElementA const* ptr_A; + ArrayElementB const** ptr_B; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) + , runtime_data_type_a_(params.runtime_data_type_a) + , runtime_data_type_b_(params.runtime_data_type_b) { + if constexpr (IsDynamicCluster) { + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + } + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape problem_shapes, + Arguments const& args, + void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. + // These will be replaced with correct values before the initial tma load. + auto init_shape = repeat_like(append<4>(typename ProblemShape::UnderlyingProblemShape{}, 1), int32_t(1)); + auto init_M = get<0>(init_shape); + auto init_N = get<1>(init_shape); + auto init_K_A = get<2>(init_shape); + auto init_K_B = get<2>(init_shape); + auto init_L = get<3>(init_shape); + + // Tensor pointers will be fixed before the first access + auto ptr_A_first_batch = recast_ptr(args.ptr_A); + TmaInternalElementB const* ptr_B_first_batch = nullptr; + + auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0); + init_M = get<0>(problem_shape_MNK); + init_K_A = get<2>(problem_shape_MNK); + + auto shape_a = make_shape(init_M, init_K_A, problem_shapes.groups()); + InternalStrideA stride_a = cutlass::make_internal_packed_stride(InternalStrideA{}, shape_a); + + InternalStrideB stride_b = InternalStrideB{}; + + // Batches/Groups are managed by using appropriate pointers to input matrices. + Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(init_M, init_K_A, problem_shapes.groups()), stride_a)); + Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(init_N, init_K_B, init_L), stride_b)); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + return { + tma_load_a, + tma_load_b, + tma_load_a_fallback, + tma_load_b_fallback, + hw_info.cluster_shape_fallback, + args.runtime_data_type_a, + args.runtime_data_type_b, + reinterpret_cast(workspace), + args.ptr_A, + reinterpret_cast(args.ptr_B) + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr uint32_t NumInputTensors = 1; + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + // Allocate gmem space for input tensormaps per each SM, B tensormap copies + return (NumInputTensors * SizeOfCuTensorMap * sm_count * NumTmaDescriptorsPerSm); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + ProblemShape problem_shapes, + [[maybe_unused]] Arguments const& args) { + static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits::value; + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cute::sizeof_bits::value; + + bool implementable = true; + if (problem_shapes.is_host_problem_shape_available()) { + // Check alignment for all problem sizes + for (int i = 0; i < problem_shapes.groups(); i++) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{}); + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{}); + } + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE static + auto + partition_accumulator_shape() { + return partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + } + + template + CUTLASS_DEVICE static + auto + slice_accumulator(TmemStorage tmem_storage, int stage) { + return tmem_storage.accumulators(_,_,_,stage); + } + + template + CUTLASS_DEVICE static + auto + init_tmem_tensors(EpilogueTile epi_tile) { + TiledMma tiled_mma; + auto acc_shape = partition_accumulator_shape(); + // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. + Tensor accumulators = cutlass::detail::make_sm100_accumulator( + tiled_mma, acc_shape, EpilogueTile{}); + TmemStorage tmem_storage; + tmem_storage.accumulators = accumulators; + return tmem_storage; + } + + template + CUTLASS_DEVICE static + void + set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) { + tmem_storage.accumulators.data() = tmem_base_addr; + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tma tensor for input A + /// gB_nkl - The tiled tma tensor for input B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + /// mcast_mask_a - tma multicast mask for A + /// mcast_mask_b - tma multicast mask for B + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, int32_t const sm_idx, + int32_t num_groups, + [[maybe_unused]] int32_t init_group) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + // Problem Shape and therefore strides that we construct are [M,N,K,L], but since here for the TMA loads + // we are managing TMA descriptors to change batches, we need to neglect the L mode + const int32_t mock_L = 1; + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,num_groups)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K,mock_L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + // Partition for this CTA + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + + // Define the CTA-in-Cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,3>(tCgB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + auto ret = cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b // multicast masks + ); + + if constexpr (IsTensorMapUpdateAsync) { + return ret; + } + else { + // Fetch a copy of tensormaps for the CTA from Params + auto input_tensormaps = tensormaps_init(params, shared_tensormaps, sm_count, sm_idx); + return cute::tuple_cat(ret, cute::make_tuple(input_tensormaps)); + } + } + + /// Set up the data needed by this collective for mma compute. + template + CUTLASS_DEVICE auto + mma_init( + [[maybe_unused]] TmemStorage tmem_storage, + TensorStorage& shared_tensors) const { + + // Allocate "fragments/descriptors" for A and B matrices + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); // PIPE + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + // Update instruction descriptor according to runtime argument. + // Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe. + tiled_mma.idesc_.a_format_ = uint8_t(runtime_data_type_a_) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(runtime_data_type_b_) & 0b111; + } + + return cute::make_tuple(tiled_mma, tCrA, tCrB); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class TensorMapB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load( + Params const& params, + MainloopPipeline mainloop_pipeline, + MainloopPipelineState mainloop_pipe_producer_state, + cute::tuple> const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count, + bool did_batch_change, + int curr_batch) { + + auto [unused_gA, unused_gB, + tAgA_mkl, tBgB_nkl, tAsA, tBsB, + mcast_mask_a, mcast_mask_b, + input_tensormaps] = load_inputs; + + // Check to see if tensormaps have been replaced in gmem + if (did_batch_change) { + tensormaps_fence_acquire(input_tensormaps); + } + + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, curr_batch); + Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage)); + } + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline mainloop_pipeline, MainloopPipelineState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class AccumulatorPipeline, + class FrgEngine, class FrgLayout, + class FragmentA, class FragmentB, + class CtaTileCoord + > + CUTLASS_DEVICE auto + mma(cute::tuple pipelines, + cute::tuple pipeline_states, + cute::Tensor& accumulators, + cute::tuple const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + auto [tiled_mma, tCrA, tCrB] = mma_inputs; + + auto [mainloop_pipeline, accumulator_pipeline] = pipelines; + auto [mainloop_pipe_consumer_state, accumulator_pipe_producer_state] = pipeline_states; + + uint32_t skip_wait = k_tile_count <= 0; + auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + // + // PIPELINED MAIN LOOP + // + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // WAIT on mainloop_pipe_consumer_state until its data are available + // (phase bit flips from mainloop_pipe_consumer_state.phase() value) + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + + // Compute on k_tile + int read_stage = mainloop_pipe_consumer_state.index(); + // Save current mainlop pipeline read state + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + + // Advance mainloop_pipe + ++mainloop_pipe_consumer_state; + --k_tile_count; + skip_wait = k_tile_count <= 0; + // Peek at next iteration + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + // Unroll the K mode manually so we can set scale C to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, + tCrA(_,_,k_block,read_stage), + tCrB(_,_,k_block,read_stage), + accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + + return mainloop_pipe_consumer_state; + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + + template + CUTLASS_DEVICE auto + tensormaps_init( + Params const& mainloop_params, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, + int32_t const sm_idx) const { + cute::TmaDescriptor* gmem_tensormap = mainloop_params.tensormaps; + + cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx * NumTmaDescriptorsPerSm]; + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + Tensor pB_tensormap = make_tensor(observed_tma_load_b_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{}); + + copy(recast(pB_tensormap), recast(sB_tensormap)); + } + __syncwarp(); + + struct TensorMapArray { + cute::TmaDescriptor* tma_desc_b; + + TensorMapArray() = default; + + CUTLASS_DEVICE + TensorMapArray(cute::TmaDescriptor* tma_desc_b) : tma_desc_b(tma_desc_b) {} + + CUTLASS_DEVICE + cute::tuple + operator[](int32_t idx) { + idx = idx % NumTmaDescriptorsPerSm; + return cute::make_tuple(tma_desc_b + idx); + } + }; + + if constexpr (IsTensorMapUpdateAsync) { + return TensorMapArray(tma_desc_b); + } + else { + return cute::make_tuple(tma_desc_b); + } + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE + void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B, + mainloop_params.ptr_B[next_batch]); + } + + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_tensor_properties( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl) { + const uint32_t N = get<1>(problem_shape_mnkl); + const uint32_t K = get<2>(problem_shape_mnkl); + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape_B = {1,1,1,1,1}; + cute::array prob_stride_B = {0,0,0,0,0}; + + TmaInternalElementB const* ptr_B = nullptr; + auto internal_shape_b = make_shape(static_cast(N), static_cast(K), 1); + InternalStrideB stride_b = cutlass::make_internal_packed_stride(InternalStrideB{}, internal_shape_b); + + Tensor tensor_b = make_tensor(ptr_B, make_shape(N,K,Int<1>{}), stride_b); + + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_b_, tensor_b, + prob_shape_B, prob_stride_B); + + // Convert strides to byte strides + for (uint64_t& stride : prob_stride_B) { + stride = (stride * sizeof_bits_v) / 8; + } + + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B, + prob_shape_B, + prob_stride_B); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + cute::tuple const& input_tensormaps, + ProblemShape problem_shape, + int32_t next_batch + ) { + if (cute::elect_one_sync()) { + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); + + auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(next_batch), 1); + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties(shared_tensormaps, + mainloop_params, next_batch, problem_shape_MNKL); + } + // Ensure warp is converged before issuing tensormap fence release + __syncwarp(); + // Entire warp must do this (ie its aligned) + tensormaps_cp_fence_release( + shared_tensormaps, + input_tensormaps + ); + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release ( + TensorMapStorage& shared_tensormaps, + cute::tuple const& input_tensormaps + ) { + if constexpr (WaitForInflightTmaRequests) { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + } + // Entire warp must do this (i.e. it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_B); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::tuple const& input_tensormaps) { + cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps)); + } + +protected: + + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + typename Params::TMA_B const* observed_tma_load_b_{nullptr}; + RuntimeDataTypeA runtime_data_type_a_{}; + RuntimeDataTypeB runtime_data_type_b_{}; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp new file mode 100644 index 0000000..f210d5d --- /dev/null +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp @@ -0,0 +1,588 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/arch/memory.h" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int Stages, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100UmmaCpAsyncWarpSpecialized< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + // Statically asserting to ensure only 1x1x1 cluster shape & 1sm setup is received + static_assert(size(AtomThrShapeMNK{}) == 1, "Lower alignment SM100 GEMM only supports 1SM MMA"); + static_assert(size(ClusterShape{}) == 1, "CPASYNC does not support multicast so the cluster shape is restricted to 1, 1, 1"); + + using DispatchPolicy = MainloopSm100UmmaCpAsyncWarpSpecialized< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; + // TileShape refers to MmaTileShape to adapt for runtime cluster shape + using TileShape = TileShape_; + + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TileShape should be evenly divided by TiledMma"); + + // Define A and B block shapes + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + using LoadShapeA_MK = decltype(select<0,2>(TileShape{})); + using LoadShapeB_NK = decltype(select<1,2>(TileShape{})); + + // CtaShape_MNK is queried from collective in all kernel layers + using CtaShape_MNK = TileShape; + + using ElementA = ElementA_; + using ElementAMma = typename TiledMma::ValTypeA; + using StrideA = StrideA_; + using ElementB = ElementB_; + using ElementBMma = typename TiledMma::ValTypeB; + using StrideB = StrideB_; + + static constexpr bool IsRuntimeDataTypeA = cute::is_same_v; + static constexpr bool IsRuntimeDataTypeB = cute::is_same_v; + + static_assert(IsRuntimeDataTypeA == IsRuntimeDataTypeB, + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineUmmaConsumerAsync; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + + static_assert(size(GmemTiledCopyA{}) == size(GmemTiledCopyB{}), "A and B GmemTiledCopy should share the same thread count"); + static constexpr int NumLoadThreads = size(GmemTiledCopyA{}); + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M,K)"); + static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtomB must be rank 2 (N,K)"); + static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using MmaSmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + using LoadSmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + append(LoadShapeA_MK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + using MmaSmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + using LoadSmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + append(LoadShapeB_NK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + + using SmemAllocTypeA = cute::conditional_t < 8, uint8_t, ElementAMma>; + using SmemAllocTypeB = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using BitTypeElementA = cute::uint_bit_t>; + using BitTypeElementB = cute::uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = cute::conditional_t; + using RuntimeDataTypeB = cute::conditional_t; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ArrayElementA const* ptr_A{nullptr}; + StrideA dA{}; + ArrayElementB const* ptr_B{nullptr}; + StrideB dB{}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + ArrayElementA const* ptr_A{nullptr}; + StrideA dA{}; + ArrayElementB const* ptr_B{nullptr}; + StrideB dB{}; + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + }; + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_B = recast_ptr(args.ptr_B); + + return { + args.ptr_A, + args.dA, + args.ptr_B, + args.dB, + args.runtime_data_type_a, + args.runtime_data_type_b + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + bool implementable = true; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for CpAsync.\n"); + } + return implementable; + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tensor for input A + /// gB_nkl - The tiled tensor for input B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // Represent the full tensors + Tensor mA_mkl = make_tensor(make_gmem_ptr(params.ptr_A), make_shape(M,K,L), params.dA); //(m,k,l) + Tensor mB_nkl = make_tensor(make_gmem_ptr(params.ptr_B), make_shape(N,K,L), params.dB); //(n,k,l) + // Partition for cpasync + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + // Build the coordinate tensors with the same shape as input matrices + Tensor cA_mk = make_identity_tensor(make_shape(M,K)); + Tensor cB_nk = make_identity_tensor(make_shape(N,K)); + + // Slice the coordinate tensors in the same way as A/B tensor partitioning + Tensor cgA_mk = local_tile(cA_mk, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k) + Tensor cgB_nk = local_tile(cB_nk, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), LoadSmemLayoutA{}); + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), LoadSmemLayoutB{}); + + GmemTiledCopyA gmem_to_smem_a_tiled_copy; + GmemTiledCopyB gmem_to_smem_b_tiled_copy; + + int thread_idx = threadIdx.x % NumLoadThreads; + auto thr_copy_a = gmem_to_smem_a_tiled_copy.get_slice(thread_idx); + auto thr_copy_b = gmem_to_smem_b_tiled_copy.get_slice(thread_idx); + + return cute::make_tuple( + gA_mkl, gB_nkl, // gmem + cgA_mk, cgB_nk, // crd + sA, sB, // smem + problem_shape_MNKL, + gmem_to_smem_a_tiled_copy, gmem_to_smem_b_tiled_copy, + thr_copy_a, thr_copy_b); + } + + /// Set up the data needed by this collective for mma compute. + template + CUTLASS_DEVICE auto + mma_init( + Params const& params, + [[maybe_unused]] cute::tuple, cute::Tensor> const& accumulators_pair, + TensorStorage& shared_tensors) const { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), MmaSmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), MmaSmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + // Update instruction descriptor according to runtime argument. + // Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe. + tiled_mma.idesc_.a_format_ = uint8_t(params.runtime_data_type_a) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(params.runtime_data_type_b) & 0b111; + } + + return cute::make_tuple(tiled_mma, tCrA, tCrB); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class GTensorA, class GTensorB, + class CTensorA, class CTensorB, + class STensorA, class STensorB, + class ProblemShape_MNKL, + class TiledCopyA, class TiledCopyB, + class ThreadCopyA, class ThreadCopyB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load( + Params const& params, + MainloopPipeline mainloop_pipeline, + MainloopPipelineState mainloop_pipe_producer_state, + cute::tuple const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + // Unpack from load_inputs + GTensorA tAgA_mkl = get<0>(load_inputs); + GTensorB tBgB_nkl = get<1>(load_inputs); + CTensorA cgA_mk = get<2>(load_inputs); + CTensorB cgB_nk = get<3>(load_inputs); + STensorA sA = get<4>(load_inputs); + STensorB sB = get<5>(load_inputs); + ProblemShape_MNKL problem_shape_MNKL = get<6>(load_inputs); + TiledCopyA gmem_to_smem_a_tiled_copy = get<7>(load_inputs); + TiledCopyB gmem_to_smem_b_tiled_copy = get<8>(load_inputs); + ThreadCopyA thr_copy_a = get<9>(load_inputs); + ThreadCopyB thr_copy_b = get<10>(load_inputs); + auto [M,N,K,L] = problem_shape_MNKL; + + // Slice out the work coord from partitioned tensors + Tensor gA_in = tAgA_mkl(_, _, get<0>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + Tensor gB_in = tBgB_nkl(_, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + // Repeat slicing out coordinate tensor exactly the same as input tensor does + Tensor cgA_mk_in = cgA_mk(_, _, get<0>(cta_coord_mnkl), _); + Tensor cgB_nk_in = cgB_nk(_, _, get<1>(cta_coord_mnkl), _); + + auto k_residue = K - size<1>(gB_in) * size<2>(gA_in); + + // Shift tensor so residue_k is at origin (Can't read any k_coord < residue_k) + // This aligns the tensor with BLK_K for all but the 0th k_tile + Tensor gA = domain_offset(make_coord(0, k_residue, 0), gA_in); + Tensor gB = domain_offset(make_coord(0, k_residue, 0), gB_in); + + Tensor cA = domain_offset(make_coord(0, k_residue, 0), cgA_mk_in); + Tensor cB = domain_offset(make_coord(0, k_residue, 0), cgB_nk_in); + + auto tAgA = thr_copy_a.partition_S(gA); + auto tAsA = thr_copy_a.partition_D(sA); + + auto tBgB = thr_copy_b.partition_S(gB); + auto tBsB = thr_copy_b.partition_D(sB); + + // Allocate predicate tensors for m and n + Tensor tApA = make_tensor(make_shape(size<1>(tAsA), size<2>(tAsA)), Stride<_1,_0>{}); + Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); + + Tensor tAcA = thr_copy_a.partition_S(cA); + Tensor tBcB = thr_copy_b.partition_S(cB); + + // Copy gmem to smem for *k_tile_iter, predicating for k residue + Tensor tAgAk = tAgA(_,_,_,*k_tile_iter); + Tensor tBgBk = tBgB(_,_,_,*k_tile_iter); + + // Repeating on predicators with the same operations on tAgA and tBgB + Tensor tAcAk = tAcA(_,_,_,*k_tile_iter); + Tensor tBcBk = tBcB(_,_,_,*k_tile_iter); + + // Set predicates for m bounds + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<0>(tApA); ++m) { + tApA(m,0) = elem_less(get<0>(tAcAk(0,m,0)), M); // blk_m coord < M + } + // Set predicates for n bounds + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<0>(tBpB); ++n) { + tBpB(n,0) = elem_less(get<0>(tBcBk(0,n,0)), N); // blk_n coord < N + } + + // 0-th stage with predication on k to account for residue + // For performance consideration, + // this predicated block for K-tail is only activated when there is k-residue + if (k_residue != 0 && k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tAsA); ++k) { + if ( int(get<1>(tAcAk(0,0,k))) >= 0) { // blk_k coord < K + copy_if(gmem_to_smem_a_tiled_copy, tApA(_,k), tAgAk(_,_,k), tAsA(_,_,k,write_stage)); + } + else { + clear(tAsA(_,_,k,write_stage)); + } + } + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tBsB); ++k) { + if (int(get<1>(tBcBk(0,0,k))) >= 0) { // blk_k coord < K + copy_if(gmem_to_smem_b_tiled_copy, tBpB(_,k), tBgBk(_,_,k), tBsB(_,_,k,write_stage)); + } + else { + clear(tBsB(_,_,k,write_stage)); + } + } + ++k_tile_iter; + --k_tile_count; + + // UNLOCK mainloop_pipe_producer_state + mainloop_pipeline.producer_commit(mainloop_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive); + + // Advance mainloop_pipe_producer_state + ++mainloop_pipe_producer_state; + } + + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + auto mainloop_pipe_producer_state_curr = mainloop_pipe_producer_state; + ++mainloop_pipe_producer_state; + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state_curr, barrier_token); + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state_curr.index(); + + copy_if(gmem_to_smem_a_tiled_copy, tApA, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy_if(gmem_to_smem_b_tiled_copy, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + + mainloop_pipeline.producer_commit(mainloop_pipe_producer_state_curr, cutlass::arch::cpasync_barrier_arrive); + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline mainloop_pipeline, MainloopPipelineState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgEngine, class FrgLayout, + class FragmentA, class FragmentB + > + CUTLASS_DEVICE auto + mma(MainloopPipeline mainloop_pipeline, + MainloopPipelineState mainloop_pipe_consumer_state, + cute::tuple, cute::Tensor> const& accumulators_pair, + cute::tuple const& mma_inputs, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + auto accumulators = get<0>(accumulators_pair); + auto [tiled_mma, tCrA, tCrB] = mma_inputs; + + // + // PIPELINED MAIN LOOP + // + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state); + + int read_stage = mainloop_pipe_consumer_state.index(); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + mainloop_pipeline.consumer_release(mainloop_pipe_consumer_state); + --k_tile_count; + ++mainloop_pipe_consumer_state; + } + + return mainloop_pipe_consumer_state; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_mixed_tma_cpasync_warpspecialized.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_mixed_tma_cpasync_warpspecialized.hpp new file mode 100644 index 0000000..2a05693 --- /dev/null +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_mixed_tma_cpasync_warpspecialized.hpp @@ -0,0 +1,752 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/arch/memory.h" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +#include "cutlass/gemm/collective/collective_mma_decl.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int Stages, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100UmmaMixedTmaCpAsyncWarpSpecialized< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + // Statically asserting to ensure only 1x1x1 cluster shape & 1sm setup is received + static_assert(size(AtomThrShapeMNK{}) == 1, "Lower alignment SM100 GEMM only supports 1SM MMA"); + static_assert(size(ClusterShape{}) == 1, "CPASYNC does not support multicast so the cluster shape is restricted to 1, 1, 1"); + + static_assert(size(typename TiledMma::AtomThrID{}) == 1); + + using DispatchPolicy = MainloopSm100UmmaMixedTmaCpAsyncWarpSpecialized< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; + // TileShape refers to MmaTileShape to adapt for runtime cluster + using TileShape = TileShape_; + + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TileShape should be evenly divided by TiledMma"); + + // Define A and B block shapes + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + // using LoadShapeA_MK = decltype(select<0,2>(TileShape{})); + using LoadShapeB_NK = decltype(select<1,2>(TileShape{})); + + // CtaShape_MNK is queried from collective in all kernel layers + using CtaShape_MNK = TileShape; + + using ElementA = ElementA_; + using ElementAMma = typename TiledMma::ValTypeA; + using StrideA = StrideA_; + using ElementB = ElementB_; + using ElementBMma = typename TiledMma::ValTypeB; + using StrideB = StrideB_; + + static constexpr bool IsRuntimeDataTypeA = cute::is_same_v; + static constexpr bool IsRuntimeDataTypeB = cute::is_same_v; + + static_assert(IsRuntimeDataTypeA == IsRuntimeDataTypeB, + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipelineTMA = cutlass::PipelineTmaUmmaAsync; + using MainloopPipelineTMAState = typename MainloopPipelineTMA::PipelineState; + + using MainloopPipelineCpAsync = cutlass::PipelineUmmaConsumerAsync; + using MainloopPipelineCpAsyncState = typename MainloopPipelineCpAsync::PipelineState; + + // static_assert(size(GmemTiledCopyA{}) == size(GmemTiledCopyB{}), "A and B GmemTiledCopy should share the same thread count"); + static constexpr int NumLoadThreadsCpAsync = size(GmemTiledCopyB{}); + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M,K)"); + static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtomB must be rank 2 (N,K)"); + static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + + using MmaSmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + using LoadSmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + append(LoadShapeB_NK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + + using TmaInternalElementA = cute::conditional_t, cutlass::tfloat32_t, ElementAMma>; + + using SmemAllocTypeA = cute::conditional_t < 8, uint8_t, ElementAMma>; + using SmemAllocTypeB = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using BitTypeElementA = cute::uint_bit_t>; + using BitTypeElementB = cute::uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = cute::conditional_t; + using RuntimeDataTypeB = cute::conditional_t; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + + using PipelineStorageTMA = typename MainloopPipelineTMA::SharedStorage; + using PipelineStorageCpAsync = typename MainloopPipelineCpAsync::SharedStorage; + + struct PipelineStorage : cute::aligned_struct<16, _0> { + alignas(16) PipelineStorageTMA tma; + alignas(16) PipelineStorageCpAsync cpasync; + } pipelines; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr uint32_t TmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v); + + template + struct TmemStorage { + AccTensor accumulators; + }; + + // Host side kernel arguments + struct Arguments { + ArrayElementA const* ptr_A{nullptr}; + ArrayElementB const* ptr_B{nullptr}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(ClusterShape{}), + make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + TMA_A tma_load_a; + + ArrayElementB const* ptr_B{nullptr}; + + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params) + : runtime_data_type_a_(params.runtime_data_type_a) + , runtime_data_type_b_(params.runtime_data_type_b) { + + observed_tma_load_a_ = ¶ms.tma_load_a; + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_B = recast_ptr(args.ptr_B); + + auto shape_a = make_shape(M, K, L); + StrideA stride_a = cutlass::make_internal_packed_stride(StrideA{}, shape_a); + Tensor tensor_a = make_tensor(ptr_A, make_layout(shape_a, stride_a)); + + auto cluster_layout_vmnk = tiled_divide(make_layout(ClusterShape{}), make_tile(typename TiledMma::AtomThrID{})); + + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + return { + tma_load_a, + args.ptr_B, + args.runtime_data_type_a, + args.runtime_data_type_b + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits::value; + + bool implementable = true; + + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for CpAsync.\n"); + } + + return implementable; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE void + prefetch_tma_descriptors() { + cute::prefetch_tma_descriptor(observed_tma_load_a_->get_tma_descriptor()); + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE static + auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + template + CUTLASS_DEVICE static + auto + slice_accumulator(TmemStorage tmem_storage, int stage) { + return cute::make_tuple(tmem_storage.accumulators(_,_,_,stage)); + } + + template + CUTLASS_DEVICE static + auto + init_tmem_tensors(EpilogueTile epi_tile) { + TiledMma tiled_mma; + auto acc_shape = partition_accumulator_shape(); + // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. + Tensor accumulators = cutlass::detail::make_sm100_accumulator( + tiled_mma, acc_shape, EpilogueTile{}); + TmemStorage tmem_storage; + tmem_storage.accumulators = accumulators; + return tmem_storage; + } + + template + CUTLASS_DEVICE static + void + set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) { + tmem_storage.accumulators.data() = tmem_base_addr; + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tensor for input A + /// gB_nkl - The tiled tensor for input B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + template + CUTLASS_DEVICE auto + load_init_tma( + ProblemShape_MNKL const& problem_shape_MNKL, + TensorStorage& shared_tensors) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,L)); + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + + ThrMMA cta_mma = TiledMma{}.get_slice(0); + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(ClusterShape{}); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(0); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + return cute::make_tuple( + shape<3>(gA_mkl), // for scheduler + tAgA_mkl, tAsA // for input tensor values + ); + } + + template + CUTLASS_DEVICE auto + load_init_cpasync( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors, + TileScheduler const& scheduler, + typename TileScheduler::WorkTileInfo const& work_tile_info) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // Represent the full tensors + + auto shape_b = make_shape(N, K, L); + StrideB stride_b = cutlass::make_internal_packed_stride(StrideB{}, shape_b); + + Tensor mB_nkl = make_tensor(make_gmem_ptr(params.ptr_B), shape_b, stride_b); //(n,k,l) + // Partition for cpasync + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + // Build the coordinate tensors with the same shape as input matrices + Tensor cB_nk = make_identity_tensor(make_shape(N,K)); + // Slice the coordinate tensors in the same way as A/B tensor partitioning + Tensor cgB_nk = local_tile(cB_nk, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k) + + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), LoadSmemLayoutB{}); + + GmemTiledCopyB gmem_to_smem_b_tiled_copy; + + int thread_idx = threadIdx.x % NumLoadThreadsCpAsync; + auto thr_copy_b = gmem_to_smem_b_tiled_copy.get_slice(thread_idx); + + return cute::make_tuple( + gB_nkl, cgB_nk, sB, + gmem_to_smem_b_tiled_copy, thr_copy_b); + } + + /// Set up the data needed by this collective for mma compute. + template + CUTLASS_DEVICE auto + mma_init( + Params const& params, + [[maybe_unused]] TmemStorage tmem_storage, + // [[maybe_unused]] cute::tuple, cute::Tensor> const& accumulators_pair, + TensorStorage& shared_tensors) const { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), MmaSmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + // Update instruction descriptor according to runtime argument. + // Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe. + tiled_mma.idesc_.a_format_ = uint8_t(params.runtime_data_type_a) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(params.runtime_data_type_b) & 0b111; + } + + return cute::make_tuple(tiled_mma, tCrA, tCrB); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class KTileCount, + class GTensorPartitionedA, + class STensorA, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load_tma( + MainloopPipelineTMA mainloop_pipeline, + MainloopPipelineTMAState mainloop_pipe_producer_state, + cute::tuple const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + // Unpack from load_inputs + KTileCount k_tiles = get<0>(load_inputs); + GTensorPartitionedA tAgA_mkl = get<1>(load_inputs); + STensorA tAsA = get<2>(load_inputs); + + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + + using BarrierType = typename MainloopPipelineTMA::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(*tma_barrier), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + } + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + + template < + // class GTensorB, + // class CTensorB, + // class STensorB, + // class ProblemShape_MNKL, + // class TiledCopyB, + // class ThreadCopyB, + class TileCoordMNKL, + class KTileIterator, + class ProblemShape_MNKL, + class... TParams + > + CUTLASS_DEVICE auto + load_cpasync( + Params const& params, + MainloopPipelineCpAsync mainloop_pipeline, + MainloopPipelineCpAsyncState mainloop_pipe_producer_state, + cute::tuple const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count, + ProblemShape_MNKL effective_shape + ) { + + auto [ + tBgB_nkl, cgB_nk, sB, + // problem_shape_MNKL, + gmem_to_smem_b_tiled_copy, thr_copy_b] = load_inputs; + + auto [M,N,K,L] = effective_shape; + + // Slice out the work coord from partitioned tensors + Tensor gB_in = tBgB_nkl(_, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + // Repeat slicing out coordinate tensor exactly the same as input tensor does + Tensor cgB_nk_in = cgB_nk(_, _, get<1>(cta_coord_mnkl), _); + + auto k_residue = K - size<1>(gB_in) * size<2>(gB_in); // K - BLK_K * k is negative + + Tensor gB = gB_in; + Tensor cB = cgB_nk_in; + + auto tBgB = thr_copy_b.partition_S(gB); + auto tBsB = thr_copy_b.partition_D(sB); + + // Allocate predicate tensors for n + Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); + Tensor tBcB_nk = thr_copy_b.partition_S(cgB_nk_in); + Tensor tBcB = thr_copy_b.partition_S(cB); + + // Copy gmem to smem for *k_tile_iter, predicating for k residue + Tensor tBgBk = tBgB(_,_,_,*k_tile_iter); + + // Repeating on predicators with the same operations on tBgB + Tensor tBcBk = tBcB(_,_,_,*k_tile_iter); + + // Set predicates for n bounds + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<0>(tBpB); ++n) { + tBpB(n,0) = elem_less(get<0>(tBcBk(0,n,0)), N); // blk_n coord < N + } + + // we will process the last tile after the mainloop + if (k_residue != 0) { + --k_tile_count; + } + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + + copy_if(gmem_to_smem_b_tiled_copy, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + + mainloop_pipeline.producer_commit(mainloop_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive); + --k_tile_count; + ++k_tile_iter; + ++mainloop_pipe_producer_state; + } + + // last tile with predication on k to account for residue + // For performance consideration, + // this predicated block for K-tail is only activated when there is k-residue + if (k_residue != 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tBsB); ++k) { + if (int(get<1>(tBcBk(0,0,k))) >= 0) { // blk_k coord < K + copy_if(gmem_to_smem_b_tiled_copy, tBpB(_,k), tBgB(_,_,k,*k_tile_iter), tBsB(_,_,k,write_stage)); + } + else { + clear(tBsB(_,_,k,write_stage)); + } + } + ++k_tile_iter; + --k_tile_count; + + // UNLOCK mainloop_pipe_producer_state + mainloop_pipeline.producer_commit(mainloop_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive); + + // Advance mainloop_pipe_producer_state + ++mainloop_pipe_producer_state; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_tail_tma(MainloopPipelineTMA mainloop_pipeline, MainloopPipelineTMAState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + CUTLASS_DEVICE void + load_tail_cpasync(MainloopPipelineCpAsync mainloop_pipeline, MainloopPipelineCpAsyncState mainloop_pipe_producer_state) { + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class AccumulatorPipeline, + class FrgEngine, class FrgLayout, + class FragmentA, class FragmentB, + class CtaTileCoord + > + CUTLASS_DEVICE auto + mma(cute::tuple pipelines, + cute::tuple pipeline_states, + cute::tuple> const& accumulators_pair, + cute::tuple const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + auto accumulators = get<0>(accumulators_pair); + auto [tiled_mma, tCrA, tCrB] = mma_inputs; + + auto [mainloop_pipeline_tma, mainloop_pipeline_cpasync, accumulator_pipeline] = pipelines; + auto [mainloop_pipe_tma_consumer_state, mainloop_pipe_cpasync_consumer_state, accumulator_pipe_producer_state] = pipeline_states; + + // + // PIPELINED MAIN LOOP + // + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + mainloop_pipeline_tma.consumer_wait(mainloop_pipe_tma_consumer_state); + mainloop_pipeline_cpasync.consumer_wait(mainloop_pipe_cpasync_consumer_state); + + int read_stage_tma = mainloop_pipe_tma_consumer_state.index(); + int read_stage_cpasync = mainloop_pipe_cpasync_consumer_state.index(); + + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage_tma), tCrB(_,_,k_block,read_stage_cpasync), accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + mainloop_pipeline_tma.consumer_release(mainloop_pipe_tma_consumer_state); + mainloop_pipeline_cpasync.consumer_release(mainloop_pipe_cpasync_consumer_state); + --k_tile_count; + ++mainloop_pipe_tma_consumer_state; + ++mainloop_pipe_cpasync_consumer_state; + } + + return cute::make_tuple(mainloop_pipe_tma_consumer_state, mainloop_pipe_cpasync_consumer_state); + } + +protected: + + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + RuntimeDataTypeA runtime_data_type_a_{}; + RuntimeDataTypeB runtime_data_type_b_{}; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp index fe5ee3c..4644eae 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp index e76818d..3f008c0 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -347,20 +347,20 @@ struct CollectiveMma< template< class KTileCount, - class GTensorPartitionedScaleA, class GTensorPartitionedScaleB, - class IdentTensorPartitionedScaleA, class IdentTensorPartitionedScaleB, + class GTensorScaleA, class GTensorScaleB, + class IdentTensorScaleA, class IdentTensorScaleB, class STensorScaleA, class STensorScaleB > struct LoadSFParams { // for scheduler KTileCount k_tiles; - GTensorPartitionedScaleA tSFAgSFA_mkl; - GTensorPartitionedScaleB tSFBgSFB_nkl; - IdentTensorPartitionedScaleA tSFAIdentSFA_mkl; - IdentTensorPartitionedScaleB tSFBIdentSFB_nkl; - STensorScaleA tSFAsSFA; - STensorScaleB tSFBsSFB; + GTensorScaleA gSFA_mkl; + GTensorScaleB gSFB_nkl; + IdentTensorScaleA identSFA_mkl; + IdentTensorScaleB identSFB_nkl; + STensorScaleA sSFA; + STensorScaleB sSFB; LayoutSFA layout_SFA; LayoutSFB layout_SFB; @@ -368,14 +368,14 @@ struct CollectiveMma< CUTLASS_DEVICE LoadSFParams ( KTileCount k_tiles_, - GTensorPartitionedScaleA tSFAgSFA_mkl_, GTensorPartitionedScaleB tSFBgSFB_nkl_, - IdentTensorPartitionedScaleA tSFAIdentSFA_mkl_, IdentTensorPartitionedScaleB tSFBIdentSFB_nkl_, - STensorScaleA tSFAsSFA_, STensorScaleB tSFBsSFB_, + GTensorScaleA gSFA_mkl_, GTensorScaleB gSFB_nkl_, + IdentTensorScaleA identSFA_mkl_, IdentTensorScaleB identSFB_nkl_, + STensorScaleA sSFA_, STensorScaleB sSFB_, LayoutSFA layout_SFA_, LayoutSFB layout_SFB_) : k_tiles(k_tiles_) - , tSFAgSFA_mkl(tSFAgSFA_mkl_), tSFBgSFB_nkl(tSFBgSFB_nkl_) - , tSFAIdentSFA_mkl(tSFAIdentSFA_mkl_), tSFBIdentSFB_nkl(tSFBIdentSFB_nkl_) - , tSFAsSFA(tSFAsSFA_), tSFBsSFB(tSFBsSFB_) + , gSFA_mkl(gSFA_mkl_), gSFB_nkl(gSFB_nkl_) + , identSFA_mkl(identSFA_mkl_), identSFB_nkl(identSFB_nkl_) + , sSFA(sSFA_), sSFB(sSFB_) , layout_SFA(layout_SFA_), layout_SFB(layout_SFB_) {} }; @@ -732,35 +732,16 @@ struct CollectiveMma< static_assert(rank(decltype(gSFA_mkl){}) == 5); static_assert(rank(decltype(gSFB_nkl){}) == 5); - // 1 thread copies entire set of scalar - GmemTiledCopySFA scale_copy_a{}; - GmemTiledCopySFB scale_copy_b{}; - - ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x % size(scale_copy_a)); - ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x % size(scale_copy_b)); - Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutScaleA{}); // (CTA_M,CTA_K,P) Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutScaleB{}); // (CTA_M,CTA_K,P) - Tensor tSFAgSFA_mkl = thr_scale_copy_a.partition_S(gSFA_mkl); // (CPY, BLK_M, BLK_K, m, k, l) - Tensor tSFAIdentSFA_mkl = thr_scale_copy_a.partition_S(identSFA_mkl); // (CPY, BLK_M, BLK_K, m, k, l) - - Tensor tSFAsSFA = thr_scale_copy_a.partition_D(sSFA); - - Tensor tSFBgSFB_nkl = thr_scale_copy_b.partition_S(gSFB_nkl); // (CPY, BLK_N, BLK_K, m, k, l) - Tensor tSFBIdentSFB_nkl = thr_scale_copy_b.partition_S(identSFB_nkl); // (CPY, BLK_N, BLK_K, m, k, l) - Tensor tSFBsSFB = thr_scale_copy_b.partition_D(sSFB); - - static_assert(rank(decltype(tSFAgSFA_mkl){}) == 6); - static_assert(rank(decltype(tSFBgSFB_nkl){}) == 6); - LoadSFParams load_params { size<3>(gSFA_mkl), - tSFAgSFA_mkl, tSFBgSFB_nkl, // for input scale tensor values - tSFAIdentSFA_mkl, tSFBIdentSFB_nkl, // for predicating scale tensor copies - tSFAsSFA, tSFBsSFB, // for scale tensor values + gSFA_mkl, gSFB_nkl, // for input scale tensor values + identSFA_mkl, identSFB_nkl, // for predicating scale tensor copies + sSFA, sSFB, // for scale tensor values mainloop_params.layout_SFA, // for predicating scale tensor copies mainloop_params.layout_SFB // for predicating scale tensor copies }; @@ -922,24 +903,44 @@ struct CollectiveMma< KTileIterator k_tile_iter, int k_tile_count) { auto [unused_k_tiles, - tSFAgSFA_mkl, tSFBgSFB_nkl, - tSFAIdentSFA_mkl, tSFBIdentSFB_nkl, - tSFAsSFA, tSFBsSFB, + gSFA_mkl, gSFB_nkl, + identSFA_mkl, identSFB_nkl, + sSFA, sSFB, layout_SFA, layout_SFB] = load_inputs; // slice out the work coord from partitioned tensors GmemTiledCopySFA scale_copy_a{}; GmemTiledCopySFB scale_copy_b{}; - Tensor tSFAgSFA = tSFAgSFA_mkl(_, _, _, get<0>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + Tensor gSFA_k_compact = filter_zeros( + gSFA_mkl(_, _, get<0>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl))); // (BLK_M_CPT, BLK_K_CPT, k_cpt) + Tensor gSFB_k_compact = filter_zeros( + gSFB_nkl(_, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl))); // (BLK_N_CPT, BLK_K_CPT, k_cpt) + + Tensor identSFA_k_compact = filter_zeros( + identSFA_mkl(_, _, get<0>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)), + gSFA_k_compact.stride()); // (BLK_M_CPT, BLK_K_CPT, k_cpt) + Tensor identSFB_k_compact = filter_zeros( + identSFB_nkl(_, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)), + gSFB_k_compact.stride()); // (BLK_N_CPT, BLK_K_CPT, k_cpt) + + Tensor sSFA_compact = filter_zeros(sSFA); // (BLK_M_CPT, BLK_K_CPT, P) + Tensor sSFB_compact = filter_zeros(sSFB); // (BLK_N_CPT, BLK_K_CPT, P) + + ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x % size(scale_copy_a)); + ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x % size(scale_copy_b)); + + Tensor tSFAgSFA_k_compact = thr_scale_copy_a.partition_S(gSFA_k_compact); // (CPY, BLK_M, BLK_K, k) + Tensor tSFAIdentSFA_k_compact = thr_scale_copy_a.partition_S(identSFA_k_compact); // (CPY, BLK_M, BLK_K, k) - Tensor tSFBgSFB = tSFBgSFB_nkl(_, _, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + Tensor tSFAsSFA_compact = thr_scale_copy_a.partition_D(sSFA_compact); - Tensor thr_tile_SFA_k = tSFAIdentSFA_mkl(_0{}, _, _, get<0>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); - Tensor thr_tile_pSFA = make_tensor(shape(filter_zeros(thr_tile_SFA_k(_,_,_0{}), tSFAgSFA(_0{},_,_,_0{}).stride()))); - Tensor thr_tile_SFB_k = tSFBIdentSFB_nkl(_0{}, _, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + Tensor tSFBgSFB_k_compact = thr_scale_copy_b.partition_S(gSFB_k_compact); // (CPY, BLK_N, BLK_K, k) + Tensor tSFBIdentSFB_k_compact = thr_scale_copy_b.partition_S(identSFB_k_compact); // (CPY, BLK_N, BLK_K, k) + Tensor tSFBsSFB_compact = thr_scale_copy_b.partition_D(sSFB_compact); - Tensor thr_tile_pSFB = make_tensor(shape(filter_zeros(thr_tile_SFB_k(_,_,_0{}), tSFBgSFB(_0{},_,_,_0{}).stride()))); + Tensor thr_tile_pSFA = make_fragment_like(tSFAgSFA_k_compact(_0{},_,_,_0{})); + Tensor thr_tile_pSFB = make_fragment_like(tSFBgSFB_k_compact(_0{},_,_,_0{})); // Issue the loads CUTLASS_PRAGMA_NO_UNROLL @@ -949,18 +950,22 @@ struct CollectiveMma< CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(thr_tile_pSFA); ++i) { - Tensor thr_tile_SFA = filter_zeros(thr_tile_SFA_k(_,_,*k_tile_iter), tSFAgSFA(_0{},_,_,_0{}).stride()); - thr_tile_pSFA(i) = elem_less(thr_tile_SFA(i), shape(filter_zeros(layout_SFA))) && threadIdx.x % 32 < size(scale_copy_a); + Tensor tSFAIdentSFA_compact = tSFAIdentSFA_k_compact(_0{},_,_,*k_tile_iter); + thr_tile_pSFA(i) = elem_less(tSFAIdentSFA_compact(i), + shape(filter_zeros(layout_SFA))) && threadIdx.x % 32 < size(scale_copy_a); } CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(thr_tile_pSFB); ++i) { - Tensor thr_tile_SFB = filter_zeros(thr_tile_SFB_k(_,_,*k_tile_iter), tSFBgSFB(_0{},_,_,_0{}).stride()); - thr_tile_pSFB(i) = elem_less(thr_tile_SFB(i), shape(filter_zeros(layout_SFB))) && threadIdx.x % 32 < size(scale_copy_b); + Tensor tSFBIdentSFB_compact = tSFBIdentSFB_k_compact(_0{},_,_,*k_tile_iter); + thr_tile_pSFB(i) = elem_less(tSFBIdentSFB_compact(i), + shape(filter_zeros(layout_SFB))) && threadIdx.x % 32 < size(scale_copy_b); } - copy_if(scale_copy_a, thr_tile_pSFA, filter_zeros(tSFAgSFA(_,_,_,*k_tile_iter)), filter_zeros(tSFAsSFA(_,_,_,mainloop_sf_pipe_producer_state.index()))); - copy_if(scale_copy_b, thr_tile_pSFB, filter_zeros(tSFBgSFB(_,_,_,*k_tile_iter)), filter_zeros(tSFBsSFB(_,_,_,mainloop_sf_pipe_producer_state.index()))); + copy_if(scale_copy_a, thr_tile_pSFA, tSFAgSFA_k_compact(_,_,_,*k_tile_iter), + tSFAsSFA_compact(_,_,_,mainloop_sf_pipe_producer_state.index())); + copy_if(scale_copy_b, thr_tile_pSFB, tSFBgSFB_k_compact(_,_,_,*k_tile_iter), + tSFBsSFB_compact(_,_,_,mainloop_sf_pipe_producer_state.index())); mainloop_sf_pipeline.producer_commit(mainloop_sf_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive_noinc); __syncwarp(); diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_emulated.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_emulated.hpp index 54c3bd5..1be8060 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_emulated.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_emulated.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -192,6 +192,11 @@ struct CollectiveMma< static constexpr uint32_t NumTransformationThreads = 128; static constexpr uint32_t NumAccumThreads = 128; + // Register reconfiguration + static constexpr uint32_t GenericRegisterRequirement = 64; + static constexpr uint32_t TransformRegisterRequirement = 184; + static constexpr uint32_t AccumRegisterRequirement = 256; + // Get the Algorithm parameters constexpr static int NumComputeMtxs = 3; constexpr static int NumBandsToCompute = DispatchPolicy::NumBandsToCompute; diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_interleaved_complex_emulated.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_interleaved_complex_emulated.hpp new file mode 100644 index 0000000..a8fe8c4 --- /dev/null +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_interleaved_complex_emulated.hpp @@ -0,0 +1,1077 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +#pragma once +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/numeric_conversion.h" +#include "cutlass/detail/sm100_tmem_helper.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/arch/mma_sm100.hpp" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop for FastF32 Kernels: Interleaved complex variants +template < + int Load2TransformPipelineStageCount_, + int Transform2MmaPipelineStageCount_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + int NumBandsToCompute_, + int ScalingFactor_, + int AccPromotionInterval_, + class AccumulatorCopyAtom_, + class ClusterShape, + class TileShape_, + class StrideA_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomsA_, + class CopyAtomsA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomsB_, + class CopyAtomsB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100TmaUmmaWarpSpecializedFastF32< + Load2TransformPipelineStageCount_, + Transform2MmaPipelineStageCount_, + SchedulerPipelineStageCount_, + AccumulatorPipelineStageCount_, + NumBandsToCompute_, + ScalingFactor_, + AccPromotionInterval_, + ClusterShape, + AccumulatorCopyAtom_>, + TileShape_, + complex, + StrideA_, + complex, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomsA_, + CopyAtomsA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomsB_, + CopyAtomsB_, + TransformB_> +{ + // + // Type Aliases + // + using TileShape = TileShape_; + using TiledMma = TiledMma_; + + // ElementA and ElementB are cutlass::complex, which are used as GMEM input and output data type. + using ElementA = complex; + using StrideA = StrideA_; + using ElementB = complex; + using StrideB = StrideB_; + + // For a complex kernel, the MMA output type is real valued, but ElementAccumulator is a complex type for the GETT reference kernel + using ElementAccumulator = complex; + using ElementAccumulatorRaw = typename TiledMma::ValTypeC; + +private: + // ElementAMma and ElementBMma are cutlass::complex, which are used as SMEM and RF data type. + // ElementAMmaRaw and ElementBMmaRaw are cutlass::bfloat16_t, which is the real internal data type set in TMA descriptor and used in TCGEN05 calculation. + using ElementAMma = typename TiledMma::ValTypeA; // complex + using ElementAMmaRaw = typename ElementAMma::value_type; // bfloat16_t + using ElementBMma = typename TiledMma::ValTypeB; // complex + using ElementBMmaRaw = typename ElementBMma::value_type; // bfloat16_t + +public: + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomsA = SmemLayoutAtomsA_; + using SmemLayoutAtomsB = SmemLayoutAtomsB_; + using CopyAtomsA = CopyAtomsA_; + using CopyAtomsB = CopyAtomsB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + + static_assert(cute::is_same_v, "Underlying input type for A should be float"); + static_assert(cute::is_same_v, "Underlying input type for B should be float"); + static_assert(cute::is_same_v, "Underlying compute type for A should be bfloat16_t"); + static_assert(cute::is_same_v, "Underlying compute type for A should be bfloat16_t"); + + // Determine MMA type: MMA_1SM vs MMA_2SM + using AtomThrShapeMNK = Shape(typename TiledMma_::ThrLayoutVMNK{})), _1, _1>; + using DispatchPolicy = MainloopSm100TmaUmmaWarpSpecializedFastF32< + Load2TransformPipelineStageCount_, + Transform2MmaPipelineStageCount_, + SchedulerPipelineStageCount_, + AccumulatorPipelineStageCount_, + NumBandsToCompute_, + ScalingFactor_, + AccPromotionInterval_, + ClusterShape, + AccumulatorCopyAtom_>; + static constexpr bool IsDynamicCluster = not cute::is_static_v; + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + + // Define A and B block shapes for reduced size TMA_LOADs + using CtaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using CtaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + + using ArchTag = typename DispatchPolicy::ArchTag; + + using Load2TransformPipeline = cutlass::PipelineTmaTransformAsync< + DispatchPolicy::Load2TransformPipelineStageCount, + AtomThrShapeMNK>; + using Load2TransformPipelineState = typename Load2TransformPipeline::PipelineState; + + using Transform2MmaPipeline = cutlass::PipelineUmmaConsumerAsync< + DispatchPolicy::Transform2MmaPipelineStageCount, + AtomThrShapeMNK>; + using Transform2MmaPipelineState = typename Transform2MmaPipeline::PipelineState; + + using Mma2AccumPipeline = cutlass::PipelineUmmaAsync< + DispatchPolicy::Schedule::AccumulatorPipelineStageCount, + AtomThrShapeMNK>; + using Mma2AccumPipelineState = typename Mma2AccumPipeline::PipelineState; + + // Thread Counts + static constexpr uint32_t NumTransformationThreads = 128; + static constexpr uint32_t NumAccumThreads = 128; + + // Register reconfiguration + static constexpr uint32_t GenericRegisterRequirement = 64; + static constexpr uint32_t TransformRegisterRequirement = 184; + static constexpr uint32_t AccumRegisterRequirement = 256; + + // Get the Algorithm parameters + constexpr static int NumComputeMtxs = 3; + constexpr static int ConjSwapMode = 2; + constexpr static int NumBandsToCompute = DispatchPolicy::NumBandsToCompute; + constexpr static int ScalingFactor = DispatchPolicy::ScalingFactor; + constexpr static int AccPromotionInterval = DispatchPolicy::AccPromotionInterval; + constexpr static int AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; + constexpr static int StagesPerTile = size<2>(CtaShapeA_MK{}) / DispatchPolicy::AccPromotionInterval; + constexpr static int NumBandsMax = 5; + static_assert(NumBandsToCompute <= NumBandsMax && NumBandsToCompute >= 3, "NumBandsToCompute should be less than maximum number of bands"); + static_assert(StagesPerTile * AccPromotionInterval == size<2>(CtaShapeA_MK{}), "PromotionInterval*InstructionK doesn't evenly divide CTA shape"); + + // Copy atom for Accumulator + using AccumulatorCopyAtom = typename DispatchPolicy::AccumulatorCopyAtom; + + static_assert((NumBandsToCompute == 5 || NumBandsToCompute == 4 || NumBandsToCompute == 3), + "9xBF16 with 5/4/3 Bands are supported"); + + using SmemLayoutAtomA = typename SmemLayoutAtomsA::InputLayoutAtom; + using SmemLayoutAtomACompute = typename SmemLayoutAtomsA::ComputeLayoutAtom; + using SmemLayoutAtomB = typename SmemLayoutAtomsB::InputLayoutAtom; + using SmemLayoutAtomBCompute = typename SmemLayoutAtomsB::ComputeLayoutAtom; + + using InputCopyAtomA = typename CopyAtomsA::InputCopyAtom; + using ComputeCopyAtomA = typename CopyAtomsA::ComputeCopyAtom; + using InputCopyAtomB = typename CopyAtomsB::InputCopyAtom; + using ComputeCopyAtomB = typename CopyAtomsB::ComputeCopyAtom; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(((size<0,0>(CtaShapeA_MK{}) * size<1>(CtaShapeA_MK{})) % size<0>(SmemLayoutAtomACompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + static_assert(((size<0,1>(CtaShapeA_MK{}) * size<2>(CtaShapeA_MK{})) % size<1>(SmemLayoutAtomACompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(((size<0,0>(CtaShapeB_NK{}) * size<1>(CtaShapeB_NK{})) % size<0>(SmemLayoutAtomBCompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + static_assert(((size<0,1>(CtaShapeB_NK{}) * size<2>(CtaShapeB_NK{})) % size<1>(SmemLayoutAtomBCompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(CtaShapeA_MK{}, Int{}), + (cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}))); + + using SmemLayoutACompute = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomACompute{}, + tuple_cat(CtaShapeA_MK{}, tuple, Int>{}))); + + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(CtaShapeB_NK{}, Int{}), + (cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}))); + + using SmemLayoutBCompute = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomBCompute{}, + tuple_cat(CtaShapeB_NK{}, tuple, Int, Int>{}))); + + static_assert(DispatchPolicy::Load2TransformPipelineStageCount >= 2 && DispatchPolicy::Load2TransformPipelineStageCount >= 2, + "Specialization requires Stages set to value 2 or more."); + static_assert((cute::is_base_of::value || + cute::is_base_of::value ) && + cute::is_base_of::value, + "MMA atom must A operand from SMEM or TMEM and B operand from SMEM for this mainloop."); + static_assert((cute::is_same_v || cute::is_same_v), + "GmemTiledCopyA - invalid TMA copy atom specified."); + static_assert((cute::is_same_v || cute::is_same_v), + "GmemTiledCopyB - invalid TMA copy atom specified."); + + struct PipelineStorage { + using Load2TransformPipelineStorage = typename Load2TransformPipeline::SharedStorage; + alignas(16) Load2TransformPipelineStorage load2transform_pipeline; + using Transform2MmaPipelineStorage = typename Transform2MmaPipeline::SharedStorage; + alignas(16) Transform2MmaPipelineStorage transform2mma_pipeline; + using Mma2AccumPipelineStorage = typename Mma2AccumPipeline::SharedStorage; + alignas(16) Mma2AccumPipelineStorage mma2accum_pipeline; + }; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + struct TensorStorageUntransformed { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + }; + + struct TensorStorageTransformedAinSmem { + alignas(1024) cute::ArrayEngine> smem_ACompute; + alignas(1024) cute::ArrayEngine> smem_BCompute; + }; + + union TensorStorageTransformedAinTmem { + alignas(1024) cute::ArrayEngine smem_ACompute; // No smem_ACompute + alignas(1024) cute::ArrayEngine> smem_BCompute; + }; + + using TensorStorageTransformed = cute::conditional_t< + cute::is_base_of::value, + TensorStorageTransformedAinSmem, + TensorStorageTransformedAinTmem>; + + TensorStorageUntransformed input; + TensorStorageTransformed compute; + } tensors; + + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + + // Different from other GEMM kernels, both CTAs should be aware of loads. Both CTAs will work on + // loaded input A and B matrices to convert the data type + static constexpr uint32_t TmaTransactionBytes = + cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * size<2>(SmemLayoutA{}) * static_cast(sizeof_bits::value))+ + cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * size<2>(SmemLayoutB{}) * static_cast(sizeof_bits::value)); + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A{nullptr}; + StrideA dA{}; + ElementB const* ptr_B{nullptr}; + StrideB dB{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), + make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + using TMA_B = decltype(make_tma_atom_B_sm100( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + dim3 cluster_shape_fallback; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) { + if constexpr (IsDynamicCluster) { + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + } + } + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + Tensor tensor_a = make_tensor(args.ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(args.ptr_B, make_layout(make_shape(N,K,L), args.dB)); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + // Cluster layout for TMA construction + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + return { + tma_load_a, + tma_load_b, + tma_load_a_fallback, + tma_load_b_fallback, + hw_info.cluster_shape_fallback + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE static void + prefetch_tma_descriptors(Params const& params) { + if constexpr (IsDynamicCluster) { + dim3 cs = cute::cluster_shape(); + const bool is_fallback_cluster = (cs.x == params.cluster_shape_fallback.x && cs.y == params.cluster_shape_fallback.y); + if (is_fallback_cluster) { + cute::prefetch_tma_descriptor(params.tma_load_a_fallback.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b_fallback.get_tma_descriptor()); + } + else { + cute::prefetch_tma_descriptor(params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b.get_tma_descriptor()); + } + } + else { + cute::prefetch_tma_descriptor(params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b.get_tma_descriptor()); + } + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + /// Produce the inputs to the transform threads by loading inputs from gmem -> smem + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load( + Params const& params, + Load2TransformPipeline pipeline, + Load2TransformPipelineState load2xform_pipeline_state, + cute::tuple const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + auto [unused_gA, unused_gB, + tAgA_mkl, tBgB_nkl, tAsA, tBsB, + mcast_mask_a, mcast_mask_b] = load_inputs; + + // slice out the work coord from tiled tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + uint32_t skip_wait = (k_tile_count <= 0); + auto pipeline_flag = pipeline.producer_try_acquire(load2xform_pipeline_state, skip_wait); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK mainloop_load2xform_pipeline_state for _writing_ + pipeline.producer_acquire(load2xform_pipeline_state, pipeline_flag); + int write_stage = load2xform_pipeline_state.index(); + + using BarrierType = typename Load2TransformPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(load2xform_pipeline_state); + + // Advance mainloop_pipe + ++load2xform_pipeline_state; + skip_wait = (k_tile_count <= 1); + pipeline_flag = pipeline.producer_try_acquire(load2xform_pipeline_state, skip_wait); + + copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(*tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage)); + ++k_tile_iter; + } + return cute::make_tuple(load2xform_pipeline_state, k_tile_iter); + } + + /// Set up the data needed by this collective for load. + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tiled tensor for input A + /// gB_nkl - The tiled tensor for input B + // Other inputs needed for load(): partitioned AB tensors for gmem and smem, and mcast masks + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_storage) const { + auto [gA_mkl, gB_nkl] = tile_input_tensors(params, problem_shape_MNKL); + + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_storage.input.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_storage.input.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,3>(tCgB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + return cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b); // multicast masks + } + + template< + class KTileIterator, class Accumulator, + class GTensorA, class DstCopyA, class SrcTensorA, class DstTensorA, + class GTensorB, class SrcTensorB, class DstTensorB + > + CUTLASS_DEVICE auto + transform( + Load2TransformPipeline load2transform_pipeline, + Load2TransformPipelineState load2transform_pipeline_consumer_state, + Transform2MmaPipeline transform2mma_pipeline, + Transform2MmaPipelineState transform2mma_pipeline_producer_state, + Accumulator accumulators, + cute::tuple input_operands, + KTileIterator k_tile_iter, int k_tile_count) { + + static_assert(cute::is_same_v, "ElementA and ElementB types should be the same."); + static_assert(cute::is_same_v, "ElementAMma and ElementBMma types should be the same."); + + cutlass::arch::NamedBarrier transform_bar(NumTransformationThreads, cutlass::arch::ReservedNamedBarriers::TransformBarrier); + + // tAsA : (Copy,#Copy),MMA_Rest,MMA_M_Rest,MMA_K_Rest, SmemStages (In SMEM) + // tAdA : (Copy,#Copy),MMA_Rest,MMA_M_Rest,MMA_K_Rest, NumComputeMtxs(Emulation), SmemStages (In SMEM or TMEM) + // tBsB : (Copy,#Copy),MMA_Rest,MMA_N_Rest,MMA_K_Rest, SmemStages (In SMEM) + // tBsB : (Copy,#Copy),MMA_Rest,MMA_N_Rest,MMA_K_Rest, NumComputeMtxs(Complex,Emulation), SmemStages (In SMEM) + auto [unused_tAgA, dst_copy_A, tAsA, tAdACompute, + unused_tBgB, tBsB, tBsBCompute] = input_operands; + + // Create the tensors in registers + auto tArA = make_tensor(tAsA(_,_,_,_,0).shape()); + auto tArA_temp = make_tensor(tAsA(_,_,_,_,0).shape()); + auto tArACompute = make_tensor(tAsA(_,_,_,_,0).shape()); + + auto tBrB = make_tensor(tBsB(_,_,_,_,0).shape()); + auto tBrB_temp = make_tensor(tBsB(_,_,_,_,0).shape()); + auto tBrBCompute = make_tensor(tBsB(_,_,_,_,0).shape()); + + // For compute, cast to 4 raw elements instead of 2 complex elements. + auto tArA_x4 = recast>(tArA); + auto tArA_temp_x4 = recast>(tArA_temp); + auto tArACompute_x4 = recast>(tArACompute); + + auto tBrB_x4 = recast>(tBrB); + auto tBrB_temp_x4 = recast>(tBrB_temp); + auto tBrBCompute_x4 = recast>(tBrBCompute); + + uint32_t skip_wait = (k_tile_count <= 0); + auto load2transform_flag = load2transform_pipeline.consumer_try_wait(load2transform_pipeline_consumer_state, skip_wait); + auto transform2mma_flag = transform2mma_pipeline.producer_try_acquire(transform2mma_pipeline_producer_state, skip_wait); + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + + load2transform_pipeline.consumer_wait(load2transform_pipeline_consumer_state, load2transform_flag); + transform2mma_pipeline.producer_acquire(transform2mma_pipeline_producer_state, transform2mma_flag); + + int load2transform_consumer_index = load2transform_pipeline_consumer_state.index(); + int transform2mma_producer_index = transform2mma_pipeline_producer_state.index(); + + auto curr_load2transform_pipeline_consumer_state = load2transform_pipeline_consumer_state; + auto curr_transform2mma_pipeline_producer_state = transform2mma_pipeline_producer_state; + + // Copy the input B matrix from SMEM + copy(AutoVectorizingCopy{}, tBsB(_,_,_,_,load2transform_consumer_index), tBrB); + // Copy the input A matrix from SMEM + copy(AutoVectorizingCopy{}, tAsA(_,_,_,_,load2transform_consumer_index), tArA); + + /// NOTE: sm100_mma_warpspecialized_interleaved_complex_tf32.hpp introduced about expanding: + /// re(a_complex * b_complex) -> (a_re, a_im) . (b_re,-b_im) = a . b_conj + /// im(a_complex * b_complex) -> (a_re, a_im) . (b_im, b_re) = a . b_swap + /// However, 16b types need to be packed for swapping and negation. + /// Hence, (re | im | re | im) is reordered into (re_x2 | im_x2). + cute::transform(tBrB_x4, tBrB_x4, [&] (auto& f4) -> Array {return {f4[0], f4[2], f4[1], f4[3]};}); + // Conversion b -> b_conj goes first, hence TransformB has a not preceding it. + if constexpr (not cute::is_same_v) { + cute::transform(tBrB_x4, tBrB_x4, [&] (auto& f4) { + auto f2_x2 = *reinterpret_cast,2>*>(&f4); + f2_x2[1] = cutlass::negate>{}(f2_x2[1]); + return *reinterpret_cast*>(&f2_x2); + }); + } + CUTE_UNROLL + for (int comp_mtx_index = 0; comp_mtx_index < NumComputeMtxs; ++comp_mtx_index) { + // Convert from fp32 -> bf16 + cute::transform(tBrB_x4, tBrBCompute_x4, + cutlass::NumericArrayConverter::convert); + // Store as B_conj (for producing C_re) + copy(AutoVectorizingCopy{}, tBrBCompute, tBsBCompute(_,_,_,_,0,comp_mtx_index,transform2mma_producer_index)); + // if it is not the last compute matrix, scale and substract + if (comp_mtx_index < NumComputeMtxs - 1) { + // Convert from bf16 -> fp32 to substract + cute::transform(tBrBCompute_x4, tBrB_temp_x4, + cutlass::NumericArrayConverter::convert); + cute::transform(tBrB_x4, tBrB_temp_x4, tBrB_x4, cutlass::minus>{}); + if constexpr (DispatchPolicy::ScalingFactor != 0) { + cute::transform(tBrB_x4, tBrB_x4, cutlass::scale>{(1 << DispatchPolicy::ScalingFactor)}); + } + } + + // Convert B_conj to B_swap + cute::transform(tBrBCompute_x4, tBrBCompute_x4, [&] (auto& h4) { + // Reinterpret as packed types + auto h2_x2_conj = *reinterpret_cast,2>*>(&h4); + cutlass::negate> neg; + Array,2> h2_x2_swap{ neg(h2_x2_conj[1]), h2_x2_conj[0] }; + return *reinterpret_cast*>(&h2_x2_swap); + }); + // Store as B_swap (for producing C_im) + copy(AutoVectorizingCopy{}, tBrBCompute, tBsBCompute(_,_,_,_,1,comp_mtx_index,transform2mma_producer_index)); + } + + // Loads from SMEM are done. Signal the mainloop load as early as possible + transform_bar.sync(); + load2transform_pipeline.consumer_release(curr_load2transform_pipeline_consumer_state); + + // ( re | im | re | im ) -> ( re_x2 | im_x2 ) + cute::transform(tArA_x4, tArA_x4, [&] (auto& f4) -> Array{return {f4[0], f4[2], f4[1], f4[3]};}); + if constexpr (cute::is_same_v) { + cute::transform(tArA_x4, tArA_x4, [&] (auto& f4) { + auto f2_x2 = *reinterpret_cast,2>*>(&f4); + f2_x2[1] = cutlass::negate>{}(f2_x2[1]); + return *reinterpret_cast*>(&f2_x2); + }); + } + CUTE_UNROLL + for (int comp_mtx_index = 0; comp_mtx_index < NumComputeMtxs; ++comp_mtx_index) { + // Convert from fp32 -> bf16 + cute::transform(tArA_x4, tArACompute_x4, + cutlass::NumericArrayConverter::convert); + copy(dst_copy_A, tArACompute, tAdACompute(_,_,_,_,comp_mtx_index,transform2mma_producer_index)); + // if it is not the last compute matrix, scale and substract + if (comp_mtx_index < NumComputeMtxs - 1) { + // Convert from bf16 -> fp32 to substract + cute::transform(tArACompute_x4, tArA_temp_x4, + cutlass::NumericArrayConverter::convert); + cute::transform(tArA_x4, tArA_temp_x4, tArA_x4, cutlass::minus>{}); + if constexpr (DispatchPolicy::ScalingFactor != 0) { + cute::transform(tArA_x4, tArA_x4, cutlass::scale>{(1 << DispatchPolicy::ScalingFactor)}); + } + } + } + + // fence for SMEM writes + cutlass::arch::fence_view_async_shared(); + if constexpr (is_tmem::value) { + // fence for TMEM writes if A operand is coming from TMEM + cutlass::arch::fence_view_async_tmem_store(); + } + + // Let the MMA know we are done transforming + transform2mma_pipeline.producer_commit(curr_transform2mma_pipeline_producer_state); + // Next pipeline stage + ++load2transform_pipeline_consumer_state; + ++transform2mma_pipeline_producer_state; + + skip_wait = (k_tile_count <= 1); + // Peek the next pipeline stage's barriers + load2transform_flag = load2transform_pipeline.consumer_try_wait(load2transform_pipeline_consumer_state, skip_wait); + transform2mma_flag = transform2mma_pipeline.producer_try_acquire(transform2mma_pipeline_producer_state, skip_wait); + } + return cute::make_tuple(load2transform_pipeline_consumer_state, transform2mma_pipeline_producer_state); + } + + template + CUTLASS_DEVICE auto + transform_init( + Params const& params, + ProblemShape_MNKL const& problem_shape_MNKL, + Accumulator accumulators, + TensorStorage& shared_storage) { + auto [gA_mkl, gB_nkl] = tile_input_tensors(params, problem_shape_MNKL); + + Tensor sA_orig = make_tensor(make_smem_ptr(shared_storage.input.smem_A.begin()), SmemLayoutA{}); + Tensor sA = as_position_independent_swizzle_tensor(sA_orig); + Tensor sACompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_ACompute.begin()), SmemLayoutACompute{}); + + Tensor sB_orig = make_tensor(make_smem_ptr(shared_storage.input.smem_B.begin()), SmemLayoutB{}); + Tensor sB = as_position_independent_swizzle_tensor(sB_orig); + Tensor sBCompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_BCompute.begin()), SmemLayoutBCompute{}); + + // Map input, compute, and fragment tensors to + // Copy strategies and partitioned tensors. These will become the input + // operands of the transform function. Depending on MMA atom type, the + // operands can reside in SMEM or TMEM + auto setup_copy_ops = [&] ( + auto tensor_input, + auto input_copy_atom, + auto tensor_compute, + auto make_fragment, + auto compute_copy_atom) constexpr { + auto fragment_compute = make_fragment(tensor_compute); + if constexpr (cute::is_tmem>::value) { + // For M=128 with 2CTA MMA atoms, the TMEM tensor for A has a duplicated allocation. + // Instead of allocation a 64x16 TMEM tensor, we have a 128x16 allocation + // See: TmemAllocMode::Duplicated. + Tensor tensor_input2x = [&] () constexpr { + if constexpr (decltype(size<0,0>(fragment_compute) == Int<128>{} && size<0,0>(tensor_input) == Int<64>{})::value) { + return make_tensor(tensor_input.data(), + logical_product(tensor_input.layout(), + make_tile(make_tile(Layout<_2,_0>{},_),_,_,_))); // ((128,16),m,k,PIPE) + } + else { + return tensor_input; + } + }(); + + fragment_compute.data() = accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + // If operand comes from TMEM, create the TMEM_STORE based copy + auto reg2tmem_tiled_copy = make_tmem_copy(compute_copy_atom, fragment_compute(_,_,_,0,0)); + auto thr_reg2tmem_tiled_copy = reg2tmem_tiled_copy.get_slice(threadIdx.x % NumTransformationThreads); + auto partitioned_tensor_input = thr_reg2tmem_tiled_copy.partition_S(tensor_input2x); + auto partitioned_tensor_compute = thr_reg2tmem_tiled_copy.partition_D(fragment_compute); + return cute::make_tuple(reg2tmem_tiled_copy, partitioned_tensor_input, partitioned_tensor_compute); + } + else { + // If the operand comes from SMEM, create SMEM copy. + auto tensor_compute_ind_sw = as_position_independent_swizzle_tensor(tensor_compute); + auto reg2smem_tiled_copy = make_cotiled_copy(compute_copy_atom, Layout, Stride< _4,_1>>{}, + take<0,3>(tensor_compute.layout())); + + // Source copy is based on the source operand of copy. + auto thr_reg2smem_tiled_copy = reg2smem_tiled_copy.get_slice(threadIdx.x % NumTransformationThreads); + auto partitioned_tensor_input = thr_reg2smem_tiled_copy.partition_S(tensor_input); + auto partitioned_tensor_compute = thr_reg2smem_tiled_copy.partition_D(tensor_compute_ind_sw); + + return cute::make_tuple(AutoVectorizingCopy{}, partitioned_tensor_input, partitioned_tensor_compute); + } + }; + + auto [dst_copy_A, tAsA, tAsACompute] = + setup_copy_ops(sA, InputCopyAtomA{}, sACompute, [&](auto &arg) {return TiledMma::make_fragment_A(arg);}, ComputeCopyAtomA{}); + + auto [dst_copy_B, tBsB, tBsBCompute] = + setup_copy_ops(sB, InputCopyAtomB{}, sBCompute, [&](auto &arg) {return TiledMma::make_fragment_B(arg);}, ComputeCopyAtomB{}); + return cute::make_tuple(gA_mkl, dst_copy_A, tAsA, tAsACompute, + gB_nkl, tBsB, tBsBCompute); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgEngine, class FrgLayout, + class TensorA, class TensorB + > + CUTLASS_DEVICE auto + mma( + Transform2MmaPipeline transform2mma_pipeline, + Transform2MmaPipelineState transform2mma_pipeline_consumer_state, + Mma2AccumPipeline mma2accum_pipeline, + Mma2AccumPipelineState mma2accum_pipeline_producer_state, + cute::Tensor const& accumulators, + cute::tuple const& input_operands, + int k_tile_count + ) { + TiledMma tiled_mma; + + auto curr_transform2mma_pipeline_consumer_state = transform2mma_pipeline_consumer_state; + auto next_transform2mma_pipeline_consumer_state = transform2mma_pipeline_consumer_state; + uint32_t skip_wait = (k_tile_count <= 0); + auto transform2mma_flag = transform2mma_pipeline.consumer_try_wait(next_transform2mma_pipeline_consumer_state, skip_wait); + ++next_transform2mma_pipeline_consumer_state; + + // tCrA : (MMA), MMA_M, MMA_K, NumComputeMtxs, SmemStage (In SMEM or TMEM) + // We use SMEM stages to match #buffers in Load <-> Convert + // tCrB : (MMA), MMA_N, MMA_K, NumComputeMtxs, SmemStages (In SMEM) + auto const [tCrA, tCrB] = input_operands; + + using ZeroScaler = cute::integral_constant; + using Scaler = cute::integral_constant; + + int remaining_accum_promotions = k_tile_count * StagesPerTile; + uint32_t mma2accum_skip_wait = (remaining_accum_promotions <= 0); + auto mma2accum_flag = mma2accum_pipeline.producer_try_acquire(mma2accum_pipeline_producer_state, mma2accum_skip_wait); + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + + transform2mma_pipeline.consumer_wait(curr_transform2mma_pipeline_consumer_state, transform2mma_flag); + + int transform2mma_pipeline_consumer_state_index = curr_transform2mma_pipeline_consumer_state.index(); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); k_block += DispatchPolicy::AccPromotionInterval, --remaining_accum_promotions) { + // Accum stages are organized as (C_real | C_imag | C_real | C_imag | ...) + CUTLASS_PRAGMA_UNROLL + for (int re_im = 0; re_im < 2; ++re_im) { + mma2accum_pipeline.producer_acquire(mma2accum_pipeline_producer_state, mma2accum_flag); + + int mma2accum_pipeline_producer_state_index = mma2accum_pipeline_producer_state.index(); + auto tCtC = accumulators(_,_,_,mma2accum_pipeline_producer_state_index); + auto curr_mma2accum_pipeline_producer_state = mma2accum_pipeline_producer_state; + + ++mma2accum_pipeline_producer_state; + mma2accum_skip_wait = (remaining_accum_promotions <= 1) && (re_im == 1); + mma2accum_flag = mma2accum_pipeline.producer_try_acquire(mma2accum_pipeline_producer_state, mma2accum_skip_wait); + + auto tCrA0 = tCrA(_,_,_,0,transform2mma_pipeline_consumer_state_index); + auto tCrA1 = tCrA(_,_,_,1,transform2mma_pipeline_consumer_state_index); + auto tCrA2 = tCrA(_,_,_,2,transform2mma_pipeline_consumer_state_index); + + auto tCrB0 = tCrB(_,_,_,re_im,0,transform2mma_pipeline_consumer_state_index); + auto tCrB1 = tCrB(_,_,_,re_im,1,transform2mma_pipeline_consumer_state_index); + auto tCrB2 = tCrB(_,_,_,re_im,2,transform2mma_pipeline_consumer_state_index); + + // MMA instructions Emulation + auto accumulate = UMMA::ScaleOut::Zero; + // First set of GEMMs that we need to perform for each band are unrolled to set compile-time constant + // scaling parameter. Scaled GEMM operations are only needed for the first MMA operation of each band. + + // Band 5 + if constexpr (NumBandsToCompute == 5) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block), tCrB2(_,_,k_block), tCtC); // A[2]*B[2] + accumulate = UMMA::ScaleOut::One; + CUTLASS_PRAGMA_UNROLL + for (int s = 1; s < DispatchPolicy::AccPromotionInterval; s++) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block+s), tCrB2(_,_,k_block+s), tCtC); // A[2]*B[2] + } + } + // Band 4 + if constexpr (NumBandsToCompute >= 4) { + cute::gemm(tiled_mma.with(accumulate, Scaler{}), tCrA1(_,_,k_block), tCrB2(_,_,k_block), tCtC); // A[1]*B[2] + accumulate = UMMA::ScaleOut::One; + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block), tCrB1(_,_,k_block), tCtC); // A[2]*B[1] + CUTLASS_PRAGMA_UNROLL + for (int s = 1; s < DispatchPolicy::AccPromotionInterval; s++) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA1(_,_,k_block+s), tCrB2(_,_,k_block+s), tCtC); // A[1]*B[2] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block+s), tCrB1(_,_,k_block+s), tCtC); // A[2]*B[1] + } + } + // Band 3 + cute::gemm(tiled_mma.with(accumulate, Scaler{}), tCrA0(_,_,k_block), tCrB2(_,_,k_block), tCtC); // A[2]*B[0] + accumulate = UMMA::ScaleOut::One; + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA1(_,_,k_block), tCrB1(_,_,k_block), tCtC); // A[1]*B[1] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block), tCrB0(_,_,k_block), tCtC); // A[0]*B[2] + CUTLASS_PRAGMA_UNROLL + for (int s = 1; s < DispatchPolicy::AccPromotionInterval; s++) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA0(_,_,k_block+s), tCrB2(_,_,k_block+s), tCtC); // A[2]*B[0] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA1(_,_,k_block+s), tCrB1(_,_,k_block+s), tCtC); // A[1]*B[1] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block+s), tCrB0(_,_,k_block+s), tCtC); // A[0]*B[2] + } + // Band 2 + cute::gemm(tiled_mma.with(accumulate, Scaler{}), tCrA0(_,_,k_block), tCrB1(_,_,k_block), tCtC); // A[0]*B[1] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA1(_,_,k_block), tCrB0(_,_,k_block), tCtC); // A[1]*B[0] + CUTLASS_PRAGMA_UNROLL + for (int s = 1; s < DispatchPolicy::AccPromotionInterval; s++) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA0(_,_,k_block+s), tCrB1(_,_,k_block+s), tCtC); // A[0]*B[1] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA1(_,_,k_block+s), tCrB0(_,_,k_block+s), tCtC); // A[1]*B[0] + } + // Band 1 + cute::gemm(tiled_mma.with(accumulate, Scaler{}), tCrA0(_,_,k_block), tCrB0(_,_,k_block), tCtC); // A[0]*B[0] + CUTLASS_PRAGMA_UNROLL + for (int s = 1; s < DispatchPolicy::AccPromotionInterval; s++) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA0(_,_,k_block+s), tCrB0(_,_,k_block+s), tCtC); // A[0]*B[0] + } + mma2accum_pipeline.producer_commit(curr_mma2accum_pipeline_producer_state); + } + } + + transform2mma_pipeline.consumer_release(curr_transform2mma_pipeline_consumer_state); + + skip_wait = (k_tile_count <= 1); + transform2mma_flag = transform2mma_pipeline.consumer_try_wait(next_transform2mma_pipeline_consumer_state, skip_wait); + + curr_transform2mma_pipeline_consumer_state = next_transform2mma_pipeline_consumer_state; + ++next_transform2mma_pipeline_consumer_state; + } + return cute::make_tuple(curr_transform2mma_pipeline_consumer_state, mma2accum_pipeline_producer_state); + } + + template + CUTLASS_DEVICE auto + mma_init(cute::Tensor const& accumulators, TensorStorage& shared_storage) const { + TiledMma tiled_mma; + + Tensor tCrA = [&] () constexpr { + if constexpr (cute::is_base_of::value) { + Tensor sACompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_ACompute.begin()), SmemLayoutACompute{}); + return tiled_mma.make_fragment_A(sACompute); + } + else { + auto tCrA = tiled_mma.make_fragment_A(shape(SmemLayoutACompute{})); + tCrA.data() = accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + return tCrA; + } + } (); + Tensor sBCompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_BCompute.begin()), SmemLayoutBCompute{}); + Tensor tCrB = tiled_mma.make_fragment_B(sBCompute); + return cute::make_tuple(tCrA, tCrB); + } + + template + CUTLASS_DEVICE auto + accum_init(cute::Tensor const& accumulators, TmemCopyAtom tmem_cp_atom, EpilogueTile epilogue_tile) { + // Obtain a single accumulator + Tensor tAcc = tensor<0>(accumulators(_,_,_,_0{})); + // Apply epilogue subtiling + Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + // Create the TMEM copy for single EpilogueTile. + // Note that EpilogueTile = CtaTile for NoSmem epilogue + auto tiled_t2r = make_tmem_copy(tmem_cp_atom, tAcc_epi(_,_,_0{},_0{})); + auto thread_t2r = tiled_t2r.get_slice(threadIdx.x % size(tiled_t2r)); + Tensor tTR_gC = thread_t2r.partition_D(tAcc_epi); + Tensor tTR_rAcc = make_tensor(shape(tTR_gC)); // (T2R,T2R_M,T2R_N) + Tensor tTR_rGlobAcc = make_tensor(shape(tTR_gC)); // (T2R,T2R_M,T2R_N) + + // Apply epilogue subtiling to bulk accumulator + // We need to tile the whole bulk_tmem allocation with EpilogueTile. + // The accumulation should be aware of the AccumulatorPipelineStages + Tensor tBulkAcc_epi = flat_divide(accumulators(make_coord(_,_),_0{},_0{},_), EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N,PIPE) + Tensor tTR_tBulkAcc = thread_t2r.partition_S(tBulkAcc_epi); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N,PIPE) + return cute::make_tuple(tiled_t2r, thread_t2r, tTR_tBulkAcc, tTR_rAcc, tTR_rGlobAcc); + } + + template + CUTLASS_DEVICE auto + accum(cute::tuple accum_inputs, + Mma2AccumPipeline mma2accum_pipeline, + Mma2AccumPipelineState mma2accum_pipeline_consumer_state, + int k_tile_count) { + auto [tiled_t2r, thread_t2r, tTR_tBulkAcc, + tTR_rAcc, tTR_rGlobAcc] = accum_inputs; + + Tensor tTR_rAcc_float2 = recast>(tTR_rAcc); // (T2R/2,T2R_M,T2R_N) + Tensor tTR_rGlobAcc_float2_x2 = recast,2>>(tTR_rGlobAcc);// (T2R/2,T2R_M,T2R_N) + + // Clear the global accumulator + CUTE_UNROLL + for (int i = 0; i 0; --k_tile_count) { + // The stage is limited to a CTA tile + CUTLASS_PRAGMA_NO_UNROLL + for (int k_block = 0; k_block cute::remove_cvref_t {return {cutlass::plus>{}(f2_x2[0], r2), f2_x2[1]};}); + } + else { + cute::transform(tTR_rGlobAcc_float2_x2, tTR_rAcc_float2, tTR_rGlobAcc_float2_x2, + [&] (auto& f2_x2, auto& i2) -> cute::remove_cvref_t {return {f2_x2[0], cutlass::plus>{}(f2_x2[1], i2)};}); + } + + cutlass::arch::fence_view_async_tmem_load(); // Need a fence bw TMEM_LOAD and arrive + mma2accum_pipeline.consumer_release(mma2accum_pipeline_consumer_state); + + ++mma2accum_pipeline_consumer_state; + skip_wait = ((k_tile_count <= 1) && (k_block >= (StagesPerTile-1))) && (re_im == 1); + mma2accum_flag = mma2accum_pipeline.consumer_try_wait(mma2accum_pipeline_consumer_state, skip_wait); + } + } + } + + // Interleave back (real_x2 | imag_x2) to (real | imag | real | imag) + cute::transform(tTR_rGlobAcc_float2_x2, tTR_rGlobAcc_float2_x2, [&] (auto& f2_x2) -> cute::remove_cvref_t { + Array c0{f2_x2[0][0], f2_x2[1][0]}; + Array c1{f2_x2[0][1], f2_x2[1][1]}; + return {c0, c1}; + }); + + return cute::make_tuple(mma2accum_pipeline_consumer_state, tTR_rGlobAcc); + } + +protected: + + template + CUTLASS_DEVICE + constexpr auto + tile_input_tensors(Params const& params, ProblemShape_MNKL const& problem_shape_MNKL) const { + using X = cute::Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,L)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K,L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); + + return cute::make_tuple(gA_mkl, gB_nkl); + } + + typename Params::TMA_A const* observed_tma_load_a_ = nullptr; + typename Params::TMA_B const* observed_tma_load_b_ = nullptr; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_interleaved_complex_tf32.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_interleaved_complex_tf32.hpp new file mode 100644 index 0000000..a1f2501 --- /dev/null +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_interleaved_complex_tf32.hpp @@ -0,0 +1,880 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +#pragma once +#include + +#include "cutlass/cutlass.h" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/numeric_conversion.h" +#include "cutlass/detail/sm100_tmem_helper.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/arch/mma_sm100.hpp" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +namespace detail { +template +struct Sm100CollectiveMmaComplexLayoutAtomType { + using InputLayoutAtom = InputLayoutAtom_; + using ComputeLayoutAtom = ComputeLayoutAtom_; +}; + +template +struct Sm100CollectiveMmaComplexCopyType { + using InputCopyAtom = InputCopyAtom_; + using ComputeCopyAtom = ComputeCopyAtom_; +}; +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop for complex kernels +template < + int ComputationPipelineStageCount_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + int TransformationPipelineStageCount_, + class AccumulatorCopyAtom_, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class StrideA_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomsA_, + class CopyAtomsA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomsB_, + class CopyAtomsB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100TmaUmmaWarpSpecializedInterleavedComplexTF32< + ComputationPipelineStageCount_, + SchedulerPipelineStageCount_, + AccumulatorPipelineStageCount_, + TransformationPipelineStageCount_, + ClusterShape, + AccumulatorCopyAtom_>, + TileShape_, + complex, + StrideA_, + complex, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomsA_, + CopyAtomsA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomsB_, + CopyAtomsB_, + TransformB_> +{ + // + // Type Aliases + // + using TileShape = TileShape_; + using TiledMma = TiledMma_; + + // ElementA and ElementB are cutlass::complex, which are used as GMEM input and output data type. + using ElementA = complex; + using StrideA = StrideA_; + using ElementB = complex; + using StrideB = StrideB_; + +private: + // ElementAMma and ElementBMma are cutlass::complex, which are used as SMEM and RF data type. + // ElementAMmaRaw and ElementBMmaRaw are cutlass::tfloat32_t, which is the real internal data type set in TMA descriptor and used in TCGEN05 calculation. + using ElementAMma = typename TiledMma::ValTypeA; // complex + using ElementAMmaRaw = typename ElementAMma::value_type; // tfloat32_t + using ElementBMma = typename TiledMma::ValTypeB; // complex + using ElementBMmaRaw = typename ElementBMma::value_type; // tfloat32_t + +public: + // For a complex kernel, the MMA output type is real valued, but ElementAccumulator is a complex type for the GETT reference kernel + using ElementAccumulator = cutlass::complex; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomsA = SmemLayoutAtomsA_; + using SmemLayoutAtomsB = SmemLayoutAtomsB_; + using CopyAtomsA = CopyAtomsA_; + using CopyAtomsB = CopyAtomsB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + + // Determine MMA type: MMA_1SM vs MMA_2SM + using AtomThrShapeMNK = Shape(typename TiledMma_::ThrLayoutVMNK{})), _1, _1>; + using DispatchPolicy = MainloopSm100TmaUmmaWarpSpecializedInterleavedComplexTF32< + ComputationPipelineStageCount_, + SchedulerPipelineStageCount_, + AccumulatorPipelineStageCount_, + TransformationPipelineStageCount_, + ClusterShape, + AccumulatorCopyAtom_>; + static constexpr bool IsDynamicCluster = not cute::is_static_v; + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + + // Define A and B block shapes for reduced size TMA_LOADs + using CtaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using CtaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + + using ArchTag = typename DispatchPolicy::ArchTag; + + using Load2TransformPipeline = cutlass::PipelineTmaTransformAsync< + DispatchPolicy::ComputationPipelineStageCount, + AtomThrShapeMNK>; + using Load2TransformPipelineState = typename Load2TransformPipeline::PipelineState; + + using Transform2MmaPipeline = cutlass::PipelineUmmaConsumerAsync< + DispatchPolicy::TransformationPipelineStageCount, + AtomThrShapeMNK>; + using Transform2MmaPipelineState = typename Transform2MmaPipeline::PipelineState; + + using Mma2AccumPipeline = cutlass::PipelineUmmaAsync< + DispatchPolicy::Schedule::AccumulatorPipelineStageCount, + AtomThrShapeMNK>; + using Mma2AccumPipelineState = typename Mma2AccumPipeline::PipelineState; + + // Thread Counts + static constexpr uint32_t NumTransformationThreads = 128; + static constexpr uint32_t NumAccumThreads = 128; + + // Register reconfiguration + static constexpr uint32_t GenericRegisterRequirement = 152; + static constexpr uint32_t TransformRegisterRequirement = 200; + static constexpr uint32_t AccumRegisterRequirement = 152; + + // Get the Algorithm parameters + constexpr static int NumComputeMtxs = 2; + constexpr static int AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; + constexpr static int StagesPerTile = size<2>(CtaShapeA_MK{}); + + // Copy atom for Accumulator + using AccumulatorCopyAtom = typename DispatchPolicy::AccumulatorCopyAtom; + + using SmemLayoutAtomA = typename SmemLayoutAtomsA::InputLayoutAtom; + using SmemLayoutAtomACompute = typename SmemLayoutAtomsA::ComputeLayoutAtom; + using SmemLayoutAtomB = typename SmemLayoutAtomsB::InputLayoutAtom; + using SmemLayoutAtomBCompute = typename SmemLayoutAtomsB::ComputeLayoutAtom; + + using InputCopyAtomA = typename CopyAtomsA::InputCopyAtom; + using ComputeCopyAtomA = typename CopyAtomsA::ComputeCopyAtom; + using InputCopyAtomB = typename CopyAtomsB::InputCopyAtom; + using ComputeCopyAtomB = typename CopyAtomsB::ComputeCopyAtom; + + static_assert(((size<0,0>(CtaShapeA_MK{}) * size<1>(CtaShapeA_MK{})) % size<0>(SmemLayoutAtomACompute{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(CtaShapeA_MK{}) * size<2>(CtaShapeA_MK{})) % size<1>(SmemLayoutAtomACompute{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(((size<0,0>(CtaShapeB_NK{}) * size<1>(CtaShapeB_NK{})) % size<0>(SmemLayoutAtomBCompute{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(CtaShapeB_NK{}) * size<2>(CtaShapeB_NK{})) % size<1>(SmemLayoutAtomBCompute{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(CtaShapeA_MK{}, Int{}), + (cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}))); + + using SmemLayoutACompute = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomACompute{}, + append(append(CtaShapeA_MK{}, Int{}), Int{}))); + + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(CtaShapeB_NK{}, Int{}), + (cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}))); + + using SmemLayoutBCompute = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomBCompute{}, + append(CtaShapeB_NK{}, Int{}))); + + static_assert(DispatchPolicy::ComputationPipelineStageCount >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(DispatchPolicy::TransformationPipelineStageCount >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must have A operand from TMEM and B operand from SMEM for this mainloop."); + static_assert((cute::is_same_v || cute::is_same_v), + "GmemTiledCopyA - invalid TMA copy atom specified."); + static_assert((cute::is_same_v || cute::is_same_v), + "GmemTiledCopyB - invalid TMA copy atom specified."); + + struct PipelineStorage { + using Load2TransformPipelineStorage = typename Load2TransformPipeline::SharedStorage; + alignas(16) Load2TransformPipelineStorage load2transform_pipeline; + using Transform2MmaPipelineStorage = typename Transform2MmaPipeline::SharedStorage; + alignas(16) Transform2MmaPipelineStorage transform2mma_pipeline; + using Mma2AccumPipelineStorage = typename Mma2AccumPipeline::SharedStorage; + alignas(16) Mma2AccumPipelineStorage mma2accum_pipeline; + }; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + struct TensorStorageUntransformed { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + } input; + + union TensorStorageTransformed { + alignas(1024) cute::ArrayEngine smem_ACompute; // smem_ACompute is actually in tmem + alignas(1024) cute::ArrayEngine> smem_BCompute; + } compute; + } tensors; + + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + + // Different from other GEMM kernels, both CTAs should be aware of loads. Both CTAs will work on + // loaded input A and B matrices to convert the data type + static constexpr uint32_t TmaTransactionBytes = + (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * size<2>(SmemLayoutA{}) * static_cast(sizeof(ElementAMma))) + + (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * size<2>(SmemLayoutB{}) * static_cast(sizeof(ElementBMma))); + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A{nullptr}; + StrideA dA{}; + ElementB const* ptr_B{nullptr}; + StrideB dB{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), + make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + using TMA_B = decltype(make_tma_atom_B_sm100( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + dim3 cluster_shape_fallback; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) { + if constexpr (IsDynamicCluster) { + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + } + } + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + Tensor tensor_a = make_tensor(args.ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(args.ptr_B, make_layout(make_shape(N,K,L), args.dB)); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + // Cluster layout for TMA construction + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + return { + tma_load_a, + tma_load_b, + tma_load_a_fallback, + tma_load_b_fallback, + hw_info.cluster_shape_fallback + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE static void + prefetch_tma_descriptors(Params const& params) { + if constexpr (IsDynamicCluster) { + dim3 cs = cute::cluster_shape(); + const bool is_fallback_cluster = (cs.x == params.cluster_shape_fallback.x && cs.y == params.cluster_shape_fallback.y); + if (is_fallback_cluster) { + cute::prefetch_tma_descriptor(params.tma_load_a_fallback.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b_fallback.get_tma_descriptor()); + } + else { + cute::prefetch_tma_descriptor(params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b.get_tma_descriptor()); + } + } + else { + cute::prefetch_tma_descriptor(params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b.get_tma_descriptor()); + } + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE auto + partition_accumulator_shape() { + return append( + partition_shape_C(TiledMma{}, take<0,2>(TileShape{})), + Int<2>{}); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,TMEM_PIPE,2) + } + + /// Produce the inputs to the transform threads by loading inputs from gmem -> smem + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE cute::tuple + load( + Params const& params, + Load2TransformPipeline pipeline, + Load2TransformPipelineState load2xform_pipeline_state, + cute::tuple const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + auto [unused_gA, unused_gB, + tAgA_mkl, tBgB_nkl, tAsA, tBsB, + mcast_mask_a, mcast_mask_b] = load_inputs; + + // slice out the work coord from tiled tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + uint32_t skip_wait = (k_tile_count <= 0); + auto pipeline_flag = pipeline.producer_try_acquire(load2xform_pipeline_state, skip_wait); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK mainloop_load2xform_pipeline_state for _writing_ + pipeline.producer_acquire(load2xform_pipeline_state, pipeline_flag); + int write_stage = load2xform_pipeline_state.index(); + + using BarrierType = typename Load2TransformPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(load2xform_pipeline_state); + + // Advance mainloop_pipe + ++load2xform_pipeline_state; + skip_wait = (k_tile_count <= 1); + pipeline_flag = pipeline.producer_try_acquire(load2xform_pipeline_state, skip_wait); + + copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(*tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage)); + ++k_tile_iter; + } + return cute::make_tuple(load2xform_pipeline_state, k_tile_iter); + } + + + /// Set up the data needed by this collective for load. + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tiled tensor for input A + /// gB_nkl - The tiled tensor for input B + // Other inputs needed for load(): partitioned AB tensors for gmem and smem, and mcast masks + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_storage) const { + auto [gA_mkl, gB_nkl] = tile_input_tensors(params, problem_shape_MNKL); + + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_storage.input.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_storage.input.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,3>(tCgB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + return cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b); // multicast masks + } + + template< + class KTileIterator, class Accumulator, + class GTensorA, class SrcCopyA, class DstCopyA, class SrcTensorA, class DstTensorA, + class GTensorB, class SrcCopyB, class DstCopyB, class SrcTensorB, class DstTensorB + > + CUTLASS_DEVICE auto + transform( + Load2TransformPipeline load2transform_pipeline, + Load2TransformPipelineState load2transform_pipeline_consumer_state, + Transform2MmaPipeline transform2mma_pipeline, + Transform2MmaPipelineState transform2mma_pipeline_producer_state, + Accumulator accumulators, + cute::tuple input_operands, + KTileIterator k_tile_iter, int k_tile_count) { + cutlass::arch::NamedBarrier transform_barrier(NumTransformationThreads, cutlass::arch::ReservedNamedBarriers::TransformBarrier); + + // tAsA : (Copy,#Copy),MMA_Rest,MMA_M_Rest,MMA_K_Rest, SmemStages (In SMEM) + // tAtACompute : (Copy,#Copy),MMA_Rest,MMA_M_Rest,MMA_K_Rest, NumComputeMtxs, SmemStages (In TMEM) + // tBsB : (Copy,#Copy),MMA_Rest,MMA_N_Rest,MMA_K_Rest, SmemStages (In SMEM) + // tBsBCompute : (Copy,#Copy),MMA_Rest,MMA_N_Rest,MMA_K_Rest, SmemStages (In SMEM) + auto [unused_tAgA, src_copy_A, dst_copy_A, tAsA, tAtACompute, + unused_tBgB, src_copy_B, dst_copy_B, tBsB, tBsBCompute] = input_operands; + + // Create the tensors in registers + auto tArA = make_tensor(tAsA(_,_,_,_,0).shape()); + auto tArA_conj = make_tensor(tAsA(_,_,_,_,0).shape()); + auto tArA_swap = make_tensor(tAsA(_,_,_,_,0).shape()); + auto tBrB = make_tensor(tBsB(_,_,_,_,0).shape()); + + uint32_t skip_wait = (k_tile_count <= 0); + auto load2transform_flag = load2transform_pipeline.consumer_try_wait(load2transform_pipeline_consumer_state, skip_wait); + auto transform2mma_flag = transform2mma_pipeline.producer_try_acquire(transform2mma_pipeline_producer_state, skip_wait); + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + + load2transform_pipeline.consumer_wait(load2transform_pipeline_consumer_state, load2transform_flag); + transform2mma_pipeline.producer_acquire(transform2mma_pipeline_producer_state, transform2mma_flag); + + int load2transform_consumer_index = load2transform_pipeline_consumer_state.index(); + int transform2mma_producer_index = transform2mma_pipeline_producer_state.index(); + + auto curr_load2transform_pipeline_consumer_state = load2transform_pipeline_consumer_state; + auto curr_transform2mma_pipeline_producer_state = transform2mma_pipeline_producer_state; + + // Copy the input A matrix from SMEM + copy(src_copy_A, tAsA(_,_,_,_,load2transform_consumer_index), tArA); + // Copy the input B matrix from SMEM + copy(src_copy_B, tBsB(_,_,_,_,load2transform_consumer_index), tBrB); + + // First MMA, A.real * B.real - A.imag * B.imag + // Compose [real, -imag] copy for A TMEM + // Reflect the conjugation of B through A + if constexpr (cute::is_same_v) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tArA); i++) { + tArA_conj(i) = {tArA(i).real(), -tArA(i).imag()}; + } + } + else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tArA); i++) { + tArA_conj(i) = tArA(i); + } + } + // Write to TMEM + copy(dst_copy_A, tArA_conj, tAtACompute(_,_,_,_,0,transform2mma_producer_index)); + + // Second MMA, A.imag * B.real + A.real * B.imag + // Compose [imag, real] copy for A TMEM + // Reflect the conjugation of B through A + auto transform_element = [] (ElementAMma const& tArA_i) -> ElementAMma { + if constexpr (cute::is_same_v && cute::is_same_v) { // CC + return {-tArA_i.imag(), -tArA_i.real()}; + } + else if constexpr (cute::is_same_v && not cute::is_same_v) { // CN/CT + return {-tArA_i.imag(), tArA_i.real()}; + } + else if constexpr (not cute::is_same_v && cute::is_same_v) { // NC/TC + return {tArA_i.imag(), -tArA_i.real()}; + } + else { // TN/NT/NN/TT + return {tArA_i.imag(), tArA_i.real()}; + } + }; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tArA); i++) { + tArA_swap(i) = transform_element(tArA(i)); + } + + // Write to TMEM + copy(dst_copy_A, tArA_swap, tAtACompute(_,_,_,_,1,transform2mma_producer_index)); + + // Write the B matrix to SMEM without any changes + copy(dst_copy_B, tBrB, tBsBCompute(_,_,_,_,transform2mma_producer_index)); + + // Loads from SMEM are done. Signal the mainloop load as early as possible + transform_barrier.sync(); + load2transform_pipeline.consumer_release(curr_load2transform_pipeline_consumer_state); + + // fence for SMEM writes + cutlass::arch::fence_view_async_shared(); + if constexpr (is_tmem::value) { + // fence for TMEM writes if A operand is coming from TMEM + cutlass::arch::fence_view_async_tmem_store(); + } + + // Let the MMA know we are done transforming + transform2mma_pipeline.producer_commit(curr_transform2mma_pipeline_producer_state); + // Next pipeline stage + ++load2transform_pipeline_consumer_state; + ++transform2mma_pipeline_producer_state; + + skip_wait = (k_tile_count <= 1); + // Peek the next pipeline stage's barriers + load2transform_flag = load2transform_pipeline.consumer_try_wait(load2transform_pipeline_consumer_state, skip_wait); + transform2mma_flag = transform2mma_pipeline.producer_try_acquire(transform2mma_pipeline_producer_state, skip_wait); + } + return cute::make_tuple(load2transform_pipeline_consumer_state, transform2mma_pipeline_producer_state); + } + + template + CUTLASS_DEVICE auto + transform_init( + Params const& params, + ProblemShape_MNKL const& problem_shape_MNKL, + Accumulator accumulators, + TensorStorage& shared_storage) { + auto [gA_mkl, gB_nkl] = tile_input_tensors(params, problem_shape_MNKL); + + Tensor sA_orig = make_tensor(make_smem_ptr(shared_storage.input.smem_A.begin()), SmemLayoutA{}); + Tensor sA = as_position_independent_swizzle_tensor(sA_orig); + Tensor sACompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_ACompute.begin()), SmemLayoutACompute{}); + + Tensor sB_orig = make_tensor(make_smem_ptr(shared_storage.input.smem_B.begin()), SmemLayoutB{}); + Tensor sB = as_position_independent_swizzle_tensor(sB_orig); + Tensor sBCompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_BCompute.begin()), SmemLayoutBCompute{}); + + // Map input, compute, and fragment tensors to + // Copy strategies and partitioned tensors. These will become the input + // operands of the transform function. Depending on MMA atom type, the + // operands can reside in SMEM or TMEM + auto setup_copy_ops = [&] (auto tensor_input, auto input_copy_atom, + auto tensor_compute, auto make_fragment, auto compute_copy_atom) constexpr { + auto fragment_compute = make_fragment(tensor_compute); + if constexpr (cute::is_tmem>::value) { + // For M=128 with 2CTA MMA atoms, the TMEM tensor for A has a duplicated allocation. + // Instead of allocation a 64x16 TMEM tensor, we have a 128x16 allocation + // See: TmemAllocMode::Duplicated. + Tensor tensor_input2x = [&] () constexpr { + if constexpr (decltype(size<0,0>(fragment_compute) == Int<128>{} && size<0,0>(tensor_input) == Int<64>{})::value) { + return make_tensor(tensor_input.data(), + logical_product(tensor_input.layout(), + make_tile(make_tile(Layout<_2,_0>{},_),_,_,_))); // ((128,16),m,k,PIPE) + } + else { + return tensor_input; + } + }(); + + fragment_compute.data() = accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + // If operand comes from TMEM, create the TMEM_STORE based copy + auto reg2tmem_tiled_copy = make_tmem_copy(compute_copy_atom, fragment_compute(_,_,_,0,0)); + auto thr_reg2tmem_tiled_copy = reg2tmem_tiled_copy.get_slice(threadIdx.x % NumTransformationThreads); + auto partitioned_tensor_input = thr_reg2tmem_tiled_copy.partition_S(tensor_input2x); + auto partitioned_tensor_compute = thr_reg2tmem_tiled_copy.partition_D(fragment_compute); + // Source copy is based on the source operand of TMEM_STORE copy. + auto smem2reg_tiled_copy = make_tiled_copy_S(Copy_Atom, ElementAMma>{}, reg2tmem_tiled_copy); + return cute::make_tuple(smem2reg_tiled_copy, reg2tmem_tiled_copy, partitioned_tensor_input, partitioned_tensor_compute); + } + else { + // If the operand comes from SMEM, create SMEM copy. + auto tensor_compute_ind_sw = as_position_independent_swizzle_tensor(tensor_compute); + auto reg2smem_tiled_copy = make_cotiled_copy(compute_copy_atom, Layout, Stride< _8,_1>>{}, + tensor_compute(_,_,_,0).layout()); + + // Source copy is based on the source operand of copy. + auto smem2reg_tiled_copy = make_tiled_copy_S(input_copy_atom, reg2smem_tiled_copy); + auto thr_smem2reg_tiled_copy = smem2reg_tiled_copy.get_slice(threadIdx.x % NumTransformationThreads); + auto thr_reg2smem_tiled_copy = reg2smem_tiled_copy.get_slice(threadIdx.x % NumTransformationThreads); + auto partitioned_tensor_input = thr_reg2smem_tiled_copy.partition_S(tensor_input); + auto partitioned_tensor_compute = thr_reg2smem_tiled_copy.partition_D(tensor_compute_ind_sw); + + return cute::make_tuple(smem2reg_tiled_copy, reg2smem_tiled_copy, partitioned_tensor_input, partitioned_tensor_compute); + } + }; + + auto [src_copy_A, dst_copy_A, tAsA, tAtACompute] = + setup_copy_ops(sA, InputCopyAtomA{}, sACompute, [&](auto &arg) {return TiledMma::make_fragment_A(arg);}, ComputeCopyAtomA{}); + + auto [src_copy_B, dst_copy_B, tBsB, tBsBCompute] = + setup_copy_ops(sB, InputCopyAtomB{}, sBCompute, [&](auto &arg) {return TiledMma::make_fragment_B(arg);}, ComputeCopyAtomB{}); + + return cute::make_tuple(gA_mkl, src_copy_A, dst_copy_A, tAsA, tAtACompute, + gB_nkl, src_copy_B, dst_copy_B, tBsB, tBsBCompute); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgEngine, class FrgLayout, + class TensorA, class TensorB + > + CUTLASS_DEVICE auto + mma( + Transform2MmaPipeline transform2mma_pipeline, + Transform2MmaPipelineState transform2mma_pipeline_consumer_state, + Mma2AccumPipeline mma2accum_pipeline, + Mma2AccumPipelineState mma2accum_pipeline_producer_state, + cute::Tensor const& accumulators, + cute::tuple const& input_operands, + int k_tile_count + ) { + TiledMma tiled_mma; + + // tCrA : (MMA), MMA_M, MMA_K, NumComputeMtxs, SmemStage (In TMEM) + // We use SMEM stages to match #buffers in Load <-> Convert + // tCrB : (MMA), MMA_N, MMA_K, SmemStages (In SMEM) + auto const [tCrA, tCrB] = input_operands; + + auto curr_transform2mma_pipeline_consumer_state = transform2mma_pipeline_consumer_state; + auto next_transform2mma_pipeline_consumer_state = transform2mma_pipeline_consumer_state; + uint32_t skip_wait = (k_tile_count <= 0); + auto transform2mma_flag = transform2mma_pipeline.consumer_try_wait(next_transform2mma_pipeline_consumer_state, skip_wait); + ++next_transform2mma_pipeline_consumer_state; + + mma2accum_pipeline.producer_acquire(mma2accum_pipeline_producer_state); + + constexpr int RealAccumIndex = 0; + constexpr int ImagAccumIndex = 1; + + int mma2accum_pipeline_producer_state_index = mma2accum_pipeline_producer_state.index(); + auto tCtC_real = accumulators(_,_,_,RealAccumIndex,mma2accum_pipeline_producer_state_index); + auto tCtC_imag = accumulators(_,_,_,ImagAccumIndex,mma2accum_pipeline_producer_state_index); + auto curr_mma2accum_pipeline_producer_state = mma2accum_pipeline_producer_state; + ++mma2accum_pipeline_producer_state; + + // + // PIPELINED MAIN LOOP + // + + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + + transform2mma_pipeline.consumer_wait(curr_transform2mma_pipeline_consumer_state, transform2mma_flag); + + int transform2mma_pipeline_consumer_state_index = curr_transform2mma_pipeline_consumer_state.index(); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < StagesPerTile; ++k_block) { + + auto tCrA_conj = tCrA(_,_,_,Int<0>{},transform2mma_pipeline_consumer_state_index); + auto tCrA_swap = tCrA(_,_,_,Int<1>{},transform2mma_pipeline_consumer_state_index); + + auto tCrB0 = tCrB(_,_,_,transform2mma_pipeline_consumer_state_index); + + // A conjugate * B + cute::gemm(tiled_mma, tCrA_conj(_,_,k_block), tCrB0(_,_,k_block), tCtC_real); // A[0]*B[0] + // A swapped * B + cute::gemm(tiled_mma, tCrA_swap(_,_,k_block), tCrB0(_,_,k_block), tCtC_imag); // A[0]*B[0] + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + transform2mma_pipeline.consumer_release(curr_transform2mma_pipeline_consumer_state); + + skip_wait = (k_tile_count <= 1); + transform2mma_flag = transform2mma_pipeline.consumer_try_wait(next_transform2mma_pipeline_consumer_state, skip_wait); + + curr_transform2mma_pipeline_consumer_state = next_transform2mma_pipeline_consumer_state; + ++next_transform2mma_pipeline_consumer_state; + } + + mma2accum_pipeline.producer_commit(curr_mma2accum_pipeline_producer_state); + + return cute::make_tuple(curr_transform2mma_pipeline_consumer_state, mma2accum_pipeline_producer_state); + } + + template + CUTLASS_DEVICE auto + mma_init(cute::Tensor const& accumulators, TensorStorage& shared_storage) const { + TiledMma tiled_mma; + + Tensor tCrA = [&] () constexpr { + if constexpr (cute::is_base_of::value) { + Tensor sACompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_ACompute.begin()), SmemLayoutACompute{}); + return tiled_mma.make_fragment_A(sACompute); + } + else { + auto tCrA = tiled_mma.make_fragment_A(shape(SmemLayoutACompute{})); + tCrA.data() = accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + return tCrA; + } + } (); + Tensor sBCompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_BCompute.begin()), SmemLayoutBCompute{}); + Tensor tCrB = tiled_mma.make_fragment_B(sBCompute); + return cute::make_tuple(tCrA, tCrB); + } + + template + CUTLASS_DEVICE auto + accum_init(cute::Tensor const& accumulators, TmemCopyAtom, EpilogueTile) { + return accumulators; + } + +protected: + + template + CUTLASS_DEVICE + constexpr auto + tile_input_tensors(Params const& params, ProblemShape_MNKL const& problem_shape_MNKL) const { + using X = cute::Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,L)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K,L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + return cute::make_tuple(gA_mkl, gB_nkl); + } + + typename Params::TMA_A const* observed_tma_load_a_ = nullptr; + typename Params::TMA_B const* observed_tma_load_b_ = nullptr; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp new file mode 100644 index 0000000..97fbd33 --- /dev/null +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp @@ -0,0 +1,1294 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/numeric_conversion.h" +#include "cutlass/detail/sm100_tmem_helper.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/detail/collective/mixed_input_utils.hpp" +#include "cutlass/detail/sm100_mixed_dtype_blockwise_layout.hpp" +#include "cutlass/detail/blockwise_scale_layout.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/arch/mma_sm100.hpp" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop for Mixed Input Kernels +template < + int Load2TransformPipelineStageCount_, + int Transform2MmaPipelineStageCount_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ClusterShape, + class TileShape_, + class ElementAOptionalTuple_, + class StridePairA_, + class ElementBOptionalTuple_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomsA_, + class CopyAtomsA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomsB_, + class CopyAtomsB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100TmaUmmaWarpSpecializedMixedInput< + Load2TransformPipelineStageCount_, + Transform2MmaPipelineStageCount_, + SchedulerPipelineStageCount_, + AccumulatorPipelineStageCount_, + ClusterShape>, + TileShape_, + ElementAOptionalTuple_, + StridePairA_, + ElementBOptionalTuple_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomsA_, + CopyAtomsA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomsB_, + CopyAtomsB_, + TransformB_> +{ +public: + // + // Type Aliases + // + + using ConversionMode = cutlass::detail::ConversionMode; + // Determine MMA type: MMA_1SM vs MMA_2SM + using AtomThrShapeMNK = Shape(typename TiledMma_::ThrLayoutVMNK{})), _1, _1>; + using DispatchPolicy = MainloopSm100TmaUmmaWarpSpecializedMixedInput< + Load2TransformPipelineStageCount_, + Transform2MmaPipelineStageCount_, + SchedulerPipelineStageCount_, + AccumulatorPipelineStageCount_, + ClusterShape>; + using TileShape = TileShape_; + using TiledMma = TiledMma_; + using KernelSchedule = typename DispatchPolicy::Schedule; + static constexpr bool IsDynamicCluster = not cute::is_static_v; + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + using ElementAOptionalTuple = ElementAOptionalTuple_; + using ElementBOptionalTuple = ElementBOptionalTuple_; + +private: + + template friend struct detail::MixedInputUtils; + using CollectiveType = CollectiveMma; + using Utils = detail::MixedInputUtils; + + using ElementScaleA = detail::deduce_mixed_width_dtype_t<1, ElementAOptionalTuple_>; + using ElementScaleB = detail::deduce_mixed_width_dtype_t<1, ElementBOptionalTuple>; + using ElementZeroA = detail::deduce_mixed_width_dtype_t<2, ElementAOptionalTuple>; + using ElementZeroB = detail::deduce_mixed_width_dtype_t<2, ElementBOptionalTuple>; + +public: + static_assert(cute::is_tuple::value ^ cute::is_tuple::value, + "Either A OR B must be a tuple. It must take the from {ElementOperand, [ElementScale]," + "[ElementZero]}. Inputs in [] are optional."); + + using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementAOptionalTuple>; + using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementBOptionalTuple>; + static constexpr bool IsATransformed = cute::is_tuple::value; + using ElementScale = cute::conditional_t; + using ElementZero = cute::conditional_t; + // For cases where we can't have a void type, we can use this to allow the code to compile when the scale / zero is void. + using NonVoidElementScale = cute::conditional_t, float, ElementScale>; + using NonVoidElementZero = cute::conditional_t, float, ElementZero>; + + using StrideA = cute::remove_cvref_t(StridePairA_{}))>; + using LayoutScale = cute::remove_cvref_t(StridePairA_{}))>; + using InternalStrideA = cute::remove_pointer_t; + using StrideB = StrideB_; + using InternalStrideB = cute::remove_pointer_t; + + static_assert((IsATransformed && cutlass::gemm::detail::is_k_major()) || + (!IsATransformed && cutlass::gemm::detail::is_k_major()), + "The transformed type must be K-major."); + + static_assert(( IsATransformed && (sizeof(ElementB) == 2)) || + (!IsATransformed && (sizeof(ElementA) == 2)) || + (cutlass::gemm::detail::is_k_major() && + cutlass::gemm::detail::is_k_major()), + "The unscaled element must be 2 bytes OR both inputs must be K-major"); + + // Define A and B block shapes for reduced size TMA_LOADs + using CtaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using CtaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + + using ElementAMma = typename TiledMma::ValTypeA; + using ElementBMma = typename TiledMma::ValTypeB; + + using ElementAccumulator = typename TiledMma::ValTypeC; + + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using GmemTiledCopyScale = GmemTiledCopyA_; + + using SmemLayoutAtomsA = SmemLayoutAtomsA_; + using SmemLayoutAtomsB = SmemLayoutAtomsB_; + using CopyAtomsA = CopyAtomsA_; + using CopyAtomsB = CopyAtomsB_; + using SmemCopyAtomScale = Copy_Atom; + + using SmemLayoutAtomA = typename SmemLayoutAtomsA::InputLayoutAtom; + using SmemLayoutAtomACompute = typename SmemLayoutAtomsA::ComputeLayoutAtom; + using SmemLayoutAtomB = typename SmemLayoutAtomsB::InputLayoutAtom; + using SmemLayoutAtomBCompute = typename SmemLayoutAtomsB::ComputeLayoutAtom; + + using InputCopyAtomA = typename CopyAtomsA::InputCopyAtom; + using ComputeCopyAtomA = typename CopyAtomsA::ComputeCopyAtom; + using InputCopyAtomB = typename CopyAtomsB::InputCopyAtom; + using ComputeCopyAtomB = typename CopyAtomsB::ComputeCopyAtom; + + // We must ensure the type to be scaled goes to RF + static constexpr bool SwapAB = !IsATransformed; + using InternalSmemLayoutAtomA = cute::conditional_t; + using InternalSmemLayoutAtomB = cute::conditional_t; + using InternalSmemLayoutAtomACompute = cute::conditional_t; + using InternalSmemLayoutAtomBCompute = cute::conditional_t; + + using InternalInputCopyAtomA = cute::conditional_t; + using InternalInputCopyAtomB = cute::conditional_t; + using InternalComputeCopyAtomA = cute::conditional_t; + using InternalComputeCopyAtomB = cute::conditional_t; + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using ConvertedElementA = cute::conditional_t>>; + using ConvertedElementB = cute::conditional_t>>; + using RealSwappedElementA = cute::conditional_t; + using RealSwappedElementB = cute::conditional_t; + using SwappedElementA = cute::conditional_t; + using SwappedElementB = cute::conditional_t; + using SwappedStrideA = cute::conditional_t; + using SwappedStrideB = cute::conditional_t; + using InternalSwappedStrideA = cute::conditional_t; + using InternalSwappedStrideB = cute::conditional_t; + + using TransformA = TransformA_; + using TransformB = TransformB_; + using InternalTransformA = cute::conditional_t; + using InternalTransformB = cute::conditional_t; + + static constexpr int IsSubbyteA = cute::sizeof_bits_v < 8; + using TmaElementA = cute::conditional_t; + using TmaElementScale = uint_bit_t >; // in case we have array. translating to uint to satisfy tma descriptor's specialization + + using ArchTag = typename DispatchPolicy::ArchTag; + static_assert(cute::is_same_v || cute::is_same_v || cute::is_same_v, + "Compute type A should be cutlass::bfloat16_t or cutlass::half_t or cutlass::float_e4m3_t"); + + using Load2TransformPipeline = cutlass::PipelineTmaTransformAsync< + DispatchPolicy::Load2TransformPipelineStageCount, + AtomThrShapeMNK>; + using Load2TransformPipelineState = typename Load2TransformPipeline::PipelineState; + + using Load2MmaPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::Load2TransformPipelineStageCount, + ClusterShape, + AtomThrShapeMNK>; + using Load2MmaPipelineState = typename Load2MmaPipeline::PipelineState; + + using Transform2MmaPipeline = cutlass::PipelineUmmaConsumerAsync< + DispatchPolicy::Transform2MmaPipelineStageCount, + AtomThrShapeMNK>; + using Transform2MmaPipelineState = typename Transform2MmaPipeline::PipelineState; + + using Mma2AccumPipeline = cutlass::PipelineUmmaAsync< + DispatchPolicy::Schedule::AccumulatorPipelineStageCount, + AtomThrShapeMNK>; + using Mma2AccumPipelineState = typename Mma2AccumPipeline::PipelineState; + + static constexpr int ScaleGranularityMN = size<0,0>(LayoutScale{}); + static constexpr int ScaleGranularityK = size<1,0>(LayoutScale{}); + using ScaleConfig = cutlass::detail::Sm100MixedInputBlockwiseScaleConfig< + ScaleGranularityMN, + ScaleGranularityK>; + + using ScaleTileShape = cute::conditional_t(TileShape{}), size<2>(TileShape{}))), + decltype(make_shape(size<1>(TileShape{}), size<2>(TileShape{})))>; + + using SmemLayoutAtomScaleFull = decltype(ScaleConfig::smem_atom_layout_scale(ScaleTileShape{})); + + // Getting the SmemSizeMN and SmemSizeK from the mixed_dtype blockwise utils. + using SmemLayoutAtomScale = decltype(slice(make_coord(make_coord(_,0),make_coord(_,0)), SmemLayoutAtomScaleFull{})); + + static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomScale{}) == 2, "SmemLayoutAtomScale must be rank 2"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomScale{})) == 0, "SmemLayoutAtomScale must equal the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0, "SmemLayoutAtomScale must evenly divide tile k shape."); + + // Thread Counts + static constexpr uint32_t NumTransformationThreads = 128; + static constexpr uint32_t NumAccumThreads = 128; //Maintains compatibility with input_transform kernel + + // Get the Algorithm parameters + constexpr static int AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; + constexpr static int StagesPerTile = size<2>(CtaShapeA_MK{}); + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(((size<0,0>(CtaShapeA_MK{}) * size<1>(CtaShapeA_MK{})) % size<0>(SmemLayoutAtomACompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + static_assert(((size<0,1>(CtaShapeA_MK{}) * size<2>(CtaShapeA_MK{})) % size<1>(SmemLayoutAtomACompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(((size<0,0>(CtaShapeB_NK{}) * size<1>(CtaShapeB_NK{})) % size<0>(SmemLayoutAtomBCompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + static_assert(((size<0,1>(CtaShapeB_NK{}) * size<2>(CtaShapeB_NK{})) % size<1>(SmemLayoutAtomBCompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(CtaShapeA_MK{}, Int{}), + (cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}))); + + using SmemLayoutACompute = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomACompute{}, + append(CtaShapeA_MK{}, Int{}), + (cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}))); + + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(CtaShapeB_NK{}, Int{}), + (cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}))); + + using SmemLayoutScale = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomScale{}, + append(CtaShapeA_MK{}, Int{}), + (cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}))); + + static_assert(DispatchPolicy::Load2TransformPipelineStageCount >= 2 && DispatchPolicy::Load2TransformPipelineStageCount >= 2, + "Specialization requires Stages set to value 2 or more."); + static_assert((cute::is_base_of::value || + cute::is_base_of::value ) && + cute::is_base_of::value, + "MMA atom must A operand from SMEM or TMEM and B operand from SMEM for this mainloop."); + static_assert((cute::is_same_v || cute::is_same_v), + "GmemTiledCopyA - invalid TMA copy atom specified."); + +private: + static constexpr ConversionMode + get_conversion_mode() { + if constexpr (cute::is_void_v) { + return ConversionMode::DirectConvert; + } + else if constexpr (cute::is_void_v) { + return ConversionMode::ConvertAndScale; + } + else { + return ConversionMode::ConvertAndScaleWithZero; + } + } + +public: + static constexpr ConversionMode KernelConversionMode = get_conversion_mode(); + static constexpr bool ModeHasScales = KernelConversionMode == ConversionMode::ConvertAndScale || + KernelConversionMode == ConversionMode::ConvertAndScaleWithZero; + static constexpr bool UseScaleLookupTable = KernelConversionMode == ConversionMode::ConvertAndScale && + cutlass::detail::is_Array_v; + static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); + + static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{}); + + // Just pick the max alignment of A and B since it is required to be at least 128B + static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB); + + static_assert(SmemAlignmentA >= 128 and SmemAlignmentB >= 128, "Require at least 128B alignment"); + + struct PipelineStorage { + using Load2TransformPipelineStorage = typename Load2TransformPipeline::SharedStorage; + alignas(16) Load2TransformPipelineStorage load2transform_pipeline; + using Load2MmaPipelineStorage = typename Load2MmaPipeline::SharedStorage; + alignas(16) Load2MmaPipelineStorage load2mma_pipeline; + using Transform2MmaPipelineStorage = typename Transform2MmaPipeline::SharedStorage; + alignas(16) Transform2MmaPipelineStorage transform2mma_pipeline; + using Mma2AccumPipelineStorage = typename Mma2AccumPipeline::SharedStorage; + alignas(16) Mma2AccumPipelineStorage mma2accum_pipeline; + }; + + struct SharedStorage { + static constexpr int scale_elements = Utils::elements_per_smem_scale(); + static constexpr int zero_elements = Utils::elements_per_smem_zero(); + struct TensorStorage : cute::aligned_struct<128, _0> { + + struct TensorStorageUntransformed { + alignas(512) cute::ArrayEngine> smem_A; + alignas(1024) cute::ArrayEngine> smem_B; + cute::ArrayEngine smem_scale; + cute::ArrayEngine smem_zero; + }; + + struct TensorStorageTransformedAinSmem { + // We require alignas(1024) here because the smem_ACompute may not be aligned to 1024 by default. + // We need 1024B alignment of smem_ACompute because we are using Swizzle<3,4,3> here. + // The Swizzle<3,4,3> aligns with 1024B. If we don't align the data, the compiler cannot deduce + // the base pointer of the data. + // This alignment allows us to perform the function swizzle(layout(i) * base_ptr). + alignas(1024) cute::ArrayEngine> smem_ACompute; + }; + + union TensorStorageTransformedAinTmem { + cute::ArrayEngine smem_ACompute; // No smem_ACompute + }; + + using TensorStorageTransformed = cute::conditional_t< + cute::is_base_of::value, + TensorStorageTransformedAinSmem, + TensorStorageTransformedAinTmem>; + + TensorStorageUntransformed input; + TensorStorageTransformed compute; + } tensors; + + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + + // Different from other GEMM kernels, both CTAs should be aware of loads. Both CTAs will work on + // loaded input A and B matrices to convert the data type + static constexpr uint32_t TmaTransactionBytes_A = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v) + Utils::compute_tma_transaction_bytes_extra_transform(); + static constexpr uint32_t TmaTransactionBytes_B = cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytes_A + TmaTransactionBytes_B; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A{nullptr}; + StrideA dA{}; + ElementB const* ptr_B{nullptr}; + StrideB dB{}; + ElementScale const* ptr_S{nullptr}; + LayoutScale layout_S{}; + ElementZero const* ptr_Z{nullptr}; + }; + + struct TMAScaleParams { + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), + make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_Scale = decltype(make_tma_atom_A_sm100( + GmemTiledCopyScale{}, + make_tensor(static_cast(nullptr), LayoutScale{}), + SmemLayoutScale{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + TMA_Scale tma_load_scale; + TMA_Scale tma_load_zero; + + }; + + struct EmptyScaleParams {}; + + // Device side kernel params + struct Params : public cute::conditional_t { + + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), + make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_B = decltype(make_tma_atom_B_sm100( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + dim3 cluster_shape_fallback; + + uint32_t tma_transaction_bytes{TmaTransactionBytes}; + SwappedStrideA dA{}; + SwappedStrideB dB{}; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) { + if constexpr (IsDynamicCluster) { + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + } + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + Tensor tensor_a = make_tensor(args.ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(args.ptr_B, make_layout(make_shape(N,K,L), args.dB)); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + // Cluster layout for TMA construction + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + uint32_t tma_transaction_bytes = TmaTransactionBytes; + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return { + {}, + tma_load_a, + tma_load_b, + tma_load_a_fallback, + tma_load_b_fallback, + hw_info.cluster_shape_fallback, + tma_transaction_bytes, + args.dA, args.dB }; + } + else if constexpr (ModeHasScales) { + ElementScale const* ptr_S = args.ptr_S; + + Tensor tensor_scale = make_tensor(detail::get_logical_ptr(ptr_S), args.layout_S); + typename Params::TMA_Scale tma_load_scale = make_tma_atom_A_sm100( + GmemTiledCopyScale{}, + tensor_scale, + SmemLayoutScale{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + if constexpr(KernelConversionMode == ConversionMode::ConvertAndScale) { + typename Params::TMAScaleParams scale_params{tma_load_scale, {}}; + return { + scale_params, + tma_load_a, + tma_load_b, + tma_load_a_fallback, + tma_load_b_fallback, + hw_info.cluster_shape_fallback, + tma_transaction_bytes, + args.dA, args.dB }; + } + else if constexpr(KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor tensor_zero = make_tensor(detail::get_logical_ptr(args.ptr_Z), args.layout_S); + typename Params::TMA_Scale tma_load_zero = make_tma_atom_A_sm100( + GmemTiledCopyScale{}, + tensor_zero, + SmemLayoutScale{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMAScaleParams scale_params{tma_load_scale, tma_load_zero}; + return { + scale_params, + tma_load_a, + tma_load_b, + tma_load_a_fallback, + tma_load_b_fallback, + hw_info.cluster_shape_fallback, + tma_transaction_bytes, + args.dA, args.dB }; + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in to_underlying_arguments."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in to_underlying_arguments."); + } + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_S = cutlass::detail::get_input_alignment_bits(); + + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cutlass::sizeof_bits::value; + bool check_aligned_A = cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cutlass::sizeof_bits::value; + bool check_aligned_B = cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + bool check_aligned_S = true; + bool check_aligned_Z = true; + bool check_mode_args = true; + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + check_mode_args = check_mode_args && (args.ptr_S == nullptr); + check_mode_args = check_mode_args && (args.ptr_Z == nullptr); + } + else if constexpr (ModeHasScales) { + constexpr int min_tma_aligned_elements_scale = tma_alignment_bits_S / cutlass::sizeof_bits::value; + check_aligned_S = cutlass::detail::check_alignment(args.layout_S); + check_mode_args = check_mode_args && (args.ptr_S != nullptr); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + check_mode_args = check_mode_args && (args.ptr_Z == nullptr); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + constexpr int min_tma_aligned_elements_zero = tma_alignment_bits_S / cutlass::sizeof_bits::value; + check_aligned_Z = cutlass::detail::check_alignment(args.layout_S); + check_mode_args = check_mode_args && (args.ptr_Z != nullptr); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + + if (!check_mode_args) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Invalid arguments for the selected conversion mode.\n"); + } + if (!check_aligned_A) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor A meet the minimum alignment requirements for TMA.\n"); + } + if (!check_aligned_B) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor B meet the minimum alignment requirements for TMA.\n"); + } + if (!check_aligned_S) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor S (scale) meet the minimum alignment requirements for TMA.\n"); + } + if (!check_aligned_Z) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor Z (zeros) meet the minimum alignment requirements for TMA.\n"); + } + + return check_mode_args && check_aligned_A && check_aligned_B && check_aligned_S && check_aligned_Z; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE static void + prefetch_tma_descriptors(Params const& params) { + if constexpr (IsDynamicCluster) { + dim3 cs = cute::cluster_shape(); + const bool is_fallback_cluster = (cs.x == params.cluster_shape_fallback.x && cs.y == params.cluster_shape_fallback.y); + if (is_fallback_cluster) { + cute::prefetch_tma_descriptor(params.tma_load_a_fallback.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b_fallback.get_tma_descriptor()); + } + else { + cute::prefetch_tma_descriptor(params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b.get_tma_descriptor()); + } + } + else { + cute::prefetch_tma_descriptor(params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b.get_tma_descriptor()); + } + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert); + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + cute::prefetch_tma_descriptor(params.tma_load_scale.get_tma_descriptor()); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + cute::prefetch_tma_descriptor(params.tma_load_scale.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_zero.get_tma_descriptor()); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA prefetch."); + } + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + /// Produce the inputs to the transform threads by loading inputs from gmem -> smem + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class TileCoordMNKL, + class KTileIterator, + class... Ts + > + CUTLASS_DEVICE auto + load_A( + Params const& params, + Load2TransformPipeline load2xform_pipeline, + Load2TransformPipelineState load2xform_pipeline_state, + cute::tuple> const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + auto [unused_gA, unused_gB, + tAgA_mkl, tBgB_nkl, tAsA, tBsB, + mcast_mask_a, mcast_mask_b, extra_input_partitions] = load_inputs; + + // slice out the work coord from tiled tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + + uint32_t skip_wait = (k_tile_count <= 0); + auto load2xform_pipeline_flag = load2xform_pipeline.producer_try_acquire(load2xform_pipeline_state, skip_wait); + + //Load2Mma and Load2Transform pipelines both have the same ProducerBarrierType + using BarrierType = typename Load2TransformPipeline::ProducerBarrierType; + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + + // LOCK mainloop_load2xform_pipeline_state for _writing_ + load2xform_pipeline.producer_acquire(load2xform_pipeline_state, load2xform_pipeline_flag); + + int tile_A_write_stage = load2xform_pipeline_state.index(); + + BarrierType* load2xform_tma_barrier = load2xform_pipeline.producer_get_barrier(load2xform_pipeline_state); + + // Advance mainloop load2transform pipeline + ++load2xform_pipeline_state; + + skip_wait = (k_tile_count <= 1); + load2xform_pipeline_flag = load2xform_pipeline.producer_try_acquire(load2xform_pipeline_state, skip_wait); + + // TMA load for A k_tile + copy(observed_tma_load_a_->with(*load2xform_tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,tile_A_write_stage)); + + if constexpr (ModeHasScales) { + auto tSgS_mkl = get<0>(extra_input_partitions); + auto tSgS = tSgS_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + auto tSsS = get<1>(extra_input_partitions); + copy(params.tma_load_scale.with(*load2xform_tma_barrier, mcast_mask_a), tSgS(_,*k_tile_iter), tSsS(_,tile_A_write_stage)); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto tZgZ_mkl = get<2>(extra_input_partitions); + auto tZgZ = tZgZ_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + auto tZsZ = get<3>(extra_input_partitions); + copy(params.tma_load_zero.with(*load2xform_tma_barrier, mcast_mask_a), tZgZ(_,*k_tile_iter), tZsZ(_,tile_A_write_stage)); + } + } + else { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert); + else static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); + } + + ++k_tile_iter; + } + + + return cute::make_tuple(load2xform_pipeline_state, k_tile_iter); + + } + + /// Produce the inputs to the transform threads by loading inputs from gmem -> smem + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class TileCoordMNKL, + class KTileIterator, + class... Ts + > + CUTLASS_DEVICE auto + load_B( + Params const& params, + Load2MmaPipeline load2mma_pipeline, + Load2MmaPipelineState load2mma_pipeline_state, + cute::tuple> const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + auto [unused_gA, unused_gB, + tAgA_mkl, tBgB_nkl, tAsA, tBsB, + mcast_mask_a, mcast_mask_b, extra_input_partitions] = load_inputs; + + // slice out the work coord from tiled tensors + Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + uint32_t skip_wait = (k_tile_count <= 0); + auto load2mma_pipeline_flag = load2mma_pipeline.producer_try_acquire(load2mma_pipeline_state, skip_wait); + + //Load2Mma and Load2Transform pipelines both have the same ProducerBarrierType + using BarrierType = typename Load2TransformPipeline::ProducerBarrierType; + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + + // LOCK mainloop_load2mma_pipeline_state for _writing_ + load2mma_pipeline.producer_acquire(load2mma_pipeline_state, load2mma_pipeline_flag); + + int tile_B_write_stage = load2mma_pipeline_state.index(); + + BarrierType* load2mma_tma_barrier = load2mma_pipeline.producer_get_barrier(load2mma_pipeline_state); + + // Advance mainloop load2mma pipeline + ++load2mma_pipeline_state; + + skip_wait = (k_tile_count <= 1); + load2mma_pipeline_flag = load2mma_pipeline.producer_try_acquire(load2mma_pipeline_state, skip_wait); + + // TMA load for B k_tile + copy(observed_tma_load_b_->with(*load2mma_tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,tile_B_write_stage)); + + ++k_tile_iter; + } + + return cute::make_tuple(load2mma_pipeline_state, k_tile_iter); + + } + + /// Set up the data needed by this collective for load. + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tiled tensor for input A + /// gB_nkl - The tiled tensor for input B + // Other inputs needed for load(): partitioned AB tensors for gmem and smem, and mcast masks + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_storage) const { + auto [gA_mkl, gB_nkl] = tile_input_tensors(params, problem_shape_MNKL); + + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_storage.input.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_storage.input.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,3>(tCgB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b, // multicast masks + cute::make_tuple()); + } + else if constexpr (ModeHasScales) { + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + Tensor mS_mkl = params.tma_load_scale.get_tma_tensor(shape(LayoutScale{})); + Tensor gS_mkl = local_tile(mS_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); + + Tensor sS = make_tensor(make_smem_ptr(shared_storage.input.smem_scale.begin()), SmemLayoutScale{}); + + Tensor tCgS_mkl = cta_mma.partition_A(gS_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + + // Project the cta_layout for tma_scale along the n-modes + auto [tSgS_mkl, tSsS] = tma_partition(params.tma_load_scale, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sS), group_modes<0,3>(tCgS_mkl)); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b, // multicast masks + cute::make_tuple(tSgS_mkl, tSsS)); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor mZ_mkl = params.tma_load_scale.get_tma_tensor(shape(LayoutScale{})); + Tensor gZ_mkl = local_tile(mS_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); + Tensor sZ = make_tensor(make_smem_ptr(shared_storage.input.smem_zero.begin()), SmemLayoutScale{}); + + Tensor tCgZ_mkl = cta_mma.partition_A(gZ_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + + // Project the cta_layout for tma_scale along the n-modes + auto [tZgZ_mkl, tZsZ] = tma_partition(params.tma_load_zero, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sZ), group_modes<0,3>(tCgZ_mkl)); + return cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b, // multicast masks + cute::make_tuple(tSgS_mkl, tSsS, tZgZ_mkl, tZsZ)); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in load_init."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in load_init."); + } + + } + + template< + class KTileIterator, class Accumulator, + class GTensorA, class DstCopyA, class SrcTensorA, class DstTensorA, + class... Ts + > + CUTLASS_DEVICE auto + transform( + Load2TransformPipeline load2transform_pipeline, + Load2TransformPipelineState load2transform_pipeline_consumer_state, + Transform2MmaPipeline transform2mma_pipeline, + Transform2MmaPipelineState transform2mma_pipeline_producer_state, + Accumulator accumulators, + cute::tuple> input_operands, + KTileIterator k_tile_iter, int k_tile_count) { + + static_assert(cute::is_same_v, "ElementAMma and ElementBMma types should be the same."); + cutlass::arch::NamedBarrier transform_bar(NumTransformationThreads, cutlass::arch::ReservedNamedBarriers::TransformBarrier); + + // tAsA : (Copy,#Copy),MMA_Rest,MMA_M_Rest,MMA_K_Rest, SmemStages (In SMEM) + // tAsACompute : (Copy,#Copy),MMA_Rest,MMA_M_Rest,MMA_K_Rest, SmemStages (In SMEM or TMEM) + auto [unused_tAgA, dst_copy_A, tAsA, tAsACompute, + partitioned_extra_info] = input_operands; + + // Create the tensors in registers + auto tArA = make_tensor(tAsA(_,_,_,_,0).shape()); //(Copy,#Copy),MMA_Rest,MMA_M_Rest,MMA_K_Rest (Register) + auto tArACompute = make_tensor(tAsA(_,_,_,_,0).shape()); + constexpr int K_BLOCK_MAX = size<3>(tArA); + + uint32_t skip_wait = (k_tile_count <= 0); + auto load2transform_flag = load2transform_pipeline.consumer_try_wait(load2transform_pipeline_consumer_state, skip_wait); + auto transform2mma_flag = transform2mma_pipeline.producer_try_acquire(transform2mma_pipeline_producer_state, skip_wait); + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + + load2transform_pipeline.consumer_wait(load2transform_pipeline_consumer_state, load2transform_flag); + + transform2mma_pipeline.producer_acquire(transform2mma_pipeline_producer_state, transform2mma_flag); + + int load2transform_consumer_index = load2transform_pipeline_consumer_state.index(); // read stage + int transform2mma_producer_index = transform2mma_pipeline_producer_state.index(); //write stage + + auto curr_load2transform_pipeline_consumer_state = load2transform_pipeline_consumer_state; + + // Copy the input A matrix from SMEM + copy(AutoVectorizingCopy{}, tAsA(_,_,_,_,load2transform_consumer_index), tArA); + // Copy scale/zero vector from SMEM + Utils::copy_scale_zeros_for_transform(partitioned_extra_info, load2transform_consumer_index); + + // Loads from SMEM are done. Signal the mainloop load as early as possible + transform_bar.sync(); + load2transform_pipeline.consumer_release(curr_load2transform_pipeline_consumer_state); + + auto curr_transform2mma_pipeline_producer_state = transform2mma_pipeline_producer_state; + + // Dequantize A with scale/zero in RF + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; k_block ++){ + Utils::dequantize_A_kblock_for_transform(tArA, tArACompute, partitioned_extra_info, k_block); + } + + // Dequantized A is stored into either Smem or Tmem + copy(dst_copy_A, tArACompute, tAsACompute(_,_,_,_,transform2mma_producer_index)); + + // fence for SMEM writes + cutlass::arch::fence_view_async_shared(); + if constexpr (is_tmem::value) { + // fence for TMEM writes if A operand is coming from TMEM + cutlass::arch::fence_view_async_tmem_store(); + } + + // Let the MMA know we are done transforming + transform2mma_pipeline.producer_commit(curr_transform2mma_pipeline_producer_state); + // Next pipeline stage + ++load2transform_pipeline_consumer_state; + ++transform2mma_pipeline_producer_state; + + skip_wait = (k_tile_count <= 1); + // Peek the next pipeline stage's barriers + load2transform_flag = load2transform_pipeline.consumer_try_wait(load2transform_pipeline_consumer_state, skip_wait); + transform2mma_flag = transform2mma_pipeline.producer_try_acquire(transform2mma_pipeline_producer_state, skip_wait); + } + return cute::make_tuple(load2transform_pipeline_consumer_state, transform2mma_pipeline_producer_state); + } + + template + CUTLASS_DEVICE auto + transform_init( + Params const& params, + ProblemShape_MNKL const& problem_shape_MNKL, + Accumulator accumulators, + TensorStorage& shared_storage) { + + auto [gA_mkl, gB_nkl] = tile_input_tensors(params, problem_shape_MNKL); + + Tensor sA_orig = make_tensor(make_smem_ptr(shared_storage.input.smem_A.begin()), SmemLayoutA{}); + Tensor sA = as_position_independent_swizzle_tensor(sA_orig); + Tensor sACompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_ACompute.begin()), SmemLayoutACompute{}); + + Tensor sS = make_tensor(make_smem_ptr(shared_storage.input.smem_scale.begin()), SmemLayoutScale{}); + Tensor sZ = make_tensor(make_smem_ptr(shared_storage.input.smem_zero.begin()), SmemLayoutScale{}); + + // Map input, compute, and fragment tensors to + // Copy strategies and partitioned tensors. These will become the input + // operands of the transform function. Depending on MMA atom type, the + // operands can reside in SMEM or TMEM + auto setup_copy_ops = [&] ( + auto tensor_input, + auto input_copy_atom, + auto tensor_compute, + auto make_fragment, + auto compute_copy_atom) constexpr { + auto fragment_compute = make_fragment(tensor_compute); + if constexpr (cute::is_tmem>::value) { + // For M=128 with 2CTA MMA atoms, the TMEM tensor for A has a duplicated allocation. + // Instead of allocation a 64x16 TMEM tensor, we have a 128x16 allocation + // See: TmemAllocMode::Duplicated. + Tensor tensor_input2x = [&] () constexpr { + if constexpr (decltype(size<0,0>(fragment_compute) == Int<128>{} && size<0,0>(tensor_input) == Int<64>{})::value) { + return make_tensor(tensor_input.data(), + logical_product(tensor_input.layout(), + make_tile(make_tile(Layout<_2,_0>{},_),_,_,_))); + } + else { + return tensor_input; + } + }(); + + fragment_compute.data() = accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + // If operand comes from TMEM, create the TMEM_STORE based copy + auto r2t_tiled_copy = make_tmem_copy(compute_copy_atom, fragment_compute(_,_,_,0)); + auto thr_r2t_tiled_copy = r2t_tiled_copy.get_slice(threadIdx.x % NumTransformationThreads); + auto partitioned_tensor_input = thr_r2t_tiled_copy.partition_S(tensor_input2x); //(TMEM_STORE, TMEM_STORE_M, TMEM_STORE_N) + auto partitioned_tensor_compute = thr_r2t_tiled_copy.partition_D(fragment_compute); //(TMEM_STORE, TMEM_STORE_M, TMEM_STORE_N) + + // Source copy is based on the source operand of TMEM_STORE copy. + auto smem2reg_tiled_copy = make_tiled_copy_S(Copy_Atom{}, r2t_tiled_copy); + return cute::make_tuple(smem2reg_tiled_copy, r2t_tiled_copy, partitioned_tensor_input, partitioned_tensor_compute); + } + else { + auto tensor_compute_ind_sw = as_position_independent_swizzle_tensor(tensor_compute); + auto r2s_tiled_copy = make_cotiled_copy(compute_copy_atom, Layout, Stride< _8,_1>>{}, + tensor_compute(_,_,_,0).layout()); + + auto smem2reg_tiled_copy = make_tiled_copy_S(input_copy_atom, r2s_tiled_copy); + auto thr_r2s_tiled_copy = r2s_tiled_copy.get_slice(threadIdx.x % NumTransformationThreads); + auto partitioned_tensor_input = thr_r2s_tiled_copy.partition_S(tensor_input); //(SMEM_STORE, SMEM_STORE_M, SMEM_STORE_N) + + auto partitioned_tensor_compute = thr_r2s_tiled_copy.partition_D(tensor_compute_ind_sw);//(SMEM_STORE, SMEM_STORE_M, SMEM_STORE_N) + + + return cute::make_tuple(smem2reg_tiled_copy, AutoVectorizingCopy{}, partitioned_tensor_input, partitioned_tensor_compute); + } + }; + + auto [src_copy_A, dst_copy_A, tAsA, tAsACompute] = + setup_copy_ops(sA, InputCopyAtomA{}, sACompute, [&](auto &arg) {return TiledMma::make_fragment_A(arg);}, ComputeCopyAtomA{}); + + // Partition of thread -> shared and thread -> RF + auto fragment_compute = TiledMma::make_fragment_A(sS); + fragment_compute.data() = accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + auto r2t_tiled_copy = make_tmem_copy(ComputeCopyAtomA{}, fragment_compute(_,_,_,0)); + auto src_copy_scale = make_tiled_copy_S(Copy_Atom{}, r2t_tiled_copy); + + auto partitioned_extra_info = Utils::partition_extra_transform_info(TiledMma{}, src_copy_scale, shared_storage); + + return cute::make_tuple(gA_mkl, dst_copy_A, tAsA, tAsACompute, + partitioned_extra_info); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgEngine, class FrgLayout, + class TensorA, class TensorB + > + CUTLASS_DEVICE auto + mma( + Load2MmaPipeline load2mma_pipeline, + Load2MmaPipelineState load2mma_pipeline_consumer_state, + Transform2MmaPipeline transform2mma_pipeline, + Transform2MmaPipelineState transform2mma_pipeline_consumer_state, + Mma2AccumPipeline mma2accum_pipeline, + Mma2AccumPipelineState mma2accum_pipeline_producer_state, + cute::Tensor const& accumulators, + cute::tuple const& input_operands, + int k_tile_count + ) { + TiledMma tiled_mma; + + auto curr_load2mma_pipeline_consumer_state = load2mma_pipeline_consumer_state; + auto next_load2mma_pipeline_consumer_state = load2mma_pipeline_consumer_state; + + auto curr_transform2mma_pipeline_consumer_state = transform2mma_pipeline_consumer_state; + auto next_transform2mma_pipeline_consumer_state = transform2mma_pipeline_consumer_state; + + uint32_t skip_wait = (k_tile_count <= 0); + auto transform2mma_flag = transform2mma_pipeline.consumer_try_wait(next_transform2mma_pipeline_consumer_state, skip_wait); + auto load2mma_flag = load2mma_pipeline.consumer_try_wait(next_load2mma_pipeline_consumer_state, skip_wait); + ++next_transform2mma_pipeline_consumer_state; + ++next_load2mma_pipeline_consumer_state; + + + // tCrA : (MMA), MMA_M, MMA_K, SmemStage (In SMEM or TMEM) + // We use SMEM stages to match #buffers in Load <-> Convert + // tCrB : (MMA), MMA_N, MMA_K, SmemStages (In SMEM) + auto const [tCrA, tCrB] = input_operands; + + mma2accum_pipeline.producer_acquire(mma2accum_pipeline_producer_state); + + int mma2accum_pipeline_producer_state_index = mma2accum_pipeline_producer_state.index(); + auto tCtC = accumulators(_,_,_,mma2accum_pipeline_producer_state_index); + auto curr_mma2accum_pipeline_producer_state = mma2accum_pipeline_producer_state; + ++mma2accum_pipeline_producer_state; + + // + // PIPELINED MAIN LOOP + // + // Clear the accumulator + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + + load2mma_pipeline.consumer_wait(curr_load2mma_pipeline_consumer_state, load2mma_flag); + transform2mma_pipeline.consumer_wait(curr_transform2mma_pipeline_consumer_state, transform2mma_flag); + + int load2mma_pipeline_consumer_state_index = curr_load2mma_pipeline_consumer_state.index(); //read_stage + int transform2mma_pipeline_consumer_state_index = curr_transform2mma_pipeline_consumer_state.index(); //read_stage + + auto tCrA0 = tCrA(_,_,_,transform2mma_pipeline_consumer_state_index); + auto tCrB0 = tCrB(_,_,_,load2mma_pipeline_consumer_state_index); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); k_block ++) { + cute::gemm(tiled_mma, tCrA0(_,_,k_block), tCrB0(_,_,k_block), tCtC); // A[0]*B[0] + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + load2mma_pipeline.consumer_release(curr_load2mma_pipeline_consumer_state); + transform2mma_pipeline.consumer_release(curr_transform2mma_pipeline_consumer_state); + + skip_wait = (k_tile_count <= 1); + load2mma_flag = load2mma_pipeline.consumer_try_wait(next_load2mma_pipeline_consumer_state, skip_wait); + transform2mma_flag = transform2mma_pipeline.consumer_try_wait(next_transform2mma_pipeline_consumer_state, skip_wait); + + curr_load2mma_pipeline_consumer_state = next_load2mma_pipeline_consumer_state; + curr_transform2mma_pipeline_consumer_state = next_transform2mma_pipeline_consumer_state; + + ++next_load2mma_pipeline_consumer_state; + ++next_transform2mma_pipeline_consumer_state; + } + + mma2accum_pipeline.producer_commit(curr_mma2accum_pipeline_producer_state); + + return cute::make_tuple(curr_load2mma_pipeline_consumer_state, curr_transform2mma_pipeline_consumer_state, mma2accum_pipeline_producer_state); + } + + template + CUTLASS_DEVICE auto + mma_init(cute::Tensor const& accumulators, TensorStorage& shared_storage) const { + TiledMma tiled_mma; + + auto get_tCrA = [&] () constexpr { + if constexpr (cute::is_base_of::value) { + Tensor sACompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_ACompute.begin()), SmemLayoutACompute{}); + return tiled_mma.make_fragment_A(sACompute); + } + else { + auto tCrA = tiled_mma.make_fragment_A(shape(SmemLayoutACompute{})); + tCrA.data() = accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + return tCrA; + } + }; + + Tensor tCrA = get_tCrA(); + Tensor sB = make_tensor(make_smem_ptr(shared_storage.input.smem_B.begin()), SmemLayoutB{}); + Tensor tCrB = tiled_mma.make_fragment_B(sB); + return cute::make_tuple(tCrA, tCrB); + } + + template + CUTLASS_DEVICE auto + accum_init(cute::Tensor const& accumulators, TmemCopyAtom tmem_cp_atom, EpilogueTile epilogue_tile) { + return accumulators; + } + +private: + template + CUTLASS_DEVICE + constexpr auto + tile_input_tensors(Params const& params, ProblemShape_MNKL const& problem_shape_MNKL) const { + using X = cute::Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,L)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K,L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); + + return cute::make_tuple(gA_mkl, gB_nkl); + } + + typename Params::TMA_A const* observed_tma_load_a_ = nullptr; + typename Params::TMA_B const* observed_tma_load_b_ = nullptr; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_planar_complex.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_planar_complex.hpp new file mode 100644 index 0000000..3d9c11e --- /dev/null +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_planar_complex.hpp @@ -0,0 +1,829 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/trace.h" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cutlass/kernel_hardware_info.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +namespace detail { +template +struct Sm100CollectiveMmaPlanarComplexTiledMmaType { + using TiledMmaAPosAtom = TiledMmaAPos_; + using TiledMmaANegAtom = TiledMmaANeg_; +}; + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int Stages, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, + class TileShape_, // Static cluster shape or dynamic (int, int, _1) + class ElementA_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMmaPair_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100TmaUmmaWarpSpecializedPlanarComplex< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMmaPair_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + + // Determine MMA type: MMA_1SM vs MMA_2SM + using TiledMmaPair = TiledMmaPair_; + using TiledMma = typename TiledMmaPair::TiledMmaAPosAtom; + using TiledMmaANeg = typename TiledMmaPair::TiledMmaANegAtom; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + + using DispatchPolicy = MainloopSm100TmaUmmaWarpSpecializedPlanarComplex< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; + using TileShape = TileShape_; + + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TileShape should be evenly divided by TiledMma"); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + + // Define A and B block shapes for reduced size TMA_LOADs + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + + using ElementA = ElementA_; + using ElementAMma = typename TiledMma::ValTypeA; + using StrideA = StrideA_; + using ElementB = ElementB_; + using ElementBMma = typename TiledMma::ValTypeB; + using StrideB = StrideB_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::Stages, + ClusterShape, + AtomThrShapeMNK>; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M, K)"); + static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtomB must be rank 2 (N, K)"); + static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::ArrayEngine> smem_A_real; + cute::ArrayEngine> smem_A_imag; + cute::ArrayEngine> smem_B_real; + cute::ArrayEngine> smem_B_imag; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly + static constexpr uint32_t TmaTransactionBytes = 2 * ( + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * (cosize(take<0,3>(SmemLayoutA{})) * static_cast(cute::sizeof_bits::value))) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * (cosize(take<0,3>(SmemLayoutB{})) * static_cast(cute::sizeof_bits::value)))); + + template + struct TmemStorage { + AccTensor accumulators; + }; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A_real{nullptr}; + StrideA dA_real{}; + ElementA const* ptr_A_imag{nullptr}; + StrideA dA_imag{}; + ElementB const* ptr_B_real{nullptr}; + StrideB dB_real{}; + ElementB const* ptr_B_imag{nullptr}; + StrideB dB_imag{}; + }; + + template< + class KTileCount, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB + > + struct LoadParams { + // For scheduler + KTileCount k_tiles; + // for input tensor values + GTensorPartitionedA tAgA_real_mkl; + GTensorPartitionedA tAgA_imag_mkl; + GTensorPartitionedB tBgB_real_nkl; + GTensorPartitionedB tBgB_imag_nkl; + STensorA tAsA_real; + STensorA tAsA_imag; + STensorB tBsB_real; + STensorB tBsB_imag; + // for input tensor values + uint16_t mcast_mask_a; + uint16_t mcast_mask_b; + + CUTLASS_DEVICE + LoadParams ( + KTileCount k_tiles_, + GTensorPartitionedA tAgA_real_mkl_, GTensorPartitionedA tAgA_imag_mkl_, + GTensorPartitionedB tBgB_real_nkl_, GTensorPartitionedB tBgB_imag_nkl_, + STensorA tAsA_real_, STensorA tAsA_imag_, + STensorB tBsB_real_, STensorB tBsB_imag_, + uint16_t mcast_mask_a_, uint16_t mcast_mask_b_) + : k_tiles(k_tiles_) + , tAgA_real_mkl(tAgA_real_mkl_), tAgA_imag_mkl(tAgA_imag_mkl_) + , tBgB_real_nkl(tBgB_real_nkl_), tBgB_imag_nkl(tBgB_imag_nkl_) + , tAsA_real(tAsA_real_), tAsA_imag(tAsA_imag_) + , tBsB_real(tBsB_real_), tBsB_imag(tBsB_imag_) + , mcast_mask_a(mcast_mask_a_), mcast_mask_b(mcast_mask_b_) {} + }; + + template + struct MmaParams { + TiledMma tiled_mma_a_pos; + TiledMmaANeg tiled_mma_a_neg; + FragmentA tCrA_real; + FragmentA tCrA_imag; + FragmentB tCrB_real; + FragmentB tCrB_imag; + + CUTLASS_DEVICE + MmaParams ( + TiledMma tiled_mma_a_pos_, TiledMmaANeg tiled_mma_a_neg_, + FragmentA tCrA_real_, FragmentA tCrA_imag_, + FragmentB tCrB_real_, FragmentB tCrB_imag_) + : tiled_mma_a_pos(tiled_mma_a_pos_), tiled_mma_a_neg(tiled_mma_a_neg_) + , tCrA_real(tCrA_real_), tCrA_imag(tCrA_imag_) + , tCrB_real(tCrB_real_), tCrB_imag(tCrB_imag_) {} + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_B = decltype(make_tma_atom_B_sm100( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + TMA_A tma_load_a_real; + TMA_A tma_load_a_imag; + TMA_B tma_load_b_real; + TMA_B tma_load_b_imag; + TMA_A tma_load_a_real_fallback; + TMA_A tma_load_a_imag_fallback; + TMA_B tma_load_b_real_fallback; + TMA_B tma_load_b_imag_fallback; + dim3 cluster_shape_fallback; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) { + if constexpr (IsDynamicCluster) { + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_real_ = is_fallback_cluster ? ¶ms.tma_load_a_real_fallback : ¶ms.tma_load_a_real; + observed_tma_load_a_imag_ = is_fallback_cluster ? ¶ms.tma_load_a_imag_fallback : ¶ms.tma_load_a_imag; + observed_tma_load_b_real_ = is_fallback_cluster ? ¶ms.tma_load_b_real_fallback : ¶ms.tma_load_b_real; + observed_tma_load_b_imag_ = is_fallback_cluster ? ¶ms.tma_load_b_imag_fallback : ¶ms.tma_load_b_imag; + } + else { + observed_tma_load_a_real_ = ¶ms.tma_load_a_real; + observed_tma_load_a_imag_ = ¶ms.tma_load_a_imag; + observed_tma_load_b_real_ = ¶ms.tma_load_b_real; + observed_tma_load_b_imag_ = ¶ms.tma_load_b_imag; + } + } + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A_real = recast_ptr(args.ptr_A_real); + auto ptr_A_imag = recast_ptr(args.ptr_A_imag); + + auto ptr_B_real = recast_ptr(args.ptr_B_real); + auto ptr_B_imag = recast_ptr(args.ptr_B_imag); + + Tensor tensor_a_real = make_tensor(ptr_A_real, make_layout(make_shape(M,K,L), args.dA_real)); + Tensor tensor_a_imag = make_tensor(ptr_A_imag, make_layout(make_shape(M,K,L), args.dA_imag)); + + Tensor tensor_b_real = make_tensor(ptr_B_real, make_layout(make_shape(N,K,L), args.dB_real)); + Tensor tensor_b_imag = make_tensor(ptr_B_imag, make_layout(make_shape(N,K,L), args.dB_imag)); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + + auto cluster_shape_fallback = conditional_return(make_shape(hw_info.cluster_shape_fallback.x, hw_info.cluster_shape_fallback.y, Int<1>{}), ClusterShape{}); + // Cluster layout for TMA construction + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + + typename Params::TMA_A tma_load_a_real = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a_real, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_imag = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a_imag, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b_real = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b_real, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b_imag = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b_imag, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_real_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a_real, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_A tma_load_a_imag_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a_imag, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_real_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b_real, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_imag_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b_imag, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + return { + tma_load_a_real, + tma_load_a_imag, + tma_load_b_real, + tma_load_b_imag, + tma_load_a_real_fallback, + tma_load_a_imag_fallback, + tma_load_b_real_fallback, + tma_load_b_imag_fallback, + hw_info.cluster_shape_fallback + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = 128 / cute::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = 128 / cute::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE void + prefetch_tma_descriptors() { + cute::prefetch_tma_descriptor(observed_tma_load_a_real_->get_tma_descriptor()); + cute::prefetch_tma_descriptor(observed_tma_load_a_imag_->get_tma_descriptor()); + cute::prefetch_tma_descriptor(observed_tma_load_b_real_->get_tma_descriptor()); + cute::prefetch_tma_descriptor(observed_tma_load_b_imag_->get_tma_descriptor()); + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE static + auto + partition_accumulator_shape() { + auto acc_shape = append( + partition_shape_C(TiledMma{}, take<0,2>(TileShape{})), + Int<2>{}); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,TMEM_PIPE,2) + + return acc_shape; + } + + template + CUTLASS_DEVICE static + auto + slice_accumulator(TmemStorage tmem_storage, int stage) { + return cute::make_tuple(tmem_storage.accumulators(_,_,_,_,stage)); + } + + template + CUTLASS_DEVICE static + auto + init_tmem_tensors(EpilogueTile epi_tile) { + TiledMma tiled_mma; + auto acc_shape = partition_accumulator_shape(); + // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. + Tensor accumulators = cutlass::detail::make_sm100_accumulator( + tiled_mma, acc_shape, EpilogueTile{}); + + TmemStorage tmem_storage; + tmem_storage.accumulators = accumulators; + + return tmem_storage; + } + + template + CUTLASS_DEVICE static + void + set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) { + tmem_storage.accumulators.data() = tmem_base_addr; + } + + /// Set up the data needed by this collective for load. + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_(real/imag)_mkl - The tiled tma tensor for input A_(real/imag) + /// gB_(real/imag)_nkl - The tiled tma tensor for input B_(real/imag) + /// tAsA_(real/imag) - partitioned smem tensor for A_(real/imag) + /// tBsB_(real/imag) - partitioned smem tensor for B_(real/imag) + /// mcast_mask_a - tma multicast mask for A_(real/imag) + /// mcast_mask_b - tma multicast mask for B_(real/imag) + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + TensorStorage& shared_tensors) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // Represent the full tensors -- get these from TMA + Tensor mA_real_mkl = observed_tma_load_a_real_->get_tma_tensor(make_shape(M,K,L)); + Tensor mA_imag_mkl = observed_tma_load_a_imag_->get_tma_tensor(make_shape(M,K,L)); + Tensor mB_real_nkl = observed_tma_load_b_real_->get_tma_tensor(make_shape(N,K,L)); + Tensor mB_imag_nkl = observed_tma_load_b_imag_->get_tma_tensor(make_shape(N,K,L)); + + // Tile the tensors and defer the slice + Tensor gA_real_mkl = local_tile(mA_real_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor gA_imag_mkl = local_tile(mA_imag_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_N, BLK_K, m, k, l) + + Tensor gB_real_nkl = local_tile(mB_real_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + Tensor gB_imag_nkl = local_tile(mB_imag_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + // Partition for this CTA + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_real_mkl = cta_mma.partition_A(gA_real_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgA_imag_mkl = cta_mma.partition_A(gA_imag_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + + Tensor tCgB_real_nkl = cta_mma.partition_B(gB_real_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + Tensor tCgB_imag_nkl = cta_mma.partition_B(gB_imag_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sA_real = make_tensor(make_smem_ptr(shared_tensors.smem_A_real.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sA_imag = make_tensor(make_smem_ptr(shared_tensors.smem_A_imag.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + + Tensor sB_real = make_tensor(make_smem_ptr(shared_tensors.smem_B_real.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + Tensor sB_imag = make_tensor(make_smem_ptr(shared_tensors.smem_B_imag.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_real_mkl, tAsA_real] = tma_partition(*observed_tma_load_a_real_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA_real), group_modes<0,3>(tCgA_real_mkl)); + auto [tAgA_imag_mkl, tAsA_imag] = tma_partition(*observed_tma_load_a_imag_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA_imag), group_modes<0,3>(tCgA_imag_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_real_nkl, tBsB_real] = tma_partition(*observed_tma_load_b_real_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB_real), group_modes<0,3>(tCgB_real_nkl)); + auto [tBgB_imag_nkl, tBsB_imag] = tma_partition(*observed_tma_load_b_imag_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB_imag), group_modes<0,3>(tCgB_imag_nkl)); + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + LoadParams load_params { + shape<3>(gA_real_mkl), // for scheduler + tAgA_real_mkl, tAgA_imag_mkl, tBgB_real_nkl, tBgB_imag_nkl, // for input tensor values + tAsA_real, tAsA_imag, tBsB_real, tBsB_imag, // for input tensor values + mcast_mask_a, mcast_mask_b + }; + return load_params; + } + + /// Set up the data needed by this collective for mma compute. + template + CUTLASS_DEVICE auto + mma_init( + [[maybe_unused]] TmemStorage tmem_storage, + TensorStorage& shared_tensors) const { + Tensor sA_real = make_tensor(make_smem_ptr(shared_tensors.smem_A_real.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sA_imag = make_tensor(make_smem_ptr(shared_tensors.smem_A_imag.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + + Tensor sB_real = make_tensor(make_smem_ptr(shared_tensors.smem_B_real.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sB_imag = make_tensor(make_smem_ptr(shared_tensors.smem_B_imag.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA_real = TiledMma::make_fragment_A(sA_real); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrA_imag = TiledMma::make_fragment_A(sA_imag); // (MMA,MMA_M,MMA_K,PIPE) + + Tensor tCrB_real = TiledMma::make_fragment_B(sB_real); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCrB_imag = TiledMma::make_fragment_B(sB_imag); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA_real)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB_real)); // PIPE + + TiledMma tiled_mma_a_pos; + TiledMmaANeg tiled_mma_a_neg; + MmaParams mma_params { + tiled_mma_a_pos, tiled_mma_a_neg, + tCrA_real, tCrA_imag, + tCrB_real, tCrB_imag + }; + + return mma_params; + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class LoadParams, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load( + MainloopPipeline mainloop_pipeline, + MainloopPipelineState mainloop_pipe_producer_state, + LoadParams const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + auto [unused_k_tiles, + tAgA_real_mkl, tAgA_imag_mkl, tBgB_real_nkl, tBgB_imag_nkl, + tAsA_real, tAsA_imag, tBsB_real, tBsB_imag, + mcast_mask_a, mcast_mask_b] = load_inputs; + + // slice out the work coord from partitioned tensors + Tensor tAgA_real = tAgA_real_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tAgA_imag = tAgA_imag_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + + Tensor tBgB_real = tBgB_real_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + Tensor tBgB_imag = tBgB_imag_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_real_->with(*tma_barrier, mcast_mask_a), tAgA_real(_,*k_tile_iter), tAsA_real(_,write_stage)); + copy(observed_tma_load_a_imag_->with(*tma_barrier, mcast_mask_a), tAgA_imag(_,*k_tile_iter), tAsA_imag(_,write_stage)); + + copy(observed_tma_load_b_real_->with(*tma_barrier, mcast_mask_b), tBgB_real(_,*k_tile_iter), tBsB_real(_,write_stage)); + copy(observed_tma_load_b_imag_->with(*tma_barrier, mcast_mask_b), tBgB_imag(_,*k_tile_iter), tBsB_imag(_,write_stage)); + } + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline mainloop_pipeline, MainloopPipelineState mainloop_pipe_producer_state) { + // Issue the epilogue waits + /* This helps avoid early exit of ctas in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class AccumulatorPipeline, + class FrgEngine, class FrgLayout, + class MmaParams, + class CtaTileCoord + > + CUTLASS_DEVICE auto + mma(cute::tuple pipelines, + cute::tuple pipeline_states, + cute::tuple> const& accumulators_pair, + MmaParams const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 4 && size<3>(FrgLayout{}) == _2{}, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N, _2)"); + + auto [tiled_mma_a_pos, tiled_mma_a_neg, tCrA_real, tCrA_imag, tCrB_real, tCrB_imag] = mma_inputs; + + auto [mainloop_pipeline, accumulator_pipeline] = pipelines; + auto [mainloop_pipe_consumer_state, accumulator_pipe_producer_state] = pipeline_states; + + uint32_t skip_wait = k_tile_count <= 0; + auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + // + // PIPELINED MAIN LOOP + // + tiled_mma_a_pos.accumulate_ = UMMA::ScaleOut::Zero; + tiled_mma_a_neg.accumulate_ = UMMA::ScaleOut::Zero; + + auto accumulators = get<0>(accumulators_pair); + auto accumulators_real = accumulators(_,_,_,0); + auto accumulators_imag = accumulators(_,_,_,1); + + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // WAIT on mainloop_pipe_consumer_state until its data are available + // (phase bit flips from mainloop_pipe_consumer_state.phase() value) + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + + // Compute on k_tile + int read_stage = mainloop_pipe_consumer_state.index(); + // Save current mainlop pipeline read state + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + + // Advance mainloop_pipe + ++mainloop_pipe_consumer_state; + --k_tile_count; + skip_wait = k_tile_count <= 0; + // Peek at next iteration + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + // Unroll the K mode manually so we can set scale C to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA_real); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + + // Calculate real acc, 1st step + // realAcc += realA * realB + cute::gemm(tiled_mma_a_pos, tCrA_real(_,_,k_block,read_stage), tCrB_real(_,_,k_block,read_stage), accumulators_real); + + // Calculate imag acc, 1st step + if constexpr (cute::is_same_v) { + // imagAcc += realA * (-imagB) + cute::gemm(tiled_mma_a_neg, tCrA_real(_,_,k_block,read_stage), tCrB_imag(_,_,k_block,read_stage), accumulators_imag); + } else { + // imagAcc += realA * imagB + cute::gemm(tiled_mma_a_pos, tCrA_real(_,_,k_block,read_stage), tCrB_imag(_,_,k_block,read_stage), accumulators_imag); + } + + tiled_mma_a_pos.accumulate_ = UMMA::ScaleOut::One; + tiled_mma_a_neg.accumulate_ = UMMA::ScaleOut::One; + + // Calculate real acc, 2nd step + if constexpr (cute::is_same_v) { + // realAcc -= imagA * imagB + cute::gemm(tiled_mma_a_neg, tCrA_imag(_,_,k_block,read_stage), tCrB_imag(_,_,k_block,read_stage), accumulators_real); + } else { + // realAcc += imagA * imagB + cute::gemm(tiled_mma_a_pos, tCrA_imag(_,_,k_block,read_stage), tCrB_imag(_,_,k_block,read_stage), accumulators_real); + } + + // Calculate imag acc, 2nd step + if constexpr (cute::is_same_v) { + // imagAcc += (-imagA) * realB + cute::gemm(tiled_mma_a_neg, tCrA_imag(_,_,k_block,read_stage), tCrB_real(_,_,k_block,read_stage), accumulators_imag); + } else { + // imagAcc += imagA * realB + cute::gemm(tiled_mma_a_pos, tCrA_imag(_,_,k_block,read_stage), tCrB_real(_,_,k_block,read_stage), accumulators_imag); + } + } + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + + return mainloop_pipe_consumer_state; + } + +protected: + + typename Params::TMA_A const* observed_tma_load_a_real_ = nullptr; + typename Params::TMA_A const* observed_tma_load_a_imag_ = nullptr; + typename Params::TMA_B const* observed_tma_load_b_real_ = nullptr; + typename Params::TMA_B const* observed_tma_load_b_imag_ = nullptr; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp index d2d8172..185ab05 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm103_blockscaled_mma_array_warpspecialized.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm103_blockscaled_mma_array_warpspecialized.hpp new file mode 100644 index 0000000..2030b9a --- /dev/null +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm103_blockscaled_mma_array_warpspecialized.hpp @@ -0,0 +1,1685 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/sm103_blockscaled_layout.hpp" +#include "cutlass/detail/collective/sm103_kernel_type.hpp" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/sm100_tmem_helper.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int LoadABPipelineStageCount, + int LoadSFPipelineStageCount, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, int) + cutlass::sm103::detail::KernelPrefetchType PrefetchType, + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementPairA_, + class StridePairA_, + class ElementPairB_, + class StridePairB_, + class TiledMma_, + class GmemTiledCopyPairA_, + class SmemLayoutAtomPairA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyPairB_, + class SmemLayoutAtomPairB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm103ArrayTmaUmmaWarpSpecializedBlockScaled< + LoadABPipelineStageCount, + LoadSFPipelineStageCount, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape, + PrefetchType>, + TileShape_, + ElementPairA_, + StridePairA_, + ElementPairB_, + StridePairB_, + TiledMma_, + GmemTiledCopyPairA_, + SmemLayoutAtomPairA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyPairB_, + SmemLayoutAtomPairB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + using DispatchPolicy = MainloopSm103ArrayTmaUmmaWarpSpecializedBlockScaled< + LoadABPipelineStageCount, + LoadSFPipelineStageCount, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape, + PrefetchType>; + + using TileShape = TileShape_; + // Due to an MSVC bug, we can't use decltype(make_tiled_mma()) interface. + using TiledMMA_SF = TiledMMA, + Layout>, + Tile>; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + static constexpr int SFVecSize = TiledMma::SFVecSize; + static constexpr bool IsOverlappingAccum = DispatchPolicy::IsOverlappingAccum; + + // Assert that TiledMma and TileShape should be weakly compatible + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TiledMma and TileShape should be weakly compatible"); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + static_assert(shape<1>(CtaShape_MNK{}) == 192 or shape<1>(CtaShape_MNK{}) == 128 or shape<1>(CtaShape_MNK{}) == 256, + "Cta N should be one of 128/192/256"); + + using ClusterTileShape = decltype(make_shape(get<0>(TileShape{})*get<0>(ClusterShape{}),get<1>(TileShape{})*get<1>(ClusterShape{}),get<2>(TileShape{})*get<2>(ClusterShape{}))); + using Sm1xxBlkScaledConfig = cutlass::detail::Sm103BlockScaledConfig; + using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; + static constexpr int IsCtaN192 = shape<1>(CtaShape_MNK{}) == 192; + static int constexpr CTA_N_SF = cutlass::round_up(size<1>(CtaShape_MNK{}), Blk_MN{}); + // Tile shape used for partitioning Scale Factor B. + // The M-dim does not affect the SFB, so just set it as the original TileShape; + using TileShape_SF = decltype(make_shape(get<0>(CtaShape_MNK{}), + Int{} * shape<2>(typename TiledMma::ThrLayoutVMNK()), + get<2>(TileShape{}))); + + static int constexpr SF_BUFFERS_PER_TILE_K = SFVecSize == 16 ? 4 : 2; + using MMA_SF_Tiler = decltype(make_tile(shape<0>(CtaShape_MNK{}), Int{}, Int(CtaShape_MNK{})/SF_BUFFERS_PER_TILE_K>{})); + + using ElementPairA = ElementPairA_; + using ElementPairB = ElementPairB_; + using ElementAMma = typename TiledMma::ValTypeA; + using ElementBMma = typename TiledMma::ValTypeB; + using StridePairA = StridePairA_; + using StridePairB = StridePairB_; + using SmemLayoutAtomPairA = SmemLayoutAtomPairA_; + using SmemLayoutAtomPairB = SmemLayoutAtomPairB_; + static_assert(cute::is_same_v(ElementPairA{}))>, + remove_cvref_t(ElementPairB{}))>>, "SFA and SFB data types should be the same"); + + // A and B matrices + using ElementA = remove_cvref_t(ElementPairA{}))>; + using StrideA = remove_cvref_t(StridePairA{}))>; + using InternalStrideA = cute::remove_pointer_t; + using ElementB = remove_cvref_t(ElementPairB{}))>; + using StrideB = remove_cvref_t(StridePairB{}))>; + using InternalStrideB = cute::remove_pointer_t; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + // SFA and SFB + using ElementSF = remove_cvref_t(ElementPairA{}))>; + using LayoutSFA = remove_cvref_t(StridePairA{}))>; + using InternalLayoutSFA = cute::remove_pointer_t; + using LayoutSFB = remove_cvref_t(StridePairB{}))>; + using InternalLayoutSFB = cute::remove_pointer_t; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyPairA = GmemTiledCopyPairA_; + using GmemTiledCopyPairB = GmemTiledCopyPairB_; + using GmemTiledCopyA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopySFA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopyB = remove_cvref_t(GmemTiledCopyPairB{}))>; + using GmemTiledCopySFB = remove_cvref_t(GmemTiledCopyPairB{}))>; + + using SmemLayoutAtomA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomSFA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + using SmemLayoutAtomSFB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopABPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::LoadABPipelineStageCount, + ClusterShape, + AtomThrShapeMNK>; + using MainloopABPipelineState = typename MainloopABPipeline::PipelineState; + + using MainloopSFPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::LoadSFPipelineStageCount, + ClusterShape, + AtomThrShapeMNK>; + using MainloopSFPipelineState = typename MainloopSFPipeline::PipelineState; + + static_assert(cute::is_void_v, + "SM103 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(cute::is_void_v, + "SM103 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(make_shape(make_shape(shape<0>(CtaShape_MNK{}), _16{}), _1{}, _8{}), Int{} /*PIPE*/), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); // ((CTA_MMA_M,16bytes),1,8,NUM_PIPES) + using SmemLayoutA_tma = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(make_shape(make_shape(shape<0>(CtaShape_MNK{}), _16{}), _1{}, _8{}), Int<3>{} /*Per mainloop iteration */), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); // ((CTA_MMA_M,16bytes),1,8,3) + + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(make_shape(make_shape(shape<1>(CtaShape_MNK{}) / size(typename TiledMma::AtomThrID{}), _16{}), _1{}, _8{}), Int{} /*PIPE*/), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); // ((CTA_MMA_N,16bytes),1,8,NUM_PIPES) + using SmemLayoutB_tma = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(make_shape(make_shape(shape<1>(CtaShape_MNK{}) / size(typename TiledMma::AtomThrID{}), _16{}), _1{}, _8{}), Int<3>{} /*Per mainloop iteration */), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); // ((CTA_MMA_N,16bytes),1,8,3) + + + // SmemLayoutAtomSFA and SmemLayoutAtomSFB are for whole CTA tiles. We add the number of pipeline stages here. + // The number of pipeline stages is the same as the number of pipeline stages from AB Load <-> MainLoop + using SmemLayoutSFA = decltype(make_layout( + append(shape(SmemLayoutAtomSFA{}), Int{}), + append(stride(SmemLayoutAtomSFA{}), size(filter_zeros(SmemLayoutAtomSFA{}))) + )); + using SmemLayoutSFB = decltype(make_layout( + append(shape(SmemLayoutAtomSFB{}), Int{}), + append(stride(SmemLayoutAtomSFB{}), size(filter_zeros(SmemLayoutAtomSFB{}))) + )); + + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + + static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + using TmaInternalElementA = uint8_t; + using TmaInternalElementB = uint8_t; + + using SmemAllocTypeA = uint8_t; + using SmemAllocTypeB = uint8_t; + + using BitTypeElementA = cute::uint_bit_t>; + using BitTypeElementB = cute::uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = typename detail::sm10x_block_scale_runtime_input_t::Type; + using RuntimeDataTypeB = typename detail::sm10x_block_scale_runtime_input_t::Type; + + using SmemPrefetchType = uint8_t; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + cute::ArrayEngine> smem_SFA; + cute::ArrayEngine> smem_SFB; + } tensors; + + struct TensorMapStorage : cute::aligned_struct<128, _0> { + cute::TmaDescriptor smem_tensormap_A; + cute::TmaDescriptor smem_tensormap_B; + cute::TmaDescriptor smem_tensormap_SFA; + cute::TmaDescriptor smem_tensormap_SFB; + } tensormaps; + + using PipelineABStorage = typename MainloopABPipeline::SharedStorage; + using PipelineSFStorage = typename MainloopSFPipeline::SharedStorage; + struct PipelineStorage { + PipelineABStorage pipeline_ab; + PipelineSFStorage pipeline_sf; + }; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr uint32_t SFTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFB{})) * cute::sizeof_bits_v); + // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly + static constexpr uint32_t ABTmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v); + + // Host side kernel arguments + struct Arguments { + ArrayElementA const** ptr_A{nullptr}; + StrideA dA{}; + ArrayElementB const** ptr_B{nullptr}; + StrideB dB{}; + ElementSF const** ptr_SFA{nullptr}; + LayoutSFA layout_SFA{}; + ElementSF const** ptr_SFB{nullptr}; + LayoutSFB layout_SFB{}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = + decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), + ClusterShape{})), make_tile(typename TiledMma::AtomThrID{}))); + using ClusterLayoutSfb_VMNK = + decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), + ClusterShape{})), make_tile(typename TiledMMA_SF::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom( + GmemTiledCopyA{}, + recast(make_tensor(recast_ptr(nullptr), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{})), + SmemLayoutA_tma{}, + make_tile(size<1,0>(typename TiledMma::ALayout{}), _384{}), + size<1>(ClusterShape{})) + ); + + using TMA_B = decltype(make_tma_atom( + GmemTiledCopyB{}, + recast(make_tensor(recast_ptr(nullptr), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{})), + SmemLayoutB_tma{}, + make_tile(size<1,0>(typename TiledMma::BLayout{}), _384{}), + size<0>(ClusterShape{})/size(typename TiledMma::AtomThrID{})) + ); + + using TMA_SFA = decltype(make_tma_atom( // using legacy sm90 make_tma_atom + GmemTiledCopySFA{}, + make_tensor(static_cast(nullptr), InternalLayoutSFA{}), + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + make_shape(get<0>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<1>(ClusterShape{})) + ); + + using TMA_SFB = decltype(make_tma_atom( // using legacy sm90 make_tma_atom + GmemTiledCopySFB{}, + make_tensor(static_cast(nullptr), InternalLayoutSFB{}), + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + make_shape(get<1>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<0>(ClusterShape{})/size(typename TiledMMA_SF::AtomThrID{})) + ); + + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_SFA tma_load_sfa; + TMA_SFB tma_load_sfb; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + TMA_SFA tma_load_sfa_fallback; + TMA_SFB tma_load_sfb_fallback; + LayoutSFA layout_SFA; + LayoutSFB layout_SFB; + dim3 cluster_shape_fallback; + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + cute::TmaDescriptor* tensormaps; + ArrayElementA const** ptr_A; + StrideA dA; + ArrayElementB const** ptr_B; + StrideB dB; + ElementSF const** ptr_SFA; + ElementSF const** ptr_SFB; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params) { + if constexpr (IsDynamicCluster) { + dim3 cs = cute::cluster_shape(); + const bool is_fallback_cluster = (cs.x == params.cluster_shape_fallback.x && cs.y == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + observed_tma_load_sfa_ = is_fallback_cluster ? ¶ms.tma_load_sfa_fallback : ¶ms.tma_load_sfa; + observed_tma_load_sfb_ = is_fallback_cluster ? ¶ms.tma_load_sfb_fallback : ¶ms.tma_load_sfb; + + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + observed_tma_load_sfa_ = ¶ms.tma_load_sfa; + observed_tma_load_sfb_ = ¶ms.tma_load_sfb; + } + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shapes, + Arguments const& args, + [[maybe_unused]] void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. + // These will be replaced with correct values before the initial tma load. + auto init_M = int32_t(size<0>(TileShape{})); + auto init_N = int32_t(size<1>(TileShape{})); + auto init_K = int32_t(size<2>(TileShape{})); + auto init_L = 1; + + // Tensor pointers will be fixed before the first access + ElementA const* ptr_A_first_batch = nullptr; + ElementB const* ptr_B_first_batch = nullptr; + + InternalStrideA stride_a; + InternalStrideB stride_b; + InternalLayoutSFA layout_SFA; + InternalLayoutSFB layout_SFB; + + if constexpr (IsGroupedGemmKernel) { + // Strides for Grouped Gemm will be replaced prior to the first access regardless. + stride_a = InternalStrideA{}; + stride_b = InternalStrideB{}; + layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(init_M, init_N, init_K, 1)); + layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(init_M, init_N, init_K, 1)); + } + else { + // Tensor shapes for Ptr-Array are initialized correctly only here. + auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0); + init_M = get<0>(problem_shape_MNK); + init_N = get<1>(problem_shape_MNK); + init_K = get<2>(problem_shape_MNK); + + stride_a = args.dA; + stride_b = args.dB; + layout_SFA = args.layout_SFA; + layout_SFB = args.layout_SFB; + } + + // Batches/Groups are managed by using appropriate pointers to input matrices. + Tensor tensor_a = recast(make_tensor(ptr_A_first_batch, make_layout(make_shape(init_M,init_K,init_L), stride_a))); + Tensor tensor_b = recast(make_tensor(ptr_B_first_batch, make_layout(make_shape(init_N,init_K,init_L), stride_b))); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + + // Tensor pointers will be fixed before the first access + ElementSF const* ptr_SFA_first_batch = nullptr; + ElementSF const* ptr_SFB_first_batch = nullptr; + + Tensor tensor_sfa = make_tensor(ptr_SFA_first_batch, layout_SFA); + Tensor tensor_sfb = make_tensor(ptr_SFB_first_batch, layout_SFB); + + typename Params::TMA_A tma_load_a = make_tma_atom( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA_tma{}, + make_tile(size<1,0>(typename TiledMma::ALayout{}), _384{}), + size<1>(cluster_shape) + ); + + typename Params::TMA_B tma_load_b = make_tma_atom( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB_tma{}, + make_tile(size<1,0>(typename TiledMma::BLayout{}), _384{}), + size<0>(cluster_shape)/size(typename TiledMma::AtomThrID{}) + ); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA_tma{}, + make_tile(size<1,0>(typename TiledMma::ALayout{}), _384{}), + size<1>(cluster_shape_fallback) + ); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB_tma{}, + make_tile(size<1,0>(typename TiledMma::BLayout{}), _384{}), + size<0>(cluster_shape_fallback)/size(typename TiledMma::AtomThrID{}) + ); + + typename Params::TMA_SFA tma_load_sfa = make_tma_atom( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + make_shape(get<0>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<1>(cluster_shape) + ); + + typename Params::TMA_SFB tma_load_sfb = make_tma_atom( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + make_shape(get<1>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<0>(cluster_shape)/size(typename TiledMMA_SF::AtomThrID{}) + ); + + typename Params::TMA_SFA tma_load_sfa_fallback = make_tma_atom( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + make_shape(get<0>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<1>(cluster_shape_fallback) + ); + + typename Params::TMA_SFB tma_load_sfb_fallback = make_tma_atom( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + make_shape(get<1>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<0>(cluster_shape_fallback)/size(typename TiledMMA_SF::AtomThrID{}) + ); + + #if 0 + print("tma_load_a:\n"); + print(tma_load_a); + print("tma_load_a.tma_desc:\n"); print(tma_load_a.tma_desc_); print("\n"); + + print("tma_load_b:\n"); + print(tma_load_b); + print("tma_load_b.tma_desc:\n"); print(tma_load_b.tma_desc_); print("\n"); + + print("layout_SFA: "); print(args.layout_SFA); print("\n"); + print("tma_load_sfa:\n"); + print(tma_load_sfa); + print("tma_load_sfa.tma_desc:\n"); print(tma_load_sfa.tma_desc_); print("\n"); + + print("layout_SFB: "); print(args.layout_SFB); print("\n"); + print("tma_load_sfb:\n"); + print(tma_load_sfb); + print("tma_load_sfb.tma_desc:\n"); print(tma_load_sfb.tma_desc_); print("\n"); + + print("layout_sfa: "); print(args.layout_SFA); print("\n"); + print("tma_load_sfa_fallback:\n"); + print(tma_load_sfa_fallback); + print("tma_load_sfa_fallback.tma_desc:\n"); print(tma_load_sfa_fallback.tma_desc_); print("\n"); + + print("layout_sfb: "); print(args.layout_SFB); print("\n"); + print("tma_load_sfb_fallback:\n"); + print(tma_load_sfb_fallback); + print("tma_load_sfb_fallback.tma_desc:\n"); print(tma_load_sfb_fallback.tma_desc_); print("\n"); + #endif + + return { + tma_load_a, + tma_load_b, + tma_load_sfa, + tma_load_sfb, + tma_load_a_fallback, + tma_load_b_fallback, + tma_load_sfa_fallback, + tma_load_sfb_fallback, + args.layout_SFA, + args.layout_SFB, + hw_info.cluster_shape_fallback, + args.runtime_data_type_a, + args.runtime_data_type_b, + reinterpret_cast(workspace), + reinterpret_cast(args.ptr_A), + args.dA, + reinterpret_cast(args.ptr_B), + args.dB, + reinterpret_cast(args.ptr_SFA), + reinterpret_cast(args.ptr_SFB) + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr uint32_t NumInputTensors = 4; + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies + return (NumInputTensors * SizeOfCuTensorMap * sm_count); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + ProblemShape problem_shapes, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits::value; + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cute::sizeof_bits::value; + + bool implementable = true; + if (problem_shapes.is_host_problem_shape_available()) { + // Check alignment for all problem sizes + for (int i = 0; i < problem_shapes.groups(); i++) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{}); + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{}); + } + } + + if constexpr (IsRuntimeDataType && detail::is_sm10x_mxf4nvf4_input() && detail::is_sm10x_mxf4nvf4_input()) { + bool is_compatible = (SFVecSize == 16 || + (SFVecSize == 32 && is_same_v + && args.runtime_data_type_a == cute::UMMA::MXF4Format::E2M1 + && args.runtime_data_type_b == cute::UMMA::MXF4Format::E2M1)); + if (!is_compatible) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: 2x mode (VectorSize=32) only supports float_e2m1_t for a/b types and ue8m0_t for sf type.\n"); + } + implementable &= is_compatible; + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + template + CUTLASS_DEVICE auto + slice_accumulator(cute::Tensor const& accumulators, int stage) { + return accumulators(_,_,_,stage); + } + + template + CUTLASS_DEVICE auto + get_mkl_shape_tensor ( + ProblemShape_MNKL const& problem_shape_MNKL) const { + auto [M,N,K,L] = problem_shape_MNKL; + const int32_t mock_L = 1; + int K_recast = (K*cute::sizeof_bits_v/8); + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K_recast,mock_L)); + Tensor gA_mkl = local_tile(mA_mkl, replace<2>(TileShape{}, _384{}), make_coord(_,_,_), Step<_1, X,_1>{}); + return gA_mkl; + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tma tensor for input A + /// gB_nkl - The tiled tma tensor for input B + /// tAgA_mkl - partitioned gmem tensor for A + /// tBgB_nkl - partitioned gmem tensor for B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + /// mcast_mask_a - tma multicast mask for A + /// mcast_mask_b - tma multicast mask for B + template + CUTLASS_DEVICE auto + load_ab_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, int32_t const sm_idx) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + const int32_t mock_L = 1; + int K_recast = (K*cute::sizeof_bits_v/8); + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K_recast,mock_L)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K_recast,mock_L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, replace<2>(TileShape{}, _384{}), make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor gB_nkl = local_tile(mB_nkl, replace<2>(TileShape{}, _384{}), make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + // Partition for this CTA + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl_tmp = cta_mma.partition_A(gA_mkl); // ((CTA_MMA_M,96),Rest_MMA_M,Rest_MMA_K, m, k, l) + Tensor cta_tCgA = make_tensor(tCgA_mkl_tmp.data(), make_layout(coalesce(make_layout(cute::layout<0,0>(tCgA_mkl_tmp), cute::layout<1>(tCgA_mkl_tmp))), + coalesce(make_layout(cute::layout<0,1>(tCgA_mkl_tmp), cute::layout<2>(tCgA_mkl_tmp))), + cute::layout<3>(tCgA_mkl_tmp), cute::layout<4>(tCgA_mkl_tmp), cute::layout<5>(tCgA_mkl_tmp))); // (CTA_M,CTA_K,m,k,l) + + Tensor tCgA_mkl = make_tensor(cta_tCgA.data(), tiled_divide(cta_tCgA.layout(), + make_tile(size<1,0>(typename TiledMma::ALayout{}) /*MMA_M for SM100*/, + _128{} /*128bytes*/))); // ((CTA_MMA_M,256),Rest_MMA_M,Rest_MMA_K, m, k, l) + + Tensor tCgB_nkl_tmp = cta_mma.partition_B(gB_nkl); // ((MMA_ATOM_M,96),Rest_MMA_M,Rest_MMA_K, n, k, l) + Tensor cta_tCgB = make_tensor(tCgB_nkl_tmp.data(), make_layout(coalesce(make_layout(cute::layout<0,0>(tCgB_nkl_tmp), cute::layout<1>(tCgB_nkl_tmp))), + coalesce(make_layout(cute::layout<0,1>(tCgB_nkl_tmp), cute::layout<2>(tCgB_nkl_tmp))), + cute::layout<3>(tCgB_nkl_tmp), cute::layout<4>(tCgB_nkl_tmp), cute::layout<5>(tCgB_nkl_tmp))); // (CTA_M,CTA_K,m,k,l) + Tensor tCgB_nkl = make_tensor(cta_tCgB.data(), tiled_divide(cta_tCgB.layout(), + make_tile(size<1,0>(typename TiledMma::BLayout{}) /*MMA_M for SM100*/, + _128{} /*128bytes*/))); // ((CTA_MMA_M,256),Rest_MMA_M, Rest_MMA_K, m, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // ((CTA_MMA_M,32),Rest_MMA_M,8,NUM_PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // ((CTA_MMA_N,32),Rest_MMA_N,8,NUM_PIPE) + + + Layout cta_layout_mnk = make_layout(cutlass::detail::select_cluster_shape(ClusterShape{}, cute::cluster_shape())); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + int block_rank_in_cluster = cute::block_rank_in_cluster(); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); + + Layout cta_layout_sfb_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA_SF::AtomThrID{})); + auto cta_coord_sfb_vmnk = cta_layout_sfb_vmnk.get_flat_coord(block_rank_in_cluster); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,1>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,1>(tCgB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + // Fetch a copy of tensormaps for the CTA from Params + auto input_tensormaps = tensormaps_init_ab(params, shared_tensormaps, sm_count, sm_idx); + + return cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b, // multicast masks + input_tensormaps); // for tma descriptor modification (per-CTA tensormap copy) + } + + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// tAgA_mkl - partitioned gmem tensor for A + /// tBgB_nkl - partitioned gmem tensor for B + /// mcast_mask_sfa - tma multicast mask for SFA + /// mcast_mask_sfb - tma multicast mask for SFB + template + CUTLASS_DEVICE auto + load_sf_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, int32_t const sm_idx, + int32_t init_group) const { + using X = Underscore; + + // Separate out problem shape for convenience + + InternalLayoutSFA layout_SFA{}; + InternalLayoutSFB layout_SFB{}; + if constexpr (IsGroupedGemmKernel) { + layout_SFA = params.layout_SFA[init_group]; + layout_SFB = params.layout_SFB[init_group]; + } + else { + layout_SFA = params.layout_SFA; + layout_SFB = params.layout_SFB; + } + + // Represent the full tensor of Scale factors + Tensor mSFA_mkl = observed_tma_load_sfa_->get_tma_tensor(shape(layout_SFA)); + auto mSFB_nkl = [=](){ + if constexpr (IsCtaN192) { + Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB)); + auto x = stride<0,1>(mSFB_tmp); + auto y = ceil_div(shape<0,1>(mSFB_tmp), 4); + auto new_shape = make_shape (make_shape( shape<0,0>(mSFB_tmp), + make_shape( make_shape(_2{}, _2{}), y)), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); + auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), + make_stride(make_stride( x, x), x*3)), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); + return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); + } + else { + return observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB)); + } + }(); + + // Partition for this CTA + Tensor gSFA_mkl = local_tile(mSFA_mkl, MMA_SF_Tiler{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (TILE_M,TILE_K,m,k,l) + Tensor gSFB_nkl = local_tile(mSFB_nkl, MMA_SF_Tiler{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (TILE_N,TILE_K,n,k,l) + + Tensor tCgSFA_mkl = make_tensor(gSFA_mkl.data(), tiled_divide(gSFA_mkl.layout(), make_tile(get<0>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})))); // ((MMA_M,MMA_K),Rest_MMA_M,Rest_MMA_K, m, k, l) + Tensor tCgSFB_nkl = make_tensor(gSFB_nkl.data(), tiled_divide(gSFB_nkl.layout(), make_tile(get<1>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})))); // ((MMA_N,MMA_K),Rest_MMA_N,Rest_MMA_K, n, k, l) + + Tensor tCsSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor tCsSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + Layout cta_layout_mnk = make_layout(cutlass::detail::select_cluster_shape(ClusterShape{}, cute::cluster_shape())); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + int block_rank_in_cluster = cute::block_rank_in_cluster(); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); + + Layout cta_layout_sfb_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA_SF::AtomThrID{})); + auto cta_coord_sfb_vmnk = cta_layout_sfb_vmnk.get_flat_coord(block_rank_in_cluster); + // Project the cta_layout for tma_a along the n-modes + auto [tAgSFA_mkl, tAsSFA] = tma_partition(*observed_tma_load_sfa_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(tCsSFA), group_modes<0,3>(tCgSFA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgSFB_nkl, tBsSFB] = tma_partition(*observed_tma_load_sfb_, + get<1>(cta_coord_sfb_vmnk), make_layout(size<1>(cta_layout_sfb_vmnk)), + group_modes<0,3>(tCsSFB), group_modes<0,3>(tCgSFB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_sfa = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_sfb = create_tma_multicast_mask<1>(cta_layout_sfb_vmnk, cta_coord_sfb_vmnk); + + auto input_tensormaps = tensormaps_init_sf(params, shared_tensormaps, sm_count, sm_idx); + + return cute::make_tuple( + tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB, // for input scale factor tensor values + mcast_mask_sfa, mcast_mask_sfb, // multicast masks + input_tensormaps); // for tma descriptor modification (per-CTA tensormap copy) + } + + /// Set up the data needed by this collective for mma compute. + CUTLASS_DEVICE auto + mma_init( + Params const& params, + TensorStorage& shared_tensors, + uint32_t const tmem_offset) const { + + // Allocate "fragments/descriptors" for A and B matrices + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // ((CTA_MMA_M,32),Rest_MMA_M,8,NUM_PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // ((CTA_MMA_M,32),Rest_MMA_M,8,NUM_PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = make_tensor(sA);; + Tensor tCrB = make_tensor(sB);; + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); // PIPE + + // + // Scale Factor + // + Tensor tCtSFA = make_tensor(take<0,3>(shape(SmemLayoutAtomSFA{}))); + // TMEM allocations for SFA and SFB will always start at DP 0. + tCtSFA.data() = tmem_offset; + Tensor tCtSFB = make_tensor(take<0,3>(shape(SmemLayoutAtomSFB{}))); + + tCtSFB.data() = tCtSFA.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tCtSFA); + + // Setup smem descriptors for UTCCP + Tensor tCsSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor tCsSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + // Make SMEM and TMEM tensors compact removing the zero strides to eliminate unnecessary copy instructions. + auto tCsSFA_compact = make_tensor(tCsSFA.data(), filter_zeros(tCsSFA.layout())); + auto tCtSFA_compact = make_tensor(tCtSFA.data(), filter_zeros(tCtSFA.layout())); + auto tCsSFB_compact = make_tensor(tCsSFB.data(), filter_zeros(tCsSFB.layout())); + auto tCtSFB_compact = make_tensor(tCtSFB.data(), filter_zeros(tCtSFB.layout())); + + // Create the SMEM to TMEM copy operations based on the MMA atom used (1CTA vs 2CTA) + using AtomThrID = typename TiledMma::AtomThrID; + using UtccpOp = cute::conditional_t<(decltype(cute::size(AtomThrID{}) == Int<2>{})::value), + SM100_UTCCP_4x32dp128bit_2cta, SM100_UTCCP_4x32dp128bit_1cta>; + auto tCtSFA_compact_copy = make_tensor(tCtSFA_compact.data(), append<3>(tCtSFA_compact(_,_0{},_0{}).layout())); + auto tCtSFB_compact_copy = make_tensor(tCtSFB_compact.data(), append<3>(tCtSFB_compact(_,_0{},_0{}).layout())); + auto tiled_copy_s2t_SFA = make_utccp_copy(UtccpOp{}, tCtSFA_compact_copy); + auto tiled_copy_s2t_SFB = make_utccp_copy(UtccpOp{}, tCtSFB_compact_copy); + + auto thr_copy_s2t_SFA = tiled_copy_s2t_SFA.get_slice(0); + auto thr_tCsSFA_compact_s2t_ = thr_copy_s2t_SFA.partition_S(tCsSFA_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFA_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFA_compact_s2t_); + auto thr_tCtSFA_compact_s2t = thr_copy_s2t_SFA.partition_D(tCtSFA_compact); + + auto thr_copy_s2t_SFB = tiled_copy_s2t_SFB.get_slice(0); + auto thr_tCsSFB_compact_s2t_ = thr_copy_s2t_SFB.partition_S(tCsSFB_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFB_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFB_compact_s2t_); + auto thr_tCtSFB_compact_s2t = thr_copy_s2t_SFB.partition_D(tCtSFB_compact); + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + tiled_mma.idesc_.a_format_ = uint8_t(params.runtime_data_type_a) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(params.runtime_data_type_b) & 0b111; + } + + // using MMA_SF_Tiler = decltype(make_tile(shape<0>(CtaShape_MNK{}), Int{}, Int(CtaShape_MNK{})/2>{})); // 128x128x384 + // MMA shapes are ((_128,_96),_1,_8) which makes the MMA_SFA_Shape ((128, (16,3)), 1, 8/3) + // The number is not divisible by 4 in K dimension which is needed for TMEM allocation. + // To be able to iterate thru the SFs for MMA, we model this as (MMA), MMA_M, MMA_K: ((128, (16,1)), 1, 24) + // with this layout we can iterate thru the SFs by incrementing MMA_K mode by 3/6 for this example (Vs=16 vs Vs=32). + constexpr int MMA_M = size<0>(CtaShape_MNK{}); + constexpr int MMA_N_SF = CTA_N_SF; + constexpr int MMA_K_SF = shape<2>(CtaShape_MNK{}) / 2; + auto mnBasicBlockShape = make_shape(_32{}, _4{}); + auto kBasicBlockShape_single = make_shape(Int{}, Int<1>{}); + auto mma_iter_SFA_shape = make_shape( prepend(Int{}, mnBasicBlockShape), kBasicBlockShape_single); + auto sSFA_iter_shape = make_shape(mma_iter_SFA_shape, _1{}, Int{}); + auto mma_iter_SFB_shape = make_shape( prepend(Int{}, mnBasicBlockShape), kBasicBlockShape_single); + auto sSFB_iter_shape = make_shape(mma_iter_SFB_shape, _1{}, Int{}); + + // Used for MMAs + using MmaIterShapeSFA = decltype(sSFA_iter_shape); // ((32,4),(SFVecSize,1), MMA_M/128, SF_MMA_K/SfVecSize + using MmaIterShapeSFB = decltype(sSFB_iter_shape); // ((32,4),(SFVecSize,1), MMA_N/128, SF_MMA_K/SfVecSize + + Tensor tCtSFA_mma = make_tensor(MmaIterShapeSFA{}); + tCtSFA_mma.data() = tCtSFA.data(); + Tensor tCtSFB_mma = make_tensor(MmaIterShapeSFB{}); + tCtSFB_mma.data() = tCtSFB.data(); + + return cute::make_tuple( + tiled_mma, + tCrA, tCrB, tCtSFA, tCtSFB, tCtSFA_mma, tCtSFB_mma, + tiled_copy_s2t_SFA, thr_tCsSFA_compact_s2t, thr_tCtSFA_compact_s2t, + tiled_copy_s2t_SFB, thr_tCsSFB_compact_s2t, thr_tCtSFB_compact_s2t); + } + +// Helper function to handle both prefetch types + template + CUTLASS_DEVICE void issue_prefetch( + int& prefetch_k_tile_count, + int& prefetch_buf_idx, + KTileIterator& prefetch_k_tile, + TmaPrefetchFn&& tma_prefetch_fn) + { + if (prefetch_k_tile_count > 0) { + if constexpr (PrefetchType == cutlass::sm103::detail::KernelPrefetchType::TmaPrefetch) { + tma_prefetch_fn(); + } + + prefetch_buf_idx = (prefetch_buf_idx + 1) % BuffersPerKtile; + if(prefetch_buf_idx == 0) { + ++prefetch_k_tile; + --prefetch_k_tile_count; + } + } + } + + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class TensorMapA, class TensorMapB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load_ab( + Params const& params, + MainloopABPipeline pipeline, + MainloopABPipelineState mainloop_pipe_producer_state, + cute::tuple> const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count, + bool did_batch_change, int prefetch_k_tile_count = 0) { + + auto tAgA_mkl = get<2>(load_inputs); + auto tBgB_nkl = get<3>(load_inputs); + auto tAsA = get<4>(load_inputs); + auto tBsB = get<5>(load_inputs); + auto mcast_mask_a = get<6>(load_inputs); + auto mcast_mask_b = get<7>(load_inputs); + auto input_tensormaps = get<8>(load_inputs); + + if (did_batch_change) { + tensormaps_fence_acquire(get<0>(input_tensormaps)); + tensormaps_fence_acquire(get<1>(input_tensormaps)); + } + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mkl(_, _, _, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgB = tBgB_nkl(_, _, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = pipeline.producer_try_acquire(mainloop_pipe_producer_state); + constexpr int BuffersPerKtile = 3; + auto prefetch_k_tile = k_tile_iter; + auto prefetch_buf_idx = 0; + auto tile_k_advance = LoadABPipelineStageCount / BuffersPerKtile; + + if constexpr (PrefetchType != cutlass::sm103::detail::KernelPrefetchType::Disable) { + prefetch_buf_idx = LoadABPipelineStageCount % BuffersPerKtile; + CUTLASS_PRAGMA_UNROLL + for (int i=0;i 0) { + using BarrierType = typename MainloopABPipeline::ProducerBarrierType; + // In total, we will load 3 buffers per k_tile_iter. Unrolled. + CUTLASS_PRAGMA_UNROLL + for(int buffer = 0; buffer < BuffersPerKtile; buffer++) { + pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + BarrierType* tma_barrier = pipeline.producer_get_barrier(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + auto tma_copy_traits_a = observed_tma_load_a_->with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a); + auto tma_copy_traits_b = observed_tma_load_b_->with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b); + + if (cute::elect_one_sync()) { + copy(tma_copy_traits_a, group_modes<0,2>(tAgA(_,_,buffer,*k_tile_iter)), tAsA(_,write_stage)); + copy(tma_copy_traits_b, group_modes<0,2>(tBgB(_,_,buffer,*k_tile_iter)), tBsB(_,write_stage)); + } + + if constexpr (PrefetchType != cutlass::sm103::detail::KernelPrefetchType::Disable) { + issue_prefetch ( + prefetch_k_tile_count, + prefetch_buf_idx, + prefetch_k_tile, + [&]() { + prefetch(tma_copy_traits_a, group_modes<0,2>(tAgA(_,_,prefetch_buf_idx,*prefetch_k_tile))); + prefetch(tma_copy_traits_b, group_modes<0,2>(tBgB(_,_,prefetch_buf_idx,*prefetch_k_tile))); + } + ); + } + } + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class GTensorPartitionedSFA, class GTensorPartitionedSFB, + class STensorSFA, class STensorSFB, + class TensorMapSFA, class TensorMapSFB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load_sf( + Params const& params, + MainloopSFPipeline pipeline, + MainloopSFPipelineState mainloop_sf_pipe_producer_state, + cute::tuple> const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count, + bool did_batch_change, int prefetch_k_tile_count = 0) { + + auto tAgSFA_mkl = get<0>(load_inputs); + auto tBgSFB_nkl = get<1>(load_inputs); + auto tAsSFA = get<2>(load_inputs); + auto tBsSFB = get<3>(load_inputs); + auto mcast_mask_sfa = get<4>(load_inputs); + auto mcast_mask_sfb = get<5>(load_inputs); + auto input_tensormaps_sf = get<6>(load_inputs); + // slice out the work coord from partitioned tensors + Tensor tAgSFA = tAgSFA_mkl(_, get<0>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + Tensor tBgSFB = tBgSFB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + // Check to see if tensormaps have been replaced in gmem + if (did_batch_change) { + tensormaps_fence_acquire(get<0>(input_tensormaps_sf)); + tensormaps_fence_acquire(get<1>(input_tensormaps_sf)); + } + + auto barrier_token = pipeline.producer_try_acquire(mainloop_sf_pipe_producer_state); + + using BarrierType = typename MainloopSFPipeline::ProducerBarrierType; + auto tAsSFA_compact = make_tensor(tAsSFA.data(), filter_zeros(tAsSFA.layout())); + auto tBsSFB_compact = make_tensor(tBsSFB.data(), filter_zeros(tBsSFB.layout())); + auto prefetch_k_tile = k_tile_iter; + auto prefetch_buf_idx = 0; + auto tile_k_advance = LoadSFPipelineStageCount / SF_BUFFERS_PER_TILE_K; + + if constexpr (PrefetchType != cutlass::sm103::detail::KernelPrefetchType::Disable) { + prefetch_buf_idx = LoadSFPipelineStageCount % SF_BUFFERS_PER_TILE_K; + CUTLASS_PRAGMA_UNROLL + for (int i=0;i 0) { + // In total, we will load 2 or 4 buffers per k_tile_iter. Unrolled. + CUTLASS_PRAGMA_UNROLL + for(int buffer = 0; buffer < SF_BUFFERS_PER_TILE_K; buffer++) { + pipeline.producer_acquire(mainloop_sf_pipe_producer_state, barrier_token); + BarrierType* tma_barrier = pipeline.producer_get_barrier(mainloop_sf_pipe_producer_state); + + int write_stage = mainloop_sf_pipe_producer_state.index(); + ++mainloop_sf_pipe_producer_state; + barrier_token = pipeline.producer_try_acquire(mainloop_sf_pipe_producer_state); + auto tAgSFA_compact = make_tensor(tAgSFA(_,*k_tile_iter*SF_BUFFERS_PER_TILE_K + buffer).data(), filter_zeros(tAgSFA(_,*k_tile_iter*SF_BUFFERS_PER_TILE_K + buffer).layout())); + auto tBgSFB_compact = make_tensor(tBgSFB(_,*k_tile_iter*SF_BUFFERS_PER_TILE_K + buffer).data(), filter_zeros(tBgSFB(_,*k_tile_iter*SF_BUFFERS_PER_TILE_K + buffer).layout())); + + auto tma_copy_traits_sfa = observed_tma_load_sfa_->with(get<0>(input_tensormaps_sf), *tma_barrier, mcast_mask_sfa); + auto tma_copy_traits_sfb = observed_tma_load_sfb_->with(get<1>(input_tensormaps_sf), *tma_barrier, mcast_mask_sfb); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_sfa_->with(get<0>(input_tensormaps_sf), *tma_barrier, mcast_mask_sfa), tAgSFA_compact, tAsSFA_compact(_,write_stage)); + copy(observed_tma_load_sfb_->with(get<1>(input_tensormaps_sf), *tma_barrier, mcast_mask_sfb), tBgSFB_compact, tBsSFB_compact(_,write_stage)); + } + + auto tAgSFA_compact_prefetch = make_tensor(tAgSFA(_,*prefetch_k_tile*SF_BUFFERS_PER_TILE_K + prefetch_buf_idx).data(), filter_zeros(tAgSFA(_,*prefetch_k_tile*SF_BUFFERS_PER_TILE_K + prefetch_buf_idx).layout())); + auto tBgSFB_compact_prefetch = make_tensor(tBgSFB(_,*prefetch_k_tile*SF_BUFFERS_PER_TILE_K + prefetch_buf_idx).data(), filter_zeros(tBgSFB(_,*prefetch_k_tile*SF_BUFFERS_PER_TILE_K + prefetch_buf_idx).layout())); + if constexpr (PrefetchType != cutlass::sm103::detail::KernelPrefetchType::Disable) { + issue_prefetch ( + prefetch_k_tile_count, + prefetch_buf_idx, + prefetch_k_tile, + [&]() { + prefetch(tma_copy_traits_sfa, tAgSFA_compact_prefetch); + prefetch(tma_copy_traits_sfb, tBgSFB_compact_prefetch); + } + ); + } + } + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_sf_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + template < + class MainloopPipeline, class MainloopPipelineState + > + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, MainloopPipelineState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class AccumulatorPipeline, + class FrgEngine, class FrgLayout, + class FragmentA, class FragmentB, + class FragmentSFA, class FragmentSFB, + class MmaFragmentSFA, class MmaFragmentSFB, + class CtaTileCoord, + class SFATiledCopy, class SmemFrgSFA, class TmemFrgSFA, + class SFBTiledCopy, class SmemFrgSFB, class TmemFrgSFB + > + CUTLASS_DEVICE auto + mma(cute::tuple pipelines, + cute::tuple pipeline_states, + cute::Tensor& accumulators, + cute::tuple const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + auto pipeline_ab = get<0>(pipelines); + auto pipeline_sf = get<1>(pipelines); + auto accumulator_pipeline = get<2>(pipelines); + auto mainloop_pipe_ab_consumer_state = get<0>(pipeline_states); + auto mainloop_pipe_sf_consumer_state = get<1>(pipeline_states); + auto accumulator_pipe_producer_state = get<2>(pipeline_states); + auto tiled_mma = get<0>(mma_inputs); + auto tCrA = get<1>(mma_inputs); + auto tCrB = get<2>(mma_inputs); + auto tCtSFA = get<3>(mma_inputs); + auto tCtSFB = get<4>(mma_inputs); + auto tCtSFA_mma = get<5>(mma_inputs); + auto tCtSFB_mma = get<6>(mma_inputs); + auto tiled_copy_s2t_SFA = get<7>(mma_inputs); + auto tCsSFA_s2t = get<8>(mma_inputs); + auto tCtSFA_s2t = get<9>(mma_inputs); + auto tiled_copy_s2t_SFB = get<10>(mma_inputs); + auto tCsSFB_s2t = get<11>(mma_inputs); + auto tCtSFB_s2t = get<12>(mma_inputs); + + tCtSFB_mma = [tCtSFB_mma = tCtSFB_mma, cta_tile_coord]() { + if constexpr (IsCtaN192) { + // If this is an ODD tile, shift the TMEM start address for N=192 case by two words (ignores first 64 columns of SFB) + auto tCtSFB_tmp = tCtSFB_mma; + if (get<1>(cta_tile_coord) % 2 == 1) { + tCtSFB_tmp.data() = tCtSFB_tmp.data().get() + 2; + } + return tCtSFB_tmp; + } + else { + return tCtSFB_mma; + } + }(); + + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + constexpr int sf_stride = TiledMma::SFVecSize == 16 ? 6 : 3; + auto barrier_token_ab = pipeline_ab.consumer_try_wait(mainloop_pipe_ab_consumer_state); + auto barrier_token_sf = pipeline_sf.consumer_try_wait(mainloop_pipe_sf_consumer_state); + constexpr int MmasPerSfBuffer = 8 / SF_BUFFERS_PER_TILE_K; + + auto sf_load_fn = [&](const int kphase, const int k_tile_count) { + if (kphase % MmasPerSfBuffer == 0) { + pipeline_sf.consumer_wait(mainloop_pipe_sf_consumer_state, barrier_token_sf); + int read_stage_sf_buffer0 = mainloop_pipe_sf_consumer_state.index(); + if (cute::elect_one_sync()) { + copy(tiled_copy_s2t_SFA, tCsSFA_s2t(_,_,_,_,read_stage_sf_buffer0), tCtSFA_s2t); + copy(tiled_copy_s2t_SFB, tCsSFB_s2t(_,_,_,_,read_stage_sf_buffer0), tCtSFB_s2t); + } + auto buffer0_mainloop_pipe_sf_consumer_state = mainloop_pipe_sf_consumer_state; + ++mainloop_pipe_sf_consumer_state; + barrier_token_sf = pipeline_sf.consumer_try_wait(mainloop_pipe_sf_consumer_state, (kphase == 8 - MmasPerSfBuffer) && k_tile_count <= 1); // only skip wait for the last one. + pipeline_sf.consumer_release(buffer0_mainloop_pipe_sf_consumer_state); + } + }; + + bool is_first_iteration = true; + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // MMA 0 + sf_load_fn(0, k_tile_count); + pipeline_ab.consumer_wait(mainloop_pipe_ab_consumer_state, barrier_token_ab); + int read_stage_ab_buffer0 = mainloop_pipe_ab_consumer_state.index(); + auto buffer0_mainloop_pipe_ab_consumer_state = mainloop_pipe_ab_consumer_state; + ++mainloop_pipe_ab_consumer_state; + barrier_token_ab = pipeline_ab.consumer_try_wait(mainloop_pipe_ab_consumer_state); + + // delay the acc acquire to unblock tmem copy. + if constexpr (IsOverlappingAccum) { + if(is_first_iteration) { + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + is_first_iteration = false; + } + }; + + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,0,read_stage_ab_buffer0), // A buffer: Points to buffer[0] + tCrA(_,_,0,read_stage_ab_buffer0), // Next A buffer for circular buffers: Points to buffer[0] + tCtSFA_mma(_, _, 0 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,0,read_stage_ab_buffer0), // B buffer: Points to buffer[0] + tCrB(_,_,0,read_stage_ab_buffer0), // Next B buffer for circular buffers: Points to buffer[0] + tCtSFB_mma(_, _, 0 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + + // MMA 1 + sf_load_fn(1, k_tile_count); + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,3,read_stage_ab_buffer0), // A buffer: Points to buffer[0] + 48 bytes. Note the 3. + tCrA(_,_,0,read_stage_ab_buffer0), // Next A buffer for circular buffers: Points to buffer[0] + tCtSFA_mma(_, _, 1 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,3,read_stage_ab_buffer0), // B buffer: Points to buffer[0] + 48 bytes. Note the 3. + tCrB(_,_,0,read_stage_ab_buffer0), // Next B buffer for circular buffers: Points to buffer[0] + tCtSFB_mma(_, _, 1 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + + // MMA 2 + sf_load_fn(2, k_tile_count); + pipeline_ab.consumer_wait(mainloop_pipe_ab_consumer_state, barrier_token_ab); + int read_stage_ab_buffer1 = mainloop_pipe_ab_consumer_state.index(); + auto buffer1_mainloop_pipe_ab_consumer_state = mainloop_pipe_ab_consumer_state; + ++mainloop_pipe_ab_consumer_state; + barrier_token_ab = pipeline_ab.consumer_try_wait(mainloop_pipe_ab_consumer_state); + + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,6,read_stage_ab_buffer0), // A buffer: Points to buffer[0] + 96 bytes. Note the 6. + tCrA(_,_,0,read_stage_ab_buffer1), // Next A buffer for circular buffers: Points to buffer[1]. + tCtSFA_mma(_, _, 2 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,6,read_stage_ab_buffer0), // B buffer: Points to buffer[0] + 96 bytes. Note the 6. + tCrB(_,_,0,read_stage_ab_buffer1), // Next B buffer for circular buffers: Points to buffer[1]. + tCtSFB_mma(_, _, 2 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + pipeline_ab.consumer_release(buffer0_mainloop_pipe_ab_consumer_state); + + + // MMA 3 + sf_load_fn(3, k_tile_count); + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,1,read_stage_ab_buffer1), // A buffer: Points to buffer[1] + 16 bytes. Note the 1. + tCrA(_,_,0,read_stage_ab_buffer1), // Next A buffer for circular buffers: Points to buffer[1]. + tCtSFA_mma(_, _, 3 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,1,read_stage_ab_buffer1), // B buffer: Points to buffer[1] + 16 bytes. Note the 1. + tCrB(_,_,0,read_stage_ab_buffer1), // Next B buffer for circular buffers: Points to buffer[1]. + tCtSFB_mma(_, _, 3 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + // MMA 4 + sf_load_fn(4, k_tile_count); + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,4,read_stage_ab_buffer1), // A buffer: Points to buffer[1] + 64 bytes. Note the 1. + tCrA(_,_,0,read_stage_ab_buffer1), // Next A buffer for circular buffers: Points to buffer[1]. + tCtSFA_mma(_, _, 4 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,4,read_stage_ab_buffer1), // B buffer: Points to buffer[1] + 64 bytes. Note the 1. + tCrB(_,_,0,read_stage_ab_buffer1), // Next B buffer for circular buffers: Points to buffer[1]. + tCtSFB_mma(_, _, 4 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + // MMA 5 + sf_load_fn(5, k_tile_count); + pipeline_ab.consumer_wait(mainloop_pipe_ab_consumer_state, barrier_token_ab); + int read_stage_ab_buffer2 = mainloop_pipe_ab_consumer_state.index(); + auto buffer2_mainloop_pipe_ab_consumer_state = mainloop_pipe_ab_consumer_state; + ++mainloop_pipe_ab_consumer_state; + barrier_token_ab = pipeline_ab.consumer_try_wait(mainloop_pipe_ab_consumer_state, k_tile_count <= 1); + + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,7,read_stage_ab_buffer1), // A buffer: Points to buffer[1] + 112 bytes. Note the 7. + tCrA(_,_,0,read_stage_ab_buffer2), // Next A buffer for circular buffers: Points to buffer[2]. + tCtSFA_mma(_, _, 5 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,7,read_stage_ab_buffer1), // B buffer: Points to buffer[1] + 112 bytes. Note the 7. + tCrB(_,_,0,read_stage_ab_buffer2), // Next B buffer for circular buffers: Points to buffer[2]. + tCtSFB_mma(_, _, 5 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + pipeline_ab.consumer_release(buffer1_mainloop_pipe_ab_consumer_state); + + // MMA 6 + sf_load_fn(6, k_tile_count); + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,2,read_stage_ab_buffer2), // A buffer: Points to buffer[1] + 32 bytes. Note the 2. + tCrA(_,_,0,read_stage_ab_buffer2), // Next A buffer for circular buffers: Points to buffer[2]. + tCtSFA_mma(_, _, 6 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,2,read_stage_ab_buffer2), // B buffer: Points to buffer[1] + 32 bytes. Note the 2. + tCrB(_,_,0,read_stage_ab_buffer2), // Next B buffer for circular buffers: Points to buffer[2]. + tCtSFB_mma(_, _, 6 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + // MMA 7 + sf_load_fn(7, k_tile_count); + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,5,read_stage_ab_buffer2), // A buffer: Points to buffer[1] + 80 bytes. Note the 5. + tCrA(_,_,0,read_stage_ab_buffer2), // Next A buffer for circular buffers: Points to buffer[2]. + tCtSFA_mma(_, _, 7 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,5,read_stage_ab_buffer2), // B buffer: Points to buffer[1] + 80 bytes. Note the 5. + tCrB(_,_,0,read_stage_ab_buffer2), // Next B buffer for circular buffers: Points to buffer[2]. + tCtSFB_mma(_, _, 7 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + pipeline_ab.consumer_release(buffer2_mainloop_pipe_ab_consumer_state); + --k_tile_count; + } + return cute::make_tuple(mainloop_pipe_ab_consumer_state, mainloop_pipe_sf_consumer_state); + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + CUTLASS_DEVICE auto + tensormaps_init_ab( + Params const& mainloop_params, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, + int32_t const sm_idx) const { + cute::TmaDescriptor* gmem_tensormap = mainloop_params.tensormaps; + + cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx]; + cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count]; + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + Tensor pA_tensormap = make_tensor(observed_tma_load_a_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A), Int<1>{}, Int<1>{}); + Tensor pB_tensormap = make_tensor(observed_tma_load_b_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{}); + + copy(recast(pA_tensormap), recast(sA_tensormap)); + copy(recast(pB_tensormap), recast(sB_tensormap)); + + } + __syncwarp(); + + return cute::make_tuple(tma_desc_a, tma_desc_b); + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE + void + tensormaps_replace_global_address_ab( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_A, + mainloop_params.ptr_A[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B, + mainloop_params.ptr_B[next_batch]); + } + + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_tensor_properties_ab( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl) { + const uint32_t M = get<0>(problem_shape_mnkl); + const uint32_t N = get<1>(problem_shape_mnkl); + const uint32_t K = get<2>(problem_shape_mnkl); + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape_A = {1,1,1,1,1}; + cute::array prob_stride_A = {0,0,0,0,0}; + cute::array prob_shape_B = {1,1,1,1,1}; + cute::array prob_stride_B = {0,0,0,0,0}; + + ElementA const* ptr_A = nullptr; + Tensor tensor_a = recast(make_tensor(ptr_A, make_shape(M,K,Int<1>{}), mainloop_params.dA[next_group])); + + ElementB const* ptr_B = nullptr; + Tensor tensor_b = recast(make_tensor(ptr_B, make_shape(N,K,Int<1>{}), mainloop_params.dB[next_group])); + + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_a_, tensor_a, + prob_shape_A, prob_stride_A); + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_b_, tensor_b, + prob_shape_B, prob_stride_B); + + // Convert strides to byte strides + for (uint64_t& stride : prob_stride_A) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_B) { + stride = (stride * sizeof_bits_v) / 8; + } + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_A, + prob_shape_A, + prob_stride_A); + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B, + prob_shape_B, + prob_stride_B); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_perform_update_ab( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + cute::tuple const& input_ab_tensormaps, + ProblemShape problem_shape, + int32_t next_batch) { + if (cute::elect_one_sync()) { + // Replacing global_address for the next batch + tensormaps_replace_global_address_ab(shared_tensormaps, mainloop_params, next_batch); + + if constexpr (IsGroupedGemmKernel) { + auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(next_batch), 1); + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties_ab(shared_tensormaps, + mainloop_params, next_batch, problem_shape_MNKL); + } + } + // Ensure warp is converged before issuing tensormap fence release + __syncwarp(); + // Entire warp must do this (ie its aligned) + tensormaps_cp_fence_release_ab(shared_tensormaps, input_ab_tensormaps); + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release_ab ( + TensorMapStorage& shared_tensormaps, + cute::tuple const& input_ab_tensormaps) { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + // Entire warp must do this (i.e. it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_ab_tensormaps), shared_tensormaps.smem_tensormap_A); + tma_descriptor_cp_fence_release(get<1>(input_ab_tensormaps), shared_tensormaps.smem_tensormap_B); + + } + + // SF tensormap ops + CUTLASS_DEVICE auto + tensormaps_init_sf( + Params const& mainloop_params, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, + int32_t const sm_idx) const { + cute::TmaDescriptor* gmem_tensormap = mainloop_params.tensormaps; + + cute::TmaDescriptor* tma_desc_sfa = &gmem_tensormap[sm_idx + 2 * sm_count]; + cute::TmaDescriptor* tma_desc_sfb = &gmem_tensormap[sm_idx + 3 * sm_count]; + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + Tensor pSFA_tensormap = make_tensor(observed_tma_load_sfa_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sSFA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_SFA), Int<1>{}, Int<1>{}); + Tensor pSFB_tensormap = make_tensor(observed_tma_load_sfb_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sSFB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_SFB), Int<1>{}, Int<1>{}); + + copy(recast(pSFA_tensormap), recast(sSFA_tensormap)); + copy(recast(pSFB_tensormap), recast(sSFB_tensormap)); + } + __syncwarp(); + + return cute::make_tuple(tma_desc_sfa, tma_desc_sfb); + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE + void + tensormaps_replace_global_address_sf( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_SFA, + mainloop_params.ptr_SFA[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_SFB, + mainloop_params.ptr_SFB[next_batch]); + } + + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_tensor_properties_sf( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl) { + const uint32_t M = get<0>(problem_shape_mnkl); + const uint32_t N = get<1>(problem_shape_mnkl); + const uint32_t K = get<2>(problem_shape_mnkl); + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape_SFA = {1,1,1,1,1}; + cute::array prob_stride_SFA = {0,0,0,0,0}; + cute::array prob_shape_SFB = {1,1,1,1,1}; + cute::array prob_stride_SFB = {0,0,0,0,0}; + + ElementSF const* ptr_SF = nullptr; + Tensor tensor_sfa = make_tensor(ptr_SF, mainloop_params.layout_SFA[next_group]); + + Tensor tensor_sfb = make_tensor(ptr_SF, mainloop_params.layout_SFB[next_group]); + + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_sfa_, tensor_sfa, + prob_shape_SFA, prob_stride_SFA); + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_sfb_, tensor_sfb, + prob_shape_SFB, prob_stride_SFB); + + // Convert strides to byte strides + for (uint64_t& stride : prob_stride_SFA) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_SFB) { + stride = (stride * sizeof_bits_v) / 8; + } + + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_SFA, + prob_shape_SFA, + prob_stride_SFA); + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_SFB, + prob_shape_SFB, + prob_stride_SFB); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_perform_update_sf( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + cute::tuple const& input_tensormaps_sf, + ProblemShape problem_shape, + int32_t next_batch) { + if (cute::elect_one_sync()) { + // Replacing global_address for the next batch + tensormaps_replace_global_address_sf(shared_tensormaps, mainloop_params, next_batch); + + if constexpr (IsGroupedGemmKernel) { + auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(next_batch), 1); + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties_sf(shared_tensormaps, + mainloop_params, next_batch, problem_shape_MNKL); + } + } + // Ensure warp is converged before issuing tensormap fence release + __syncwarp(); + // Entire warp must do this (ie its aligned) + tensormaps_cp_fence_release_sf(shared_tensormaps, input_tensormaps_sf); + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release_sf ( + TensorMapStorage& shared_tensormaps, + cute::tuple const& input_tensormaps_sf) { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + // Entire warp must do this (i.e. it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_tensormaps_sf), shared_tensormaps.smem_tensormap_SFA); + tma_descriptor_cp_fence_release(get<1>(input_tensormaps_sf), shared_tensormaps.smem_tensormap_SFB); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::TmaDescriptor const* input_tma_desc) { + cute::tma_descriptor_fence_acquire(input_tma_desc); + } + +protected: + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + typename Params::TMA_B const* observed_tma_load_b_{nullptr}; + typename Params::TMA_SFA const* observed_tma_load_sfa_{nullptr}; + typename Params::TMA_SFB const* observed_tma_load_sfb_{nullptr}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm103_blockscaled_mma_warpspecialized.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm103_blockscaled_mma_warpspecialized.hpp new file mode 100644 index 0000000..920d7e7 --- /dev/null +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm103_blockscaled_mma_warpspecialized.hpp @@ -0,0 +1,1276 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/sm103_blockscaled_layout.hpp" +#include "cutlass/detail/collective/sm103_kernel_type.hpp" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/sm100_tmem_helper.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int LoadABPipelineStageCount, + int LoadSFPipelineStageCount, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, int) + cutlass::sm103::detail::KernelPrefetchType PrefetchType, + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementPairA_, + class StridePairA_, + class ElementPairB_, + class StridePairB_, + class TiledMma_, + class GmemTiledCopyPairA_, + class SmemLayoutAtomPairA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyPairB_, + class SmemLayoutAtomPairB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm103TmaUmmaWarpSpecializedBlockScaled< + LoadABPipelineStageCount, + LoadSFPipelineStageCount, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape, + PrefetchType>, + TileShape_, + ElementPairA_, + StridePairA_, + ElementPairB_, + StridePairB_, + TiledMma_, + GmemTiledCopyPairA_, + SmemLayoutAtomPairA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyPairB_, + SmemLayoutAtomPairB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + using DispatchPolicy = MainloopSm103TmaUmmaWarpSpecializedBlockScaled< + LoadABPipelineStageCount, + LoadSFPipelineStageCount, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape, + PrefetchType>; + + using TileShape = TileShape_; + // Due to an MSVC bug, we can't use decltype(make_tiled_mma()) interface. + using TiledMMA_SF = TiledMMA, + Layout>, + Tile>; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + static constexpr int SFVecSize = TiledMma::SFVecSize; + static constexpr bool IsOverlappingAccum = DispatchPolicy::IsOverlappingAccum; + + // Assert that TiledMma and TileShape should be weakly compatible + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TiledMma and TileShape should be weakly compatible"); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + static_assert(shape<1>(CtaShape_MNK{}) == 192 or shape<1>(CtaShape_MNK{}) == 128 or shape<1>(CtaShape_MNK{}) == 256, + "Cta N should be one of 128/192/256"); + + using ClusterTileShape = decltype(make_shape(get<0>(TileShape{})*get<0>(ClusterShape{}),get<1>(TileShape{})*get<1>(ClusterShape{}),get<2>(TileShape{})*get<2>(ClusterShape{}))); + using Sm1xxBlkScaledConfig = cutlass::detail::Sm103BlockScaledConfig; + using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; + static constexpr int IsCtaN192 = shape<1>(CtaShape_MNK{}) == 192; + static int constexpr CTA_N_SF = cutlass::round_up(size<1>(CtaShape_MNK{}), Blk_MN{}); + // Tile shape used for partitioning Scale Factor B. + // The M-dim does not affect the SFB, so just set it as the original TileShape; + using TileShape_SF = decltype(make_shape(get<0>(CtaShape_MNK{}), + Int{} * shape<2>(typename TiledMma::ThrLayoutVMNK()), + get<2>(TileShape{}))); + + static int constexpr SF_BUFFERS_PER_TILE_K = SFVecSize == 16 ? 4 : 2; + using MMA_SF_Tiler = decltype(make_tile(shape<0>(CtaShape_MNK{}), Int{}, Int(CtaShape_MNK{})/SF_BUFFERS_PER_TILE_K>{})); + + using ElementPairA = ElementPairA_; + using ElementPairB = ElementPairB_; + using ElementAMma = typename TiledMma::ValTypeA; + using ElementBMma = typename TiledMma::ValTypeB; + using StridePairA = StridePairA_; + using StridePairB = StridePairB_; + using SmemLayoutAtomPairA = SmemLayoutAtomPairA_; + using SmemLayoutAtomPairB = SmemLayoutAtomPairB_; + static_assert(cute::is_same_v(ElementPairA{}))>, + remove_cvref_t(ElementPairB{}))>>, "SFA and SFB data types should be the same"); + + // A and B matrices + using ElementA = remove_cvref_t(ElementPairA{}))>; + using StrideA = remove_cvref_t(StridePairA{}))>; + + using ElementB = remove_cvref_t(ElementPairB{}))>; + using StrideB = remove_cvref_t(StridePairB{}))>; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + // SFA and SFB + using ElementSF = remove_cvref_t(ElementPairA{}))>; + using LayoutSFA = remove_cvref_t(StridePairA{}))>; + using LayoutSFB = remove_cvref_t(StridePairB{}))>; + + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyPairA = GmemTiledCopyPairA_; + using GmemTiledCopyPairB = GmemTiledCopyPairB_; + using GmemTiledCopyA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopySFA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopyB = remove_cvref_t(GmemTiledCopyPairB{}))>; + using GmemTiledCopySFB = remove_cvref_t(GmemTiledCopyPairB{}))>; + + using SmemLayoutAtomA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomSFA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + using SmemLayoutAtomSFB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopABPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::LoadABPipelineStageCount, + ClusterShape, + AtomThrShapeMNK>; + using MainloopABPipelineState = typename MainloopABPipeline::PipelineState; + + using MainloopSFPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::LoadSFPipelineStageCount, + ClusterShape, + AtomThrShapeMNK>; + using MainloopSFPipelineState = typename MainloopSFPipeline::PipelineState; + + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(make_shape(make_shape(shape<0>(CtaShape_MNK{}), _16{}), _1{}, _8{}), Int{} /*PIPE*/), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); // ((CTA_MMA_M,16bytes),1,8,NUM_PIPES) + using SmemLayoutA_tma = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(make_shape(make_shape(shape<0>(CtaShape_MNK{}), _16{}), _1{}, _8{}), Int<3>{} /*Per mainloop iteration */), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); // ((CTA_MMA_M,16bytes),1,8,3) + + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(make_shape(make_shape(shape<1>(CtaShape_MNK{}) / size(typename TiledMma::AtomThrID{}), _16{}), _1{}, _8{}), Int{} /*PIPE*/), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); // ((CTA_MMA_N,16bytes),1,8,NUM_PIPES) + using SmemLayoutB_tma = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(make_shape(make_shape(shape<1>(CtaShape_MNK{}) / size(typename TiledMma::AtomThrID{}), _16{}), _1{}, _8{}), Int<3>{} /*Per mainloop iteration */), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); // ((CTA_MMA_N,16bytes),1,8,3) + + + // SmemLayoutAtomSFA and SmemLayoutAtomSFB are for whole CTA tiles. We add the number of pipeline stages here. + // The number of pipeline stages is the same as the number of pipeline stages from AB Load <-> MainLoop + using SmemLayoutSFA = decltype(make_layout( + append(shape(SmemLayoutAtomSFA{}), Int{}), + append(stride(SmemLayoutAtomSFA{}), size(filter_zeros(SmemLayoutAtomSFA{}))) + )); + using SmemLayoutSFB = decltype(make_layout( + append(shape(SmemLayoutAtomSFB{}), Int{}), + append(stride(SmemLayoutAtomSFB{}), size(filter_zeros(SmemLayoutAtomSFB{}))) + )); + + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + + static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + + using TmaInternalElementA = uint8_t; + using TmaInternalElementB = uint8_t; + + using SmemAllocTypeA = uint8_t; + using SmemAllocTypeB = uint8_t; + + using BitTypeElementA = cute::uint_bit_t>; + using BitTypeElementB = cute::uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = typename detail::sm10x_block_scale_runtime_input_t::Type; + using RuntimeDataTypeB = typename detail::sm10x_block_scale_runtime_input_t::Type; + + using SmemPrefetchType = uint8_t; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + cute::ArrayEngine> smem_SFA; + cute::ArrayEngine> smem_SFB; + } tensors; + + using PipelineABStorage = typename MainloopABPipeline::SharedStorage; + using PipelineSFStorage = typename MainloopSFPipeline::SharedStorage; + struct PipelineStorage { + PipelineABStorage pipeline_ab; + PipelineSFStorage pipeline_sf; + }; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr uint32_t SFTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFB{})) * cute::sizeof_bits_v); + // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly + static constexpr uint32_t ABTmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v); + + // Host side kernel arguments + struct Arguments { + ArrayElementA const* ptr_A{nullptr}; + StrideA dA{}; + ArrayElementB const* ptr_B{nullptr}; + StrideB dB{}; + ElementSF const* ptr_SFA{nullptr}; + LayoutSFA layout_SFA{}; + ElementSF const* ptr_SFB{nullptr}; + LayoutSFB layout_SFB{}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = + decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), + ClusterShape{})), make_tile(typename TiledMma::AtomThrID{}))); + using ClusterLayoutSfb_VMNK = + decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), + ClusterShape{})), make_tile(typename TiledMMA_SF::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom( + GmemTiledCopyA{}, + recast(make_tensor(recast_ptr(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{})), + SmemLayoutA_tma{}, + make_tile(size<1,0>(typename TiledMma::ALayout{}), _384{}), + size<1>(ClusterShape{})) + ); + + using TMA_B = decltype(make_tma_atom( + GmemTiledCopyB{}, + recast(make_tensor(recast_ptr(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{})), + SmemLayoutB_tma{}, + make_tile(size<1,0>(typename TiledMma::BLayout{}), _384{}), + size<0>(ClusterShape{})/size(typename TiledMma::AtomThrID{})) + ); + + using TMA_SFA = decltype(make_tma_atom( // using legacy sm90 make_tma_atom + GmemTiledCopySFA{}, + make_tensor(static_cast(nullptr), LayoutSFA{}), + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + make_shape(get<0>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<1>(ClusterShape{})) + ); + + using TMA_SFB = decltype(make_tma_atom( // using legacy sm90 make_tma_atom + GmemTiledCopySFB{}, + make_tensor(static_cast(nullptr), LayoutSFB{}), + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + make_shape(get<1>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<0>(ClusterShape{})/size(typename TiledMMA_SF::AtomThrID{})) + ); + + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_SFA tma_load_sfa; + TMA_SFB tma_load_sfb; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + TMA_SFA tma_load_sfa_fallback; + TMA_SFB tma_load_sfb_fallback; + LayoutSFA layout_SFA; + LayoutSFB layout_SFB; + dim3 cluster_shape_fallback; + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params) { + if constexpr (IsDynamicCluster) { + dim3 cs = cute::cluster_shape(); + const bool is_fallback_cluster = (cs.x == params.cluster_shape_fallback.x && cs.y == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + observed_tma_load_sfa_ = is_fallback_cluster ? ¶ms.tma_load_sfa_fallback : ¶ms.tma_load_sfa; + observed_tma_load_sfb_ = is_fallback_cluster ? ¶ms.tma_load_sfb_fallback : ¶ms.tma_load_sfb; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + observed_tma_load_sfa_ = ¶ms.tma_load_sfa; + observed_tma_load_sfb_ = ¶ms.tma_load_sfb; + + } + } + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_B = recast_ptr(args.ptr_B); + + Tensor tensor_a = recast(make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA))); + Tensor tensor_b = recast(make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB))); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + + Tensor tensor_sfa = make_tensor(args.ptr_SFA, args.layout_SFA); + Tensor tensor_sfb = make_tensor(args.ptr_SFB, args.layout_SFB); + + typename Params::TMA_A tma_load_a = make_tma_atom( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA_tma{}, + make_tile(size<1,0>(typename TiledMma::ALayout{}), _384{}), + size<1>(cluster_shape) + ); + typename Params::TMA_B tma_load_b = make_tma_atom( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB_tma{}, + make_tile(size<1,0>(typename TiledMma::BLayout{}), _384{}), + size<0>(cluster_shape)/size(typename TiledMma::AtomThrID{}) + ); + typename Params::TMA_A tma_load_a_fallback = make_tma_atom( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA_tma{}, + make_tile(size<1,0>(typename TiledMma::ALayout{}), _384{}), + size<1>(cluster_shape_fallback) + ); + typename Params::TMA_B tma_load_b_fallback = make_tma_atom( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB_tma{}, + make_tile(size<1,0>(typename TiledMma::BLayout{}), _384{}), + size<0>(cluster_shape_fallback)/size(typename TiledMma::AtomThrID{}) + ); + typename Params::TMA_SFA tma_load_sfa = make_tma_atom( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + make_shape(get<0>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<1>(cluster_shape) + ); + typename Params::TMA_SFB tma_load_sfb = make_tma_atom( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + make_shape(get<1>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<0>(cluster_shape)/size(typename TiledMMA_SF::AtomThrID{}) + ); + typename Params::TMA_SFA tma_load_sfa_fallback = make_tma_atom( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + make_shape(get<0>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<1>(cluster_shape_fallback) + ); + typename Params::TMA_SFB tma_load_sfb_fallback = make_tma_atom( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + make_shape(get<1>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<0>(cluster_shape_fallback)/size(typename TiledMMA_SF::AtomThrID{}) + ); + return { + tma_load_a, + tma_load_b, + tma_load_sfa, + tma_load_sfb, + tma_load_a_fallback, + tma_load_b_fallback, + tma_load_sfa_fallback, + tma_load_sfb_fallback, + args.layout_SFA, + args.layout_SFB, + hw_info.cluster_shape_fallback, + args.runtime_data_type_a, + args.runtime_data_type_b, + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cute::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if constexpr (IsRuntimeDataType && detail::is_sm10x_mxf4nvf4_input() && detail::is_sm10x_mxf4nvf4_input()) { + bool is_compatible = (SFVecSize == 16 || + (SFVecSize == 32 && is_same_v + && args.runtime_data_type_a == cute::UMMA::MXF4Format::E2M1 + && args.runtime_data_type_b == cute::UMMA::MXF4Format::E2M1)); + if (!is_compatible) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: 2x mode (VectorSize=32) only supports float_e2m1_t for a/b types and ue8m0_t for sf type.\n"); + } + implementable &= is_compatible; + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE static void + prefetch_tma_descriptors(Params const& params) { + if constexpr (IsDynamicCluster) { + dim3 cs = cute::cluster_shape(); + const bool is_fallback_cluster = (cs.x == params.cluster_shape_fallback.x && cs.y == params.cluster_shape_fallback.y); + if (is_fallback_cluster) { + cute::prefetch_tma_descriptor(params.tma_load_a_fallback.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b_fallback.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_sfa_fallback.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_sfb_fallback.get_tma_descriptor()); + } + else { + cute::prefetch_tma_descriptor(params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_sfa.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_sfb.get_tma_descriptor()); + } + } + else { + cute::prefetch_tma_descriptor(params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_sfa.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_sfb.get_tma_descriptor()); + } + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tma tensor for input A + /// gB_nkl - The tiled tma tensor for input B + /// tAgA_mkl - partitioned gmem tensor for A + /// tBgB_nkl - partitioned gmem tensor for B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + /// mcast_mask_a - tma multicast mask for A + /// mcast_mask_b - tma multicast mask for B + template + CUTLASS_DEVICE auto + load_ab_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + int K_recast = (K*cute::sizeof_bits_v/8); + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K_recast,L)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K_recast,L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, replace<2>(TileShape{}, _384{}), make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor gB_nkl = local_tile(mB_nkl, replace<2>(TileShape{}, _384{}), make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + // Partition for this CTA + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl_tmp = cta_mma.partition_A(gA_mkl); // ((CTA_MMA_M,96),Rest_MMA_M,Rest_MMA_K, m, k, l) + Tensor cta_tCgA = make_tensor(tCgA_mkl_tmp.data(), make_layout(coalesce(make_layout(cute::layout<0,0>(tCgA_mkl_tmp), cute::layout<1>(tCgA_mkl_tmp))), + coalesce(make_layout(cute::layout<0,1>(tCgA_mkl_tmp), cute::layout<2>(tCgA_mkl_tmp))), + cute::layout<3>(tCgA_mkl_tmp), cute::layout<4>(tCgA_mkl_tmp), cute::layout<5>(tCgA_mkl_tmp))); // (CTA_M,CTA_K,m,k,l) + + Tensor tCgA_mkl = make_tensor(cta_tCgA.data(), tiled_divide(cta_tCgA.layout(), + make_tile(size<1,0>(typename TiledMma::ALayout{}) /*MMA_M for SM100*/, + _128{} /*128bytes*/))); // ((CTA_MMA_M,256),Rest_MMA_M,Rest_MMA_K, m, k, l) + + Tensor tCgB_nkl_tmp = cta_mma.partition_B(gB_nkl); // ((MMA_ATOM_M,96),Rest_MMA_M,Rest_MMA_K, n, k, l) + Tensor cta_tCgB = make_tensor(tCgB_nkl_tmp.data(), make_layout(coalesce(make_layout(cute::layout<0,0>(tCgB_nkl_tmp), cute::layout<1>(tCgB_nkl_tmp))), + coalesce(make_layout(cute::layout<0,1>(tCgB_nkl_tmp), cute::layout<2>(tCgB_nkl_tmp))), + cute::layout<3>(tCgB_nkl_tmp), cute::layout<4>(tCgB_nkl_tmp), cute::layout<5>(tCgB_nkl_tmp))); // (CTA_M,CTA_K,m,k,l) + Tensor tCgB_nkl = make_tensor(cta_tCgB.data(), tiled_divide(cta_tCgB.layout(), + make_tile(size<1,0>(typename TiledMma::BLayout{}) /*MMA_M for SM100*/, + _128{} /*128bytes*/))); // ((CTA_MMA_M,256),Rest_MMA_M, Rest_MMA_K, m, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // ((CTA_MMA_M,32),Rest_MMA_M,8,NUM_PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // ((CTA_MMA_N,32),Rest_MMA_N,8,NUM_PIPE) + + + Layout cta_layout_mnk = make_layout(cutlass::detail::select_cluster_shape(ClusterShape{}, cute::cluster_shape())); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + int block_rank_in_cluster = cute::block_rank_in_cluster(); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); + + Layout cta_layout_sfb_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA_SF::AtomThrID{})); + auto cta_coord_sfb_vmnk = cta_layout_sfb_vmnk.get_flat_coord(block_rank_in_cluster); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,1>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,1>(tCgB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + return cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b // multicast masks + ); + } + + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// tAgA_mkl - partitioned gmem tensor for A + /// tBgB_nkl - partitioned gmem tensor for B + /// mcast_mask_sfa - tma multicast mask for SFA + /// mcast_mask_sfb - tma multicast mask for SFB + template + CUTLASS_DEVICE auto + load_sf_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // Represent the full tensor of Scale factors + Tensor mSFA_mkl = observed_tma_load_sfa_->get_tma_tensor(shape(params.layout_SFA)); + auto mSFB_nkl = [=](){ + if constexpr (IsCtaN192) { + Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(params.layout_SFB)); + auto x = stride<0,1>(mSFB_tmp); + auto y = ceil_div(shape<0,1>(mSFB_tmp), 4); + auto new_shape = make_shape (make_shape( shape<0,0>(mSFB_tmp), + make_shape( make_shape(_2{}, _2{}), y)), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); + auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), + make_stride(make_stride( x, x), x*3)), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); + return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); + } + else { + return observed_tma_load_sfb_->get_tma_tensor(shape(params.layout_SFB)); + } + }(); + + // Partition for this CTA + Tensor gSFA_mkl = local_tile(mSFA_mkl, MMA_SF_Tiler{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (TILE_M,TILE_K,m,k,l) + Tensor gSFB_nkl = local_tile(mSFB_nkl, MMA_SF_Tiler{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (TILE_N,TILE_K,n,k,l) + + Tensor tCgSFA_mkl = make_tensor(gSFA_mkl.data(), tiled_divide(gSFA_mkl.layout(), make_tile(get<0>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})))); // ((MMA_M,MMA_K),Rest_MMA_M,Rest_MMA_K, m, k, l) + Tensor tCgSFB_nkl = make_tensor(gSFB_nkl.data(), tiled_divide(gSFB_nkl.layout(), make_tile(get<1>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})))); // ((MMA_N,MMA_K),Rest_MMA_N,Rest_MMA_K, n, k, l) + + Tensor tCsSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor tCsSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + Layout cta_layout_mnk = make_layout(cutlass::detail::select_cluster_shape(ClusterShape{}, cute::cluster_shape())); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + int block_rank_in_cluster = cute::block_rank_in_cluster(); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); + + Layout cta_layout_sfb_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA_SF::AtomThrID{})); + auto cta_coord_sfb_vmnk = cta_layout_sfb_vmnk.get_flat_coord(block_rank_in_cluster); + // Project the cta_layout for tma_a along the n-modes + auto [tAgSFA_mkl, tAsSFA] = tma_partition(*observed_tma_load_sfa_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(tCsSFA), group_modes<0,3>(tCgSFA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgSFB_nkl, tBsSFB] = tma_partition(*observed_tma_load_sfb_, + get<1>(cta_coord_sfb_vmnk), make_layout(size<1>(cta_layout_sfb_vmnk)), + group_modes<0,3>(tCsSFB), group_modes<0,3>(tCgSFB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_sfa = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_sfb = create_tma_multicast_mask<1>(cta_layout_sfb_vmnk, cta_coord_sfb_vmnk); + + return cute::make_tuple( + tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB, // for input scale factor tensor values + mcast_mask_sfa, mcast_mask_sfb // multicast masks + ); + } + + /// Set up the data needed by this collective for mma compute. + CUTLASS_DEVICE auto + mma_init( + Params const& params, + TensorStorage& shared_tensors, + uint32_t const tmem_offset) const { + + // Allocate "fragments/descriptors" for A and B matrices + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // ((CTA_MMA_M,32),Rest_MMA_M,8,NUM_PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // ((CTA_MMA_M,32),Rest_MMA_M,8,NUM_PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = make_tensor(sA);; + Tensor tCrB = make_tensor(sB);; + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); // PIPE + + // + // Scale Factor + // + Tensor tCtSFA = make_tensor(take<0,3>(shape(SmemLayoutAtomSFA{}))); + // TMEM allocations for SFA and SFB will always start at DP 0. + tCtSFA.data() = tmem_offset; + Tensor tCtSFB = make_tensor(take<0,3>(shape(SmemLayoutAtomSFB{}))); + + tCtSFB.data() = tCtSFA.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tCtSFA); + + // Setup smem descriptors for UTCCP + Tensor tCsSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor tCsSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + // Make SMEM and TMEM tensors compact removing the zero strides to eliminate unnecessary copy instructions. + auto tCsSFA_compact = make_tensor(tCsSFA.data(), filter_zeros(tCsSFA.layout())); + auto tCtSFA_compact = make_tensor(tCtSFA.data(), filter_zeros(tCtSFA.layout())); + auto tCsSFB_compact = make_tensor(tCsSFB.data(), filter_zeros(tCsSFB.layout())); + auto tCtSFB_compact = make_tensor(tCtSFB.data(), filter_zeros(tCtSFB.layout())); + + // Create the SMEM to TMEM copy operations based on the MMA atom used (1CTA vs 2CTA) + using AtomThrID = typename TiledMma::AtomThrID; + using UtccpOp = cute::conditional_t<(decltype(cute::size(AtomThrID{}) == Int<2>{})::value), + SM100_UTCCP_4x32dp128bit_2cta, SM100_UTCCP_4x32dp128bit_1cta>; + auto tCtSFA_compact_copy = make_tensor(tCtSFA_compact.data(), append<3>(tCtSFA_compact(_,_0{},_0{}).layout())); + auto tCtSFB_compact_copy = make_tensor(tCtSFB_compact.data(), append<3>(tCtSFB_compact(_,_0{},_0{}).layout())); + auto tiled_copy_s2t_SFA = make_utccp_copy(UtccpOp{}, tCtSFA_compact_copy); + auto tiled_copy_s2t_SFB = make_utccp_copy(UtccpOp{}, tCtSFB_compact_copy); + + auto thr_copy_s2t_SFA = tiled_copy_s2t_SFA.get_slice(0); + auto thr_tCsSFA_compact_s2t_ = thr_copy_s2t_SFA.partition_S(tCsSFA_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFA_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFA_compact_s2t_); + auto thr_tCtSFA_compact_s2t = thr_copy_s2t_SFA.partition_D(tCtSFA_compact); + + auto thr_copy_s2t_SFB = tiled_copy_s2t_SFB.get_slice(0); + auto thr_tCsSFB_compact_s2t_ = thr_copy_s2t_SFB.partition_S(tCsSFB_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFB_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFB_compact_s2t_); + auto thr_tCtSFB_compact_s2t = thr_copy_s2t_SFB.partition_D(tCtSFB_compact); + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + tiled_mma.idesc_.a_format_ = uint8_t(params.runtime_data_type_a) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(params.runtime_data_type_b) & 0b111; + } + + // using MMA_SF_Tiler = decltype(make_tile(shape<0>(CtaShape_MNK{}), Int{}, Int(CtaShape_MNK{})/2>{})); // 128x128x384 + // MMA shapes are ((_128,_96),_1,_8) which makes the MMA_SFA_Shape ((128, (16,3)), 1, 8/3) + // The number is not divisible by 4 in K dimension which is needed for TMEM allocation. + // To be able to iterate thru the SFs for MMA, we model this as (MMA), MMA_M, MMA_K: ((128, (16,1)), 1, 24) + // with this layout we can iterate thru the SFs by incrementing MMA_K mode by 3/6 for this example (Vs=16 vs Vs=32). + constexpr int MMA_M = size<0>(CtaShape_MNK{}); + constexpr int MMA_N_SF = CTA_N_SF; + constexpr int MMA_K_SF = shape<2>(CtaShape_MNK{}) / 2; + auto mnBasicBlockShape = make_shape(_32{}, _4{}); + auto kBasicBlockShape_single = make_shape(Int{}, Int<1>{}); + auto mma_iter_SFA_shape = make_shape( prepend(Int{}, mnBasicBlockShape), kBasicBlockShape_single); + auto sSFA_iter_shape = make_shape(mma_iter_SFA_shape, _1{}, Int{}); + auto mma_iter_SFB_shape = make_shape( prepend(Int{}, mnBasicBlockShape), kBasicBlockShape_single); + auto sSFB_iter_shape = make_shape(mma_iter_SFB_shape, _1{}, Int{}); + + // Used for MMAs + using MmaIterShapeSFA = decltype(sSFA_iter_shape); // ((32,4),(SFVecSize,1), MMA_M/128, SF_MMA_K/SfVecSize + using MmaIterShapeSFB = decltype(sSFB_iter_shape); // ((32,4),(SFVecSize,1), MMA_N/128, SF_MMA_K/SfVecSize + + Tensor tCtSFA_mma = make_tensor(MmaIterShapeSFA{}); + tCtSFA_mma.data() = tCtSFA.data(); + Tensor tCtSFB_mma = make_tensor(MmaIterShapeSFB{}); + tCtSFB_mma.data() = tCtSFB.data(); + + return cute::make_tuple( + tiled_mma, + tCrA, tCrB, tCtSFA, tCtSFB, tCtSFA_mma, tCtSFB_mma, + tiled_copy_s2t_SFA, thr_tCsSFA_compact_s2t, thr_tCtSFA_compact_s2t, + tiled_copy_s2t_SFB, thr_tCsSFB_compact_s2t, thr_tCtSFB_compact_s2t); + } + +// Helper function to handle both prefetch types + template + CUTLASS_DEVICE void issue_prefetch( + int& prefetch_k_tile_count, + int& prefetch_buf_idx, + KTileIterator& prefetch_k_tile, + TmaPrefetchFn&& tma_prefetch_fn + ) + { + if (prefetch_k_tile_count > 0) { + if constexpr (PrefetchType == cutlass::sm103::detail::KernelPrefetchType::TmaPrefetch) { + tma_prefetch_fn(); + } + prefetch_buf_idx = (prefetch_buf_idx + 1) % BuffersPerKtile; + if(prefetch_buf_idx == 0) { + ++prefetch_k_tile; + --prefetch_k_tile_count; + } + } + } + + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load_ab( + Params const& params, + MainloopABPipeline pipeline, + MainloopABPipelineState mainloop_pipe_producer_state, + cute::tuple const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count, + int prefetch_k_tile_count = 0) { + + auto tAgA_mkl = get<2>(load_inputs); + auto tBgB_nkl = get<3>(load_inputs); + auto tAsA = get<4>(load_inputs); + auto tBsB = get<5>(load_inputs); + auto mcast_mask_a = get<6>(load_inputs); + auto mcast_mask_b = get<7>(load_inputs); + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mkl(_, _, _, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgB = tBgB_nkl(_, _, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = pipeline.producer_try_acquire(mainloop_pipe_producer_state); + constexpr int BuffersPerKtile = 3; + auto prefetch_k_tile = k_tile_iter; + auto prefetch_buf_idx = 0; + auto tile_k_advance = LoadABPipelineStageCount / BuffersPerKtile; + + if constexpr (PrefetchType != cutlass::sm103::detail::KernelPrefetchType::Disable) { + prefetch_buf_idx = LoadABPipelineStageCount % BuffersPerKtile; + CUTLASS_PRAGMA_UNROLL + for (int i=0;i 0) { + using BarrierType = typename MainloopABPipeline::ProducerBarrierType; + // In total, we will load 3 buffers per k_tile_iter. Unrolled. + CUTLASS_PRAGMA_UNROLL + for(int buffer = 0; buffer < BuffersPerKtile; buffer++) { + pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + BarrierType* tma_barrier = pipeline.producer_get_barrier(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), group_modes<0,2>(tAgA(_,_,buffer,*k_tile_iter)), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(*tma_barrier, mcast_mask_b), group_modes<0,2>(tBgB(_,_,buffer,*k_tile_iter)), tBsB(_,write_stage)); + } + + if constexpr (PrefetchType != cutlass::sm103::detail::KernelPrefetchType::Disable) { + issue_prefetch ( + prefetch_k_tile_count, + prefetch_buf_idx, + prefetch_k_tile, + [&]() { + prefetch(*observed_tma_load_a_, group_modes<0,2>(tAgA(_,_,prefetch_buf_idx,*prefetch_k_tile))); + prefetch(*observed_tma_load_b_, group_modes<0,2>(tBgB(_,_,prefetch_buf_idx,*prefetch_k_tile))); + } + ); + } + } + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class GTensorPartitionedSFA, class GTensorPartitionedSFB, + class STensorSFA, class STensorSFB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load_sf( + Params const& params, + MainloopSFPipeline pipeline, + MainloopSFPipelineState mainloop_sf_pipe_producer_state, + cute::tuple const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count, + int prefetch_k_tile_count = 0) { + + auto tAgSFA_mkl = get<0>(load_inputs); + auto tBgSFB_nkl = get<1>(load_inputs); + auto tAsSFA = get<2>(load_inputs); + auto tBsSFB = get<3>(load_inputs); + auto mcast_mask_sfa = get<4>(load_inputs); + auto mcast_mask_sfb = get<5>(load_inputs); + // slice out the work coord from partitioned tensors + Tensor tAgSFA = tAgSFA_mkl(_, get<0>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + Tensor tBgSFB = tBgSFB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = pipeline.producer_try_acquire(mainloop_sf_pipe_producer_state); + + using BarrierType = typename MainloopSFPipeline::ProducerBarrierType; + auto tAsSFA_compact = make_tensor(tAsSFA.data(), filter_zeros(tAsSFA.layout())); + auto tBsSFB_compact = make_tensor(tBsSFB.data(), filter_zeros(tBsSFB.layout())); + auto prefetch_k_tile = k_tile_iter; + auto prefetch_buf_idx = 0; + auto tile_k_advance = LoadSFPipelineStageCount / SF_BUFFERS_PER_TILE_K; + + if constexpr (PrefetchType != cutlass::sm103::detail::KernelPrefetchType::Disable) { + prefetch_buf_idx = LoadSFPipelineStageCount % SF_BUFFERS_PER_TILE_K; + CUTLASS_PRAGMA_UNROLL + for (int i=0;i 0) { + // In total, we will load 2 or 4 buffers per k_tile_iter. Unrolled. + CUTLASS_PRAGMA_UNROLL + for(int buffer = 0; buffer < SF_BUFFERS_PER_TILE_K; buffer++) { + pipeline.producer_acquire(mainloop_sf_pipe_producer_state, barrier_token); + BarrierType* tma_barrier = pipeline.producer_get_barrier(mainloop_sf_pipe_producer_state); + + int write_stage = mainloop_sf_pipe_producer_state.index(); + ++mainloop_sf_pipe_producer_state; + barrier_token = pipeline.producer_try_acquire(mainloop_sf_pipe_producer_state); + auto tAgSFA_compact = make_tensor(tAgSFA(_,*k_tile_iter*SF_BUFFERS_PER_TILE_K + buffer).data(), filter_zeros(tAgSFA(_,*k_tile_iter*SF_BUFFERS_PER_TILE_K + buffer).layout())); + auto tBgSFB_compact = make_tensor(tBgSFB(_,*k_tile_iter*SF_BUFFERS_PER_TILE_K + buffer).data(), filter_zeros(tBgSFB(_,*k_tile_iter*SF_BUFFERS_PER_TILE_K + buffer).layout())); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_sfa_->with(*tma_barrier, mcast_mask_sfa), tAgSFA_compact, tAsSFA_compact(_,write_stage)); + copy(observed_tma_load_sfb_->with(*tma_barrier, mcast_mask_sfb), tBgSFB_compact, tBsSFB_compact(_,write_stage)); + } + #if 0 + if(threadIdx.x == 256 && blockIdx.x == 1 && blockIdx.y == 0) { + print("tAgSFA_compact: "); print(tAgSFA_compact); print("\n"); + print("tBgSFB_compact: "); print(tBgSFB_compact); print("\n"); + } + #endif + + auto tAgSFA_compact_prefetch = make_tensor(tAgSFA(_,*prefetch_k_tile*SF_BUFFERS_PER_TILE_K + prefetch_buf_idx).data(), filter_zeros(tAgSFA(_,*prefetch_k_tile*SF_BUFFERS_PER_TILE_K + prefetch_buf_idx).layout())); + auto tBgSFB_compact_prefetch = make_tensor(tBgSFB(_,*prefetch_k_tile*SF_BUFFERS_PER_TILE_K + prefetch_buf_idx).data(), filter_zeros(tBgSFB(_,*prefetch_k_tile*SF_BUFFERS_PER_TILE_K + prefetch_buf_idx).layout())); + + if constexpr (PrefetchType != cutlass::sm103::detail::KernelPrefetchType::Disable) { + issue_prefetch ( + prefetch_k_tile_count, + prefetch_buf_idx, + prefetch_k_tile, + [&]() { + prefetch(*observed_tma_load_sfa_, tAgSFA_compact_prefetch); + prefetch(*observed_tma_load_sfb_, tBgSFB_compact_prefetch); + } + ); + } + } + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_sf_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + template < + class MainloopPipeline, class MainloopPipelineState + > + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, MainloopPipelineState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class AccumulatorPipeline, + class FrgEngine, class FrgLayout, + class FragmentA, class FragmentB, + class FragmentSFA, class FragmentSFB, + class MmaFragmentSFA, class MmaFragmentSFB, + class CtaTileCoord, + class SFATiledCopy, class SmemFrgSFA, class TmemFrgSFA, + class SFBTiledCopy, class SmemFrgSFB, class TmemFrgSFB + > + CUTLASS_DEVICE auto + mma(cute::tuple pipelines, + cute::tuple pipeline_states, + cute::Tensor& accumulators, + cute::tuple const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + auto pipeline_ab = get<0>(pipelines); + auto pipeline_sf = get<1>(pipelines); + auto accumulator_pipeline = get<2>(pipelines); + auto mainloop_pipe_ab_consumer_state = get<0>(pipeline_states); + auto mainloop_pipe_sf_consumer_state = get<1>(pipeline_states); + auto accumulator_pipe_producer_state = get<2>(pipeline_states); + auto tiled_mma = get<0>(mma_inputs); + auto tCrA = get<1>(mma_inputs); + auto tCrB = get<2>(mma_inputs); + auto tCtSFA = get<3>(mma_inputs); + auto tCtSFB = get<4>(mma_inputs); + auto tCtSFA_mma = get<5>(mma_inputs); + auto tCtSFB_mma = get<6>(mma_inputs); + auto tiled_copy_s2t_SFA = get<7>(mma_inputs); + auto tCsSFA_s2t = get<8>(mma_inputs); + auto tCtSFA_s2t = get<9>(mma_inputs); + auto tiled_copy_s2t_SFB = get<10>(mma_inputs); + auto tCsSFB_s2t = get<11>(mma_inputs); + auto tCtSFB_s2t = get<12>(mma_inputs); + + tCtSFB_mma = [tCtSFB_mma = tCtSFB_mma, cta_tile_coord]() { + if constexpr (IsCtaN192) { + // If this is an ODD tile, shift the TMEM start address for N=192 case by two words (ignores first 64 columns of SFB) + auto tCtSFB_tmp = tCtSFB_mma; + if (get<1>(cta_tile_coord) % 2 == 1) { + tCtSFB_tmp.data() = tCtSFB_tmp.data().get() + 2; + } + return tCtSFB_tmp; + } + else { + return tCtSFB_mma; + } + }(); + + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + constexpr int sf_stride = TiledMma::SFVecSize == 16 ? 6 : 3; + auto barrier_token_ab = pipeline_ab.consumer_try_wait(mainloop_pipe_ab_consumer_state); + auto barrier_token_sf = pipeline_sf.consumer_try_wait(mainloop_pipe_sf_consumer_state); + constexpr int MmasPerSfBuffer = 8 / SF_BUFFERS_PER_TILE_K; + + auto sf_load_fn = [&](const int kphase, const int k_tile_count) { + if (kphase % MmasPerSfBuffer == 0) { + pipeline_sf.consumer_wait(mainloop_pipe_sf_consumer_state, barrier_token_sf); + int read_stage_sf_buffer0 = mainloop_pipe_sf_consumer_state.index(); + if (cute::elect_one_sync()) { + copy(tiled_copy_s2t_SFA, tCsSFA_s2t(_,_,_,_,read_stage_sf_buffer0), tCtSFA_s2t); + copy(tiled_copy_s2t_SFB, tCsSFB_s2t(_,_,_,_,read_stage_sf_buffer0), tCtSFB_s2t); + } + auto buffer0_mainloop_pipe_sf_consumer_state = mainloop_pipe_sf_consumer_state; + ++mainloop_pipe_sf_consumer_state; + barrier_token_sf = pipeline_sf.consumer_try_wait(mainloop_pipe_sf_consumer_state, (kphase == 8 - MmasPerSfBuffer) && k_tile_count <= 1); // only skip wait for the last one. + pipeline_sf.consumer_release(buffer0_mainloop_pipe_sf_consumer_state); + } + }; + + bool is_first_iteration = true; + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // MMA 0 + sf_load_fn(0, k_tile_count); + pipeline_ab.consumer_wait(mainloop_pipe_ab_consumer_state, barrier_token_ab); + int read_stage_ab_buffer0 = mainloop_pipe_ab_consumer_state.index(); + auto buffer0_mainloop_pipe_ab_consumer_state = mainloop_pipe_ab_consumer_state; + ++mainloop_pipe_ab_consumer_state; + barrier_token_ab = pipeline_ab.consumer_try_wait(mainloop_pipe_ab_consumer_state); + + // delay the acc acquire to unblock tmem copy. + if constexpr (IsOverlappingAccum) { + if(is_first_iteration) { + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + is_first_iteration = false; + } + }; + + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,0,read_stage_ab_buffer0), // A buffer: Points to buffer[0] + tCrA(_,_,0,read_stage_ab_buffer0), // Next A buffer for circular buffers: Points to buffer[0] + tCtSFA_mma(_, _, 0 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,0,read_stage_ab_buffer0), // B buffer: Points to buffer[0] + tCrB(_,_,0,read_stage_ab_buffer0), // Next B buffer for circular buffers: Points to buffer[0] + tCtSFB_mma(_, _, 0 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + + // MMA 1 + sf_load_fn(1, k_tile_count); + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,3,read_stage_ab_buffer0), // A buffer: Points to buffer[0] + 48 bytes. Note the 3. + tCrA(_,_,0,read_stage_ab_buffer0), // Next A buffer for circular buffers: Points to buffer[0] + tCtSFA_mma(_, _, 1 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,3,read_stage_ab_buffer0), // B buffer: Points to buffer[0] + 48 bytes. Note the 3. + tCrB(_,_,0,read_stage_ab_buffer0), // Next B buffer for circular buffers: Points to buffer[0] + tCtSFB_mma(_, _, 1 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + + // MMA 2 + sf_load_fn(2, k_tile_count); + pipeline_ab.consumer_wait(mainloop_pipe_ab_consumer_state, barrier_token_ab); + int read_stage_ab_buffer1 = mainloop_pipe_ab_consumer_state.index(); + auto buffer1_mainloop_pipe_ab_consumer_state = mainloop_pipe_ab_consumer_state; + ++mainloop_pipe_ab_consumer_state; + barrier_token_ab = pipeline_ab.consumer_try_wait(mainloop_pipe_ab_consumer_state); + + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,6,read_stage_ab_buffer0), // A buffer: Points to buffer[0] + 96 bytes. Note the 6. + tCrA(_,_,0,read_stage_ab_buffer1), // Next A buffer for circular buffers: Points to buffer[1]. + tCtSFA_mma(_, _, 2 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,6,read_stage_ab_buffer0), // B buffer: Points to buffer[0] + 96 bytes. Note the 6. + tCrB(_,_,0,read_stage_ab_buffer1), // Next B buffer for circular buffers: Points to buffer[1]. + tCtSFB_mma(_, _, 2 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + pipeline_ab.consumer_release(buffer0_mainloop_pipe_ab_consumer_state); + + + // MMA 3 + sf_load_fn(3, k_tile_count); + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,1,read_stage_ab_buffer1), // A buffer: Points to buffer[1] + 16 bytes. Note the 1. + tCrA(_,_,0,read_stage_ab_buffer1), // Next A buffer for circular buffers: Points to buffer[1]. + tCtSFA_mma(_, _, 3 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,1,read_stage_ab_buffer1), // B buffer: Points to buffer[1] + 16 bytes. Note the 1. + tCrB(_,_,0,read_stage_ab_buffer1), // Next B buffer for circular buffers: Points to buffer[1]. + tCtSFB_mma(_, _, 3 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + // MMA 4 + sf_load_fn(4, k_tile_count); + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,4,read_stage_ab_buffer1), // A buffer: Points to buffer[1] + 64 bytes. Note the 1. + tCrA(_,_,0,read_stage_ab_buffer1), // Next A buffer for circular buffers: Points to buffer[1]. + tCtSFA_mma(_, _, 4 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,4,read_stage_ab_buffer1), // B buffer: Points to buffer[1] + 64 bytes. Note the 1. + tCrB(_,_,0,read_stage_ab_buffer1), // Next B buffer for circular buffers: Points to buffer[1]. + tCtSFB_mma(_, _, 4 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + // MMA 5 + sf_load_fn(5, k_tile_count); + pipeline_ab.consumer_wait(mainloop_pipe_ab_consumer_state, barrier_token_ab); + int read_stage_ab_buffer2 = mainloop_pipe_ab_consumer_state.index(); + auto buffer2_mainloop_pipe_ab_consumer_state = mainloop_pipe_ab_consumer_state; + ++mainloop_pipe_ab_consumer_state; + barrier_token_ab = pipeline_ab.consumer_try_wait(mainloop_pipe_ab_consumer_state, k_tile_count <= 1); + + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,7,read_stage_ab_buffer1), // A buffer: Points to buffer[1] + 112 bytes. Note the 7. + tCrA(_,_,0,read_stage_ab_buffer2), // Next A buffer for circular buffers: Points to buffer[2]. + tCtSFA_mma(_, _, 5 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,7,read_stage_ab_buffer1), // B buffer: Points to buffer[1] + 112 bytes. Note the 7. + tCrB(_,_,0,read_stage_ab_buffer2), // Next B buffer for circular buffers: Points to buffer[2]. + tCtSFB_mma(_, _, 5 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + pipeline_ab.consumer_release(buffer1_mainloop_pipe_ab_consumer_state); + + // MMA 6 + sf_load_fn(6, k_tile_count); + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,2,read_stage_ab_buffer2), // A buffer: Points to buffer[1] + 32 bytes. Note the 2. + tCrA(_,_,0,read_stage_ab_buffer2), // Next A buffer for circular buffers: Points to buffer[2]. + tCtSFA_mma(_, _, 6 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,2,read_stage_ab_buffer2), // B buffer: Points to buffer[1] + 32 bytes. Note the 2. + tCrB(_,_,0,read_stage_ab_buffer2), // Next B buffer for circular buffers: Points to buffer[2]. + tCtSFB_mma(_, _, 6 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + // MMA 7 + sf_load_fn(7, k_tile_count); + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,5,read_stage_ab_buffer2), // A buffer: Points to buffer[1] + 80 bytes. Note the 5. + tCrA(_,_,0,read_stage_ab_buffer2), // Next A buffer for circular buffers: Points to buffer[2]. + tCtSFA_mma(_, _, 7 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,5,read_stage_ab_buffer2), // B buffer: Points to buffer[1] + 80 bytes. Note the 5. + tCrB(_,_,0,read_stage_ab_buffer2), // Next B buffer for circular buffers: Points to buffer[2]. + tCtSFB_mma(_, _, 7 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + pipeline_ab.consumer_release(buffer2_mainloop_pipe_ab_consumer_state); + --k_tile_count; + } + return cute::make_tuple(mainloop_pipe_ab_consumer_state, mainloop_pipe_sf_consumer_state); + } + +protected: + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + typename Params::TMA_B const* observed_tma_load_b_{nullptr}; + typename Params::TMA_SFA const* observed_tma_load_sfa_{nullptr}; + typename Params::TMA_SFB const* observed_tma_load_sfb_{nullptr}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm120_blockscaled_mma_array_tma.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm120_blockscaled_mma_array_tma.hpp index 6d0f5a1..458ee1a 100755 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm120_blockscaled_mma_array_tma.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm120_blockscaled_mma_array_tma.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -259,8 +259,8 @@ struct CollectiveMma< struct TensorStorage : cute::aligned_struct<128, _0> { alignas(1024) cute::ArrayEngine> smem_A; alignas(1024) cute::ArrayEngine> smem_B; - cute::ArrayEngine> smem_SFA; - cute::ArrayEngine> smem_SFB; + alignas(16) cute::ArrayEngine> smem_SFA; + alignas(16) cute::ArrayEngine> smem_SFB; } tensors; struct TensorMapStorage : cute::aligned_struct<128, _0> { diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp index 84d1ab1..9cb8051 100755 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -256,8 +256,8 @@ struct CollectiveMma< struct TensorStorage : cute::aligned_struct<128, _0> { alignas(1024) cute::ArrayEngine> smem_A; alignas(1024) cute::ArrayEngine> smem_B; - cute::ArrayEngine> smem_SFA; - cute::ArrayEngine> smem_SFB; + alignas(16) cute::ArrayEngine> smem_SFA; + alignas(16) cute::ArrayEngine> smem_SFB; } tensors; using PipelineStorage = typename MainloopPipeline::SharedStorage; alignas(16) PipelineStorage pipeline_storage; diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp index 6442eb3..f65629c 100755 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -101,7 +101,8 @@ struct CollectiveMma< using StridePairB = StridePairB_; using SmemCopyAtomsA = SmemCopyAtomsA_; using SmemCopyAtomsB = SmemCopyAtomsB_; - + using RuntimeDataTypeA = void*; + using RuntimeDataTypeB = void*; using TiledMma = TiledMma_; using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; using DispatchPolicy = MainloopSm120TmaWarpSpecializedSparseBlockScaled; @@ -153,13 +154,13 @@ struct CollectiveMma< // Asymmetric buffering // Tensor A/B could have different buffering, with TILEK, and STAGEs. // It let AsymmetricKRatio equals TILEK_A / TILEK_B, to make sure A/B's - // pipeline keep same steps when procude / consume data. + // pipeline keep same steps when produce / consume data. // Currently, AsymmetricKRatio = {1, 2} is the only support. static constexpr int AsymmetricKRatio = DispatchPolicy::StagesA != DispatchPolicy::StagesB ? 2 : 1; // Construct TileShape for SFB load from GMEM to SMEM. // It is required to keep consistency with BlockScaled granularity defined in Sm1xxBlkScaledConfig. - // So that TileShape for scaling factor needs to be defined as a mutliple of Blk_MN. + // So that TileShape for scaling factor needs to be defined as a multiple of Blk_MN. using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; using TileShapeSF = decltype(make_shape(ceil_div(size<0>(CtaShape_MNK{}), Blk_MN{}) * Blk_MN{}, ceil_div(size<1>(CtaShape_MNK{}), Blk_MN{}) * Blk_MN{}, @@ -295,9 +296,9 @@ struct CollectiveMma< struct TensorStorage : cute::aligned_struct<128> { alignas(1024) cute::ArrayEngine> smem_A; alignas(1024) cute::ArrayEngine> smem_B; - cute::ArrayEngine> smem_SFA; - cute::ArrayEngine> smem_SFB; - cute::ArrayEngine{}> smem_E; + alignas(16) cute::ArrayEngine> smem_SFA; + alignas(16) cute::ArrayEngine> smem_SFB; + alignas(16) cute::ArrayEngine{}> smem_E; } tensors; using PipelineStorageMK = typename MainloopPipelineMK::SharedStorage; diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm120_mma_array_tma_blockwise_scaling.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm120_mma_array_tma_blockwise_scaling.hpp index 3fc3d58..a7f4812 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm120_mma_array_tma_blockwise_scaling.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm120_mma_array_tma_blockwise_scaling.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm120_mma_tma.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm120_mma_tma.hpp index 65f8333..951b179 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm120_mma_tma.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm120_mma_tma.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm120_mma_tma_blockwise_scaling.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm120_mma_tma_blockwise_scaling.hpp index 2f77d66..bc22419 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm120_mma_tma_blockwise_scaling.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm120_mma_tma_blockwise_scaling.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm120_sparse_mma_tma.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm120_sparse_mma_tma.hpp index 3316308..30587e4 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm120_sparse_mma_tma.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm120_sparse_mma_tma.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -136,7 +136,7 @@ struct CollectiveMma< // Asymmetric buffering // Tensor A/B could have different buffering, with TILEK, and STAGEs. // It let AsymmetricKRatio equals TILEK_A / TILEK_B, to make sure A/B's - // pipeline keep same steps when procude / consume data. + // pipeline keep same steps when produce / consume data. static constexpr int AsymmetricKRatio = DispatchPolicy::StagesA != DispatchPolicy::StagesB ? 2 : 1; using TileShapeB = decltype(make_shape(size<0>(TileShape{}), @@ -253,7 +253,7 @@ struct CollectiveMma< struct TensorStorage : cute::aligned_struct<128, _0> { alignas(1024) cute::ArrayEngine> smem_A; alignas(1024) cute::ArrayEngine> smem_B; - cute::ArrayEngine{}> smem_E; + alignas(16) cute::ArrayEngine{}> smem_E; } tensors; using PipelineStorageMK = typename MainloopPipelineMK::SharedStorage; diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm70_mma_twostage.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm70_mma_twostage.hpp index a0c8f2a..2888417 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm70_mma_twostage.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm70_mma_twostage.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -214,14 +214,16 @@ struct CollectiveMma< // Copy Atom retiling // - auto thr_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma).get_thread_slice(thread_idx); - Tensor tCsA = thr_copy_A.partition_S(sA); - Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); + auto smem_tiled_copy_a = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma); + auto thr_copy_A = smem_tiled_copy_a.get_thread_slice(thread_idx); + Tensor tCsA = thr_copy_A.partition_S(sA); + Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M - auto thr_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma).get_thread_slice(thread_idx); - Tensor tCsB = thr_copy_B.partition_S(sB); - Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); + auto smem_tiled_copy_b = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma); + auto thr_copy_B = smem_tiled_copy_b.get_thread_slice(thread_idx); + Tensor tCsB = thr_copy_B.partition_S(sB); + Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N // @@ -239,8 +241,8 @@ struct CollectiveMma< __syncthreads(); // Load A, B smem->rmem for k=0 - copy(tCsA(_,_,0), tCrA_copy_view(_,_,0)); - copy(tCsB(_,_,0), tCrB_copy_view(_,_,0)); + copy(smem_tiled_copy_a, tCsA(_,_,0), tCrA_copy_view(_,_,0)); + copy(smem_tiled_copy_b, tCsB(_,_,0), tCrB_copy_view(_,_,0)); // // Mainloop // @@ -266,8 +268,8 @@ struct CollectiveMma< // Load A, B smem->rmem for k+1 int k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static - copy(tCsA(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next)); - copy(tCsB(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next)); + copy(smem_tiled_copy_a, tCsA(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next)); + copy(smem_tiled_copy_b, tCsB(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next)); if (k_block == 0) { // Copy gmem to rmem @@ -515,14 +517,16 @@ struct CollectiveMma< // Copy Atom retiling // - auto thr_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma).get_thread_slice(thread_idx); - Tensor tCsA = thr_copy_A.partition_S(sA); - Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); + auto smem_tiled_copy_a = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma); + auto thr_copy_A = smem_tiled_copy_a.get_thread_slice(thread_idx); + Tensor tCsA = thr_copy_A.partition_S(sA); + Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M - auto thr_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma).get_thread_slice(thread_idx); - Tensor tCsB = thr_copy_B.partition_S(sB); - Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); + auto smem_tiled_copy_b = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma); + auto thr_copy_B = smem_tiled_copy_b.get_thread_slice(thread_idx); + Tensor tCsB = thr_copy_B.partition_S(sB); + Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N // @@ -536,8 +540,8 @@ struct CollectiveMma< __syncthreads(); // Load A, B smem->rmem for k=0 - copy(tCsA(_,_,0), tCrA_copy_view(_,_,0)); - copy(tCsB(_,_,0), tCrB_copy_view(_,_,0)); + copy(smem_tiled_copy_a, tCsA(_,_,0), tCrA_copy_view(_,_,0)); + copy(smem_tiled_copy_b, tCsB(_,_,0), tCrB_copy_view(_,_,0)); // // Mainloop // @@ -563,8 +567,8 @@ struct CollectiveMma< // Load A, B smem->rmem for k+1 int k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static - copy(tCsA(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next)); - copy(tCsB(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next)); + copy(smem_tiled_copy_a, tCsA(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next)); + copy(smem_tiled_copy_b, tCsB(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next)); if (k_block == 0) { if (k_tile_count <= 0) { diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm80_mma_array_multistage.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm80_mma_array_multistage.hpp new file mode 100644 index 0000000..a6a668a --- /dev/null +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm80_mma_array_multistage.hpp @@ -0,0 +1,412 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Stages, + class ClusterShape_, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_ +> +struct CollectiveMma< + MainloopSm80ArrayCpAsync< + Stages, + ClusterShape_>, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_ + > +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm80ArrayCpAsync< + Stages, + ClusterShape_>; + using TileShape = TileShape_; + // Follow the change in TestSmall: TileShape switch to CtaShape + // In legacy arch, it should be same + using CtaShape_MNK = TileShape; + using ElementA = ElementA_; + using StrideA = StrideA_; + using InternalStrideA = cute::remove_pointer_t; + using ElementB = ElementB_; + using StrideB = StrideB_; + using InternalStrideB = cute::remove_pointer_t; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + using ArrayElementA = ElementA; + using ArrayElementB = ElementB; + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}))); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}))); + + static_assert(DispatchPolicy::Stages >= 2, "CpAsync mainloop must have at least 2 stages in the pipeline."); + + struct SharedStorage + { + cute::array_aligned> smem_a; + cute::array_aligned> smem_b; + }; + + // Host side kernel arguments + struct Arguments { + ElementA const** ptr_A{nullptr}; + StrideA dA{}; + ElementB const** ptr_B{nullptr}; + StrideB dB{}; + }; + + // Device side kernel params + using Params = Arguments; + + // + // Methods + // + + CollectiveMma() = default; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& _, Arguments const& args, void* workspace) { + (void) workspace; + return args; + } + + /// Perform a collective-scoped matrix multiply-accumulate + template < + class FrgTensorD, + class TensorA, + class TensorB, + class FrgTensorC, + class KTileIterator, + class ResidueMNK + > + CUTLASS_DEVICE void + operator() ( + FrgTensorD &accum, + TensorA gA, // (BLK_M, BLK_K, K_TILES) + TensorB gB, // (BLK_N, BLK_K, K_TILES) + FrgTensorC const &src_accum, + KTileIterator k_tile_iter, int k_tile_count, + ResidueMNK residue_mnk, + int thread_idx, + char *smem_buf) + { + using namespace cute; + + static_assert(is_rmem::value, "D tensor must be rmem resident."); + static_assert(is_gmem::value, "A tensor must be gmem resident."); + static_assert(is_gmem::value, "B tensor must be gmem resident."); + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + + // Construct shared memory tiles + SharedStorage& storage = *reinterpret_cast(smem_buf); + Tensor sA = make_tensor(make_smem_ptr(storage.smem_a.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(storage.smem_b.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<0>(gA) == size<0>(sA)); // BLK_M + CUTE_STATIC_ASSERT_V(size<1>(gA) == size<1>(sA)); // BLK_K + CUTE_STATIC_ASSERT_V(size<0>(gB) == size<0>(sB)); // BLK_N + CUTE_STATIC_ASSERT_V(size<1>(gB) == size<1>(sB)); // BLK_K + CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // BLK_K + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // Shift tensor so residue_k is at origin (Can't read any k_coord < residue_k) + // This aligns the tensor with BLK_K for all but the 0th k_tile + gA = cute::domain_offset(make_coord(0, get<2>(residue_mnk), 0), gA); + gB = cute::domain_offset(make_coord(0, get<2>(residue_mnk), 0), gB); + + // Partition the copying of A and B tiles across the threads + GmemTiledCopyA gmem_tiled_copy_A; + GmemTiledCopyB gmem_tiled_copy_B; + auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx); + auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx); + + Tensor tAgA = gmem_thr_copy_A.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) + Tensor tAsA = gmem_thr_copy_A.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,PIPE) + Tensor tBgB = gmem_thr_copy_B.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) + Tensor tBsB = gmem_thr_copy_B.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,PIPE) + + // + // PREDICATES + // + + // Allocate predicate tensors for m and n + Tensor tApA = make_tensor(make_shape(size<1>(tAsA), size<2>(tAsA)), Stride<_1,_0>{}); + Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); + + // Construct identity layout for sA and sB + Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + Tensor tAcA = gmem_thr_copy_A.partition_S(cA); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tBcB = gmem_thr_copy_B.partition_S(cB); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Set predicates for m bounds + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<0>(tApA); ++m) { + tApA(m,0) = get<0>(tAcA(0,m,0)) < get<0>(residue_mnk); // blk_m coord < residue_m + } + // Set predicates for n bounds + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<0>(tBpB); ++n) { + tBpB(n,0) = get<0>(tBcB(0,n,0)) < get<1>(residue_mnk); // blk_n coord < residue_n + } + + // + // PREFETCH + // + + // Clear the smem tiles to account for predicated off loads + clear(tAsA); + clear(tBsB); + + // Start async loads for 0th k-tile, where we take care of the k residue + { + constexpr int k_pipe = 0; + + Tensor tAgAk = tAgA(_,_,_,*k_tile_iter); + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tAsA); ++k) { + if (get<1>(tAcA(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gA shifted) + copy_if(gmem_tiled_copy_A, tApA(_,k), tAgAk(_,_,k), tAsA(_,_,k,k_pipe)); + } + } + Tensor tBgBk = tBgB(_,_,_,*k_tile_iter); + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tBsB); ++k) { + if (get<1>(tBcB(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gB shifted) + copy_if(gmem_tiled_copy_B, tBpB(_,k), tBgBk(_,_,k), tBsB(_,_,k,k_pipe)); + } + } + cp_async_fence(); + ++k_tile_iter; + --k_tile_count; + } + + // Start async loads for 1st k-tile onwards, no k-residue handling needed + CUTLASS_PRAGMA_UNROLL + for (int k_pipe = 1; k_pipe < DispatchPolicy::Stages-1; ++k_pipe) { + if (k_tile_count <= 0) { + clear(tApA); + clear(tBpB); + } + copy_if(gmem_tiled_copy_A, tApA, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,k_pipe)); // CpAsync + copy_if(gmem_tiled_copy_B, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,k_pipe)); // CpAsync + cp_async_fence(); + ++k_tile_iter; + --k_tile_count; + } + + // + // MMA Atom partitioning + // + + // Tile MMA compute thread partitions and allocate accumulators + TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCrA = thr_mma.partition_fragment_A(sA(_,_,0)); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thr_mma.partition_fragment_B(sB(_,_,0)); // (MMA,MMA_N,MMA_K) + + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(src_accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(accum)); // MMA_N + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(src_accum)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + Tensor tCsA = smem_thr_copy_A.partition_S(sA); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K) + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K + + auto smem_tiled_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma); + auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); + Tensor tCsB = smem_thr_copy_B.partition_S(sB); // (CPY,CPY_N,CPY_K,PIPE) + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); // (CPY,CPY_N,CPY_K) + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // CPY_N + CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrB_copy_view)); // CPY_K + + // + // PIPELINED MAIN LOOP + // + + // Current pipe index in smem to read from + int smem_pipe_read = 0; + // Current pipe index in smem to write to + int smem_pipe_write = DispatchPolicy::Stages-1; + + Tensor tCsA_p = tCsA(_,_,_,smem_pipe_read); + Tensor tCsB_p = tCsB(_,_,_,smem_pipe_read); + + // Size of the register pipeline + auto K_BLOCK_MAX = size<2>(tCrA); + + // PREFETCH register pipeline + if (K_BLOCK_MAX > 1) { + // Wait until our first prefetched tile is loaded in + cp_async_wait(); + __syncthreads(); + + // Prefetch the first rmem from the first k-tile + copy(smem_tiled_copy_A, tCsA_p(_,_,Int<0>{}), tCrA_copy_view(_,_,Int<0>{})); + copy(smem_tiled_copy_B, tCsB_p(_,_,Int<0>{}), tCrB_copy_view(_,_,Int<0>{})); + } + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > -(DispatchPolicy::Stages-1); --k_tile_count) + { + // Pipeline the outer products with a static for loop. + // + // Note, the for_each() function is required here to ensure `k_block` is of type Int. + for_each(make_int_sequence{}, [&] (auto k_block) + { + if (k_block == K_BLOCK_MAX - 1) + { + // Slice the smem_pipe_read smem + tCsA_p = tCsA(_,_,_,smem_pipe_read); + tCsB_p = tCsB(_,_,_,smem_pipe_read); + + // Commit the smem for smem_pipe_read + cp_async_wait(); + __syncthreads(); + } + + // Load A, B shmem->regs for k_block+1 + auto k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static + copy(smem_tiled_copy_A, tCsA_p(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next)); + copy(smem_tiled_copy_B, tCsB_p(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next)); + // Copy gmem to smem before computing gemm on each k-pipe + if (k_block == 0) + { + // Set all predicates to false if we are going to overshoot bounds + if (k_tile_count <= 0) { + clear(tApA); + clear(tBpB); + } + copy_if(gmem_tiled_copy_A, tApA, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,smem_pipe_write)); + copy_if(gmem_tiled_copy_B, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,smem_pipe_write)); + cp_async_fence(); + ++k_tile_iter; + + // Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe) + smem_pipe_write = smem_pipe_read; + ++smem_pipe_read; + smem_pipe_read = (smem_pipe_read == DispatchPolicy::Stages) ? 0 : smem_pipe_read; + } + + // Transform before compute + cute::transform(tCrA(_,_,k_block), TransformA{}); + cute::transform(tCrB(_,_,k_block), TransformB{}); + // Thread-level register gemm for k_block + cute::gemm(tiled_mma, accum, tCrA(_,_,k_block), tCrB(_,_,k_block), src_accum); + }); + + } + + cp_async_wait<0>(); + __syncthreads(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm80_mma_multistage.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm80_mma_multistage.hpp index 9775840..be488dd 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm80_mma_multistage.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm80_mma_multistage.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -100,7 +100,7 @@ struct CollectiveMma< using TransformB = TransformB_; using ArchTag = typename DispatchPolicy::ArchTag; // Follow the change in TestSmall: TileShape switch to CtaShape - // For sm80 arch, CtaShape should euqal to TileShape + // For sm80 arch, CtaShape should equal to TileShape using CtaShape_MNK = TileShape; static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp index 653db90..1ca5e7c 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -89,15 +89,10 @@ struct CollectiveMma< TransformB_> { public: - enum class ConversionMode { - DirectConvert, - ConvertAndScale, - ConvertAndScaleWithZero - }; - // // Type Aliases // + using ConversionMode = cutlass::detail::ConversionMode; using DispatchPolicy = MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput; using TileShape = TileShape_; using KernelSchedule = KernelSchedule_; diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp index 6786cec..83a281c 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8.hpp index 916c6db..6070810 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -531,8 +531,8 @@ struct CollectiveMma< int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - - GmmaFP8Accumulation accumulation(accum, mainloop_params.mma_promotion_interval, size<2>(tCrA)); + auto accm_temp = cute::make_fragment_like(accum); + GmmaFP8Accumulation accumulation(accm_temp, mainloop_params.mma_promotion_interval, size<2>(tCrA)); warpgroup_fence_operand(accumulation()); CUTLASS_PRAGMA_UNROLL for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) @@ -556,7 +556,7 @@ struct CollectiveMma< } warpgroup_commit_batch(); - accumulation.promote_if_needed(); + accumulation.promote_if_needed(accum); ++smem_pipe_read; } @@ -597,7 +597,7 @@ struct CollectiveMma< warpgroup_wait(); warpgroup_fence_operand(accumulation()); - accumulation.promote_if_needed(); + accumulation.promote_if_needed(accum); pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it @@ -606,7 +606,7 @@ struct CollectiveMma< ++smem_pipe_release; } - accumulation.promote_residue_if_needed(); + accumulation.promote_residue_if_needed(accum); warpgroup_fence_operand(accumulation()); } diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp index 67c8268..20c4095 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -73,7 +73,7 @@ template < class SmemCopyAtomB_, class TransformB_> struct CollectiveMma< - MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling, + MainloopSm90ArrayTmaGmmaWarpSpecializedBlockwise, TileShape_, ElementA_, StridePairA_, @@ -92,7 +92,7 @@ struct CollectiveMma< // // Type Aliases // - using DispatchPolicy = MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling; + using DispatchPolicy = MainloopSm90ArrayTmaGmmaWarpSpecializedBlockwise; using TileShape = TileShape_; using ElementA = ElementA_; using StrideA = cute::tuple_element_t<0,StridePairA_>; @@ -153,7 +153,15 @@ struct CollectiveMma< static_assert((size<0>(TileShape{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M."); static_assert((size<1>(TileShape{}) % ScaleGranularityN) == 0, "FP8 scaling granularity must evenly divide tile shape along N."); - using ScaleConfig = ::cutlass::detail::Sm90BlockwiseScaleConfig; + static constexpr bool MMajorSFA = size<0,1>(InternalLayoutSFA{}.stride()) == 1; + static constexpr bool NMajorSFB = size<0,1>(InternalLayoutSFB{}.stride()) == 1; + + using ScaleConfig = ::cutlass::detail::Sm90BlockwiseScaleConfig< + ScaleGranularityM, + ScaleGranularityN, + ScaleGranularityK, + MMajorSFA ? cute::GMMA::Major::MN : cute::GMMA::Major::K, + NMajorSFB ? cute::GMMA::Major::MN : cute::GMMA::Major::K>; using SmemLayoutAtomSFA = decltype(ScaleConfig::smem_atom_layoutSFA(TileShape{})); using SmemLayoutAtomSFB = decltype(ScaleConfig::smem_atom_layoutSFB(TileShape{})); @@ -204,6 +212,9 @@ struct CollectiveMma< static_assert(cute::is_same_v, "ElementAccumulator and ElementBlockScale should be same datatype"); + // For TileShapeM < 128, NumSplitsM should be 1 + using NumSplitsM = cute::conditional_t(TileShape_{}) < _128{}, _1, cute::C(TileShape_{}) / 128>>; + static_assert(NumSplitsM{} == 1 || NumSplitsM{} == 2); struct SharedStorage { struct TensorStorage : cute::aligned_struct<128, _0> { @@ -374,8 +385,6 @@ struct CollectiveMma< auto [M,N,K,L] = problem_shape_MNKL; implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{}); implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{}); - // We expect full tiles in K - implementable = implementable && K % size<2>(TileShape{}) == 0; } } @@ -681,35 +690,37 @@ struct CollectiveMma< template< + class AccumSlice, class EngineAccum, class LayoutAccum, class ScaleFactor > CUTLASS_DEVICE - void scale_if_needed(GmmaFP8Accumulation& accumulation, ScaleFactor scaleFactor) { + void scale_if_needed(AccumSlice & accum, GmmaFP8Accumulation& accumulation, ScaleFactor scaleFactor) { if constexpr (ScalePromotionInterval != 4) { - accumulation.scale_if_needed(scaleFactor); + accumulation.scale_if_needed(accum, scaleFactor); } else { // avoid unnecessary tests when granularity is the finnest - accumulation.scale(scaleFactor); + accumulation.scale(accum, scaleFactor); } } template< + class AccumSlice, class EngineAccum, class LayoutAccum, class ScaleFactor1, class ScaleFactor2 > CUTLASS_DEVICE - void scale_if_needed(GmmaFP8Accumulation& accumulation, ScaleFactor1 scaleFactor1, ScaleFactor2 scaleFactor2) { + void scale_if_needed(AccumSlice & accum, GmmaFP8Accumulation& accumulation, ScaleFactor1 scaleFactor1, ScaleFactor2 scaleFactor2) { if constexpr (ScalePromotionInterval != 4) { - accumulation.scale_if_needed(scaleFactor1, scaleFactor2); + accumulation.scale_if_needed(accum, scaleFactor1, scaleFactor2); } else { // avoid unnecessary tests when granularity is the finnest - accumulation.scale(scaleFactor1, scaleFactor2); + accumulation.scale(accum, scaleFactor1, scaleFactor2); } } @@ -815,78 +826,26 @@ struct CollectiveMma< // Prologue GMMAs tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + // Tile accum - // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) - auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); - pipeline.consumer_wait(smem_pipe_read, barrier_token); - - // fence_operand(); - GmmaFP8Accumulation accumulation(accum, ScalePromotionInterval, size<2>(tCrA)); - - warpgroup_fence_operand(accumulation()); - - { - - int read_stage = smem_pipe_read.index(); - // Load per block scale values from shared memory to registers - copy(tCsSFA(_,_,_,make_coord(_0{},read_stage)), tCrSFA); - copy(tCsSFB(_,_,_,make_coord(_0{},read_stage)), tCrSFB); - - warpgroup_fence_operand(accumulation()); - warpgroup_arrive(); - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - // (V,M) x (V,N) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } - warpgroup_commit_batch(); - warpgroup_fence_operand(accumulation()); + using NumSplitsM_Scale = cute::conditional_t; + static constexpr int ScaleMsPerWave = ScaleMsPerTile == 1 ? 1 : ScaleMsPerTile / NumSplitsM{}; + auto accum_tiled = tiled_divide(accum, cute::tuple<_1, NumSplitsM>{}); + auto tCrA_tiled = tiled_divide(tCrA, cute::tuple<_1, NumSplitsM>{}); + auto tCsSFA_tiled = tiled_divide(tCsSFA, cute::tuple<_1, NumSplitsM_Scale>{}); + auto tCrSFA_tiled = tiled_divide(tCrSFA, cute::tuple<_1, NumSplitsM_Scale>{}); + auto tCrSFB_tiled = tiled_divide(tCrSFB, cute::tuple<_1, NumSplitsM_Scale>{}); + // Temporary accumulator used by MMA + // On promotion, accumulated values are scaled and copied into `accum` + auto accum_temp = cute::make_fragment_like(accum_tiled(_0{}, _, _, _)); - if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { - tCrSFA(_0{}) = tCrSFA(_0{}) * tCrSFB(_0{}); - } - if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { - ElementBlockScale scale_b = tCrSFB(_0{}); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(filter_zeros(tCrSFA)); i++) { - filter_zeros(tCrSFA)(i) = filter_zeros(tCrSFA)(i) * scale_b; - } - } - if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { - ElementBlockScale scale_a = tCrSFA(_0{}); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(filter_zeros(tCrSFB)); i++) { - filter_zeros(tCrSFB)(i) = filter_zeros(tCrSFB)(i) * scale_a; - } - } - - warpgroup_wait<0>(); - ++smem_pipe_read; - barrier_token = pipeline.consumer_try_wait(smem_pipe_read); - - // Block scale the accumulators with reg tensor `tCrSFA` and `tCrSFB` - if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { - ElementBlockScale scale_ab = tCrSFA(_0{}); - scale_if_needed(accumulation, scale_ab); - } - if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { - scale_if_needed(accumulation, tCrSFA); - } - if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { - scale_if_needed(accumulation, tCrSFB); - } - if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) { - scale_if_needed(accumulation, tCrSFA, tCrSFB); - } - } + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + // Secondary accumulator for FP32 accum + GmmaFP8Accumulation accumulation(accum_temp, ScalePromotionInterval, size<2>(tCrA)); warpgroup_fence_operand(accumulation()); - // Mainloop GMMAs - k_tile_count--; CUTLASS_PRAGMA_NO_UNROLL for ( ; k_tile_count > 1; --k_tile_count) @@ -900,76 +859,89 @@ struct CollectiveMma< int read_stage = smem_pipe_read.index(); // Load per block scale values from shared memory to registers (at most twice per block along M and/or N) - copy(tCsSFA(_,_,_,make_coord(_0{}, read_stage)), tCrSFA); - copy(tCsSFB(_,_,_,make_coord(_0{}, read_stage)), tCrSFB); + copy(tCsSFB(_,_,_,make_coord(_0{}, read_stage)), tCrSFB); - if constexpr (ScalePromotionInterval != 4) { - if (accumulation.prepare_if_needed()) { + CUTLASS_PRAGMA_UNROLL + for (int m_split = 0; m_split < NumSplitsM{}; ++m_split) { + auto tCrA_local = tCrA_tiled(m_split, _, _, _, _); + auto tCrSFA_local = tCrSFA_tiled(m_split, _, _, _); + auto tCrSFB_local = tCrSFB_tiled(m_split, _, _, _); + auto accum_local = accum_tiled(m_split, _, _, _); + copy(tCsSFA_tiled(m_split, _, _, _, make_coord(_0{}, read_stage)), tCrSFA_local); + bool is_last = (m_split == NumSplitsM{} - 1); + + if constexpr (ScalePromotionInterval != 4) { + if (accumulation.prepare_if_needed()) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + } + else { + // Always zero out the accumulator for finest granularity tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; } - } - else { - // Always zero out the accumulator for finest granularity - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - } - warpgroup_fence_operand(accumulation()); - warpgroup_arrive(); - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - // (V,M) x (V,N) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } - warpgroup_commit_batch(); + warpgroup_fence_operand(accumulation()); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_local(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); - /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed - warpgroup_fence_operand(accumulation()); + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_fence_operand(accumulation()); - if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { - tCrSFA(_0{}) = tCrSFA(_0{}) * tCrSFB(_0{}); - } - if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { - ElementBlockScale scale_b = tCrSFB(_0{}); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(filter_zeros(tCrSFA)); i++) { - filter_zeros(tCrSFA)(i) = filter_zeros(tCrSFA)(i) * scale_b; + + if constexpr (ScaleMsPerWave == 1 && ScaleNsPerTile == 1) { + tCrSFA_local(_0{}) = tCrSFA_local(_0{}) * tCrSFB(_0{}); } - } - if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { - ElementBlockScale scale_a = tCrSFA(_0{}); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(filter_zeros(tCrSFB)); i++) { - filter_zeros(tCrSFB)(i) = filter_zeros(tCrSFB)(i) * scale_a; + if constexpr (ScaleMsPerWave > 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_b = tCrSFB(_0{}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(filter_zeros(tCrSFA_local)); i++) { + filter_zeros(tCrSFA_local)(i) = filter_zeros(tCrSFA_local)(i) * scale_b; + } + } + if constexpr (ScaleMsPerWave == 1 && ScaleNsPerTile > 1) { + ElementBlockScale scale_a = tCrSFA_local(_0{}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(filter_zeros(tCrSFB_local)); i++) { + filter_zeros(tCrSFB_local)(i) = filter_zeros(tCrSFB_local)(i) * scale_a; + } } - } - - warpgroup_wait<0>(); - pipeline.consumer_release(smem_pipe_release); // Unlock previous tile - ++smem_pipe_read; - barrier_token = pipeline.consumer_try_wait(smem_pipe_read); - // Block scale the accumulators with reg tensor `tCrSFA` and `tCrSFB` - if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { - ElementBlockScale scale_ab = tCrSFA(_0{}); - scale_if_needed(accumulation, scale_ab); - } - if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { - scale_if_needed(accumulation, tCrSFA); - } - if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { - scale_if_needed(accumulation, tCrSFB); - } - if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) { - scale_if_needed(accumulation, tCrSFA, tCrSFB); - } + warpgroup_wait<0>(); + if (is_last) { + pipeline.consumer_release(smem_pipe_release); // Unlock previous tile + ++smem_pipe_read; + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + } + // Block scale the accumulators with reg tensor `tCrSFA_local` and `tCrSFB` + if constexpr (ScaleMsPerWave == 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_ab = tCrSFA_local(_0{}); + scale_if_needed(accum_local, accumulation, scale_ab); + } + if constexpr (ScaleMsPerWave > 1 && ScaleNsPerTile == 1) { + scale_if_needed(accum_local, accumulation, tCrSFA_local); + } + if constexpr (ScaleMsPerWave == 1 && ScaleNsPerTile > 1) { + scale_if_needed(accum_local, accumulation, tCrSFB_local); + } + if constexpr (ScaleMsPerWave > 1 && ScaleNsPerTile > 1) { + scale_if_needed(accum_local, accumulation, tCrSFA_local, tCrSFB_local); + } - // Advance smem_pipe_read and smem_pipe_release - ++smem_pipe_release; + if (is_last) { + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_release; + } + } // end for (m_split) } - if (k_tile_count) { + if (k_tile_count > 0) { pipeline.consumer_wait(smem_pipe_read, barrier_token); // @@ -978,95 +950,101 @@ struct CollectiveMma< int read_stage = smem_pipe_read.index(); // Load per block scale values from shared memory to registers (at most twice per block along M and/or N) - copy(tCsSFA(_,_,_,make_coord(_0{}, read_stage)), tCrSFA); copy(tCsSFB(_,_,_,make_coord(_0{}, read_stage)), tCrSFB); - - if constexpr (ScalePromotionInterval != 4) { - if (accumulation.prepare_if_needed()) { + CUTLASS_PRAGMA_UNROLL + for (int m_split = 0; m_split < NumSplitsM{}; ++m_split) { + auto tCrA_local = tCrA_tiled(m_split, _, _, _, _); + auto tCrSFA_local = tCrSFA_tiled(m_split, _, _, _); + auto tCrSFB_local = tCrSFB_tiled(m_split, _, _, _); + auto accum_local = accum_tiled(m_split, _, _, _); + copy(tCsSFA_tiled(m_split, _, _, _, make_coord(_0{}, read_stage)), tCrSFA_local); + bool is_last = (m_split == NumSplitsM{} - 1); + + if constexpr (ScalePromotionInterval != 4) { + if (accumulation.prepare_if_needed()) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + } + else { + // Always zero out the accumulator for finest granularity tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; } - } - else { - // Always zero out the accumulator for finest granularity - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - } - warpgroup_fence_operand(accumulation()); - warpgroup_arrive(); - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - // (V,M) x (V,N) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } - warpgroup_commit_batch(); + warpgroup_fence_operand(accumulation()); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_local(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); - /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed - warpgroup_fence_operand(accumulation()); + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_fence_operand(accumulation()); - if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { - tCrSFA(_0{}) = tCrSFA(_0{}) * tCrSFB(_0{}); - } - if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { - ElementBlockScale scale_b = tCrSFB(_0{}); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(filter_zeros(tCrSFA)); i++) { - filter_zeros(tCrSFA)(i) = filter_zeros(tCrSFA)(i) * scale_b; + + if constexpr (ScaleMsPerWave == 1 && ScaleNsPerTile == 1) { + tCrSFA_local(_0{}) = tCrSFA_local(_0{}) * tCrSFB(_0{}); } - } - if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { - ElementBlockScale scale_a = tCrSFA(_0{}); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(filter_zeros(tCrSFB)); i++) { - filter_zeros(tCrSFB)(i) = filter_zeros(tCrSFB)(i) * scale_a; + if constexpr (ScaleMsPerWave > 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_b = tCrSFB(_0{}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(filter_zeros(tCrSFA_local)); i++) { + filter_zeros(tCrSFA_local)(i) = filter_zeros(tCrSFA_local)(i) * scale_b; + } + } + if constexpr (ScaleMsPerWave == 1 && ScaleNsPerTile > 1) { + ElementBlockScale scale_a = tCrSFA_local(_0{}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(filter_zeros(tCrSFB_local)); i++) { + filter_zeros(tCrSFB_local)(i) = filter_zeros(tCrSFB_local)(i) * scale_a; + } + } + warpgroup_wait<0>(); + if (is_last) { + pipeline.consumer_release(smem_pipe_release); // Unlock previous tile + } + // Block scale the accumulators with reg tensor `tCrSFA_local` and `tCrSFB` + if constexpr (ScaleMsPerWave == 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_ab = tCrSFA_local(_0{}); + scale_if_needed(accum_local, accumulation, scale_ab); + } + if constexpr (ScaleMsPerWave > 1 && ScaleNsPerTile == 1) { + scale_if_needed(accum_local, accumulation, tCrSFA_local); + } + if constexpr (ScaleMsPerWave == 1 && ScaleNsPerTile > 1) { + scale_if_needed(accum_local, accumulation, tCrSFB_local); + } + if constexpr (ScaleMsPerWave > 1 && ScaleNsPerTile > 1) { + scale_if_needed(accum_local, accumulation, tCrSFA_local, tCrSFB_local); + } + if constexpr (ScalePromotionInterval != 4) { + // residues only exists when granularity is not the finnest + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_ab = tCrSFA_local(_0{}); + accumulation.scale_residue_if_needed(accum_local, scale_ab); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + accumulation.scale_residue_if_needed(accum_local, tCrSFA_local); + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + accumulation.scale_residue_if_needed(accum_local, tCrSFB_local); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) { + accumulation.scale_residue_if_needed(accum_local, tCrSFA_local, tCrSFB_local); + } } - } - warpgroup_wait<0>(); - pipeline.consumer_release(smem_pipe_release); // Unlock previous tile - // Block scale the accumulators with reg tensor `tCrSFA` and `tCrSFB` - if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { - ElementBlockScale scale_ab = tCrSFA(_0{}); - scale_if_needed(accumulation, scale_ab); - } - if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { - scale_if_needed(accumulation, tCrSFA); - } - if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { - scale_if_needed(accumulation, tCrSFB); - } - if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) { - scale_if_needed(accumulation, tCrSFA, tCrSFB); - } + warpgroup_fence_operand(accumulation()); + } // end for (m_split) } - if constexpr (ScalePromotionInterval != 4) { - // residues only exists when granularity is not the finnest - if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { - ElementBlockScale scale_ab = tCrSFA(_0{}); - accumulation.scale_residue_if_needed(scale_ab); - } - if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { - accumulation.scale_residue_if_needed(tCrSFA); - } - if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { - accumulation.scale_residue_if_needed(tCrSFB); - } - if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) { - accumulation.scale_residue_if_needed(tCrSFA, tCrSFB); - } - } - - warpgroup_fence_operand(accumulation()); - } /// Perform a Consumer Epilogue to release all buffers CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { - // The pipeline is not released in the first iteration - smem_pipe_release.advance(k_tile_count - 1); - pipeline.consumer_release(smem_pipe_release); } // diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp index 4289bc8..9bf1c7a 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp index fbbe971..6367bca 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp index f8e0543..f611f29 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp index 4e43529..9b9f3ad 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -96,15 +96,11 @@ struct CollectiveMma< TransformB_> { public: - enum class ConversionMode { - DirectConvert, - ConvertAndScale, - ConvertAndScaleWithZero - }; // // Type Aliases // + using ConversionMode = cutlass::detail::ConversionMode; using DispatchPolicy = MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput; using TileShape = TileShape_; using KernelSchedule = KernelSchedule_; diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp index 228c258..b9c98b6 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp index 0e64bad..e42caed 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp index c7ea65a..7671b07 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -481,8 +481,8 @@ struct CollectiveMma< int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - - GmmaFP8Accumulation accumulation(accum, mainloop_params.mma_promotion_interval, size<2>(tCrA)); + auto accm_temp = cute::make_fragment_like(accum); + GmmaFP8Accumulation accumulation(accm_temp, mainloop_params.mma_promotion_interval, size<2>(tCrA)); warpgroup_fence_operand(accumulation()); CUTLASS_PRAGMA_UNROLL for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) @@ -506,7 +506,7 @@ struct CollectiveMma< } warpgroup_commit_batch(); - accumulation.promote_if_needed(); + accumulation.promote_if_needed(accum); ++smem_pipe_read; } @@ -547,7 +547,7 @@ struct CollectiveMma< warpgroup_wait(); warpgroup_fence_operand(accumulation()); - accumulation.promote_if_needed(); + accumulation.promote_if_needed(accum); pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it @@ -556,7 +556,7 @@ struct CollectiveMma< ++smem_pipe_release; } - accumulation.promote_residue_if_needed(); + accumulation.promote_residue_if_needed(accum); warpgroup_fence_operand(accumulation()); } diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp index ecbd59b..0d4bd9b 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -33,7 +33,9 @@ #include "cutlass/cutlass.h" #include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/fp8_accumulation.hpp" #include "cutlass/trace.h" +#include "cutlass/pipeline/pipeline.hpp" #include "cutlass/numeric_types.h" #include "cute/arch/cluster_sm90.hpp" @@ -73,7 +75,7 @@ template < class SmemCopyAtomB_, class TransformB_> struct CollectiveMma< - MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8, + MainloopSm90TmaGmmaWarpSpecializedBlockwiseFP8, TileShape_, ElementA_, StridePairA_, @@ -91,7 +93,7 @@ struct CollectiveMma< // // Type Aliases // - using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8; + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockwiseFP8; using TileShape = TileShape_; using ElementA = ElementA_; using StrideA = cute::tuple_element_t<0,StridePairA_>; @@ -134,9 +136,12 @@ struct CollectiveMma< static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN; + static constexpr bool MMajorSFA = size<0,1>(LayoutSFA{}.stride()) == 1; + static constexpr bool NMajorSFB = size<0,1>(LayoutSFB{}.stride()) == 1; + static constexpr int ScaleTmaThreshold = 32; - static constexpr bool IsTmaLoadSFA = ScaleMsPerTile >= ScaleTmaThreshold && ScaleNsPerTile < ScaleTmaThreshold; - static constexpr bool IsTmaLoadSFB = ScaleNsPerTile >= ScaleTmaThreshold && ScaleMsPerTile < ScaleTmaThreshold; + static constexpr bool IsTmaLoadSFA = ScaleMsPerTile >= ScaleTmaThreshold && ScaleNsPerTile < ScaleTmaThreshold && MMajorSFA; + static constexpr bool IsTmaLoadSFB = ScaleNsPerTile >= ScaleTmaThreshold && ScaleMsPerTile < ScaleTmaThreshold && NMajorSFB; // Two threads per CTA are producers (1 for operand tile `tma`, and 32 for scales `cp.async`) static constexpr int NumProducerThreadEvents = ((IsTmaLoadSFA && IsTmaLoadSFB)? 1 : 33); @@ -151,7 +156,12 @@ struct CollectiveMma< static_assert((size<0>(TileShape{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M."); static_assert((size<1>(TileShape{}) % ScaleGranularityN) == 0, "FP8 scaling granularity must evenly divide tile shape along N."); - using ScaleConfig = ::cutlass::detail::Sm90BlockwiseScaleConfig; + using ScaleConfig = ::cutlass::detail::Sm90BlockwiseScaleConfig< + ScaleGranularityM, + ScaleGranularityN, + ScaleGranularityK, + MMajorSFA ? cute::GMMA::Major::MN : cute::GMMA::Major::K, + NMajorSFB ? cute::GMMA::Major::MN : cute::GMMA::Major::K>; using SmemLayoutAtomSFA = decltype(ScaleConfig::smem_atom_layoutSFA(TileShape{})); using SmemLayoutAtomSFB = decltype(ScaleConfig::smem_atom_layoutSFB(TileShape{})); @@ -170,8 +180,8 @@ struct CollectiveMma< using CopyAtomSFA = Copy_Atom, ElementBlockScale>; using CopyAtomSFB = Copy_Atom, ElementBlockScale>; - static constexpr int AlignmentSFA = 1; - static constexpr int AlignmentSFB = 1; + static constexpr int AlignmentSFA = IsTmaLoadSFA ? 128 / cutlass::sizeof_bits::value : 1; + static constexpr int AlignmentSFB = IsTmaLoadSFB ? 128 / cutlass::sizeof_bits::value : 1; // Block scaling smem layout using SmemLayoutSFA = decltype(make_layout( @@ -195,6 +205,9 @@ struct CollectiveMma< static_assert(cute::is_same_v, "ElementAccumulator and ElementBlockScale should be same datatype"); + using NumSplitsM = cute::C(TileShape_{}) / 128>; + static_assert(NumSplitsM{} == 1 || NumSplitsM{} == 2); + struct SharedStorage { struct TensorStorage : cute::aligned_struct<128> { @@ -383,12 +396,6 @@ struct CollectiveMma< implementable = false; CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem size doesn't meet the minimum alignment requirements for using TMA to load scale B.\n"); } - - // We expect full tiles in K - if (K % size<2>(TileShape{}) != 0) { - implementable = false; - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem size K is incompatible with tile size.\n"); - } return implementable; } @@ -669,7 +676,7 @@ struct CollectiveMma< Tensor tSFAcSFA_compact = filter_zeros(tSFAcSFA); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(tSFApSFA); ++i) { - tSFApSFA(i) = load_sfa && elem_less(get<0>(tSFAcSFA_compact(i)), get<0>(SFA_shape)); + tSFApSFA(i) = load_sfa && elem_less(tSFAcSFA_compact(i), SFA_shape); } bool load_sfb = thread_idx < ScaleNsPerTile; @@ -677,7 +684,7 @@ struct CollectiveMma< Tensor tSFBcSFB_compact = filter_zeros(tSFBcSFB); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(tSFBpSFB); ++i) { - tSFBpSFB(i) = load_sfb && elem_less(get<0>(tSFBcSFB_compact(i)), get<0>(SFB_shape)); + tSFBpSFB(i) = load_sfb && elem_less(tSFBcSFB_compact(i), SFB_shape); } int write_stage = smem_pipe_write.index(); // Copy scale tensors from global memory to shared memory @@ -718,34 +725,36 @@ struct CollectiveMma< } template< + class AccumSlice, class EngineAccum, class LayoutAccum, class ScaleFactor > CUTLASS_DEVICE - void scale_if_needed(GmmaFP8Accumulation& accumulation, ScaleFactor scaleFactor) { + void scale_if_needed(AccumSlice & accum, GmmaFP8Accumulation& accumulation, ScaleFactor scaleFactor) { if constexpr (ScalePromotionInterval != 4) { - accumulation.scale_if_needed(scaleFactor); + accumulation.scale_if_needed(accum, scaleFactor); } else { // avoid unnecessary tests when granularity is the finnest - accumulation.scale(scaleFactor); + accumulation.scale(accum, scaleFactor); } } template< + class AccumSlice, class EngineAccum, class LayoutAccum, class ScaleFactor1, class ScaleFactor2 > CUTLASS_DEVICE - void scale_if_needed(GmmaFP8Accumulation& accumulation, ScaleFactor1 scaleFactor1, ScaleFactor2 scaleFactor2) { + void scale_if_needed(AccumSlice & accum, GmmaFP8Accumulation& accumulation, ScaleFactor1 scaleFactor1, ScaleFactor2 scaleFactor2) { if constexpr (ScalePromotionInterval != 4) { - accumulation.scale_if_needed(scaleFactor1, scaleFactor2); + accumulation.scale_if_needed(accum, scaleFactor1, scaleFactor2); } else { // avoid unnecessary tests when granularity is the finnest - accumulation.scale(scaleFactor1, scaleFactor2); + accumulation.scale(accum, scaleFactor1, scaleFactor2); } } @@ -850,70 +859,27 @@ struct CollectiveMma< // Prologue GMMAs tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - + // Tile accum + + using NumSplitsM_Scale = cute::conditional_t; + static constexpr int ScaleMsPerWave = ScaleMsPerTile == 1 ? 1 : ScaleMsPerTile / NumSplitsM{}; + + auto accum_tiled = tiled_divide(accum, cute::tuple<_1, NumSplitsM>{}); + auto tCrA_tiled = tiled_divide(tCrA, cute::tuple<_1, NumSplitsM>{}); + auto tCsSFA_tiled = tiled_divide(tCsSFA, cute::tuple<_1, NumSplitsM_Scale>{}); + auto tCrSFA_tiled = tiled_divide(tCrSFA, cute::tuple<_1, NumSplitsM_Scale>{}); + auto tCrSFB_tiled = tiled_divide(tCrSFB, cute::tuple<_1, NumSplitsM_Scale>{}); + // Temporary accumulator used by MMA + // On promotion, accumulated values are scaled and copied into `accum` + auto accum_temp = cute::make_fragment_like(accum_tiled(_0{}, _, _, _)); // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); pipeline.consumer_wait(smem_pipe_read, barrier_token); - GmmaFP8Accumulation accumulation(accum, ScalePromotionInterval, size<2>(tCrA)); - warpgroup_fence_operand(accumulation()); - { - int read_stage = smem_pipe_read.index(); - - // Load per block scale values from shared memory to registers - copy(tCsSFA(_,_,_,make_coord(_0{}, read_stage)), tCrSFA); - copy(tCsSFB(_,_,_,make_coord(_0{}, read_stage)), tCrSFB); - - warpgroup_fence_operand(accumulation()); - warpgroup_arrive(); - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - // (V,M) x (V,N) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } - warpgroup_commit_batch(); - warpgroup_fence_operand(accumulation()); - - if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { - tCrSFA(_0{}) = tCrSFA(_0{}) * tCrSFB(_0{}); - } - if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { - ElementBlockScale scale_b = tCrSFB(_0{}); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(filter_zeros(tCrSFA)); i++) { - filter_zeros(tCrSFA)(i) = filter_zeros(tCrSFA)(i) * scale_b; - } - } - if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { - ElementBlockScale scale_a = tCrSFA(_0{}); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(filter_zeros(tCrSFB)); i++) { - filter_zeros(tCrSFB)(i) = filter_zeros(tCrSFB)(i) * scale_a; - } - } - warpgroup_wait<0>(); - ++smem_pipe_read; - barrier_token = pipeline.consumer_try_wait(smem_pipe_read); - // Block scale the accumulators with reg tensor `tCrSFA` and `tCrSFB` - if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { - ElementBlockScale scale_ab = tCrSFA(_0{}); - scale_if_needed(accumulation, scale_ab); - } - if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { - scale_if_needed(accumulation, tCrSFA); - } - if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { - scale_if_needed(accumulation, tCrSFB); - } - if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) { - scale_if_needed(accumulation, tCrSFA, tCrSFB); - } - } + // Secondary accumulator for FP32 accum + GmmaFP8Accumulation accumulation(accum_temp, ScalePromotionInterval, size<2>(tCrA)); warpgroup_fence_operand(accumulation()); // Mainloop GMMAs - k_tile_count -= 1; CUTLASS_PRAGMA_NO_UNROLL for ( ; k_tile_count > 1; --k_tile_count) @@ -927,71 +893,86 @@ struct CollectiveMma< int read_stage = smem_pipe_read.index(); // Load per block scale values from shared memory to registers (at most twice per block along M and/or N) - copy(tCsSFA(_,_,_,make_coord(_0{}, read_stage)), tCrSFA); copy(tCsSFB(_,_,_,make_coord(_0{}, read_stage)), tCrSFB); - if constexpr (ScalePromotionInterval != 4) { - if (accumulation.prepare_if_needed()) { + CUTLASS_PRAGMA_UNROLL + for (int m_split = 0; m_split < NumSplitsM{}; ++m_split) { + auto tCrA_local = tCrA_tiled(m_split, _, _, _, _); + auto tCrSFA_local = tCrSFA_tiled(m_split, _, _, _); + auto tCrSFB_local = tCrSFB_tiled(m_split, _, _, _); + auto accum_local = accum_tiled(m_split, _, _, _); + copy(tCsSFA_tiled(m_split, _, _, _, make_coord(_0{}, read_stage)), tCrSFA_local); + bool is_last = (m_split == NumSplitsM{} - 1); + + if constexpr (ScalePromotionInterval != 4) { + if (accumulation.prepare_if_needed()) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + } + else { + // Always zero out the accumulator for finest granularity tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; } - } - else { - // Always zero out the accumulator for finest granularity - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - } - warpgroup_fence_operand(accumulation()); - warpgroup_arrive(); - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - // (V,M) x (V,N) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } - warpgroup_commit_batch(); + warpgroup_fence_operand(accumulation()); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_local(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); - /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed - warpgroup_fence_operand(accumulation()); + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_fence_operand(accumulation()); - if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { - tCrSFA(_0{}) = tCrSFA(_0{}) * tCrSFB(_0{}); - } - if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { - ElementBlockScale scale_b = tCrSFB(_0{}); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(filter_zeros(tCrSFA)); i++) { - filter_zeros(tCrSFA)(i) = filter_zeros(tCrSFA)(i) * scale_b; + + if constexpr (ScaleMsPerWave == 1 && ScaleNsPerTile == 1) { + tCrSFA_local(_0{}) = tCrSFA_local(_0{}) * tCrSFB(_0{}); } - } - if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { - ElementBlockScale scale_a = tCrSFA(_0{}); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(filter_zeros(tCrSFB)); i++) { - filter_zeros(tCrSFB)(i) = filter_zeros(tCrSFB)(i) * scale_a; + if constexpr (ScaleMsPerWave > 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_b = tCrSFB(_0{}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(filter_zeros(tCrSFA_local)); i++) { + filter_zeros(tCrSFA_local)(i) = filter_zeros(tCrSFA_local)(i) * scale_b; + } + } + if constexpr (ScaleMsPerWave == 1 && ScaleNsPerTile > 1) { + ElementBlockScale scale_a = tCrSFA_local(_0{}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(filter_zeros(tCrSFB_local)); i++) { + filter_zeros(tCrSFB_local)(i) = filter_zeros(tCrSFB_local)(i) * scale_a; + } + } + + warpgroup_wait<0>(); + if (is_last) { + pipeline.consumer_release(smem_pipe_release); // Unlock previous tile + ++smem_pipe_read; + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + } + // Block scale the accumulators with reg tensor `tCrSFA_local` and `tCrSFB` + if constexpr (ScaleMsPerWave == 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_ab = tCrSFA_local(_0{}); + scale_if_needed(accum_local, accumulation, scale_ab); + } + if constexpr (ScaleMsPerWave > 1 && ScaleNsPerTile == 1) { + scale_if_needed(accum_local, accumulation, tCrSFA_local); + } + if constexpr (ScaleMsPerWave == 1 && ScaleNsPerTile > 1) { + scale_if_needed(accum_local, accumulation, tCrSFB_local); + } + if constexpr (ScaleMsPerWave > 1 && ScaleNsPerTile > 1) { + scale_if_needed(accum_local, accumulation, tCrSFA_local, tCrSFB_local); } - } - warpgroup_wait<0>(); - pipeline.consumer_release(smem_pipe_release); // Unlock previous tile - ++smem_pipe_read; - barrier_token = pipeline.consumer_try_wait(smem_pipe_read); - // Block scale the accumulators with reg tensor `tCrSFA` and `tCrSFB` - if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { - ElementBlockScale scale_ab = tCrSFA(_0{}); - scale_if_needed(accumulation, scale_ab); - } - if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { - scale_if_needed(accumulation, tCrSFA); - } - if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { - scale_if_needed(accumulation, tCrSFB); - } - if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) { - scale_if_needed(accumulation, tCrSFA, tCrSFB); - } - // Advance smem_pipe_read and smem_pipe_release - ++smem_pipe_release; + if (is_last) { + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_release; + } + } // end for (m_split) } if (k_tile_count) { pipeline.consumer_wait(smem_pipe_read, barrier_token); @@ -1003,93 +984,101 @@ struct CollectiveMma< int read_stage = smem_pipe_read.index(); // Load per block scale values from shared memory to registers (at most twice per block along M and/or N) - copy(tCsSFA(_,_,_,make_coord(_0{}, read_stage)), tCrSFA); copy(tCsSFB(_,_,_,make_coord(_0{}, read_stage)), tCrSFB); - - if constexpr (ScalePromotionInterval != 4) { - if (accumulation.prepare_if_needed()) { + CUTLASS_PRAGMA_UNROLL + for (int m_split = 0; m_split < NumSplitsM{}; ++m_split) { + auto tCrA_local = tCrA_tiled(m_split, _, _, _, _); + auto tCrSFA_local = tCrSFA_tiled(m_split, _, _, _); + auto tCrSFB_local = tCrSFB_tiled(m_split, _, _, _); + auto accum_local = accum_tiled(m_split, _, _, _); + copy(tCsSFA_tiled(m_split, _, _, _, make_coord(_0{}, read_stage)), tCrSFA_local); + bool is_last = (m_split == NumSplitsM{} - 1); + + if constexpr (ScalePromotionInterval != 4) { + if (accumulation.prepare_if_needed()) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + } + else { + // Always zero out the accumulator for finest granularity tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; } - } - else { - // Always zero out the accumulator for finest granularity - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - } - warpgroup_fence_operand(accumulation()); - warpgroup_arrive(); - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - // (V,M) x (V,N) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } - warpgroup_commit_batch(); + warpgroup_fence_operand(accumulation()); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_local(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); - /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed - warpgroup_fence_operand(accumulation()); + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_fence_operand(accumulation()); - if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { - tCrSFA(_0{}) = tCrSFA(_0{}) * tCrSFB(_0{}); - } - if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { - ElementBlockScale scale_b = tCrSFB(_0{}); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(filter_zeros(tCrSFA)); i++) { - filter_zeros(tCrSFA)(i) = filter_zeros(tCrSFA)(i) * scale_b; + + if constexpr (ScaleMsPerWave == 1 && ScaleNsPerTile == 1) { + tCrSFA_local(_0{}) = tCrSFA_local(_0{}) * tCrSFB(_0{}); } - } - if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { - ElementBlockScale scale_a = tCrSFA(_0{}); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(filter_zeros(tCrSFB)); i++) { - filter_zeros(tCrSFB)(i) = filter_zeros(tCrSFB)(i) * scale_a; + if constexpr (ScaleMsPerWave > 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_b = tCrSFB(_0{}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(filter_zeros(tCrSFA_local)); i++) { + filter_zeros(tCrSFA_local)(i) = filter_zeros(tCrSFA_local)(i) * scale_b; + } + } + if constexpr (ScaleMsPerWave == 1 && ScaleNsPerTile > 1) { + ElementBlockScale scale_a = tCrSFA_local(_0{}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(filter_zeros(tCrSFB_local)); i++) { + filter_zeros(tCrSFB_local)(i) = filter_zeros(tCrSFB_local)(i) * scale_a; + } + } + warpgroup_wait<0>(); + if (is_last) { + pipeline.consumer_release(smem_pipe_release); // Unlock previous tile + } + // Block scale the accumulators with reg tensor `tCrSFA_local` and `tCrSFB` + if constexpr (ScaleMsPerWave == 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_ab = tCrSFA_local(_0{}); + scale_if_needed(accum_local, accumulation, scale_ab); + } + if constexpr (ScaleMsPerWave > 1 && ScaleNsPerTile == 1) { + scale_if_needed(accum_local, accumulation, tCrSFA_local); + } + if constexpr (ScaleMsPerWave == 1 && ScaleNsPerTile > 1) { + scale_if_needed(accum_local, accumulation, tCrSFB_local); + } + if constexpr (ScaleMsPerWave > 1 && ScaleNsPerTile > 1) { + scale_if_needed(accum_local, accumulation, tCrSFA_local, tCrSFB_local); + } + if constexpr (ScalePromotionInterval != 4) { + // residues only exists when granularity is not the finnest + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_ab = tCrSFA_local(_0{}); + accumulation.scale_residue_if_needed(accum_local, scale_ab); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + accumulation.scale_residue_if_needed(accum_local, tCrSFA_local); + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + accumulation.scale_residue_if_needed(accum_local, tCrSFB_local); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) { + accumulation.scale_residue_if_needed(accum_local, tCrSFA_local, tCrSFB_local); + } } - } - warpgroup_wait<0>(); - pipeline.consumer_release(smem_pipe_release); // Unlock previous tile - // Block scale the accumulators with reg tensor `tCrSFA` and `tCrSFB` - if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { - ElementBlockScale scale_ab = tCrSFA(_0{}); - scale_if_needed(accumulation, scale_ab); - } - if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { - scale_if_needed(accumulation, tCrSFA); - } - if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { - scale_if_needed(accumulation, tCrSFB); - } - if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) { - scale_if_needed(accumulation, tCrSFA, tCrSFB); - } - } - if constexpr (ScalePromotionInterval != 4) { - // residues only exists when granularity is not the finnest - if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { - ElementBlockScale scale_ab = tCrSFA(_0{}); - accumulation.scale_residue_if_needed(scale_ab); - } - if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { - accumulation.scale_residue_if_needed(tCrSFA); - } - if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { - accumulation.scale_residue_if_needed(tCrSFB); - } - if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) { - accumulation.scale_residue_if_needed(tCrSFA, tCrSFB); - } - } - warpgroup_fence_operand(accumulation()); + warpgroup_fence_operand(accumulation()); + } // end for (m_split) + } } /// Perform a Consumer Epilogue to release all buffers CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { - // The pipeline is not released in the first iteration - smem_pipe_release.advance(k_tile_count - 1); - pipeline.consumer_release(smem_pipe_release); } }; diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized.hpp index 220e996..be0a400 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized_fp8.hpp b/3rd/cutlass/include/cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized_fp8.hpp index d993d9a..d0fdb5b 100644 --- a/3rd/cutlass/include/cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized_fp8.hpp +++ b/3rd/cutlass/include/cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized_fp8.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -586,8 +586,8 @@ struct CollectiveMma< int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - - GmmaFP8Accumulation accumulation(accum, mainloop_params.mma_promotion_interval, size<2>(tCrA)); + auto accm_temp = cute::make_fragment_like(accum); + GmmaFP8Accumulation accumulation(accm_temp, mainloop_params.mma_promotion_interval, size<2>(tCrA)); warpgroup_fence_operand(accumulation()); CUTLASS_PRAGMA_UNROLL for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) @@ -614,7 +614,7 @@ struct CollectiveMma< warpgroup_commit_batch(); - accumulation.promote_if_needed(); + accumulation.promote_if_needed(accum); ++smem_pipe_read; } @@ -652,7 +652,7 @@ struct CollectiveMma< warpgroup_wait(); warpgroup_fence_operand(accumulation()); - accumulation.promote_if_needed(); + accumulation.promote_if_needed(accum); // UNLOCK smem_pipe_release, done _computing_ on it pipeline.consumer_release(smem_pipe_release); @@ -662,7 +662,7 @@ struct CollectiveMma< ++smem_pipe_release; } - accumulation.promote_residue_if_needed(); + accumulation.promote_residue_if_needed(accum); warpgroup_fence_operand(accumulation()); } diff --git a/3rd/cutlass/include/cutlass/gemm/device/base_grouped.h b/3rd/cutlass/include/cutlass/gemm/device/base_grouped.h index d9c2423..2562bd0 100644 --- a/3rd/cutlass/include/cutlass/gemm/device/base_grouped.h +++ b/3rd/cutlass/include/cutlass/gemm/device/base_grouped.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/device/default_gemm_configuration.h b/3rd/cutlass/include/cutlass/gemm/device/default_gemm_configuration.h index 75edf2f..15799db 100644 --- a/3rd/cutlass/include/cutlass/gemm/device/default_gemm_configuration.h +++ b/3rd/cutlass/include/cutlass/gemm/device/default_gemm_configuration.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/device/ell_gemm.h b/3rd/cutlass/include/cutlass/gemm/device/ell_gemm.h index 4261496..b6cabaf 100644 --- a/3rd/cutlass/include/cutlass/gemm/device/ell_gemm.h +++ b/3rd/cutlass/include/cutlass/gemm/device/ell_gemm.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -99,7 +99,7 @@ namespace device { Description of parameters and tensors used to represent the Blocked-Ellpack (ELL) format: a_rows - Rows in the sparse matrix. - a_cols - Colums in the sparse matrix. + a_cols - Columns in the sparse matrix. BlockedEllA - Packed matrix (ellValue matrix) that stores non-zero values in consecutive blocks, whose size is (a_rows * a_ell_num_columns) ell_idx - Blocked-ELL Column indices (ellColInd) matrix, whose size is @@ -715,7 +715,7 @@ class EllGemm gemm_op; + + // + // Launch the GEMM operation on the device + // + + cutlass::Status status = gemm_op({ + {m, n, k}, // GemmCoord problem_size, + {A, lda}, // TensorRef ref_A, {B, ldb}, // + TensorRef ref_B, {C, ldc}, // TensorRef ref_C, {D, ldd}, // + TensorRef ref_D, {alpha, beta} // + EpilogueOutputOp::Params epilogue_op_params + }); + + + A simplified view of the template is listed below. + + template < + /// Element type for A matrix operand + typename ElementA, + + /// Layout type for A matrix operand + typename LayoutA, + + /// Element type for B matrix operand + typename ElementB, + + /// Layout type for B matrix operand + typename LayoutB, + + /// Element type for C and D matrix operands + typename ElementC, + + /// Layout type for C and D matrix operands + typename LayoutC, + + /// Element type for internal accumulation + typename ElementAccumulator, + + /// Operator class tag + typename OperatorClass, + + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag, + + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + + /// Epilogue output operator + typename EpilogueOutputOp, + + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + + /// Number of stages used in the pipelined mainloop + int Stages + > + class Gemm; +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Tag indicating architecture to tune for + typename ArchTag_ = arch::Sm89, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Element Type for for the scalesl + typename ElementScale_ = float, + /// Layout for the scales. + typename LayoutScale_ = cutlass::layout::RowMajor, + /// Scale Block Size. + int ScaleBlockSize_ = 128, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = + typename threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// If true, kernel supports split-K with serial reduction + bool SplitKSerial = false, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator, + /// Gather operand A by using an index array + bool GatherA = false, + /// Gather operand B by using an index array + bool GatherB = false, + /// Scatter result D by using an index array + bool ScatterD = false, + /// Permute result D + typename PermuteDLayout = layout::NoPermute> +class GemmBlockwise { +public: + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + using ElementScale = ElementScale_; + using LayoutScale = LayoutScale_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + + static int const kAlignmentC = EpilogueOutputOp::kCount; + static bool const kSplitKSerial = SplitKSerial; + static int const kScaleBlockSize = ScaleBlockSize_; + + + static_assert(kScaleBlockSize == 128, "Scale block size has to be 128 for now."); + // Ensure the threadblock K-dimension is 128 + static_assert(ThreadblockShape::kK == kScaleBlockSize, + "GemmBlockwise requires ThreadblockShape::kK equale to Scale Block Size"); + + static_assert(cutlass::platform::is_same::value, + "Scales have to be row major for now."); + + static_assert(cutlass::platform::is_same::value, + "Scales have to be float."); + + // Tensor reference type for the FP8 scale tensors + using TensorRefScale = cutlass::TensorRef; + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Define the kernel + using GemmKernel = typename kernel::GemmBlockwise< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC, + LayoutC, ElementAccumulator, ElementScale, LayoutScale, + OperatorClass, ArchTag, ThreadblockShape, + WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, + kStages, kSplitKSerial, Operator, SharedMemoryClearOption::kNone, GatherA, + GatherB, ScatterD, PermuteDLayout>::GemmKernel; + + /// Argument structure + struct Arguments { + // + // Data members + // + + GemmCoord problem_size; + TensorRef ref_A; + TensorRef ref_B; + TensorRef ref_C; + TensorRef ref_D; + typename EpilogueOutputOp::Params epilogue; + int split_k_slices; + // For gather+scatter operations + int const *gather_A_indices; + int const *gather_B_indices; + int const *scatter_D_indices; + + // Dequantization scale tensors (row-major) + TensorRefScale scale_A; + TensorRefScale scale_B; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() : problem_size(0, 0, 0), split_k_slices(1) {} + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments(GemmCoord problem_size_, + TensorRef ref_A_, + TensorRef ref_B_, + TensorRef ref_C_, + TensorRef ref_D_, + TensorRefScale scale_A_, + TensorRefScale scale_B_, + typename EpilogueOutputOp::Params epilogue_ = typename EpilogueOutputOp::Params(), + int split_k_slices = 1, int const *gather_A_indices_ = nullptr, + int const *gather_B_indices_ = nullptr, + int const *scatter_D_indices_ = nullptr) + : problem_size(problem_size_), ref_A(ref_A_), ref_B(ref_B_), + ref_C(ref_C_), ref_D(ref_D_), scale_A(scale_A_), + scale_B(scale_B_), epilogue(epilogue_), + split_k_slices(split_k_slices), gather_A_indices(gather_A_indices_), + gather_B_indices(gather_B_indices_), + scatter_D_indices(scatter_D_indices_) {} + }; + +private: + /// Kernel parameters object + typename GemmKernel::Params params_; + +public: + /// Constructs the GEMM. + GemmBlockwise() {} + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + // Require the problem K dimension to be an exact multiple of the Threadblock K tile. + if (args.problem_size.k() % ThreadblockShape::kK != 0) { + return Status::kErrorInvalidProblem; + } + + if (!kSplitKSerial && args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + + // ------------------------------------------------------------------ + // Validate scale tensor leading dimensions. + // Row-major layout implies stride(0) equals number of columns (kBlocks). + // Both scale_A (mBlocks × kBlocks) and scale_B (nBlocks × kBlocks) must + // therefore have stride(0) == kBlocks where kBlocks = ceil_div(K, 128). + // ------------------------------------------------------------------ + int const kBlocks = (args.problem_size.k() + ThreadblockShape::kK - 1) / ThreadblockShape::kK; + + if (args.scale_A.stride(0) != kBlocks || args.scale_B.stride(0) != kBlocks) { + return Status::kErrorInvalidProblem; + } + + Status status = GemmKernel::can_implement( + args.problem_size, args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), args.ref_C.non_const_ref(), args.ref_D); + + if (status != Status::kSuccess) { + return status; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + size_t bytes = 0; + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + if (kSplitKSerial && args.split_k_slices > 1) { + bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); + } + + return bytes; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, + cudaStream_t stream = nullptr) { + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + if (kSplitKSerial) { + if (args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + size_t bytes = get_workspace_size(args); + + cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + } else { + if (args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + } + + // Initialize the Params structure + params_ = typename GemmKernel::Params{args.problem_size, + grid_shape, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + args.ref_C.non_const_ref(), + args.ref_D, + args.scale_A, + args.scale_B, + args.epilogue, + static_cast(workspace), + args.gather_A_indices, + args.gather_B_indices, + args.scatter_D_indices}; + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + if (kSplitKSerial && args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + } + + params_.ref_A.reset(args.ref_A.non_const_ref().data()); + params_.ref_B.reset(args.ref_B.non_const_ref().data()); + params_.ref_C.reset(args.ref_C.non_const_ref().data()); + params_.ref_D.reset(args.ref_D.data()); + params_.scale_A.reset(args.scale_A.data()); + params_.scale_B.reset(args.scale_B.data()); + params_.output_op = args.epilogue; + params_.semaphore = static_cast(workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + cudaError_t result; + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + cutlass::arch::synclog_setup(); + cutlass::Kernel<<>>(params_); + + result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { return run(stream); } + + /// Runs the kernel using initialized state. + Status operator()(Arguments const &args, void *workspace = nullptr, + cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for column-major output exchanges problem size and +/// operand. +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Element Type for for the scalesl + typename ElementScale_, + /// Layout for the scales. + typename LayoutScale_, + /// Scale Block Size. + int ScaleBlockSize_, + /// Epilogue output operator + typename EpilogueOutputOp_, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Access granularity of A matrix in units of elements + int AlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB, + /// If true, kernel supports split-K as a serial reduction + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator_, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB, + /// Scatter result D by using an index array + bool ScatterD, + /// Permute result D + typename PermuteDLayout> +class GemmBlockwise { +public: + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = layout::ColumnMajor; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + using ElementScale = ElementScale_; + using LayoutScale = LayoutScale_; + + static int const kScaleBlockShape = ScaleBlockSize_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + + // Alias for per-tile FP8 dequantization scale tensors + using TensorRefScale = cutlass::TensorRef; + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + static bool const kSplitKSerial = SplitKSerial; + + using UnderlyingOperator = + GemmBlockwise::type, + ElementA, typename layout::LayoutTranspose::type, + ElementC, layout::RowMajor, ElementAccumulator, + OperatorClass, ArchTag, ThreadblockShape, WarpShape, + InstructionShape, ElementScale, LayoutScale, kScaleBlockShape, + EpilogueOutputOp, ThreadblockSwizzle, + Stages, kAlignmentB, kAlignmentA, SplitKSerial, Operator, + GatherB, GatherA, ScatterD, PermuteDLayout>; + + using UnderlyingArguments = typename UnderlyingOperator::Arguments; + using GemmKernel = typename UnderlyingOperator::GemmKernel; + static int const kAlignmentC = UnderlyingOperator::kAlignmentC; + + /// Argument structure + struct Arguments { + // + // Data members + // + + GemmCoord problem_size; + TensorRef ref_A; + TensorRef ref_B; + TensorRef ref_C; + TensorRef ref_D; + typename EpilogueOutputOp::Params epilogue; + int split_k_slices; + // For gather+scatter operations + int const *gather_A_indices; + int const *gather_B_indices; + int const *scatter_D_indices; + + // Dequantization scale tensors (row-major) + TensorRefScale scale_A; + TensorRefScale scale_B; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() {} + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments(GemmCoord problem_size_, + TensorRef ref_A_, + TensorRef ref_B_, + TensorRef ref_C_, + TensorRef ref_D_, + TensorRefScale scale_A_, + TensorRefScale scale_B_, + typename EpilogueOutputOp::Params epilogue_ = typename EpilogueOutputOp::Params(), + int split_k_slices = 1, int const *gather_A_indices_ = nullptr, + int const *gather_B_indices_ = nullptr, + int const *scatter_D_indices_ = nullptr) + : problem_size(problem_size_), ref_A(ref_A_), ref_B(ref_B_), + ref_C(ref_C_), ref_D(ref_D_), scale_A(scale_A_), + scale_B(scale_B_), epilogue(epilogue_), + split_k_slices(split_k_slices), gather_A_indices(gather_A_indices_), + gather_B_indices(gather_B_indices_), + scatter_D_indices(scatter_D_indices_) {} + }; + +private: + UnderlyingOperator underlying_operator_; + +public: + /// Constructs the GEMM. + GemmBlockwise() {} + + /// Helper to construct a transposed equivalent for the underying GEMM + /// operator + static UnderlyingArguments to_underlying_arguments(Arguments const &args) { + return UnderlyingArguments( + {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()}, + {args.ref_B.data(), args.ref_B.stride(0)}, + {args.ref_A.data(), args.ref_A.stride(0)}, + {args.ref_C.data(), args.ref_C.stride(0)}, + {args.ref_D.data(), args.ref_D.stride(0)}, + {args.scale_B.data(), args.scale_B.stride(0)}, + {args.scale_A.data(), args.scale_A.stride(0)}, + args.epilogue, + args.split_k_slices, args.gather_B_indices, args.gather_A_indices, + args.scatter_D_indices); + } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + return UnderlyingOperator::can_implement(to_underlying_arguments(args)); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + return UnderlyingOperator::get_workspace_size( + to_underlying_arguments(args)); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + return underlying_operator_.initialize(to_underlying_arguments(args), + workspace, stream); + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + return underlying_operator_.update(to_underlying_arguments(args), + workspace); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + return underlying_operator_.run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { return run(stream); } + + /// Runs the kernel using initialized state. + Status operator()(Arguments const &args, void *workspace = nullptr, + cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass +//////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/gemm/device/gemm_complex.h b/3rd/cutlass/include/cutlass/gemm/device/gemm_complex.h index 3596501..418f366 100644 --- a/3rd/cutlass/include/cutlass/gemm/device/gemm_complex.h +++ b/3rd/cutlass/include/cutlass/gemm/device/gemm_complex.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -645,7 +645,7 @@ class GemmComplex< /// Constructs the GEMM. GemmComplex() { } - /// Helper to construct a transposed equivalent for the underying GEMM operator + /// Helper to construct a transposed equivalent for the underlying GEMM operator static UnderlyingArguments to_underlying_arguments(Arguments const &args) { return UnderlyingArguments( {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()}, diff --git a/3rd/cutlass/include/cutlass/gemm/device/gemm_grouped.h b/3rd/cutlass/include/cutlass/gemm/device/gemm_grouped.h index 3c1c9bc..c7d66ce 100644 --- a/3rd/cutlass/include/cutlass/gemm/device/gemm_grouped.h +++ b/3rd/cutlass/include/cutlass/gemm/device/gemm_grouped.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/device/gemm_layernorm_mainloop_fusion.h b/3rd/cutlass/include/cutlass/gemm/device/gemm_layernorm_mainloop_fusion.h index bdc2e5f..8c633a9 100644 --- a/3rd/cutlass/include/cutlass/gemm/device/gemm_layernorm_mainloop_fusion.h +++ b/3rd/cutlass/include/cutlass/gemm/device/gemm_layernorm_mainloop_fusion.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/device/gemm_sparse.h b/3rd/cutlass/include/cutlass/gemm/device/gemm_sparse.h index 57f345f..6e2d3e1 100644 --- a/3rd/cutlass/include/cutlass/gemm/device/gemm_sparse.h +++ b/3rd/cutlass/include/cutlass/gemm/device/gemm_sparse.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/device/gemm_sparse_universal.h b/3rd/cutlass/include/cutlass/gemm/device/gemm_sparse_universal.h index 2c92030..b072f30 100644 --- a/3rd/cutlass/include/cutlass/gemm/device/gemm_sparse_universal.h +++ b/3rd/cutlass/include/cutlass/gemm/device/gemm_sparse_universal.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/device/gemm_sparse_universal_with_absmax.h b/3rd/cutlass/include/cutlass/gemm/device/gemm_sparse_universal_with_absmax.h index c42c82b..e24daa6 100644 --- a/3rd/cutlass/include/cutlass/gemm/device/gemm_sparse_universal_with_absmax.h +++ b/3rd/cutlass/include/cutlass/gemm/device/gemm_sparse_universal_with_absmax.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/device/gemm_sparse_with_absmax.h b/3rd/cutlass/include/cutlass/gemm/device/gemm_sparse_with_absmax.h index 5b86f12..daa245d 100644 --- a/3rd/cutlass/include/cutlass/gemm/device/gemm_sparse_with_absmax.h +++ b/3rd/cutlass/include/cutlass/gemm/device/gemm_sparse_with_absmax.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/device/gemm_sparse_with_visitor.h b/3rd/cutlass/include/cutlass/gemm/device/gemm_sparse_with_visitor.h index c700733..1167a8d 100644 --- a/3rd/cutlass/include/cutlass/gemm/device/gemm_sparse_with_visitor.h +++ b/3rd/cutlass/include/cutlass/gemm/device/gemm_sparse_with_visitor.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/device/gemm_splitk_parallel.h b/3rd/cutlass/include/cutlass/gemm/device/gemm_splitk_parallel.h index e059981..97a3d81 100644 --- a/3rd/cutlass/include/cutlass/gemm/device/gemm_splitk_parallel.h +++ b/3rd/cutlass/include/cutlass/gemm/device/gemm_splitk_parallel.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -561,7 +561,7 @@ class GemmSplitKParallel +struct IsDistGemmKernel : cute::false_type { }; + +template +struct IsDistGemmKernel> + : cute::true_type { }; + } // namespace detail template @@ -393,10 +400,16 @@ class GemmUniversalAdapter< [[maybe_unused]] dim3 fallback_cluster = dim3{0,0,0}; if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 100 || GemmKernel::ArchTag::kMinComputeCapability == 101 + || GemmKernel::ArchTag::kMinComputeCapability == 103 ) { if constexpr (!cute::is_static_v) { - fallback_cluster = params.hw_info.cluster_shape_fallback; - cluster = params.hw_info.cluster_shape; + if constexpr (detail::IsDistGemmKernel::value) { + fallback_cluster = params.base.hw_info.cluster_shape_fallback; + cluster = params.base.hw_info.cluster_shape; + } else { + fallback_cluster = params.hw_info.cluster_shape_fallback; + cluster = params.hw_info.cluster_shape; + } } } @@ -473,6 +486,7 @@ class GemmUniversalAdapter< if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 100 || GemmKernel::ArchTag::kMinComputeCapability == 101 || GemmKernel::ArchTag::kMinComputeCapability == 120 + || GemmKernel::ArchTag::kMinComputeCapability == 103 ) { if constexpr (is_static_1x1x1) { #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) @@ -693,7 +707,7 @@ class GemmUniversalAdapter< /// Constructs the GEMM. GemmUniversalAdapter() { } - /// Helper to construct a transposed equivalent for the underying GEMM operator + /// Helper to construct a transposed equivalent for the underlying GEMM operator static Arguments to_underlying_arguments(Arguments const &args) { if (kInternalTranspose) { return args.transposed_problem(); diff --git a/3rd/cutlass/include/cutlass/gemm/device/gemm_universal_base.h b/3rd/cutlass/include/cutlass/gemm/device/gemm_universal_base.h index 6f010c1..0998735 100644 --- a/3rd/cutlass/include/cutlass/gemm/device/gemm_universal_base.h +++ b/3rd/cutlass/include/cutlass/gemm/device/gemm_universal_base.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -34,14 +34,13 @@ */ #pragma once - +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(limits) #else #include #endif -#include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" #include "cutlass/arch/arch.h" #include "cutlass/device_kernel.h" @@ -167,6 +166,7 @@ class GemmUniversalBase { } } +#ifndef __QNX__ // Update SM occupancy member cudart_result = cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( &sm_occupancy_, @@ -178,6 +178,7 @@ class GemmUniversalBase { CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags() returned error " << cudaGetErrorString(cudart_result)); return Status::kErrorInternal; } +#endif // Update device ordinal member on success device_ordinal_ = current_ordinal; diff --git a/3rd/cutlass/include/cutlass/gemm/device/gemm_universal_streamk_with_broadcast.h b/3rd/cutlass/include/cutlass/gemm/device/gemm_universal_streamk_with_broadcast.h index 7de048b..8c3cc18 100644 --- a/3rd/cutlass/include/cutlass/gemm/device/gemm_universal_streamk_with_broadcast.h +++ b/3rd/cutlass/include/cutlass/gemm/device/gemm_universal_streamk_with_broadcast.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -311,7 +311,7 @@ class GemmUniversalStreamkWithBroadcast +class GemvBlockScaled { +public: + + using GemvKernel = GemvKernel_; + + + using ElementA = typename GemvKernel::ElementA; + using LayoutA = typename GemvKernel::LayoutA; + using ElementB = typename GemvKernel::ElementB; + using ElementC = typename GemvKernel::ElementC; + + using ElementSFA = typename GemvKernel::ElementSFA; + using ElementSFB = typename GemvKernel::ElementSFB; + + using ElementAccumulator = typename GemvKernel::ElementAccumulator; + using EpilogueOutputOp = typename GemvKernel::EpilogueOutputOp; + + static ComplexTransform const kTransformA = GemvKernel::kTransformA; + static ComplexTransform const kTransformB = GemvKernel::kTransformB; + + static int const kThreadCount = GemvKernel::kThreadCount; + static int const kThreadsPerRow = GemvKernel::kThreadsPerRow; + + using Arguments = typename GemvKernel::Arguments; + using Params = typename GemvKernel::Params; + +private: + + Params params_; + +public: + + /// Constructs the GemvBlockScaled. + GemvBlockScaled() = default; + + /// Determines whether the GemvBlockScaled can execute the given problem. + static Status can_implement(Arguments const &args) { + + return GemvKernel::can_implement(args); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + return 0; + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const &args, dim3 const &block) { + if(platform::is_same::value) { + return dim3((args.problem_size.row() + (block.x - 1)) / block.x, 1, args.batch_count % 65536); + } + else { + return dim3((args.problem_size.row() + (block.y - 1)) / block.y, 1, args.batch_count % 65536); + } + } + + /// Computes the block shape + static dim3 get_block_shape() { + if(platform::is_same::value) { + return dim3(kThreadCount, 1, 1); + } + else { + return dim3(kThreadsPerRow, kThreadCount / kThreadsPerRow, 1); + } + } + + /// Initializes GemvBlockScaled state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + params_ = Params(args); + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + return params_.update(args); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + const dim3 block = get_block_shape(); + const dim3 grid = get_grid_shape(params_, block); + + int smem_size = int(sizeof(typename GemvKernel::SharedStorage)); + + cutlass::arch::synclog_setup(); + cutlass::Kernel<<>>(params_); + + cudaError_t result = cudaGetLastError(); + if (result == cudaSuccess) { + return Status::kSuccess; + } else { + return Status::kErrorInternal; + } + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/gemm/device/rank_2k.h b/3rd/cutlass/include/cutlass/gemm/device/rank_2k.h index 8e7f436..c010fc9 100644 --- a/3rd/cutlass/include/cutlass/gemm/device/rank_2k.h +++ b/3rd/cutlass/include/cutlass/gemm/device/rank_2k.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -473,7 +473,7 @@ class Rank2K::value; // Kernel schedule policies (the base class tags, one for each kernel layer file) // struct KernelMultistage { }; +struct KernelPtrArrayMultistage { }; struct KernelCpAsyncWarpSpecialized { }; struct KernelCpAsyncWarpSpecializedPingpong { }; struct KernelCpAsyncWarpSpecializedCooperative { }; @@ -125,10 +127,15 @@ struct KernelPtrArrayTmaWarpSpecializedCooperative { }; struct KernelPtrArrayTmaWarpSpecializedPingpong { }; // FP8 related policies (including Blocked Scaled Accumulation) -struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum: KernelTmaWarpSpecializedCooperative { }; -struct KernelTmaWarpSpecializedPingpongFP8BlockScaledAccum: KernelTmaWarpSpecializedPingpong { }; -struct KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum: KernelPtrArrayTmaWarpSpecializedCooperative { }; -struct KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum: KernelPtrArrayTmaWarpSpecializedPingpong { }; +struct KernelTmaWarpSpecializedCooperativeFP8Blockwise: KernelTmaWarpSpecializedCooperative { }; +struct KernelTmaWarpSpecializedPingpongFP8Blockwise: KernelTmaWarpSpecializedPingpong { }; +struct KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise: KernelPtrArrayTmaWarpSpecializedCooperative { }; +struct KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise: KernelPtrArrayTmaWarpSpecializedPingpong { }; + +using KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum = KernelTmaWarpSpecializedCooperativeFP8Blockwise; +using KernelTmaWarpSpecializedPingpongFP8BlockScaledAccum = KernelTmaWarpSpecializedPingpongFP8Blockwise; +using KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum = KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise; +using KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum = KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise; // Policies to opt into mixed type GEMMs struct KernelTmaWarpSpecializedMixedInput : KernelTmaWarpSpecialized { }; @@ -198,6 +205,17 @@ struct MainloopSm80CpAsync { using ClusterShape = ClusterShape_; }; +// n-buffer in smem (cp.async), pipelined with registers, with predicated gmem loads for SM100 Simt Ptr-Array +template +> +struct MainloopSm80ArrayCpAsync { + constexpr static int Stages = Stages_; + using ArchTag = cute::conditional_t<(size(ClusterShape_{}) > 1), arch::Sm90, arch::Sm80>; + using Schedule = KernelPtrArrayMultistage; + using ClusterShape = ClusterShape_; +}; + // n-buffer in smem (cp.async), pipelined with Hopper GMMA, with predicated gmem loads, warp specialized dynamic schedule template< int Stages_, @@ -306,17 +324,17 @@ struct MainloopSm90TmaGmmaWarpSpecializedFP8 // n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule -// For FP8 kernels with Block Scaling +// For FP8 kernels with Blockwise (Software) Scaling template< int Stages_, class ClusterShape_ = Shape<_1,_1,_1>, - class KernelSchedule = KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum + class KernelSchedule = KernelTmaWarpSpecializedCooperativeFP8Blockwise > -struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8 +struct MainloopSm90TmaGmmaWarpSpecializedBlockwiseFP8 : MainloopSm90TmaGmmaWarpSpecialized { static_assert( - cute::is_same_v || - cute::is_same_v, + cute::is_same_v || + cute::is_same_v, "KernelSchedule must be one of the warp specialized FP8 block scale policies"); }; @@ -398,19 +416,43 @@ struct MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput { template< int Stages_, class ClusterShape_ = Shape<_1,_1,_1>, - class KernelSchedule = KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum + class KernelSchedule = KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise > -struct MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling +struct MainloopSm90ArrayTmaGmmaWarpSpecializedBlockwise : MainloopSm90ArrayTmaGmmaWarpSpecialized { static_assert( cute::is_any_of_v< KernelSchedule, - KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum, - KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum + KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise, + KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise >, "KernelSchedule must be one of the warp specialized FP8 block scale policies"); }; +////////////////////////////////////////////////////////////////////////////// + +// +// Kernel Scheduler Tag +// + +// Dense GEMM: SM100 tensor op policy that applies to both 1SM and 2SM MMA atoms +template< + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_ +> +struct KernelWarpSpecializedSm100 final { + static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; +}; + +template< + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_ +> +struct KernelMixedTmaCpAsyncWarpSpecializedSm100 final { + static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; +}; template< int SchedulerPipelineStageCount_, @@ -449,6 +491,24 @@ struct KernelPtrArrayTmaWarpSpecializedMmaTransformSm100 final { static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; }; +template< + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_ +> +struct KernelTmaWarpSpecializedBlockScaledSm103 final { + static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; +}; + +template< + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_ +> +struct KernelPtrArrayTmaWarpSpecializedBlockScaledSm103 final { + static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; +}; + // Sparse Gemm template< int SchedulerPipelineStageCount_, @@ -479,6 +539,16 @@ struct KernelTmaWarpSpecializedInputTransformSm100 final { static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; }; +// Mixed Input Transform GEMM +template< + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_ +> +struct KernelTmaWarpSpecializedMixedInputTransformSm100 final { + static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; +}; + // Ptr-Array Dense GEMM: SM100 tensor op policy that applies to both 1SM and 2SM MMA atoms template< int SchedulerPipelineStageCount_, @@ -597,7 +667,7 @@ template< class KernelSchedule > struct HasAuxiliaryLoad< - MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling< + MainloopSm90ArrayTmaGmmaWarpSpecializedBlockwise< Stages, ClusterShape, KernelSchedule @@ -610,7 +680,7 @@ template< class KernelSchedule > struct HasAuxiliaryLoad< - MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8< + MainloopSm90TmaGmmaWarpSpecializedBlockwiseFP8< Stages, ClusterShape, KernelSchedule @@ -643,6 +713,9 @@ struct KernelScheduleSm100DenseGemm : KernelScheduleSm100 {}; // Base policy // Dense GEMM: Specialize for 1SM vs 2SM struct KernelTmaWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100DenseGemm {}; // Use for 1SM Dense GEMM Kernels for Collective Mainloop Builder struct KernelTmaWarpSpecialized2SmSm100 final : KernelSchedule2Sm, KernelScheduleSm100DenseGemm {}; // Use for 2SM Dense GEMM Kernels for Collective Mainloop Builder +struct KernelWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100DenseGemm {}; // Use for 1SM Dense GEMM Kernels for Collective Mainloop Builder Without TMA +struct KernelMixedTmaCpAsyncWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100DenseGemm {}; + /////////////////////////////////////////////////////////////////////////////////////////////////////// // SM100 Ptr-Array Dense GEMM Dispatch Policies /////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -699,6 +772,8 @@ struct KernelScheduleSm100MixedInputGemm : KernelScheduleSm100 {}; struct KernelTmaWarpSpecializedMixedInputSmemSm100 : KernelScheduleSm100MixedInputGemm { }; struct KernelTmaWarpSpecialized1SmMixedInputSm100 final : KernelSchedule1Sm, KernelScheduleSm100MixedInputGemm { }; struct KernelTmaWarpSpecialized1SmMixedInputSmemSm100 final : KernelSchedule1Sm, KernelTmaWarpSpecializedMixedInputSmemSm100 { }; +struct KernelTmaWarpSpecialized2SmMixedInputSm100 final : KernelSchedule2Sm, KernelScheduleSm100MixedInputGemm { }; +struct KernelTmaWarpSpecialized2SmMixedInputSmemSm100 final : KernelSchedule2Sm, KernelTmaWarpSpecializedMixedInputSmemSm100 { }; /////////////////////////////////////////////////////////////////////////////////////////////////////// // SM100 Ptr-Array FastF32 (9xBF16) GEMM Dispatch Policies @@ -712,6 +787,24 @@ struct KernelPtrArrayTmaWarpSpecialized2SmFastFP32Sm100 final : KernelSchedu struct KernelPtrArrayTmaWarpSpecialized1SmFastFP32SmemSm100 final : KernelSchedule1Sm, KernelTmaWarpSpecializedPtrArrayFastFP32SmemSm100 { }; struct KernelPtrArrayTmaWarpSpecialized2SmFastFP32SmemSm100 final : KernelSchedule2Sm, KernelTmaWarpSpecializedPtrArrayFastFP32SmemSm100 { }; +/////////////////////////////////////////////////////////////////////////////////////////////////////// +// SM100 Interleaved Complex GEMM Dispatch Policies +/////////////////////////////////////////////////////////////////////////////////////////////////////// +struct KernelScheduleSm100InterleavedComplexTF32Gemm : KernelScheduleSm100 {}; +// Transform GEMM: Specialize for Interleaved Complex GEMMs +struct KernelTmaWarpSpecialized1SmInterleavedComplexTF32Sm100 final : KernelSchedule1Sm, KernelScheduleSm100InterleavedComplexTF32Gemm { }; +struct KernelTmaWarpSpecialized2SmInterleavedComplexTF32Sm100 final : KernelSchedule2Sm, KernelScheduleSm100InterleavedComplexTF32Gemm { }; + +/////////////////////////////////////////////////////////////////////////////////////////////////////// +// SM100 Ptr-Array Interleaved Complex GEMM Dispatch Policies +/////////////////////////////////////////////////////////////////////////////////////////////////////// +// Interleaved Complex GEMM + (Ptr array or Group GEMM) +struct KernelScheduleSm100PtrArrayInterleavedComplexTF32Gemm : KernelScheduleSm100InterleavedComplexTF32Gemm {}; +// Ptr-Array Transform GEMM: Specialize for 1SM vs 2SM Complex GEMM +// Transform GEMM: Specialize for Interleaved Complex GEMMs +struct KernelPtrArrayTmaWarpSpecialized1SmInterleavedComplexTF32Sm100 final : KernelSchedule1Sm, KernelScheduleSm100PtrArrayInterleavedComplexTF32Gemm { }; +struct KernelPtrArrayTmaWarpSpecialized2SmInterleavedComplexTF32Sm100 final : KernelSchedule2Sm, KernelScheduleSm100PtrArrayInterleavedComplexTF32Gemm { }; + /////////////////////////////////////////////////////////////////////////////////////////////////////// // SM100 Sparse GEMM Dispatch Policies /////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -735,6 +828,8 @@ struct KernelTmaWarpSpecialized1SmMxf4Sm100 final : KernelSchedule1 struct KernelTmaWarpSpecialized2SmMxf4Sm100 final : KernelSchedule2Sm, KernelScheduleMxNvf4Sm100 { }; struct KernelTmaWarpSpecialized1SmMxf8f6f4Sm100 final : KernelSchedule1Sm, KernelScheduleMxf8f6f4Sm100 { }; struct KernelTmaWarpSpecialized2SmMxf8f6f4Sm100 final : KernelSchedule2Sm, KernelScheduleMxf8f6f4Sm100 { }; +struct KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100 final : KernelSchedule1Sm, KernelScheduleBlockScaledGemmSm100 {}; + /////////////////////////////////////////////////////////////////////////////////////////////////////// // SM100 BlockScaled Ptr Array Dense GEMM Dispatch Policies /////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -767,6 +862,54 @@ struct KernelSparseTmaWarpSpecialized2SmNvf4Sm100 final : KernelSchedule2 struct KernelSparseTmaWarpSpecialized1SmMxf4Sm100 final : KernelSchedule1Sm, KernelScheduleSparseMxNvf4Sm100 { }; struct KernelSparseTmaWarpSpecialized2SmMxf4Sm100 final : KernelSchedule2Sm, KernelScheduleSparseMxNvf4Sm100 { }; +/////////////////////////////////////////////////////////////////////////////////////////////////////// +// +// SM103 Dispatch Policies +// +/////////////////////////////////////////////////////////////////////////////////////////////////////// + +struct KernelScheduleSm103 {}; +struct KernelScheduleSm103BlockScaledGemm : KernelScheduleSm103 {}; +struct KernelScheduleSm103BlockScaledMxNvf4UltraTmaPrefetch : KernelScheduleSm103BlockScaledGemm {}; +struct KernelScheduleSm103BlockScaledMxNvf4UltraDisablePrefetch : KernelScheduleSm103BlockScaledGemm {}; + +// Blockscaled Gemm: Specialized for instruction type, scale factor vector size, and 1SM vs. 2SM +// These are the public dispatch policy name +struct KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch final : KernelSchedule1Sm, KernelScheduleSm103BlockScaledMxNvf4UltraTmaPrefetch { }; +struct KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch final : KernelSchedule2Sm, KernelScheduleSm103BlockScaledMxNvf4UltraTmaPrefetch { }; +struct KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch final : KernelSchedule1Sm, KernelScheduleSm103BlockScaledMxNvf4UltraTmaPrefetch { }; +struct KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch final : KernelSchedule2Sm, KernelScheduleSm103BlockScaledMxNvf4UltraTmaPrefetch { }; + +struct KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch final : KernelSchedule1Sm, KernelScheduleSm103BlockScaledMxNvf4UltraDisablePrefetch { }; +struct KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch final : KernelSchedule2Sm, KernelScheduleSm103BlockScaledMxNvf4UltraDisablePrefetch { }; +struct KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch final : KernelSchedule1Sm, KernelScheduleSm103BlockScaledMxNvf4UltraDisablePrefetch { }; +struct KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch final : KernelSchedule2Sm, KernelScheduleSm103BlockScaledMxNvf4UltraDisablePrefetch { }; + +using KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103 = KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch; +using KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103 = KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch; +using KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103 = KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch; +using KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103 = KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch; + + +struct KernelSchedulePtrArraySm103BlockScaledGemm : KernelScheduleSm103 {}; +struct KernelSchedulePtrArraySm103BlockScaledMxNvf4UltraTmaPrefetch : KernelSchedulePtrArraySm103BlockScaledGemm {}; +struct KernelSchedulePtrArraySm103BlockScaledMxNvf4UltraDisablePrefetch : KernelSchedulePtrArraySm103BlockScaledGemm {}; + +struct KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch final : KernelSchedule1Sm, KernelSchedulePtrArraySm103BlockScaledMxNvf4UltraTmaPrefetch { }; +struct KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch final : KernelSchedule2Sm, KernelSchedulePtrArraySm103BlockScaledMxNvf4UltraTmaPrefetch { }; +struct KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch final : KernelSchedule1Sm, KernelSchedulePtrArraySm103BlockScaledMxNvf4UltraTmaPrefetch { }; +struct KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch final : KernelSchedule2Sm, KernelSchedulePtrArraySm103BlockScaledMxNvf4UltraTmaPrefetch { }; + +struct KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch final : KernelSchedule1Sm, KernelSchedulePtrArraySm103BlockScaledMxNvf4UltraDisablePrefetch { }; +struct KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch final : KernelSchedule2Sm, KernelSchedulePtrArraySm103BlockScaledMxNvf4UltraDisablePrefetch { }; +struct KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch final : KernelSchedule1Sm, KernelSchedulePtrArraySm103BlockScaledMxNvf4UltraDisablePrefetch { }; +struct KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch final : KernelSchedule2Sm, KernelSchedulePtrArraySm103BlockScaledMxNvf4UltraDisablePrefetch { }; + +using KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103 = KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch; +using KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103 = KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch; +using KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103 = KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch; +using KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103 = KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch; + /////////////////////////////////////////////////////////////////////////////////////////////////////// // // SM120 Dispatch Policies @@ -822,6 +965,53 @@ struct KernelSparseTmaWarpSpecializedMxf4Sm120 final : KernelScheduleS struct KernelSparseTmaWarpSpecializedMxf8f6f4Sm120 final : KernelScheduleSparseMxf8f6f4Sm120 { }; struct KernelSparseTmaWarpSpecializedMxf8f6f4Acc2x4Sm120 final : KernelScheduleSparseMxf8f6f4Sm120, KernelScheduleAcc2x4Sm120 { }; +////////////////////////////////////////////////////////////////////////////// + +// +// Collective Mainloop Dispatch Policies +// + +// n-buffer in smem, pipelined with Blackwell UMMA and CPASYNC, Warp specialized dynamic schedule +template< + int Stages_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ClusterShape_ = Shape<_1,_1,_1> +> +struct MainloopSm100UmmaCpAsyncWarpSpecialized { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm100; + using Schedule = KernelWarpSpecializedSm100; +}; + +template< + int Stages_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ClusterShape_ = Shape<_1,_1,_1> +> +struct MainloopSm100UmmaMixedTmaCpAsyncWarpSpecialized { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm100; + using Schedule = KernelMixedTmaCpAsyncWarpSpecializedSm100; + constexpr static bool IsOverlappingAccum = false; +}; + +template< + int Stages_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ClusterShape_ = Shape<_1,_1,_1> +> +struct MainloopSm100UmmaMixedTmaCpAsyncWarpSpecializedBlockScaled { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm100; + using Schedule = KernelMixedTmaCpAsyncWarpSpecializedSm100; + constexpr static bool IsOverlappingAccum = false; +}; // n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule template< @@ -961,6 +1151,78 @@ struct MainloopSm100TmaUmmaWarpSpecializedFastF32 { }; +template< + // Number of Pipeline stages for + // MainloopLoad <-> Transformation + int ComputationPipelineStageCount_, + // TileScheduler pipeline depth + int SchedulerPipelineStageCount_, + // Accmulator pipeline depth + int AccumulatorPipelineStageCount_, + // Number of Pipeline stages for + // Transformation <-> MMA + int TransformationPipelineStageCount_, + class ClusterShape_ = Shape<_1,_1,_1>, + class AccumulatorCopyAtom_ = cute::SM100_TMEM_LOAD_16dp256b1x +> +struct MainloopSm100TmaUmmaWarpSpecializedInterleavedComplexTF32 { + constexpr static int ComputationPipelineStageCount = ComputationPipelineStageCount_; + constexpr static int TransformationPipelineStageCount = TransformationPipelineStageCount_; + constexpr static detail::KernelInputTransformType InputTransformType = detail::KernelInputTransformType::InterleavedComplexTF32; + using ClusterShape = ClusterShape_; + using AccumulatorCopyAtom = AccumulatorCopyAtom_; + using ArchTag = arch::Sm100; + using Schedule = KernelTmaWarpSpecializedInputTransformSm100; + + // For backwards compatibility with GemmUniversalAdapter. + constexpr static int Stages = ComputationPipelineStageCount; +}; + +// n-buffer in smem, pipelined with Blackwell Mixed Input kernel with UMMA (HwScaled) and TMA, +template< + // Number of Pipeline stages for + // MainloopLoad <-> Conversion <-> MainLoad + int Load2TransformPipelineStageCount_, + // Number of Pipeline stages for + // MainloopLoad <-> Conversion <-> MainLoad + int Transform2MmaPipelineStageCount_, + // TileScheduler pipeline depth + int SchedulerPipelineStageCount_, + // Accmulator pipeline depth + int AccumulatorPipelineStageCount_, + // ClusterShape for the kernel + class ClusterShape_ = Shape<_1,_1,_1> +> +struct MainloopSm100TmaUmmaWarpSpecializedMixedInput { + constexpr static int Load2TransformPipelineStageCount = Load2TransformPipelineStageCount_; + constexpr static int Load2MmaPipelineStageCount = Load2TransformPipelineStageCount_; + constexpr static int Transform2MmaPipelineStageCount = Transform2MmaPipelineStageCount_; + constexpr static detail::KernelInputTransformType InputTransformType = detail::KernelInputTransformType::MixedInput; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm100; + using Schedule = KernelTmaWarpSpecializedMixedInputTransformSm100; + + // For backwards compatibility with GemmUniversalAdapter. + constexpr static int Stages = Load2TransformPipelineStageCount; +}; + + +// n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule +template< + int Stages_, + // TileScheduler pipeline depth + int SchedulerPipelineStageCount_, + // Accmulator pipeline depth + int AccumulatorPipelineStageCount_, + class ClusterShape_ = Shape<_1,_1,_1> +> +struct MainloopSm100TmaUmmaWarpSpecializedPlanarComplex { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm100; + using Schedule = KernelTmaWarpSpecializedSm100; + constexpr static bool IsOverlappingAccum = false; +}; // n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule template< @@ -977,6 +1239,36 @@ struct MainloopSm100ArrayTmaUmmaWarpSpecialized { using Schedule = KernelPtrArrayTmaWarpSpecializedSm100; }; +// n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule +template< + int Stages_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ClusterShape_ = Shape<_1,_1,_1> +> +struct MainloopSm100RCGroupGemmTmaUmmaWarpSpecialized { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm100; + constexpr static bool IsOverlappingAccum = false; + using Schedule = KernelPtrArrayTmaWarpSpecializedSm100; +}; + +// n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule +template< + int Stages_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ClusterShape_ = Shape<_1,_1,_1> +> +struct MainloopSm100RCGroupGemmTmaUmmaWarpSpecializedBlockScaled { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm100; + constexpr static bool IsOverlappingAccum = AccumulatorPipelineStageCount_ == 1; + using Schedule = KernelPtrArrayTmaWarpSpecializedBlockScaledSm100; +}; + // n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule template< int Stages_, @@ -994,6 +1286,22 @@ struct MainloopSm100ArrayTmaUmmaWarpSpecializedBlockScaled { +// n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule +template< + int Stages_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ClusterShape_ = Shape<_1,_1,_1> +> +struct MainloopSm100ArrayTmaUmmaWarpSpecializedPlanarComplex { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm100; + constexpr static bool IsOverlappingAccum = false; + using Schedule = KernelPtrArrayTmaWarpSpecializedSm100; +}; + + // n-buffer in smem, pipelined with Blackwell Fast FP32 kernel with UMMA (HwScaled) and TMA, // Warp specialized dynamic schedule template< @@ -1042,9 +1350,77 @@ struct MainloopSm100ArrayTmaUmmaWarpSpecializedFastF32 { }; +template< + // Number of Pipeline stages for + // MainloopLoad <-> Transformation + int ComputationPipelineStageCount_, + // TileScheduler pipeline depth + int SchedulerPipelineStageCount_, + // Accmulator pipeline depth + int AccumulatorPipelineStageCount_, + // Number of Pipeline stages for + // Transformation <-> MMA + int TransformationPipelineStageCount_, + class ClusterShape_ = Shape<_1,_1,_1>, + class AccumulatorCopyAtom_ = cute::SM100_TMEM_LOAD_16dp256b1x +> +struct MainloopSm100ArrayTmaUmmaWarpSpecializedInterleavedComplexTF32 { + constexpr static int ComputationPipelineStageCount = ComputationPipelineStageCount_; + constexpr static int TransformationPipelineStageCount = TransformationPipelineStageCount_; + constexpr static detail::KernelInputTransformType InputTransformType = detail::KernelInputTransformType::InterleavedComplexTF32; + using ClusterShape = ClusterShape_; + using AccumulatorCopyAtom = AccumulatorCopyAtom_; + using ArchTag = arch::Sm100; + using Schedule = KernelPtrArrayTmaWarpSpecializedInputTransformSm100; + + // For backwards compatibility with GemmUniversalAdapter. + constexpr static int Stages = ComputationPipelineStageCount; +}; + + +// n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule +template< + int LoadABPipelineStageCount_, + int LoadSFPipelineStageCount_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ClusterShape_ = Shape<_1,_1,_1>, + cutlass::sm103::detail::KernelPrefetchType PrefetchType_ = cutlass::sm103::detail::KernelPrefetchType::TmaPrefetch +> +struct MainloopSm103TmaUmmaWarpSpecializedBlockScaled { + constexpr static int LoadABPipelineStageCount = LoadABPipelineStageCount_; + constexpr static int LoadSFPipelineStageCount = LoadSFPipelineStageCount_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm103; + constexpr static bool IsOverlappingAccum = AccumulatorPipelineStageCount_ == 1; + using Schedule = KernelTmaWarpSpecializedBlockScaledSm103; + // For backwards compatibility with GemmUniversalAdapter. + constexpr static int Stages = LoadABPipelineStageCount; + constexpr static cutlass::sm103::detail::KernelPrefetchType PrefetchType = PrefetchType_; +}; // Mainloop schedule for array-based TMA +template< + int LoadABPipelineStageCount_, + int LoadSFPipelineStageCount_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ClusterShape_ = Shape<_1,_1,_1>, + cutlass::sm103::detail::KernelPrefetchType PrefetchType_ = cutlass::sm103::detail::KernelPrefetchType::TmaPrefetch +> +struct MainloopSm103ArrayTmaUmmaWarpSpecializedBlockScaled { + constexpr static int LoadABPipelineStageCount = LoadABPipelineStageCount_; + constexpr static int LoadSFPipelineStageCount = LoadSFPipelineStageCount_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm103; + constexpr static bool IsOverlappingAccum = AccumulatorPipelineStageCount_ == 1; + using Schedule = KernelPtrArrayTmaWarpSpecializedBlockScaledSm103; + // For backwards compatibility with GemmUniversalAdapter. + constexpr static int Stages = LoadABPipelineStageCount; + constexpr static cutlass::sm103::detail::KernelPrefetchType PrefetchType = PrefetchType_; +}; + template< int Stages_, int SchedulerPipelineStageCount_, diff --git a/3rd/cutlass/include/cutlass/gemm/gemm.h b/3rd/cutlass/include/cutlass/gemm/gemm.h index 5137bfa..96e9970 100644 --- a/3rd/cutlass/include/cutlass/gemm/gemm.h +++ b/3rd/cutlass/include/cutlass/gemm/gemm.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/gemm_enumerated_types.h b/3rd/cutlass/include/cutlass/gemm/gemm_enumerated_types.h index 8961735..3bd8adb 100644 --- a/3rd/cutlass/include/cutlass/gemm/gemm_enumerated_types.h +++ b/3rd/cutlass/include/cutlass/gemm/gemm_enumerated_types.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/group_array_problem_shape.hpp b/3rd/cutlass/include/cutlass/gemm/group_array_problem_shape.hpp index fe5e4c5..f0bef44 100644 --- a/3rd/cutlass/include/cutlass/gemm/group_array_problem_shape.hpp +++ b/3rd/cutlass/include/cutlass/gemm/group_array_problem_shape.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -44,6 +44,7 @@ #include #endif + //////////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm { @@ -57,6 +58,7 @@ struct GroupProblemShape { UnderlyingProblemShape* problem_shapes = nullptr; UnderlyingProblemShape const* host_problem_shapes = nullptr; + CUTLASS_HOST_DEVICE int32_t groups() const { return num_groups; } @@ -79,6 +81,55 @@ struct GroupProblemShape { } }; +template +struct MoEProblemShape { + + using UnderlyingProblemShape = ProblemShape_; + static_assert(rank(UnderlyingProblemShape{}) == 3, "ProblemShape{} should be "); + + int32_t max_m = 0; + int32_t max_n = 0; + int32_t max_k = 0; + int32_t num_groups = 0; + int32_t* tokens_per_expert = nullptr; + int32_t* tokens_per_expert_host = nullptr; + + CUTLASS_HOST_DEVICE + int32_t groups() const { return num_groups; } + + CUTLASS_HOST_DEVICE + UnderlyingProblemShape const + get_problem_shape(int32_t group_idx=0) const { + + UnderlyingProblemShape expert_problem_dims; + assert(tokens_per_expert != nullptr); //tokens_per_expert should not be null + if (group_idx < num_groups) { // add check on the can_implement + expert_problem_dims = {max_m, tokens_per_expert[group_idx], max_k}; + } + + return expert_problem_dims; + } + + // Function returns max problem shape if tokens_per_expert host is unavailable. + // Returns host problem shape if tokens_per_expert host is available. + CUTLASS_HOST_DEVICE + UnderlyingProblemShape const + get_host_problem_shape(int32_t group_idx=0) const { + UnderlyingProblemShape expert_problem_dims = {max_m, max_n, max_k}; + if (group_idx < num_groups && tokens_per_expert_host != nullptr) { + expert_problem_dims = {max_m, tokens_per_expert_host[group_idx], max_k}; + } + return expert_problem_dims; + } + + CUTLASS_HOST_DEVICE + bool + is_host_problem_shape_available() const { + return tokens_per_expert_host != nullptr; + } + +}; + template class ArrayProblemShape { public: @@ -120,4 +171,15 @@ class ArrayProblemShape { UnderlyingProblemShape problem_shape_{}; }; + +namespace detail { + +template +struct is_moe_problem_shape : cute::false_type {}; + +template +struct is_moe_problem_shape> : cute::true_type {}; + +} + } // namespace cutlass::gemm diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_ell_gemm.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_ell_gemm.h index 561508c..708fc32 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_ell_gemm.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_ell_gemm.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm.h index da41c3e..7833db0 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_complex.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_complex.h index 438769f..77ce5ae 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_complex.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_complex.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_grouped.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_grouped.h index 1481465..546c56a 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_grouped.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_grouped.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_grouped_per_group_scale.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_grouped_per_group_scale.h index 2ace212..4c7151e 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_grouped_per_group_scale.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_grouped_per_group_scale.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_grouped_softmax_mainloop_fusion.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_grouped_softmax_mainloop_fusion.h index 7ad2f90..a744ef2 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_grouped_softmax_mainloop_fusion.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_grouped_softmax_mainloop_fusion.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_layernorm_mainloop_fusion.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_layernorm_mainloop_fusion.h index d06a2a2..128f638 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_layernorm_mainloop_fusion.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_layernorm_mainloop_fusion.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_planar_complex_universal.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_planar_complex_universal.h index 5c50d00..a809657 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_planar_complex_universal.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_planar_complex_universal.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_sparse.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_sparse.h index 8bc5ca0..ed39767 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_sparse.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_sparse.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_sparse_universal.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_sparse_universal.h index 6096552..a7dbb92 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_sparse_universal.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_sparse_universal.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_sparse_universal_with_absmax.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_sparse_universal_with_absmax.h index 15d9d79..ae31854 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_sparse_universal_with_absmax.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_sparse_universal_with_absmax.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_sparse_with_absmax.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_sparse_with_absmax.h index 2f8a2f2..5ef396a 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_sparse_with_absmax.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_sparse_with_absmax.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_sparse_with_visitor.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_sparse_with_visitor.h index eb2167f..d387618 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_sparse_with_visitor.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_sparse_with_visitor.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_splitk_parallel.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_splitk_parallel.h index c4aed55..cf9941c 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_splitk_parallel.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_splitk_parallel.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_streamk_with_broadcast.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_streamk_with_broadcast.h index 683fc51..966cfca 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_streamk_with_broadcast.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_streamk_with_broadcast.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_universal.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_universal.h index 29ff219..5be161f 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_universal.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_universal.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_universal_with_visitor.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_universal_with_visitor.h index 0ec473e..5094ffa 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_universal_with_visitor.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_universal_with_visitor.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_with_absmax.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_with_absmax.h index b27a078..37d960f 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_with_absmax.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_with_absmax.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_with_broadcast.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_with_broadcast.h index e53f31f..48c832a 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_with_broadcast.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_with_broadcast.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_with_k_reduction.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_with_k_reduction.h index 01019cf..8e84b7e 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_with_k_reduction.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_with_k_reduction.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_with_reduction.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_with_reduction.h index e24dd92..d047b26 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_with_reduction.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemm_with_reduction.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemv.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemv.h index a574dab..f6e31c9 100755 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_gemv.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_gemv.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_rank_2k.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_rank_2k.h index f52e5d7..4512309 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_rank_2k.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_rank_2k.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_rank_2k_complex.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_rank_2k_complex.h index 7b6e329..0abb468 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_rank_2k_complex.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_rank_2k_complex.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_rank_2k_grouped.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_rank_2k_grouped.h index 7f5efe3..80b0abf 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_rank_2k_grouped.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_rank_2k_grouped.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_rank_2k_universal.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_rank_2k_universal.h index a27be8d..0f93a27 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_rank_2k_universal.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_rank_2k_universal.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_rank_k.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_rank_k.h index 5001b33..41a528e 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_rank_k.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_rank_k.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_rank_k_complex.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_rank_k_complex.h index 21ccc33..3c77805 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_rank_k_complex.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_rank_k_complex.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_rank_k_universal.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_rank_k_universal.h index 503040a..95a110a 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_rank_k_universal.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_rank_k_universal.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_symm.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_symm.h index 435e46b..caa0e29 100755 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_symm.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_symm.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_symm_complex.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_symm_complex.h index 028296c..8a7f037 100755 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_symm_complex.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_symm_complex.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -160,7 +160,7 @@ struct DefaultSymmComplex< Operator, SplitKSerial, BlasMode::kSymmetric> { static BlasMode const kBlasMode = BlasMode::kSymmetric; - // Complex Transform don't appply to A or B for SYMM + // Complex Transform don't apply to A or B for SYMM static ComplexTransform const TransformA = ComplexTransform::kNone; static ComplexTransform const TransformB = ComplexTransform::kNone; @@ -353,7 +353,7 @@ struct DefaultSymmComplex< Operator, SplitKSerial, BlasMode::kSymmetric> { static BlasMode const kBlasMode = BlasMode::kSymmetric; - // Complex Transform don't appply to A or B for SYMM + // Complex Transform don't apply to A or B for SYMM static ComplexTransform const TransformA = ComplexTransform::kNone; static ComplexTransform const TransformB = ComplexTransform::kNone; diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_symm_universal.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_symm_universal.h index 8915df6..8517c18 100755 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_symm_universal.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_symm_universal.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_trmm.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_trmm.h index 8e004d0..424966c 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_trmm.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_trmm.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_trmm_complex.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_trmm_complex.h index d8eeee1..2a596d5 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_trmm_complex.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_trmm_complex.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/default_trmm_universal.h b/3rd/cutlass/include/cutlass/gemm/kernel/default_trmm_universal.h index fef1fcd..539d0d4 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/default_trmm_universal.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/default_trmm_universal.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/ell_gemm.h b/3rd/cutlass/include/cutlass/gemm/kernel/ell_gemm.h index 16010fd..452c602 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/ell_gemm.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/ell_gemm.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/gemm.h b/3rd/cutlass/include/cutlass/gemm/kernel/gemm.h index 22b5f48..4bd37ac 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/gemm.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/gemm.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_array.h b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_array.h index 8812806..1b23ada 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_array.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_array.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_batched.h b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_batched.h index efd5b84..5c51b4e 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_batched.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_batched.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_blockwise.h b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_blockwise.h new file mode 100644 index 0000000..f8ee9d0 --- /dev/null +++ b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_blockwise.h @@ -0,0 +1,223 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief + Default kernel-level GEMM definitions combine threadblock-scoped matrix + multiply-add with the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major + outputs are accommodated by exchanging A and B operands and assuming + transposed layouts. Partial specializations here choose + 'device::GemmTransposed' to implement this functionality. +*/ + +#pragma once + +#include "cutlass/arch/wmma.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_universal_blockwise.h" +#include "cutlass/gemm/kernel/gemm_pipelined.h" +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/gemm/threadblock/default_mma_multistage_blockwise.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/permute.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Element Type for for the scalesl + typename ElementScale, + /// Layout for the scales. + typename LayoutScale, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Gather operand A by using an index array + bool GatherA = false, + /// Gather operand B by using an index array + bool GatherB = false, + /// Scatter result D by using an index array + bool ScatterD = false, + /// Permute result D + typename PermuteDLayout = layout::NoPermute, + /// Permute operand A + typename PermuteALayout = layout::NoPermute, + /// Permute operand B + typename PermuteBLayout = layout::NoPermute, + /// + typename Enable = void> +struct GemmBlockwise; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Ada Architecture +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of A matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB, + /// Scatter result D by using an index array + bool ScatterD, + /// Permute result D + typename PermuteDLayout, + /// Permute operand A + typename PermuteALayout, + /// Permute operand B + typename PermuteBLayout> +struct GemmBlockwise { + /// Define the threadblock-scoped matrix multiply-accumulate + using Mma = typename cutlass::gemm::threadblock::DefaultMmaBlockwise< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementAccumulator, layout::RowMajor, float, layout::RowMajor, + arch::OpClassTensorOp, arch::Sm89, + ThreadblockShape, WarpShape, InstructionShape, Stages, Operator, false, + SharedMemoryClear, GatherA, GatherB, PermuteALayout, + PermuteBLayout>::ThreadblockMma; + + static_assert(ThreadblockShape::kK % WarpShape::kK == 0, "ThreadblockShape::kK must be divisible by WarpShape::kK."); + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + /// Define the epilogue + using Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, typename Mma::Operator, kPartitionsK, + EpilogueOutputOp, EpilogueOutputOp::kCount, ScatterD, + PermuteDLayout>::Epilogue; + + /// Define the kernel-level GEMM operator. + using GemmKernel = + kernel::GemmUniversalBlockwise; +}; +//////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_grouped.h b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_grouped.h index 3a4098c..e870dfc 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_grouped.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_grouped.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_grouped_per_group_scale.h b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_grouped_per_group_scale.h index 65325e5..e15889b 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_grouped_per_group_scale.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_grouped_per_group_scale.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_grouped_problem_visitor.h b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_grouped_problem_visitor.h index dc37d56..b7b6014 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_grouped_problem_visitor.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_grouped_problem_visitor.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_grouped_softmax_mainloop_fusion.h b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_grouped_softmax_mainloop_fusion.h index f6fc222..fd5e789 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_grouped_softmax_mainloop_fusion.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_grouped_softmax_mainloop_fusion.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_layernorm_mainloop_fusion.h b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_layernorm_mainloop_fusion.h index c862cc0..e0f6efe 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_layernorm_mainloop_fusion.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_layernorm_mainloop_fusion.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_params.h b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_params.h index a3b0eb8..b9681f6 100755 --- a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_params.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_params.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_pipelined.h b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_pipelined.h index 4d19982..3f97b45 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_pipelined.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_pipelined.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_planar_complex.h b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_planar_complex.h index 0f8cd33..42daaf8 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_planar_complex.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_planar_complex.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_planar_complex_array.h b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_planar_complex_array.h index 1685f23..b6e4e41 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_planar_complex_array.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_planar_complex_array.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_sparse_universal.h b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_sparse_universal.h index 035caf7..0bdd588 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_sparse_universal.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_sparse_universal.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_sparse_universal_with_absmax.h b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_sparse_universal_with_absmax.h index 6251c38..34dba69 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_sparse_universal_with_absmax.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_sparse_universal_with_absmax.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_splitk_parallel.h b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_splitk_parallel.h index a21f081..d31e37a 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_splitk_parallel.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_splitk_parallel.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_streamk_with_fused_epilogue.h b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_streamk_with_fused_epilogue.h index 473819a..ac068c9 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_streamk_with_fused_epilogue.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_streamk_with_fused_epilogue.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_transpose_operands.h b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_transpose_operands.h index 98bc227..93446c4 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_transpose_operands.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_transpose_operands.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_universal.h b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_universal.h index be1e1d8..652b730 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_universal.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_universal.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_universal.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_universal.hpp index 08605e0..d0c84d3 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_universal.hpp +++ b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_universal.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -54,6 +54,7 @@ struct IsCutlass3ArrayKernel +struct GemmUniversalBlockwise { + using Mma = Mma_; + using Epilogue = Epilogue_; + using OutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static bool const kSplitKSerial = SplitKSerial; + + // Added aliases for per-tile FP8 dequantisation scale tensors + using LayoutScale = cutlass::layout::RowMajor; + using TensorRefScale = cutlass::TensorRef; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::Params params_B; + typename Mma::IteratorB::TensorRef ref_B; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + typename OutputOp::Params output_op; + int *semaphore; + int gemm_k_size; + // For gather+scatter operations + int const *gather_A_indices; + int const *gather_B_indices; + int const *scatter_D_indices; + TensorRefScale scale_A; + TensorRefScale scale_B; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() : swizzle_log_tile(0), semaphore(0), gemm_k_size(0) {} + + CUTLASS_HOST_DEVICE + Params(cutlass::gemm::GemmCoord const &problem_size, + cutlass::gemm::GemmCoord const &grid_tiled_shape, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::TensorRef ref_B, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D, + TensorRefScale scale_A, + TensorRefScale scale_B, + typename OutputOp::Params output_op = typename OutputOp::Params(), + int *workspace = nullptr, int const *gather_A_indices = nullptr, + int const *gather_B_indices = nullptr, + int const *scatter_D_indices = nullptr) + : problem_size(problem_size), grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + params_A(ref_A.layout()), ref_A(ref_A), params_B(ref_B.layout()), + ref_B(ref_B), params_C(ref_C.layout()), ref_C(ref_C), + params_D(ref_D.layout()), ref_D(ref_D), scale_A(scale_A), + scale_B(scale_B), output_op(output_op), + gather_A_indices(gather_A_indices), + gather_B_indices(gather_B_indices), + scatter_D_indices(scatter_D_indices) { + int total_gemm_k_iterations = + (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + int gemm_k_iterations = + (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / + grid_tiled_shape.k(); + + gemm_k_size = gemm_k_iterations * Mma::Shape::kK; + + semaphore = workspace; + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + GemmUniversalBlockwise() {} + + /// Determines whether kernel satisfies alignment + CUTLASS_HOST_DEVICE + static Status + can_implement(cutlass::gemm::GemmCoord const &problem_size, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::TensorRef ref_B, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D) { + static int const kAlignmentA = + (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = + (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = + (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Epilogue::OutputTileIterator::kElementsPerAccess; + + if (!TensorRef_aligned(ref_A, kAlignmentA)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_B, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_C, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_D, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + return Status::kSuccess; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + return; + } + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.k() * params.gemm_k_size, + }; + + cutlass::MatrixCoord tb_offset_B{ + threadblock_tile_offset.k() * params.gemm_k_size, + threadblock_tile_offset.n() * Mma::Shape::kN}; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k = + min(params.problem_size.k(), + (threadblock_tile_offset.k() + 1) * params.gemm_k_size); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = + (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / + Mma::Shape::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, params.ref_A.data(), + {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A, + params.gather_A_indices); + + typename Mma::IteratorB iterator_B( + params.params_B, params.ref_B.data(), + {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B, + params.gather_B_indices); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = canonical_warp_idx_sync(); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + if (!kSplitKSerial || gemm_k_iterations > 0) { + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, + params.scale_A, params.scale_B); + } + + // + // Epilogue + // + + OutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // assume identity swizzle + MatrixCoord threadblock_offset(threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * + Mma::Shape::kN); + + int block_idx = threadblock_tile_offset.m() + + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // If performing a reduction via split-K, fetch the initial synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is + // currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), + params.grid_tiled_shape.k()); + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params.params_C, params.ref_C.data(), params.problem_size.mn(), + thread_idx, threadblock_offset, params.scatter_D_indices); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, params.ref_D.data(), params.problem_size.mn(), + thread_idx, threadblock_offset, params.scatter_D_indices); + + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator + // construction + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + // For subsequent threadblocks, the source matrix is held in the 'D' + // tensor. + if (threadblock_tile_offset.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // + // Release the semaphore + // + + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_universal_decl.h b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_universal_decl.h index 9465234..8a3ae30 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_universal_decl.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_universal_decl.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_universal_streamk.h b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_universal_streamk.h index 96a0956..32cc012 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_universal_streamk.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_universal_streamk.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -406,8 +406,6 @@ struct GemmUniversalStreamk { // Zero-initialize barrier workspace if (barrier_workspace) { - size_t barrier_workspace_bytes = get_barrier_workspace_size(); - CUTLASS_TRACE_HOST(" Initialize " << barrier_workspace_bytes << " barrier bytes"); cudaError_t result = cudaMemsetAsync( diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_universal_with_visitor.h b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_universal_with_visitor.h index e8fdea7..03e3221 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_universal_with_visitor.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_universal_with_visitor.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_universal_with_visitor_streamk.h b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_universal_with_visitor_streamk.h index 3fd9d60..7d9a19c 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_universal_with_visitor_streamk.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_universal_with_visitor_streamk.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_with_absmax.h b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_with_absmax.h index f1a3ec8..6856620 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_with_absmax.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_with_absmax.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h index b27c167..47aeca9 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_with_k_reduction.h b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_with_k_reduction.h index c8b24ee..7c7db70 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/gemm_with_k_reduction.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/gemm_with_k_reduction.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/gemv.h b/3rd/cutlass/include/cutlass/gemm/kernel/gemv.h index eb5da1a..4d6b42f 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/gemv.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/gemv.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -247,7 +247,7 @@ struct Gemv < Status update(Arguments const &args) { output_op = args.output_op; - ref_A = ref_A; + ref_A = args.ref_A; ptr_B = args.ptr_B; ptr_C = args.ptr_C; ptr_D = args.ptr_D; @@ -480,7 +480,7 @@ struct Gemv < problem_size = args.problem_size; batch_count = args.batch_count; output_op = args.output_op; - ref_A = ref_A; + ref_A = args.ref_A; ptr_B = args.ptr_B; ptr_C = args.ptr_C; ptr_D = args.ptr_D; diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/gemv_batched_strided.h b/3rd/cutlass/include/cutlass/gemm/kernel/gemv_batched_strided.h index 3b22c11..ddc0e47 100755 --- a/3rd/cutlass/include/cutlass/gemm/kernel/gemv_batched_strided.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/gemv_batched_strided.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -70,7 +70,7 @@ namespace detail using CDType = typename FragmentCD::value_type; static_assert(FragmentCD::kElements == FragmentAccumulator::kElements, - "Mistmatch in fragment sizes."); + "Mismatch in fragment sizes."); for (int i = 0; i < FragmentCD::kElements; ++i) { diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/gemv_blockscaled.h b/3rd/cutlass/include/cutlass/gemm/kernel/gemv_blockscaled.h new file mode 100644 index 0000000..9a09de9 --- /dev/null +++ b/3rd/cutlass/include/cutlass/gemm/kernel/gemv_blockscaled.h @@ -0,0 +1,885 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/arch/cache_operation.h" /// cutlass::arch::CacheOperation +#include "cutlass/arch/memory.h" // cutlass::arch::global_load +#include "cutlass/arch/memory_sm80.h" // cp.async helpers, ldsm, cp_async_wait +#include "cutlass/complex.h" // cutlass::ComplexTransform: +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" // cutlass::fast_max +#include "cutlass/layout/matrix.h" // cutlass::layout::RowMajor +#include "cutlass/matrix_coord.h" // cutlass::MatrixCoord +#include "cutlass/numeric_conversion.h" // cutlass::FloatRoundStyle, cutlass::NumericConverter +#include "cutlass/numeric_types.h" // cutlass::float_e4m3_t +#include "cutlass/platform/platform.h" // cutlass::is_same_v +#include "cutlass/tensor_ref.h" // cutlass::TensorRef +#include "cutlass/semaphore.h" // split-k + +#include "cute/algorithm/functional.hpp" // cute::for_each +#include "cute/numeric/arithmetic_tuple.hpp" // cute::make_int_sequence + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA_, + typename LayoutA_, + typename ElementB_, + typename ElementC_, + typename ElementAccumulator_, + typename EpilogueOutputOp_, + int kElementsPerAccess_ = 1, ///< Number of elements involved in a global access. + int kThreadCount_ = 0, ///< Number of threads in the thread block. + /// It will be calculated automatically if set to 0. + int kThreadsPerRow_ = 0, ///< Number of threads in the k dimension. + /// It will be calculated automatically if set to 0. + typename ElementSFA_ = cutlass::float_e4m3_t, + typename ElementSFB_ = cutlass::float_e4m3_t, + int kSFVecSize_ = 16 +> +struct GemvBlockScaled; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Specializations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GEMV for row-major A matrix +template +struct GemvBlockScaled +{ +public: + using ElementA = ElementA_; + using ElementSFA = ElementSFA_; + using LayoutA = cutlass::layout::RowMajor; + using TensorRefA = cutlass::TensorRef; + static_assert(cutlass::sizeof_bits::value == 8, "ElementSFA should be FP8 type"); + + using ElementB = ElementB_; + using ElementSFB = ElementSFB_; + using LayoutB = cutlass::layout::ColumnMajor; + static_assert(cutlass::sizeof_bits::value == 8, "ElementSFB should be FP8 type"); + + using ElementC = ElementC_; + using LayoutC = cutlass::layout::ColumnMajor; + + using ElementAccumulator = ElementAccumulator_; + + static constexpr cutlass::ComplexTransform kTransformA = cutlass::ComplexTransform::kNone; + static constexpr cutlass::ComplexTransform kTransformB = cutlass::ComplexTransform::kNone; + + static constexpr FloatRoundStyle Round = cutlass::FloatRoundStyle::round_to_nearest; + + // number of return elements in a global access + static constexpr int kElementsPerAccess = kElementsPerAccess_; + static constexpr int kSFVecSize = kSFVecSize_; + static constexpr int kSFPerAccess = cutlass::const_max(1, kElementsPerAccess / kSFVecSize); + + static_assert(kSFVecSize == 16, "Only SFVecSize = 16 is supported"); + // Hardcode some check for easier debug + static_assert(kElementsPerAccess == 32, "for fp4 kernel, 32 elt per access"); + static_assert(kSFPerAccess == 2, "fpr fp4 kernel, 2 sf read per thread"); + + static constexpr bool kDequantizeA = cutlass::sizeof_bits::value == 4; + static constexpr bool kDequantizeB = cutlass::sizeof_bits::value == 4; + static constexpr int kPackedElementsA = cutlass::sizeof_bits::value == 4 ? 2 : 1; + static constexpr int kPackedElementsB = cutlass::sizeof_bits::value == 4 ? 2 : 1; + static constexpr int kPackedElements = cutlass::const_max(kPackedElementsA, kPackedElementsB); + + static_assert(kDequantizeA == true, "kDequantizeA should be true"); + static_assert(kDequantizeB == true, "kDequantizeB should be true"); + + using FragmentA = cutlass::Array; + using FragmentB = cutlass::Array; + using FragmentCompute = cutlass::Array; + using FragmentSFA = cutlass::Array; + using FragmentSFB = cutlass::Array; + using FragmentPackedA = cutlass::Array; + using FragmentPackedB = cutlass::Array; + + static_assert(sizeof_bits::value == 128, "FragmentA should be 128 bits"); + static_assert(sizeof_bits::value == 128, "FragmentB should be 128 bits"); + + // // thread block shape (kThreadsPerRow, kThreadCount / kThreadsPerRow, 1) + static constexpr int kThreadCount = (kThreadCount_ <= 0) ? 128 : kThreadCount_; + static constexpr int kThreadsPerRow = (kThreadsPerRow_ <= 0) ? + cutlass::const_min(static_cast(kThreadCount / cutlass::bits_to_bytes(kElementsPerAccess * cutlass::sizeof_bits::value)), 16) : + kThreadsPerRow_; + static constexpr int kThreadsPerCol = kThreadCount / kThreadsPerRow; + + static constexpr int kStageCount = 4; + static constexpr int kBufferCount = 2; + + // Number of elements stored in shared memory per stage for operands A and B. + // Each thread contributes `kElementsPerAccess / kPackedElements{A,B}` packed + // values. + static constexpr int kSmemPerStageA = kThreadCount * kElementsPerAccess / kPackedElementsA; + // B is uniform across all threads in the same k-column, so only store it once per k-thread + static constexpr int kSmemPerStageB = kThreadsPerRow * kElementsPerAccess / kPackedElementsB; + + using EpilogueOutputOp = EpilogueOutputOp_; + + // Ensure epilogue and mainloop have same thread layout + static_assert(kThreadCount == EpilogueOutputOp::kThreadCount, "mainloop, epilogue thread count mismatch"); + static_assert(kThreadsPerRow == EpilogueOutputOp::kThreadsPerRow, "mainloop, epilogue thread per row mismatch"); + static_assert(kThreadsPerCol == EpilogueOutputOp::kThreadsPerCol, "mainloop, epilogue thread per col mismatch"); + + // + // Structures + // + + /// Argument structure + struct Arguments + { + MatrixCoord problem_size; + int32_t batch_count{0}; + typename EpilogueOutputOp::Params epilogue; + + TensorRefA ref_A; + + ElementB const *ptr_B{nullptr}; + ElementC const *ptr_C{nullptr}; + ElementC *ptr_D{nullptr}; + + ElementSFA const *ptr_SFA{nullptr}; + ElementSFB const *ptr_SFB{nullptr}; + + int64_t stride_A{0}; + int64_t batch_stride_A{0}; + int64_t batch_stride_B{0}; + int64_t batch_stride_C{0}; + int64_t batch_stride_D{0}; + + int64_t batch_stride_SFA{0}; + int64_t batch_stride_SFB{0}; + int64_t batch_stride_SFD{0}; + }; + + using Params = Arguments; + + /// Shared memory storage structure + struct SharedStorage + { + using EpilogueStorage = typename EpilogueOutputOp::SharedStorage; + EpilogueStorage epilogue; + + alignas(16) ElementA smem_A[kBufferCount][kStageCount][kSmemPerStageA]; + alignas(16) ElementB smem_B[kBufferCount][kStageCount][kSmemPerStageB]; + alignas(16) ElementSFA smem_SFA[kBufferCount][kStageCount][kThreadCount * kSFPerAccess]; + alignas(16) ElementSFB smem_SFB[kBufferCount][kStageCount][kThreadsPerRow * kSFPerAccess]; + }; + +public: + // + // Methods + // + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::MatrixCoord const &problem_size) + { + if (problem_size.column() % kElementsPerAccess != 0) { + return Status::kErrorMisalignedOperand; + } + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) + { + return can_implement(args.problem_size); + } + + /// Executes one GEMV + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) + { + EpilogueOutputOp epilogue(params.epilogue, shared_storage.epilogue); + + // Converters only needed for regular GEMV fallback case + NumericConverter A_converter; + NumericConverter B_converter; + NumericConverter SFA_converter; + NumericConverter SFB_converter; + + const int32_t gemm_m = params.problem_size.row(); + [[maybe_unused]] static constexpr int32_t gemm_n = 1; + const int32_t gemm_k = params.problem_size.column(); + const int32_t gemm_batch = params.batch_count; + + // Loop over batch indices + for (int batch_idx = blockIdx.z; batch_idx < gemm_batch; batch_idx += gridDim.z) { + + int idx_col_k = threadIdx.x; + int idx_row_m = blockIdx.x * blockDim.y + threadIdx.y; + + if (idx_row_m < gemm_m) { + // problem_size (row = m, column = k) + // matrix A (batch, m, k) + // vector B (batch, k, 1) + // vector C (batch, m, 1) + // vector D (batch, m, 1) + // move in the batch dimension + ElementA const *ptr_A = params.ref_A.data() + batch_idx * params.batch_stride_A / kPackedElementsA; + ElementB const *ptr_B = params.ptr_B + batch_idx * params.batch_stride_B / kPackedElementsB; + ElementC const *ptr_C = params.ptr_C + batch_idx * params.batch_stride_C; + ElementC *ptr_D = params.ptr_D + batch_idx * params.batch_stride_D; + + // move in the k dimension + ptr_A += idx_col_k * kElementsPerAccess / kPackedElementsA; + ptr_B += idx_col_k * kElementsPerAccess / kPackedElementsB; + + // move in the m dimension + ptr_A += idx_row_m * params.stride_A / kPackedElementsA; + ptr_C += idx_row_m; + ptr_D += idx_row_m; + + ElementSFA const *ptr_SF_A{nullptr}; + ElementSFB const *ptr_SF_B{nullptr}; + int global_k{0}; + + int SF_blocks_by_M = (gemm_m + 127) >> 7; + int SF_blocks_by_K = (gemm_k / kSFVecSize + 3) >> 2; + + // move in the batch dimension + ptr_SF_A = params.ptr_SFA + batch_idx * SF_blocks_by_M * SF_blocks_by_K * 512; + ptr_SF_B = params.ptr_SFB + batch_idx * SF_blocks_by_K * 512; + + // move in the m dimension + ptr_SF_A += (((idx_row_m >> 7) * SF_blocks_by_K) << 9) + ((idx_row_m & 0x1f) << 4) + ((idx_row_m & 0x7f) >> 5 << 2); + + global_k = idx_col_k * kElementsPerAccess; + + ElementAccumulator accum = ElementAccumulator(0); + + // Local aliases + const int tileA_k_local = kThreadsPerRow * kElementsPerAccess; + const int total_tiles = gemm_k / tileA_k_local; + + int unroll_col_k = 0; // total K elements consumed so far by this thread + const int thread_id = threadIdx.y * kThreadsPerRow + threadIdx.x; + const bool is_even_thread = (threadIdx.x % 2 == 0); + const bool load_b = (threadIdx.y == 0); + const int smem_sf_write_offset = (thread_id / 2) * 4; // 4 FP8 per even thread + const int smem_sf_offset = thread_id * kSFPerAccess; + + // Fast path: if the problem fits entirely in the tail path, skip SMEM + if (total_tiles == 0) { + accum += process_tail_elements(0, idx_col_k, gemm_k, + ptr_A, ptr_B, + ptr_SF_A, ptr_SF_B, + A_converter, B_converter, + SFA_converter, SFB_converter); + } else { + + // Scaling factors are now loaded from shared memory, no register pipeline needed + + // Thread-local SMEM line offset + const int thread_linear = threadIdx.y * kThreadsPerRow + threadIdx.x; + const int smem_offset_A = thread_linear * (kElementsPerAccess / kPackedElementsA); + // Only one row of threads (threadIdx.y == 0) loads B + const int smem_offset_B = threadIdx.x * (kElementsPerAccess / kPackedElementsB); + + // PROLOGUE – prime first kStageCount-1 stages into buffer 0 + CUTLASS_PRAGMA_UNROLL + for (int b = 0; b < kBufferCount - 1; ++b) { + // Load all stages using the helper function + load_stages_gmem_to_smem( + b, // buffer_idx + kStageCount, // num_stages + unroll_col_k, // passed by reference + global_k, // passed by reference + tileA_k_local, + smem_offset_A, + smem_offset_B, + smem_sf_write_offset, + is_even_thread, + load_b, + true, // valid_tile = true for prologue + ptr_A, + ptr_B, + ptr_SF_A, + ptr_SF_B, + shared_storage); + } + cutlass::arch::cp_async_fence(); + + // Ensure first stage committed + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Register double buffering for A/B fragments and SFA/SFB like SM80 + FragmentA fragA_reg[2]; + FragmentB fragB_reg[2]; + FragmentSFA fragSFA_reg[2]; + FragmentSFB fragSFB_reg[2]; + + // Current pipe index in smem to read from + int smem_pipe_read = 0; + // Current pipe index in smem to write to + int smem_pipe_write = kBufferCount - 1; + + // PREFETCH register pipeline - load first kblock (stage 0) into register bank 0 + if constexpr (kStageCount > 1) + { + int frag_idx = 0; + + // Load fragments using the helper function + load_smem_fragments( + fragA_reg[frag_idx], + fragB_reg[frag_idx], + fragSFA_reg[frag_idx], + fragSFB_reg[frag_idx], + smem_pipe_read, + 0, // k_block = 0 + smem_offset_A, + smem_offset_B, + smem_sf_offset, + shared_storage); + + } + + // Mainloop + int tile_idx = 0; + while (tile_idx < total_tiles) { + int smem_pipe_read_curr = smem_pipe_read; + + for_each(make_int_sequence{}, [&] (auto k_block) + { + if (k_block == kStageCount - 1) + { + cutlass::arch::cp_async_wait(); + __syncthreads(); + + smem_pipe_read_curr = smem_pipe_read; + } + + // Load A/B/SFA/SFB smem->regs for k_block_next + auto k_block_next = (k_block + Int<1>{}) % kStageCount; + int frag_idx_next = (k_block + 1) & 1; + + // Prefetch next kblock data using saved pipe index + load_smem_fragments( + fragA_reg[frag_idx_next], + fragB_reg[frag_idx_next], + fragSFA_reg[frag_idx_next], + fragSFB_reg[frag_idx_next], + smem_pipe_read_curr, + k_block_next, + smem_offset_A, + smem_offset_B, + smem_sf_offset, + shared_storage); + // Copy gmem to smem before computing gemm on each k-pipe + if (k_block == 0) + { + // Use predicate instead of branch for cp_async + bool valid_tile = (global_k < gemm_k); + + // Load all stages using the helper function + load_stages_gmem_to_smem( + smem_pipe_write, // buffer_idx + kStageCount, // num_stages + unroll_col_k, // passed by reference + global_k, // passed by reference + tileA_k_local, + smem_offset_A, + smem_offset_B, + smem_sf_write_offset, + is_even_thread, + load_b, + valid_tile, + ptr_A, + ptr_B, + ptr_SF_A, + ptr_SF_B, + shared_storage); + + cutlass::arch::cp_async_fence(); + + // Advance the pipe indices + smem_pipe_write = smem_pipe_read; + ++smem_pipe_read; + smem_pipe_read = (smem_pipe_read == kBufferCount) ? 0 : smem_pipe_read; + } + + { + int frag_idx = k_block & 1; + + // Compute using current fragments + accum += blockscaled_multiply_add( + fragA_reg[frag_idx], fragB_reg[frag_idx], + fragSFA_reg[frag_idx], + fragSFB_reg[frag_idx]); + } + }); + + tile_idx += kStageCount; + } + + // Drain outstanding async copies + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + // Tail elements that don't fill a full tile + if (unroll_col_k + idx_col_k * kPackedElementsA < gemm_k) { + accum += process_tail_elements(unroll_col_k, idx_col_k, gemm_k, + ptr_A, ptr_B, + ptr_SF_A, ptr_SF_B, + A_converter, B_converter, + SFA_converter, SFB_converter); + } + } + + CUTLASS_PRAGMA_UNROLL + for (int mask = (kThreadsPerRow >> 1); mask > 0; mask >>= 1) { + accum += ElementAccumulator(__shfl_xor_sync(0xFFFFFFFF, static_cast(accum), mask, 32)); + } + + auto frag_acc = static_cast(accum); + auto frag_c = static_cast(*(ptr_C)); + + // Applying blockscaled epilogue + epilogue(frag_acc, frag_c, batch_idx); + } + } + } //end of operator() + +private: + // Load multiple stages from global to shared memory + CUTLASS_DEVICE + void load_stages_gmem_to_smem( + int buffer_idx, + int num_stages, + int& unroll_col_k, + int& global_k, + int tileA_k_local, + int smem_offset_A, + int smem_offset_B, + int smem_sf_write_offset, + bool is_even_thread, + bool load_b, + bool valid_tile, + ElementA const* ptr_A, + ElementB const* ptr_B, + ElementSFA const* ptr_SF_A, + ElementSFB const* ptr_SF_B, + SharedStorage& shared_storage) { + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < num_stages; ++s) { + // Load scaling factors using cp.async - only even threads participate + // Calculate SF indices for this thread + int SF_idx = global_k / kSFVecSize; + int SF_offset_by_k = ((SF_idx >> 2) << 9) + (SF_idx & 0x3); + + void *smem_ptr_SFA = &shared_storage.smem_SFA[buffer_idx][s][smem_sf_write_offset]; + const void *gmem_ptr_SFA = ptr_SF_A + SF_offset_by_k; + // Load 4 FP8 values (32 bits) - for this thread and next thread + cutlass::arch::cp_async(smem_ptr_SFA, gmem_ptr_SFA, valid_tile && is_even_thread); + + void *smem_ptr_SFB = &shared_storage.smem_SFB[buffer_idx][s][(threadIdx.x / 2) * 4]; + const void *gmem_ptr_SFB = ptr_SF_B + SF_offset_by_k; + // Load 4 FP8 values (32 bits) - for this thread and next thread, only if threadIdx.y == 0 + cutlass::arch::cp_async(smem_ptr_SFB, gmem_ptr_SFB, valid_tile && load_b && is_even_thread); + + void *smem_ptr_A = &shared_storage.smem_A[buffer_idx][s][smem_offset_A]; + const void *gmem_ptr_A = ptr_A + unroll_col_k / kPackedElementsA; + cutlass::arch::cp_async(smem_ptr_A, gmem_ptr_A, valid_tile); + + void *smem_ptr_B = &shared_storage.smem_B[buffer_idx][s][smem_offset_B]; + const void *gmem_ptr_B = ptr_B + unroll_col_k / kPackedElementsB; + cutlass::arch::cp_async(smem_ptr_B, gmem_ptr_B, valid_tile && load_b); + + unroll_col_k += tileA_k_local; + global_k += tileA_k_local; + } + } + + /// Fused blockscaled GEMV computation using PTX + CUTLASS_DEVICE + ElementAccumulator blockscaled_multiply_add( + FragmentA const& fragA, + FragmentB const& fragB, + FragmentSFA const& fragSFA, + FragmentSFB const& fragSFB) { + + #if defined(CUDA_PTX_FP4FP6_CVT_ENABLED) + uint16_t const& src_fragSFA_packed = reinterpret_cast(fragSFA); + uint16_t const& src_fragSFB_packed = reinterpret_cast(fragSFB); + + uint32_t const* src_fragA_packed = reinterpret_cast(&fragA); + uint32_t const* src_fragB_packed = reinterpret_cast(&fragB); + + ElementAccumulator out; + uint16_t* out_fp16 = reinterpret_cast(&out); + + asm volatile( \ + "{\n" \ + // declare registers for A / B tensors + ".reg .b8 byte0_0, byte0_1, byte0_2, byte0_3;\n" \ + ".reg .b8 byte0_4, byte0_5, byte0_6, byte0_7;\n" \ + ".reg .b8 byte1_0, byte1_1, byte1_2, byte1_3;\n" \ + ".reg .b8 byte1_4, byte1_5, byte1_6, byte1_7;\n" \ + ".reg .b8 byte2_0, byte2_1, byte2_2, byte2_3;\n" \ + ".reg .b8 byte2_4, byte2_5, byte2_6, byte2_7;\n" \ + ".reg .b8 byte3_0, byte3_1, byte3_2, byte3_3;\n" \ + ".reg .b8 byte3_4, byte3_5, byte3_6, byte3_7;\n" \ + + // declare registers for accumulators + ".reg .f16x2 accum_0_0, accum_0_1, accum_0_2, accum_0_3;\n" \ + ".reg .f16x2 accum_1_0, accum_1_1, accum_1_2, accum_1_3;\n" \ + ".reg .f16x2 accum_2_0, accum_2_1, accum_2_2, accum_2_3;\n" \ + ".reg .f16x2 accum_3_0, accum_3_1, accum_3_2, accum_3_3;\n" \ + + // declare registers for scaling factors + ".reg .f16x2 sfa_f16x2;\n" \ + ".reg .f16x2 sfb_f16x2;\n" \ + ".reg .f16x2 sf_f16x2;\n" \ + + // declare registers for conversion + ".reg .f16x2 cvt_0_0, cvt_0_1, cvt_0_2, cvt_0_3;\n" \ + ".reg .f16x2 cvt_0_4, cvt_0_5, cvt_0_6, cvt_0_7;\n" \ + ".reg .f16x2 cvt_1_0, cvt_1_1, cvt_1_2, cvt_1_3;\n" \ + ".reg .f16x2 cvt_1_4, cvt_1_5, cvt_1_6, cvt_1_7;\n" \ + ".reg .f16x2 cvt_2_0, cvt_2_1, cvt_2_2, cvt_2_3;\n" \ + ".reg .f16x2 cvt_2_4, cvt_2_5, cvt_2_6, cvt_2_7;\n" \ + ".reg .f16x2 cvt_3_0, cvt_3_1, cvt_3_2, cvt_3_3;\n" \ + ".reg .f16x2 cvt_3_4, cvt_3_5, cvt_3_6, cvt_3_7;\n" \ + ".reg .f16 result_f16, lane0, lane1;\n" \ + ".reg .f16x2 mul_f16x2_0, mul_f16x2_1;\n" \ + + // convert scaling factors from fp8 to f16x2 + "cvt.rn.f16x2.e4m3x2 sfa_f16x2, %1;\n" \ + "cvt.rn.f16x2.e4m3x2 sfb_f16x2, %2;\n" \ + + // clear accumulators + "mov.b32 accum_0_0, 0;\n" \ + "mov.b32 accum_0_1, 0;\n" \ + "mov.b32 accum_0_2, 0;\n" \ + "mov.b32 accum_0_3, 0;\n" \ + "mov.b32 accum_1_0, 0;\n" \ + "mov.b32 accum_1_1, 0;\n" \ + "mov.b32 accum_1_2, 0;\n" \ + "mov.b32 accum_1_3, 0;\n" \ + "mov.b32 accum_2_0, 0;\n" \ + "mov.b32 accum_2_1, 0;\n" \ + "mov.b32 accum_2_2, 0;\n" \ + "mov.b32 accum_2_3, 0;\n" \ + "mov.b32 accum_3_0, 0;\n" \ + "mov.b32 accum_3_1, 0;\n" \ + "mov.b32 accum_3_2, 0;\n" \ + "mov.b32 accum_3_3, 0;\n" \ + + // multiply, unpacking and permuting scale factors + "mul.rn.f16x2 sf_f16x2, sfa_f16x2, sfb_f16x2;\n" \ + "mov.b32 {lane0, lane1}, sf_f16x2;\n" \ + "mov.b32 mul_f16x2_0, {lane0, lane0};\n" \ + "mov.b32 mul_f16x2_1, {lane1, lane1};\n" \ + + // unpacking A and B tensors + "mov.b32 {byte0_0, byte0_1, byte0_2, byte0_3}, %3;\n" \ + "mov.b32 {byte0_4, byte0_5, byte0_6, byte0_7}, %4;\n" \ + "mov.b32 {byte1_0, byte1_1, byte1_2, byte1_3}, %5;\n" \ + "mov.b32 {byte1_4, byte1_5, byte1_6, byte1_7}, %6;\n" \ + "mov.b32 {byte2_0, byte2_1, byte2_2, byte2_3}, %7;\n" \ + "mov.b32 {byte2_4, byte2_5, byte2_6, byte2_7}, %8;\n" \ + "mov.b32 {byte3_0, byte3_1, byte3_2, byte3_3}, %9;\n" \ + "mov.b32 {byte3_4, byte3_5, byte3_6, byte3_7}, %10;\n" \ + + // convert A and B tensors from fp4 to f16x2 + + // A[0 - 7] and B[0 - 7] + "cvt.rn.f16x2.e2m1x2 cvt_0_0, byte0_0;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_0_1, byte0_1;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_0_2, byte0_2;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_0_3, byte0_3;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_0_4, byte0_4;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_0_5, byte0_5;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_0_6, byte0_6;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_0_7, byte0_7;\n" \ + + // A[8 - 15] and B[8 - 15] + "cvt.rn.f16x2.e2m1x2 cvt_1_0, byte1_0;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_1_1, byte1_1;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_1_2, byte1_2;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_1_3, byte1_3;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_1_4, byte1_4;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_1_5, byte1_5;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_1_6, byte1_6;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_1_7, byte1_7;\n" \ + + // A[16 - 23] and B[16 - 23] + "cvt.rn.f16x2.e2m1x2 cvt_2_0, byte2_0;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_2_1, byte2_1;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_2_2, byte2_2;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_2_3, byte2_3;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_2_4, byte2_4;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_2_5, byte2_5;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_2_6, byte2_6;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_2_7, byte2_7;\n" \ + + // A[24 - 31] and B[24 - 31] + "cvt.rn.f16x2.e2m1x2 cvt_3_0, byte3_0;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_3_1, byte3_1;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_3_2, byte3_2;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_3_3, byte3_3;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_3_4, byte3_4;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_3_5, byte3_5;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_3_6, byte3_6;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_3_7, byte3_7;\n" \ + + // fma for A[0 - 7] and B[0 - 7] + "fma.rn.f16x2 accum_0_0, cvt_0_0, cvt_0_4, accum_0_0;\n" \ + "fma.rn.f16x2 accum_0_1, cvt_0_1, cvt_0_5, accum_0_1;\n" \ + "fma.rn.f16x2 accum_0_2, cvt_0_2, cvt_0_6, accum_0_2;\n" \ + "fma.rn.f16x2 accum_0_3, cvt_0_3, cvt_0_7, accum_0_3;\n" \ + + // fma for A[8 - 15] and B[8 - 15] + "fma.rn.f16x2 accum_1_0, cvt_1_0, cvt_1_4, accum_1_0;\n" \ + "fma.rn.f16x2 accum_1_1, cvt_1_1, cvt_1_5, accum_1_1;\n" \ + "fma.rn.f16x2 accum_1_2, cvt_1_2, cvt_1_6, accum_1_2;\n" \ + "fma.rn.f16x2 accum_1_3, cvt_1_3, cvt_1_7, accum_1_3;\n" \ + + // fma for A[16 - 23] and B[16 - 23] + "fma.rn.f16x2 accum_2_0, cvt_2_0, cvt_2_4, accum_2_0;\n" \ + "fma.rn.f16x2 accum_2_1, cvt_2_1, cvt_2_5, accum_2_1;\n" \ + "fma.rn.f16x2 accum_2_2, cvt_2_2, cvt_2_6, accum_2_2;\n" \ + "fma.rn.f16x2 accum_2_3, cvt_2_3, cvt_2_7, accum_2_3;\n" \ + + // fma for A[24 - 31] and B[24 - 31] + "fma.rn.f16x2 accum_3_0, cvt_3_0, cvt_3_4, accum_3_0;\n" \ + "fma.rn.f16x2 accum_3_1, cvt_3_1, cvt_3_5, accum_3_1;\n" \ + "fma.rn.f16x2 accum_3_2, cvt_3_2, cvt_3_6, accum_3_2;\n" \ + "fma.rn.f16x2 accum_3_3, cvt_3_3, cvt_3_7, accum_3_3;\n" \ + + // tree reduction for accumulators + "add.rn.f16x2 accum_0_0, accum_0_0, accum_0_1;\n" \ + "add.rn.f16x2 accum_0_2, accum_0_2, accum_0_3;\n" \ + "add.rn.f16x2 accum_1_0, accum_1_0, accum_1_1;\n" \ + "add.rn.f16x2 accum_1_2, accum_1_2, accum_1_3;\n" \ + "add.rn.f16x2 accum_2_0, accum_2_0, accum_2_1;\n" \ + "add.rn.f16x2 accum_2_2, accum_2_2, accum_2_3;\n" \ + "add.rn.f16x2 accum_3_0, accum_3_0, accum_3_1;\n" \ + "add.rn.f16x2 accum_3_2, accum_3_2, accum_3_3;\n" \ + + "add.rn.f16x2 accum_0_0, accum_0_0, accum_0_2;\n" \ + "add.rn.f16x2 accum_1_0, accum_1_0, accum_1_2;\n" \ + "add.rn.f16x2 accum_2_0, accum_2_0, accum_2_2;\n" \ + "add.rn.f16x2 accum_3_0, accum_3_0, accum_3_2;\n" \ + + "add.rn.f16x2 accum_0_0, accum_0_0, accum_1_0;\n" \ + "add.rn.f16x2 accum_2_0, accum_2_0, accum_3_0;\n" \ + + // apply scaling factors and final reduction + "mul.rn.f16x2 accum_0_0, mul_f16x2_0, accum_0_0;\n" \ + "mul.rn.f16x2 accum_2_0, mul_f16x2_1, accum_2_0;\n" \ + + "add.rn.f16x2 accum_0_0, accum_0_0, accum_2_0;\n" \ + + "mov.b32 {lane0, lane1}, accum_0_0;\n" \ + "add.rn.f16 result_f16, lane0, lane1;\n" \ + + "mov.b16 %0, result_f16;\n" \ + + "}\n" + : "=h"(out_fp16[0]) // 0 + : "h"(src_fragSFA_packed), "h"(src_fragSFB_packed), // 1, 2 + "r"(src_fragA_packed[0]), "r"(src_fragB_packed[0]), // 3, 4 + "r"(src_fragA_packed[1]), "r"(src_fragB_packed[1]), // 5, 6 + "r"(src_fragA_packed[2]), "r"(src_fragB_packed[2]), // 7, 8 + "r"(src_fragA_packed[3]), "r"(src_fragB_packed[3]) // 9, 10 + : "memory" + ); + + return out; + + #else + NumericArrayConverter srcA_converter; + NumericArrayConverter srcB_converter; + NumericConverter SFA_converter; + NumericConverter SFB_converter; + + FragmentCompute fragA_Compute = srcA_converter(fragA); + FragmentCompute fragB_Compute = srcB_converter(fragB); + ElementAccumulator accum = ElementAccumulator(0); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kSFPerAccess; i++) { + ElementAccumulator accum_SF_block = ElementAccumulator(0); + + int local_k_offset = i * kSFVecSize; + ElementAccumulator multiplier{1}; + + multiplier = SFA_converter(fragSFA.at(i)) * SFB_converter(fragSFB.at(i)); + + + CUTLASS_PRAGMA_UNROLL + for (int e = 0; e < kSFVecSize; e++) { + accum_SF_block += fragA_Compute.at(e + local_k_offset) * fragB_Compute.at(e + local_k_offset); + } + + accum_SF_block *= multiplier; + accum += accum_SF_block; + } + + return accum; + + #endif + } + + CUTLASS_DEVICE + ElementAccumulator process_tail_elements( + int unroll_col_k, + int idx_col_k, + int gemm_k, + ElementA const *ptr_A, + ElementB const *ptr_B, + ElementSFA const *ptr_SF_A, + ElementSFB const *ptr_SF_B, + NumericConverter const &A_converter, + NumericConverter const &B_converter, + NumericConverter const &SFA_converter, + NumericConverter const &SFB_converter) { + + ElementAccumulator accum = ElementAccumulator(0); + + // calculate the rest of K elements + // each thread fetch 1 element each time + for (int k = unroll_col_k + idx_col_k * kPackedElementsA; k < gemm_k; k += kThreadsPerRow * kPackedElementsA) { + // blockscaled GEMV + int SF_idx = k / kSFVecSize; + int SF_offset_by_k = ((SF_idx >> 2) << 9) + (SF_idx & 0x3); + + ElementSFA sfa = *(ptr_SF_A + SF_offset_by_k); + ElementSFB sfb = *(ptr_SF_B + SF_offset_by_k); + + FragmentPackedA fragA; + FragmentPackedB fragB; + + // fetch from matrix A + arch::global_load( + fragA, + ptr_A - (idx_col_k * kElementsPerAccess - k) / kPackedElementsA, + true); + + // fetch from vector B + arch::global_load( + fragB, + ptr_B - (idx_col_k * kElementsPerAccess - k) / kPackedElementsB, + true); + + ElementAccumulator accum_SF_packed = ElementAccumulator(0); + + CUTLASS_PRAGMA_UNROLL + for (int e = 0; e < kPackedElements; e++) { + accum_SF_packed += A_converter(fragA.at(e)) * B_converter(fragB.at(e)); + } + + accum_SF_packed *= SFA_converter(sfa) * SFB_converter(sfb); + + accum += accum_SF_packed; + + } + + return accum; + } + + // Load fragments from shared memory + template + CUTLASS_DEVICE + void load_smem_fragments( + FragmentA& fragA, + FragmentB& fragB, + FragmentSFA& fragSFA, + FragmentSFB& fragSFB, + int smem_pipe_idx, + int k_block, + int smem_offset_A, + int smem_offset_B, + int smem_sf_offset, + SharedStorage& shared_storage) const { + + // Load A/B fragments + arch::shared_load(fragA, &shared_storage.smem_A[smem_pipe_idx][k_block][smem_offset_A]); + arch::shared_load(fragB, &shared_storage.smem_B[smem_pipe_idx][k_block][smem_offset_B]); + + // Load SF fragments + uint32_t smem_ptr = cutlass::arch::cutlass_get_smem_pointer(&shared_storage.smem_SFA[smem_pipe_idx][k_block][smem_sf_offset]); + arch::shared_load<2>(&fragSFA, smem_ptr); + smem_ptr = cutlass::arch::cutlass_get_smem_pointer(&shared_storage.smem_SFB[smem_pipe_idx][k_block][threadIdx.x * kSFPerAccess]); + arch::shared_load<2>(&fragSFB, smem_ptr); + + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/grouped_problem_visitor.h b/3rd/cutlass/include/cutlass/gemm/kernel/grouped_problem_visitor.h index 7aaaa09..ab15055 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/grouped_problem_visitor.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/grouped_problem_visitor.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/params_sparse_base.h b/3rd/cutlass/include/cutlass/gemm/kernel/params_sparse_base.h index 3b1d2c9..d3ba4d3 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/params_sparse_base.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/params_sparse_base.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/params_universal_base.h b/3rd/cutlass/include/cutlass/gemm/kernel/params_universal_base.h index 46933d9..ff4a20e 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/params_universal_base.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/params_universal_base.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/rank_2k_grouped.h b/3rd/cutlass/include/cutlass/gemm/kernel/rank_2k_grouped.h index 41165cf..23c8ae9 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/rank_2k_grouped.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/rank_2k_grouped.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h b/3rd/cutlass/include/cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h index c9fcf0c..eaae822 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/rank_2k_transpose_operands.h b/3rd/cutlass/include/cutlass/gemm/kernel/rank_2k_transpose_operands.h index 349cd25..39b96bc 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/rank_2k_transpose_operands.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/rank_2k_transpose_operands.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/rank_2k_universal.h b/3rd/cutlass/include/cutlass/gemm/kernel/rank_2k_universal.h index f304d06..5e36530 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/rank_2k_universal.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/rank_2k_universal.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/rank_k_universal.h b/3rd/cutlass/include/cutlass/gemm/kernel/rank_k_universal.h index 9609143..cd3aba5 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/rank_k_universal.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/rank_k_universal.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized.hpp index 738f460..18d1cd6 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -132,7 +132,10 @@ class GemmUniversal< using AtomThrShapeMNK = typename CollectiveMainloop::AtomThrShapeMNK; using CtaShape_MNK = typename CollectiveMainloop::CtaShape_MNK; - static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + //For the case of RCGroupedGemm we are still GroupedGemm but our StrideA will not match with InternalStrideA + // Hence it's better to take this decision based upon StrideB + static constexpr bool IsGroupedGemmKernel = !(cute::is_same_v); using TileSchedulerTag = cute::conditional_t; using TileScheduler = typename detail::TileSchedulerSelector< @@ -140,25 +143,40 @@ class GemmUniversal< using TileSchedulerArguments = typename TileScheduler::Arguments; using TileSchedulerParams = typename TileScheduler::Params; + static constexpr bool IsSchedDynamicPersistent = TileScheduler::IsDynamicPersistent; + static constexpr bool IsTensorMapUpdateAsync = not IsSchedDynamicPersistent; static constexpr bool IsDynamicCluster = not cute::is_static_v; static constexpr uint32_t MinTensorMapWorkspaceAlignment = 64; // Warp specialization thread count per threadblock - static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp - static constexpr uint32_t NumMMAThreads = NumThreadsPerWarp; // 1 warp - static constexpr uint32_t NumMainloopLoadThreads = NumThreadsPerWarp; // 1 warp - static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp - static constexpr uint32_t NumEpilogueThreads = CollectiveEpilogue::ThreadCount; - static constexpr uint32_t NumEpilogueWarps = NumEpilogueThreads / NumThreadsPerWarp; - - static constexpr uint32_t MaxThreadsPerBlock = NumSchedThreads + + static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumTensorMapUpdaterThreads = IsTensorMapUpdateAsync ? NumThreadsPerWarp * 4 : 0; // Four warps to update tensor maps and plumb updated tileId. + static constexpr uint32_t NumMMAThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMainloopLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueThreads = CollectiveEpilogue::ThreadCount; + static constexpr uint32_t NumEpilogueWarps = NumEpilogueThreads / NumThreadsPerWarp; + + static_assert( + SchedulerPipelineStageCount % (IsTensorMapUpdateAsync ? NumTensorMapUpdaterThreads / NumThreadsPerWarp : 1) == 0, + "SchedulerPipelineStageCount for async tensor map update kernels must be divisible by the number of asynchronous tensor map updater warps." + ); + + static_assert( + (!IsTensorMapUpdateAsync) + || CollectiveEpilogue::NumMaxSchedulerPipelineStageCount >= SchedulerPipelineStageCount, + "The epilog collective expected a less scheduler stage count. Consider relaxing its NumMaxSchedulerPipelineStageCount parameter." + ); + + static constexpr uint32_t MaxThreadsPerBlock = NumSchedThreads + NumTensorMapUpdaterThreads + NumMainloopLoadThreads + NumMMAThreads + NumEpilogueLoadThreads + NumEpilogueThreads; static constexpr uint32_t MinBlocksPerMultiprocessor = 1; static constexpr uint32_t NumFixupBarriers = 1; static constexpr uint32_t CLCResponseSize = sizeof(typename TileScheduler::CLCResponse); - - static constexpr bool IsSchedDynamicPersistent = TileScheduler::IsDynamicPersistent; + + static constexpr uint32_t GenericRegisterRequirement = 136; + static constexpr uint32_t AccumRegisterRequirement = 232; // Pipeline and pipeline state types using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; @@ -179,17 +197,59 @@ class GemmUniversal< cutlass::PipelineCLCFetchAsync, cutlass::PipelineAsync>; using CLCPipelineState = typename CLCPipeline::PipelineState; + using TensorMapReadyPipeline = cute::conditional_t, + CLCPipeline + >; + using TensorMapReadyPipelineState = typename TensorMapReadyPipeline::PipelineState; using CLCThrottlePipeline = cute::conditional_t, cutlass::PipelineEmpty>; using CLCThrottlePipelineState = typename CLCThrottlePipeline::PipelineState; + template + struct WithTensorMapUpdateInfo : public BaseResponse { + uint16_t batch_changed = 0; + uint16_t TMA_stage = 0; + WithTensorMapUpdateInfo() = default; + CUTLASS_DEVICE WithTensorMapUpdateInfo(BaseResponse const& response) : BaseResponse(response) {} + }; + + using CLCResponseWithAdditionalInformation = cute::conditional_t< + IsTensorMapUpdateAsync, + WithTensorMapUpdateInfo, + typename TileScheduler::CLCResponse + >; + using TmemAllocator = cute::conditional_t(typename TiledMma::ThrLayoutVMNK{})) == 1, cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>; // Kernel level shared memory storage struct SharedStorage { - struct PipelineStorage : cute::aligned_struct<16, _1> { + // The PipelineStorageImplWithoutAsyncUpdate and PipelineStorageImplWithAsyncUpdate only differ in the + // presence of the TensorMapReadyPipelineStorage. + // We could use some other technique to avoid duplication for the common members, but any technique + // we tried would break the MSVC build. + // As a workaround, we just copied the code. + + struct PipelineStorageImplWithoutAsyncUpdate : cute::aligned_struct<16, _1> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + using LoadOrderBarrierStorage = typename LoadOrderBarrier::SharedStorage; + using CLCPipelineStorage = typename CLCPipeline::SharedStorage; + using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; + using CLCThrottlePipelineStorage = typename CLCThrottlePipeline::SharedStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) LoadOrderBarrierStorage load_order; + alignas(16) CLCPipelineStorage clc; + alignas(16) AccumulatorPipelineStorage accumulator; + alignas(16) CLCThrottlePipelineStorage clc_throttle; + alignas(16) arch::ClusterBarrier tmem_dealloc; + }; + + struct PipelineStorageImplWithAsyncUpdate : cute::aligned_struct<16, _1> { using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; using LoadOrderBarrierStorage = typename LoadOrderBarrier::SharedStorage; @@ -204,9 +264,17 @@ class GemmUniversal< alignas(16) AccumulatorPipelineStorage accumulator; alignas(16) CLCThrottlePipelineStorage clc_throttle; alignas(16) arch::ClusterBarrier tmem_dealloc; - } pipelines; - alignas(16) typename TileScheduler::CLCResponse clc_response[SchedulerPipelineStageCount]; + // Below is the only difference between PipelineStorageImpl and PipelineStorageImpl + using TensorMapReadyPipelineStorage = typename TensorMapReadyPipeline::SharedStorage; + alignas(16) TensorMapReadyPipelineStorage tensor_map_ready; + }; + + using PipelineStorage = cute::conditional_t; + + PipelineStorage pipelines; + + alignas(16) CLCResponseWithAdditionalInformation clc_response[IsTensorMapUpdateAsync ? 2 : 1][SchedulerPipelineStageCount]; uint32_t tmem_base_ptr; struct TensorMapStorage : cute::aligned_struct<128, _1> { @@ -214,7 +282,7 @@ class GemmUniversal< using MainloopTensorMapStorage = typename CollectiveMainloop::TensorMapStorage; alignas(128) EpilogueTensorMapStorage epilogue; alignas(128) MainloopTensorMapStorage mainloop; - } tensormaps; + } tensormaps[(NumTensorMapUpdaterThreads/NumThreadsPerWarp)+1]; struct TensorStorage : cute::aligned_struct<128, _1> { using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; @@ -226,7 +294,6 @@ class GemmUniversal< }; static constexpr int SharedStorageSize = sizeof(SharedStorage); - static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); // Host facing host arguments struct Arguments { @@ -253,7 +320,9 @@ class GemmUniversal< Sched = 1, MainloopLoad = 2, EpilogueLoad = 3, - Epilogue = 4 + Epilogue = 4, + // TensorMapUpdater starts at 256 thread alignment + TensorMapUpdater = 8 }; struct IsParticipant { @@ -262,6 +331,7 @@ class GemmUniversal< uint32_t main_load = false; uint32_t epi_load = false; uint32_t epilogue = false; + uint32_t tensor_map_updater = false; }; // @@ -468,8 +538,7 @@ class GemmUniversal< return grid_shape; } - static constexpr - dim3 + static dim3 get_block_shape() { return dim3(MaxThreadsPerBlock, 1, 1); } @@ -481,18 +550,20 @@ class GemmUniversal< using namespace cute; using X = Underscore; + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); auto problem_shape = params.problem_shape; // Account for more than one epilogue warp int warp_idx = canonical_warp_idx_sync(); - WarpCategory warp_category = warp_idx < static_cast(WarpCategory::Epilogue) ? WarpCategory(warp_idx) - : WarpCategory::Epilogue; + WarpCategory warp_category = warp_idx < static_cast(WarpCategory::Epilogue) ? WarpCategory(warp_idx) + : warp_idx < static_cast(WarpCategory::TensorMapUpdater) ? WarpCategory::Epilogue + : WarpCategory::TensorMapUpdater; uint32_t lane_predicate = cute::elect_one_sync(); auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}); int cluster_size = size(cluster_shape); uint32_t cta_rank_in_cluster = cute::block_rank_in_cluster(); - bool is_first_cta_in_cluster = IsSchedDynamicPersistent ? (cta_rank_in_cluster == 0) : true; + bool is_first_cta_in_cluster = cta_rank_in_cluster == 0; int cta_coord_v = cta_rank_in_cluster % size<0>(typename TiledMma::AtomThrID{}); bool is_mma_leader_cta = cta_coord_v == 0; constexpr bool has_mma_peer_cta = size(AtomThrShapeMNK{}) == 2; @@ -509,12 +580,57 @@ class GemmUniversal< bool is_epi_load_needed = collective_epilogue.is_producer_load_needed(); IsParticipant is_participant = { (warp_category == WarpCategory::MMA), // mma - (warp_category == WarpCategory::Sched) && is_first_cta_in_cluster, // sched + (warp_category == WarpCategory::Sched) && (IsSchedDynamicPersistent ? is_first_cta_in_cluster : true), // sched (warp_category == WarpCategory::MainloopLoad), // main_load (warp_category == WarpCategory::EpilogueLoad) && is_epi_load_needed, // epi_load - (warp_category == WarpCategory::Epilogue) // epilogue + (warp_category == WarpCategory::Epilogue), // epilogue + (warp_category == WarpCategory::TensorMapUpdater) && IsTensorMapUpdateAsync // tensor_map_updater + }; + + int32_t sm_id = static_cast(cutlass::arch::SmId()); + if constexpr (IsGroupedGemmKernel) { + // In case user wants to engage less SMs than available on device + sm_id = blockIdx.x + (blockIdx.y * gridDim.x); + } + auto tensormaps_init_main_load = [&] () { + if constexpr (IsTensorMapUpdateAsync) { + return collective_mainloop.template tensormaps_init( + params.mainloop, + shared_storage.tensormaps[0].mainloop, + params.hw_info.sm_count, + sm_id + ); + } + else { + return nullptr; + } + }; + + auto tensormaps_init_epi_load = [&] () { + if constexpr (IsTensorMapUpdateAsync) { + return collective_epilogue.template tensormaps_init( + params.epilogue, + shared_storage.tensormaps[0].epilogue, + params.hw_info.sm_count, + sm_id + ); + } + else { + return nullptr; + } }; + decltype(tensormaps_init_main_load()) pre_init_main_load_tensormaps; + decltype(tensormaps_init_epi_load()) pre_init_epi_load_tensormaps; + + + if (is_participant.main_load) { + pre_init_main_load_tensormaps = tensormaps_init_main_load(); + } + if (is_participant.epi_load) { + pre_init_epi_load_tensormaps = tensormaps_init_epi_load(); + } + // Mainloop Load pipeline typename MainloopPipeline::Params mainloop_pipeline_params; if (WarpCategory::MainloopLoad == warp_category) { @@ -582,9 +698,14 @@ class GemmUniversal< clc_pipeline_params.transaction_bytes = CLCResponseSize; } else { - clc_pipeline_params.consumer_arv_count = NumMainloopLoadThreads + NumEpilogueThreads + NumMMAThreads; - if (is_epi_load_needed) { - clc_pipeline_params.consumer_arv_count += NumEpilogueLoadThreads; + if constexpr (IsTensorMapUpdateAsync) { + clc_pipeline_params.consumer_arv_count = NumThreadsPerWarp; + } + else { + clc_pipeline_params.consumer_arv_count = NumMainloopLoadThreads + NumEpilogueThreads + NumMMAThreads; + if (is_epi_load_needed) { + clc_pipeline_params.consumer_arv_count += NumEpilogueLoadThreads; + } } } // Now declare the pipeline outside the if constexpr @@ -597,6 +718,32 @@ class GemmUniversal< } }(); + auto tensor_map_ready_pipeline = [&] () { + if constexpr (IsGroupedGemmKernel) { + // TMA update ready pipeline + typename TensorMapReadyPipeline::Params tensor_map_ready_pipeline_params; + + if (WarpCategory::TensorMapUpdater == warp_category) { + tensor_map_ready_pipeline_params.role = TensorMapReadyPipeline::ThreadCategory::Producer; + } + else { + tensor_map_ready_pipeline_params.role = TensorMapReadyPipeline::ThreadCategory::Consumer; + } + + tensor_map_ready_pipeline_params.initializing_warp = 8; + tensor_map_ready_pipeline_params.producer_arv_count = NumThreadsPerWarp; + + tensor_map_ready_pipeline_params.consumer_arv_count = NumMainloopLoadThreads + NumEpilogueThreads + NumMMAThreads; + if (is_epi_load_needed) { + tensor_map_ready_pipeline_params.consumer_arv_count += NumEpilogueLoadThreads; + } + return TensorMapReadyPipeline(shared_storage.pipelines.tensor_map_ready, tensor_map_ready_pipeline_params); + } + else { + return clc_pipeline; + } + }(); + // Mainloop-Epilogue pipeline typename AccumulatorPipeline::Params accumulator_pipeline_params; if (WarpCategory::MMA == warp_category) { @@ -673,19 +820,65 @@ class GemmUniversal< CLCPipelineState clc_pipe_consumer_state; CLCPipelineState clc_pipe_producer_state = cutlass::make_producer_start_state(); + TensorMapReadyPipelineState tensor_map_ready_pipe_consumer_state; + TensorMapReadyPipelineState tensor_map_ready_pipe_producer_state = cutlass::make_producer_start_state(); + AccumulatorPipelineState accumulator_pipe_consumer_state; AccumulatorPipelineState accumulator_pipe_producer_state = cutlass::make_producer_start_state(); dim3 block_id_in_cluster = cute::block_id_in_cluster(); - int32_t sm_id = static_cast(cutlass::arch::SmId()); // Calculate mask after cluster barrier arrival mainloop_pipeline.init_masks(cluster_shape, block_id_in_cluster); accumulator_pipeline.init_masks(cluster_shape, block_id_in_cluster); + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction. + // For the static grouped scheduler, the problem shapes + // might be produced by a previous kernel in global memory. + cutlass::arch::wait_on_dependent_grids(); + // TileID scheduler - TileScheduler scheduler(&shared_storage.clc_response[0], params.scheduler, block_id_in_cluster); - typename TileScheduler::WorkTileInfo work_tile_info = scheduler.initial_work_tile_info(cluster_shape); + TileScheduler scheduler( + (!IsTensorMapUpdateAsync || is_participant.sched || is_participant.tensor_map_updater) + ? &shared_storage.clc_response[0][0] + : &shared_storage.clc_response[1][0], + params.scheduler, + block_id_in_cluster + ); + + auto work_tile_info = [&] () { + if constexpr (IsTensorMapUpdateAsync) { + return scheduler.initial_work_tile_info(cluster_shape, [] (typename TileScheduler::CLCResponse response) { + CLCResponseWithAdditionalInformation response_with_additional_info = response; + response_with_additional_info.TMA_stage = 0; + response_with_additional_info.batch_changed = 1; + return response_with_additional_info; + }); + } + else { + return scheduler.initial_work_tile_info(cluster_shape); + } + } (); + + auto get_tma_desc_offset = [] ([[maybe_unused]] const auto& tile_info) { + if constexpr (IsTensorMapUpdateAsync) { + return tile_info.TMA_stage; + } + else { + return 0; + } + }; + + auto get_tensormap = [] (auto& tensormaps, [[maybe_unused]] auto tma_desc_offset) { + if constexpr (IsTensorMapUpdateAsync) { + return tensormaps[tma_desc_offset]; + } + else { + return tensormaps; + } + }; + auto cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); // @@ -699,44 +892,67 @@ class GemmUniversal< // When problem shapes are only on device, the grid launched may be larger than the total number of blocks across groups return; } - // In case user wants to engage less SMs than available on device - sm_id = blockIdx.x + (blockIdx.y * gridDim.x); } // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(work_tile_info.L_idx), 1); if (is_participant.main_load) { - auto load_inputs = collective_mainloop.load_init( - problem_shape_MNKL, params.mainloop, - shared_storage.tensors.mainloop, - shared_storage.tensormaps.mainloop, - params.hw_info.sm_count, sm_id, work_tile_info.L_idx); - // Ensure that the prefetched kernel does not touch // unflushed global memory prior to this instruction cutlass::arch::wait_on_dependent_grids(); + auto load_inputs = collective_mainloop.load_init( + problem_shape_MNKL, params.mainloop, + shared_storage.tensors.mainloop, + shared_storage.tensormaps[0].mainloop, + params.hw_info.sm_count, sm_id, problem_shape.groups(), work_tile_info.L_idx); + bool do_load_order_arrive = is_epi_load_needed; Tensor gA_mkl = get<0>(load_inputs); + // Fetch a copy of tensormaps for the CTA from Params - auto input_tensormaps = get(load_inputs); + auto input_tensormaps = [&] ([[maybe_unused]] auto inputs) { + if constexpr (IsTensorMapUpdateAsync) { + return pre_init_main_load_tensormaps; + } + else { + static constexpr size_t idx = rank(inputs) - 1; + return get(inputs); + } + } (load_inputs); + + if constexpr (IsTensorMapUpdateAsync) { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + } + + auto pad_inputs = [] (auto& inputs, [[maybe_unused]] auto tensormaps) { + if constexpr (IsTensorMapUpdateAsync) { + return cute::tuple_cat(inputs, cute::make_tuple(tensormaps)); + } + else { + return inputs; + } + }; // Initial batch's tensor address update // Even the first tile for a CTA can be from any of the batches. // And during initialization of the first TMA descriptor on host, we don't initialize to the first batch due to that args value being device-only. + bool is_first_iteration = true; bool did_batch_change = true; bool requires_clc_query = true; do { + auto tma_desc_offset = get_tma_desc_offset(work_tile_info); int32_t curr_batch = idx2crd(work_tile_info.L_idx, shape<4>(gA_mkl)); // Usually just returns work_tile_info.L_idx; if constexpr (IsGroupedGemmKernel) { problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(curr_batch), 1); } - if (did_batch_change) { + if (IsTensorMapUpdateAsync ? is_first_iteration : did_batch_change) { collective_mainloop.tensormaps_perform_update( - shared_storage.tensormaps.mainloop, + shared_storage.tensormaps[0].mainloop, params.mainloop, - input_tensormaps, + get_tensormap(input_tensormaps, tma_desc_offset), problem_shape, curr_batch ); @@ -764,10 +980,11 @@ class GemmUniversal< params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state, - load_inputs, + pad_inputs(load_inputs, get_tensormap(input_tensormaps, tma_desc_offset)), cta_coord_mnk, k_tile_iter, k_tile_prologue, - did_batch_change + IsTensorMapUpdateAsync ? is_first_iteration : did_batch_change, // did_batch_change + curr_batch ); mainloop_pipe_producer_state = mainloop_producer_state_next; @@ -780,10 +997,11 @@ class GemmUniversal< params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state, - load_inputs, + pad_inputs(load_inputs, get_tensormap(input_tensormaps, tma_desc_offset)), cta_coord_mnk, k_tile_iter_next, k_tile_count - k_tile_prologue, - false /* did_batch_change - prologue loads handle tensormap acquire */ + false, /* did_batch_change - prologue loads handle tensormap acquire */ + curr_batch ); mainloop_pipe_producer_state = mainloop_producer_state_next_; @@ -792,16 +1010,17 @@ class GemmUniversal< auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( work_tile_info, - clc_pipeline, - clc_pipe_consumer_state + tensor_map_ready_pipeline, + tensor_map_ready_pipe_consumer_state ); work_tile_info = next_work_tile_info; cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); requires_clc_query = increment_pipe; if (increment_pipe) { - ++clc_pipe_consumer_state; + ++tensor_map_ready_pipe_consumer_state; } // For subsequent tiles, check if batch changes and therefore, we need tensormap updates + is_first_iteration = false; did_batch_change = curr_batch != idx2crd(work_tile_info.L_idx, shape<4>(gA_mkl)); } while (work_tile_info.is_valid()); collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); @@ -809,6 +1028,12 @@ class GemmUniversal< } else if (is_participant.sched) { + + if constexpr (IsTensorMapUpdateAsync) { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + } + // Grouped GEMM uses static tile scheduler if constexpr (IsSchedDynamicPersistent) { // Whether a new CLC query must be performed. @@ -854,20 +1079,135 @@ class GemmUniversal< } else { - cutlass::arch::wait_on_dependent_grids(); + static_assert(IsTensorMapUpdateAsync || IsSchedDynamicPersistent, "We only support async tensor map update with static persistent scheduler"); + + auto update_tensor_map_stages = [&] (typename TileScheduler::CLCResponse next_work_tile_info_from_scheduler) { + if constexpr (IsTensorMapUpdateAsync) { + CLCResponseWithAdditionalInformation next_work_tile_info = next_work_tile_info_from_scheduler; + auto tensor_map_buffer_stage = work_tile_info.TMA_stage; + next_work_tile_info.batch_changed = work_tile_info.L_idx != next_work_tile_info.L_idx; + if (next_work_tile_info.batch_changed) { + ++tensor_map_buffer_stage; + } + next_work_tile_info.TMA_stage = tensor_map_buffer_stage; + return next_work_tile_info; + } + }; do { - auto [next_work_tile_info, increment_pipe] = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state); + auto [next_work_tile_info, increment_pipe] = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state, 1, update_tensor_map_stages); work_tile_info = next_work_tile_info; if (increment_pipe) { ++clc_pipe_producer_state; } } while (work_tile_info.is_valid()); - clc_pipeline.producer_tail(clc_pipe_producer_state); + + // Push additional invalid work items for all tensormap updater threads + for (int i = 0; i < NumTensorMapUpdaterThreads / NumThreadsPerWarp;) { + auto [next_work_tile_info, increment_pipe] = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state, 1, update_tensor_map_stages); + work_tile_info = next_work_tile_info; + if (increment_pipe) { + ++clc_pipe_producer_state; + ++i; + } + } + } + } + + else if (is_participant.tensor_map_updater) { + if constexpr (IsTensorMapUpdateAsync) { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + } + + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + if constexpr (IsTensorMapUpdateAsync) { + auto updater_id = canonical_warp_idx_sync() - static_cast(WarpCategory::TensorMapUpdater); + + clc_pipe_consumer_state += updater_id; + tensor_map_ready_pipe_producer_state += updater_id; + + auto tensormaps_mainloop = collective_mainloop.tensormaps_init( + params.mainloop,shared_storage.tensormaps[updater_id+1].mainloop, params.hw_info.sm_count, sm_id); + auto tensormaps_epilogue_load = collective_epilogue.template tensormaps_init( + params.epilogue, shared_storage.tensormaps[updater_id+1].epilogue, params.hw_info.sm_count, sm_id); + auto tensormaps_epilogue_store = collective_epilogue.template tensormaps_init( + params.epilogue, shared_storage.tensormaps[updater_id+1].epilogue, params.hw_info.sm_count, sm_id); + + auto update_tensor_map_and_increment_pipe_if_needed = [&] (auto &next_work_tile_info, auto &increment_pipe) { + auto next_batch = next_work_tile_info.L_idx; + auto did_batch_change = next_work_tile_info.batch_changed; + + if (increment_pipe) { + tensor_map_ready_pipeline.producer_acquire(tensor_map_ready_pipe_producer_state); + if (next_work_tile_info.is_valid() && did_batch_change) { + auto tma_desc_offset = get_tma_desc_offset(next_work_tile_info); + collective_mainloop.template tensormaps_perform_update( + shared_storage.tensormaps[updater_id+1].mainloop, + params.mainloop, + tensormaps_mainloop[tma_desc_offset], + problem_shape, + next_batch + ); + + if (collective_epilogue.is_producer_load_needed()) { + collective_epilogue.template tensormaps_perform_update( + shared_storage.tensormaps[updater_id+1].epilogue, + params.epilogue, + tensormaps_epilogue_load[tma_desc_offset], + problem_shape, + next_batch + ); + } + + collective_epilogue.template tensormaps_perform_update( + shared_storage.tensormaps[updater_id+1].epilogue, + params.epilogue, + tensormaps_epilogue_store[tma_desc_offset], + problem_shape, + next_batch + ); + + collective_mainloop.tensormaps_fence_acquire(tensormaps_mainloop[tma_desc_offset]); + + if (collective_epilogue.is_producer_load_needed()) { + collective_epilogue.template tensormaps_fence_acquire(tensormaps_epilogue_load[tma_desc_offset]); + } + collective_epilogue.template tensormaps_fence_acquire(tensormaps_epilogue_store[tma_desc_offset]); + } + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + shared_storage.clc_response[1][tensor_map_ready_pipe_producer_state.index()] = next_work_tile_info; + cutlass::arch::fence_view_async_shared(); + cute::tma_desc_wait_group(); + } + + // Signal the other warps that the TMA update is complete + tensor_map_ready_pipeline.producer_commit(tensor_map_ready_pipe_producer_state); + tensor_map_ready_pipe_producer_state += (NumTensorMapUpdaterThreads / NumThreadsPerWarp); + clc_pipe_consumer_state += (NumTensorMapUpdaterThreads / NumThreadsPerWarp); + } + }; + + do { + + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + update_tensor_map_and_increment_pipe_if_needed(next_work_tile_info, increment_pipe); + work_tile_info = next_work_tile_info; + + } while (work_tile_info.is_valid()); } } else if (is_participant.mma) { + // Tmem allocation sequence tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); __syncwarp(); @@ -876,18 +1216,12 @@ class GemmUniversal< collective_mainloop.set_tmem_offsets(tmem_storage, tmem_base_ptr); auto mma_inputs = collective_mainloop.mma_init(tmem_storage, shared_storage.tensors.mainloop); - do { - - // Fetch next work tile - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( - work_tile_info, - clc_pipeline, - clc_pipe_consumer_state - ); + if constexpr (IsTensorMapUpdateAsync) { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + } - if (increment_pipe) { - ++clc_pipe_consumer_state; - } + do { if constexpr (IsGroupedGemmKernel) { problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(work_tile_info.L_idx), 1); @@ -903,7 +1237,7 @@ class GemmUniversal< } }(); auto accumulator = collective_mainloop.slice_accumulator(tmem_storage, acc_stage); - if (is_mma_leader_cta) { + if (is_mma_leader_cta && k_tile_count > 0) { mainloop_pipe_consumer_state = collective_mainloop.mma( cute::make_tuple(mainloop_pipeline, accumulator_pipeline), cute::make_tuple(mainloop_pipe_consumer_state, accumulator_pipe_producer_state), @@ -914,7 +1248,20 @@ class GemmUniversal< ); accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); } - ++accumulator_pipe_producer_state; + if (k_tile_count > 0) { + ++accumulator_pipe_producer_state; + } + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + tensor_map_ready_pipeline, + tensor_map_ready_pipe_consumer_state + ); + + if (increment_pipe) { + ++tensor_map_ready_pipe_consumer_state; + } work_tile_info = next_work_tile_info; cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); @@ -950,6 +1297,12 @@ class GemmUniversal< } else if (is_participant.epi_load) { + + if constexpr (IsTensorMapUpdateAsync) { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + } + // Ensure that the prefetched kernel does not touch // unflushed global memory prior to this instruction cutlass::arch::wait_on_dependent_grids(); @@ -959,37 +1312,42 @@ class GemmUniversal< int current_wave = 0; // Fetch a copy of tensormaps for the CTA from Params - auto epi_load_tensormap = get<0>(collective_epilogue.load_init( - params.epilogue, shared_storage.tensormaps.epilogue, params.hw_info.sm_count, sm_id)); + auto epi_load_tensormap = [&] () { + if constexpr (IsTensorMapUpdateAsync) { + collective_epilogue.template load_init( + params.epilogue, + shared_storage.tensormaps[0].epilogue, + params.hw_info.sm_count, + sm_id + ); + return pre_init_epi_load_tensormaps; + } + else { + return get<0>(collective_epilogue.template load_init( + params.epilogue, shared_storage.tensormaps[0].epilogue, params.hw_info.sm_count, sm_id)); + } + } (); + // Initial batch's tensor address update // Even the first tile for a CTA can be from any of the batches. // And during initialization of the first TMA descriptor on host, we don't initialize to the first batch due to that args value being device-only. + bool is_first_iteration = true; bool did_batch_change = true; constexpr bool IsEpiLoad = true; do { int32_t curr_batch = work_tile_info.L_idx; - if (did_batch_change) { - collective_epilogue.template tensormaps_perform_update( - shared_storage.tensormaps.epilogue, + auto tma_desc_offset = get_tma_desc_offset(work_tile_info); + if (IsTensorMapUpdateAsync ? is_first_iteration : did_batch_change) { + collective_epilogue.template tensormaps_perform_update( + shared_storage.tensormaps[0].epilogue, params.epilogue, - epi_load_tensormap, + get_tensormap(epi_load_tensormap, tma_desc_offset), problem_shape, curr_batch ); } bool compute_epilogue = TileScheduler::compute_epilogue(work_tile_info, params.scheduler); - // Get current work tile and fetch next work tile - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( - work_tile_info, - clc_pipeline, - clc_pipe_consumer_state - ); - work_tile_info = next_work_tile_info; - - if (increment_pipe) { - ++clc_pipe_consumer_state; - } if (compute_epilogue) { if (do_load_order_wait) { @@ -1010,17 +1368,35 @@ class GemmUniversal< TileShape{}, TiledMma{}, shared_storage.tensors.epilogue, - cute::make_tuple(epi_load_tensormap, did_batch_change), + cute::make_tuple(get_tensormap(epi_load_tensormap, tma_desc_offset), IsTensorMapUpdateAsync ? is_first_iteration : did_batch_change), reverse_epi_n ); do_tail_load = true; } - current_wave++; + // Relevant only for OverlappingAccum cases. + // Only increment the wave if the problem shape K dimension is not 0, otherwise accumulator will be skipped. + if constexpr (IsOverlappingAccum) { + if (size<2>(problem_shape_MNKL) > 0) { + current_wave++; + } + } + + // Fetch the next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + tensor_map_ready_pipeline, + tensor_map_ready_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + if (increment_pipe) { + ++tensor_map_ready_pipe_consumer_state; + } // Calculate the cta coordinates of the next work tile cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); // For subsequent tiles, check if batch changes and therefore, we need tensormap updates + is_first_iteration = false; did_batch_change = curr_batch != work_tile_info.L_idx; } while (work_tile_info.is_valid()); @@ -1036,6 +1412,11 @@ class GemmUniversal< } else if (is_participant.epilogue) { + if constexpr (IsTensorMapUpdateAsync) { + // Register reconfiguration + arch::warpgroup_reg_alloc(); + } + // Wait for tmem allocate here tmem_allocation_result_barrier.arrive_and_wait(); uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; @@ -1044,20 +1425,43 @@ class GemmUniversal< auto warp_idx_in_epi = canonical_warp_idx_sync() - static_cast(WarpCategory::Epilogue); bool do_tail_store = false; // Fetch a copy of tensormaps for the CTA from Params - auto epi_store_tensormap = get<0>(collective_epilogue.store_init( - params.epilogue, shared_storage.tensormaps.epilogue, params.hw_info.sm_count, sm_id)); + auto epi_store_tensormap = [&] () { + if constexpr (IsTensorMapUpdateAsync) { + collective_epilogue.template store_init( + params.epilogue, + shared_storage.tensormaps[0].epilogue, + params.hw_info.sm_count, + sm_id + ); + + return collective_epilogue.template tensormaps_init( + params.epilogue, + shared_storage.tensormaps[0].epilogue, + params.hw_info.sm_count, + sm_id, + warp_idx_in_epi == 0 + ); + } + else { + return get<0>(collective_epilogue.template store_init( + params.epilogue, shared_storage.tensormaps[0].epilogue, params.hw_info.sm_count, sm_id)); + } + } (); + // Initial batch's tensor address update // Even the first tile for a CTA can be from any of the batches. // And during initialization of the first TMA descriptor on host, we don't initialize to the first batch due to that args value being device-only. + bool is_first_iteration = true; bool did_batch_change = true; constexpr bool IsEpiLoad = false; do { int32_t curr_batch = work_tile_info.L_idx; - if (did_batch_change && warp_idx_in_epi == 0) { - collective_epilogue.template tensormaps_perform_update( - shared_storage.tensormaps.epilogue, + auto tma_desc_offset = get_tma_desc_offset(work_tile_info); + if ((IsTensorMapUpdateAsync ? is_first_iteration : did_batch_change) && warp_idx_in_epi == 0) { + collective_epilogue.template tensormaps_perform_update( + shared_storage.tensormaps[0].epilogue, params.epilogue, - epi_store_tensormap, + get_tensormap(epi_store_tensormap, tma_desc_offset), problem_shape, curr_batch ); @@ -1065,12 +1469,12 @@ class GemmUniversal< // Fetch next work tile auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( work_tile_info, - clc_pipeline, - clc_pipe_consumer_state + tensor_map_ready_pipeline, + tensor_map_ready_pipe_consumer_state ); if (increment_pipe) { - ++clc_pipe_consumer_state; + ++tensor_map_ready_pipe_consumer_state; } // Accumulator stage slice @@ -1091,7 +1495,11 @@ class GemmUniversal< // // Epilogue and write to gD // - auto [load_state_next, store_state_next, acc_state_next] = collective_epilogue.template store( + auto [ + load_state_next, + store_state_next, + acc_state_next + ] = collective_epilogue.template store( epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline, @@ -1105,7 +1513,7 @@ class GemmUniversal< TiledMma{}, accumulator, shared_storage.tensors.epilogue, - cute::make_tuple(epi_store_tensormap, did_batch_change) + cute::make_tuple(get_tensormap(epi_store_tensormap, tma_desc_offset), IsTensorMapUpdateAsync ? is_first_iteration : did_batch_change) ); epi_load_pipe_consumer_state = load_state_next; epi_store_pipe_producer_state = store_state_next; @@ -1115,6 +1523,7 @@ class GemmUniversal< work_tile_info = next_work_tile_info; cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); // For subsequent tiles, check if batch changes and therefore, we need tensormap updates + is_first_iteration = false; did_batch_change = curr_batch != work_tile_info.L_idx; } while (work_tile_info.is_valid()); @@ -1139,6 +1548,11 @@ class GemmUniversal< } else { + if constexpr (IsTensorMapUpdateAsync) { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + } + } } }; diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_input_transform.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_input_transform.hpp index 76432e1..25d5c6e 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_input_transform.hpp +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_input_transform.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -148,9 +148,10 @@ class GemmUniversal< static constexpr uint32_t NumFixupBarriers = 1; static constexpr uint32_t CLCResponseSize = sizeof(typename TileScheduler::CLCResponse); - // Transfer registers from regular warps to Accum warps - static constexpr uint32_t GenericRegisterRequirement = 152; - static constexpr uint32_t AccumRegisterRequirement = 200; + // Register reconfiguration + static constexpr uint32_t GenericRegisterRequirement = CollectiveMainloop::GenericRegisterRequirement; + static constexpr uint32_t TransformRegisterRequirement = CollectiveMainloop::TransformRegisterRequirement; + static constexpr uint32_t AccumRegisterRequirement = CollectiveMainloop::AccumRegisterRequirement; // Pipeline and pipeline state types using Load2TransformPipeline = typename CollectiveMainloop::Load2TransformPipeline; @@ -222,7 +223,6 @@ class GemmUniversal< }; static constexpr int SharedStorageSize = sizeof(SharedStorage); - static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); // Host facing host arguments struct Arguments { @@ -413,6 +413,22 @@ class GemmUniversal< return dim3(MaxThreadsPerBlock, 1, 1); } + // Register alloc/dealloc behavior might change according to the underlying collective used + template + CUTLASS_DEVICE + static constexpr void + warpgroup_reg_reconfig() { + // Compute default-allocated registers per thread: round_down((512 / NumWG), 8) + constexpr int32_t MaxWarpGroupsPerBlock = ceil_div(MaxThreadsPerBlock, NumThreadsPerWarpGroup); + constexpr int32_t NumRegsPerThread = (512 / MaxWarpGroupsPerBlock) / 8 * 8; + if constexpr (NReg < NumRegsPerThread) { + arch::warpgroup_reg_dealloc(); + } + else if constexpr (NReg > NumRegsPerThread) { + arch::warpgroup_reg_alloc(); + } + } + CUTLASS_DEVICE void operator() (Params const& params, char* smem_buf) { @@ -420,6 +436,7 @@ class GemmUniversal< using namespace cute; using X = Underscore; + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); auto problem_shape = params.problem_shape; // Account for multiple epilogue and transformation warps @@ -649,18 +666,21 @@ class GemmUniversal< transform2mma_pipeline.init_masks(cluster_shape); mma2accum_pipeline.init_masks(cluster_shape); + // Allocate accumulators + auto acc_shape = collective_mainloop.partition_accumulator_shape(); + + // Ensure memory ops in this kernel are not done prior to completion of dependent grids. + cutlass::arch::wait_on_dependent_grids(); + // TileID scheduler TileScheduler scheduler(&shared_storage.clc_response[0], params.scheduler, block_id_in_cluster); - typename TileScheduler::WorkTileInfo work_tile_info = scheduler.initial_work_tile_info(cluster_shape); + typename TileScheduler::WorkTileInfo work_tile_info = scheduler.initial_work_tile_info(cluster_shape); auto cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(work_tile_info.L_idx), 1); - // Allocate accumulators - auto acc_shape = collective_mainloop.partition_accumulator_shape(); - // NOTE: we can assume the tmem buf starts at zero since we allocate all tmem in this kernel auto bulk_tmem = TiledMma::make_fragment_C(append(acc_shape, Int{})); @@ -674,7 +694,7 @@ class GemmUniversal< if (is_participant.main_load) { // Register reconfiguration - arch::warpgroup_reg_dealloc(); + warpgroup_reg_reconfig(); // Ensure that the prefetched kernel does not touch // unflushed global memory prior to this instruction @@ -788,7 +808,7 @@ class GemmUniversal< else if (is_participant.transformation) { // Register reconfiguration - arch::warpgroup_reg_dealloc(); + warpgroup_reg_reconfig(); // Signal the epilogue warps to proceed once the prologue is complete epilogue_throttle_barrier.arrive(); @@ -830,7 +850,7 @@ class GemmUniversal< else if (is_participant.sched) { // Register reconfiguration - arch::warpgroup_reg_dealloc(); + warpgroup_reg_reconfig(); // Signal the epilogue warps to proceed once the prologue is complete epilogue_throttle_barrier.arrive(); @@ -895,7 +915,7 @@ class GemmUniversal< else if (is_participant.mma) { // Register reconfiguration - arch::warpgroup_reg_dealloc(); + warpgroup_reg_reconfig(); // Allocate all tmem tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); @@ -963,7 +983,7 @@ class GemmUniversal< else if (is_participant.epi_load) { // Register reconfiguration - arch::warpgroup_reg_dealloc(); + warpgroup_reg_reconfig(); // Ensure that the prefetched kernel does not touch // unflushed global memory prior to this instruction @@ -1048,7 +1068,7 @@ class GemmUniversal< else if (is_participant.epilogue) { // Register reconfiguration - arch::warpgroup_reg_alloc(); + warpgroup_reg_reconfig(); // Throttle the epilogue warps to improve prologue performance static constexpr int epilogue_throttle_phase_bit = 0; @@ -1143,7 +1163,9 @@ class GemmUniversal< // support fixup operations needed by split-/stream-K. These operations are pushed // to the collective layer so that they can reuse the TMEM -> RF copy performed // at the collective layer. - auto [mma2accum_pipeline_state_next] = collective_epilogue( + auto [mma2accum_pipeline_state_next, epi_load_pipe_consumer_state_next] = collective_epilogue( + epi_load_pipeline, + epi_load_pipe_consumer_state, mma2accum_pipeline, mma2accum_pipeline_consumer_state, problem_shape_MNKL, @@ -1154,6 +1176,7 @@ class GemmUniversal< ); // Advance the mm2accum pipe mma2accum_pipeline_consumer_state = mma2accum_pipeline_state_next; + epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next; } work_tile_info = next_work_tile_info; @@ -1176,7 +1199,7 @@ class GemmUniversal< else { // Register reconfiguration - arch::warpgroup_reg_dealloc(); + warpgroup_reg_reconfig(); } } }; diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_mma_transform.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_mma_transform.hpp index 2ec1049..83ae76a 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_mma_transform.hpp +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_mma_transform.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -234,7 +234,6 @@ class GemmUniversal< }; static constexpr int SharedStorageSize = sizeof(SharedStorage); - static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); // Host facing host arguments struct Arguments { @@ -493,6 +492,7 @@ class GemmUniversal< using namespace cute; using X = Underscore; + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); auto problem_shape = params.problem_shape; // Account for more than one epilogue warp @@ -726,11 +726,6 @@ class GemmUniversal< mainloop_ab_pipeline.init_masks(cluster_shape, block_id_in_cluster); accumulator_pipeline.init_masks(cluster_shape, block_id_in_cluster); - // TileID scheduler - TileScheduler scheduler(&shared_storage.clc_response[0], params.scheduler, block_id_in_cluster); - typename TileScheduler::WorkTileInfo work_tile_info = scheduler.initial_work_tile_info(cluster_shape); - auto cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); - // // TMEM "Allocation" // @@ -740,6 +735,15 @@ class GemmUniversal< Tensor accumulators = cutlass::detail::make_sm100_accumulator( tiled_mma, acc_shape, EpilogueTile{}); + // Ensure memory ops in this kernel are not done prior to completion of dependent grids. + cutlass::arch::wait_on_dependent_grids(); + + // TileID scheduler + TileScheduler scheduler(&shared_storage.clc_response[0], params.scheduler, block_id_in_cluster); + + typename TileScheduler::WorkTileInfo work_tile_info = scheduler.initial_work_tile_info(cluster_shape); + auto cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + pipeline_init_wait(cluster_size); if constexpr (IsGroupedGemmKernel) { diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_cpasync_warpspecialized.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_cpasync_warpspecialized.hpp new file mode 100644 index 0000000..d0fd31a --- /dev/null +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_cpasync_warpspecialized.hpp @@ -0,0 +1,794 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/workspace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/fast_math.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/barrier.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/sm100_tile_scheduler.hpp" +#include "cutlass/pipeline/pipeline.hpp" + +#include "cute/tensor.hpp" +#include "cute/arch/tmem_allocator_sm100.hpp" +#include "cute/atom/mma_atom.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileSchedulerTag_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileSchedulerTag_, + cute::enable_if_t< + cutlass::detail::is_kernel_tag_of_v>> +{ +public: + using ProblemShape = ProblemShape_; + static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + static constexpr bool IsGdcEnabled = false; + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static_assert(ArchTag::kMinComputeCapability >= 100); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + static constexpr bool IsComplex = CollectiveEpilogue::NumAccumulatorMtxs == 2; + + // CLC pipeline depth + // determines how many waves (stages-1) a warp can race ahead + static constexpr uint32_t SchedulerPipelineStageCount = DispatchPolicy::Schedule::SchedulerPipelineStageCount; + + // TileID scheduler + // Get Blk and Scheduling tile shapes + using CtaShape_MNK = typename CollectiveMainloop::CtaShape_MNK; + using AtomThrShapeMNK = typename CollectiveMainloop::AtomThrShapeMNK; + + static_assert(size(AtomThrShapeMNK{}) == 1, "Lower alignment kernel only supports 1x1x1 cluster shape."); + using TileSchedulerTag = TileSchedulerTag_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileSchedulerTag, ArchTag, CtaShape_MNK, ClusterShape, SchedulerPipelineStageCount>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + // Warp specialization thread count per threadblock + static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMMAThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEmptyThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMainloopLoadThreads = CollectiveMainloop::NumLoadThreads; // 4 warps + static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueThreads = CollectiveEpilogue::ThreadCount; + static constexpr uint32_t NumEpilogueWarps = NumEpilogueThreads / NumThreadsPerWarp; + + static constexpr uint32_t MaxThreadsPerBlock = NumSchedThreads + + NumMainloopLoadThreads + NumMMAThreads + + NumEpilogueLoadThreads + NumEpilogueThreads + NumEmptyThreads; + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + static constexpr uint32_t NumFixupBarriers = 1; + static constexpr uint32_t CLCResponseSize = sizeof(typename TileScheduler::CLCResponse); + + static constexpr bool IsSchedDynamicPersistent = TileScheduler::IsDynamicPersistent; + + // Pipelines and pipeline states + static constexpr uint32_t AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; + + // Pipeline and pipeline state types + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + using MainloopPipelineState = typename CollectiveMainloop::MainloopPipelineState; + + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + using EpiLoadPipelineState = typename CollectiveEpilogue::LoadPipelineState; + + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + using EpiStorePipelineState = typename CollectiveEpilogue::StorePipelineState; + + using AccumulatorPipeline = cutlass::PipelineUmmaAsync; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + + using CLCPipeline = cutlass::PipelineCLCFetchAsync; + using CLCPipelineState = typename CLCPipeline::PipelineState; + + using TmemAllocator = cute::TMEM::Allocator1Sm; + + static constexpr int EpilogueWarpRegs = 248; + static constexpr int NonEpilogueWarpRegs = 128; + + // Kernel level shared memory storage + struct SharedStorage { + struct PipelineStorage : cute::aligned_struct<16, _1> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + using CLCPipelineStorage = typename CLCPipeline::SharedStorage; + using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) CLCPipelineStorage clc; + alignas(16) AccumulatorPipelineStorage accumulator; + alignas(16) arch::ClusterBarrier tmem_dealloc; + } pipelines; + + alignas(16) typename TileScheduler::CLCResponse clc_response[SchedulerPipelineStageCount]; + uint32_t tmem_base_ptr; + + struct TensorStorage : cute::aligned_struct<128, _1> { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Host facing host arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel device entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + }; + + enum class WarpCategory : int32_t { + MMA = 0, + Sched = 1, + EpilogueLoad = 3, + Epilogue = 4, + MainloopLoad = 8 + }; + + struct IsParticipant { + uint32_t mma = false; + uint32_t sched = false; + uint32_t epi_load = false; + uint32_t epilogue = false; + uint32_t main_load = false; + }; + + // Convert to underlying arguments. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + (void) workspace; + auto problem_shape = args.problem_shape; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + static constexpr uint32_t NumEpilogueSubTiles = 1; + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count != 0) { + CUTLASS_TRACE_HOST(" WARNING: SM100 tile scheduler does not allow for user specified SM counts.\n" + " To restrict a kernel's resource usage, consider using CUDA driver APIs instead (green contexts)."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + // Epilogue + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* mainloop_workspace = nullptr; + + // Tile scheduler + void* scheduler_workspace = workspace_ptr + workspace_offset; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + return { + args.mode, + args.problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), + hw_info, + TileScheduler::to_underlying_arguments( + problem_shape_MNKL, TileShape{}, AtomThrShapeMNK{}, ClusterShape{}, + args.hw_info, args.scheduler, scheduler_workspace + ) + }; + } + + static bool + can_implement(Arguments const& args) { + bool implementable = (args.mode == GemmUniversalMode::kGemm) or + (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + + static constexpr int MaxClusterSize = 16; + implementable &= size(ClusterShape{}) <= MaxClusterSize; + + return implementable; + } + + static size_t + get_workspace_size(Arguments const& args) { + static constexpr uint32_t NumEpilogueSubTiles = 1; + size_t workspace_size = 0; + + // Epilogue + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + // Tile scheduler + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + static constexpr uint32_t NumEpilogueSubTiles = 1; + + // Epilogue + status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + status = cutlass::Status::kSuccess; + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + // Tile scheduler + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs, cuda_adapter); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumFixupBarriers); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + static dim3 + get_grid_shape(Params const& params) { + auto cluster_shape = ClusterShape{}; + auto blk_shape = CtaShape_MNK{}; + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + return TileScheduler::get_grid_shape( + params.scheduler, + problem_shape_MNKL, + TileShape{}, + AtomThrShapeMNK{}, + cluster_shape, + params.hw_info + ); + + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + +public: + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + + using namespace cute; + using X = Underscore; + + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); + // Separate out problem shape for convenience + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + auto M = get<0>(problem_shape_MNKL); + auto N = get<1>(problem_shape_MNKL); + auto K = get<2>(problem_shape_MNKL); + auto L = get<3>(problem_shape_MNKL); + + // Account for more than one epilogue warp + int warp_idx = canonical_warp_idx_sync(); + WarpCategory warp_category = warp_idx < static_cast(WarpCategory::Epilogue) ? WarpCategory(warp_idx) + : warp_idx < static_cast(WarpCategory::MainloopLoad) ? WarpCategory::Epilogue + : WarpCategory::MainloopLoad; + uint32_t lane_predicate = cute::elect_one_sync(); + auto tile_shape = TileShape{}; + auto cluster_shape = ClusterShape{}; + constexpr int cluster_size = size(ClusterShape{}); + int cta_rank_in_cluster = cute::block_rank_in_cluster(); + bool is_first_cta_in_cluster = cta_rank_in_cluster == 0; + int cta_coord_v = cta_rank_in_cluster % size<0>(typename TiledMma::AtomThrID{}); + bool is_mma_leader_cta = cta_coord_v == 0; + int mma_leader_ctas = size(shape_div(cluster_shape, AtomThrShapeMNK{})); + [[maybe_unused]] uint32_t mma_peer_cta_rank = cta_rank_in_cluster; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Do we load source tensor C or other aux inputs + bool is_epi_load_needed = collective_epilogue.is_producer_load_needed(); + + IsParticipant is_participant = { + (warp_category == WarpCategory::MMA) && is_mma_leader_cta, // mma + (warp_category == WarpCategory::Sched) && is_first_cta_in_cluster, // sched + (warp_category == WarpCategory::EpilogueLoad) && is_epi_load_needed, // epi_load + (warp_category == WarpCategory::Epilogue), // epilogue + (warp_category == WarpCategory::MainloopLoad) // main_load + }; + + // Mainloop Load pipeline + typename MainloopPipeline::Params mainloop_pipeline_params; + if (WarpCategory::MainloopLoad == warp_category) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (WarpCategory::MMA == warp_category) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + + mainloop_pipeline_params.producer_arv_count = NumMainloopLoadThreads; + mainloop_pipeline_params.consumer_arv_count = 1; // Only UMMA consumes the A and B buffers + mainloop_pipeline_params.dst_blockid = cta_rank_in_cluster; + mainloop_pipeline_params.initializing_warp = 0; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, cluster_shape); + + // Epilogue Load pipeline + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (WarpCategory::EpilogueLoad == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cta_rank_in_cluster; + epi_load_pipeline_params.producer_arv_count = NumEpilogueLoadThreads; + epi_load_pipeline_params.consumer_arv_count = NumEpilogueThreads; + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + epi_load_pipeline_params.initializing_warp = 3; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // CLC pipeline + typename CLCPipeline::Params clc_pipeline_params; + if (WarpCategory::Sched == warp_category) { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::ProducerConsumer; + } + else { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::Consumer; + } + clc_pipeline_params.producer_blockid = 0; + clc_pipeline_params.producer_arv_count = 1; + clc_pipeline_params.consumer_arv_count = NumSchedThreads + cluster_size * + (NumMainloopLoadThreads + NumEpilogueThreads + NumMMAThreads); + + clc_pipeline_params.transaction_bytes = CLCResponseSize; + clc_pipeline_params.initializing_warp = 1; + CLCPipeline clc_pipeline(shared_storage.pipelines.clc, clc_pipeline_params, cluster_shape); + + // Mainloop-Epilogue pipeline + typename AccumulatorPipeline::Params accumulator_pipeline_params; + if (WarpCategory::MMA == warp_category) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; + } + // Only one producer thread arrives on this barrier. + accumulator_pipeline_params.producer_arv_count = 1; + accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * NumEpilogueThreads; + accumulator_pipeline_params.initializing_warp = 2; + AccumulatorPipeline accumulator_pipeline(shared_storage.pipelines.accumulator, accumulator_pipeline_params, cluster_shape); + + // Tmem allocator + TmemAllocator tmem_allocator{}; + + // Sync allocation status between MMA and epilogue warps within CTA + arch::NamedBarrier tmem_allocation_result_barrier(NumMMAThreads + NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + // Sync deallocation status between MMA warps of peer CTAs + arch::ClusterBarrier& tmem_deallocation_result_barrier = shared_storage.pipelines.tmem_dealloc; + [[maybe_unused]] uint32_t dealloc_barrier_phase = 0; + + MainloopPipelineState mainloop_pipe_consumer_state; + MainloopPipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + + EpiLoadPipelineState epi_load_pipe_consumer_state; + EpiLoadPipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + + // epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + EpiStorePipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + CLCPipelineState clc_pipe_consumer_state; + CLCPipelineState clc_pipe_producer_state = cutlass::make_producer_start_state(); + + AccumulatorPipelineState accumulator_pipe_consumer_state; + AccumulatorPipelineState accumulator_pipe_producer_state = cutlass::make_producer_start_state(); + + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer threadblocks in the cluster + pipeline_init_arrive_relaxed(cluster_size); + + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + // TileID scheduler + TileScheduler scheduler(&shared_storage.clc_response[0], params.scheduler, block_id_in_cluster); + typename TileScheduler::WorkTileInfo work_tile_info = scheduler.initial_work_tile_info(cluster_shape); + auto cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + + // + // TMEM "Allocation" + // + auto acc_shape = collective_mainloop.partition_accumulator_shape(); + auto bulk_tmem = TiledMma::make_fragment_C(append(acc_shape, + Int{})); + + // + // END PROLOGUE + // + + // Synchronization call. Blocks until barriers are initialized in shared memory. + pipeline_init_wait(cluster_size); + + if (is_participant.main_load) { + cutlass::arch::warpgroup_reg_dealloc(); + + auto load_inputs = collective_mainloop.load_init( + problem_shape_MNKL, params.mainloop, shared_storage.tensors.mainloop); + Tensor gA_mkl = get<0>(load_inputs); + + do { + // Get current work tile and fetch next work tile + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + auto k_tile_iter = scheduler.get_k_tile_iterator(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}, shape<3>(gA_mkl)); + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + + auto [mainloop_producer_state_next, unused_] = collective_mainloop.load( + params.mainloop, + mainloop_pipeline, + mainloop_pipe_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter, k_tile_count + ); + mainloop_pipe_producer_state = mainloop_producer_state_next; + + // Sync warp to prevent non-participating threads entering next wave early + __syncwarp(); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + } while (work_tile_info.is_valid()); + + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + + } + + else if (is_participant.sched) { + cutlass::arch::warpgroup_reg_dealloc(); + + if constexpr (IsSchedDynamicPersistent) { + // Whether a new CLC query must be performed. + // See comment below where this variable is updated for a description of + // why this variable is needed. + bool requires_clc_query = true; + + cutlass::arch::wait_on_dependent_grids(); + + do { + if (requires_clc_query) { + // Query next clcID and update producer state + clc_pipe_producer_state = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state); + } + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + // Only perform a new CLC query if we consumed a new CLC query result in + // `fetch_next_work`. An example of a case in which CLC `fetch_next_work` does + // not consume a new CLC query response is when processing stream-K units. + // The current stream-K scheduler uses single WorkTileInfo to track multiple + // (potentially-partial) tiles to be computed via stream-K. In this case, + // `fetch_next_work` simply performs in-place updates on the existing WorkTileInfo, + // rather than consuming a CLC query response. + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + work_tile_info = next_work_tile_info; + } while (work_tile_info.is_valid()); + clc_pipeline.producer_tail(clc_pipe_producer_state); + } + } + + else if (is_participant.mma) { + cutlass::arch::warpgroup_reg_dealloc(); + + // Tmem allocation sequence + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem.data() = tmem_base_ptr; + + // Pass the acc with tuple type since the bgrad kernel change the mma_init API + auto mma_inputs = collective_mainloop.mma_init(params.mainloop, cute::make_tuple(bulk_tmem, bulk_tmem), shared_storage.tensors.mainloop); + do { + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + + int acc_stage = accumulator_pipe_producer_state.index(); + Tensor accumulators = bulk_tmem(_,_,_,acc_stage); + mainloop_pipe_consumer_state = collective_mainloop.mma( + mainloop_pipeline, + mainloop_pipe_consumer_state, + // Pass the acc with tuple type since the bgrad kernel change the mma API + cute::make_tuple(accumulators, accumulators), + mma_inputs, + k_tile_count + ); + + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + + ++accumulator_pipe_producer_state; + work_tile_info = next_work_tile_info; + } while (work_tile_info.is_valid()); + // Release the right to allocate before deallocations so that the next CTA can rasterize + tmem_allocator.release_allocation_lock(); + + accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); + + // Free entire tmem allocation + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + else if (is_participant.epi_load) { + cutlass::arch::warpgroup_reg_dealloc(); + + bool do_tail_load = false; + do { + bool compute_epilogue = TileScheduler::compute_epilogue(work_tile_info, params.scheduler); + + // Get current work tile and fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + if (compute_epilogue) { + + epi_load_pipe_producer_state = collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + shared_storage.tensors.epilogue + ); + + do_tail_load = true; + } + + // Calculate the cta coordinates of the next work tile + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + } while (work_tile_info.is_valid()); + if (do_tail_load) { + collective_epilogue.load_tail( + epi_load_pipeline, epi_load_pipe_producer_state, + epi_store_pipeline, epi_store_pipe_producer_state); + } + } + + else if (is_participant.epilogue) { + cutlass::arch::warpgroup_reg_alloc(); + + // Wait for tmem allocate here + tmem_allocation_result_barrier.arrive_and_wait(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem.data() = tmem_base_ptr; + + bool do_tail_store = false; + do { + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + // Accumulator stage slice + int acc_stage = accumulator_pipe_consumer_state.index(); + Tensor accumulators = bulk_tmem(_,_,_,acc_stage); + + accumulator_pipe_consumer_state = scheduler.template fixup( + TiledMma{}, + work_tile_info, + accumulators, + accumulator_pipeline, + accumulator_pipe_consumer_state, + typename CollectiveEpilogue::CopyOpT2R{} + ); + + // + // Epilogue and write to gD + // + if (scheduler.compute_epilogue(work_tile_info)) { + auto [load_state_next, store_state_next, acc_state_next] = collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + accumulator_pipeline, + accumulator_pipe_consumer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + accumulators, + shared_storage.tensors.epilogue + ); + epi_load_pipe_consumer_state = load_state_next; + epi_store_pipe_producer_state = store_state_next; + accumulator_pipe_consumer_state = acc_state_next; + do_tail_store = true; + } + + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + + } while (work_tile_info.is_valid()); + if (do_tail_store) { + collective_epilogue.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state, + epi_store_pipeline, epi_store_pipe_producer_state, + CtaShape_MNK{}); + } + } + + else { + cutlass::arch::warpgroup_reg_dealloc(); + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_mixed_tma_cpasync_warpspecialized.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_mixed_tma_cpasync_warpspecialized.hpp new file mode 100644 index 0000000..8b0fc43 --- /dev/null +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_mixed_tma_cpasync_warpspecialized.hpp @@ -0,0 +1,1000 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/workspace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/fast_math.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/barrier.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/sm100_tile_scheduler.hpp" +#include "cutlass/pipeline/pipeline.hpp" + +#include "cute/tensor.hpp" +#include "cute/arch/tmem_allocator_sm100.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/gemm/kernel/gemm_universal_decl.h" + +#include "cutlass/gemm/kernel/tile_scheduler.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileSchedulerTag_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileSchedulerTag_, + cute::enable_if_t< + cutlass::detail::is_kernel_tag_of_v>> +{ +public: + using ProblemShape = ProblemShape_; + + static constexpr bool IsGroupedGemmKernel = cutlass::gemm::detail::is_moe_problem_shape::value; + static constexpr bool IsMoEScheduler = false; // stub for MoE scheduler, which accepts a MoEProblemShape instead of GroupProblemShape + + CUTLASS_HOST_DEVICE + static auto get_problem_shape_gemm(ProblemShape const& shape) { + if constexpr (IsGroupedGemmKernel) { + auto problem_shape_MNK = shape.get_host_problem_shape(0); //gets the maximum problem shape here + auto problem_shape_MNKL = append<4>(problem_shape_MNK, shape.groups()); //appends num_groups to it + return problem_shape_MNKL; + } + else { + return shape; + } + } + + CUTLASS_HOST_DEVICE + static auto get_problem_shape_scheduler(ProblemShape const& shape) { + if constexpr (IsMoEScheduler) { + return shape; + } + else { + return shape; + } + } + + template + CUTLASS_HOST_DEVICE + static auto get_effective_shape(ProblemShape const& shape, WorkTileInfo const& work_tile_info) { + if constexpr (IsGroupedGemmKernel) { + return append<4>(shape.get_problem_shape(work_tile_info.L_idx), Int<1>{}); + } + else { + return append<4>(shape, Int<1>{}); + } + } + + using ProblemShapeGemm = decltype(get_problem_shape_gemm(ProblemShape{})); + + static_assert(rank(ProblemShapeGemm{}) == 3 or rank(ProblemShapeGemm{}) == 4, + "ProblemShape{} should be or "); + static constexpr bool IsGdcEnabled = false; + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static_assert(ArchTag::kMinComputeCapability >= 100); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using EpilogueTile = typename CollectiveEpilogue::EpilogueTile; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + static constexpr bool IsComplex = CollectiveEpilogue::NumAccumulatorMtxs == 2; + + // CLC pipeline depth + // determines how many waves (stages-1) a warp can race ahead + static constexpr uint32_t SchedulerPipelineStageCount = DispatchPolicy::Schedule::SchedulerPipelineStageCount; + static constexpr bool IsOverlappingAccum = DispatchPolicy::IsOverlappingAccum; + static_assert(!IsOverlappingAccum, "TMA+CPASYNC kernel currently only supports non-overlapping accum."); + + // TileID scheduler + // Get Blk and Scheduling tile shapes + using CtaShape_MNK = typename CollectiveMainloop::CtaShape_MNK; + using AtomThrShapeMNK = typename CollectiveMainloop::AtomThrShapeMNK; + + static_assert(size(AtomThrShapeMNK{}) == 1, "Lower alignment kernel only supports 1x1x1 cluster shape."); + using TileSchedulerTag = cute::conditional_t; + using TileScheduler = typename detail::TileSchedulerSelector< + TileSchedulerTag, ArchTag, CtaShape_MNK, ClusterShape, SchedulerPipelineStageCount, ProblemShape>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + // Warp specialization thread count per threadblock + static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMMAThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEmptyThreads = 0; + static constexpr uint32_t NumMainloopTMALoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMainloopCpAsyncLoadThreads = CollectiveMainloop::NumLoadThreadsCpAsync; // 4 warps + static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueThreads = CollectiveEpilogue::ThreadCount; + static constexpr uint32_t NumEpilogueWarps = NumEpilogueThreads / NumThreadsPerWarp; + + static constexpr uint32_t MaxThreadsPerBlock = NumSchedThreads + + NumMainloopTMALoadThreads + NumMainloopCpAsyncLoadThreads + + NumMMAThreads + + NumEpilogueLoadThreads + NumEpilogueThreads + NumEmptyThreads; + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + static constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_load_pipe_increment(CtaShape_MNK{}); + + static constexpr uint32_t NumFixupBarriers = 1; + static constexpr uint32_t CLCResponseSize = sizeof(typename TileScheduler::CLCResponse); + + static constexpr bool IsSchedDynamicPersistent = TileScheduler::IsDynamicPersistent; + + // Pipelines and pipeline states + static constexpr uint32_t AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; + + // Pipeline and pipeline state types + using MainloopPipelineTMA = typename CollectiveMainloop::MainloopPipelineTMA; + using MainloopPipelineTMAState = typename CollectiveMainloop::MainloopPipelineTMAState; + using MainloopPipelineCpAsync = typename CollectiveMainloop::MainloopPipelineCpAsync; + using MainloopPipelineCpAsyncState = typename CollectiveMainloop::MainloopPipelineCpAsyncState; + + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + using EpiLoadPipelineState = typename CollectiveEpilogue::LoadPipelineState; + + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + using EpiStorePipelineState = typename CollectiveEpilogue::StorePipelineState; + + using AccumulatorPipeline = cutlass::PipelineUmmaAsync; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + + // using CLCPipeline = cutlass::PipelineCLCFetchAsync; + using CLCPipeline = cute::conditional_t, + cutlass::PipelineAsync>; + using CLCPipelineState = typename CLCPipeline::PipelineState; + + using TmemAllocator = cute::TMEM::Allocator1Sm; + + // Kernel level shared memory storage + struct SharedStorage { + struct PipelineStorage : cute::aligned_struct<16, _1> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + using CLCPipelineStorage = typename CLCPipeline::SharedStorage; + using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) CLCPipelineStorage clc; + alignas(16) AccumulatorPipelineStorage accumulator; + alignas(16) arch::ClusterBarrier tmem_dealloc; + } pipelines; + + alignas(16) typename TileScheduler::CLCResponse clc_response[SchedulerPipelineStageCount]; + uint32_t tmem_base_ptr; + + struct TensorStorage : cute::aligned_struct<128, _1> { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); + + // Host facing host arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel device entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + ProblemShapeGemm problem_shape_gemm{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + }; + + enum class WarpCategory : int32_t { + MMA = 0, + Sched = 1, + MainloopLoadTMA = 2, + EpilogueLoad = 3, + Epilogue = 4, + MainloopLoadCpAsync = 8 + }; + + struct IsParticipant { + uint32_t mma = false; + uint32_t sched = false; + uint32_t main_load_tma = false; + uint32_t epi_load = false; + uint32_t epilogue = false; + uint32_t main_load_cpasync = false; + }; + + // Convert to underlying arguments. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + (void) workspace; + auto problem_shape = args.problem_shape; + auto problem_shape_gemm = get_problem_shape_gemm(args.problem_shape); + + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (IsGroupedGemmKernel && sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + else if (!IsGroupedGemmKernel && sm_count != 0) { + CUTLASS_TRACE_HOST(" WARNING: SM100 tile scheduler does not allow for user specified SM counts.\n" + " To restrict a kernel's resource usage, consider using CUDA driver APIs instead (green contexts)."); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + // Epilogue + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* mainloop_workspace = nullptr; + + // Tile scheduler + void* scheduler_workspace = workspace_ptr + workspace_offset; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + TileSchedulerParams scheduler; + if constexpr (IsGroupedGemmKernel) { + scheduler = TileScheduler::to_underlying_arguments( + problem_shape, TileShape{}, AtomThrShapeMNK{}, ClusterShape{}, + args.hw_info, args.scheduler, scheduler_workspace); + } + else { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + scheduler = TileScheduler::to_underlying_arguments( + problem_shape, TileShape{}, AtomThrShapeMNK{}, ClusterShape{}, + args.hw_info, args.scheduler, scheduler_workspace + ); + } + + return { + args.mode, + args.problem_shape, + problem_shape_gemm, + CollectiveMainloop::to_underlying_arguments(problem_shape_gemm, args.mainloop, mainloop_workspace), + CollectiveEpilogue::to_underlying_arguments(problem_shape_gemm, args.epilogue, epilogue_workspace), + hw_info, + scheduler + }; + } + + static bool + can_implement(Arguments const& args) { + bool implementable = true; + + if constexpr (IsGroupedGemmKernel) { + implementable &= args.mode == GemmUniversalMode::kGrouped; + implementable &= rank(typename ProblemShape::UnderlyingProblemShape{}) == 3; + } + else { + implementable &= (args.mode == GemmUniversalMode::kGemm) or + (args.mode == GemmUniversalMode::kBatched && rank(ProblemShapeGemm{}) == 4); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + + auto problem_shape_gemm = get_problem_shape_gemm(args.problem_shape); + implementable &= CollectiveMainloop::can_implement(problem_shape_gemm, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(problem_shape_gemm, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + + static constexpr int MaxClusterSize = 16; + implementable &= size(ClusterShape{}) <= MaxClusterSize; + + return implementable; + } + + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_size = 0; + + auto problem_shape_gemm = get_problem_shape_gemm(args.problem_shape); + + // Epilogue + workspace_size += CollectiveEpilogue::get_workspace_size(problem_shape_gemm, args.epilogue); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + // Tile scheduler + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + auto problem_shape_gemm = get_problem_shape_gemm(args.problem_shape); + + // Epilogue + status = CollectiveEpilogue::initialize_workspace(problem_shape_gemm, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(problem_shape_gemm, args.epilogue); + status = cutlass::Status::kSuccess; + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + // Tile scheduler + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs, cuda_adapter); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumFixupBarriers); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + static dim3 + get_grid_shape(Params const& params) { + auto cluster_shape = ClusterShape{}; + + dim3 grid_shape; + if constexpr (IsGroupedGemmKernel) { + grid_shape = TileScheduler::get_grid_shape( + params.scheduler, + params.problem_shape, + TileShape{}, + AtomThrShapeMNK{}, + cluster_shape, + params.hw_info); + } + else { + auto problem_shape_MNKL = append<4>(params.problem_shape, 1); + grid_shape = TileScheduler::get_grid_shape( + params.scheduler, + problem_shape_MNKL, + TileShape{}, + AtomThrShapeMNK{}, + cluster_shape, + params.hw_info); + } + return grid_shape; + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + + using namespace cute; + using X = Underscore; + // Separate out problem shape for convenience + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape_gemm, Int<1>{}); + auto M = get<0>(problem_shape_MNKL); + auto N = get<1>(problem_shape_MNKL); + auto K = get<2>(problem_shape_MNKL); + auto L = get<3>(problem_shape_MNKL); + + // Account for more than one epilogue warp + int warp_idx = canonical_warp_idx_sync(); + WarpCategory warp_category = warp_idx < static_cast(WarpCategory::Epilogue) ? WarpCategory(warp_idx) + : warp_idx < static_cast(WarpCategory::MainloopLoadCpAsync) ? WarpCategory::Epilogue + : WarpCategory::MainloopLoadCpAsync; + uint32_t lane_predicate = cute::elect_one_sync(); + auto tile_shape = TileShape{}; + auto cluster_shape = ClusterShape{}; + constexpr int cluster_size = size(ClusterShape{}); + int cta_rank_in_cluster = cute::block_rank_in_cluster(); + bool is_first_cta_in_cluster = cta_rank_in_cluster == 0; + int cta_coord_v = cta_rank_in_cluster % size<0>(typename TiledMma::AtomThrID{}); + bool is_mma_leader_cta = cta_coord_v == 0; + int mma_leader_ctas = size(shape_div(cluster_shape, AtomThrShapeMNK{})); + [[maybe_unused]] uint32_t mma_peer_cta_rank = cta_rank_in_cluster; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop(params.mainloop); + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Do we load source tensor C or other aux inputs + bool is_epi_load_needed = collective_epilogue.is_producer_load_needed(); + + IsParticipant is_participant = { + (warp_category == WarpCategory::MMA) && is_mma_leader_cta, // mma + (warp_category == WarpCategory::Sched) && is_first_cta_in_cluster, // sched + (warp_category == WarpCategory::MainloopLoadTMA), // main_load_tma + (warp_category == WarpCategory::EpilogueLoad) && is_epi_load_needed, // epi_load + (warp_category == WarpCategory::Epilogue), // epilogue + (warp_category == WarpCategory::MainloopLoadCpAsync) // main_load_cpasync + }; + + // Mainloop Load pipeline (TMA) + typename MainloopPipelineTMA::Params mainloop_pipeline_tma_params; + if (WarpCategory::MainloopLoadTMA == warp_category) { + mainloop_pipeline_tma_params.role = MainloopPipelineTMA::ThreadCategory::Producer; + } + if (WarpCategory::MMA == warp_category) { + mainloop_pipeline_tma_params.role = MainloopPipelineTMA::ThreadCategory::Consumer; + } + + mainloop_pipeline_tma_params.is_leader = lane_predicate && is_mma_leader_cta && is_participant.main_load_tma; + mainloop_pipeline_tma_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + mainloop_pipeline_tma_params.initializing_warp = 0; + MainloopPipelineTMA mainloop_pipeline_tma(shared_storage.pipelines.mainloop.tma, + mainloop_pipeline_tma_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{}); // Delay mask calculation + + // Mainloop Load pipeline (CpAsync) + typename MainloopPipelineCpAsync::Params mainloop_pipeline_cpasync_params; + if (WarpCategory::MainloopLoadCpAsync == warp_category) { + mainloop_pipeline_cpasync_params.role = MainloopPipelineCpAsync::ThreadCategory::Producer; + } + if (WarpCategory::MMA == warp_category) { + mainloop_pipeline_cpasync_params.role = MainloopPipelineCpAsync::ThreadCategory::Consumer; + } + + mainloop_pipeline_cpasync_params.producer_arv_count = NumMainloopCpAsyncLoadThreads; + mainloop_pipeline_cpasync_params.consumer_arv_count = 1; // Only UMMA consumes the A and B buffers + mainloop_pipeline_cpasync_params.dst_blockid = cta_rank_in_cluster; + mainloop_pipeline_cpasync_params.initializing_warp = 0; + MainloopPipelineCpAsync mainloop_pipeline_cpasync(shared_storage.pipelines.mainloop.cpasync, mainloop_pipeline_cpasync_params, cluster_shape); + + // Epilogue Load pipeline + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (WarpCategory::EpilogueLoad == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cta_rank_in_cluster; + epi_load_pipeline_params.producer_arv_count = NumEpilogueLoadThreads; + epi_load_pipeline_params.consumer_arv_count = NumEpilogueThreads; + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + epi_load_pipeline_params.initializing_warp = 3; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // CLC pipeline + typename CLCPipeline::Params clc_pipeline_params; + if (WarpCategory::Sched == warp_category) { + clc_pipeline_params.role = IsSchedDynamicPersistent ? CLCPipeline::ThreadCategory::ProducerConsumer : CLCPipeline::ThreadCategory::Producer; + } + else { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::Consumer; + } + clc_pipeline_params.producer_arv_count = 1; + + if constexpr (IsSchedDynamicPersistent) { + clc_pipeline_params.producer_blockid = 0; + clc_pipeline_params.consumer_arv_count = NumSchedThreads + cluster_size * + (NumMainloopTMALoadThreads + NumMainloopCpAsyncLoadThreads + NumEpilogueThreads + NumMMAThreads); + clc_pipeline_params.transaction_bytes = CLCResponseSize; + } + else { + clc_pipeline_params.consumer_arv_count = NumMainloopTMALoadThreads + NumMainloopCpAsyncLoadThreads + NumEpilogueThreads + NumMMAThreads; + } + + clc_pipeline_params.initializing_warp = 1; + // CLCPipeline clc_pipeline(shared_storage.pipelines.clc, clc_pipeline_params, cluster_shape); + // Now declare the pipeline outside the if constexpr + CLCPipeline clc_pipeline = [&]() { + if constexpr (IsSchedDynamicPersistent) { + return CLCPipeline(shared_storage.pipelines.clc, clc_pipeline_params, cluster_shape); + } + else { + return CLCPipeline(shared_storage.pipelines.clc, clc_pipeline_params); + } + }(); + + // Mainloop-Epilogue pipeline + typename AccumulatorPipeline::Params accumulator_pipeline_params; + if (WarpCategory::MMA == warp_category) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; + } + // Only one producer thread arrives on this barrier. + accumulator_pipeline_params.producer_arv_count = 1; + accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * NumEpilogueThreads; + accumulator_pipeline_params.initializing_warp = 2; + AccumulatorPipeline accumulator_pipeline(shared_storage.pipelines.accumulator, accumulator_pipeline_params, cluster_shape); + + // Tmem allocator + TmemAllocator tmem_allocator{}; + + // Sync allocation status between MMA and epilogue warps within CTA + arch::NamedBarrier tmem_allocation_result_barrier(NumMMAThreads + NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + // Sync deallocation status between MMA warps of peer CTAs + arch::ClusterBarrier& tmem_deallocation_result_barrier = shared_storage.pipelines.tmem_dealloc; + [[maybe_unused]] uint32_t dealloc_barrier_phase = 0; + + MainloopPipelineTMAState mainloop_pipe_tma_consumer_state; + MainloopPipelineTMAState mainloop_pipe_tma_producer_state = cutlass::make_producer_start_state(); + MainloopPipelineCpAsyncState mainloop_pipe_cpasync_consumer_state; + MainloopPipelineCpAsyncState mainloop_pipe_cpasync_producer_state = cutlass::make_producer_start_state(); + + EpiLoadPipelineState epi_load_pipe_consumer_state; + EpiLoadPipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + + // epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + EpiStorePipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + CLCPipelineState clc_pipe_consumer_state; + CLCPipelineState clc_pipe_producer_state = cutlass::make_producer_start_state(); + + AccumulatorPipelineState accumulator_pipe_consumer_state; + AccumulatorPipelineState accumulator_pipe_producer_state = cutlass::make_producer_start_state(); + + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer threadblocks in the cluster + pipeline_init_arrive_relaxed(cluster_size); + + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + // TileID scheduler + TileScheduler scheduler(&shared_storage.clc_response[0], params.scheduler, block_id_in_cluster); + typename TileScheduler::WorkTileInfo work_tile_info = scheduler.initial_work_tile_info(cluster_shape); + auto cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + + // + // TMEM "Allocation" + // + auto tmem_storage = collective_mainloop.template init_tmem_tensors(EpilogueTile{}); + + // + // END PROLOGUE + // + + // Synchronization call. Blocks until barriers are initialized in shared memory. + pipeline_init_wait(cluster_size); + + if (not work_tile_info.is_valid()) { + // When problem shapes are only on device, the grid launched may be larger than the total number of blocks across groups + return; + } + + if (is_participant.main_load_tma) { + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + // bool do_load_order_arrive = is_epi_load_needed; + bool requires_clc_query = true; + + auto load_inputs = collective_mainloop.load_init_tma( + problem_shape_MNKL, shared_storage.tensors.mainloop); + auto k_tiles = cute::get<0>(load_inputs); + + do { + auto effective_shape = get_effective_shape(params.problem_shape, work_tile_info); + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + auto k_tile_iter = scheduler.get_k_tile_iterator(work_tile_info, effective_shape, CtaShape_MNK{}, k_tiles); + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, effective_shape, CtaShape_MNK{}); + + + auto [mainloop_producer_state_next_, unused_] = collective_mainloop.load_tma( + mainloop_pipeline_tma, + mainloop_pipe_tma_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter, k_tile_count // - k_tile_prologue + ); + mainloop_pipe_tma_producer_state = mainloop_producer_state_next_; + + // Sync warp to prevent non-participating threads entering next wave early + __syncwarp(); + + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + } while (work_tile_info.is_valid()); + collective_mainloop.load_tail_tma(mainloop_pipeline_tma, mainloop_pipe_tma_producer_state); + + } + + else if (is_participant.main_load_cpasync) { + auto load_inputs = collective_mainloop.load_init_cpasync( + problem_shape_MNKL, params.mainloop, shared_storage.tensors.mainloop, + scheduler, work_tile_info); + Tensor gA_mkl = get<0>(load_inputs); + + do { + // Get current work tile and fetch next work tile + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + + auto effective_shape = get_effective_shape(params.problem_shape, work_tile_info); + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + auto k_tile_iter = scheduler.get_k_tile_iterator(work_tile_info, effective_shape, CtaShape_MNK{}, shape<3>(gA_mkl)); + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, effective_shape, CtaShape_MNK{}); + + auto [mainloop_producer_state_next, unused_] = collective_mainloop.load_cpasync( + params.mainloop, + mainloop_pipeline_cpasync, + mainloop_pipe_cpasync_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter, k_tile_count, + effective_shape + ); + mainloop_pipe_cpasync_producer_state = mainloop_producer_state_next; + + // Sync warp to prevent non-participating threads entering next wave early + __syncwarp(); + + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + } while (work_tile_info.is_valid()); + + collective_mainloop.load_tail_cpasync(mainloop_pipeline_cpasync, mainloop_pipe_cpasync_producer_state); + + } + + else if (is_participant.sched) { + + if constexpr (IsSchedDynamicPersistent) { + // Whether a new CLC query must be performed. + // See comment below where this variable is updated for a description of + // why this variable is needed. + bool requires_clc_query = true; + + cutlass::arch::wait_on_dependent_grids(); + + do { + if (requires_clc_query) { + // Query next clcID and update producer state + clc_pipe_producer_state = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state); + } + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + // Only perform a new CLC query if we consumed a new CLC query result in + // `fetch_next_work`. An example of a case in which CLC `fetch_next_work` does + // not consume a new CLC query response is when processing stream-K units. + // The current stream-K scheduler uses single WorkTileInfo to track multiple + // (potentially-partial) tiles to be computed via stream-K. In this case, + // `fetch_next_work` simply performs in-place updates on the existing WorkTileInfo, + // rather than consuming a CLC query response. + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + work_tile_info = next_work_tile_info; + } while (work_tile_info.is_valid()); + clc_pipeline.producer_tail(clc_pipe_producer_state); + } + else { + + cutlass::arch::wait_on_dependent_grids(); + + do { + auto [next_work_tile_info, increment_pipe] = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state); + work_tile_info = next_work_tile_info; + if (increment_pipe) { + ++clc_pipe_producer_state; + } + } while (work_tile_info.is_valid()); + clc_pipeline.producer_tail(clc_pipe_producer_state); + } + } + + else if (is_participant.mma) { + // Tmem allocation sequence + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + // bulk_tmem.data() = tmem_base_ptr; + collective_mainloop.set_tmem_offsets(tmem_storage, tmem_base_ptr); + + + // Pass the acc with tuple type since the bgrad kernel change the mma_init API + auto mma_inputs = collective_mainloop.mma_init(params.mainloop, + tmem_storage, + shared_storage.tensors.mainloop); + do { + auto effective_shape = get_effective_shape(params.problem_shape, work_tile_info); + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, effective_shape, CtaShape_MNK{}); + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + // Wait for tmem accumulator buffer to become empty with a flipped phase + // accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + + int acc_stage = accumulator_pipe_producer_state.index(); + // Tensor accumulators = bulk_tmem(_,_,_,acc_stage); + auto [mainloop_pipe_tma_consumer_state_next_, mainloop_pipe_cpasync_consumer_state_next_] = collective_mainloop.mma( + cute::make_tuple(mainloop_pipeline_tma, mainloop_pipeline_cpasync, accumulator_pipeline), + cute::make_tuple(mainloop_pipe_tma_consumer_state, mainloop_pipe_cpasync_consumer_state, accumulator_pipe_producer_state), + // Pass the acc with tuple type since the bgrad kernel change the mma API + // cute::make_tuple(accumulators, accumulators), + collective_mainloop.slice_accumulator(tmem_storage, acc_stage), + mma_inputs, + cta_coord_mnkl, + k_tile_count + ); + mainloop_pipe_tma_consumer_state = mainloop_pipe_tma_consumer_state_next_; + mainloop_pipe_cpasync_consumer_state = mainloop_pipe_cpasync_consumer_state_next_; + + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + + ++accumulator_pipe_producer_state; + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + } while (work_tile_info.is_valid()); + // Release the right to allocate before deallocations so that the next CTA can rasterize + tmem_allocator.release_allocation_lock(); + + accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); + + // Free entire tmem allocation + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + else if (is_participant.epi_load) { + bool do_tail_load = false; + do { + bool compute_epilogue = TileScheduler::compute_epilogue(work_tile_info, params.scheduler); + + // Get current work tile and fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + if (compute_epilogue) { + + epi_load_pipe_producer_state = collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + shared_storage.tensors.epilogue + ); + + do_tail_load = true; + } + + // Calculate the cta coordinates of the next work tile + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + } while (work_tile_info.is_valid()); + + // Only perform a tail load if one of the work units processed performed + // an epilogue load. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_load) { + collective_epilogue.load_tail( + epi_load_pipeline, epi_load_pipe_producer_state, + epi_store_pipeline, epi_store_pipe_producer_state); + } + } + + else if (is_participant.epilogue) { + // Wait for tmem allocate here + tmem_allocation_result_barrier.arrive_and_wait(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + collective_mainloop.set_tmem_offsets(tmem_storage, tmem_base_ptr); + // bulk_tmem.data() = tmem_base_ptr; + + bool do_tail_store = false; + do { + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + // Accumulator stage slice + int acc_stage = accumulator_pipe_consumer_state.index(); + // Tensor accumulators = bulk_tmem(_,_,_,acc_stage); + auto accumulator = get<0>(collective_mainloop.slice_accumulator(tmem_storage, acc_stage)); + accumulator_pipe_consumer_state = scheduler.template fixup( + TiledMma{}, + work_tile_info, + accumulator, + accumulator_pipeline, + accumulator_pipe_consumer_state, + typename CollectiveEpilogue::CopyOpT2R{} + ); + + // + // Epilogue and write to gD + // + if (scheduler.compute_epilogue(work_tile_info)) { + auto [load_state_next, store_state_next, acc_state_next] = collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + accumulator_pipeline, + accumulator_pipe_consumer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + accumulator, + shared_storage.tensors.epilogue + ); + epi_load_pipe_consumer_state = load_state_next; + epi_store_pipe_producer_state = store_state_next; + accumulator_pipe_consumer_state = acc_state_next; + do_tail_store = true; + } + + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + + } while (work_tile_info.is_valid()); + + // Only perform a tail store if one of the work units processed performed + // an epilogue. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_store) { + collective_epilogue.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state, + epi_store_pipeline, epi_store_pipe_producer_state, + CtaShape_MNK{}); + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp index 85f87af..4ea3f02 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -210,7 +210,6 @@ class GemmUniversal< }; static constexpr int SharedStorageSize = sizeof(SharedStorage); - static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); // Host facing host arguments struct Arguments { @@ -396,8 +395,7 @@ class GemmUniversal< params.hw_info); } - static constexpr - dim3 + static dim3 get_block_shape() { return dim3(MaxThreadsPerBlock, 1, 1); } @@ -409,6 +407,7 @@ class GemmUniversal< using namespace cute; using X = Underscore; + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); // Separate out problem shape for convenience // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_input_transform.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_input_transform.hpp index fcaae85..c82e084 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_input_transform.hpp +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_input_transform.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -143,9 +143,10 @@ class GemmUniversal< static constexpr bool IsSchedDynamicPersistent = TileScheduler::IsDynamicPersistent; - // Transfer registers from regular warps to Accum warps - static constexpr uint32_t GenericRegisterRequirement = 152; - static constexpr uint32_t AccumRegisterRequirement = 200; + // Register reconfiguration + static constexpr uint32_t GenericRegisterRequirement = CollectiveMainloop::GenericRegisterRequirement; + static constexpr uint32_t TransformRegisterRequirement = CollectiveMainloop::TransformRegisterRequirement; + static constexpr uint32_t AccumRegisterRequirement = CollectiveMainloop::AccumRegisterRequirement; // Pipeline and pipeline state types using Load2TransformPipeline = typename CollectiveMainloop::Load2TransformPipeline; @@ -205,7 +206,6 @@ class GemmUniversal< }; static constexpr int SharedStorageSize = sizeof(SharedStorage); - static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); // Device side arguments struct Arguments { @@ -389,6 +389,22 @@ class GemmUniversal< return dim3(MaxThreadsPerBlock, 1, 1); } + // Register alloc/dealloc behavior might change according to the underlying collective used + template + CUTLASS_DEVICE + static constexpr void + warpgroup_reg_reconfig() { + // Compute default-allocated registers per thread: round_down((512 / NumWG), 8) + constexpr int32_t MaxWarpGroupsPerBlock = ceil_div(MaxThreadsPerBlock, NumThreadsPerWarpGroup); + constexpr int32_t NumRegsPerThread = (512 / MaxWarpGroupsPerBlock) / 8 * 8; + if constexpr (NReg < NumRegsPerThread) { + arch::warpgroup_reg_dealloc(); + } + else if constexpr (NReg > NumRegsPerThread) { + arch::warpgroup_reg_alloc(); + } + } + CUTLASS_DEVICE void operator() (Params const& params, char* smem_buf) { @@ -396,6 +412,7 @@ class GemmUniversal< using namespace cute; using X = Underscore; + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); // Separate out problem shape for convenience // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); @@ -637,7 +654,7 @@ class GemmUniversal< if (is_participant.main_load) { // Register reconfiguration - arch::warpgroup_reg_dealloc(); + warpgroup_reg_reconfig(); // Ensure that the prefetched kernel does not touch // unflushed global memory prior to this instruction @@ -715,7 +732,7 @@ class GemmUniversal< else if (is_participant.sched) { // Register reconfiguration - arch::warpgroup_reg_dealloc(); + warpgroup_reg_reconfig(); // Signal the epilogue warps to proceed once the prologue is complete epilogue_throttle_barrier.arrive(); @@ -769,7 +786,7 @@ class GemmUniversal< else if (is_participant.transformation) { // Register reconfiguration - arch::warpgroup_reg_dealloc(); + warpgroup_reg_reconfig(); // Signal the epilogue warps to proceed once the prologue is complete epilogue_throttle_barrier.arrive(); @@ -781,18 +798,19 @@ class GemmUniversal< auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); auto k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); auto k_tile_iter = cute::make_coord_iterator(idx2crd(k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); - auto [load2transform_pipeline_consumer_state_next, transform2mma_pipeline_producer_state_next] = collective_mainloop.transform( - load2transform_pipeline, - load2transform_pipeline_consumer_state, - transform2mma_pipeline, - transform2mma_pipeline_producer_state, - bulk_tmem, - transform_inputs, - k_tile_iter, k_tile_count - ); - transform2mma_pipeline_producer_state = transform2mma_pipeline_producer_state_next; - load2transform_pipeline_consumer_state = load2transform_pipeline_consumer_state_next; - + { + auto [load2transform_pipeline_consumer_state_next, transform2mma_pipeline_producer_state_next] = collective_mainloop.transform( + load2transform_pipeline, + load2transform_pipeline_consumer_state, + transform2mma_pipeline, + transform2mma_pipeline_producer_state, + bulk_tmem, + transform_inputs, + k_tile_iter, k_tile_count + ); + transform2mma_pipeline_producer_state = transform2mma_pipeline_producer_state_next; + load2transform_pipeline_consumer_state = load2transform_pipeline_consumer_state_next; + } // Fetch next work tile auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( work_tile_info, @@ -811,7 +829,7 @@ class GemmUniversal< else if (is_participant.mma) { // Register reconfiguration - arch::warpgroup_reg_dealloc(); + warpgroup_reg_reconfig(); // Tmem allocation sequence tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); @@ -878,7 +896,7 @@ class GemmUniversal< else if (is_participant.epi_load) { // Register reconfiguration - arch::warpgroup_reg_dealloc(); + warpgroup_reg_reconfig(); // Ensure that the prefetched kernel does not touch // unflushed global memory prior to this instruction @@ -941,7 +959,7 @@ class GemmUniversal< else if (is_participant.epilogue) { // Register reconfiguration - arch::warpgroup_reg_alloc(); + warpgroup_reg_reconfig(); // Throttle the epilogue warps to improve prologue performance static constexpr int epilogue_throttle_phase_bit = 0; @@ -950,7 +968,9 @@ class GemmUniversal< // Wait for tmem allocation tmem_allocation_result_barrier.arrive_and_wait_unaligned(); - auto accum_inputs = collective_mainloop.accum_init(bulk_tmem, typename CollectiveEpilogue::CopyOpT2R{}, typename CollectiveEpilogue::EpilogueTile{}); + auto accum_inputs = [&]() { + return collective_mainloop.accum_init(bulk_tmem, typename CollectiveEpilogue::CopyOpT2R{}, typename CollectiveEpilogue::EpilogueTile{}); + }(); bool do_tail_store = false; do { // Fetch next work tile @@ -967,12 +987,12 @@ class GemmUniversal< auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); if constexpr (InputTransformType == cutlass::gemm::detail::KernelInputTransformType::FastF32) { - auto [mma2accum_pipeline_consumer_state_next,tTR_rGlobAcc] = collective_mainloop.accum( - accum_inputs, - mma2accum_pipeline, - mma2accum_pipeline_consumer_state, - k_tile_count); - + auto [mma2accum_pipeline_consumer_state_next,tTR_rGlobAcc] = + collective_mainloop.accum( + accum_inputs, + mma2accum_pipeline, + mma2accum_pipeline_consumer_state, + k_tile_count); mma2accum_pipeline_consumer_state_next = scheduler.template fixup( TiledMma{}, work_tile_info, @@ -1008,39 +1028,6 @@ class GemmUniversal< // Advance the mm2accum pipe mma2accum_pipeline_consumer_state = mma2accum_pipeline_consumer_state_next; } - else if constexpr (InputTransformType == cutlass::gemm::detail::KernelInputTransformType::MixedInput) { - - mma2accum_pipeline.consumer_wait(mma2accum_pipeline_consumer_state); - - // Accumulators - Tensor accumulators = bulk_tmem(_,_,_,mma2accum_pipeline_consumer_state.index()); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) - - mma2accum_pipeline_consumer_state = scheduler.template fixup( - TiledMma{}, - work_tile_info, - accumulators, - mma2accum_pipeline, - mma2accum_pipeline_consumer_state, - typename CollectiveEpilogue::CopyOpT2R{} - ); - - // - // Epilogue and write to gD - // - if (scheduler.compute_epilogue(work_tile_info)) { - auto [mma2accum_pipeline_state_next] = collective_epilogue( - mma2accum_pipeline, - mma2accum_pipeline_consumer_state, - problem_shape_MNKL, - CtaShape_MNK{}, - cta_coord_mnkl, - accumulators, - shared_storage.tensors.epilogue - ); - // Advance the mma2accum pipe - mma2accum_pipeline_consumer_state = mma2accum_pipeline_state_next; - } - } // Complex kernels use a collective epilogue else { mma2accum_pipeline.consumer_wait(mma2accum_pipeline_consumer_state); @@ -1061,7 +1048,9 @@ class GemmUniversal< // Epilogue and write to gD // if (scheduler.compute_epilogue(work_tile_info)) { - auto [mma2accum_pipeline_state_next] = collective_epilogue( + auto [mma2accum_pipeline_state_next, epi_load_pipe_consumer_state_next] = collective_epilogue( + epi_load_pipeline, + epi_load_pipe_consumer_state, mma2accum_pipeline, mma2accum_pipeline_consumer_state, problem_shape_MNKL, @@ -1072,6 +1061,7 @@ class GemmUniversal< ); // Advance the mm2accum pipe mma2accum_pipeline_consumer_state = mma2accum_pipeline_state_next; + epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next; } } @@ -1093,7 +1083,7 @@ class GemmUniversal< else { // Register reconfiguration - arch::warpgroup_reg_dealloc(); + warpgroup_reg_reconfig(); } } }; diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mixed_input_transform.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mixed_input_transform.hpp new file mode 100644 index 0000000..5533f10 --- /dev/null +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mixed_input_transform.hpp @@ -0,0 +1,1090 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/arch/grid_dependency_control.h" +#include "cutlass/fast_math.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/sm100_tile_scheduler.hpp" +#include "cutlass/pipeline/pipeline.hpp" + +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileScheduler_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_, + cute::enable_if_t< + cutlass::detail::is_kernel_tag_of_v>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled; + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + + // Get Blk and Scheduling tile shapes + using CtaShape_MNK = typename CollectiveMainloop::CtaShape_MNK; + using AtomThrShapeMNK = typename CollectiveMainloop::AtomThrShapeMNK; + + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static constexpr bool IsComplex = DispatchPolicy::InputTransformType == cutlass::gemm::detail::KernelInputTransformType::InterleavedComplexTF32; + static_assert(ArchTag::kMinComputeCapability >= 100); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + // CLC pipeline depth + // determines how many waves (stages-1) a warp can race ahead + static constexpr uint32_t SchedulerPipelineStageCount = DispatchPolicy::Schedule::SchedulerPipelineStageCount; + // TileID scheduler + using TileSchedulerTag = TileScheduler_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileScheduler_, ArchTag, CtaShape_MNK, ClusterShape, SchedulerPipelineStageCount>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + + // Warp specialization thread count per threadblock + static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMMAThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMainloopLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueThreads = CollectiveMainloop::NumAccumThreads; // 4 warps + static constexpr uint32_t NumEpilogueWarps = NumEpilogueThreads / NumThreadsPerWarp; + static constexpr uint32_t NumTransformationThreads = CollectiveMainloop::NumTransformationThreads; // 4 warps + static constexpr uint32_t NumMainloopLoadBThreads = NumThreadsPerWarp; // 1 warp + + static constexpr uint32_t MaxThreadsPerBlock = NumSchedThreads + + NumMainloopLoadThreads + NumMMAThreads + + NumEpilogueLoadThreads + + NumEpilogueThreads + NumTransformationThreads + NumMainloopLoadBThreads; + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + static constexpr uint32_t AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; + static constexpr cutlass::gemm::detail::KernelInputTransformType InputTransformType = DispatchPolicy::InputTransformType; + static constexpr uint32_t NumFixupBarriers = 1; + static constexpr uint32_t CLCResponseSize = sizeof(typename TileScheduler::CLCResponse); + + static constexpr bool IsSchedDynamicPersistent = TileScheduler::IsDynamicPersistent; + + // Pipeline and pipeline state types + using Load2TransformPipeline = typename CollectiveMainloop::Load2TransformPipeline; + using Load2TransformPipelineState = typename CollectiveMainloop::Load2TransformPipelineState; + + using Load2MmaPipeline = typename CollectiveMainloop::Load2MmaPipeline; + using Load2MmaPipelineState = typename CollectiveMainloop::Load2MmaPipelineState; + + using Transform2MmaPipeline = typename CollectiveMainloop::Transform2MmaPipeline; + using Transform2MmaPipelineState = typename CollectiveMainloop::Transform2MmaPipelineState; + + using Mma2AccumPipeline = typename CollectiveMainloop::Mma2AccumPipeline; + using Mma2AccumPipelineState = typename CollectiveMainloop::Mma2AccumPipelineState; + + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + using EpiLoadPipelineState = typename CollectiveEpilogue::LoadPipelineState; + + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + using EpiStorePipelineState = typename CollectiveEpilogue::StorePipelineState; + + using LoadOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; + + using CLCPipeline = cutlass::PipelineCLCFetchAsync; + using CLCPipelineState = cutlass::PipelineState; + + using CLCThrottlePipeline = cutlass::PipelineAsync; + using CLCThrottlePipelineState = typename CLCThrottlePipeline::PipelineState; + + using TmemAllocator = cute::conditional_t(typename TiledMma::ThrLayoutVMNK{})) == 1, + cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>; + + // Kernel level shared memory storage + struct SharedStorage { + struct PipelineStorage : cute::aligned_struct<16, _1> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + using LoadOrderBarrierStorage = typename LoadOrderBarrier::SharedStorage; + using CLCPipelineStorage = typename CLCPipeline::SharedStorage; + using CLCThrottlePipelineStorage = typename CLCThrottlePipeline::SharedStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) LoadOrderBarrierStorage load_order; + alignas(16) CLCPipelineStorage clc; + alignas(16) CLCThrottlePipelineStorage clc_throttle; + alignas(16) arch::ClusterBarrier tmem_dealloc; + alignas(16) arch::ClusterBarrier epilogue_throttle; + } pipelines; + + alignas(16) typename TileScheduler::CLCResponse clc_response[SchedulerPipelineStageCount]; + uint32_t tmem_base_ptr; + + struct TensorStorage : cute::aligned_struct<128, _1> { + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + + EpilogueTensorStorage epilogue; + MainloopTensorStorage mainloop; + } tensors; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + TileSchedulerParams scheduler{}; + KernelHardwareInfo hw_info{}; + }; + + enum class WarpCategory : int32_t { + MMA = 0, + Sched = 1, + MainloopLoad = 2, + EpilogueLoad = 3, + Epilogue = 4, + // Transformation starts at 256 thread alignment + Transformation = 8, + MainloopLoadB = 12, + }; + + struct IsParticipant { + uint32_t mma = false; + uint32_t sched = false; + uint32_t main_load = false; + uint32_t main_loadA = false; + uint32_t main_loadB = false; + uint32_t epi_load = false; + uint32_t epilogue = false; + uint32_t transformation = false; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + static constexpr uint32_t NumEpilogueSubTiles = 1; + auto problem_shape = args.problem_shape; + if constexpr (detail::Has_SwapAB_v) { + // swap M/N + get<0>(problem_shape) = get<1>(args.problem_shape); + get<1>(problem_shape) = get<0>(args.problem_shape); + } + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + // Epilogue + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* mainloop_workspace = nullptr; + + // Tile scheduler + void* scheduler_workspace = workspace_ptr + workspace_offset; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + return { + args.mode, + args.problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace, args.hw_info), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), + TileScheduler::to_underlying_arguments(problem_shape_MNKL, TileShape{}, AtomThrShapeMNK{}, ClusterShape{}, + args.hw_info, args.scheduler, scheduler_workspace + ) + ,args.hw_info + }; + } + + static bool + can_implement(Arguments const& args) { + bool implementable = (args.mode == GemmUniversalMode::kGemm) or + (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + + if constexpr (IsDynamicCluster) { + static constexpr int MaxClusterSize = 16; + implementable &= size(args.hw_info.cluster_shape) <= MaxClusterSize; + implementable &= size(args.hw_info.cluster_shape_fallback) <= MaxClusterSize; + implementable &= cutlass::detail::preferred_cluster_can_implement(args.hw_info.cluster_shape, args.hw_info.cluster_shape_fallback); + } + + return implementable; + } + + static size_t + get_workspace_size(Arguments const& args) { + static constexpr uint32_t NumEpilogueSubTiles = 1; + size_t workspace_size = 0; + + // Epilogue + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + // Tile scheduler + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + static constexpr uint32_t NumEpilogueSubTiles = 1; + + // Epilogue + status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + // Tile scheduler + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs, cuda_adapter); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, params.hw_info.cluster_shape); + auto blk_shape = CtaShape_MNK{}; + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + return TileScheduler::get_grid_shape( + params.scheduler, + problem_shape_MNKL, + TileShape{}, + AtomThrShapeMNK{}, + cluster_shape, + params.hw_info); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator() (Params const& params, char* smem_buf) { + + using namespace cute; + using X = Underscore; + + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); + // Separate out problem shape for convenience + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + auto M = get<0>(problem_shape_MNKL); + auto N = get<1>(problem_shape_MNKL); + auto K = get<2>(problem_shape_MNKL); + auto L = get<3>(problem_shape_MNKL); + + // Account for multiple epilogue and transformation warps + int warp_idx = canonical_warp_idx_sync(); + WarpCategory warp_category = warp_idx < static_cast(WarpCategory::Epilogue) ? WarpCategory(warp_idx) + : warp_idx < static_cast(WarpCategory::Transformation) ? WarpCategory::Epilogue + : warp_idx < static_cast(WarpCategory::MainloopLoadB) ? WarpCategory::Transformation + : WarpCategory::MainloopLoadB; + + int thread_idx = int(threadIdx.x); + int thread_idx_in_warp = thread_idx % 32; + uint32_t lane_predicate = cute::elect_one_sync(); + int cta_rank_in_cluster = cute::block_rank_in_cluster(); + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, cute::cluster_shape()); + int cluster_size = size(cluster_shape); + bool is_first_cta_in_cluster = (cta_rank_in_cluster == 0); + bool is_mma_leader_cta = (cta_rank_in_cluster % size<0>(TiledMma{}) == 0); + // Even if this variable is unused, shape_div still performs useful compile-time checks. + [[maybe_unused]] auto mma_leader_ctas = size(shape_div(cluster_shape, AtomThrShapeMNK{})); + constexpr bool has_mma_peer_cta = size(AtomThrShapeMNK{}) == 2; + uint32_t mma_peer_cta_rank = has_mma_peer_cta ? cta_rank_in_cluster ^ 1 : cta_rank_in_cluster; + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_category == WarpCategory::Sched) && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + } + if ((warp_category == WarpCategory::EpilogueLoad) && lane_predicate) { + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + CollectiveMainloop collective_mainloop(params.mainloop, cluster_shape, cta_rank_in_cluster); + CollectiveEpilogue collective_epilogue{params.epilogue, shared_storage.tensors.epilogue}; + + bool is_epi_load_needed = collective_epilogue.is_producer_load_needed(); + IsParticipant is_participant = { + (warp_category == WarpCategory::MMA), // mma + (warp_category == WarpCategory::Sched) && (is_first_cta_in_cluster), // sched + (warp_category == WarpCategory::MainloopLoad || warp_category == WarpCategory::MainloopLoadB), // main_load + (warp_category == WarpCategory::MainloopLoad), // main_loadA + (warp_category == WarpCategory::MainloopLoadB), // main_loadB + (warp_category == WarpCategory::EpilogueLoad) && is_epi_load_needed, // epi_load + (warp_category == WarpCategory::Epilogue), // epilogue + (warp_category == WarpCategory::Transformation) // transformation + }; + + // MainloopLoad <--> Transformation Pipeline + typename Load2TransformPipeline::Params load2transform_pipeline_params; + if (warp_category == WarpCategory::MainloopLoad) { + load2transform_pipeline_params.role = Load2TransformPipeline::ThreadCategory::Producer; + } + else if (warp_category == WarpCategory::Transformation) { + load2transform_pipeline_params.role = Load2TransformPipeline::ThreadCategory::Consumer; + } + load2transform_pipeline_params.is_leader = (thread_idx_in_warp == 0); + load2transform_pipeline_params.num_consumers = NumTransformationThreads; + load2transform_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes_A; + load2transform_pipeline_params.initializing_warp = 0; + Load2TransformPipeline load2transform_pipeline(shared_storage.pipelines.mainloop.load2transform_pipeline, + load2transform_pipeline_params, + cluster_shape, + McastDirection::kRow, + cute::true_type{}, // Perform barrier init + cute::false_type{} // Delay mask calculation + ); + + Load2TransformPipelineState load2transform_pipeline_consumer_state; + Load2TransformPipelineState load2transform_pipeline_producer_state = cutlass::make_producer_start_state(); + + // MainloopLoad <--> MMA Pipeline + typename Load2MmaPipeline::Params load2mma_pipeline_params; + if (warp_category == WarpCategory::MainloopLoadB) { + load2mma_pipeline_params.role = Load2MmaPipeline::ThreadCategory::Producer; + } + else if (warp_category == WarpCategory::MMA) { + load2mma_pipeline_params.role = Load2MmaPipeline::ThreadCategory::Consumer; + } + load2mma_pipeline_params.is_leader = lane_predicate && is_mma_leader_cta && is_participant.main_loadB; + load2mma_pipeline_params.num_consumers = NumMMAThreads; + load2mma_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes_B; + load2mma_pipeline_params.initializing_warp = 8; + Load2MmaPipeline load2mma_pipeline(shared_storage.pipelines.mainloop.load2mma_pipeline, + load2mma_pipeline_params, + cluster_shape, + McastDirection::kCol, + cute::true_type{}, // Perform barrier init + cute::false_type{} // Delay mask calculation + ); + + Load2MmaPipelineState load2mma_pipeline_consumer_state; + Load2MmaPipelineState load2mma_pipeline_producer_state = cutlass::make_producer_start_state(); + + + // Transformation <--> MMA pipeline + typename Transform2MmaPipeline::Params transform2mma_pipeline_params; + if (warp_category == WarpCategory::Transformation) { + transform2mma_pipeline_params.role = Transform2MmaPipeline::ThreadCategory::Producer; + } + else if (warp_category == WarpCategory::MMA) { + transform2mma_pipeline_params.role = Transform2MmaPipeline::ThreadCategory::Consumer; + } + transform2mma_pipeline_params.consumer_arv_count = 1; + transform2mma_pipeline_params.producer_arv_count = size(AtomThrShapeMNK{}) * NumTransformationThreads; + transform2mma_pipeline_params.initializing_warp = 2; + Transform2MmaPipeline transform2mma_pipeline(shared_storage.pipelines.mainloop.transform2mma_pipeline, + transform2mma_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{} // Delay mask calculation + ); + + Transform2MmaPipelineState transform2mma_pipeline_consumer_state; + Transform2MmaPipelineState transform2mma_pipeline_producer_state = cutlass::make_producer_start_state(); + + // MMA <--> Accumulator pipeline + typename Mma2AccumPipeline::Params mma2accum_pipeline_params; + if (warp_category == WarpCategory::MMA) { + mma2accum_pipeline_params.role = Mma2AccumPipeline::ThreadCategory::Producer; + } + else if (warp_category == WarpCategory::Epilogue) { + mma2accum_pipeline_params.role = Mma2AccumPipeline::ThreadCategory::Consumer; + } + mma2accum_pipeline_params.producer_arv_count = 1; + mma2accum_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * NumEpilogueThreads; + mma2accum_pipeline_params.initializing_warp = 6; + Mma2AccumPipeline mma2accum_pipeline(shared_storage.pipelines.mainloop.mma2accum_pipeline, + mma2accum_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{} // Delay mask calculation + ); + + Mma2AccumPipelineState mma2accum_pipeline_consumer_state; + Mma2AccumPipelineState mma2accum_pipeline_producer_state = cutlass::make_producer_start_state(); + + // Epilogue Load pipeline + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (WarpCategory::EpilogueLoad == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cta_rank_in_cluster; + epi_load_pipeline_params.producer_arv_count = NumEpilogueLoadThreads; + epi_load_pipeline_params.consumer_arv_count = NumEpilogueThreads; + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + epi_load_pipeline_params.initializing_warp = 4; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // Load order barrier + typename LoadOrderBarrier::Params load_order_barrier_params; + load_order_barrier_params.group_id = (warp_category == WarpCategory::MainloopLoad) ? 0 : 1; + load_order_barrier_params.group_size = 1; + load_order_barrier_params.initializing_warp = 5; + LoadOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, load_order_barrier_params); + + EpiLoadPipelineState epi_load_pipe_consumer_state; + EpiLoadPipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + + // epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + EpiStorePipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + // CLC pipeline + // Operates Scheduling Warp <--> All Warps + typename CLCPipeline::Params clc_pipeline_params; + if (WarpCategory::Sched == warp_category) { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::ProducerConsumer; + } + else { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::Consumer; + } + clc_pipeline_params.producer_blockid = 0; + clc_pipeline_params.producer_arv_count = 1; + clc_pipeline_params.consumer_arv_count = NumSchedThreads + cluster_size * + (NumMainloopLoadThreads + NumMainloopLoadBThreads + NumEpilogueThreads + + NumMMAThreads + NumTransformationThreads); + if (is_epi_load_needed) { + clc_pipeline_params.consumer_arv_count += cluster_size * NumEpilogueLoadThreads; + } + clc_pipeline_params.transaction_bytes = CLCResponseSize; + clc_pipeline_params.initializing_warp = 1; + CLCPipeline clc_pipeline(shared_storage.pipelines.clc, clc_pipeline_params, cluster_shape); + + CLCPipelineState clc_pipeline_consumer_state; + CLCPipelineState clc_pipeline_producer_state = cutlass::make_producer_start_state(); + + // CLC throttle pipeline + typename CLCThrottlePipeline::Params clc_throttle_pipeline_params; + if (WarpCategory::MainloopLoad == warp_category) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Producer; + } + if (WarpCategory::Sched == warp_category) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Consumer; + } + clc_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads; + clc_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; + clc_throttle_pipeline_params.dst_blockid = 0; + clc_throttle_pipeline_params.initializing_warp = 3; + CLCThrottlePipeline clc_throttle_pipeline(shared_storage.pipelines.clc_throttle, clc_throttle_pipeline_params); + CLCThrottlePipelineState clc_pipe_throttle_consumer_state; + CLCThrottlePipelineState clc_pipe_throttle_producer_state = cutlass::make_producer_start_state(); + + // Tmem allocator + TmemAllocator tmem_allocator{}; + + // Sync allocation status between transform, MMA, and epilogue warps within CTA + arch::NamedBarrier tmem_allocation_result_barrier(NumTransformationThreads + NumMMAThreads + NumEpilogueThreads, + cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + // Sync deallocation status between MMA warps of peer CTAs + arch::ClusterBarrier& tmem_deallocation_result_barrier = shared_storage.pipelines.tmem_dealloc; + [[maybe_unused]] uint32_t dealloc_barrier_phase = 0; + if (WarpCategory::MMA == warp_category && has_mma_peer_cta && lane_predicate) { + tmem_deallocation_result_barrier.init(NumMMAThreads); + } + + // Initialize smem barrier for prologue throttling. Epilogue warps are stalled until the prologue finishes. + arch::ClusterBarrier& epilogue_throttle_barrier = shared_storage.pipelines.epilogue_throttle; + if (WarpCategory::MMA == warp_category && lane_predicate) { + epilogue_throttle_barrier.init( NumMMAThreads + + (is_first_cta_in_cluster ? NumSchedThreads : 0) + + NumMainloopLoadThreads + + NumMainloopLoadBThreads + + (is_epi_load_needed ? NumEpilogueLoadThreads : 0) + + NumTransformationThreads); + } + + + + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer threadblocks in the cluster + pipeline_init_arrive_relaxed(cluster_size); + + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + + // Calculate mask after cluster barrier arrival + load2transform_pipeline.init_masks(cluster_shape, block_id_in_cluster, cutlass::McastDirection::kRow); + load2mma_pipeline.init_masks(cluster_shape, cutlass::McastDirection::kCol); + transform2mma_pipeline.init_masks(cluster_shape); + mma2accum_pipeline.init_masks(cluster_shape); + + // TileID scheduler + TileScheduler scheduler(&shared_storage.clc_response[0], params.scheduler, block_id_in_cluster); + typename TileScheduler::WorkTileInfo work_tile_info = scheduler.initial_work_tile_info(cluster_shape); + + auto cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + + // Allocate accumulators + auto acc_shape = collective_mainloop.partition_accumulator_shape(); + auto bulk_tmem = TiledMma::make_fragment_C(append(acc_shape, + Int{})); + + // Tile transform inputs now to get the k tile count + auto transform_inputs = collective_mainloop.transform_init(params.mainloop, problem_shape_MNKL, bulk_tmem, shared_storage.tensors.mainloop); + Tensor gA_mkl = get<0>(transform_inputs); + + // Synchronization call. Blocks wait until barriers are initialized in shared memory. + pipeline_init_wait(cluster_size); + + if (is_participant.main_load) { + + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + bool do_load_order_arrive = is_epi_load_needed; + auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop, shared_storage.tensors.mainloop); + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + bool requires_clc_query = true; + + do { + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + auto k_tile_iter = scheduler.get_k_tile_iterator(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}, shape<3>(gA_mkl)); + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + auto k_tile_prologue = min(Load2TransformPipeline::Stages, k_tile_count); + + if(is_participant.main_loadA){ + if constexpr (IsSchedDynamicPersistent) { + if (is_first_cta_in_cluster && requires_clc_query) { + clc_throttle_pipeline.producer_acquire(clc_pipe_throttle_producer_state); + clc_throttle_pipeline.producer_commit(clc_pipe_throttle_producer_state); + ++clc_pipe_throttle_producer_state; + } + } + } + + if (lane_predicate) { + if(is_participant.main_loadA){ + auto [load2transform_pipeline_producer_state_next, k_tile_iter_next] = collective_mainloop.load_A( + params.mainloop, + load2transform_pipeline, + load2transform_pipeline_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter, k_tile_prologue + ); + load2transform_pipeline_producer_state = load2transform_pipeline_producer_state_next; + + if (do_load_order_arrive) { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + + auto [load2transform_pipeline_producer_state_next_, unused_] = collective_mainloop.load_A( + params.mainloop, + load2transform_pipeline, + load2transform_pipeline_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter_next, k_tile_count - k_tile_prologue + ); + load2transform_pipeline_producer_state = load2transform_pipeline_producer_state_next_; + } + + if(is_participant.main_loadB){ + auto [load2mma_pipeline_producer_state_next, k_tile_iter_next] = collective_mainloop.load_B( + params.mainloop, + load2mma_pipeline, + load2mma_pipeline_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter, k_tile_prologue + ); + load2mma_pipeline_producer_state = load2mma_pipeline_producer_state_next; + + auto [load2mma_pipeline_producer_state_next_, unused_] = collective_mainloop.load_B( + params.mainloop, + load2mma_pipeline, + load2mma_pipeline_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter_next, k_tile_count - k_tile_prologue + ); + load2mma_pipeline_producer_state = load2mma_pipeline_producer_state_next_; + + } + } + + // Sync warp to prevent non-participating threads entering next wave early + __syncwarp(); + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipeline_consumer_state + ); + + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipeline_consumer_state; + } + work_tile_info = next_work_tile_info; + } while (work_tile_info.is_valid()); + + if(is_participant.main_loadA){ + if (lane_predicate) { + load2transform_pipeline.producer_tail(load2transform_pipeline_producer_state); + } + } + if(is_participant.main_loadB){ + if (lane_predicate) { + load2mma_pipeline.producer_tail(load2mma_pipeline_producer_state); + } + } + + } + + else if (is_participant.sched) { + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + + if constexpr (IsSchedDynamicPersistent) { + // Whether a new CLC query must be performed. + // See comment below where this variable is updated for a description of + // why this variable is needed. + bool requires_clc_query = true; + + cutlass::arch::wait_on_dependent_grids(); + + do { + if (requires_clc_query) { + // Throttle CLC query to mitigate workload imbalance caused by skews among persistent workers. + clc_throttle_pipeline.consumer_wait(clc_pipe_throttle_consumer_state); + clc_throttle_pipeline.consumer_release(clc_pipe_throttle_consumer_state); + ++clc_pipe_throttle_consumer_state; + + // Query next clcID and update producer state + clc_pipeline_producer_state = scheduler.advance_to_next_work( + clc_pipeline, + clc_pipeline_producer_state + ); + } + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipeline_consumer_state + ); + + // Only perform a new CLC query if we consumed a new CLC query result in + // `fetch_next_work`. An example of a case in which CLC `fetch_next_work` does + // not consume a new CLC query response is when processing stream-K units. + // The current stream-K scheduler uses single WorkTileInfo to track multiple + // (potentially-partial) tiles to be computed via stream-K. In this case, + // `fetch_next_work` simply performs in-place updates on the existing WorkTileInfo, + // rather than consuming a CLC query response. + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipeline_consumer_state; + } + + work_tile_info = next_work_tile_info; + } while (work_tile_info.is_valid()); + clc_pipeline.producer_tail(clc_pipeline_producer_state); + } + } + + else if (is_participant.transformation) { + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + + // Wait for tmem allocation + tmem_allocation_result_barrier.arrive_and_wait_unaligned(); + + do { + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + auto k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); + auto k_tile_iter = cute::make_coord_iterator(idx2crd(k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); + auto [load2transform_pipeline_consumer_state_next, transform2mma_pipeline_producer_state_next] = collective_mainloop.transform( + load2transform_pipeline, + load2transform_pipeline_consumer_state, + transform2mma_pipeline, + transform2mma_pipeline_producer_state, + bulk_tmem, + transform_inputs, + k_tile_iter, k_tile_count + ); + transform2mma_pipeline_producer_state = transform2mma_pipeline_producer_state_next; + load2transform_pipeline_consumer_state = load2transform_pipeline_consumer_state_next; + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipeline_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipeline_consumer_state; + } + } while (work_tile_info.is_valid()); + + transform2mma_pipeline.producer_tail(transform2mma_pipeline_producer_state); + } + + else if (is_participant.mma) { + + // Tmem allocation sequence + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + + auto mma_input_operands = collective_mainloop.mma_init(bulk_tmem, shared_storage.tensors.mainloop); + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + + do { + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipeline_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipeline_consumer_state; + } + + if (is_mma_leader_cta) { + auto [load2mma_pipeline_consumer_state_next, transform2mma_pipeline_consumer_state_next, mma2accum_pipeline_producer_state_next] = collective_mainloop.mma( + load2mma_pipeline, + load2mma_pipeline_consumer_state, + transform2mma_pipeline, + transform2mma_pipeline_consumer_state, + mma2accum_pipeline, + mma2accum_pipeline_producer_state, + bulk_tmem, + mma_input_operands, + k_tile_count + ); + // Advance the mm2accum pipe + load2mma_pipeline_consumer_state = load2mma_pipeline_consumer_state_next; + transform2mma_pipeline_consumer_state = transform2mma_pipeline_consumer_state_next; + mma2accum_pipeline_producer_state = mma2accum_pipeline_producer_state_next; + } + } while (work_tile_info.is_valid()); + + // leader MMA waits for leader + peer epilogues to release accumulator stage + if (is_mma_leader_cta) { + mma2accum_pipeline.producer_tail(mma2accum_pipeline_producer_state); + } + + // Hint on an early release of global memory resources. + // The timing of calling this function only influences performance, + // not functional correctness. + cutlass::arch::launch_dependent_grids(); + + // Signal to peer MMA that entire tmem allocation can be deallocated + if constexpr (has_mma_peer_cta) { + // Leader does wait + arrive, follower does arrive + wait + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, not is_mma_leader_cta); + tmem_deallocation_result_barrier.wait(dealloc_barrier_phase); + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, is_mma_leader_cta); + } + + // Free entire tmem allocation + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + else if (is_participant.epi_load) { + + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + bool do_load_order_wait = true; + bool do_tail_load = false; + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + + do { + bool compute_epilogue = TileScheduler::compute_epilogue(work_tile_info, params.scheduler); + // Get current work tile and fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipeline_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipeline_consumer_state; + } + + if (compute_epilogue) { + if (do_load_order_wait) { + load_order_barrier.wait(); + do_load_order_wait = false; + } + + epi_load_pipe_producer_state = collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + shared_storage.tensors.epilogue + ); + + do_tail_load = true; + } + + // Calculate the cta coordinates of the next work tile + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + } while (work_tile_info.is_valid()); + + // Only perform a tail load if one of the work units processed performed + // an epilogue load. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_load) { + collective_epilogue.load_tail( + epi_load_pipeline, epi_load_pipe_producer_state, + epi_store_pipeline, epi_store_pipe_producer_state); + } + } + + else if (is_participant.epilogue) { + + // Throttle the epilogue warps to improve prologue performance + static constexpr int epilogue_throttle_phase_bit = 0; + epilogue_throttle_barrier.wait(epilogue_throttle_phase_bit); + + // Wait for tmem allocation + tmem_allocation_result_barrier.arrive_and_wait_unaligned(); + + auto accum_inputs = collective_mainloop.accum_init(bulk_tmem, typename CollectiveEpilogue::CopyOpT2R{}, typename CollectiveEpilogue::EpilogueTile{}); + bool do_tail_store = false; + do { + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipeline_consumer_state + ); + + if (increment_pipe) { + ++clc_pipeline_consumer_state; + } + + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + mma2accum_pipeline.consumer_wait(mma2accum_pipeline_consumer_state); + + // Accumulators + Tensor accumulators = bulk_tmem(_,_,_,mma2accum_pipeline_consumer_state.index()); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + mma2accum_pipeline_consumer_state = scheduler.template fixup( + TiledMma{}, + work_tile_info, + accumulators, + mma2accum_pipeline, + mma2accum_pipeline_consumer_state, + typename CollectiveEpilogue::CopyOpT2R{} + ); + + // + // Epilogue and write to gD + // + if (scheduler.compute_epilogue(work_tile_info)) { + auto [load_state_next, store_state_next, mma2accum_pipeline_state_next] = collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + mma2accum_pipeline, + mma2accum_pipeline_consumer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + accumulators, + shared_storage.tensors.epilogue + ); + epi_load_pipe_consumer_state = load_state_next; + epi_store_pipe_producer_state = store_state_next; + do_tail_store = true; + + // Advance the mma2accum pipe + mma2accum_pipeline_consumer_state = mma2accum_pipeline_state_next; + } + + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + } while (work_tile_info.is_valid()); + + // Only perform a tail load if one of the work units processed performed + // an epilogue load. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_store) { + collective_epilogue.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state, + epi_store_pipeline, epi_store_pipe_producer_state, + CtaShape_MNK{}); + } + } + else { + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mma_transform.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mma_transform.hpp index 180bda3..9a6b10a 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mma_transform.hpp +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mma_transform.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -216,7 +216,6 @@ class GemmUniversal< }; static constexpr int SharedStorageSize = sizeof(SharedStorage); - static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); // Host facing host arguments struct Arguments { @@ -390,8 +389,7 @@ class GemmUniversal< params.hw_info); } - static constexpr - dim3 + static dim3 get_block_shape() { return dim3(MaxThreadsPerBlock, 1, 1); } @@ -403,6 +401,7 @@ class GemmUniversal< using namespace cute; using X = Underscore; + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); // Separate out problem shape for convenience // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sm100_sparse_gemm_tma_warpspecialized.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/sm100_sparse_gemm_tma_warpspecialized.hpp index a5f6eb9..d87ac8f 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/sm100_sparse_gemm_tma_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sm100_sparse_gemm_tma_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -216,7 +216,6 @@ class GemmUniversal< }; static constexpr int SharedStorageSize = sizeof(SharedStorage); - static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); // Host facing host arguments struct Arguments { @@ -456,6 +455,7 @@ class GemmUniversal< using namespace cute; using X = Underscore; + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); // Separate out problem shape for convenience // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sm100_static_tile_scheduler.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/sm100_static_tile_scheduler.hpp index 6a1e6a8..776026f 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/sm100_static_tile_scheduler.hpp +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sm100_static_tile_scheduler.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp index 806d902..739010c 100755 --- a/3rd/cutlass/include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -621,7 +621,7 @@ class PersistentTileSchedulerSm100 { , "r"(clc_response.data[1]) , "r"(clc_response.data[2]) , "r"(clc_response.data[3])); - cutlass::arch::fence_view_async_shared(); + cutlass::arch::fence_view_shared(); #endif } diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp index b5538a7..2d8728a 100755 --- a/3rd/cutlass/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -52,7 +52,7 @@ namespace cutlass::gemm::kernel::detail { // Therefore, we don't how many tiles there will be for the scheduler to hand out. // Hence, we have a SM90 style static group scheduler that launches the largest grid possible. // If we had access to host-side problem shapes, one could to use it to figure out the grid shape -// and thereafter use CLC query (which can then be linearized and mapped to an approriate tile coord). +// and thereafter use CLC query (which can then be linearized and mapped to an appropriate tile coord). template class PersistentTileSchedulerSm100Group { @@ -88,9 +88,7 @@ class PersistentTileSchedulerSm100Group { static_assert(cute::is_static::value); auto selected_cluster_shape = cutlass::detail::select_cluster_shape(cluster_shape_mnk, hw_info.cluster_shape); - auto cta_shape = cute::conditional_return>( - shape_div(tile_shape_mnk, atom_thr_shape_mnk), // Dynamic Cluster: For 2SM kernels, use CTA tile shape for the underlying scheduler - shape_div(tile_shape_mnk, selected_cluster_shape)); // Static Cluster: Blackwell builders expects TileShape to be Cluster's Tile Shape, Hopper doesn't + auto cta_shape = shape_div(tile_shape_mnk, atom_thr_shape_mnk); // For 2SM kernels, use CTA tile shape for the underlying scheduler dim3 problem_blocks = get_tiled_cta_shape_mnl( problem_shapes, @@ -118,23 +116,26 @@ class PersistentTileSchedulerSm100Group { CUTLASS_DEVICE PersistentTileSchedulerSm100Group() { } - + + // Note: constructing this tile scheduler can touch global memory that was + // written to by the prior kernel. CUTLASS_DEVICE PersistentTileSchedulerSm100Group(CLCResponse* clc_response_ptr, Params const& params) : scheduler_params(params), scheduler_sm90(params.params_sm90_, clc_response_ptr) { } - + // Note: constructing this tile scheduler can touch global memory that was + // written to by the prior kernel. CUTLASS_DEVICE PersistentTileSchedulerSm100Group(CLCResponse* clc_response_ptr, Params const& params, dim3 /* block_id_in_cluster */) : scheduler_params(params), scheduler_sm90(params.params_sm90_, clc_response_ptr) { } // Returns the initial work tile info that will be computed over - template + template CUTLASS_DEVICE auto - initial_work_tile_info(ClusterShape cluster_shape) { - return scheduler_sm90.initial_work_tile_info(cluster_shape); + initial_work_tile_info(ClusterShape cluster_shape, CallbackBeforeCommit callback_before_commit = [] (WorkTileInfo info) { return info;}) { + return scheduler_sm90.initial_work_tile_info(cluster_shape, callback_before_commit); } template @@ -163,9 +164,6 @@ class PersistentTileSchedulerSm100Group { // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently Arguments args{}; - if constexpr (!std::is_const_v) { - args.max_swizzle_size = 1 << params.params_sm90_.log_swizzle_size_; - } args.raster_order = params.params_sm90_.raster_order_ == RasterOrder::AlongN ? RasterOrderOptions::AlongN : RasterOrderOptions::AlongM; return Params::get_grid_shape( @@ -191,15 +189,16 @@ class PersistentTileSchedulerSm100Group { ); } - template + template CUTLASS_DEVICE auto advance_to_next_work( CLCPipeline& clc_pipeline, CLCPipelineState clc_pipe_producer_state, - uint32_t advance_count = 1) { + uint32_t advance_count = 1, + CallbackBeforeCommit callback_before_commit = [] (WorkTileInfo info) { return info;}) { - return scheduler_sm90.advance_to_next_work(clc_pipeline, clc_pipe_producer_state, advance_count); + return scheduler_sm90.advance_to_next_work(clc_pipeline, clc_pipe_producer_state, advance_count, callback_before_commit); } // @@ -242,6 +241,27 @@ class PersistentTileSchedulerSm100Group { void fixup(WorkTileInfo const&, FrgTensorC&, uint32_t, uint32_t, uint32_t = 1) const { } + template < + bool IsComplex, + class TiledMma, + class AccEngine, + class AccLayout, + class AccumulatorPipeline, + class AccumulatorPipelineState, + class CopyOpT2R + > + CUTLASS_DEVICE + AccumulatorPipelineState + fixup( + TiledMma const& , + WorkTileInfo const&, + cute::Tensor&, + AccumulatorPipeline, + AccumulatorPipelineState acc_pipe_consumer_state, + CopyOpT2R) const { + return acc_pipe_consumer_state; + } + template static size_t get_workspace_size(Arguments const& args, ProblemShape problem_shape, KernelHardwareInfo const& hw_info, uint32_t, uint32_t = 1, uint32_t = 1) { @@ -285,11 +305,11 @@ class PersistentTileSchedulerSm100Group { } // Kernel helper function to get next CLC ID - template + template CUTLASS_DEVICE auto fetch_next_work( - WorkTileInfo work_tile_info, + WorkTileWithCallbackInfo work_tile_info, CLCPipeline& clc_pipeline, CLCPipelineState clc_pipe_consumer_state) { @@ -301,7 +321,7 @@ class PersistentTileSchedulerSm100Group { // Methods // [[nodiscard]] CUTLASS_DEVICE - static CLCResponse + static auto load_query_response(uint32_t smem_ptr) { return UnderlyingScheduler::load_query_response(smem_ptr); } diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp index ca853cd..54d81ca 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -37,6 +37,8 @@ #include "cutlass/gemm/kernel/sm100_tile_scheduler.hpp" #include "cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp" #include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/conv/detail.hpp" + //////////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm::kernel::detail { @@ -177,6 +179,44 @@ class PersistentTileSchedulerSm100StreamK { return params; } + template + static Params + to_underlying_arguments( + cutlass::conv::ConvProblemShape problem_shape, + TileShapeMNK tile_shape_mnk, + AtomThrShape atom_thr_shape_mnk, + ClusterShape cluster_shape_mnk, + KernelHardwareInfo const& hw_info, + Arguments const& args, + void* workspace = nullptr + ) { + + auto problem_shape_mnkl = [&] () { + // Infer im2col linearization from ConvOp and TileShape + constexpr bool is_linearized_M = (ConvOp == conv::Operator::kFprop || ConvOp == conv::Operator::kDgrad) + && cute::depth<0>(TileShapeMNK{}) == _0{}; + constexpr bool is_linearized_K = ConvOp == conv::Operator::kWgrad && cute::depth<2>(TileShapeMNK{}) == _1{}; + if constexpr (is_linearized_M || is_linearized_K) { + // transformation + im2col linearization + return cutlass::conv::detail::get_linearized_problem_shape_MNKL(problem_shape); + } + else { + // transformation + return cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape); + } + }(); + + return to_underlying_arguments( + problem_shape_mnkl, + tile_shape_mnk, + atom_thr_shape_mnk, + cluster_shape_mnk, + hw_info, + args, + workspace + ); + } + static bool can_implement(Arguments const& args) { return UnderlyingStreamKScheduler::can_implement(args); @@ -186,7 +226,7 @@ class PersistentTileSchedulerSm100StreamK { PipelineState advance_to_next_work(Pipeline& clc_pipeline, PipelineState clc_pipe_producer_state) const { return sm100_scheduler_.advance_to_next_work(clc_pipeline, clc_pipe_producer_state); - } + } // Given the inputs, computes the total number of output blocks this problem will compute over template @@ -728,7 +768,7 @@ class PersistentTileSchedulerSm100StreamK { auto cluster_start_linear_id = sm_count * wave_idx + cluster_idx; // Determine the offset of this CTA in the preferred cluster shape. - // This calculation aims to accomodate both cases in which this CTA is part of a preferred cluster + // This calculation aims to accommodate both cases in which this CTA is part of a preferred cluster // and those in which it is part of a fallback cluster. // // The calculation is performed by computing the starting M and N index of the preferred cluster that diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_array_tma_warpspecialized.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_array_tma_warpspecialized.hpp new file mode 100644 index 0000000..7416f41 --- /dev/null +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_array_tma_warpspecialized.hpp @@ -0,0 +1,1330 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/workspace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/arch/grid_dependency_control.h" +#include "cutlass/fast_math.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/barrier.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/sm100_tile_scheduler.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/detail/sm100_tmem_helper.hpp" + +#include "cute/tensor.hpp" +#include "cute/arch/tmem_allocator_sm100.hpp" +#include "cute/atom/mma_atom.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileSchedulerTag_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileSchedulerTag_, + cute::enable_if_t< + cutlass::detail::is_kernel_tag_of_v>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(rank(typename ProblemShape::UnderlyingProblemShape{}) == 3 or rank(typename ProblemShape::UnderlyingProblemShape{}) == 4, + "ProblemShape{} should be or "); + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using InternalStrideA = typename CollectiveMainloop::InternalStrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using InternalStrideB = typename CollectiveMainloop::InternalStrideB; + using LayoutSFA = typename CollectiveMainloop::LayoutSFA; + using LayoutSFB = typename CollectiveMainloop::LayoutSFB; + using ElementSF = typename CollectiveMainloop::ElementSF; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static_assert(ArchTag::kMinComputeCapability >= 100); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using EpilogueTile = typename CollectiveEpilogue::EpilogueTile; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using InternalStrideC = typename CollectiveEpilogue::InternalStrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using InternalStrideD = typename CollectiveEpilogue::InternalStrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + // CLC pipeline depth + // determines how many waves (stages-1) a warp can race ahead + static constexpr uint32_t SchedulerPipelineStageCount = DispatchPolicy::Schedule::SchedulerPipelineStageCount; + static constexpr uint32_t AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; + static constexpr bool IsOverlappingAccum = DispatchPolicy::IsOverlappingAccum; + + // TileID scheduler + // Get Blk and Scheduling tile shapes + using AtomThrShapeMNK = typename CollectiveMainloop::AtomThrShapeMNK; + using CtaShape_MNK = typename CollectiveMainloop::CtaShape_MNK; + + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + using TileSchedulerTag = TileSchedulerTag_; + using TileScheduler = cute::conditional_t::Scheduler, + typename detail::TileSchedulerSelector< + TileSchedulerTag_, ArchTag, CtaShape_MNK, ClusterShape, SchedulerPipelineStageCount>::Scheduler>; + + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + static constexpr bool IsSchedDynamicPersistent = TileScheduler::IsDynamicPersistent; + static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled; + + static constexpr uint32_t MinTensorMapWorkspaceAlignment = 64; + + // Warp specialization thread count per threadblock + static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMMAThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMainloopABLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMainloopSFLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueThreads = CollectiveEpilogue::ThreadCount; + static constexpr uint32_t NumEpilogueWarps = NumEpilogueThreads / NumThreadsPerWarp; + static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEmptyThreads = 3 * NumThreadsPerWarp; // 3 warp + + static constexpr uint32_t MaxThreadsPerBlock = NumSchedThreads + + NumMainloopABLoadThreads + NumMainloopSFLoadThreads + NumMMAThreads + + NumEpilogueLoadThreads + NumEpilogueThreads + NumEmptyThreads; + + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + static constexpr uint32_t NumFixupBarriers = 1; + static constexpr uint32_t CLCResponseSize = sizeof(typename TileScheduler::CLCResponse); + + // Pipeline and pipeline state types + using MainloopABPipeline = typename CollectiveMainloop::MainloopABPipeline; + using MainloopABPipelineState = typename CollectiveMainloop::MainloopABPipelineState; + + using MainloopSFPipeline = typename CollectiveMainloop::MainloopSFPipeline; + using MainloopSFPipelineState = typename CollectiveMainloop::MainloopSFPipelineState; + + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + using EpiLoadPipelineState = typename CollectiveEpilogue::LoadPipelineState; + + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + using EpiStorePipelineState = typename CollectiveEpilogue::StorePipelineState; + + using LoadOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; + + using AccumulatorPipeline = cutlass::PipelineUmmaAsync; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + + using CLCPipeline = cute::conditional_t, + cutlass::PipelineAsync>; + using CLCPipelineState = typename CLCPipeline::PipelineState; + + using CLCThrottlePipeline = cute::conditional_t, + cutlass::PipelineEmpty>; + using CLCThrottlePipelineState = typename CLCThrottlePipeline::PipelineState; + + using TmemAllocator = cute::conditional_t(typename TiledMma::ThrLayoutVMNK{})) == 1, + cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>; + + static constexpr int EpilogueWarpRegs = 248; + static constexpr int NonEpilogueWarpRegs = 128; + + // Kernel level shared memory storage + struct SharedStorage { + // Barriers should be allocated in lower 8KB of SMEM for SM100 + struct PipelineStorage : cute::aligned_struct<16, _1> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + using LoadOrderBarrierStorage = typename LoadOrderBarrier::SharedStorage; + using CLCPipelineStorage = typename CLCPipeline::SharedStorage; + using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; + using CLCThrottlePipelineStorage = typename CLCThrottlePipeline::SharedStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) LoadOrderBarrierStorage load_order; + alignas(16) CLCPipelineStorage clc; + alignas(16) AccumulatorPipelineStorage accumulator; + alignas(16) CLCThrottlePipelineStorage clc_throttle; + alignas(8) arch::ClusterBarrier tmem_dealloc; + } pipelines; + + alignas(16) typename TileScheduler::CLCResponse clc_response[SchedulerPipelineStageCount]; + uint32_t tmem_base_ptr; + + struct TensorMapStorage : cute::aligned_struct<128, _1> { + using EpilogueTensorMapStorage = typename CollectiveEpilogue::TensorMapStorage; + using MainloopTensorMapStorage = typename CollectiveMainloop::TensorMapStorage; + alignas(128) EpilogueTensorMapStorage epilogue; + alignas(128) MainloopTensorMapStorage mainloop; + } tensormaps; + + struct TensorStorage : cute::aligned_struct<128, _1> { + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + + EpilogueTensorStorage epilogue; + MainloopTensorStorage mainloop; + } tensors; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Host facing host arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel device entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + TileSchedulerParams scheduler{}; + KernelHardwareInfo hw_info{}; + }; + + enum class WarpCategory : int32_t { + MMA = 0, + Sched = 1, + MainloopABLoad = 2, + MainloopSFLoad = 3, + Epilogue = 4, // Warps [4-8) + EpilogueLoad = 8, + Unused = 9 + }; + + struct IsParticipant { + uint32_t mma = false; + uint32_t sched = false; + uint32_t main_ab_load = false; + uint32_t epi_load = false; + uint32_t epilogue = false; + uint32_t main_sf_load = false; + uint32_t unused = false; + }; + + // + // Methods + // + + // Convert to underlying arguments. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + constexpr uint32_t NumEpilogueSubTiles = 1; + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + ProblemShape problem_shapes = args.problem_shape; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (IsGroupedGemmKernel && sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + else if (!IsGroupedGemmKernel && sm_count != 0) { + CUTLASS_TRACE_HOST(" WARNING: SM100 tile scheduler does not allow for user specified SM counts.\n" + " To restrict a kernel's resource usage, consider using CUDA driver APIs instead (green contexts)."); + } + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + // Epilogue + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(problem_shapes, args.epilogue, args.hw_info.sm_count); + workspace_offset = round_nearest(workspace_offset, MinTensorMapWorkspaceAlignment); + + void* mainloop_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveMainloop::get_workspace_size(problem_shapes, args.mainloop, args.hw_info.sm_count); + workspace_offset = round_nearest(workspace_offset, MinTensorMapWorkspaceAlignment); + + // Tile scheduler + void* scheduler_workspace = workspace_ptr + workspace_offset; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, problem_shapes.get_host_problem_shape(0), args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinTensorMapWorkspaceAlignment); + + TileSchedulerParams scheduler; + if constexpr (IsGroupedGemmKernel) { + scheduler = TileScheduler::to_underlying_arguments( + problem_shapes, TileShape{}, AtomThrShapeMNK{}, ClusterShape{}, + args.hw_info, args.scheduler, scheduler_workspace); + } + else { + scheduler = TileScheduler::to_underlying_arguments( + problem_shapes.get_host_problem_shape(), TileShape{}, AtomThrShapeMNK{}, ClusterShape{}, + args.hw_info, args.scheduler, scheduler_workspace + ); + } + + return { + args.mode, + problem_shapes, + CollectiveMainloop::to_underlying_arguments(problem_shapes, args.mainloop, mainloop_workspace, args.hw_info), + CollectiveEpilogue::to_underlying_arguments(problem_shapes, args.epilogue, epilogue_workspace), + scheduler, + args.hw_info + }; + } + + static bool + can_implement(Arguments const& args) { + bool implementable = true; + if constexpr (IsGroupedGemmKernel) { + // Group GEMM currently only supports rank-3 problem shapes + implementable &= (args.mode == GemmUniversalMode::kGrouped && rank(typename ProblemShape::UnderlyingProblemShape{}) == 3); + } else { + implementable &= (args.mode == GemmUniversalMode::kArray && rank(typename ProblemShape::UnderlyingProblemShape{}) == 4); + } + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements for Ptr Array Gemm or Grouped Gemm.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Mainloop, Epilogue or Scheduler don't meet the requirements for Ptr Array Gemm or Grouped Gemm.\n"); + return implementable; + } + + if constexpr (IsDynamicCluster) { + static constexpr int MaxClusterSize = 16; + implementable &= size(args.hw_info.cluster_shape) <= MaxClusterSize; + implementable &= size(args.hw_info.cluster_shape_fallback) <= MaxClusterSize; + implementable &= cutlass::detail::preferred_cluster_can_implement(args.hw_info.cluster_shape, args.hw_info.cluster_shape_fallback); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Dynamic Cluster or Preferred Cluster don't meet the requirements for Ptr Array Gemm or Grouped Gemm.\n"); + return implementable; + } + + constexpr bool IsBlockscaled = !cute::is_void_v; + if constexpr (IsBlockscaled) { + if constexpr (IsDynamicCluster) { + implementable &= cutlass::detail::preferred_cluster_can_implement(args.hw_info.cluster_shape, args.hw_info.cluster_shape_fallback); + // Special cluster check for scale factor multicasts. Due to limited size of scale factors, we can't multicast among + // more than 4 CTAs + implementable &= (args.hw_info.cluster_shape.x <= 4 && args.hw_info.cluster_shape.y <= 4 && + args.hw_info.cluster_shape_fallback.x <= 4 && args.hw_info.cluster_shape_fallback.y <= 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Cluster Shapes cannot be greater than 4.\n"); + return implementable; + } + } + else { + // Special cluster check for scale factor multicasts. Due to limited size of scale factors, we can't multicast among + // more than 4 CTAs + implementable &= ((size<0>(ClusterShape{}) <= 4) && (size<1>(ClusterShape{}) <= 4)); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Cluster Shapes cannot be greater than 4.\n"); + return implementable; + } + } + } + + return implementable; + } + + static size_t + get_workspace_size(Arguments const& args) { + constexpr uint32_t NumEpilogueSubTiles = 1; + size_t workspace_size = 0; + + // Epilogue + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, args.hw_info.sm_count); + workspace_size = round_nearest(workspace_size, MinTensorMapWorkspaceAlignment); + + // Mainloop + workspace_size += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, args.hw_info.sm_count); + workspace_size = round_nearest(workspace_size, MinTensorMapWorkspaceAlignment); + + // Tile scheduler + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape.get_host_problem_shape(0), args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_size = round_nearest(workspace_size, MinTensorMapWorkspaceAlignment); + + return workspace_size; + } + + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + constexpr uint32_t NumEpilogueSubTiles = 1; + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + // Epilogue + status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, args.hw_info.sm_count); + workspace_offset = round_nearest(workspace_offset, MinTensorMapWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + // Mainloop + status = CollectiveMainloop::initialize_workspace(args.problem_shape, args.mainloop, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, args.hw_info.sm_count); + workspace_offset = round_nearest(workspace_offset, MinTensorMapWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + // Tile scheduler + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape.get_host_problem_shape(0), args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs, cuda_adapter); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape.get_host_problem_shape(0), args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinTensorMapWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + // NOTE: cluster_shape here is the major cluster shape, not fallback one + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, params.hw_info.cluster_shape); + + dim3 grid_shape; + if constexpr (IsGroupedGemmKernel) { + grid_shape = TileScheduler::get_grid_shape( + params.scheduler, + params.problem_shape, + TileShape{}, + AtomThrShapeMNK{}, + cluster_shape, + params.hw_info); + } + else { + grid_shape = TileScheduler::get_grid_shape( + params.scheduler, + params.problem_shape.get_host_problem_shape(), + TileShape{}, + AtomThrShapeMNK{}, + cluster_shape, + params.hw_info); + } + return grid_shape; + } + + static constexpr + dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + +private: + + static constexpr + CUTLASS_DEVICE + void set_warpgroup_reg_dealloc() { + cutlass::arch::warpgroup_reg_dealloc(); + } + + static constexpr + CUTLASS_DEVICE + void set_warpgroup_reg_alloc() { + cutlass::arch::warpgroup_reg_alloc(); + } + +public: + + CUTLASS_DEVICE + void + operator() (Params const& params, char* smem_buf) { + + using namespace cute; + using X = Underscore; + + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); + auto problem_shape = params.problem_shape; + + // Account for more than one epilogue warp + int warp_idx = canonical_warp_idx_sync(); + WarpCategory warp_category = (warp_idx >= static_cast(WarpCategory::Epilogue) && warp_idx < static_cast(WarpCategory::EpilogueLoad)) ? WarpCategory::Epilogue : + WarpCategory(warp_idx); + if (warp_idx > static_cast(WarpCategory::EpilogueLoad)) { + warp_category = WarpCategory::Unused; + } + + uint32_t lane_predicate = cute::elect_one_sync(); + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, cute::cluster_shape()); + int cluster_size = size(cluster_shape); + uint32_t cta_rank_in_cluster = cute::block_rank_in_cluster(); + bool is_first_cta_in_cluster = IsSchedDynamicPersistent ? (cta_rank_in_cluster == 0) : true; + int cta_coord_v = cta_rank_in_cluster % size<0>(typename TiledMma::AtomThrID{}); + bool is_mma_leader_cta = cta_coord_v == 0; + constexpr bool has_mma_peer_cta = size(AtomThrShapeMNK{}) == 2; + [[maybe_unused]] uint32_t mma_peer_cta_rank = has_mma_peer_cta ? cta_rank_in_cluster ^ 1 : cta_rank_in_cluster; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop(params.mainloop); + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Do we load source tensor C or other aux inputs + bool is_epi_load_needed = collective_epilogue.is_producer_load_needed(); + IsParticipant is_participant = { + (warp_category == WarpCategory::MMA), // mma + (warp_category == WarpCategory::Sched) && is_first_cta_in_cluster, // sched + (warp_category == WarpCategory::MainloopABLoad), // main_ab_load + (warp_category == WarpCategory::EpilogueLoad) && is_epi_load_needed, // epi_load + (warp_category == WarpCategory::Epilogue), // epilogue + (warp_category == WarpCategory::MainloopSFLoad), // main_sf_load + (warp_category == WarpCategory::Unused) // empty + }; + + // Mainloop Load pipeline + typename MainloopABPipeline::Params mainloop_ab_pipeline_params; + if (WarpCategory::MainloopABLoad == warp_category) { + mainloop_ab_pipeline_params.role = MainloopABPipeline::ThreadCategory::Producer; + } + if (WarpCategory::MMA == warp_category) { + mainloop_ab_pipeline_params.role = MainloopABPipeline::ThreadCategory::Consumer; + } + mainloop_ab_pipeline_params.is_leader = lane_predicate && is_mma_leader_cta && is_participant.main_ab_load; + mainloop_ab_pipeline_params.transaction_bytes = CollectiveMainloop::ABTmaTransactionBytes; + mainloop_ab_pipeline_params.initializing_warp = 0; + MainloopABPipeline mainloop_ab_pipeline(shared_storage.pipelines.mainloop.pipeline_ab, + mainloop_ab_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{}); // Delay mask calculation + + // Mainloop SF load pipeline + typename MainloopSFPipeline::Params mainloop_sf_pipeline_params; + if (WarpCategory::MainloopSFLoad == warp_category) { + mainloop_sf_pipeline_params.role = MainloopSFPipeline::ThreadCategory::Producer; + } + if (WarpCategory::MMA == warp_category) { + mainloop_sf_pipeline_params.role = MainloopSFPipeline::ThreadCategory::Consumer; + } + mainloop_sf_pipeline_params.is_leader = lane_predicate && is_mma_leader_cta && is_participant.main_sf_load; + mainloop_sf_pipeline_params.transaction_bytes = CollectiveMainloop::SFTransactionBytes; + mainloop_sf_pipeline_params.initializing_warp = 0; + MainloopSFPipeline mainloop_sf_pipeline(shared_storage.pipelines.mainloop.pipeline_sf, + mainloop_sf_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{}); // Delay mask calculation + + // Epilogue Load pipeline + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (WarpCategory::EpilogueLoad == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cta_rank_in_cluster; + epi_load_pipeline_params.producer_arv_count = NumEpilogueLoadThreads; + epi_load_pipeline_params.consumer_arv_count = NumEpilogueThreads; + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + epi_load_pipeline_params.initializing_warp = 4; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // Load order barrier + typename LoadOrderBarrier::Params load_order_barrier_params; + load_order_barrier_params.group_id = (warp_category == WarpCategory::MainloopABLoad || warp_category == WarpCategory::MainloopSFLoad) ? 0 : 1; + load_order_barrier_params.group_size = NumMainloopABLoadThreads + NumMainloopSFLoadThreads; + load_order_barrier_params.initializing_warp = 5; + LoadOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, load_order_barrier_params); + + // CLC pipeline + typename CLCPipeline::Params clc_pipeline_params; + if (WarpCategory::Sched == warp_category) { + clc_pipeline_params.role = IsSchedDynamicPersistent ? + CLCPipeline::ThreadCategory::ProducerConsumer : + CLCPipeline::ThreadCategory::Producer; + } + else { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::Consumer; + } + + clc_pipeline_params.initializing_warp = 1; + clc_pipeline_params.producer_arv_count = 1; + + if constexpr (IsSchedDynamicPersistent) { + clc_pipeline_params.producer_blockid = 0; + clc_pipeline_params.consumer_arv_count = NumSchedThreads + cluster_size * + (NumMainloopABLoadThreads + NumMainloopSFLoadThreads + NumEpilogueThreads + NumMMAThreads); + if (is_epi_load_needed) { + clc_pipeline_params.consumer_arv_count += cluster_size * NumEpilogueLoadThreads; + } + clc_pipeline_params.transaction_bytes = CLCResponseSize; + } + else { + clc_pipeline_params.consumer_arv_count = NumMainloopABLoadThreads + NumMainloopSFLoadThreads + NumEpilogueThreads + NumMMAThreads; + if (is_epi_load_needed) { + clc_pipeline_params.consumer_arv_count += NumEpilogueLoadThreads; + } + } + // Now declare the pipeline outside the if constexpr + CLCPipeline clc_pipeline = [&]() { + if constexpr (IsSchedDynamicPersistent) { + return CLCPipeline(shared_storage.pipelines.clc, clc_pipeline_params, cluster_shape); + } + else { + return CLCPipeline(shared_storage.pipelines.clc, clc_pipeline_params); + } + }(); + + // Mainloop-Epilogue pipeline + typename AccumulatorPipeline::Params accumulator_pipeline_params; + if (WarpCategory::MMA == warp_category) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; + } + // Only one producer thread arrives on this barrier. + accumulator_pipeline_params.producer_arv_count = 1; + accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * NumEpilogueThreads; + accumulator_pipeline_params.initializing_warp = 2; + AccumulatorPipeline accumulator_pipeline(shared_storage.pipelines.accumulator, + accumulator_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{}); // Delay mask calculation + + // CLC throttle pipeline + typename CLCThrottlePipeline::Params clc_throttle_pipeline_params; + if constexpr (IsSchedDynamicPersistent) { + if (WarpCategory::MainloopABLoad == warp_category || WarpCategory::MainloopSFLoad== warp_category) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Producer; + } + if (WarpCategory::Sched == warp_category) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Consumer; + } + clc_throttle_pipeline_params.producer_arv_count = NumMainloopSFLoadThreads; + clc_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; + clc_throttle_pipeline_params.dst_blockid = 0; + clc_throttle_pipeline_params.initializing_warp = 3; + } + CLCThrottlePipeline clc_throttle_pipeline(shared_storage.pipelines.clc_throttle, clc_throttle_pipeline_params); + CLCThrottlePipelineState clc_pipe_throttle_consumer_state; + CLCThrottlePipelineState clc_pipe_throttle_producer_state = cutlass::make_producer_start_state(); + + // Tmem allocator + TmemAllocator tmem_allocator{}; + + // Sync allocation status between MMA and epilogue warps within CTA + arch::NamedBarrier tmem_allocation_result_barrier(NumMMAThreads + NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + // Sync deallocation status between MMA warps of peer CTAs + arch::ClusterBarrier& tmem_deallocation_result_barrier = shared_storage.pipelines.tmem_dealloc; + [[maybe_unused]] uint32_t dealloc_barrier_phase = 0; + if constexpr(!IsOverlappingAccum) { + if (WarpCategory::MMA == warp_category && has_mma_peer_cta && lane_predicate) { + tmem_deallocation_result_barrier.init(NumMMAThreads); + } + } + else { + if (WarpCategory::MMA == warp_category && has_mma_peer_cta && lane_predicate) { + tmem_deallocation_result_barrier.init(NumEpilogueThreads*2); + } + else if (WarpCategory::MMA == warp_category && lane_predicate) { + tmem_deallocation_result_barrier.init(NumEpilogueThreads); + } + } + + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer threadblocks in the cluster + pipeline_init_arrive_relaxed(cluster_size); + + MainloopABPipelineState mainloop_ab_pipe_consumer_state; + MainloopABPipelineState mainloop_ab_pipe_producer_state = cutlass::make_producer_start_state(); + + MainloopSFPipelineState mainloop_sf_pipe_consumer_state; + MainloopSFPipelineState mainloop_sf_pipe_producer_state = cutlass::make_producer_start_state(); + + EpiLoadPipelineState epi_load_pipe_consumer_state; + EpiLoadPipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + + // epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + EpiStorePipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + CLCPipelineState clc_pipe_consumer_state; + CLCPipelineState clc_pipe_producer_state = cutlass::make_producer_start_state(); + + AccumulatorPipelineState accumulator_pipe_consumer_state; + AccumulatorPipelineState accumulator_pipe_producer_state = cutlass::make_producer_start_state(); + + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + int32_t sm_id = static_cast(cutlass::arch::SmId()); + + // Calculate mask after cluster barrier arrival + mainloop_ab_pipeline.init_masks(cluster_shape); + mainloop_sf_pipeline.init_masks(cluster_shape); + accumulator_pipeline.init_masks(cluster_shape); + // + // TMEM "Allocation" + // + // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. + TiledMma tiled_mma; + ThrMMA cta_mma = tiled_mma.get_slice(cta_coord_v); + auto acc_shape = partition_shape_C(tiled_mma, take<0,2>(TileShape{})); + Tensor accumulators = cutlass::detail::make_sm100_accumulator( + tiled_mma, acc_shape, EpilogueTile{}); + + // Ensure memory ops in this kernel are not done prior to completion of dependent grids. + cutlass::arch::wait_on_dependent_grids(); + + // TileID scheduler + TileScheduler scheduler(&shared_storage.clc_response[0], params.scheduler, block_id_in_cluster); + + typename TileScheduler::WorkTileInfo work_tile_info = scheduler.initial_work_tile_info(cluster_shape); + auto cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + + pipeline_init_wait(cluster_size); + + if constexpr (IsGroupedGemmKernel) { + if (not work_tile_info.is_valid()) { + // When problem shapes are only on device, the grid launched may be larger than the total number of blocks across groups + return; + } + // In case user wants to engage less SMs than available on device + sm_id = blockIdx.x + (blockIdx.y * gridDim.x); + } + + auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(work_tile_info.L_idx), 1); + + if (is_participant.main_ab_load) { + set_warpgroup_reg_dealloc(); + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + bool do_load_order_arrive = is_epi_load_needed; + auto load_inputs = collective_mainloop.load_ab_init( + problem_shape_MNKL, params.mainloop, shared_storage.tensors.mainloop, + shared_storage.tensormaps.mainloop, + params.hw_info.sm_count, sm_id); + Tensor gA_mkl = get<0>(load_inputs); + // Fetch a copy of tensormaps for the CTA from Params + auto input_tensormaps = get(load_inputs); + + // Initial batch's tensor address update + // Even the first tile for a CTA can be from any of the batches. + // And during initialization of the first TMA descriptor on host, we don't initialize + bool did_batch_change = true; + bool requires_clc_query = true; + // 2cta: 4x4/4x2/2x4 enable the PF + bool enable_prefetch = shape<0>(AtomThrShapeMNK{}) == 2 and + (size<0>(cluster_shape) == 4 and size<1>(cluster_shape) == 4) or + (size<0>(cluster_shape) == 4 and size<1>(cluster_shape) == 2) or + (size<0>(cluster_shape) == 2 and size<1>(cluster_shape) == 4); + + do { + int32_t curr_batch = idx2crd(work_tile_info.L_idx, shape<4>(gA_mkl)); // Usually just returns work_tile_info.L_idx; + + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(curr_batch), 1); + } + if (did_batch_change) { + collective_mainloop.tensormaps_perform_update_ab( + shared_storage.tensormaps.mainloop, + params.mainloop, + input_tensormaps, + problem_shape, + curr_batch + ); + } + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + auto k_tile_iter = scheduler.get_k_tile_iterator(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}, shape<3>(gA_mkl)); + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + auto k_tile_prologue = min(MainloopABPipeline::Stages, k_tile_count); + // Problem Shape and therefore strides that we construct are [M,N,K,L], but since here for the TMA loads + // we are managing TMA descriptors to change batches, we need to neglect the L mode + auto cta_coord_mnk = append<4>(make_coord(get<0>(cta_coord_mnkl), get<1>(cta_coord_mnkl), get<2>(cta_coord_mnkl)), Int<0>{}); + + // Start mainloop prologue loads, arrive on the epilogue residual load barrier, resume mainloop loads + auto [mainloop_producer_state_next, k_tile_iter_next] = collective_mainloop.load_ab( + params.mainloop, + mainloop_ab_pipeline, + mainloop_ab_pipe_producer_state, + load_inputs, + cta_coord_mnk, + k_tile_iter, k_tile_prologue, + did_batch_change, + enable_prefetch ? k_tile_count : 0 + ); + mainloop_ab_pipe_producer_state = mainloop_producer_state_next; + + if (do_load_order_arrive) { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + + auto [mainloop_producer_state_next_, unused_] = collective_mainloop.load_ab( + params.mainloop, + mainloop_ab_pipeline, + mainloop_ab_pipe_producer_state, + load_inputs, + cta_coord_mnk, + k_tile_iter_next, k_tile_count - k_tile_prologue, + false, /* did_batch_change - prologue loads handle tensormap acquire */ + enable_prefetch ? k_tile_count - k_tile_prologue : 0 + ); + mainloop_ab_pipe_producer_state = mainloop_producer_state_next_; + + // Sync warp to prevent non-participating threads entering next wave early + __syncwarp(); + + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + // For subsequent tiles, check if batch changes and therefore, we need tensormap updates + did_batch_change = curr_batch != idx2crd(work_tile_info.L_idx, shape<4>(gA_mkl)); + + } while (work_tile_info.is_valid()); + collective_mainloop.load_tail(mainloop_ab_pipeline, mainloop_ab_pipe_producer_state); + + } + + else if (is_participant.sched) { + set_warpgroup_reg_dealloc(); + + if constexpr (IsSchedDynamicPersistent) { + // Whether a new CLC query must be performed. + // See comment below where this variable is updated for a description of + // why this variable is needed. + bool requires_clc_query = true; + + do { + if (requires_clc_query) { + // Throttle CLC query to mitigate workload imbalance caused by skews among persistent workers. + clc_throttle_pipeline.consumer_wait(clc_pipe_throttle_consumer_state); + clc_throttle_pipeline.consumer_release(clc_pipe_throttle_consumer_state); + ++clc_pipe_throttle_consumer_state; + // Query next clcID and update producer state + clc_pipe_producer_state = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state); + } + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + // Only perform a new CLC query if we consumed a new CLC query result in + // `fetch_next_work`. An example of a case in which CLC `fetch_next_work` does + // not consume a new CLC query response is when processing stream-K units. + // The current stream-K scheduler uses single WorkTileInfo to track multiple + // (potentially-partial) tiles to be computed via stream-K. In this case, + // `fetch_next_work` simply performs in-place updates on the existing WorkTileInfo, + // rather than consuming a CLC query response. + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + work_tile_info = next_work_tile_info; + } while (work_tile_info.is_valid()); + clc_pipeline.producer_tail(clc_pipe_producer_state); + } + else { + do { + auto [next_work_tile_info, increment_pipe] = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state); + work_tile_info = next_work_tile_info; + if (increment_pipe) { + ++clc_pipe_producer_state; + } + } while (work_tile_info.is_valid()); + clc_pipeline.producer_tail(clc_pipe_producer_state); + } + } + + else if (is_participant.main_sf_load) { + set_warpgroup_reg_dealloc(); + bool do_load_order_arrive = is_epi_load_needed; + auto load_inputs = collective_mainloop.load_sf_init( + problem_shape_MNKL, params.mainloop, shared_storage.tensors.mainloop, + shared_storage.tensormaps.mainloop, + params.hw_info.sm_count, sm_id, work_tile_info.L_idx); + + auto gA_mkl = collective_mainloop.get_mkl_shape_tensor(problem_shape_MNKL); + auto input_tensormaps = get(load_inputs); + + // Initial batch's tensor address update + // Even the first tile for a CTA can be from any of the batches. + // And during initialization of the first TMA descriptor on host, we don't initialize to the first batch due to that args value being device-only. + bool did_batch_change = true; + + bool requires_clc_query = true; + // 2cta: 4x4/4x2/2x4 enable the PF + bool enable_prefetch = shape<0>(AtomThrShapeMNK{}) == 2 and + (size<0>(cluster_shape) == 4 and size<1>(cluster_shape) == 4) or + (size<0>(cluster_shape) == 4 and size<1>(cluster_shape) == 2) or + (size<0>(cluster_shape) == 2 and size<1>(cluster_shape) == 4); + do { + int32_t curr_batch = idx2crd(work_tile_info.L_idx, shape<4>(gA_mkl)); // Usually just returns work_tile_info.L_idx; + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(curr_batch), 1); + } + if (did_batch_change) { + collective_mainloop.tensormaps_perform_update_sf( + shared_storage.tensormaps.mainloop, + params.mainloop, + input_tensormaps, + problem_shape, + curr_batch + ); + } + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + auto k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); + auto k_tile_prologue = min(MainloopSFPipeline::Stages/2, k_tile_count); + auto k_tile_iter = cute::make_coord_iterator(idx2crd(k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); // maybe we could use ceil_div(gSFA_mkl, 2); + auto cta_coord_mnk = append<4>(make_coord(get<0>(cta_coord_mnkl), get<1>(cta_coord_mnkl), get<2>(cta_coord_mnkl)), Int<0>{}); + if constexpr (IsSchedDynamicPersistent) { + if (is_first_cta_in_cluster && requires_clc_query) { + clc_throttle_pipeline.producer_acquire(clc_pipe_throttle_producer_state); + clc_throttle_pipeline.producer_commit(clc_pipe_throttle_producer_state); + ++clc_pipe_throttle_producer_state; + } + } + // Start mainloop prologue loads, arrive on the epilogue residual load barrier, resume mainloop loads + auto [mainloop_producer_state_next, k_tile_iter_next] = collective_mainloop.load_sf( + params.mainloop, + mainloop_sf_pipeline, + mainloop_sf_pipe_producer_state, + load_inputs, + cta_coord_mnk, + k_tile_iter, k_tile_prologue, + did_batch_change, + enable_prefetch ? k_tile_count : 0 + ); + mainloop_sf_pipe_producer_state = mainloop_producer_state_next; + + if (do_load_order_arrive) { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + + auto [mainloop_producer_state_next_, unused_] = collective_mainloop.load_sf( + params.mainloop, + mainloop_sf_pipeline, + mainloop_sf_pipe_producer_state, + load_inputs, + cta_coord_mnk, + k_tile_iter_next, k_tile_count - k_tile_prologue, + false, /* did_batch_change - prologue loads handle tensormap acquire */ + enable_prefetch ? k_tile_count - k_tile_prologue : 0 + ); + mainloop_sf_pipe_producer_state = mainloop_producer_state_next_; + + // Sync warp to prevent non-participating threads entering next wave early + __syncwarp(); + + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + // For subsequent tiles, check if batch changes and therefore, we need tensormap updates + did_batch_change = curr_batch != idx2crd(work_tile_info.L_idx, shape<4>(gA_mkl)); + } while (work_tile_info.is_valid()); + collective_mainloop.load_tail(mainloop_sf_pipeline, mainloop_sf_pipe_producer_state); + + } + + else if (is_participant.mma) { + set_warpgroup_reg_dealloc(); + // Tmem allocation sequence + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + accumulators.data() = tmem_base_ptr; + int tmem_non_accumulator_base = tmem_base_ptr + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + auto mma_inputs = collective_mainloop.mma_init(params.mainloop, + shared_storage.tensors.mainloop, + tmem_non_accumulator_base /*Start SF TMEM allocation after the accumulator*/); + + do { + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(work_tile_info.L_idx), 1); + } + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + // Wait for tmem accumulator buffer to become empty with a flipped phase + if constexpr (!IsOverlappingAccum) { + if (is_mma_leader_cta) { + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + } + } + int stage_idx = (IsOverlappingAccum) ? (accumulator_pipe_producer_state.phase() ^ 1) : (accumulator_pipe_producer_state.index()); + Tensor accumulator = accumulators(_,_,_, stage_idx); + + if (is_mma_leader_cta) { + auto [mainloop_ab_pipe_consumer_state_next, mainloop_sf_pipe_consumer_state_next] = collective_mainloop.mma( + cute::make_tuple(mainloop_ab_pipeline, mainloop_sf_pipeline, accumulator_pipeline), + cute::make_tuple(mainloop_ab_pipe_consumer_state, mainloop_sf_pipe_consumer_state, accumulator_pipe_producer_state), + accumulator, + mma_inputs, + cta_coord_mnkl, + k_tile_count + ); + + mainloop_ab_pipe_consumer_state = mainloop_ab_pipe_consumer_state_next; + mainloop_sf_pipe_consumer_state = mainloop_sf_pipe_consumer_state_next; + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + } + + + ++accumulator_pipe_producer_state; + + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + } while (work_tile_info.is_valid()); + + // Hint on an early release of global memory resources. + // The timing of calling this function only influences performance, + // not functional correctness. + cutlass::arch::launch_dependent_grids(); + + // Release the right to allocate before deallocations so that the next CTA can rasterize + tmem_allocator.release_allocation_lock(); + + if constexpr (!IsOverlappingAccum) { + // Leader MMA waits for leader + peer epilogues to release accumulator stage + if (is_mma_leader_cta) { + accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); + } + // Signal to peer MMA that entire tmem allocation can be deallocated + if constexpr (has_mma_peer_cta) { + // Leader does wait + arrive, follower does arrive + wait + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, not is_mma_leader_cta); + tmem_deallocation_result_barrier.wait(dealloc_barrier_phase); + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, is_mma_leader_cta); + } + } + else { + tmem_deallocation_result_barrier.wait(dealloc_barrier_phase); + } + + // Free entire tmem allocation + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + else if (is_participant.epi_load) { + set_warpgroup_reg_dealloc(); + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + bool do_load_order_wait = true; + bool do_tail_load = false; + int current_wave = 0; + + // Fetch a copy of tensormaps for the CTA from Params + auto epi_load_tensormap = get<0>(collective_epilogue.load_init( + params.epilogue, shared_storage.tensormaps.epilogue, params.hw_info.sm_count, sm_id)); + + bool did_batch_change = true; + constexpr bool IsEpiLoad = true; + + do { + int32_t curr_batch = work_tile_info.L_idx; + if (did_batch_change) { + collective_epilogue.template tensormaps_perform_update( + shared_storage.tensormaps.epilogue, + params.epilogue, + epi_load_tensormap, + problem_shape, + curr_batch + ); + } + + bool compute_epilogue = TileScheduler::compute_epilogue(work_tile_info, params.scheduler); + + // Get current work tile and fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + if (compute_epilogue) { + if (do_load_order_wait) { + load_order_barrier.wait(); + do_load_order_wait = false; + } + + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(curr_batch), 1); + } + + bool reverse_epi_n = IsOverlappingAccum && (current_wave % 2 == 0); + epi_load_pipe_producer_state = collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + shared_storage.tensors.epilogue, + cute::make_tuple(epi_load_tensormap, did_batch_change), + reverse_epi_n + ); + + do_tail_load = true; + } + current_wave++; + + // Calculate the cta coordinates of the next work tile + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + // For subsequent tiles, check if batch changes and therefore, we need tensormap updates + did_batch_change = curr_batch != work_tile_info.L_idx; + } while (work_tile_info.is_valid()); + + // Only perform a tail load if one of the work units processed performed + // an epilogue load. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_load) { + collective_epilogue.load_tail( + epi_load_pipeline, epi_load_pipe_producer_state, + epi_store_pipeline, epi_store_pipe_producer_state); + } + } + + else if (is_participant.epilogue) { + set_warpgroup_reg_alloc(); + // Wait for tmem allocate here + tmem_allocation_result_barrier.arrive_and_wait(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + accumulators.data() = tmem_base_ptr; + + auto warp_idx_in_epi = canonical_warp_idx_sync() - static_cast(WarpCategory::Epilogue); + bool do_tail_store = false; + // Fetch a copy of tensormaps for the CTA from Params + auto epi_store_tensormap = get<0>(collective_epilogue.store_init( + params.epilogue, shared_storage.tensormaps.epilogue, params.hw_info.sm_count, sm_id)); + // Initial batch's tensor address update + // Even the first tile for a CTA can be from any of the batches. + // And during initialization of the first TMA descriptor on host, we don't initialize to the first batch due to that args value being device-only. + bool did_batch_change = true; + constexpr bool IsEpiLoad = false; + do { + int32_t curr_batch = work_tile_info.L_idx; + + + if (did_batch_change && warp_idx_in_epi == 0) { + collective_epilogue.template tensormaps_perform_update( + shared_storage.tensormaps.epilogue, + params.epilogue, + epi_store_tensormap, + problem_shape, + curr_batch + ); + } + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + // Accumulator stage slice after making sure allocation has been performed + int acc_stage = [&] () { + if constexpr (IsOverlappingAccum) { + return accumulator_pipe_consumer_state.phase(); + } + else { + return accumulator_pipe_consumer_state.index(); + } + }(); + + // Fusions may need problem shape for the current group + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(curr_batch), 1); + } + + // Epilogue and write to gD + // + auto [load_state_next, store_state_next, acc_state_next] = collective_epilogue.template store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + accumulator_pipeline, + accumulator_pipe_consumer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + collective_mainloop.slice_accumulator(accumulators, acc_stage), + shared_storage.tensors.epilogue, + cute::make_tuple(epi_store_tensormap, did_batch_change) + ); + epi_load_pipe_consumer_state = load_state_next; + epi_store_pipe_producer_state = store_state_next; + accumulator_pipe_consumer_state = acc_state_next; + + do_tail_store |= TileScheduler::compute_epilogue(work_tile_info, params.scheduler); + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + // For subsequent tiles, check if batch changes and therefore, we need tensormap updates + did_batch_change = curr_batch != work_tile_info.L_idx; + } while (work_tile_info.is_valid()); + + if constexpr (IsOverlappingAccum) { + // Signal to peer MMA that Full TMEM alloc can be deallocated + if constexpr (has_mma_peer_cta) { + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank); + } + tmem_deallocation_result_barrier.arrive(); + } + + // Only perform a tail store if one of the work units processed performed + // an epilogue. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_store) { + collective_epilogue.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state, + epi_store_pipeline, epi_store_pipe_producer_state, + CtaShape_MNK{}); + } + + } + + else { + set_warpgroup_reg_dealloc(); + } + + } + + +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_tma_warpspecialized.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_tma_warpspecialized.hpp new file mode 100644 index 0000000..455464e --- /dev/null +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_tma_warpspecialized.hpp @@ -0,0 +1,1121 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/workspace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/arch/grid_dependency_control.h" +#include "cutlass/fast_math.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/barrier.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/sm100_tile_scheduler.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/detail/sm100_tmem_helper.hpp" + +#include "cute/tensor.hpp" +#include "cute/arch/tmem_allocator_sm100.hpp" +#include "cute/atom/mma_atom.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileScheduler_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_, + cute::enable_if_t< + cutlass::detail::is_kernel_tag_of_v>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using LayoutSFA = typename CollectiveMainloop::LayoutSFA; + using LayoutSFB = typename CollectiveMainloop::LayoutSFB; + using ElementSF = typename CollectiveMainloop::ElementSF; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static_assert(ArchTag::kMinComputeCapability >= 100); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using EpilogueTile = typename CollectiveEpilogue::EpilogueTile; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + static constexpr bool IsNoSmemEpilogue = is_same_v; + static constexpr bool IsComplex = CollectiveEpilogue::NumAccumulatorMtxs == 2; + + // CLC pipeline depth + // determines how many waves (stages-1) a warp can race ahead + static constexpr uint32_t SchedulerPipelineStageCount = DispatchPolicy::Schedule::SchedulerPipelineStageCount; + static constexpr uint32_t AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; + static constexpr bool IsOverlappingAccum = DispatchPolicy::IsOverlappingAccum; + + // TileID scheduler + // Get Blk and Scheduling tile shapes + using AtomThrShapeMNK = typename CollectiveMainloop::AtomThrShapeMNK; + using CtaShape_MNK = typename CollectiveMainloop::CtaShape_MNK; + using TileSchedulerTag = TileScheduler_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileSchedulerTag, ArchTag, CtaShape_MNK, ClusterShape, SchedulerPipelineStageCount>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + static constexpr bool IsSchedDynamicPersistent = TileScheduler::IsDynamicPersistent; + static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled; + + // Warp specialization thread count per threadblock + static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMMAThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMainloopABLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMainloopSFLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueThreads = CollectiveEpilogue::ThreadCount; + static constexpr uint32_t NumEpilogueWarps = NumEpilogueThreads / NumThreadsPerWarp; + static constexpr uint32_t NumEpilogueLoadThreads = IsNoSmemEpilogue ? 0 : NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEmptyThreads = IsNoSmemEpilogue ? 0 : 3 * NumThreadsPerWarp; // 3 warp + + static constexpr uint32_t MaxThreadsPerBlock = NumSchedThreads + + NumMainloopABLoadThreads + NumMainloopSFLoadThreads + NumMMAThreads + + NumEpilogueLoadThreads + NumEpilogueThreads + NumEmptyThreads; + + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + static constexpr uint32_t NumFixupBarriers = 1; + static constexpr uint32_t CLCResponseSize = sizeof(typename TileScheduler::CLCResponse); + + // Pipeline and pipeline state types + using MainloopABPipeline = typename CollectiveMainloop::MainloopABPipeline; + using MainloopABPipelineState = typename CollectiveMainloop::MainloopABPipelineState; + + using MainloopSFPipeline = typename CollectiveMainloop::MainloopSFPipeline; + using MainloopSFPipelineState = typename CollectiveMainloop::MainloopSFPipelineState; + + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + using EpiLoadPipelineState = typename CollectiveEpilogue::LoadPipelineState; + + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + using EpiStorePipelineState = typename CollectiveEpilogue::StorePipelineState; + + using LoadOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; + + using AccumulatorPipeline = cutlass::PipelineUmmaAsync; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + + using CLCPipeline = cutlass::PipelineCLCFetchAsync; + using CLCPipelineState = typename CLCPipeline::PipelineState; + + using CLCThrottlePipeline = cutlass::PipelineAsync; + using CLCThrottlePipelineState = typename CLCThrottlePipeline::PipelineState; + + using TmemAllocator = cute::conditional_t(typename TiledMma::ThrLayoutVMNK{})) == 1, + cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>; + + static constexpr int EpilogueWarpRegs = 248; + static constexpr int NonEpilogueWarpRegs = 128; + + // Kernel level shared memory storage + struct SharedStorage { + // Barriers should be allocated in lower 8KB of SMEM for SM100 + struct PipelineStorage : cute::aligned_struct<16, _1> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + using LoadOrderBarrierStorage = typename LoadOrderBarrier::SharedStorage; + using CLCPipelineStorage = typename CLCPipeline::SharedStorage; + using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; + using CLCThrottlePipelineStorage = typename CLCThrottlePipeline::SharedStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) LoadOrderBarrierStorage load_order; + alignas(16) CLCPipelineStorage clc; + alignas(16) AccumulatorPipelineStorage accumulator; + alignas(16) CLCThrottlePipelineStorage clc_throttle; + alignas(8) arch::ClusterBarrier tmem_dealloc; + } pipelines; + + alignas(16) typename TileScheduler::CLCResponse clc_response[SchedulerPipelineStageCount]; + uint32_t tmem_base_ptr; + + struct TensorStorage : cute::aligned_struct<128, _1> { + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + + EpilogueTensorStorage epilogue; + MainloopTensorStorage mainloop; + } tensors; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Host facing host arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel device entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + TileSchedulerParams scheduler{}; + KernelHardwareInfo hw_info{}; + }; + + enum class WarpCategory : int32_t { + MMA = 0, + Sched = 1, + MainloopABLoad = 2, + MainloopSFLoad = 3, + Epilogue = 4, // Warps [4-8) + EpilogueLoad = 8, + Unused = 9 + }; + + struct IsParticipant { + uint32_t mma = false; + uint32_t sched = false; + uint32_t main_ab_load = false; + uint32_t epi_load = false; + uint32_t epilogue = false; + uint32_t main_sf_load = false; + uint32_t unused = false; + }; + + // + // Methods + // + + // Convert to underlying arguments. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + (void) workspace; + auto problem_shape = args.problem_shape; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + constexpr int NumEpilogueSubTiles = 1; + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count != 0) { + CUTLASS_TRACE_HOST(" WARNING: SM100 tile scheduler does not allow for user specified SM counts.\n" + " To restrict a kernel's resource usage, consider using CUDA driver APIs instead (green contexts)."); + } + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + // Epilogue + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* mainloop_workspace = nullptr; + + // Tile scheduler + void* scheduler_workspace = workspace_ptr + workspace_offset; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + return { + args.mode, + args.problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace, args.hw_info), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), + TileScheduler::to_underlying_arguments( + problem_shape_MNKL, TileShape{}, AtomThrShapeMNK{}, ClusterShape{}, + args.hw_info, args.scheduler, scheduler_workspace + ) + ,args.hw_info + }; + } + + static bool + can_implement(Arguments const& args) { + bool implementable = (args.mode == GemmUniversalMode::kGemm) or + (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + + if constexpr (IsDynamicCluster) { + implementable &= cutlass::detail::preferred_cluster_can_implement(args.hw_info.cluster_shape, args.hw_info.cluster_shape_fallback); + // Special cluster shape check for scale factor multicasts. Due to limited size of scale factors, we can't multicast among + // more than 4 CTAs + implementable &= (args.hw_info.cluster_shape.x <= 4 && args.hw_info.cluster_shape.y <= 4 && + args.hw_info.cluster_shape_fallback.x <= 4 && args.hw_info.cluster_shape_fallback.y <= 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Cluster Shapes cannot be greater than 4.\n"); + return implementable; + } + } + else { + // Special cluster check for scale factor multicasts. Due to limited size of scale factors, we can't multicast among + // more than 4 CTAs + implementable &= ((size<0>(ClusterShape{}) <= 4) && (size<1>(ClusterShape{}) <= 4)); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Cluster Shapes cannot be greater than 4.\n"); + return implementable; + } + } + + return implementable; + } + + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_size = 0; + constexpr int NumEpilogueSubTiles = 1; + + // Epilogue + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + // Tile scheduler + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + constexpr int NumEpilogueSubTiles = 1; + + // Epilogue + status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + // Tile scheduler + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs, cuda_adapter); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, params.hw_info.cluster_shape); + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + return TileScheduler::get_grid_shape( + params.scheduler, + problem_shape_MNKL, + TileShape{}, + AtomThrShapeMNK{}, + cluster_shape, + params.hw_info); + } + + static constexpr + dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + +private: + + static constexpr + CUTLASS_DEVICE + void set_warpgroup_reg_dealloc() { + if constexpr (not IsNoSmemEpilogue) { + cutlass::arch::warpgroup_reg_dealloc(); + } + } + + static constexpr + CUTLASS_DEVICE + void set_warpgroup_reg_alloc() { + if constexpr (not IsNoSmemEpilogue) { + cutlass::arch::warpgroup_reg_alloc(); + } + } + +public: + + CUTLASS_DEVICE + void + operator() (Params const& params, char* smem_buf) { + using namespace cute; + using X = Underscore; + + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); + // Separate out problem shape for convenience + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + auto M = get<0>(problem_shape_MNKL); + auto N = get<1>(problem_shape_MNKL); + auto K = get<2>(problem_shape_MNKL); + auto L = get<3>(problem_shape_MNKL); + + // Account for more than one epilogue warp + int warp_idx = canonical_warp_idx_sync(); + WarpCategory warp_category = (warp_idx >= static_cast(WarpCategory::Epilogue) && warp_idx < static_cast(WarpCategory::EpilogueLoad)) ? WarpCategory::Epilogue : + WarpCategory(warp_idx); + if (warp_idx > static_cast(WarpCategory::EpilogueLoad)) { + warp_category = WarpCategory::Unused; + } + + uint32_t lane_predicate = cute::elect_one_sync(); + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, cute::cluster_shape()); + int cluster_size = size(cluster_shape); + uint32_t cta_rank_in_cluster = cute::block_rank_in_cluster(); + bool is_first_cta_in_cluster = cta_rank_in_cluster == 0; + int cta_coord_v = cta_rank_in_cluster % size<0>(typename TiledMma::AtomThrID{}); + bool is_mma_leader_cta = cta_coord_v == 0; + constexpr bool has_mma_peer_cta = size(AtomThrShapeMNK{}) == 2; + [[maybe_unused]] uint32_t mma_peer_cta_rank = has_mma_peer_cta ? cta_rank_in_cluster ^ 1 : cta_rank_in_cluster; + + // // Issue Tma Descriptor Prefetch from a single thread + if ((warp_category == WarpCategory::Sched) && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + } + if ((warp_category == WarpCategory::EpilogueLoad) && lane_predicate) { + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop(params.mainloop); + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Do we load source tensor C or other aux inputs + bool is_epi_load_needed = collective_epilogue.is_producer_load_needed(); + IsParticipant is_participant = { + (warp_category == WarpCategory::MMA), // mma + (warp_category == WarpCategory::Sched) && is_first_cta_in_cluster, // sched + (warp_category == WarpCategory::MainloopABLoad), // main_ab_load + (warp_category == WarpCategory::EpilogueLoad) && is_epi_load_needed, // epi_load + (warp_category == WarpCategory::Epilogue), // epilogue + (warp_category == WarpCategory::MainloopSFLoad), // main_sf_load + (warp_category == WarpCategory::Unused) // empty + }; + + // Mainloop Load pipeline + typename MainloopABPipeline::Params mainloop_ab_pipeline_params; + if (WarpCategory::MainloopABLoad == warp_category) { + mainloop_ab_pipeline_params.role = MainloopABPipeline::ThreadCategory::Producer; + // Initialize the barrier for TMA load prefetch + + } + if (WarpCategory::MMA == warp_category) { + mainloop_ab_pipeline_params.role = MainloopABPipeline::ThreadCategory::Consumer; + } + mainloop_ab_pipeline_params.is_leader = lane_predicate && is_mma_leader_cta && is_participant.main_ab_load; + mainloop_ab_pipeline_params.transaction_bytes = CollectiveMainloop::ABTmaTransactionBytes; + mainloop_ab_pipeline_params.initializing_warp = 0; + MainloopABPipeline mainloop_ab_pipeline(shared_storage.pipelines.mainloop.pipeline_ab, + mainloop_ab_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{}); // Delay mask calculation + + // Mainloop SF load pipeline + typename MainloopSFPipeline::Params mainloop_sf_pipeline_params; + if (WarpCategory::MainloopSFLoad == warp_category) { + mainloop_sf_pipeline_params.role = MainloopSFPipeline::ThreadCategory::Producer; + } + if (WarpCategory::MMA == warp_category) { + mainloop_sf_pipeline_params.role = MainloopSFPipeline::ThreadCategory::Consumer; + } + mainloop_sf_pipeline_params.is_leader = lane_predicate && is_mma_leader_cta && is_participant.main_sf_load; + mainloop_sf_pipeline_params.transaction_bytes = CollectiveMainloop::SFTransactionBytes; + mainloop_sf_pipeline_params.initializing_warp = 0; + MainloopSFPipeline mainloop_sf_pipeline(shared_storage.pipelines.mainloop.pipeline_sf, + mainloop_sf_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{}); // Delay mask calculation + + // Epilogue Load pipeline + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (WarpCategory::EpilogueLoad == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cta_rank_in_cluster; + epi_load_pipeline_params.producer_arv_count = NumEpilogueLoadThreads; + epi_load_pipeline_params.consumer_arv_count = NumEpilogueThreads; + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + epi_load_pipeline_params.initializing_warp = 4; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // Load order barrier + typename LoadOrderBarrier::Params load_order_barrier_params; + load_order_barrier_params.group_id = (warp_category == WarpCategory::MainloopABLoad || warp_category == WarpCategory::MainloopSFLoad) ? 0 : 1; + load_order_barrier_params.group_size = NumMainloopABLoadThreads + NumMainloopSFLoadThreads; + load_order_barrier_params.initializing_warp = 5; + LoadOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, load_order_barrier_params); + + // CLC pipeline + typename CLCPipeline::Params clc_pipeline_params; + if (WarpCategory::Sched == warp_category) { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::ProducerConsumer; + } + else { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::Consumer; + } + clc_pipeline_params.producer_blockid = 0; + clc_pipeline_params.producer_arv_count = 1; + clc_pipeline_params.consumer_arv_count = NumSchedThreads + cluster_size * + (NumMainloopABLoadThreads + NumMainloopSFLoadThreads + NumEpilogueThreads + NumMMAThreads); + if (is_epi_load_needed) { + clc_pipeline_params.consumer_arv_count += cluster_size * NumEpilogueLoadThreads; + } + clc_pipeline_params.transaction_bytes = CLCResponseSize; + clc_pipeline_params.initializing_warp = 1; + CLCPipeline clc_pipeline(shared_storage.pipelines.clc, clc_pipeline_params, cluster_shape); + + // Mainloop-Epilogue pipeline + typename AccumulatorPipeline::Params accumulator_pipeline_params; + if (WarpCategory::MMA == warp_category) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; + } + // Only one producer thread arrives on this barrier. + accumulator_pipeline_params.producer_arv_count = 1; + accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * NumEpilogueThreads; + accumulator_pipeline_params.initializing_warp = 2; + AccumulatorPipeline accumulator_pipeline(shared_storage.pipelines.accumulator, + accumulator_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{}); // Delay mask calculation + + // CLC throttle pipeline + typename CLCThrottlePipeline::Params clc_throttle_pipeline_params; + if (WarpCategory::MainloopABLoad == warp_category || WarpCategory::MainloopSFLoad== warp_category) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Producer; + } + if (WarpCategory::Sched == warp_category) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Consumer; + } + + clc_throttle_pipeline_params.producer_arv_count = NumMainloopSFLoadThreads; + clc_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; + clc_throttle_pipeline_params.dst_blockid = 0; + clc_throttle_pipeline_params.initializing_warp = 3; + CLCThrottlePipeline clc_throttle_pipeline(shared_storage.pipelines.clc_throttle, clc_throttle_pipeline_params); + CLCThrottlePipelineState clc_pipe_throttle_consumer_state; + CLCThrottlePipelineState clc_pipe_throttle_producer_state = cutlass::make_producer_start_state(); + + // Tmem allocator + TmemAllocator tmem_allocator{}; + + // Sync allocation status between MMA and epilogue warps within CTA + arch::NamedBarrier tmem_allocation_result_barrier(NumMMAThreads + NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + // Sync deallocation status between MMA warps of peer CTAs + arch::ClusterBarrier& tmem_deallocation_result_barrier = shared_storage.pipelines.tmem_dealloc; + [[maybe_unused]] uint32_t dealloc_barrier_phase = 0; + if constexpr(!IsOverlappingAccum) { + if (WarpCategory::MMA == warp_category && has_mma_peer_cta && lane_predicate) { + tmem_deallocation_result_barrier.init(NumMMAThreads); + } + } + else { + if (WarpCategory::MMA == warp_category && has_mma_peer_cta && lane_predicate) { + tmem_deallocation_result_barrier.init(NumEpilogueThreads*2); + } + else if (WarpCategory::MMA == warp_category && lane_predicate) { + tmem_deallocation_result_barrier.init(NumEpilogueThreads); + } + } + + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer threadblocks in the cluster + pipeline_init_arrive_relaxed(cluster_size); + + MainloopABPipelineState mainloop_ab_pipe_consumer_state; + MainloopABPipelineState mainloop_ab_pipe_producer_state = cutlass::make_producer_start_state(); + + MainloopSFPipelineState mainloop_sf_pipe_consumer_state; + MainloopSFPipelineState mainloop_sf_pipe_producer_state = cutlass::make_producer_start_state(); + + EpiLoadPipelineState epi_load_pipe_consumer_state; + EpiLoadPipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + + // epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + EpiStorePipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + CLCPipelineState clc_pipe_consumer_state; + CLCPipelineState clc_pipe_producer_state = cutlass::make_producer_start_state(); + + AccumulatorPipelineState accumulator_pipe_consumer_state; + AccumulatorPipelineState accumulator_pipe_producer_state = cutlass::make_producer_start_state(); + + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + + // Calculate mask after cluster barrier arrival + mainloop_ab_pipeline.init_masks(cluster_shape); + mainloop_sf_pipeline.init_masks(cluster_shape); + accumulator_pipeline.init_masks(cluster_shape); + // TileID scheduler + TileScheduler scheduler(&shared_storage.clc_response[0], params.scheduler, block_id_in_cluster); + typename TileScheduler::WorkTileInfo work_tile_info = scheduler.initial_work_tile_info(cluster_shape); + auto cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + + // + // TMEM "Allocation" + // + // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. + TiledMma tiled_mma; + ThrMMA cta_mma = tiled_mma.get_slice(cta_coord_v); + auto acc_shape = partition_shape_C(tiled_mma, take<0,2>(TileShape{})); + Tensor accumulators = cutlass::detail::make_sm100_accumulator( + tiled_mma, acc_shape, EpilogueTile{}); + +#if 1 + pipeline_init_wait(cluster_size); + + if (is_participant.main_ab_load) { + set_warpgroup_reg_dealloc(); + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + bool do_load_order_arrive = is_epi_load_needed; + auto load_inputs = collective_mainloop.load_ab_init( + problem_shape_MNKL, params.mainloop, shared_storage.tensors.mainloop); + Tensor gA_mkl = get<0>(load_inputs); + bool requires_clc_query = true; + // 2cta: 4x4/4x2/2x4 enable the PF + bool enable_prefetch = shape<0>(AtomThrShapeMNK{}) == 2 and + (size<0>(cluster_shape) == 4 and size<1>(cluster_shape) == 4) or + (size<0>(cluster_shape) == 4 and size<1>(cluster_shape) == 2) or + (size<0>(cluster_shape) == 2 and size<1>(cluster_shape) == 4); + + do { + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + auto k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); + auto k_tile_prologue = min(MainloopABPipeline::Stages, k_tile_count); + auto k_tile_iter = cute::make_coord_iterator(idx2crd(k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); + + // Start mainloop prologue loads, arrive on the epilogue residual load barrier, resume mainloop loads + auto [mainloop_producer_state_next, k_tile_iter_next] = collective_mainloop.load_ab( + params.mainloop, + mainloop_ab_pipeline, + mainloop_ab_pipe_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter, k_tile_prologue, + enable_prefetch ? k_tile_count : 0 + ); + mainloop_ab_pipe_producer_state = mainloop_producer_state_next; + + if constexpr (not IsNoSmemEpilogue) { + if (do_load_order_arrive) { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + } + + auto [mainloop_producer_state_next_, unused_] = collective_mainloop.load_ab( + params.mainloop, + mainloop_ab_pipeline, + mainloop_ab_pipe_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter_next, k_tile_count - k_tile_prologue, + enable_prefetch ? k_tile_count - k_tile_prologue : 0 + ); + mainloop_ab_pipe_producer_state = mainloop_producer_state_next_; + + // Sync warp to prevent non-participating threads entering next wave early + __syncwarp(); + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + } while (work_tile_info.is_valid()); + collective_mainloop.load_tail(mainloop_ab_pipeline, mainloop_ab_pipe_producer_state); + + } + + else if (is_participant.sched) { + set_warpgroup_reg_dealloc(); + + if constexpr (IsSchedDynamicPersistent) { + // Whether a new CLC query must be performed. + // See comment below where this variable is updated for a description of + // why this variable is needed. + bool requires_clc_query = true; + + cutlass::arch::wait_on_dependent_grids(); + + do { + if (requires_clc_query) { + // Throttle CLC query to mitigate workload imbalance caused by skews among persistent workers. + clc_throttle_pipeline.consumer_wait(clc_pipe_throttle_consumer_state); + clc_throttle_pipeline.consumer_release(clc_pipe_throttle_consumer_state); + ++clc_pipe_throttle_consumer_state; + + // Query next clcID and update producer state + clc_pipe_producer_state = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state); + } + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + // Only perform a new CLC query if we consumed a new CLC query result in + // `fetch_next_work`. An example of a case in which CLC `fetch_next_work` does + // not consume a new CLC query response is when processing stream-K units. + // The current stream-K scheduler uses single WorkTileInfo to track multiple + // (potentially-partial) tiles to be computed via stream-K. In this case, + // `fetch_next_work` simply performs in-place updates on the existing WorkTileInfo, + // rather than consuming a CLC query response. + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + work_tile_info = next_work_tile_info; + } while (work_tile_info.is_valid()); + clc_pipeline.producer_tail(clc_pipe_producer_state); + } + } + + + else if (is_participant.main_sf_load) { + set_warpgroup_reg_dealloc(); + bool do_load_order_arrive = is_epi_load_needed; + auto load_inputs = collective_mainloop.load_sf_init( + problem_shape_MNKL, params.mainloop, shared_storage.tensors.mainloop); + + auto tmp = collective_mainloop.load_ab_init( + problem_shape_MNKL, params.mainloop, shared_storage.tensors.mainloop); + Tensor gA_mkl = get<0>(tmp); // just to get k_tile_count or maybe we could use ceil_div(shape<3>(gSFA_mkl), 2); + bool requires_clc_query = true; + // 2cta: 4x4/4x2/2x4 enable the PF + bool enable_prefetch = shape<0>(AtomThrShapeMNK{}) == 2 and + (size<0>(cluster_shape) == 4 and size<1>(cluster_shape) == 4) or + (size<0>(cluster_shape) == 4 and size<1>(cluster_shape) == 2) or + (size<0>(cluster_shape) == 2 and size<1>(cluster_shape) == 4); + do { + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + auto k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); + auto k_tile_prologue = min(MainloopSFPipeline::Stages/2, k_tile_count); + auto k_tile_iter = cute::make_coord_iterator(idx2crd(k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); // maybe we could use ceil_div(gSFA_mkl, 2); + + if constexpr (IsSchedDynamicPersistent) { + if (is_first_cta_in_cluster && requires_clc_query) { + clc_throttle_pipeline.producer_acquire(clc_pipe_throttle_producer_state); + clc_throttle_pipeline.producer_commit(clc_pipe_throttle_producer_state); + ++clc_pipe_throttle_producer_state; + } + } + + // Start mainloop prologue loads, arrive on the epilogue residual load barrier, resume mainloop loads + auto [mainloop_producer_state_next, k_tile_iter_next] = collective_mainloop.load_sf( + params.mainloop, + mainloop_sf_pipeline, + mainloop_sf_pipe_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter, k_tile_prologue, + enable_prefetch ? k_tile_count : 0 + ); + mainloop_sf_pipe_producer_state = mainloop_producer_state_next; + + if constexpr (not IsNoSmemEpilogue) { + if (do_load_order_arrive) { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + } + + auto [mainloop_producer_state_next_, unused_] = collective_mainloop.load_sf( + params.mainloop, + mainloop_sf_pipeline, + mainloop_sf_pipe_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter_next, k_tile_count - k_tile_prologue, + enable_prefetch ? k_tile_count - k_tile_prologue :0 + ); + mainloop_sf_pipe_producer_state = mainloop_producer_state_next_; + + // Sync warp to prevent non-participating threads entering next wave early + __syncwarp(); + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + } while (work_tile_info.is_valid()); + collective_mainloop.load_tail(mainloop_sf_pipeline, mainloop_sf_pipe_producer_state); + + } + + + else if (is_participant.mma) { + set_warpgroup_reg_dealloc(); + // Tmem allocation sequence + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + accumulators.data() = tmem_base_ptr; + int tmem_non_accumulator_base = tmem_base_ptr + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + auto mma_inputs = collective_mainloop.mma_init(params.mainloop, + shared_storage.tensors.mainloop, + tmem_non_accumulator_base /*Start SF TMEM allocation after the accumulator*/); + + do { + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + // Wait for tmem accumulator buffer to become empty with a flipped phase + if constexpr (!IsOverlappingAccum) { + if (is_mma_leader_cta) { + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + } + } + int stage_idx = (IsOverlappingAccum) ? (accumulator_pipe_producer_state.phase() ^ 1) : (accumulator_pipe_producer_state.index()); + Tensor accumulator = accumulators(_,_,_, stage_idx); + + if (is_mma_leader_cta) { + auto [mainloop_ab_pipe_consumer_state_next, mainloop_sf_pipe_consumer_state_next] = collective_mainloop.mma( + cute::make_tuple(mainloop_ab_pipeline, mainloop_sf_pipeline, accumulator_pipeline), + cute::make_tuple(mainloop_ab_pipe_consumer_state, mainloop_sf_pipe_consumer_state, accumulator_pipe_producer_state), + accumulator, + mma_inputs, + cta_coord_mnkl, + k_tile_count + ); + + mainloop_ab_pipe_consumer_state = mainloop_ab_pipe_consumer_state_next; + mainloop_sf_pipe_consumer_state = mainloop_sf_pipe_consumer_state_next; + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + } + ++accumulator_pipe_producer_state; + + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + } while (work_tile_info.is_valid()); + + // Hint on an early release of global memory resources. + // The timing of calling this function only influences performance, + // not functional correctness. + cutlass::arch::launch_dependent_grids(); + + // Release the right to allocate before deallocations so that the next CTA can rasterize + tmem_allocator.release_allocation_lock(); + + if constexpr (!IsOverlappingAccum) { + // Leader MMA waits for leader + peer epilogues to release accumulator stage + if (is_mma_leader_cta) { + accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); + } + // Signal to peer MMA that entire tmem allocation can be deallocated + if constexpr (has_mma_peer_cta) { + // Leader does wait + arrive, follower does arrive + wait + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, not is_mma_leader_cta); + tmem_deallocation_result_barrier.wait(dealloc_barrier_phase); + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, is_mma_leader_cta); + } + } + else { + tmem_deallocation_result_barrier.wait(dealloc_barrier_phase); + } + + // Free entire tmem allocation + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + else if (not IsNoSmemEpilogue and is_participant.epi_load) { + set_warpgroup_reg_dealloc(); + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + bool do_load_order_wait = true; + bool do_tail_load = false; + int current_wave = 0; + do { + bool compute_epilogue = TileScheduler::compute_epilogue(work_tile_info, params.scheduler); + + // Get current work tile and fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + if (compute_epilogue) { + if (do_load_order_wait) { + load_order_barrier.wait(); + do_load_order_wait = false; + } + + bool reverse_epi_n = IsOverlappingAccum && (current_wave % 2 == 0); + epi_load_pipe_producer_state = collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + shared_storage.tensors.epilogue, + reverse_epi_n + ); + + do_tail_load = true; + } + current_wave++; + + // Calculate the cta coordinates of the next work tile + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + } while (work_tile_info.is_valid()); + + // Only perform a tail load if one of the work units processed performed + // an epilogue load. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_load) { + collective_epilogue.load_tail( + epi_load_pipeline, epi_load_pipe_producer_state, + epi_store_pipeline, epi_store_pipe_producer_state); + } + } + + else if (is_participant.epilogue) { + set_warpgroup_reg_alloc(); + // Wait for tmem allocate here + tmem_allocation_result_barrier.arrive_and_wait(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + accumulators.data() = tmem_base_ptr; + + + bool do_tail_store = false; + do { + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + int stage_idx = [&] () { + if constexpr (IsOverlappingAccum) { + return accumulator_pipe_consumer_state.phase(); + } + else { + return accumulator_pipe_consumer_state.index(); + } + }(); + + // Accumulator + Tensor accumulator = accumulators(_,_,_,stage_idx); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + accumulator_pipe_consumer_state = scheduler.template fixup( + TiledMma{}, + work_tile_info, + accumulator, + accumulator_pipeline, + accumulator_pipe_consumer_state, + typename CollectiveEpilogue::CopyOpT2R{} + ); + + // + // Epilogue and write to gD + // + if (scheduler.compute_epilogue(work_tile_info)) { + auto [load_state_next, store_state_next, acc_state_next] = collective_epilogue.template store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + accumulator_pipeline, + accumulator_pipe_consumer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + accumulator, + shared_storage.tensors.epilogue + ); + epi_load_pipe_consumer_state = load_state_next; + epi_store_pipe_producer_state = store_state_next; + accumulator_pipe_consumer_state = acc_state_next; + do_tail_store = true; + } + + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + + } while (work_tile_info.is_valid()); + + if constexpr (IsOverlappingAccum) { + // Signal to peer MMA that Full TMEM alloc can be deallocated + if constexpr (has_mma_peer_cta) { + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank); + } + tmem_deallocation_result_barrier.arrive(); + } + + // Only perform a tail store if one of the work units processed performed + // an epilogue. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_store) { + collective_epilogue.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state, + epi_store_pipeline, epi_store_pipe_producer_state, + CtaShape_MNK{}); + } + } + + else { + set_warpgroup_reg_dealloc(); + } +#endif + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sm120_gemm_tma_warpspecialized_cooperative_asymmetric_dma.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/sm120_gemm_tma_warpspecialized_cooperative_asymmetric_dma.hpp index 610dfc6..5f4e193 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/sm120_gemm_tma_warpspecialized_cooperative_asymmetric_dma.hpp +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sm120_gemm_tma_warpspecialized_cooperative_asymmetric_dma.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -120,7 +120,7 @@ class GemmUniversal< // Tensor A/B could have different buffering, with number of KBLOCK, aka TILEK, // and STAGEs. It let AsymmetricKRatio, equals KBLOCK_A / KBLOCK_B, to control // the balance of A/B loading, make sure A/B's pipeline keep same cadence - // when procude / consume data. + // when produce / consume data. // Currently, AsymmetricKRatio = {1, 2} is the only support. static constexpr bool isAsymmetric = DispatchPolicy::Schedule::isAsymmetric; static constexpr uint32_t AsymmetricKRatio = isAsymmetric ? 2 : 1; diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sm70_gemm.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/sm70_gemm.hpp index 18c7960..faa9b1c 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/sm70_gemm.hpp +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sm70_gemm.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sm70_gemm_array.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/sm70_gemm_array.hpp new file mode 100644 index 0000000..409ecda --- /dev/null +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sm70_gemm_array.hpp @@ -0,0 +1,279 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/tensor.hpp" + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileScheduler_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_, + cute::enable_if_t>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(rank(typename ProblemShape::UnderlyingProblemShape{}) == 4, + "ProblemShape{} should be or "); + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using InternalStrideA = typename CollectiveMainloop::InternalStrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using InternalStrideB = typename CollectiveMainloop::InternalStrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + + using TileSchedulerTag = TileScheduler_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileScheduler_, ArchTag, TileShape, + cute::Shape, cute::Int<1>, cute::Int<1>>>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + static constexpr bool IsGdcEnabled = false; + + static constexpr bool is_valid_tile_scheduler = + cute::is_void_v or cute::is_same_v; +static_assert(is_valid_tile_scheduler, "SM70 kernel does not support specializing the tile scheduler."); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using InternalStrideC = typename CollectiveEpilogue::InternalStrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using InternalStrideD = typename CollectiveEpilogue::InternalStrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + static_assert(cute::is_same_v, + "Mainloop and epilogue do not agree on accumulator value type."); + + // MSVC requires the cast to fix a warning-as-error. + static constexpr int SharedStorageSize = static_cast(cute::max( + sizeof(typename CollectiveMainloop::SharedStorage), + sizeof(typename CollectiveEpilogue::SharedStorage))); + + static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(cute::size(TiledMma{})); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode{}; + typename ProblemShape::UnderlyingProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + (void) workspace; + typename ProblemShape::UnderlyingProblemShape problem_shape = args.problem_shape.get_host_problem_shape(); + + KernelHardwareInfo hw_info{args.hw_info.device_id, args.hw_info.sm_count}; + auto problem_shape_MNKL = append<4>(args.problem_shape, Int<1>{}); + + return { + args.mode, + problem_shape, + CollectiveMainloop::to_underlying_arguments(problem_shape, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(problem_shape, args.epilogue, workspace) + }; + } + + static bool + can_implement(Arguments const& args) { + + bool implementable = (args.mode == GemmUniversalMode::kArray && rank(typename ProblemShape::UnderlyingProblemShape{}) == 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + typename ProblemShape::UnderlyingProblemShape problem_shape = args.problem_shape.get_host_problem_shape(); + implementable &= TileScheduler::can_implement(args.scheduler); + return implementable; + } + + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_size = 0; + return workspace_size; + } + + static + cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + cutlass::Status status = Status::kSuccess; + + return status; + } + + static dim3 + get_grid_shape(Params const& params) { + int batch_count = cute::size<3>(params.problem_shape); + return dim3( + cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(TileShape{}))), + cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(TileShape{}))), + batch_count + ); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + using namespace cute; + using X = Underscore; + + // Preconditions + CUTE_STATIC_ASSERT(is_static::value); + + // Separate out problem shape for convenience + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + auto [M,N,K,L] = problem_shape_MNKL; + + // Preconditions + static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + // Get the appropriate blocks for this thread block -- potential for thread block locality + int thread_idx = int(threadIdx.x); + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + auto [m_coord, n_coord, l_coord] = static_cast(blockIdx); + auto blk_coord_mnkl = make_coord(int(m_coord), int(n_coord), _, int(l_coord)); // (m,n,k,l) + + // Represent the full tensors + Tensor mA_mkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_A[l_coord]), make_shape(M,K,1), params.mainloop.dA); //(m,k,l) + Tensor mB_nkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_B[l_coord]), make_shape(N,K,1), params.mainloop.dB); //(n,k,l) + + // Get batch slice + Tensor mA_mk = mA_mkl(_,_,0); // (m,k) + Tensor mB_nk = mB_nkl(_,_,0); // (n,k) + + // Slice to get the tiles this thread block is responsible for + Tensor gA = local_tile(mA_mk, blk_shape, take<0,3>(blk_coord_mnkl), Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) + Tensor gB = local_tile(mB_nk, blk_shape, take<0,3>(blk_coord_mnkl), Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + + // Compute tile residues for predication + auto m_max_coord = M - size<0>(gA) * get<0>(blk_coord_mnkl); // M - BLK_M * m_coord + auto n_max_coord = N - size<0>(gB) * get<1>(blk_coord_mnkl); // N - BLK_N * n_coord + auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max + auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue); + + // Allocate the tiled_mma and the accumulators for the (M,N) blk_shape + TiledMma tiled_mma; + Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + clear(accumulators); + + auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); + int k_tile_count = size<2>(gA); + + + // Perform the collective scoped MMA + CollectiveMainloop collective_mma; + collective_mma( + accumulators, + gA, + gB, + accumulators, + k_tile_iter, k_tile_count, + residue_mnk, + thread_idx, + smem_buf + ); + + // Epilogue and write to gD + CollectiveEpilogue epilogue{params.epilogue}; + epilogue( + problem_shape_MNKL, + blk_shape, + blk_coord_mnkl, + accumulators, + tiled_mma, + residue_mnk, + thread_idx, + smem_buf + ); + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp index 4f5723d..3a5149d 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -48,6 +48,7 @@ #include "cutlass/trace.h" #include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" #include "cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp" +#include "cutlass/arch/grid_dependency_control.h" /////////////////////////////////////////////////////////////////////////////// @@ -386,9 +387,6 @@ class GemmUniversal< get_grid_shape(Params const& params) { // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently TileSchedulerArguments args{}; - if constexpr (!std::is_const_v) { - args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_; - } args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN ? TileScheduler::RasterOrderOptions::AlongN : TileScheduler::RasterOrderOptions::AlongM; dim3 grid_shape; if constexpr (IsGroupedGemmKernel) { @@ -411,12 +409,14 @@ class GemmUniversal< using namespace cute; using X = Underscore; -#if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM120_ALL) || CUDA_ARCH_CONDITIONAL_OR_FAMILY(1200)) -# define ENABLE_SM90_KERNEL_LEVEL 1 -#endif +# if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM120_ALL) || defined(__CUDA_ARCH_FEAT_SM121_ALL) ||\ + CUDA_ARCH_CONDITIONAL_OR_FAMILY(1200) || CUDA_ARCH_CONDITIONAL_OR_FAMILY(1210)) +# define ENABLE_SM90_KERNEL_LEVEL 1 +# endif + // Any Tensor Op MMA Atom in the ISA is arch conditional. #if ! defined(ENABLE_SM90_KERNEL_LEVEL) - printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); + CUTE_INVALID_CONTROL_PATH("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); #else // Preconditions @@ -451,16 +451,6 @@ class GemmUniversal< // Kernel level shared memory storage SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - auto scheduler = [&] () { - // Group scheduler requires a different constructor that takes a response ptr - if constexpr (cute::is_same_v) { - return TileScheduler{params.scheduler, shared_storage.scheduler_response}; - } - else { - return TileScheduler{params.scheduler}; - } - } (); - // In a warp specialized kernel, collectives expose data movement and compute operations separately CollectiveMainloop collective_mainloop; CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); @@ -579,6 +569,19 @@ class GemmUniversal< // Wait for all thread blocks in the Cluster cluster_wait_fn(); + // Ensure memory ops in this kernel are not done prior to completion of dependent grids. + cutlass::arch::wait_on_dependent_grids(); + + auto scheduler = [&] () { + // Group scheduler requires a different constructor that takes a response ptr + if constexpr (cute::is_same_v) { + return TileScheduler{params.scheduler, shared_storage.scheduler_response}; + } + else { + return TileScheduler{params.scheduler}; + } + } (); + auto work_tile_info = scheduler.initial_work_tile_info(ClusterShape{}); if (not work_tile_info.is_valid()) { @@ -829,8 +832,6 @@ class GemmUniversal< collective_epilogue.template tensormaps_fence_acquire(epi_load_tensormap); } - bool wait = work_tile_info.is_valid() && curr_batch != next_work_tile_info.L_idx; - epi_load_pipe_producer_state = collective_epilogue.load( epi_load_pipeline, epi_load_pipe_producer_state, @@ -841,8 +842,7 @@ class GemmUniversal< lane_idx, shared_storage.tensors.epilogue, epi_load_tensormap, - work_tile_info.reduction_subtile_idx(), - wait + work_tile_info.reduction_subtile_idx() ); } @@ -991,6 +991,11 @@ class GemmUniversal< // Get next work tile auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info, tile_scheduler_pipeline, tile_scheduler_pipe_consumer_state); + + if (!next_work_tile_info.is_valid()) { + cutlass::arch::launch_dependent_grids(); + } + work_tile_info = next_work_tile_info; if (increment_pipe) { ++tile_scheduler_pipe_consumer_state; diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp index f33f468..e1fa1c8 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -48,6 +48,7 @@ #include "cutlass/trace.h" #include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" #include "cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp" +#include "cutlass/arch/grid_dependency_control.h" /////////////////////////////////////////////////////////////////////////////// @@ -398,9 +399,6 @@ class GemmUniversal< get_grid_shape(Params const& params) { // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently TileSchedulerArguments args{}; - if constexpr (!std::is_const_v) { - args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_; - } args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN ? TileScheduler::RasterOrderOptions::AlongN : TileScheduler::RasterOrderOptions::AlongM; dim3 grid_shape; if constexpr (IsGroupedGemmKernel) { @@ -423,12 +421,14 @@ class GemmUniversal< using namespace cute; using X = Underscore; -#if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM120_ALL) || CUDA_ARCH_CONDITIONAL_OR_FAMILY(1200)) -# define ENABLE_SM90_KERNEL_LEVEL 1 -#endif +# if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM120_ALL) || defined(__CUDA_ARCH_FEAT_SM121_ALL) ||\ + CUDA_ARCH_CONDITIONAL_OR_FAMILY(1200) || CUDA_ARCH_CONDITIONAL_OR_FAMILY(1210)) +# define ENABLE_SM90_KERNEL_LEVEL 1 +# endif + // Any Tensor Op MMA Atom in the ISA is arch conditional. #if ! defined(ENABLE_SM90_KERNEL_LEVEL) - printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); + CUTE_INVALID_CONTROL_PATH("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); #else // Preconditions @@ -460,16 +460,6 @@ class GemmUniversal< // Kernel level shared memory storage SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - auto scheduler = [&] () { - // Group scheduler requires a different constructor that takes a response ptr - if constexpr (cute::is_same_v) { - return TileScheduler{params.scheduler, shared_storage.scheduler_response}; - } - else { - return TileScheduler{params.scheduler}; - } - } (); - // In a warp specialized kernel, collectives expose data movement and compute operations separately CollectiveMainloop collective_mainloop; CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); @@ -594,6 +584,19 @@ class GemmUniversal< // Wait for all thread blocks in the Cluster cluster_wait_fn(); + // Ensure memory ops in this kernel are not done prior to completion of dependent grids. + cutlass::arch::wait_on_dependent_grids(); + + auto scheduler = [&] () { + // Group scheduler requires a different constructor that takes a response ptr + if constexpr (cute::is_same_v) { + return TileScheduler{params.scheduler, shared_storage.scheduler_response}; + } + else { + return TileScheduler{params.scheduler}; + } + } (); + auto work_tile_info = scheduler.initial_work_tile_info(ClusterShape{}); if (not work_tile_info.is_valid()) { @@ -867,8 +870,6 @@ class GemmUniversal< collective_epilogue.template tensormaps_fence_acquire(epi_load_tensormap); } - bool wait = work_tile_info.is_valid() && curr_batch != next_work_tile_info.L_idx; - epi_load_pipe_producer_state = collective_epilogue.load( epi_load_pipeline, epi_load_pipe_producer_state, @@ -879,8 +880,7 @@ class GemmUniversal< lane_idx, shared_storage.tensors.epilogue, epi_load_tensormap, - work_tile_info.reduction_subtile_idx(), - wait + work_tile_info.reduction_subtile_idx() ); } @@ -1035,6 +1035,11 @@ class GemmUniversal< // Get next work tile auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info, tile_scheduler_pipeline, tile_scheduler_pipe_consumer_state); + + if (!next_work_tile_info.is_valid()) { + cutlass::arch::launch_dependent_grids(); + } + work_tile_info = next_work_tile_info; if (increment_pipe) { ++tile_scheduler_pipe_consumer_state; diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp index 2292d7e..899ad01 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -42,6 +42,8 @@ #include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" #include "cutlass/gemm/kernel/tile_scheduler.hpp" #include "cutlass/trace.h" +#include "cutlass/arch/grid_dependency_control.h" + #include "cute/tensor.hpp" /////////////////////////////////////////////////////////////////////////////// @@ -204,7 +206,7 @@ class GemmUniversal< // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) - printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); + CUTE_INVALID_CONTROL_PATH("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); #else // Preconditions @@ -261,6 +263,9 @@ class GemmUniversal< auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); auto k_tile_count = size<2>(gA); + // Ensure memory ops in this kernel are not done prior to completion of dependent grids. + cutlass::arch::wait_on_dependent_grids(); + // Perform the collective scoped MMA CollectiveMainloop collective_mma; collective_mma( diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp index 5bdaba1..52904ed 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -270,12 +270,14 @@ class GemmUniversal< using namespace cute; using X = Underscore; -#if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM120_ALL) || CUDA_ARCH_CONDITIONAL_OR_FAMILY(1200)) -# define ENABLE_SM90_KERNEL_LEVEL 1 -#endif +# if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM120_ALL) || defined(__CUDA_ARCH_FEAT_SM121_ALL) ||\ + CUDA_ARCH_CONDITIONAL_OR_FAMILY(1200) || CUDA_ARCH_CONDITIONAL_OR_FAMILY(1210)) +# define ENABLE_SM90_KERNEL_LEVEL 1 +# endif + // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. #if ! defined(ENABLE_SM90_KERNEL_LEVEL) - printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); + CUTE_INVALID_CONTROL_PATH("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); #else enum class WarpGroupRole { diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp index 0b12aac..587a5f4 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -132,7 +132,21 @@ class GemmUniversal< static constexpr int RegsPerThread = size<0>(TileShape{}) * size<1>(TileShape{}) / NumMMAThreads * sizeof(ElementAccumulator) / sizeof(uint32_t); - static constexpr bool HeavyRegisterPressure = RegsPerThread >= 208; + + // Detect if this is SM120 blockscaled kernel which hits high register pressure + // on smaller tiles (e.g. 256x128 registers per thread) + template + struct IsSm120BlockScaled : cute::false_type {}; + + template + struct IsSm120BlockScaled> + : cute::true_type {}; + + static constexpr bool IsSm120Family = cute::is_same_v; + + static constexpr bool HeavyRegisterPressure = + IsSm120BlockScaled::value ? (RegsPerThread >= 128) : (RegsPerThread >= 208); + static constexpr uint32_t LoadRegisterRequirement = !HeavyRegisterPressure ? 40 : 24; static constexpr uint32_t MmaRegisterRequirement = !HeavyRegisterPressure ? 232 : 240; @@ -342,12 +356,14 @@ class GemmUniversal< using namespace cute; using X = Underscore; -#if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM120_ALL) || CUDA_ARCH_CONDITIONAL_OR_FAMILY(1200)) -# define ENABLE_SM90_KERNEL_LEVEL 1 -#endif +# if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM120_ALL) || defined(__CUDA_ARCH_FEAT_SM121_ALL) ||\ + CUDA_ARCH_CONDITIONAL_OR_FAMILY(1200) || CUDA_ARCH_CONDITIONAL_OR_FAMILY(1210)) +# define ENABLE_SM90_KERNEL_LEVEL 1 +# endif + // Any Tensor Op MMA Atom in the ISA is arch conditional. #if ! defined(ENABLE_SM90_KERNEL_LEVEL) - printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); + CUTE_INVALID_CONTROL_PATH("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); #else // Preconditions @@ -788,15 +804,16 @@ class GemmUniversal< // Update starting mainloop pipeline state for the next tile mainloop_pipe_consumer_state.advance(work_k_tile_count); } - #ifdef CUTLASS_ENABLE_GDC_FOR_SM90 - if (scheduler.is_last_tile(work_tile_info)) { - // Hint on an early release of global memory resources. - // The timing of calling this function only influences performance, - // not functional correctness. - cutlass::arch::launch_dependent_grids(); + if constexpr (!IsSm120Family) { + if (scheduler.is_last_tile(work_tile_info)) { + // Hint on an early release of global memory resources. + // The timing of calling this function only influences performance, + // not functional correctness. + cutlass::arch::launch_dependent_grids(); + + } } - #endif // Index of warp group within consumer warp groups int consumer_warp_group_idx = canonical_warp_group_idx() - NumLoadWarpGroups; diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp index fc4f5fc..1734c16 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -147,6 +147,8 @@ class GemmUniversal< static constexpr uint32_t LoadRegisterRequirement = !HeavyRegisterPressure ? 40 : 24; static constexpr uint32_t MmaRegisterRequirement = !HeavyRegisterPressure ? 232 : 240; + static constexpr bool IsSm120Family = cute::is_same_v; + // 1 stage ordered sequence between mainloop and epilogue producer load threads using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; @@ -215,7 +217,6 @@ class GemmUniversal< to_underlying_arguments(Arguments const& args, void* workspace) { CUTLASS_TRACE_HOST("to_underlying_arguments():"); - (void) workspace; auto problem_shape = args.problem_shape; if constexpr (detail::Has_SwapAB_v) { // swap M/N @@ -354,12 +355,14 @@ class GemmUniversal< using namespace cute; using X = Underscore; -#if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM120_ALL) || CUDA_ARCH_CONDITIONAL_OR_FAMILY(1200)) -# define ENABLE_SM90_KERNEL_LEVEL 1 -#endif +# if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM120_ALL) || defined(__CUDA_ARCH_FEAT_SM121_ALL) ||\ + CUDA_ARCH_CONDITIONAL_OR_FAMILY(1200) || CUDA_ARCH_CONDITIONAL_OR_FAMILY(1210)) +# define ENABLE_SM90_KERNEL_LEVEL 1 +# endif + // Any Tensor Op MMA Atom in the ISA is arch conditional. #if ! defined(ENABLE_SM90_KERNEL_LEVEL) - printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); + CUTE_INVALID_CONTROL_PATH("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); #else // Preconditions @@ -798,18 +801,18 @@ class GemmUniversal< else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { cutlass::arch::warpgroup_reg_alloc(); - #ifdef CUTLASS_ENABLE_GDC_FOR_SM90 - // It is possible to have work tiles start off invalid, - // so we have to check that first. - if (not work_tile_info.is_valid()) { - // Hint on an early release of global memory resources. - // The timing of calling this function only influences performance, - // not functional correctness. - cutlass::arch::launch_dependent_grids(); + if constexpr (!IsSm120Family) { + // It is possible to have work tiles start off invalid, + // so we have to check that first. + if (not work_tile_info.is_valid()) { + // Hint on an early release of global memory resources. + // The timing of calling this function only influences performance, + // not functional correctness. + cutlass::arch::launch_dependent_grids(); - return; + return; + } } - #endif if constexpr (IsSchedDynamicPersistent) { // Consumer0's initial tile is static. It starts consuming the 2nd tile. @@ -866,15 +869,15 @@ class GemmUniversal< // Update starting mainloop pipeline state for the next tile mainloop_pipe_consumer_state.advance(k_tile_count * NumMmaWarpGroups); - #ifdef CUTLASS_ENABLE_GDC_FOR_SM90 - if (scheduler.is_last_tile(work_tile_info, NumMmaWarpGroups)) { - // Hint on an early release of global memory resources. - // The timing of calling this function only influences performance, - // not functional correctness. - cutlass::arch::launch_dependent_grids(); + if constexpr (!IsSm120Family) { + if (scheduler.is_last_tile(work_tile_info, NumMmaWarpGroups)) { + // Hint on an early release of global memory resources. + // The timing of calling this function only influences performance, + // not functional correctness. + cutlass::arch::launch_dependent_grids(); + } } - #endif // Order two Math WG's Epilogue one after the other math_wg_order_barrier.wait(); diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp index e7cafde..584178d 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -42,6 +42,8 @@ #include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" #include "cutlass/pipeline/pipeline.hpp" #include "cute/tensor.hpp" +#include "cutlass/arch/grid_dependency_control.h" + /////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm::kernel { @@ -227,7 +229,7 @@ class GemmUniversal< // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) - printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); + CUTE_INVALID_CONTROL_PATH("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); #else enum class WarpGroupRole { @@ -343,6 +345,9 @@ class GemmUniversal< auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue); + // Ensure memory ops in this kernel are not done prior to completion of dependent grids. + cutlass::arch::wait_on_dependent_grids(); + collective_mainloop.load( mainloop_pipeline, mainloop_pipe_producer_state, diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp index 1d35ff2..26e1b09 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -42,6 +42,7 @@ #include "cutlass/gemm/kernel/tile_scheduler.hpp" #include "cutlass/pipeline/pipeline.hpp" #include "cute/tensor.hpp" +#include "cutlass/arch/grid_dependency_control.h" /////////////////////////////////////////////////////////////////////////////// @@ -265,7 +266,7 @@ class GemmUniversal< // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) - printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); + CUTE_INVALID_CONTROL_PATH("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); #else static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); @@ -386,6 +387,9 @@ class GemmUniversal< auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue); + // Ensure memory ops in this kernel are not done prior to completion of dependent grids. + cutlass::arch::wait_on_dependent_grids(); + collective_mainloop.load( mainloop_pipeline, mainloop_pipe_producer_state, diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp index be086f0..379d4cb 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -45,6 +45,8 @@ #include "cutlass/trace.h" #include "cute/tensor.hpp" +#include "cutlass/arch/grid_dependency_control.h" + /////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm::kernel { @@ -271,7 +273,7 @@ class GemmUniversal< // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) - printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); + CUTE_INVALID_CONTROL_PATH("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); #else // Preconditions @@ -409,6 +411,10 @@ class GemmUniversal< auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue); + // Ensure memory ops in this kernel are not done prior to completion of dependent grids. + cutlass::arch::wait_on_dependent_grids(); + + collective_mainloop.load( mainloop_pipeline, mainloop_pipe_producer_state, diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp index dd90d48..7f76663 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp index 92749b1..d746ca7 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -59,6 +59,7 @@ class PersistentTileSchedulerSm90Group { uint64_t start_linear_idx = 0; uint64_t total_tiles = 0; uint64_t problem_blocks_along_raster_order = 0; + int32_t log_swizzle_size = 0; } current_group_info_; public: @@ -135,7 +136,7 @@ class PersistentTileSchedulerSm90Group { // Sink scheduler params as a member Params scheduler_params; - SchedulerResponse *response_ptr_ = nullptr; + void *response_ptr_ = nullptr; ProblemShape cached_problem_shapes_[2]; // @@ -225,6 +226,8 @@ class PersistentTileSchedulerSm90Group { for (int group = 0; group < groups; group++) { auto ctas_along_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shapes.get_host_problem_shape(group)), cute::shape<0>(cta_shape))); auto ctas_along_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shapes.get_host_problem_shape(group)), cute::shape<1>(cta_shape))); + if(ctas_along_m <= 0) ctas_along_m = 1; + if(ctas_along_n <= 0) ctas_along_n = 1; auto problem_blocks_m = round_up(ctas_along_m, cute::get<0>(cluster_shape)); auto problem_blocks_n = round_up(ctas_along_n, cute::get<1>(cluster_shape)); total_ctas += problem_blocks_m * problem_blocks_n; @@ -242,8 +245,29 @@ class PersistentTileSchedulerSm90Group { return true; } + // Calculate the log of the swizzle size based on the problem CTAs and the max swizzle size + CUTLASS_DEVICE + static int32_t + get_log_swizzle_size(int problem_ctas_m, int problem_ctas_n, int max_swizzle_size) { + int min_cta_dim = platform::min(problem_ctas_m, problem_ctas_n); + if (max_swizzle_size >= 8 && min_cta_dim >= 6) { + return 3; + } + else if (max_swizzle_size >= 4 && min_cta_dim >= 3) { + return 2; + } + else if (max_swizzle_size >= 2 && min_cta_dim >= 2) { + return 1; + } + else { + return 0; + } + } + PersistentTileSchedulerSm90Group() = default; + // Note: constructing this tile scheduler can touch global memory that was + // written to by the prior kernel. CUTLASS_DEVICE explicit PersistentTileSchedulerSm90Group(Params const& params_, SchedulerResponse* response_ptr) : scheduler_params(params_), response_ptr_(response_ptr) { // MSVC requires protecting use of CUDA-specific nonstandard syntax, // like blockIdx and gridDim, with __CUDA_ARCH__. @@ -272,8 +296,9 @@ class PersistentTileSchedulerSm90Group { ctas_along_m = scheduler_params.divmod_cta_shape_m_.divide(cute::shape<0>(problem_shape) + scheduler_params.divmod_cta_shape_m_.divisor - 1); ctas_along_n = scheduler_params.divmod_cta_shape_n_.divide(cute::shape<1>(problem_shape) + scheduler_params.divmod_cta_shape_n_.divisor - 1); } - auto problem_blocks_m = round_up(ctas_along_m, (1 << params_.log_swizzle_size_) * params_.cluster_shape_.m()); - auto problem_blocks_n = round_up(ctas_along_n, (1 << params_.log_swizzle_size_) * params_.cluster_shape_.n()); + current_group_info_.log_swizzle_size = get_log_swizzle_size(ctas_along_m, ctas_along_n, params_.max_swizzle_size_); + auto problem_blocks_m = round_up(ctas_along_m, (1 << current_group_info_.log_swizzle_size) * params_.cluster_shape_.m()); + auto problem_blocks_n = round_up(ctas_along_n, (1 << current_group_info_.log_swizzle_size) * params_.cluster_shape_.n()); current_group_info_.total_tiles = problem_blocks_m * problem_blocks_n; current_group_info_.problem_blocks_along_raster_order = params_.raster_order_ == RasterOrder::AlongN ? problem_blocks_n : problem_blocks_m; @@ -298,15 +323,14 @@ class PersistentTileSchedulerSm90Group { FastDivmodU64Pow2 const& divmod_cluster_shape_minor, FastDivmodU64 const& divmod_cta_shape_m, FastDivmodU64 const& divmod_cta_shape_n, - int32_t log_swizzle_size, + int32_t max_swizzle_size, RasterOrder raster_order) { - int32_t valid_tile = 1; + uint8_t valid_tile = 1; // Use a warp to "speculatively" check if the work tile maps to the next 32 groups int lane_idx = canonical_lane_idx(); int total_problem_groups = problem_shapes.groups(); - if (linear_idx >= group_info.total_tiles + group_info.start_linear_idx) { group_info.group_idx += lane_idx; for ( ; ; group_info.group_idx += NumThreadsPerWarp) { @@ -325,11 +349,13 @@ class PersistentTileSchedulerSm90Group { ctas_along_m = divmod_cta_shape_m.divide(cute::shape<0>(cached_problem_shapes[0]) + divmod_cta_shape_m.divisor - 1); ctas_along_n = divmod_cta_shape_n.divide(cute::shape<1>(cached_problem_shapes[0]) + divmod_cta_shape_n.divisor - 1); } - auto problem_blocks_m = round_up(ctas_along_m, (1 << log_swizzle_size) * cluster_shape.m()); - auto problem_blocks_n = round_up(ctas_along_n, (1 << log_swizzle_size) * cluster_shape.n()); + group_info.log_swizzle_size = get_log_swizzle_size(ctas_along_m, ctas_along_n, max_swizzle_size); + auto problem_blocks_m = round_up(ctas_along_m, (1 << group_info.log_swizzle_size) * cluster_shape.m()); + auto problem_blocks_n = round_up(ctas_along_n, (1 << group_info.log_swizzle_size) * cluster_shape.n()); group_info.problem_blocks_along_raster_order = raster_order == RasterOrder::AlongN ? problem_blocks_n : problem_blocks_m; group_info.total_tiles = problem_blocks_m * problem_blocks_n; - } else { + } + else { group_info.total_tiles = INT_MAX; } @@ -351,6 +377,7 @@ class PersistentTileSchedulerSm90Group { group_info.start_linear_idx = __shfl_sync(0xffffffff, group_info.start_linear_idx, first_succeeding_thread); group_info.total_tiles = __shfl_sync(0xffffffff, group_info.total_tiles, first_succeeding_thread); group_info.problem_blocks_along_raster_order = __shfl_sync(0xffffffff, group_info.problem_blocks_along_raster_order, first_succeeding_thread); + group_info.log_swizzle_size = __shfl_sync(0xffffffff, group_info.log_swizzle_size, first_succeeding_thread); if (group_info.group_idx + lane_idx < total_problem_groups) { cached_problem_shapes[1] = problem_shapes.get_problem_shape(group_info.group_idx + lane_idx); } @@ -385,15 +412,15 @@ class PersistentTileSchedulerSm90Group { uint64_t cluster_idx_minor_div_swizzle, extra, offset; - offset = cluster_id & ((1 << log_swizzle_size) - 1); - extra = cluster_id >> log_swizzle_size; + offset = cluster_id & ((1 << group_info.log_swizzle_size) - 1); + extra = cluster_id >> group_info.log_swizzle_size; uint64_t curr_group_cluster_blk_major = divmod_cluster_shape_major.divide(group_info.problem_blocks_along_raster_order); cluster_idx_minor_div_swizzle = extra / curr_group_cluster_blk_major; cluster_idx_major = extra % curr_group_cluster_blk_major; - cluster_idx_minor = cluster_idx_minor_div_swizzle * (1 << log_swizzle_size) + offset; + cluster_idx_minor = cluster_idx_minor_div_swizzle * (1 << group_info.log_swizzle_size) + offset; auto minor_work_idx = static_cast(cluster_idx_minor * divmod_cluster_shape_minor.divisor + cluster_minor_offset); @@ -425,26 +452,34 @@ class PersistentTileSchedulerSm90Group { scheduler_params.divmod_cluster_shape_minor_, scheduler_params.divmod_cta_shape_m_, scheduler_params.divmod_cta_shape_n_, - scheduler_params.log_swizzle_size_, + scheduler_params.max_swizzle_size_, scheduler_params.raster_order_); } - template + + template CUTLASS_DEVICE auto advance_to_next_work( TileSchedulerPipeline& scheduler_pipeline, TileSchedulerPipelineState scheduler_pipe_producer_state, - uint32_t advance_count = 1) { + uint32_t advance_count = 1, + CallbackBeforeCommit callback_before_commit = [] (WorkTileInfo info) { return info;}) { current_work_linear_idx_ += total_grid_size_ * uint64_t(advance_count); auto work_tile = get_current_work_for_linear_idx(current_work_linear_idx_); + using WorkTileWithCallbackInfo = decltype(callback_before_commit(work_tile)); + WorkTileWithCallbackInfo work_tile_with_callback_info = work_tile; scheduler_pipeline.producer_acquire(scheduler_pipe_producer_state); + if (work_tile_with_callback_info.is_valid()) { + work_tile_with_callback_info = callback_before_commit(work_tile); + } + if (cute::elect_one_sync()) { - response_ptr_[scheduler_pipe_producer_state.index()] = work_tile; + reinterpret_cast(response_ptr_)[scheduler_pipe_producer_state.index()] = work_tile_with_callback_info; cutlass::arch::fence_view_async_shared(); scheduler_pipeline.producer_commit(scheduler_pipe_producer_state); } - return cute::make_tuple(work_tile, true); + return cute::make_tuple(work_tile_with_callback_info, true); } // Returns whether the block assigned this work should compute the epilogue for the corresponding @@ -555,31 +590,37 @@ class PersistentTileSchedulerSm90Group { } // Kernel helper function to get next work tile - template + template CUTLASS_DEVICE auto fetch_next_work( - WorkTileInfo work_tile_info, + WorkTileWithCallbackInfo work_tile_with_callback_info, TileSchedulerPipeline& scheduler_pipeline, TileSchedulerPipelineState scheduler_pipe_consumer_state) { - if (continue_current_work(work_tile_info)) { - return cute::make_tuple(work_tile_info, true); + if (continue_current_work(work_tile_with_callback_info)) { + return cute::make_tuple(work_tile_with_callback_info, true); } scheduler_pipeline.consumer_wait(scheduler_pipe_consumer_state); - auto work_tile = response_ptr_[scheduler_pipe_consumer_state.index()]; + work_tile_with_callback_info = reinterpret_cast(response_ptr_)[scheduler_pipe_consumer_state.index()]; cutlass::arch::fence_view_async_shared(); scheduler_pipeline.consumer_release(scheduler_pipe_consumer_state); - return cute::make_tuple(work_tile, true); + return cute::make_tuple(work_tile_with_callback_info, true); } // Returns the initial work tile info that will be computed over - template + template CUTLASS_DEVICE auto - initial_work_tile_info(ClusterShape) { - return get_current_work_for_linear_idx(current_work_linear_idx_); + initial_work_tile_info(ClusterShape, CallbackBeforeCommit callback_before_commit = [] (WorkTileInfo response) { return response;}) { + auto work_tile = get_current_work_for_linear_idx(current_work_linear_idx_); + using WorkTileWithCallbackInfo = decltype(callback_before_commit(work_tile)); + WorkTileWithCallbackInfo work_tile_with_callback_info = work_tile; + if (work_tile_with_callback_info.is_valid()) { + work_tile_with_callback_info = callback_before_commit(work_tile); + } + return work_tile_with_callback_info; } }; diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp index a298e06..c874d63 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -246,8 +246,9 @@ class PersistentTileSchedulerSm90StreamK { static bool can_implement(Arguments const& args) { - // Split count > 1 is only valid for heuristic and split-K decomposition modes - return (args.splits == 1 || + // Split count must be positive, and > 1 is only valid for heuristic and split-K decomposition modes + return args.splits >= 1 && + (args.splits == 1 || args.decomposition_mode == DecompositionMode::Heuristic || args.decomposition_mode == DecompositionMode::SplitK); } @@ -328,13 +329,15 @@ class PersistentTileSchedulerSm90StreamK { CUTLASS_DEVICE bool is_last_tile(WorkTileInfo work_tile_info, uint32_t advance_count = 1) const { - // Never pass this by reference; it needs a copy, + // Never pass this by reference; it needs a copy, // because continue_current_work will modify it. if (continue_current_work(work_tile_info)) { return false; } + // Create a copy to avoid unit_iter_start_ being modified + uint32_t unit_iter_start = unit_iter_start_; return not get_current_work_for_linear_idx( - unit_iter_start_, + unit_iter_start, current_work_linear_idx_ + ( uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z) * uint64_t(advance_count) ), diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sparse_gemm.h b/3rd/cutlass/include/cutlass/gemm/kernel/sparse_gemm.h index 84102a6..d6cf5c2 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/sparse_gemm.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sparse_gemm.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sparse_gemm_with_absmax.h b/3rd/cutlass/include/cutlass/gemm/kernel/sparse_gemm_with_absmax.h index 0574c21..967df73 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/sparse_gemm_with_absmax.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sparse_gemm_with_absmax.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/sparse_gemm_with_visitor.h b/3rd/cutlass/include/cutlass/gemm/kernel/sparse_gemm_with_visitor.h index a8ec1c3..9e7af48 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/sparse_gemm_with_visitor.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/sparse_gemm_with_visitor.h @@ -1,6 +1,6 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/static_tile_scheduler.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/static_tile_scheduler.hpp index f8319b1..0fb6b52 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/static_tile_scheduler.hpp +++ b/3rd/cutlass/include/cutlass/gemm/kernel/static_tile_scheduler.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/symm_universal.h b/3rd/cutlass/include/cutlass/gemm/kernel/symm_universal.h index 29cf977..77ad73f 100755 --- a/3rd/cutlass/include/cutlass/gemm/kernel/symm_universal.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/symm_universal.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/tile_scheduler.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/tile_scheduler.hpp index aa7bd0d..be96227 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/tile_scheduler.hpp +++ b/3rd/cutlass/include/cutlass/gemm/kernel/tile_scheduler.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -298,6 +298,66 @@ struct TileSchedulerSelector< using Scheduler = StaticPersistentTileScheduler100; }; +template +struct TileSchedulerSelector< + PersistentScheduler, + arch::Sm103, + TileShape, + ClusterShape, + SchedulerPipelineStageCount> { + using Scheduler = PersistentTileSchedulerSm100< + ClusterShape, + SchedulerPipelineStageCount>; +}; + +// Ptr-Array kernel may provide a specialized ArrayProblemShape type +template +struct TileSchedulerSelector< + PersistentScheduler, + arch::Sm103, + TileShape, + ClusterShape, + SchedulerPipelineStageCount, + ProblemShape> { + using Scheduler = PersistentTileSchedulerSm100< + ClusterShape, + SchedulerPipelineStageCount>; +}; + +// SM103 Group tile scheduler +template < + class TileShape, + class ClusterShape, + uint32_t SchedulerPipelineStageCount, + class GroupProblemShape +> +struct TileSchedulerSelector< + GroupScheduler, + arch::Sm103, + TileShape, + ClusterShape, + SchedulerPipelineStageCount, + GroupProblemShape + > { + using Scheduler = PersistentTileSchedulerSm100Group; +}; + +template +struct TileSchedulerSelector< + StreamKScheduler, + arch::Sm103, + TileShape, + ClusterShape, + SchedulerPipelineStageCount> { + using Scheduler = PersistentTileSchedulerSm100StreamK< + TileShape, + ClusterShape, + SchedulerPipelineStageCount>; +}; + // Default (void) for Sm120 maps to PersistentTileSchedulerSm100 template struct TileSchedulerSelector< diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/tile_scheduler_detail.hpp b/3rd/cutlass/include/cutlass/gemm/kernel/tile_scheduler_detail.hpp index b1d192c..603dc79 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/tile_scheduler_detail.hpp +++ b/3rd/cutlass/include/cutlass/gemm/kernel/tile_scheduler_detail.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/tile_scheduler_params.h b/3rd/cutlass/include/cutlass/gemm/kernel/tile_scheduler_params.h index 9a89bf2..09fda77 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/tile_scheduler_params.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/tile_scheduler_params.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -409,7 +409,7 @@ struct PersistentTileSchedulerSm90StreamKParams { FastDivmodU64 divmod_clusters_mnl_{}; // We divide up the number of stream-K tiles amongst G groups of stream-K units. - // The stream-K units within a group collaborate to comptue over the `sk_tiles / G` + // The stream-K units within a group collaborate to compute over the `sk_tiles / G` // tiles assigned to that group. Non-unit group sizes can help to preserve L2 locality of // partial chunks computed by stream-K units -- units 0 in each group will compute identical K extents // of tiles that would be assigned in the same wave according to the rasterization order of the @@ -932,7 +932,7 @@ struct PersistentTileSchedulerSm90StreamKParams { } } - // Given decomposition mode output from heuristic, set all feilds of params. + // Given decomposition mode output from heuristic, set all fields of params. void set_params( DecompositionMode heuristic_mode, uint32_t groups, @@ -954,7 +954,7 @@ struct PersistentTileSchedulerSm90StreamKParams { , uint32_t ktile_start_alignment_count ) { // The highest priority when customers set as splitk mode, may set - // with a adpated splits value rather than the original splits + // with a adapted splits value rather than the original splits // even it does not make sense if (splits > 1 && heuristic_mode == DecompositionMode::SplitK) { set_params_basic( @@ -1635,7 +1635,7 @@ struct PersistentTileSchedulerSm90GroupParams { uint64_t blocks_across_problem_ = 0; bool pre_processed_problem_shapes = true; - int32_t log_swizzle_size_ = 0; + int32_t max_swizzle_size_ = 0; RasterOrder raster_order_ = RasterOrder::AlongN; GroupProblemShape problem_shapes_; @@ -1658,10 +1658,8 @@ struct PersistentTileSchedulerSm90GroupParams { CUTLASS_UNUSED(hw_info); - // Round up to nearest multiple of swizzle_size along each mode - auto log_swizzle_size = get_log_swizzle_size(problem_blocks.x, problem_blocks.y, max_swizzle_size); - auto problem_blocks_m = round_up(problem_blocks.x, (1 << log_swizzle_size) * cluster_shape.m()); - auto problem_blocks_n = round_up(problem_blocks.y, (1 << log_swizzle_size) * cluster_shape.n()); + auto problem_blocks_m = round_up(problem_blocks.x, cluster_shape.m()); + auto problem_blocks_n = round_up(problem_blocks.y, cluster_shape.n()); RasterOrder raster_order = get_rasterization_order( problem_blocks_m, @@ -1678,7 +1676,7 @@ struct PersistentTileSchedulerSm90GroupParams { blocks_across_problem_ = problem_blocks.x * problem_blocks.y * problem_blocks.z; pre_processed_problem_shapes = problem_shapes.is_host_problem_shape_available(); - log_swizzle_size_ = log_swizzle_size; + max_swizzle_size_ = max_swizzle_size; raster_order_ = raster_order; if (raster_order == RasterOrder::AlongN) { @@ -1727,10 +1725,8 @@ struct PersistentTileSchedulerSm90GroupParams { int const sm_count = hw_info.sm_count; int const max_active_clusters = hw_info.max_active_clusters; - // Round up to nearest multiple of swizzle_size along each mode - auto log_swizzle_size = get_log_swizzle_size(problem_blocks.x, problem_blocks.y, max_swizzle_size); - auto problem_blocks_m = round_up(problem_blocks.x, (1 << log_swizzle_size) * cluster_shape.m()); - auto problem_blocks_n = round_up(problem_blocks.y, (1 << log_swizzle_size) * cluster_shape.n()); + auto problem_blocks_m = round_up(problem_blocks.x, cluster_shape.m()); + auto problem_blocks_n = round_up(problem_blocks.y, cluster_shape.n()); int problem_blocks_total = problem_blocks_m * problem_blocks_n * problem_blocks.z; @@ -1803,24 +1799,6 @@ struct PersistentTileSchedulerSm90GroupParams { return launch_grid; } - CUTLASS_HOST_DEVICE - static int32_t - get_log_swizzle_size(int problem_ctas_m, int problem_ctas_n, int max_swizzle_size) { - int min_cta_dim = platform::min(problem_ctas_m, problem_ctas_n); - if (max_swizzle_size >= 8 && min_cta_dim >= 6) { - return 3; - } - else if (max_swizzle_size >= 4 && min_cta_dim >= 3) { - return 2; - } - else if (max_swizzle_size >= 2 && min_cta_dim >= 2) { - return 1; - } - else { - return 0; - } - } - CUTLASS_HOST_DEVICE static RasterOrder get_rasterization_order( @@ -2496,10 +2474,8 @@ struct PersistentTileSchedulerSm100GroupParams { int const sm_count = hw_info.sm_count; int const max_active_clusters = hw_info.max_active_clusters; - // Round up to nearest multiple of swizzle_size along each mode - auto log_swizzle_size = get_log_swizzle_size(problem_blocks.x, problem_blocks.y, max_swizzle_size); - auto problem_blocks_m = round_up(problem_blocks.x, (1 << log_swizzle_size) * cluster_shape.m()); - auto problem_blocks_n = round_up(problem_blocks.y, (1 << log_swizzle_size) * cluster_shape.n()); + auto problem_blocks_m = round_up(problem_blocks.x, cluster_shape.m()); + auto problem_blocks_n = round_up(problem_blocks.y, cluster_shape.n()); int problem_blocks_total = problem_blocks_m * problem_blocks_n * problem_blocks.z; @@ -2581,12 +2557,6 @@ struct PersistentTileSchedulerSm100GroupParams { return launch_grid; } - CUTLASS_HOST_DEVICE - static int32_t - get_log_swizzle_size(int problem_ctas_m, int problem_ctas_n, int max_swizzle_size) { - return UnderlyingSm90Params::get_log_swizzle_size(problem_ctas_m, problem_ctas_n, max_swizzle_size); - } - CUTLASS_HOST_DEVICE static RasterOrder get_rasterization_order( diff --git a/3rd/cutlass/include/cutlass/gemm/kernel/trmm_universal.h b/3rd/cutlass/include/cutlass/gemm/kernel/trmm_universal.h index 992aa48..d031ebd 100644 --- a/3rd/cutlass/include/cutlass/gemm/kernel/trmm_universal.h +++ b/3rd/cutlass/include/cutlass/gemm/kernel/trmm_universal.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/thread/mma.h b/3rd/cutlass/include/cutlass/gemm/thread/mma.h index 018963b..f09902e 100644 --- a/3rd/cutlass/include/cutlass/gemm/thread/mma.h +++ b/3rd/cutlass/include/cutlass/gemm/thread/mma.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/thread/mma_sm50.h b/3rd/cutlass/include/cutlass/gemm/thread/mma_sm50.h index e05c56e..aad655f 100644 --- a/3rd/cutlass/include/cutlass/gemm/thread/mma_sm50.h +++ b/3rd/cutlass/include/cutlass/gemm/thread/mma_sm50.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/thread/mma_sm60.h b/3rd/cutlass/include/cutlass/gemm/thread/mma_sm60.h index 64c8e03..d9ad4a4 100644 --- a/3rd/cutlass/include/cutlass/gemm/thread/mma_sm60.h +++ b/3rd/cutlass/include/cutlass/gemm/thread/mma_sm60.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/thread/mma_sm61.h b/3rd/cutlass/include/cutlass/gemm/thread/mma_sm61.h index f7127ed..7247a1e 100644 --- a/3rd/cutlass/include/cutlass/gemm/thread/mma_sm61.h +++ b/3rd/cutlass/include/cutlass/gemm/thread/mma_sm61.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/threadblock/default_ell_mma.h b/3rd/cutlass/include/cutlass/gemm/threadblock/default_ell_mma.h index e27c582..e518863 100644 --- a/3rd/cutlass/include/cutlass/gemm/threadblock/default_ell_mma.h +++ b/3rd/cutlass/include/cutlass/gemm/threadblock/default_ell_mma.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -94,7 +94,7 @@ template < typename InstructionShape_, /// Number of stages used in the pipelined mainloop int Stages, - /// Operation perfomed by GEMM + /// Operation performed by GEMM typename Operator, /// Store the accumulators in row major or column major. Row major is used /// when output layout is interleaved. @@ -364,7 +364,7 @@ template < typename InstructionShape, /// Number of stages used in the multistage mainloop int Stages, - /// Operation perfomed by GEMM + /// Operation performed by GEMM typename Operator > struct DefaultEllMma struct DefaultEllMma +struct DefaultMmaBlockwise; + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Layout type for C and D matrix operand + typename LayoutC, + /// Element Type for scales. + typename ElementScale, + /// Layout ytpe for scales. + typename LayoutScale, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the multistage mainloop + int Stages, + /// Operation perfomed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB, + /// Permute operand A + typename PermuteALayout, + /// Permute operand B + typename PermuteBLayout> +struct DefaultMmaBlockwise< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementAccumulator, LayoutC, ElementScale, LayoutScale, arch::OpClassTensorOp, ArchTag, + ThreadblockShape, WarpShape, InstructionShape, Stages, Operator, false, + SharedMemoryClear, GatherA, GatherB, PermuteALayout, PermuteBLayout> { + static_assert(platform::is_same::value || + platform::is_same>::value, + "simt epilogue must be row major"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, LayoutC, arch::OpClassTensorOp, + Stages, Operator, false, CacheOpA, CacheOpB>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA, GatherA, + PermuteALayout>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB, GatherB, + PermuteBLayout>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistageBlockwise< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, + MmaCore::kCacheOpB, ElementAccumulator, LayoutC, + ElementScale, LayoutScale, + typename MmaCore::MmaPolicy, Stages, SharedMemoryClear>; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/gemm/threadblock/default_mma_planar_complex_multistage.h b/3rd/cutlass/include/cutlass/gemm/threadblock/default_mma_planar_complex_multistage.h index cab385a..ca3e266 100644 --- a/3rd/cutlass/include/cutlass/gemm/threadblock/default_mma_planar_complex_multistage.h +++ b/3rd/cutlass/include/cutlass/gemm/threadblock/default_mma_planar_complex_multistage.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/threadblock/default_mma_planar_complex_pipelined.h b/3rd/cutlass/include/cutlass/gemm/threadblock/default_mma_planar_complex_pipelined.h index 51327c1..41b2bb5 100644 --- a/3rd/cutlass/include/cutlass/gemm/threadblock/default_mma_planar_complex_pipelined.h +++ b/3rd/cutlass/include/cutlass/gemm/threadblock/default_mma_planar_complex_pipelined.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/threadblock/default_mma_softmax_mainloop_fusion.h b/3rd/cutlass/include/cutlass/gemm/threadblock/default_mma_softmax_mainloop_fusion.h index f429b52..e2396db 100644 --- a/3rd/cutlass/include/cutlass/gemm/threadblock/default_mma_softmax_mainloop_fusion.h +++ b/3rd/cutlass/include/cutlass/gemm/threadblock/default_mma_softmax_mainloop_fusion.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -91,7 +91,7 @@ template < /// Whether problem has been transformed. This determines to which operand /// the softmax is applied. bool InternalTranspose, - /// Operation perfomed by GEMM + /// Operation performed by GEMM typename Operator, /// Store the accumulators in row major or column major. Row major is used /// when output layout is interleaved. diff --git a/3rd/cutlass/include/cutlass/gemm/threadblock/default_mma_with_reduction.h b/3rd/cutlass/include/cutlass/gemm/threadblock/default_mma_with_reduction.h index c1e0af7..f40f84e 100644 --- a/3rd/cutlass/include/cutlass/gemm/threadblock/default_mma_with_reduction.h +++ b/3rd/cutlass/include/cutlass/gemm/threadblock/default_mma_with_reduction.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -82,7 +82,7 @@ template < typename InstructionShape, /// Number of stages used in the pipelined mainloop int Stages, - /// Operation perfomed by GEMM + /// Operation performed by GEMM typename Operator, /// Store the accumulators in row major or column major. Row major is used /// when output layout is interleaved. diff --git a/3rd/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex.h b/3rd/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex.h index 62d0c49..66b649e 100644 --- a/3rd/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex.h +++ b/3rd/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core.h b/3rd/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core.h index 8751495..e78c935 100644 --- a/3rd/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core.h +++ b/3rd/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h b/3rd/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h index f9716f3..3fb559a 100644 --- a/3rd/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h +++ b/3rd/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/threadblock/default_multistage_trmm_complex.h b/3rd/cutlass/include/cutlass/gemm/threadblock/default_multistage_trmm_complex.h index 4045dd2..9571138 100644 --- a/3rd/cutlass/include/cutlass/gemm/threadblock/default_multistage_trmm_complex.h +++ b/3rd/cutlass/include/cutlass/gemm/threadblock/default_multistage_trmm_complex.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/threadblock/default_sparse_mma.h b/3rd/cutlass/include/cutlass/gemm/threadblock/default_sparse_mma.h index ca98212..fd6e2c9 100644 --- a/3rd/cutlass/include/cutlass/gemm/threadblock/default_sparse_mma.h +++ b/3rd/cutlass/include/cutlass/gemm/threadblock/default_sparse_mma.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -87,7 +87,7 @@ template < typename InstructionShape_, /// Number of stages used in the pipelined mainloop int Stages, - /// Operation perfomed by GEMM + /// Operation performed by GEMM typename Operator, /// Store the accumulators in row major or column major. Row major is used /// when output layout is interleaved. @@ -123,7 +123,7 @@ template < typename InstructionShape, /// Number of stages used in the multistage mainloop int Stages, - /// Operation perfomed by GEMM + /// Operation performed by GEMM typename Operator > struct DefaultSparseMma struct DefaultTrmm struct DefaultTrmm struct DefaultTrmm struct DefaultTrmm + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Element Type for for the scalesl + typename ElementScale_, + /// Layout for the scales. + typename LayoutScale_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Used for partial specialization + typename Enable = bool> +class MmaMultistageBlockwise : public MmaMultistage { +public: + ///< Base class + using Base = MmaMultistage; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + /// Data type of scales + using ElementScale = ElementScale_; + /// Layout Type of Scales. + using LayoutScale = LayoutScale_; + ///< Policy describing tuning details + using Policy = Policy_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // Reference to the canonical MmaMultistage specialization with identical + // template arguments. This enables us to reuse its helper structures + // (Detail and PipeState) without redefining them here. + using BaseMma = MmaMultistage; + + /// Internal structure exposed for introspection (aliased from BaseMma). + using Detail = typename BaseMma::Detail; + + // Bring selected base-class helpers into scope so that calls like + // advance_smem_read_stage() resolve correctly in a dependent-name + // context where two-phase lookup would otherwise ignore the base + // class. + using Base::advance_smem_read_stage; + using Base::advance_smem_write_stage; + using Base::copy_tiles_and_advance; + using Base::prologue; + using Base::gmem_wait; + using Base::wind_down; + +private: + // Pipeline state structure reused from the canonical multistage kernel. + using PipeState = typename BaseMma::PublicPipeState; + +private: + // + // Data members + // + + /// Warp-level MMA operator + Operator warp_mma_; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + MmaMultistageBlockwise( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx) + { + // All per-warp iterator adjustments are handled by the base-class + // constructor, so no additional work is required here. + } + + /// Perform a threadblock mainloop iteration of matrix multiply-accumulate + CUTLASS_DEVICE + void mac_loop_iter( + PipeState &pipe_state, ///< [in|out] loop-carried pipeline state + FragmentC &accum, ///< [in|out] destination accumulator tile + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + int &gemm_k_iterations, + cutlass::TensorRef scale_A, // blockwise scale tensor for A + cutlass::TensorRef scale_B, // blockwise scale tensor for B + int k_iter_idx, ///< current K-block index processed by this iteration + int block_m_idx, ///< threadblock index along M dimension (row) + int block_n_idx) ///< threadblock index along N dimension (col) + { + // Unroll the warp-level MMA tiles of a threadblock's mainloop iteration + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + + // Load the next warp-tile's A fragment from shared memory + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + // Load the next warp-tile's B fragment from shared memory + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_B_; + + // Except for the first warp-tile, all warp-tiles convert their incoming shared memory fragments as necessary + if (warp_mma_k > 0) { + warp_mma_.transform(pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_loaded_frag_B_[warp_mma_k % 2]); + } + + // Compute scale factor for the current (M-tile, N-tile, K-tile) triple. + // The K-tile index used for scaling must not exceed the allocated range + // of the scale tensors. This situation can arise in the prologue / + // epilogue iterations of the multistage pipeline when the software + // pipeline executes Stages-1 extra iterations with gemm_k_iterations < 0. + + int ldA = int(scale_A.layout().stride(0)); + int k_block_idx = k_iter_idx; + if (k_block_idx >= ldA) { + k_block_idx = ldA - 1; + } + + float scale_factor = scale_A.at({block_m_idx, k_block_idx}) * + scale_B.at({block_n_idx, k_block_idx}); + + // Perform MMA into a temporary fragment (unscaled) + FragmentC delta; + FragmentC zero_frag; + zero_frag.clear(); + + warp_mma_(delta, pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], zero_frag); + + // Apply dequantization scaling + CUTLASS_PRAGMA_UNROLL + for (int el = 0; el < FragmentC::kElements; ++el) { + delta[el] *= scale_factor; + } + + // Accumulate the scaled contribution + plus plus_accum; + + if (Detail::kStagedAccumulation) { + pipe_state.tmp_accum_ = plus_accum(pipe_state.tmp_accum_, delta); + + if (warp_mma_k == 0) { + accum = plus_accum(accum, pipe_state.tmp_accum_); + pipe_state.tmp_accum_.clear(); + } + } else { + accum = plus_accum(accum, delta); + } + + // Except for the last warp-tile, all warp-tiles issue their share of + // global->shared fragment copies + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, + group_start_iteration_B); + } + + // The second-to-last warp-tile also: + // - performs the last warp-tile's share of global->shared fragment + // copies + // - moves to the next global fetch stage + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + // Performs the last warp-tile's share of global->shared fragment copies + int group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + int group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, + group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Wait until we have at least one completed global fetch stage + gmem_wait(); + + // Move to the next global fetch stage + advance_smem_write_stage(iterator_A, iterator_B); + advance_smem_read_stage(); + + // Disable global fetching when done with global fetch iterations + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + } + + // The last warp-tile also converts the shared memory fragments used by + // the first warp-tile of the next iteration, if necessary (so we can + // immediately start issuing MMA instructions at the top of the loop ) + if (warp_mma_k + 1 == Base::kWarpGemmIterations) { + warp_mma_.transform( + pipe_state.warp_transformed_frag_A_[(warp_mma_k + 1) % 2], + pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], + pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2], + pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]); + } + } + } + + /// Perform the specified number of threadblock mainloop iterations of matrix + /// multiply-accumulate. Assumes prologue has been initiated. + CUTLASS_DEVICE + void gemm_iters( + int gemm_k_iterations, ///< number of threadblock mainloop iterations + FragmentC &accum, ///< [in|out] accumulator tile + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, + cutlass::TensorRef scale_A, // blockwise scale tensor for A + cutlass::TensorRef scale_B, // blockwise scale tensor for B + int block_m_idx, + int block_n_idx) ///< [in|out] iterator over B operand in global memory + { + PipeState pipe_state; + + // Disable global fetching if done with global fetch iterations + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + // Load first warp-tile's A fragment from shared memory + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[0]); + ++this->warp_tile_iterator_A_; + + // Load first warp-tile's B fragment from shared memory + this->warp_tile_iterator_B_.set_kgroup_index(0); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]); + ++this->warp_tile_iterator_B_; + + // Transform, if necessary, the first warp-tile's shared memory fragments + warp_mma_.transform(pipe_state.warp_transformed_frag_A_[0], + pipe_state.warp_transformed_frag_B_[0], + pipe_state.warp_loaded_frag_A_[0], + pipe_state.warp_loaded_frag_B_[0]); + + if (Detail::kStagedAccumulation) { + pipe_state.tmp_accum_.clear(); + } + + // Mainloop + int k_iter_idx = 0; + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1); ++k_iter_idx) { + mac_loop_iter(pipe_state, accum, iterator_A, iterator_B, + gemm_k_iterations, scale_A, scale_B, k_iter_idx, + block_m_idx, block_n_idx); + } + + if (Detail::kStagedAccumulation) { + plus plus_accum; + accum = plus_accum(accum, pipe_state.tmp_accum_); + } + + // Commit and drain all pending and predicated cp.async pnz from the GEMM + // mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< initial value of accumulator + FragmentC const &src_accum, + cutlass::TensorRef scaleA, + cutlass::TensorRef scaleB) { + // Each scale element corresponds to a 128x128 tile along (M, K) for A and + // (N, K) for B. Grid dimension X enumerates threadblock tiles along M and + // grid dimension Y along N when GemmIdentityThreadblockSwizzle is used with + // the default N = 1 (tile = 1). Therefore, + // blockIdx.x -> tile index along the M dimension + // blockIdx.y -> tile index along the N dimension. + + constexpr int kScaleBlock = 128; + // Row-wise block index for A (and output C/D) – one per 128 rows. + int block_m_idx = (blockIdx.x * Shape::kM) / kScaleBlock; + + // Column-wise block index for B – one per 128 columns. Note that each + // threadblock processes Shape::kN columns, which may be < 128 (64 in this + // kernel). We therefore map two consecutive threadblock tiles onto the + // same 128-wide scale block when Shape::kN < kScaleBlock. + int block_n_idx = (blockIdx.y * Shape::kN) / kScaleBlock; + + // Prologue (start fetching iterations of global fragments into shared + // memory) + prologue(iterator_A, iterator_B, gemm_k_iterations); + + // Wait until we have at least one completed global fetch stage + gmem_wait(); + + // Initialize destination accumulators with source accumulators + accum = src_accum; + + // Perform the MAC-iterations with blockwise dequantization + gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, scaleA, scaleB, + block_m_idx, block_n_idx); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/gemm/threadblock/mma_pipelined.h b/3rd/cutlass/include/cutlass/gemm/threadblock/mma_pipelined.h index 87ccc0a..957f3f6 100644 --- a/3rd/cutlass/include/cutlass/gemm/threadblock/mma_pipelined.h +++ b/3rd/cutlass/include/cutlass/gemm/threadblock/mma_pipelined.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_base.h b/3rd/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_base.h index b0ba509..7da7831 100644 --- a/3rd/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_base.h +++ b/3rd/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_base.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_multistage.h b/3rd/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_multistage.h index 2298981..fc6477f 100644 --- a/3rd/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_multistage.h +++ b/3rd/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_multistage.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -115,7 +115,7 @@ class MmaPlanarComplexMultistage : ///< Policy describing tuning details using Policy = Policy_; - ///< Archtecture tag + ///< Architecture tag using ArchTag = arch::Sm80; using SmemIteratorA = SmemIteratorA_; diff --git a/3rd/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_pipelined.h b/3rd/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_pipelined.h index 4458596..dadaf4f 100644 --- a/3rd/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_pipelined.h +++ b/3rd/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_pipelined.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/threadblock/mma_singlestage.h b/3rd/cutlass/include/cutlass/gemm/threadblock/mma_singlestage.h index d3b84d6..a7ea9a9 100644 --- a/3rd/cutlass/include/cutlass/gemm/threadblock/mma_singlestage.h +++ b/3rd/cutlass/include/cutlass/gemm/threadblock/mma_singlestage.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -173,7 +173,7 @@ class MmaSingleStage : public MmaBase { FragmentC &accum, ///< destination accumulator tile IteratorA iterator_A, ///< iterator over A operand in global memory IteratorB iterator_B, ///< iterator over B operand in global memory - FragmentC const &src_accum) { ///< source accumualtor tile + FragmentC const &src_accum) { ///< source accumulator tile // // Prologue diff --git a/3rd/cutlass/include/cutlass/gemm/threadblock/mma_softmax_mainloop_fusion_multistage.h b/3rd/cutlass/include/cutlass/gemm/threadblock/mma_softmax_mainloop_fusion_multistage.h index 5174be4..ebf4b0e 100644 --- a/3rd/cutlass/include/cutlass/gemm/threadblock/mma_softmax_mainloop_fusion_multistage.h +++ b/3rd/cutlass/include/cutlass/gemm/threadblock/mma_softmax_mainloop_fusion_multistage.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/threadblock/mma_sparse_base.h b/3rd/cutlass/include/cutlass/gemm/threadblock/mma_sparse_base.h index 9e94b0f..8ebb991 100644 --- a/3rd/cutlass/include/cutlass/gemm/threadblock/mma_sparse_base.h +++ b/3rd/cutlass/include/cutlass/gemm/threadblock/mma_sparse_base.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/threadblock/mma_sparse_multistage.h b/3rd/cutlass/include/cutlass/gemm/threadblock/mma_sparse_multistage.h index 8bc23c3..9c8a2e5 100644 --- a/3rd/cutlass/include/cutlass/gemm/threadblock/mma_sparse_multistage.h +++ b/3rd/cutlass/include/cutlass/gemm/threadblock/mma_sparse_multistage.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/threadblock/mma_with_reduction_multistage.h b/3rd/cutlass/include/cutlass/gemm/threadblock/mma_with_reduction_multistage.h index 2fd49a5..8bcf9c2 100644 --- a/3rd/cutlass/include/cutlass/gemm/threadblock/mma_with_reduction_multistage.h +++ b/3rd/cutlass/include/cutlass/gemm/threadblock/mma_with_reduction_multistage.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/threadblock/threadblock_swizzle.h b/3rd/cutlass/include/cutlass/gemm/threadblock/threadblock_swizzle.h index 9495d78..fb62615 100644 --- a/3rd/cutlass/include/cutlass/gemm/threadblock/threadblock_swizzle.h +++ b/3rd/cutlass/include/cutlass/gemm/threadblock/threadblock_swizzle.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h b/3rd/cutlass/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h index 7141a6c..9c59283 100644 --- a/3rd/cutlass/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h +++ b/3rd/cutlass/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -513,7 +513,7 @@ struct ThreadblockSwizzleStreamK { // - More than three peers working on an SK tile. (This occurs when the ratio of // SK-blocks to SK-tiles > 2, as a single tile may be covered by four SK-blocks, // e.g.:[partial-block | block | block | partial-block] ). With three or - // less peers, the two non-finishing SK-blocks are not expexted to contend. + // less peers, the two non-finishing SK-blocks are not expected to contend. if ((kReductionStrategy == kMixed) && (sk_waves < sm_occupancy) && (sk_blocks > 2 * sk_tiles)) diff --git a/3rd/cutlass/include/cutlass/gemm/warp/default_mma_complex_tensor_op.h b/3rd/cutlass/include/cutlass/gemm/warp/default_mma_complex_tensor_op.h index 067da30..cfe758b 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/default_mma_complex_tensor_op.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/default_mma_complex_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/warp/default_mma_sparse_tensor_op.h b/3rd/cutlass/include/cutlass/gemm/warp/default_mma_sparse_tensor_op.h index e2cb3f2..0ba4e3a 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/default_mma_sparse_tensor_op.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/default_mma_sparse_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op.h b/3rd/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op.h index 44d7fe1..4ac5346 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h b/3rd/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h index 8c9abb8..a6da8a5 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/warp/default_mma_with_reduction_tensor_op.h b/3rd/cutlass/include/cutlass/gemm/warp/default_mma_with_reduction_tensor_op.h index 7bd8c0f..81d4790 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/default_mma_with_reduction_tensor_op.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/default_mma_with_reduction_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/warp/default_mma_wmma_tensor_op.h b/3rd/cutlass/include/cutlass/gemm/warp/default_mma_wmma_tensor_op.h index 6a90a78..50b7785 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/default_mma_wmma_tensor_op.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/default_mma_wmma_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/warp/layernorm_scale_bias_transform.h b/3rd/cutlass/include/cutlass/gemm/warp/layernorm_scale_bias_transform.h index f032f26..ba15756 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/layernorm_scale_bias_transform.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/layernorm_scale_bias_transform.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/warp/mma.h b/3rd/cutlass/include/cutlass/gemm/warp/mma.h index cd67743..b5a83fa 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/mma.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/mma.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op.h b/3rd/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op.h index baaced7..6bdf8f7 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -782,7 +782,7 @@ class MmaComplexTensorOp< for (int n = 0; n < MmaIterations::kColumn; ++n) { // negate OperandB to accumulate -(a.imag()*b.imag()) - // negating OperandB emits less instrucitons than negating OperandA as OperandB has less elements + // negating OperandB emits less instructions than negating OperandA as OperandB has less elements negate negate_op; // Real-valued accumulator part diff --git a/3rd/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op_fast_f32.h b/3rd/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op_fast_f32.h index e84ae06..71f6bae 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op_fast_f32.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op_fast_f32.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -598,7 +598,7 @@ class MmaComplexTensorOpFastF32< for (int n = 0; n < MmaIterations::kColumn; ++n) { // negate OperandB to accumulate -(a.imag()*b.imag()) - // negating OperandB emits less instrucitons than negating OperandA as OperandB has less elements + // negating OperandB emits less instructions than negating OperandA as OperandB has less elements negate negate_op; // Real-valued accumulator part diff --git a/3rd/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h b/3rd/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h index e14450d..f799586 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op.h b/3rd/cutlass/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op.h index 6728ac2..0b4a629 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op_tile_iterator_sm80.h b/3rd/cutlass/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op_tile_iterator_sm80.h index ec99c77..809bb72 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op_tile_iterator_sm80.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op_tile_iterator_sm80.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h b/3rd/cutlass/include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h index 4e16ff8..8b0581e 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -427,7 +427,7 @@ class MmaMixedInputTensorOp { using TransformedFragmentA = Array; - /// Underlying arch::Mma instruction operand fragement for matrix A + /// Underlying arch::Mma instruction operand fragment for matrix A using MmaOperandA = typename ArchMmaOperator::FragmentA; /// Iterates over the B operand in Shared Memory @@ -443,7 +443,7 @@ class MmaMixedInputTensorOp { using TransformedFragmentB = Array; - /// Underlying arch::Mma instruction operand fragement for matrix B + /// Underlying arch::Mma instruction operand fragment for matrix B using MmaOperandB = typename ArchMmaOperator::FragmentB; /// Iterates over the C operand in memory @@ -454,7 +454,7 @@ class MmaMixedInputTensorOp { /// Storage for C tile using FragmentC = typename IteratorC::Fragment; - /// Underlying arch::Mma instruction operand fragement for matrix C + /// Underlying arch::Mma instruction operand fragment for matrix C using MmaOperandC = typename ArchMmaOperator::FragmentC; /// Number of mma operations performed diff --git a/3rd/cutlass/include/cutlass/gemm/warp/mma_planar_complex.h b/3rd/cutlass/include/cutlass/gemm/warp/mma_planar_complex.h index af1031a..1f8416f 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/mma_planar_complex.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/mma_planar_complex.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/warp/mma_simt.h b/3rd/cutlass/include/cutlass/gemm/warp/mma_simt.h index c4152da..b438b3b 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/mma_simt.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/mma_simt.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/warp/mma_simt_policy.h b/3rd/cutlass/include/cutlass/gemm/warp/mma_simt_policy.h index 9bca234..7d0804a 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/mma_simt_policy.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/mma_simt_policy.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/warp/mma_simt_tile_iterator.h b/3rd/cutlass/include/cutlass/gemm/warp/mma_simt_tile_iterator.h index c522eaf..2f0f9bc 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/mma_simt_tile_iterator.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/mma_simt_tile_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/warp/mma_sparse_tensor_op.h b/3rd/cutlass/include/cutlass/gemm/warp/mma_sparse_tensor_op.h index 81668b4..24cae2a 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/mma_sparse_tensor_op.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/mma_sparse_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -117,7 +117,7 @@ class SparseMmaTensorOp { /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) using Policy = Policy_; - /// Equivalant base dense mma + /// Equivalent base dense mma using Base = MmaTensorOp; diff --git a/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op.h b/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op.h index 190e92f..ae06f2c 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h b/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h index 570298b..d73accc 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h b/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h index 1489694..813e5fe 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -33,7 +33,7 @@ \brief This defines a "fragment" iterator for visiting the fragments of a warp tile that participate in one warp-level mma operation. - Typically, this is used to access the accumulator tile/fragement of a warp-level mma operation. + Typically, this is used to access the accumulator tile/fragment of a warp-level mma operation. The accumulator tile is then partitioned into smaller tiles/fragments that can be fed into next warp-level mma operation. diff --git a/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_policy.h b/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_policy.h index febd0e4..92a1100 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_policy.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_policy.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_sm70.h b/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_sm70.h index e7a4d87..b83f67c 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_sm70.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_sm70.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h b/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h index 6446b7b..11dc017 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -62,7 +62,7 @@ namespace warp { /// Tile access iterator -/// Each iteration acess in the tile is +/// Each iteration access in the tile is /// used as multiplicand for one /// warp-level matrix multiplication template < diff --git a/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h b/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h index dd15097..d6e4628 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h b/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h index 0d1da84..e5bfda9 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h b/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h index a5370ff..641a897 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h b/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h index 97f7e14..7d09235 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_wmma.h b/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_wmma.h index 92e065f..bf54a7c 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_wmma.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_wmma.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_wmma.h b/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_wmma.h index ec44544..5a3d1ef 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_wmma.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/mma_tensor_op_wmma.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/warp/mma_with_reduction_tensor_op.h b/3rd/cutlass/include/cutlass/gemm/warp/mma_with_reduction_tensor_op.h index d97c8f4..f21aada 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/mma_with_reduction_tensor_op.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/mma_with_reduction_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/warp/scale_bias_tile_iterator.h b/3rd/cutlass/include/cutlass/gemm/warp/scale_bias_tile_iterator.h index 2d79dcf..a578c2d 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/scale_bias_tile_iterator.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/scale_bias_tile_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/warp/softmax_scale_bias_transform.h b/3rd/cutlass/include/cutlass/gemm/warp/softmax_scale_bias_transform.h index 7e3af9b..48b732b 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/softmax_scale_bias_transform.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/softmax_scale_bias_transform.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm/warp/tile_iterator_planar_complex.h b/3rd/cutlass/include/cutlass/gemm/warp/tile_iterator_planar_complex.h index 0406db0..1252091 100644 --- a/3rd/cutlass/include/cutlass/gemm/warp/tile_iterator_planar_complex.h +++ b/3rd/cutlass/include/cutlass/gemm/warp/tile_iterator_planar_complex.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm_coord.h b/3rd/cutlass/include/cutlass/gemm_coord.h index dd826de..ed8aa9e 100644 --- a/3rd/cutlass/include/cutlass/gemm_coord.h +++ b/3rd/cutlass/include/cutlass/gemm_coord.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/gemm_coord.hpp b/3rd/cutlass/include/cutlass/gemm_coord.hpp index a22b803..fddbbf0 100644 --- a/3rd/cutlass/include/cutlass/gemm_coord.hpp +++ b/3rd/cutlass/include/cutlass/gemm_coord.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/half.h b/3rd/cutlass/include/cutlass/half.h index f5fb90d..436dbaa 100644 --- a/3rd/cutlass/include/cutlass/half.h +++ b/3rd/cutlass/include/cutlass/half.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -68,7 +68,7 @@ /////////////////////////////////////////////////////////////////////////////////////////////////// -// Optionally target F16C extentions to accelerate half-precision conversion. +// Optionally target F16C extensions to accelerate half-precision conversion. #if !defined(__CUDA_ARCH__) && (CUTLASS_ENABLE_F16C) #if defined(_MSC_VER) @@ -918,12 +918,12 @@ half_t operator--(half_t & lhs, int) { // CUTLASS_HOST_DEVICE -cutlass::half_t operator "" _hf(long double x) { +cutlass::half_t operator""_hf(long double x) { return cutlass::half_t(float(x)); } CUTLASS_HOST_DEVICE -cutlass::half_t operator "" _hf(unsigned long long int x) { +cutlass::half_t operator""_hf(unsigned long long int x) { return cutlass::half_t(int(x)); } diff --git a/3rd/cutlass/include/cutlass/integer_subbyte.h b/3rd/cutlass/include/cutlass/integer_subbyte.h index 097605c..1867071 100644 --- a/3rd/cutlass/include/cutlass/integer_subbyte.h +++ b/3rd/cutlass/include/cutlass/integer_subbyte.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -35,14 +35,13 @@ */ #pragma once - +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(cstdint) #else #include #endif -#include "cutlass/cutlass.h" #include "cutlass/numeric_size.h" #include "cutlass/platform/platform.h" @@ -85,6 +84,14 @@ struct integer_subbyte { integer_subbyte(float value) : integer_subbyte(static_cast(value)) {} + CUTLASS_HOST_DEVICE + integer_subbyte(double value) + : integer_subbyte(static_cast(value)) {} + + CUTLASS_HOST_DEVICE + integer_subbyte(signed char value) + : integer_subbyte(static_cast(value)) {} + // CUTLASS code commonly converts both signed and unsigned integers // into integer_subbyte, so the class provides both explicit // conversions. @@ -115,10 +122,9 @@ struct integer_subbyte { : storage(reinterpret_cast(value) & bits_mask_) { if constexpr (Signed) { - [[maybe_unused]] constexpr int lower_bound = -(1 << (Bits - 1)); + // no need to check lower bound since input value is unsigned [[maybe_unused]] constexpr int upper_bound = (1 << (Bits - 1)) - 1; - assert(value >= lower_bound); - assert(value <= upper_bound); + assert(value <= static_cast(upper_bound)); } else { [[maybe_unused]] constexpr unsigned upper_bound = 1u << Bits; @@ -194,6 +200,9 @@ struct integer_subbyte { /////////////////////////////////////////////////////////////////////////////////////////////////// +/// 1-bit binary type +using bin1_t = bool; + /// 1-bit Unsigned integer type using uint1b_t = integer_subbyte<1, false>; @@ -209,14 +218,12 @@ using int4b_t = integer_subbyte<4, true>; /// 4-bit Unsigned integer type using uint4b_t = integer_subbyte<4, false>; +/// 6-bit integer type +using int6b_t = integer_subbyte<6, true>; /// 6-bit unsigned integer type using uint6b_t = integer_subbyte<6, false>; - -/// 1-bit binary type -using bin1_t = bool; - /////////////////////////////////////////////////////////////////////////////////////////////////// template diff --git a/3rd/cutlass/include/cutlass/kernel_hardware_info.h b/3rd/cutlass/include/cutlass/kernel_hardware_info.h index c24e2ba..ce09b9d 100644 --- a/3rd/cutlass/include/cutlass/kernel_hardware_info.h +++ b/3rd/cutlass/include/cutlass/kernel_hardware_info.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -90,7 +90,7 @@ struct KernelHardwareInfo { int max_active_clusters = 0; #if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) ClusterLauncher::LaunchConfig cluster_launch_config = ClusterLauncher::make_cluster_launch_config( - cluster_dims /* minumum grid dim */, cluster_dims, {threads_per_block, 1, 1}); + cluster_dims /* minimum grid dim */, cluster_dims, {threads_per_block, 1, 1}); // Given the kernel function and launch configuration, return the maximum number of clusters that could co-exist on the target device. cudaError_t result = cudaOccupancyMaxActiveClusters(&max_active_clusters, kernel_ptr, &cluster_launch_config.launch_config); if (result != cudaSuccess) { diff --git a/3rd/cutlass/include/cutlass/kernel_hardware_info.hpp b/3rd/cutlass/include/cutlass/kernel_hardware_info.hpp index e1758ea..fe5a505 100644 --- a/3rd/cutlass/include/cutlass/kernel_hardware_info.hpp +++ b/3rd/cutlass/include/cutlass/kernel_hardware_info.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/kernel_launch.h b/3rd/cutlass/include/cutlass/kernel_launch.h index e92e6c1..4a5a486 100644 --- a/3rd/cutlass/include/cutlass/kernel_launch.h +++ b/3rd/cutlass/include/cutlass/kernel_launch.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/layout/layout.h b/3rd/cutlass/include/cutlass/layout/layout.h index b2e377c..b2edcb8 100644 --- a/3rd/cutlass/include/cutlass/layout/layout.h +++ b/3rd/cutlass/include/cutlass/layout/layout.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/layout/matrix.h b/3rd/cutlass/include/cutlass/layout/matrix.h index 281b668..ee3290c 100644 --- a/3rd/cutlass/include/cutlass/layout/matrix.h +++ b/3rd/cutlass/include/cutlass/layout/matrix.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/layout/permute.h b/3rd/cutlass/include/cutlass/layout/permute.h index 32a6ee0..b311010 100644 --- a/3rd/cutlass/include/cutlass/layout/permute.h +++ b/3rd/cutlass/include/cutlass/layout/permute.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -38,9 +38,10 @@ computation lies in operator() with private member variables {col_permute_, row_permute_ and stride_} as new addresses after permute op. */ #pragma once - -#include #include "cutlass/cutlass.h" +#ifndef __QNX__ +#include CUDA_STD_HEADER(cassert) +#endif #include "cutlass/fast_math.h" #include "cutlass/layout/pitch_linear.h" #include "cutlass/layout/matrix.h" diff --git a/3rd/cutlass/include/cutlass/layout/pitch_linear.h b/3rd/cutlass/include/cutlass/layout/pitch_linear.h index 7052de1..eeafeb7 100644 --- a/3rd/cutlass/include/cutlass/layout/pitch_linear.h +++ b/3rd/cutlass/include/cutlass/layout/pitch_linear.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/layout/tensor.h b/3rd/cutlass/include/cutlass/layout/tensor.h index faf6427..b270660 100644 --- a/3rd/cutlass/include/cutlass/layout/tensor.h +++ b/3rd/cutlass/include/cutlass/layout/tensor.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -39,9 +39,8 @@ defined in cutlass/tensor_ref.h. */ #pragma once - -#include #include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) #include "cutlass/fast_math.h" #include "cutlass/layout/pitch_linear.h" #include "cutlass/layout/matrix.h" diff --git a/3rd/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm70.h b/3rd/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm70.h index e4d25a5..48a7a41 100644 --- a/3rd/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm70.h +++ b/3rd/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm70.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm75.h b/3rd/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm75.h index 6ca6005..3eb5414 100644 --- a/3rd/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm75.h +++ b/3rd/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm75.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm80.h b/3rd/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm80.h index e310490..4a893e0 100644 --- a/3rd/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm80.h +++ b/3rd/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm80.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/layout/vector.h b/3rd/cutlass/include/cutlass/layout/vector.h index 6cb74f3..2825d86 100644 --- a/3rd/cutlass/include/cutlass/layout/vector.h +++ b/3rd/cutlass/include/cutlass/layout/vector.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/matrix.h b/3rd/cutlass/include/cutlass/matrix.h index b46cbfe..e83e5ea 100644 --- a/3rd/cutlass/include/cutlass/matrix.h +++ b/3rd/cutlass/include/cutlass/matrix.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -101,7 +101,7 @@ struct Matrix { data = rhs.data; } - /// Constucts a 1-by-2 matrix from scalar elements + /// Constructs a 1-by-2 matrix from scalar elements CUTLASS_HOST_DEVICE Matrix( Element _0_0, Element _0_1 @@ -429,8 +429,8 @@ struct Matrix { Matrix operator-() const { Matrix m; - m.data[0] = -m.data[0]; - m.data[1] = -m.data[1]; + m.data[0] = -data[0]; + m.data[1] = -data[1]; return m; } @@ -599,7 +599,7 @@ template using Matrix1x2 = Matrix; -/// Free funciton to infer element type from template arguments +/// Free function to infer element type from template arguments template CUTLASS_HOST_DEVICE Matrix1x2 make_Matrix1x2( Element _0_0, Element _0_1 @@ -658,7 +658,7 @@ struct Matrix { data = rhs.data; } - /// Constucts a 1-by-3 matrix from scalar elements + /// Constructs a 1-by-3 matrix from scalar elements CUTLASS_HOST_DEVICE Matrix( Element _0_0, Element _0_1, Element _0_2 @@ -1023,9 +1023,9 @@ struct Matrix { Matrix operator-() const { Matrix m; - m.data[0] = -m.data[0]; - m.data[1] = -m.data[1]; - m.data[2] = -m.data[2]; + m.data[0] = -data[0]; + m.data[1] = -data[1]; + m.data[2] = -data[2]; return m; } @@ -1214,7 +1214,7 @@ struct Matrix { Matrix cross(Matrix const &rhs) const { return Matrix( data[1] * rhs.data[2] - data[2] * rhs.data[1], - data[0] * rhs.data[2] - data[2] * rhs.data[1], + data[2] * rhs.data[0] - data[0] * rhs.data[2], data[0] * rhs.data[1] - data[1] * rhs.data[0] ); } @@ -1226,7 +1226,7 @@ template using Matrix1x3 = Matrix; -/// Free funciton to infer element type from template arguments +/// Free function to infer element type from template arguments template CUTLASS_HOST_DEVICE Matrix1x3 make_Matrix1x3( Element _0_0, Element _0_1, Element _0_2 @@ -1285,7 +1285,7 @@ struct Matrix { data = rhs.data; } - /// Constucts a 1-by-4 matrix from scalar elements + /// Constructs a 1-by-4 matrix from scalar elements CUTLASS_HOST_DEVICE Matrix( Element _0_0, Element _0_1, Element _0_2, Element _0_3 @@ -1689,10 +1689,10 @@ struct Matrix { Matrix operator-() const { Matrix m; - m.data[0] = -m.data[0]; - m.data[1] = -m.data[1]; - m.data[2] = -m.data[2]; - m.data[3] = -m.data[3]; + m.data[0] = -data[0]; + m.data[1] = -data[1]; + m.data[2] = -data[2]; + m.data[3] = -data[3]; return m; } @@ -1905,7 +1905,7 @@ template using Matrix1x4 = Matrix; -/// Free funciton to infer element type from template arguments +/// Free function to infer element type from template arguments template CUTLASS_HOST_DEVICE Matrix1x4 make_Matrix1x4( Element _0_0, Element _0_1, Element _0_2, Element _0_3 @@ -1964,7 +1964,7 @@ struct Matrix { data = rhs.data; } - /// Constucts a 2-by-1 matrix from scalar elements + /// Constructs a 2-by-1 matrix from scalar elements CUTLASS_HOST_DEVICE Matrix( Element _0_0, @@ -2306,8 +2306,8 @@ struct Matrix { Matrix operator-() const { Matrix m; - m.data[0] = -m.data[0]; - m.data[1] = -m.data[1]; + m.data[0] = -data[0]; + m.data[1] = -data[1]; return m; } @@ -2471,7 +2471,7 @@ template using Matrix2x1 = Matrix; -/// Free funciton to infer element type from template arguments +/// Free function to infer element type from template arguments template CUTLASS_HOST_DEVICE Matrix2x1 make_Matrix2x1( Element _0_0, @@ -2532,7 +2532,7 @@ struct Matrix { data = rhs.data; } - /// Constucts a 2-by-2 matrix from scalar elements + /// Constructs a 2-by-2 matrix from scalar elements CUTLASS_HOST_DEVICE Matrix( Element _0_0, Element _0_1, @@ -2543,7 +2543,7 @@ struct Matrix { data[2] = _1_0; data[3] = _1_1; } - /// Constucts a 2-by-2 matrix from row vectors + /// Constructs a 2-by-2 matrix from row vectors CUTLASS_HOST_DEVICE Matrix( Matrix const &row_0, @@ -3040,10 +3040,10 @@ struct Matrix { Matrix operator-() const { Matrix m; - m.data[0] = -m.data[0]; - m.data[1] = -m.data[1]; - m.data[2] = -m.data[2]; - m.data[3] = -m.data[3]; + m.data[0] = -data[0]; + m.data[1] = -data[1]; + m.data[2] = -data[2]; + m.data[3] = -data[3]; return m; } @@ -3258,7 +3258,7 @@ template using Matrix2x2 = Matrix; -/// Free funciton to infer element type from template arguments +/// Free function to infer element type from template arguments template CUTLASS_HOST_DEVICE Matrix2x2 make_Matrix2x2( Element _0_0, Element _0_1, @@ -3319,7 +3319,7 @@ struct Matrix { data = rhs.data; } - /// Constucts a 2-by-3 matrix from scalar elements + /// Constructs a 2-by-3 matrix from scalar elements CUTLASS_HOST_DEVICE Matrix( Element _0_0, Element _0_1, Element _0_2, @@ -3330,7 +3330,7 @@ struct Matrix { data[3] = _1_0; data[4] = _1_1; data[5] = _1_2; } - /// Constucts a 2-by-3 matrix from row vectors + /// Constructs a 2-by-3 matrix from row vectors CUTLASS_HOST_DEVICE Matrix( Matrix const &row_0, @@ -3912,12 +3912,12 @@ struct Matrix { Matrix operator-() const { Matrix m; - m.data[0] = -m.data[0]; - m.data[1] = -m.data[1]; - m.data[2] = -m.data[2]; - m.data[3] = -m.data[3]; - m.data[4] = -m.data[4]; - m.data[5] = -m.data[5]; + m.data[0] = -data[0]; + m.data[1] = -data[1]; + m.data[2] = -data[2]; + m.data[3] = -data[3]; + m.data[4] = -data[4]; + m.data[5] = -data[5]; return m; } @@ -4128,7 +4128,7 @@ template using Matrix2x3 = Matrix; -/// Free funciton to infer element type from template arguments +/// Free function to infer element type from template arguments template CUTLASS_HOST_DEVICE Matrix2x3 make_Matrix2x3( Element _0_0, Element _0_1, Element _0_2, @@ -4189,7 +4189,7 @@ struct Matrix { data = rhs.data; } - /// Constucts a 2-by-4 matrix from scalar elements + /// Constructs a 2-by-4 matrix from scalar elements CUTLASS_HOST_DEVICE Matrix( Element _0_0, Element _0_1, Element _0_2, Element _0_3, @@ -4200,7 +4200,7 @@ struct Matrix { data[4] = _1_0; data[5] = _1_1; data[6] = _1_2; data[7] = _1_3; } - /// Constucts a 2-by-4 matrix from row vectors + /// Constructs a 2-by-4 matrix from row vectors CUTLASS_HOST_DEVICE Matrix( Matrix const &row_0, @@ -4884,14 +4884,14 @@ struct Matrix { Matrix operator-() const { Matrix m; - m.data[0] = -m.data[0]; - m.data[1] = -m.data[1]; - m.data[2] = -m.data[2]; - m.data[3] = -m.data[3]; - m.data[4] = -m.data[4]; - m.data[5] = -m.data[5]; - m.data[6] = -m.data[6]; - m.data[7] = -m.data[7]; + m.data[0] = -data[0]; + m.data[1] = -data[1]; + m.data[2] = -data[2]; + m.data[3] = -data[3]; + m.data[4] = -data[4]; + m.data[5] = -data[5]; + m.data[6] = -data[6]; + m.data[7] = -data[7]; return m; } @@ -5134,7 +5134,7 @@ template using Matrix2x4 = Matrix; -/// Free funciton to infer element type from template arguments +/// Free function to infer element type from template arguments template CUTLASS_HOST_DEVICE Matrix2x4 make_Matrix2x4( Element _0_0, Element _0_1, Element _0_2, Element _0_3, @@ -5195,7 +5195,7 @@ struct Matrix { data = rhs.data; } - /// Constucts a 3-by-1 matrix from scalar elements + /// Constructs a 3-by-1 matrix from scalar elements CUTLASS_HOST_DEVICE Matrix( Element _0_0, @@ -5590,9 +5590,9 @@ struct Matrix { Matrix operator-() const { Matrix m; - m.data[0] = -m.data[0]; - m.data[1] = -m.data[1]; - m.data[2] = -m.data[2]; + m.data[0] = -data[0]; + m.data[1] = -data[1]; + m.data[2] = -data[2]; return m; } @@ -5768,7 +5768,7 @@ struct Matrix { Matrix cross(Matrix const &rhs) const { return Matrix( data[1] * rhs.data[2] - data[2] * rhs.data[1], - data[0] * rhs.data[2] - data[2] * rhs.data[1], + data[2] * rhs.data[0] - data[0] * rhs.data[2], data[0] * rhs.data[1] - data[1] * rhs.data[0] ); } @@ -5780,7 +5780,7 @@ template using Matrix3x1 = Matrix; -/// Free funciton to infer element type from template arguments +/// Free function to infer element type from template arguments template CUTLASS_HOST_DEVICE Matrix3x1 make_Matrix3x1( Element _0_0, @@ -5843,7 +5843,7 @@ struct Matrix { data = rhs.data; } - /// Constucts a 3-by-2 matrix from scalar elements + /// Constructs a 3-by-2 matrix from scalar elements CUTLASS_HOST_DEVICE Matrix( Element _0_0, Element _0_1, @@ -5856,7 +5856,7 @@ struct Matrix { data[4] = _2_0; data[5] = _2_1; } - /// Constucts a 3-by-2 matrix from row vectors + /// Constructs a 3-by-2 matrix from row vectors CUTLASS_HOST_DEVICE Matrix( Matrix const &row_0, @@ -6457,12 +6457,12 @@ struct Matrix { Matrix operator-() const { Matrix m; - m.data[0] = -m.data[0]; - m.data[1] = -m.data[1]; - m.data[2] = -m.data[2]; - m.data[3] = -m.data[3]; - m.data[4] = -m.data[4]; - m.data[5] = -m.data[5]; + m.data[0] = -data[0]; + m.data[1] = -data[1]; + m.data[2] = -data[2]; + m.data[3] = -data[3]; + m.data[4] = -data[4]; + m.data[5] = -data[5]; return m; } @@ -6665,7 +6665,7 @@ template using Matrix3x2 = Matrix; -/// Free funciton to infer element type from template arguments +/// Free function to infer element type from template arguments template CUTLASS_HOST_DEVICE Matrix3x2 make_Matrix3x2( Element _0_0, Element _0_1, @@ -6728,7 +6728,7 @@ struct Matrix { data = rhs.data; } - /// Constucts a 3-by-3 matrix from scalar elements + /// Constructs a 3-by-3 matrix from scalar elements CUTLASS_HOST_DEVICE Matrix( Element _0_0, Element _0_1, Element _0_2, @@ -6741,7 +6741,7 @@ struct Matrix { data[6] = _2_0; data[7] = _2_1; data[8] = _2_2; } - /// Constucts a 3-by-3 matrix from row vectors + /// Constructs a 3-by-3 matrix from row vectors CUTLASS_HOST_DEVICE Matrix( Matrix const &row_0, @@ -7514,15 +7514,15 @@ struct Matrix { Matrix operator-() const { Matrix m; - m.data[0] = -m.data[0]; - m.data[1] = -m.data[1]; - m.data[2] = -m.data[2]; - m.data[3] = -m.data[3]; - m.data[4] = -m.data[4]; - m.data[5] = -m.data[5]; - m.data[6] = -m.data[6]; - m.data[7] = -m.data[7]; - m.data[8] = -m.data[8]; + m.data[0] = -data[0]; + m.data[1] = -data[1]; + m.data[2] = -data[2]; + m.data[3] = -data[3]; + m.data[4] = -data[4]; + m.data[5] = -data[5]; + m.data[6] = -data[6]; + m.data[7] = -data[7]; + m.data[8] = -data[8]; return m; } @@ -7896,7 +7896,7 @@ template using Matrix3x3 = Matrix; -/// Free funciton to infer element type from template arguments +/// Free function to infer element type from template arguments template CUTLASS_HOST_DEVICE Matrix3x3 make_Matrix3x3( Element _0_0, Element _0_1, Element _0_2, @@ -7959,7 +7959,7 @@ struct Matrix { data = rhs.data; } - /// Constucts a 3-by-4 matrix from scalar elements + /// Constructs a 3-by-4 matrix from scalar elements CUTLASS_HOST_DEVICE Matrix( Element _0_0, Element _0_1, Element _0_2, Element _0_3, @@ -7972,7 +7972,7 @@ struct Matrix { data[8] = _2_0; data[9] = _2_1; data[10] = _2_2; data[11] = _2_3; } - /// Constucts a 3-by-4 matrix from row vectors + /// Constructs a 3-by-4 matrix from row vectors CUTLASS_HOST_DEVICE Matrix( Matrix const &row_0, @@ -8905,18 +8905,18 @@ struct Matrix { Matrix operator-() const { Matrix m; - m.data[0] = -m.data[0]; - m.data[1] = -m.data[1]; - m.data[2] = -m.data[2]; - m.data[3] = -m.data[3]; - m.data[4] = -m.data[4]; - m.data[5] = -m.data[5]; - m.data[6] = -m.data[6]; - m.data[7] = -m.data[7]; - m.data[8] = -m.data[8]; - m.data[9] = -m.data[9]; - m.data[10] = -m.data[10]; - m.data[11] = -m.data[11]; + m.data[0] = -data[0]; + m.data[1] = -data[1]; + m.data[2] = -data[2]; + m.data[3] = -data[3]; + m.data[4] = -data[4]; + m.data[5] = -data[5]; + m.data[6] = -data[6]; + m.data[7] = -data[7]; + m.data[8] = -data[8]; + m.data[9] = -data[9]; + m.data[10] = -data[10]; + m.data[11] = -data[11]; return m; } @@ -9208,7 +9208,7 @@ template using Matrix3x4 = Matrix; -/// Free funciton to infer element type from template arguments +/// Free function to infer element type from template arguments template CUTLASS_HOST_DEVICE Matrix3x4 make_Matrix3x4( Element _0_0, Element _0_1, Element _0_2, Element _0_3, @@ -9271,7 +9271,7 @@ struct Matrix { data = rhs.data; } - /// Constucts a 4-by-1 matrix from scalar elements + /// Constructs a 4-by-1 matrix from scalar elements CUTLASS_HOST_DEVICE Matrix( Element _0_0, @@ -9723,10 +9723,10 @@ struct Matrix { Matrix operator-() const { Matrix m; - m.data[0] = -m.data[0]; - m.data[1] = -m.data[1]; - m.data[2] = -m.data[2]; - m.data[3] = -m.data[3]; + m.data[0] = -data[0]; + m.data[1] = -data[1]; + m.data[2] = -data[2]; + m.data[3] = -data[3]; return m; } @@ -9918,7 +9918,7 @@ template using Matrix4x1 = Matrix; -/// Free funciton to infer element type from template arguments +/// Free function to infer element type from template arguments template CUTLASS_HOST_DEVICE Matrix4x1 make_Matrix4x1( Element _0_0, @@ -9983,7 +9983,7 @@ struct Matrix { data = rhs.data; } - /// Constucts a 4-by-2 matrix from scalar elements + /// Constructs a 4-by-2 matrix from scalar elements CUTLASS_HOST_DEVICE Matrix( Element _0_0, Element _0_1, @@ -9998,7 +9998,7 @@ struct Matrix { data[6] = _3_0; data[7] = _3_1; } - /// Constucts a 4-by-2 matrix from row vectors + /// Constructs a 4-by-2 matrix from row vectors CUTLASS_HOST_DEVICE Matrix( Matrix const &row_0, @@ -10724,14 +10724,14 @@ struct Matrix { Matrix operator-() const { Matrix m; - m.data[0] = -m.data[0]; - m.data[1] = -m.data[1]; - m.data[2] = -m.data[2]; - m.data[3] = -m.data[3]; - m.data[4] = -m.data[4]; - m.data[5] = -m.data[5]; - m.data[6] = -m.data[6]; - m.data[7] = -m.data[7]; + m.data[0] = -data[0]; + m.data[1] = -data[1]; + m.data[2] = -data[2]; + m.data[3] = -data[3]; + m.data[4] = -data[4]; + m.data[5] = -data[5]; + m.data[6] = -data[6]; + m.data[7] = -data[7]; return m; } @@ -10958,7 +10958,7 @@ template using Matrix4x2 = Matrix; -/// Free funciton to infer element type from template arguments +/// Free function to infer element type from template arguments template CUTLASS_HOST_DEVICE Matrix4x2 make_Matrix4x2( Element _0_0, Element _0_1, @@ -11023,7 +11023,7 @@ struct Matrix { data = rhs.data; } - /// Constucts a 4-by-3 matrix from scalar elements + /// Constructs a 4-by-3 matrix from scalar elements CUTLASS_HOST_DEVICE Matrix( Element _0_0, Element _0_1, Element _0_2, @@ -11038,7 +11038,7 @@ struct Matrix { data[9] = _3_0; data[10] = _3_1; data[11] = _3_2; } - /// Constucts a 4-by-3 matrix from row vectors + /// Constructs a 4-by-3 matrix from row vectors CUTLASS_HOST_DEVICE Matrix( Matrix const &row_0, @@ -11996,18 +11996,18 @@ struct Matrix { Matrix operator-() const { Matrix m; - m.data[0] = -m.data[0]; - m.data[1] = -m.data[1]; - m.data[2] = -m.data[2]; - m.data[3] = -m.data[3]; - m.data[4] = -m.data[4]; - m.data[5] = -m.data[5]; - m.data[6] = -m.data[6]; - m.data[7] = -m.data[7]; - m.data[8] = -m.data[8]; - m.data[9] = -m.data[9]; - m.data[10] = -m.data[10]; - m.data[11] = -m.data[11]; + m.data[0] = -data[0]; + m.data[1] = -data[1]; + m.data[2] = -data[2]; + m.data[3] = -data[3]; + m.data[4] = -data[4]; + m.data[5] = -data[5]; + m.data[6] = -data[6]; + m.data[7] = -data[7]; + m.data[8] = -data[8]; + m.data[9] = -data[9]; + m.data[10] = -data[10]; + m.data[11] = -data[11]; return m; } @@ -12291,7 +12291,7 @@ template using Matrix4x3 = Matrix; -/// Free funciton to infer element type from template arguments +/// Free function to infer element type from template arguments template CUTLASS_HOST_DEVICE Matrix4x3 make_Matrix4x3( Element _0_0, Element _0_1, Element _0_2, @@ -12356,7 +12356,7 @@ struct Matrix { data = rhs.data; } - /// Constucts a 4-by-4 matrix from scalar elements + /// Constructs a 4-by-4 matrix from scalar elements CUTLASS_HOST_DEVICE Matrix( Element _0_0, Element _0_1, Element _0_2, Element _0_3, @@ -12371,7 +12371,7 @@ struct Matrix { data[12] = _3_0; data[13] = _3_1; data[14] = _3_2; data[15] = _3_3; } - /// Constucts a 4-by-4 matrix from row vectors + /// Constructs a 4-by-4 matrix from row vectors CUTLASS_HOST_DEVICE Matrix( Matrix const &row_0, @@ -13594,22 +13594,22 @@ struct Matrix { Matrix operator-() const { Matrix m; - m.data[0] = -m.data[0]; - m.data[1] = -m.data[1]; - m.data[2] = -m.data[2]; - m.data[3] = -m.data[3]; - m.data[4] = -m.data[4]; - m.data[5] = -m.data[5]; - m.data[6] = -m.data[6]; - m.data[7] = -m.data[7]; - m.data[8] = -m.data[8]; - m.data[9] = -m.data[9]; - m.data[10] = -m.data[10]; - m.data[11] = -m.data[11]; - m.data[12] = -m.data[12]; - m.data[13] = -m.data[13]; - m.data[14] = -m.data[14]; - m.data[15] = -m.data[15]; + m.data[0] = -data[0]; + m.data[1] = -data[1]; + m.data[2] = -data[2]; + m.data[3] = -data[3]; + m.data[4] = -data[4]; + m.data[5] = -data[5]; + m.data[6] = -data[6]; + m.data[7] = -data[7]; + m.data[8] = -data[8]; + m.data[9] = -data[9]; + m.data[10] = -data[10]; + m.data[11] = -data[11]; + m.data[12] = -data[12]; + m.data[13] = -data[13]; + m.data[14] = -data[14]; + m.data[15] = -data[15]; return m; } @@ -14096,7 +14096,7 @@ template using Matrix4x4 = Matrix; -/// Free funciton to infer element type from template arguments +/// Free function to infer element type from template arguments template CUTLASS_HOST_DEVICE Matrix4x4 make_Matrix4x4( Element _0_0, Element _0_1, Element _0_2, Element _0_3, diff --git a/3rd/cutlass/include/cutlass/matrix_coord.h b/3rd/cutlass/include/cutlass/matrix_coord.h index 85d447b..f19652c 100644 --- a/3rd/cutlass/include/cutlass/matrix_coord.h +++ b/3rd/cutlass/include/cutlass/matrix_coord.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/matrix_shape.h b/3rd/cutlass/include/cutlass/matrix_shape.h index 20d668b..22f1aeb 100644 --- a/3rd/cutlass/include/cutlass/matrix_shape.h +++ b/3rd/cutlass/include/cutlass/matrix_shape.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/numeric_conversion.h b/3rd/cutlass/include/cutlass/numeric_conversion.h index 886dc9f..e9463b2 100644 --- a/3rd/cutlass/include/cutlass/numeric_conversion.h +++ b/3rd/cutlass/include/cutlass/numeric_conversion.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -51,7 +51,7 @@ namespace cutlass { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Floating-point rounding style similare to Standard Library's formats but supporting +/// Floating-point rounding style similar to Standard Library's formats but supporting /// additional rounding options. enum class FloatRoundStyle { round_indeterminate, ///< rounding mode unknown @@ -3678,30 +3678,537 @@ template < struct NumericArrayConverter : public PackedNumericArrayConverter {}; +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Helpers for PTX conversion with inline assembly (used in the following partial specializations) +// Specifically, we use lookup tables for converting e2m1 to bf16, half and float_e4m3 and float_e5m2 +// +///////////////////////////////////////////////////////////////////////////////////////////////// + + +#if defined(CUDA_PTX_FP4FP6_CVT_ENABLED) && __CUDA_ARCH__ >= 1100 + #define USE_PTX_CONVERT 1 +#endif + + + +namespace detail { + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// E2M1 Conversion Helper Functions to BF16 +// +///////////////////////////////////////////////////////////////////////////////////////////////// + + /**************************************************************************** + E2M1 to BF16 Conversion Table: + +-----------------+-------------------------------+---------------+ + | E2M1 Pattern | BF16 Pattern | Numeric Value | + | (s E E M) | (s EEEEEEEE MMMMMMM) | | + +-----------------+-------------------------------+---------------+ + | 0 00 0 | 0 00000000 0000000 | 0.0 | + | 0 00 1 | 0 01111110 0000000 | 0.5 | + | 0 01 0 | 0 01111111 0000000 | 1.0 | + | 0 01 1 | 0 01111111 1000000 | 1.5 | + | 0 10 0 | 0 10000000 0000000 | 2.0 | + | 0 10 1 | 0 10000000 1000000 | 3.0 | + | 0 11 0 | 0 10000001 0000000 | 4.0 | + | 0 11 1 | 0 10000001 1000000 | 6.0 | + +-----------------+-------------------------------+---------------+ + + bits 2-9 go into a LUT, the top 2 bits are inserted from E2M1 value (s E E M) + + +-----------------+-------------------+-----------+ + | E2M1 Pattern | LUT Value (8 bits)| Hex Value | + | (s E E M) | | | + +-----------------+-------------------+-----------+ + | 0 00 0 | 00000000 | 0x00 | + | 0 00 1 | 11111100 | 0xFC | + | 0 01 0 | 11111110 | 0xFE | + | 0 01 1 | 11111111 | 0xFF | + | 0 10 0 | 00000000 | 0x00 | + | 0 10 1 | 00000001 | 0x01 | + | 0 11 0 | 00000010 | 0x02 | + | 0 11 1 | 00000011 | 0x03 | + +-----------------+-------------------+-----------+ + + constexpr unsigned long long E2M1_to_BF16_LUT = 0x03020100FFFEFC00ULL; + constexpr unsigned int E2M1_to_BF16_UPPER_LUT = 0xc0804000U; + + ****************************************************************************/ + + // LUT _e2m1_to_bf16_x2: Direct E2M1->BF16 (converts 2 E2M1 to 2 BF16) + CUTLASS_DEVICE + void _e2m1_to_bf16_x2(unsigned int src, unsigned int& out0) { + constexpr unsigned long long lut = 0x03020100FFFEFC00ULL; + constexpr unsigned int upper_lut = 0xc0804000U; + constexpr unsigned int lut_lo = (unsigned int)lut; + constexpr unsigned int lut_hi = (unsigned int)(lut >> 32); + + asm( + "{\n\t" + ".reg .u32 prmt_ctrl01, target01, upper01;\n\t" + ".reg .u32 target0_1_, upper_0_1;\n\t" + ".reg .u32 upper_prmt_ctrl_01;\n\t" + "and.b32 prmt_ctrl01, %1, 0x0077;\n\t" + "shr.b32 upper_prmt_ctrl_01, %1, 2;\n\t" + "and.b32 upper_prmt_ctrl_01, upper_prmt_ctrl_01, 0x0033;\n\t" + "prmt.b32 target01, %2, %3, prmt_ctrl01;\n\t" + "prmt.b32 target0_1_, target01, target01, 0x3120;\n\t" // both 2 and 3 are known to be zero + "prmt.b32 upper01, %4, %4, upper_prmt_ctrl_01;\n\t" + "prmt.b32 upper_0_1, upper01, upper01, 0x1302;\n\t" // both 2 and 3 are known to be zero + "shl.b32 target0_1_, target0_1_, 6;\n\t" + "or.b32 %0, target0_1_, upper_0_1;\n\t" + "}" + : "=r"(out0) + : "r"(src), "r"(lut_lo), "r"(lut_hi), "r"(upper_lut) // %1, %2, %3, %4 + ); + } + + // LUT x4 _e2m1_to_bf16_x4: Direct E2M1->BF16 (converts 4 E2M1 to 4 BF16 in one call) + CUTLASS_DEVICE + void _e2m1_to_bf16_x4(unsigned int src, unsigned int& out0, unsigned int& out1, unsigned int shift_count=0) { + constexpr unsigned long long lut = 0x03020100FFFEFC00ULL; // bit2 - 9 for e2m1 -> bf16 conversion + constexpr unsigned int upper_lut = 0xc0804000U; // bit0, bot1 for e2m1 -> bf16 conversion + constexpr unsigned int lut_lo = (unsigned int)lut; + constexpr unsigned int lut_hi = (unsigned int)(lut >> 32); + + + if (shift_count == 0) { + asm( + "{\n\t" + ".reg .u32 prmt_ctrl0123, target0123, upper0123;\n\t" + ".reg .u32 target0_1_, target2_3_, upper_0_1, upper_2_3;\n\t" + ".reg .u32 upper_prmt_ctrl_0123;\n\t" + "" + "and.b32 prmt_ctrl0123, %2, 0x7777;\n\t" + "shr.b32 upper_prmt_ctrl_0123, %2, 2;\n\t" + "and.b32 upper_prmt_ctrl_0123, upper_prmt_ctrl_0123, 0x3333;\n\t" + "prmt.b32 target0123, %3, %4, prmt_ctrl0123;\n\t" + "prmt.b32 target0_1_, target0123, %4, 0x4140;\n\t" // 4 is the low order byte of upper_lut which is 0x00 + "prmt.b32 target2_3_, target0123, %4, 0x4342;\n\t" // 4 is the low order byte of upper_lut which is 0x00 + "prmt.b32 upper0123, %5, %5, upper_prmt_ctrl_0123;\n\t" + "prmt.b32 upper_0_1, upper0123, %4, 0x1404;\n\t" // 4 is the low order byte of upper_lut which is 0x00 + "prmt.b32 upper_2_3, upper0123, %4, 0x3424;\n\t" // 4 is the low order byte of upper_lut which is 0x00 + "shl.b32 target0_1_, target0_1_, 6;\n\t" + "shl.b32 target2_3_, target2_3_, 6;\n\t" + "or.b32 %0, target0_1_, upper_0_1;\n\t" + "or.b32 %1, target2_3_, upper_2_3;\n\t" + "}" + : "=r"(out0), "=r"(out1) + : "r"(src), "r"(lut_lo), "r"(lut_hi), "r"(upper_lut) // %2, %3, %4, %5 + ); + } else if (shift_count == 16) { + // protect against future changes to the shift_count + assert(shift_count==16 && "shift_count should be 0 or 16"); + asm( + "{\n\t" + ".reg .u32 prmt_ctrl0123, target0123, upper0123;\n\t" + ".reg .u32 target0_1_, target2_3_, upper_0_1, upper_2_3;\n\t" + ".reg .u32 upper_prmt_ctrl_0123;\n\t" + "" + "shr.b32 prmt_ctrl0123, %2, 16;\n\t" + "shr.b32 upper_prmt_ctrl_0123, %2, 18;\n\t" + "and.b32 prmt_ctrl0123, prmt_ctrl0123, 0x7777;\n\t" + "and.b32 upper_prmt_ctrl_0123, upper_prmt_ctrl_0123, 0x3333;\n\t" + "prmt.b32 target0123, %3, %4, prmt_ctrl0123;\n\t" + "prmt.b32 target0_1_, target0123, %4, 0x4140;\n\t" // 4 is the low order byte of upper_lut which is 0x00 + "prmt.b32 target2_3_, target0123, %4, 0x4342;\n\t" // 4 is the low order byte of upper_lut which is 0x00 + "prmt.b32 upper0123, %5, %5, upper_prmt_ctrl_0123;\n\t" + "prmt.b32 upper_0_1, upper0123, %4, 0x1404;\n\t" // 4 is the low order byte of upper_lut which is 0x00 + "prmt.b32 upper_2_3, upper0123, %4, 0x3424;\n\t" // 4 is the low order byte of upper_lut which is 0x00 + "shl.b32 target0_1_, target0_1_, 6;\n\t" + "shl.b32 target2_3_, target2_3_, 6;\n\t" + "or.b32 %0, target0_1_, upper_0_1;\n\t" + "or.b32 %1, target2_3_, upper_2_3;\n\t" + "}" + : "=r"(out0), "=r"(out1) + : "r"(src), "r"(lut_lo), "r"(lut_hi), "r"(upper_lut) // %2, %3, %4, %5 + ); + } else { + assert((shift_count==0 || shift_count==16) && "shift_count should be 0 or 16"); + } + } + + // LUT x8 _e2m1_to_bf16_x8: Direct E2M1->BF16 (converts 8 E2M1 to 8 BF16 in one call) + CUTLASS_DEVICE + void _e2m1_to_bf16_x8(unsigned int src, unsigned int& out0, unsigned int& out1, unsigned int& out2, unsigned int& out3) { + _e2m1_to_bf16_x4(src, out0, out1, 0); + _e2m1_to_bf16_x4(src, out2, out3, 16); + } + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// E2M1 Conversion Helper Functions to FP16 +// +///////////////////////////////////////////////////////////////////////////////////////////////// + + + /**************************************************************************** + E2M1 to FP16 Conversion Table: + +-----------------+-----------------------------------+---------------+---------------+ + | E2M1 Pattern | FP16 Pattern | MS Byte (Hex) | Numeric Value | + | (s E E M) | (s EEEEE MMMMMMMMMM) | (s EEEEE MM) | | + +-----------------+-----------------------------------+---------------+---------------+ + | 0 00 0 | 0 00000 0000000000 | 0x00 | 0.0 | + | 0 00 1 | 0 01110 0000000000 | 0x38 | 0.5 | + | 0 01 0 | 0 01111 0000000000 | 0x3C | 1.0 | + | 0 01 1 | 0 01111 1000000000 | 0x3E | 1.5 | + | 0 10 0 | 0 10000 0000000000 | 0x40 | 2.0 | + | 0 10 1 | 0 10000 1000000000 | 0x42 | 3.0 | + | 0 11 0 | 0 10001 0000000000 | 0x44 | 4.0 | + | 0 11 1 | 0 10001 1000000000 | 0x46 | 6.0 | + +-----------------+-----------------------------------+---------------+---------------+ + + constexpr unsigned long long E2M1_to_FP16_LUT = 0x464442403E3C3800ULL; + constexpr unsigned int E2M1_to_FP16_UPPER_LUT = 0xc0804000U; + + ****************************************************************************/ + + // LUT _e2m1_to_half_x2: Direct E2M1->FP16 (converts 2 E2M1 to 2 FP16) + CUTLASS_DEVICE + void _e2m1_to_half_x2(unsigned int src, unsigned int& out0) { + constexpr unsigned long long lut = 0x464442403E3C3800ULL; + constexpr unsigned int upper_lut = 0xc0804000U; + constexpr unsigned int lut_lo = (unsigned int)lut; + constexpr unsigned int lut_hi = (unsigned int)(lut >> 32); + + asm( + "{\n\t" + ".reg .u32 prmt_ctrl01, target01, upper01;\n\t" + ".reg .u32 target0_1_, upper_0_1;\n\t" + ".reg .u32 upper_prmt_ctrl_01;\n\t" + "and.b32 prmt_ctrl01, %1, 0x0077;\n\t" + "shr.b32 upper_prmt_ctrl_01, %1, 2;\n\t" + "and.b32 upper_prmt_ctrl_01, upper_prmt_ctrl_01, 0x0033;\n\t" + "prmt.b32 target01, %2, %3, prmt_ctrl01;\n\t" + "prmt.b32 target0_1_, target01, target01, 0x1302;\n\t" // both 2 and 3 are known to be zero + "prmt.b32 upper01, %4, %4, upper_prmt_ctrl_01;\n\t" + "prmt.b32 upper_0_1, upper01, upper01, 0x1302;\n\t" // both 2 and 3 are known to be zero + "or.b32 %0, target0_1_, upper_0_1;\n\t" + "}" + : "=r"(out0) + : "r"(src), "r"(lut_lo), "r"(lut_hi), "r"(upper_lut) // %2, %3, %4, %5 + ); + } + + // LUT x4 _e2m1_to_half_x4: Direct E2M1->FP16 (converts 4 E2M1 to 4 FP16 in one call) + CUTLASS_DEVICE + void _e2m1_to_half_x4(unsigned int src, unsigned int& out0, unsigned int& out1, unsigned int shift_count=0) { + constexpr unsigned long long lut = 0x464442403E3C3800ULL; // bit0 - 7 for e2m1 -> half conversion + constexpr unsigned int upper_lut = 0xc0804000U; // bit0, bit1 for e2m1 -> FP16 conversion + constexpr unsigned int lut_lo = (unsigned int)lut; + constexpr unsigned int lut_hi = (unsigned int)(lut >> 32); + + assert((shift_count==0 || shift_count==16) && "shift_count should be 0 or 16"); + + if (shift_count == 0) { + asm( + "{\n\t" + ".reg .u32 prmt_ctrl0123, target0123, upper0123;\n\t" + ".reg .u32 target0_1_, target2_3_, upper_0_1, upper_2_3;\n\t" + ".reg .u32 upper_prmt_ctrl_0123;\n\t" + "" + "and.b32 prmt_ctrl0123, %2, 0x7777;\n\t" + "shr.b32 upper_prmt_ctrl_0123, %2, 2;\n\t" + "and.b32 upper_prmt_ctrl_0123, upper_prmt_ctrl_0123, 0x3333;\n\t" + "prmt.b32 target0123, %3, %4, prmt_ctrl0123;\n\t" + "prmt.b32 target0_1_, target0123, %5, 0x1404;\n\t" // 4 is the low order byte of upper_lut which is 0x00 + "prmt.b32 target2_3_, target0123, %5, 0x3424;\n\t" // 4 is the low order byte of upper_lut which is 0x00 + "prmt.b32 upper0123, %5, %5, upper_prmt_ctrl_0123;\n\t" + "prmt.b32 upper_0_1, upper0123, %5, 0x1404;\n\t" // 4 is the low order byte of upper_lut which is 0x00 + "prmt.b32 upper_2_3, upper0123, %5, 0x3424;\n\t" // 4 is the low order byte of upper_lut which is 0x00 + "or.b32 %0, target0_1_, upper_0_1;\n\t" + "or.b32 %1, target2_3_, upper_2_3;\n\t" + "}" + : "=r"(out0), "=r"(out1) + : "r"(src), "r"(lut_lo), "r"(lut_hi), "r"(upper_lut) // %2, %3, %4, %5 + ); + } else /* shift_count == 16 */ { + // protect against future changes to the shift_count + assert(shift_count==16 && "shift_count should be 0 or 16"); + asm( + "{\n\t" + ".reg .u32 prmt_ctrl0123, target0123, upper0123;\n\t" + ".reg .u32 target0_1_, target2_3_, upper_0_1, upper_2_3;\n\t" + ".reg .u32 upper_prmt_ctrl_0123;\n\t" + "" + "shr.b32 prmt_ctrl0123, %2, 16;\n\t" + "shr.b32 upper_prmt_ctrl_0123, %2, 18;\n\t" + "and.b32 prmt_ctrl0123, prmt_ctrl0123, 0x7777;\n\t" + "and.b32 upper_prmt_ctrl_0123, upper_prmt_ctrl_0123, 0x3333;\n\t" + "prmt.b32 target0123, %3, %4, prmt_ctrl0123;\n\t" + "prmt.b32 target0_1_, target0123, %5, 0x1404;\n\t" // 4 is the low order byte of upper_lut which is 0x00 + "prmt.b32 target2_3_, target0123, %5, 0x3424;\n\t" // 4 is the low order byte of upper_lut which is 0x00 + "prmt.b32 upper0123, %5, %5, upper_prmt_ctrl_0123;\n\t" + "prmt.b32 upper_0_1, upper0123, %5, 0x1404;\n\t" // 4 is the low order byte of upper_lut which is 0x00 + "prmt.b32 upper_2_3, upper0123, %5, 0x3424;\n\t" // 4 is the low order byte of upper_lut which is 0x00 + "or.b32 %0, target0_1_, upper_0_1;\n\t" + "or.b32 %1, target2_3_, upper_2_3;\n\t" + "}" + : "=r"(out0), "=r"(out1) + : "r"(src), "r"(lut_lo), "r"(lut_hi), "r"(upper_lut) // %2, %3, %4, %5 + ); + } + } + + // LUT x8 _e2m1_to_half_x8: Direct E2M1->FP16 (converts 8 E2M1 to 8 FP16 in one call) + CUTLASS_DEVICE + void _e2m1_to_half_x8(unsigned int src, unsigned int& out0, unsigned int& out1, unsigned int& out2, unsigned int& out3) { + _e2m1_to_half_x4(src, out0, out1, 0); + _e2m1_to_half_x4(src, out2, out3, 16); + } + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// E2M1 -> FP8 (E4M3/E5M2) Shared Helper Functions +// +///////////////////////////////////////////////////////////////////////////////////////////////// + + /**************************************************************************** + E2M1 to E4M3 (FP8) Conversion Table + constexpr unsigned long long E2M1_to_E4M3_LUT_0_7 = 0x4C4844403C383000ULL; + + +-----------------+--------------------------+---------------+-------------+ + | E2M1 Pattern | E4M3 Pattern | Numeric Value | E4M3 Hex | + | (s E E M) | (s EEEE MMM) | | | + +-----------------+--------------------------+---------------+-------------+ + | 0 00 0 | 0 0000 000 | 0.0 | 0x00 | + | 0 00 1 | 0 0110 000 | 0.5 | 0x30 | + | 0 01 0 | 0 0111 000 | 1.0 | 0x38 | + | 0 01 1 | 0 0111 100 | 1.5 | 0x3C | + | 0 10 0 | 0 1000 000 | 2.0 | 0x40 | + | 0 10 1 | 0 1000 100 | 3.0 | 0x44 | + | 0 11 0 | 0 1001 000 | 4.0 | 0x48 | + | 0 11 1 | 0 1001 100 | 6.0 | 0x4C | + +-----------------+--------------------------+---------------+-------------+ + + E2M1 to E5M2 (FP8) Conversion Table + constexpr unsigned long long E2M1_to_E5M2_LUT_0_7 = 0x464442403E3C3800ULL; + + +-----------------+--------------------------+---------------+-------------+ + | E2M1 Pattern | E5M2 Pattern | Numeric Value | E5M2 Hex | + | (s E E M) | (s EEEEE MM) | | | + +-----------------+--------------------------+---------------+-------------+ + | 0 00 0 | 0 00000 00 | 0.0 | 0x00 | + | 0 00 1 | 0 01110 00 | 0.5 | 0x38 | + | 0 01 0 | 0 01111 00 | 1.0 | 0x3C | + | 0 01 1 | 0 01111 10 | 1.5 | 0x3E | + | 0 10 0 | 0 10000 00 | 2.0 | 0x40 | + | 0 10 1 | 0 10000 10 | 3.0 | 0x42 | + | 0 11 0 | 0 10001 00 | 4.0 | 0x44 | + | 0 11 1 | 0 10001 10 | 6.0 | 0x46 | + +-----------------+--------------------------+---------------+-------------+ + ****************************************************************************/ + + // LUT _e2m1_to_fp8_x4: Direct E2M1->FP8 (converts 4 E2M1 to 4 FP8) + template + CUTLASS_DEVICE + static void _e2m1_to_fp8_x4(unsigned int src, unsigned int& out0) { + constexpr unsigned int upper_lut = 0xc0804000U; + constexpr unsigned int lut_lo = (unsigned int)fp8_lut; + constexpr unsigned int lut_hi = (unsigned int)(fp8_lut >> 32); + asm( + "{\n\t" + ".reg .u32 prmt_ctrl0123, target0123, upper0123;\n\t" + ".reg .u32 upper_prmt_ctrl_0123;\n\t" + "and.b32 prmt_ctrl0123, %1, 0x7777;\n\t" + "shr.b32 upper_prmt_ctrl_0123, %1, 2;\n\t" + "and.b32 upper_prmt_ctrl_0123, upper_prmt_ctrl_0123, 0x3333;\n\t" + "prmt.b32 target0123, %2, %3, prmt_ctrl0123;\n\t" + "prmt.b32 upper0123, %4, %4, upper_prmt_ctrl_0123;\n\t" + "or.b32 %0, target0123, upper0123;\n\t" + "}" + : "=r"(out0) + : "r"(src), "r"(lut_lo), "r"(lut_hi), "r"(upper_lut) + ); + } + + // LUT _e2m1_to_fp8_x2: Direct E2M1->FP8 (converts 2 E2M1 to 2 FP8) + template + CUTLASS_DEVICE + static void _e2m1_to_fp8_x2(unsigned int src, unsigned int& out0) { + constexpr unsigned int upper_lut = 0xc0804000U; + constexpr unsigned int lut_lo = (unsigned int)fp8_lut; + constexpr unsigned int lut_hi = (unsigned int)(fp8_lut >> 32); + asm( + "{\n\t" + ".reg .u32 prmt_ctrl0123, target0123, upper0123;\n\t" + ".reg .u32 upper_prmt_ctrl_0123;\n\t" + "and.b32 prmt_ctrl0123, %1, 0x0077;\n\t" + "shr.b32 upper_prmt_ctrl_0123, %1, 2;\n\t" + "and.b32 upper_prmt_ctrl_0123, upper_prmt_ctrl_0123, 0x0033;\n\t" + "prmt.b32 target0123, %2, %3, prmt_ctrl0123;\n\t" + "prmt.b32 upper0123, %4, %4, upper_prmt_ctrl_0123;\n\t" + "or.b32 %0, target0123, upper0123;\n\t" // we wiped out the upper half when computing controls + "}" + : "=r"(out0) + : "r"(src), "r"(lut_lo), "r"(lut_hi), "r"(upper_lut) + ); + } + + // LUT _e2m1_to_fp8_x8: Direct E2M1->FP8 (converts 8 E2M1 to 8 FP8) + template + CUTLASS_DEVICE + static void _e2m1_to_fp8_x8(unsigned int src, unsigned int& out0, unsigned int& out1) { + constexpr unsigned int upper_lut = 0xc0804000U; + constexpr unsigned int lut_lo = (unsigned int)fp8_lut; + constexpr unsigned int lut_hi = (unsigned int)(fp8_lut >> 32); + asm( + "{\n\t" + ".reg .u32 prmt_ctrl0123, target0123, upper0123;\n\t" + ".reg .u32 upper_prmt_ctrl_0123;\n\t" + ".reg .u32 prmt_ctrl4567, target4567, upper4567;\n\t" + ".reg .u32 upper_prmt_ctrl_4567;\n\t" + "and.b32 prmt_ctrl0123, %2, 0x7777;\n\t" + "shr.b32 upper_prmt_ctrl_0123, %2, 2;\n\t" + "shr.b32 prmt_ctrl4567, %2, 16;\n\t" + "shr.b32 upper_prmt_ctrl_4567, %2, 18;\n\t" + "and.b32 prmt_ctrl4567, prmt_ctrl4567, 0x7777;\n\t" + "and.b32 upper_prmt_ctrl_0123, upper_prmt_ctrl_0123, 0x3333;\n\t" + "and.b32 upper_prmt_ctrl_4567, upper_prmt_ctrl_4567, 0x3333;\n\t" + "prmt.b32 target0123, %3, %4, prmt_ctrl0123;\n\t" + "prmt.b32 target4567, %3, %4, prmt_ctrl4567;\n\t" + "prmt.b32 upper0123, %5, %5, upper_prmt_ctrl_0123;\n\t" + "prmt.b32 upper4567, %5, %5, upper_prmt_ctrl_4567;\n\t" + "or.b32 %0, target0123, upper0123;\n\t" + "or.b32 %1, target4567, upper4567;\n\t" + "}" + : "=r"(out0), "=r"(out1) + : "r"(src), "r"(lut_lo), "r"(lut_hi), "r"(upper_lut) + ); + } + +} // namespace detail + +namespace detail { + + /* + A helper class that can vectorize a numeric converter with implementation for several vector widths. + + The vector widths must be giving in decreasing order or width, and must be a power of 2. + + The vector converters must produce identical results to the scalar converters for consistency. + */ + class VectorizedConverter { + private: + // Base case to handle remainder elements as scalars. + template + CUTLASS_DEVICE + static void convert_helper( + typename ArrayConverter::result_type& result, + typename ArrayConverter::source_type const& source) { + + using ElementRes = typename ArrayConverter::result_type::Element; + using ElementSrc = typename ArrayConverter::source_type::Element; + // If no more converters, handle the remaining elements as scalars. + constexpr int total_elements = ArrayConverter::result_type::kElements; + constexpr int remainder = total_elements - Offset; + static_assert(remainder == (total_elements % ParentWidth), "Unexpected remainder."); + + typename ArrayConverter::ScalarConverter scalar_converter; + CUTLASS_PRAGMA_UNROLL + for (int i = Offset; i < ArrayConverter::result_type::kElements; ++i) { + result[i] = scalar_converter(ElementSrc(source[i])); + } + } + + template + CUTLASS_DEVICE + static void convert_helper(typename ArrayConverter::result_type& result, typename ArrayConverter::source_type const& source) { + static_assert(sizeof...(OtherVectorArrays) % 2 == 0, "Vector converters must come in {dst, src} pairs"); + static_assert(ResultVectorArray::kElements == SourceVectorArray::kElements, "Vector converters must have the same vector width"); + static_assert(cutlass::platform::is_same::value, + "ResultVectorArray must have the same type ArrayConverter::result_type"); + static_assert(cutlass::platform::is_same::value, + "SourceVectorArray must have the same type ArrayConverter::result_type"); + static_assert(Offset >= 0 && Offset <= ArrayConverter::result_type::kElements, "Offset must be between 0 and N"); + + static_assert(ParentWidth == 0 || ParentWidth > ResultVectorArray::kElements, "Vector arrays must be given in decreasing order of width"); + + constexpr int vector_width = ResultVectorArray::kElements; + static_assert(ispow2(vector_width), "Vector width must be a power of 2"); + + using ElementRes = typename ArrayConverter::result_type::Element; + using ElementSrc = typename ArrayConverter::source_type::Element; + + constexpr int vector_bits_res = vector_width * cutlass::sizeof_bits::value; + constexpr int vector_bits_src = vector_width * cutlass::sizeof_bits::value; + + static_assert(vector_bits_res % 8 == 0, "Result vector type must be byte addressed."); + static_assert(vector_bits_src % 8 == 0, "Source vector type must be byte addressed."); + + constexpr int vector_offset = Offset / vector_width; + ResultVectorArray* packed_result_vec = reinterpret_cast(&result) + vector_offset; + SourceVectorArray const* packed_source_vec = reinterpret_cast(&source) + vector_offset; + + // Convert the remaining elements as vectors. + constexpr int total_elements = ArrayConverter::result_type::kElements; + constexpr int groups_of_vec = (total_elements - Offset) / vector_width; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < groups_of_vec; ++i) { + packed_result_vec[i] = ArrayConverter::template packed_convert(packed_source_vec[i]); + } + + constexpr int new_offset = Offset + vector_width * groups_of_vec; + // Recurse to handle other vector converters, or the scalar base case. + convert_helper(result, source); + } + + public: + /* + A method to convert vectors of elements using the packed_convert method of the converter. + + Converters using this class must implement packed convert and support 1 or more vector conversions. + */ + template + CUTLASS_DEVICE + static void convert(typename ArrayConverter::result_type& result, typename ArrayConverter::source_type const& source) { + convert_helper<0, 0, ArrayConverter, ResultVectorArray, SourceVectorArray, OtherVectorArrays...>(result, source); + } + }; +} + + ///////////////////////////////////////////////////////////////////////////////////////////////// // // Partial specializations for Array <=> Array // ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Partial specialization for Array <= Array +/// Partial specialization for Array <= Array template < + int N, FloatRoundStyle Round > -struct NumericArrayConverter { +struct NumericArrayConverter { using result_element = float; using source_element = cutlass::float_e2m1_t; - - using result_type = Array; - using source_type = Array; + using result_type = Array; + using source_type = Array; static FloatRoundStyle const round_style = Round; +private: + using result_type_packed_8 = Array; + using result_type_packed_4 = Array; + using result_type_packed_2 = Array; + using source_type_packed_8 = Array; + using source_type_packed_4 = Array; + using source_type_packed_2 = Array; + + using ScalarConverter = NumericConverter; + + // Convert 8 elements: E2M1 -> BF16 -> FP32 CUTLASS_DEVICE - static result_type convert(source_type const & source) { + static result_type_packed_8 lut_convert(source_type_packed_8 const &source) { + result_type_packed_8 result; + uint32_t src_packed = *reinterpret_cast(source.data()); - #if defined(CUDA_PTX_FP4FP6_CVT_ENABLED) - uint32_t out_fp16[4]; - uint32_t const& src_packed = reinterpret_cast(source); + #if defined(USE_PTX_CONVERT) + + float* out_float = reinterpret_cast(result.data()); + uint32_t halfx2_01, halfx2_23, halfx2_45, halfx2_67; asm volatile( \ "{\n" \ @@ -3711,29 +4218,157 @@ struct NumericArrayConverter { "cvt.rn.f16x2.e2m1x2 %1, byte1;\n" \ "cvt.rn.f16x2.e2m1x2 %2, byte2;\n" \ "cvt.rn.f16x2.e2m1x2 %3, byte3;\n" \ - "}\n" : "=r"(out_fp16[0]), "=r"(out_fp16[1]) , "=r"(out_fp16[2]), "=r"(out_fp16[3]): "r"(src_packed)); + "}\n" + : "=r"(halfx2_01), "=r"(halfx2_23) , "=r"(halfx2_45), "=r"(halfx2_67): "r"(src_packed) + ); + + float2 floatx2_01 = __half22float2(reinterpret_cast<__half2 &>(halfx2_01)); + float2 floatx2_23 = __half22float2(reinterpret_cast<__half2 &>(halfx2_23)); + float2 floatx2_45 = __half22float2(reinterpret_cast<__half2 &>(halfx2_45)); + float2 floatx2_67 = __half22float2(reinterpret_cast<__half2 &>(halfx2_67)); + + out_float[0] = floatx2_01.x; + out_float[1] = floatx2_01.y; + out_float[2] = floatx2_23.x; + out_float[3] = floatx2_23.y; + out_float[4] = floatx2_45.x; + out_float[5] = floatx2_45.y; + out_float[6] = floatx2_67.x; + out_float[7] = floatx2_67.y; - float2 res0 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[0])); - float2 res1 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[1])); - float2 res2 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[2])); - float2 res3 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[3])); + #else + + // Convert all 8 E2M1 values to BF16, then extend to FP32 + uint4* out_floatx4 = reinterpret_cast(result.data()); + unsigned int bfloatx2_01, bfloatx2_23, bfloatx2_45, bfloatx2_67; + detail::_e2m1_to_bf16_x8(src_packed, bfloatx2_01, bfloatx2_23, bfloatx2_45, bfloatx2_67); + + out_floatx4[0] = make_uint4(bfloatx2_01 << 16, bfloatx2_01 & 0xFFFF0000u, + bfloatx2_23 << 16, bfloatx2_23 & 0xFFFF0000u); + out_floatx4[1] = make_uint4(bfloatx2_45 << 16, bfloatx2_45 & 0xFFFF0000u, + bfloatx2_67 << 16, bfloatx2_67 & 0xFFFF0000u); + + #endif + + return result; + } + + // Convert 4 elements: E2M1 -> BF16 -> FP32 + CUTLASS_DEVICE + static result_type_packed_4 lut_convert(source_type_packed_4 const &source) { + result_type_packed_4 result; + uint16_t src_packed = *reinterpret_cast(source.data()); + + #if defined(USE_PTX_CONVERT) + + float* out_float = reinterpret_cast(result.data()); + uint32_t halfx2_01, halfx2_23; + + asm volatile( \ + "{\n" \ + ".reg .b8 byte0, byte1;\n" \ + "mov.b16 {byte0, byte1}, %2;\n" \ + "cvt.rn.f16x2.e2m1x2 %0, byte0;\n" \ + "cvt.rn.f16x2.e2m1x2 %1, byte1;\n" \ + "}\n" : "=r"(halfx2_01), "=r"(halfx2_23): "h"(src_packed) + ); + + float2 floatx2_01 = __half22float2(reinterpret_cast<__half2 &>(halfx2_01)); + float2 floatx2_23 = __half22float2(reinterpret_cast<__half2 &>(halfx2_23)); + + out_float[0] = floatx2_01.x; + out_float[1] = floatx2_01.y; + out_float[2] = floatx2_23.x; + out_float[3] = floatx2_23.y; - result_type out; - out[0] = res0.x; - out[1] = res0.y; - out[2] = res1.x; - out[3] = res1.y; - out[4] = res2.x; - out[5] = res2.y; - out[6] = res3.x; - out[7] = res3.y; - return out; #else + + uint4* out_floatx4 = reinterpret_cast(result.data()); + unsigned int bfloatx2_01, bfloatx2_23; + detail::_e2m1_to_bf16_x4(src_packed, bfloatx2_01, bfloatx2_23); + + out_floatx4[0] = make_uint4(bfloatx2_01 << 16, bfloatx2_01 & 0xFFFF0000u, + bfloatx2_23 << 16, bfloatx2_23 & 0xFFFF0000u); + + #endif + + return result; + } + + // Convert 2 elements: E2M1 -> BF16 -> FP32 + CUTLASS_DEVICE + static result_type_packed_2 lut_convert(source_type_packed_2 const &source) { + result_type_packed_2 result; + uint8_t src_packed = *reinterpret_cast(source.data()); + + #if defined(USE_PTX_CONVERT) + + float* out_float = reinterpret_cast(result.data()); + uint32_t halfx2_01; + + // extend input to 16 bit since there is no 8 bit register asm constraint available for ptx inline assembly + uint16_t src_packed_16 = src_packed; + + asm volatile( \ + "{\n" \ + ".reg .b8 byte0, byte1;\n" \ + "mov.b16 {byte0, byte1}, %1;\n" \ + "cvt.rn.f16x2.e2m1x2 %0, byte0;\n" \ + "}\n" : "=r"(halfx2_01): "h"(src_packed_16) + ); + + float2 floatx2_01 = __half22float2(reinterpret_cast<__half2 &>(halfx2_01)); + + out_float[0] = floatx2_01.x; + out_float[1] = floatx2_01.y; + + #else + + uint2* out_floatx2 = reinterpret_cast(result.data()); + unsigned int bfloatx2_01; + detail::_e2m1_to_bf16_x2(src_packed, bfloatx2_01); + + out_floatx2[0] = make_uint2(bfloatx2_01 << 16, bfloatx2_01 & 0xFFFF0000u); + + #endif + + return result; + } + + template + CUTLASS_DEVICE + static PackedResultType packed_convert(PackedSrcType const &source) { + static_assert((platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 2, 4 or 8."); + + return lut_convert(source); + } + + friend class detail::VectorizedConverter; + +public: + CUTLASS_HOST_DEVICE + static result_type convert(source_type const &source) { + #if defined(__CUDA_ARCH__) result_type result; - NumericConverter converter; + using ConverterType = NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + #else + result_type result; + ScalarConverter converter; CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 8; ++i) { + for (int i = 0; i < N; ++i) { result[i] = converter(source[i]); } @@ -3747,34 +4382,28 @@ struct NumericArrayConverter { } }; -/// Partial specialization for Array <= Array +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partial specializations for Array <=> Array +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Array <= Array template < int N, FloatRoundStyle Round > -struct NumericArrayConverter { - static_assert(!(N % 8), "N must be multiple of 8."); - - using result_type = Array; - using source_type = Array; +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; static FloatRoundStyle const round_style = Round; CUTLASS_HOST_DEVICE static result_type convert(source_type const & source) { - - NumericArrayConverter convert_vector_; - - result_type result; - - Array *result_ptr = reinterpret_cast *>(&result); - Array const *source_ptr = reinterpret_cast const *>(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 8; ++i) { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; + // Delegate to float converter, then reinterpret bits + NumericArrayConverter e2m1_to_fp32; + Array fp32_result = e2m1_to_fp32(source); + return reinterpret_cast(fp32_result); } CUTLASS_HOST_DEVICE @@ -3901,16 +4530,16 @@ struct NumericArrayConverter { static result_type convert(source_type const & source) { #if defined(CUDA_PTX_FP4FP6_CVT_ENABLED) - unsigned out; + uint16_t out; asm volatile( \ "{\n" \ ".reg .b8 byte0;\n" \ ".reg .b8 byte1;\n" \ "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" \ "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" \ - "mov.b32 %0, {byte0, byte1, 0, 0};\n" \ + "mov.b16 %0, {byte0, byte1};\n" \ "}" \ - : "=r"(out) : "f"(source[0]), "f"(source[1]), "f"(source[2]), "f"(source[3])); + : "=h"(out) : "f"(source[0]), "f"(source[1]), "f"(source[2]), "f"(source[3])); return reinterpret_cast(out); #else @@ -4287,164 +4916,246 @@ struct NumericArrayConverter { #endif // Conditional guards to enable partial specialization for packed integers -namespace detail { - /* - A helper class that can vectorize a numeric converter with implementation for several vector widths. - The vector widths must be giving in decreasing order or width, and must be a power of 2. +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round, + int N +> +struct NumericArrayConverter { + using result_element = cutlass::half_t; + using source_element = cutlass::float_e2m1_t; + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; - The vector converters must produce identical results to the scalar converters for consistency. - */ - class VectorizedConverter { - private: - // Base case to handle remainder elements as scalars. - template - CUTLASS_DEVICE - static void convert_helper( - typename ArrayConverter::result_type& result, - typename ArrayConverter::source_type const& source) { +private: + using result_type_packed_8 = Array; + using result_type_packed_4 = Array; + using result_type_packed_2 = Array; + using source_type_packed_8 = Array; + using source_type_packed_4 = Array; + using source_type_packed_2 = Array; - using ElementRes = typename ArrayConverter::result_type::Element; - using ElementSrc = typename ArrayConverter::source_type::Element; - // If no more converters, handle the remaining elements as scalars. - constexpr int total_elements = ArrayConverter::result_type::kElements; - constexpr int remainder = total_elements - Offset; - static_assert(remainder == (total_elements % ParentWidth), "Unexpected remainder."); + using ScalarConverter = NumericConverter; - typename ArrayConverter::ScalarConverter scalar_converter; - CUTLASS_PRAGMA_UNROLL - for (int i = Offset; i < ArrayConverter::result_type::kElements; ++i) { - result[i] = scalar_converter(ElementSrc(source[i])); - } - } - template - CUTLASS_DEVICE - static void convert_helper(typename ArrayConverter::result_type& result, typename ArrayConverter::source_type const& source) { - static_assert(sizeof...(OtherVectorArrays) % 2 == 0, "Vector converters must come in {dst, src} pairs"); - static_assert(ResultVectorArray::kElements == SourceVectorArray::kElements, "Vector converters must have the same vector width"); - static_assert(cutlass::platform::is_same::value, - "ResultVectorArray must have the same type ArrayConverter::result_type"); - static_assert(cutlass::platform::is_same::value, - "SourceVectorArray must have the same type ArrayConverter::result_type"); - static_assert(Offset >= 0 && Offset <= ArrayConverter::result_type::kElements, "Offset must be between 0 and N"); + #if defined(USE_PTX_CONVERT) // defined(CUDA_PTX_FP4FP6_CVT_ENABLED) + CUTLASS_DEVICE + static result_type_packed_8 ptx_convert(source_type_packed_8 const &source) { + result_type_packed_8 out; + uint32_t* out_halfx2 = reinterpret_cast(&out); + uint32_t const& src_packed = reinterpret_cast(source); - static_assert(ParentWidth == 0 || ParentWidth > ResultVectorArray::kElements, "Vector arrays must be given in decreasing order of width"); + asm volatile( \ + "{\n" \ + ".reg .b8 byte0, byte1, byte2, byte3;\n" \ + "mov.b32 {byte0, byte1, byte2, byte3}, %4;\n" \ + "cvt.rn.f16x2.e2m1x2 %0, byte0;\n" \ + "cvt.rn.f16x2.e2m1x2 %1, byte1;\n" \ + "cvt.rn.f16x2.e2m1x2 %2, byte2;\n" \ + "cvt.rn.f16x2.e2m1x2 %3, byte3;\n" \ + "}\n" : "=r"(out_halfx2[0]), "=r"(out_halfx2[1]) , "=r"(out_halfx2[2]), "=r"(out_halfx2[3]): "r"(src_packed)); - constexpr int vector_width = ResultVectorArray::kElements; - static_assert(ispow2(vector_width), "Vector width must be a power of 2"); + return out; + } - using ElementRes = typename ArrayConverter::result_type::Element; - using ElementSrc = typename ArrayConverter::source_type::Element; + CUTLASS_DEVICE + static result_type_packed_4 ptx_convert(source_type_packed_4 const &source) { + result_type_packed_4 out; + uint32_t* out_halfx2 = reinterpret_cast(&out); + uint16_t const& src_packed = reinterpret_cast(source); - constexpr int vector_bits_res = vector_width * cutlass::sizeof_bits::value; - constexpr int vector_bits_src = vector_width * cutlass::sizeof_bits::value; + asm volatile( \ + "{\n" \ + ".reg .b8 byte0, byte1;\n" \ + "mov.b16 {byte0, byte1}, %2;\n" \ + "cvt.rn.f16x2.e2m1x2 %0, byte0;\n" \ + "cvt.rn.f16x2.e2m1x2 %1, byte1;\n" \ + "}\n" : "=r"(out_halfx2[0]), "=r"(out_halfx2[1]) : "h"(src_packed)); - static_assert(vector_bits_res % 8 == 0, "Result vector type must be byte addressed."); - static_assert(vector_bits_src % 8 == 0, "Source vector type must be byte addressed."); + return out; + } - constexpr int vector_offset = Offset / vector_width; - ResultVectorArray* packed_result_vec = reinterpret_cast(&result) + vector_offset; - SourceVectorArray const* packed_source_vec = reinterpret_cast(&source) + vector_offset; + CUTLASS_DEVICE + static result_type_packed_2 ptx_convert(source_type_packed_2 const &source) { + result_type_packed_2 out; + uint32_t* out_halfx2 = reinterpret_cast(&out); + uint16_t const& src_packed = static_cast(reinterpret_cast(source)); - // Convert the remaining elements as vectors. - constexpr int total_elements = ArrayConverter::result_type::kElements; - constexpr int groups_of_vec = (total_elements - Offset) / vector_width; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < groups_of_vec; ++i) { - packed_result_vec[i] = ArrayConverter::template packed_convert(packed_source_vec[i]); - } + asm volatile( \ + "{\n" \ + ".reg .b8 byte0, byte1;\n" \ + "mov.b16 {byte0, byte1}, %1;\n" \ + "cvt.rn.f16x2.e2m1x2 %0, byte0;\n" \ + "}\n" : "=r"(out_halfx2[0]) : "h"(src_packed)); - constexpr int new_offset = Offset + vector_width * groups_of_vec; - // Recurse to handle other vector converters, or the scalar base case. - convert_helper(result, source); - } + return out; + } - public: - /* - A method to convert vectors of elements using the packed_convert method of the converter. +#endif - Converters using this class must implement packed convert and support 1 or more vector conversions. - */ - template - CUTLASS_DEVICE - static void convert(typename ArrayConverter::result_type& result, typename ArrayConverter::source_type const& source) { - convert_helper<0, 0, ArrayConverter, ResultVectorArray, SourceVectorArray, OtherVectorArrays...>(result, source); + CUTLASS_DEVICE + static result_type_packed_8 lut_convert(source_type_packed_8 const &source) { + result_type_packed_8 out; + uint32_t* out_halfx2 = reinterpret_cast(&out); + uint32_t const& src_packed = reinterpret_cast(source); + + // LUT x4 e2m1 to fp16 (SM90) + // CUTLASS_PRAGMA_UNROLL + // for (int i = 0; i < 2; i++) { + // unsigned int lane0123 = src_packed >> (16 * i); + // detail::_e2m1_to_half_x4(lane0123, out_halfx2[2*i], out_halfx2[2*i+1]); + //} + + // LUT x8 e2m1 to fp16 (SM90) + detail::_e2m1_to_half_x8(src_packed, out_halfx2[0], out_halfx2[1], out_halfx2[2], out_halfx2[3]); + + return out; + } + + CUTLASS_DEVICE + static result_type_packed_4 lut_convert(source_type_packed_4 const &source) { + result_type_packed_4 out; + uint32_t* out_halfx2 = reinterpret_cast(&out); + uint16_t const& src_packed = reinterpret_cast(source); + + // LUT x4 e2m1 to half + detail::_e2m1_to_half_x4(src_packed, out_halfx2[0], out_halfx2[1]); + return out; + } + + CUTLASS_DEVICE + static result_type_packed_2 lut_convert(source_type_packed_2 const &source) { + result_type_packed_2 out; + uint32_t* out_halfx2 = reinterpret_cast(&out); + uint16_t const& src_packed = static_cast(reinterpret_cast(source)); + + // LUT x2 e2m1 to half + detail::_e2m1_to_half_x2(src_packed, out_halfx2[0]); + return out; + } + + + template + CUTLASS_DEVICE + static PackedResultType packed_convert(PackedSrcType const &source) { + static_assert((platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 2, 4 or 8 to use private convert dispatch."); + + // either call lookup table or PTX instruction implementation + #if defined(USE_PTX_CONVERT) // defined(CUDA_PTX_FP4FP6_CVT_ENABLED) + return ptx_convert(source); + #else + return lut_convert(source); + #endif + } + + friend class detail::VectorizedConverter; + +public: + CUTLASS_HOST_DEVICE + static result_type convert(source_type const &source) { + #if defined(__CUDA_ARCH__) + result_type result; + using ConverterType = NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + #else + result_type result; + ScalarConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = converter(source[i]); } - }; -} + + return result; + #endif + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Partial specialization for Array <= Array +/// Partial specialization for Array <= Array template < FloatRoundStyle Round, int N > -struct NumericArrayConverter { - using result_element = cutlass::half_t; +struct NumericArrayConverter { + using result_element = cutlass::bfloat16_t; using source_element = cutlass::float_e2m1_t; using result_type = Array; using source_type = Array; static FloatRoundStyle const round_style = Round; private: - using result_type_packed_8 = Array; - using result_type_packed_4 = Array; - using result_type_packed_2 = Array; + using result_type_packed_8 = Array; + using result_type_packed_4 = Array; + using result_type_packed_2 = Array; using source_type_packed_8 = Array; using source_type_packed_4 = Array; using source_type_packed_2 = Array; - using ScalarConverter = NumericConverter; + using ScalarConverter = NumericConverter; - #if defined(CUDA_PTX_FP8_CVT_ENABLED) CUTLASS_DEVICE - static result_type_packed_8 ptx_convert(source_type_packed_8 const &source) { + static result_type_packed_8 lut_convert(source_type_packed_8 const &source) { + result_type_packed_8 out; - uint32_t* out_fp16 = reinterpret_cast(&out); + uint32_t* out_bfloatx2 = reinterpret_cast(&out); uint32_t const& src_packed = reinterpret_cast(source); - asm volatile( \ - "{\n" \ - ".reg .b8 byte0, byte1, byte2, byte3;\n" \ - "mov.b32 {byte0, byte1, byte2, byte3}, %4;\n" \ - "cvt.rn.f16x2.e2m1x2 %0, byte0;\n" \ - "cvt.rn.f16x2.e2m1x2 %1, byte1;\n" \ - "cvt.rn.f16x2.e2m1x2 %2, byte2;\n" \ - "cvt.rn.f16x2.e2m1x2 %3, byte3;\n" \ - "}\n" : "=r"(out_fp16[0]), "=r"(out_fp16[1]) , "=r"(out_fp16[2]), "=r"(out_fp16[3]): "r"(src_packed)); + + // LUT x4 e2m1 to bf16 (SM90) + // CUTLASS_PRAGMA_UNROLL + // for (int i = 0; i < 2; i++) { + // unsigned int lane0123 = src_packed >> (16 * i); + // detail::_e2m1_to_bf16_x4(lane0123, out_bfloatx2[2*i], out_bfloatx2[2*i+1]); + // } + + // LUT x8 e2m1 to bf16 (SM90) + detail::_e2m1_to_bf16_x8(src_packed, out_bfloatx2[0], out_bfloatx2[1], out_bfloatx2[2], out_bfloatx2[3]); + return out; } CUTLASS_DEVICE - static result_type_packed_4 ptx_convert(source_type_packed_4 const &source) { + static result_type_packed_4 lut_convert(source_type_packed_4 const &source) { result_type_packed_4 out; - uint32_t* out_fp16 = reinterpret_cast(&out); + uint32_t* out_bfloatx2 = reinterpret_cast(&out); uint16_t const& src_packed = reinterpret_cast(source); - asm volatile( \ - "{\n" \ - ".reg .b8 byte0, byte1;\n" \ - "mov.b16 {byte0, byte1}, %2;\n" \ - "cvt.rn.f16x2.e2m1x2 %0, byte0;\n" \ - "cvt.rn.f16x2.e2m1x2 %1, byte1;\n" \ - "}\n" : "=r"(out_fp16[0]), "=r"(out_fp16[1]) : "h"(src_packed)); + + // LUT x4 e2m1 to bf16 + detail::_e2m1_to_bf16_x4(src_packed, out_bfloatx2[0], out_bfloatx2[1]); return out; } CUTLASS_DEVICE - static result_type_packed_2 ptx_convert(source_type_packed_2 const &source) { + static result_type_packed_2 lut_convert(source_type_packed_2 const &source) { result_type_packed_2 out; - uint32_t* out_fp16 = reinterpret_cast(&out); + uint32_t* out_bfloatx2 = reinterpret_cast(&out); uint16_t const& src_packed = static_cast(reinterpret_cast(source)); - asm volatile( \ - "{\n" \ - ".reg .b8 byte0, byte1;\n" \ - "mov.b16 {byte0, byte1}, %1;\n" \ - "cvt.rn.f16x2.e2m1x2 %0, byte0;\n" \ - "}\n" : "=r"(out_fp16[0]) : "h"(src_packed)); + + // LUT x2 e2m1 to bf16 (SM90) + detail::_e2m1_to_bf16_x2(src_packed, out_bfloatx2[0]); return out; } - #endif template CUTLASS_DEVICE @@ -4457,27 +5168,16 @@ struct NumericArrayConverter { platform::is_same::value), "Invalid PackedSrcType/PackedResultType must be 2, 4 or 8 to use private convert dispatch."); - #if defined(CUDA_PTX_FP4FP6_CVT_ENABLED) - return ptx_convert(source); - #else - PackedResultType result; - NumericConverter converter; - - const int k_packed = PackedResultType::kElements; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < k_packed; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif + // Option to add an optimized PTX path + return lut_convert(source); } friend class detail::VectorizedConverter; public: - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE static result_type convert(source_type const &source) { + #if defined(__CUDA_ARCH__) result_type result; using ConverterType = NumericArrayConverter; detail::VectorizedConverter::convert { result_type_packed_2, source_type_packed_2>(result, source); return result; + #else + result_type result; + ScalarConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif } CUTLASS_HOST_DEVICE @@ -4545,7 +5256,7 @@ struct NumericArrayConverter // [0, 1, -2, -1] encoded as FP8 static constexpr uint32_t E4M3_LUT = 0xB8C03800; - const int iters = PackedSrcType::kElements / 4; + constexpr int iters = PackedSrcType::kElements / 4; #pragma unroll for (int ii = 0; ii < iters; ii += 2, src_reg >>= 16, src_reg_shifted >>= 16) { // This uses a look up table to convert packed int2s to packed fp8s, using the int4 value @@ -4949,6 +5660,202 @@ struct NumericArrayConverter { } }; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + +private: + using result_type_packed_8 = Array; + using result_type_packed_4 = Array; + using source_type_packed_8 = Array; + using source_type_packed_4 = Array; + + using ScalarConverter = NumericConverter; + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_4 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_8 const& source) { + return reinterpret_cast(source); + } + + // The core converter uses a lookup table to converts e2m1 -> e4m3. + template + CUTLASS_DEVICE + static PackedResultType packed_convert(PackedSrcType const &source) { + + static_assert((platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 4 or 8 to use private convert dispatch."); + + // Hold FP8 outputs in reg. We need 1 reg for every 4 outputs. + cutlass::AlignedArray out_fp8; + + // View the input as reg + uint32_t reg = to_reg(source); + + constexpr int iters_x8 = PackedSrcType::kElements / 8; + constexpr int iters_x4 = (PackedSrcType::kElements / 4) & 1; + constexpr unsigned long long E2M1_to_E4M3_LUT_0_7 = 0x4C4844403C383000ULL; + + // we really only get called with 4 or 8 elements, but allow for more arguments for future use + // practically it's one of either x8 or x4, since we unroll everything else will be optimized away + #pragma unroll + for (int ii = 0; ii < iters_x8; ++ii) { + detail::_e2m1_to_fp8_x8(reg, out_fp8[ii*2], out_fp8[ii*2+1]); + } + + if (iters_x4) { + detail::_e2m1_to_fp8_x4(reg, out_fp8[iters_x8*2]); + } + + return reinterpret_cast(out_fp8); + } + + friend class detail::VectorizedConverter; + +public: + CUTLASS_HOST_DEVICE + static result_type convert(source_type const &source) { + #if defined(__CUDA_ARCH__) + result_type result; + using ConverterType = NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + #else + result_type result; + ScalarConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif + } + + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + + +/// Partial specialization for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + +private: + using result_type_packed_8 = Array; + using result_type_packed_4 = Array; + using source_type_packed_8 = Array; + using source_type_packed_4 = Array; + + using ScalarConverter = NumericConverter; + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_4 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_8 const& source) { + return reinterpret_cast(source); + } + + // The core converter uses a lookup table to convert e2m1 -> e5m2. + // Uses the existing templated _e2m1_to_fp8_x* helper functions defined above. + template + CUTLASS_DEVICE + static PackedResultType packed_convert(PackedSrcType const &source) { + + static_assert((platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 4 or 8 to use private convert dispatch."); + + // Hold FP8 outputs in reg. We need 1 reg for every 4 outputs. + cutlass::AlignedArray out_fp8; + + // View the input as reg + uint32_t reg = to_reg(source); + + constexpr int iters_x8 = PackedSrcType::kElements / 8; + constexpr int iters_x4 = (PackedSrcType::kElements / 4) & 1; + // E2M1→E5M2 LUT from table at lines 6109-6124 + constexpr unsigned long long E2M1_to_E5M2_LUT_0_7 = 0x464442403E3C3800ULL; + + // we really only get called with 4 or 8 elements, but allow for more arguments for future use + #pragma unroll + for (int ii = 0; ii < iters_x8; ++ii) { + detail::_e2m1_to_fp8_x8(reg, out_fp8[ii*2], out_fp8[ii*2+1]); + } + + if (iters_x4) { + detail::_e2m1_to_fp8_x4(reg, out_fp8[iters_x8*2]); + } + + return reinterpret_cast(out_fp8); + } + + friend class detail::VectorizedConverter; + +public: + CUTLASS_HOST_DEVICE + static result_type convert(source_type const &source) { + #if defined(__CUDA_ARCH__) + result_type result; + using ConverterType = NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + #else + result_type result; + ScalarConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif + } + + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + + /// Partial specialization for Array <= Array template struct NumericArrayConverter { @@ -5012,7 +5919,7 @@ struct NumericArrayConverter static constexpr uint32_t NEG_E4M3s_REG2 = 0xB8C0C4C8; - const int iters = PackedSrcType::kElements / 4; + constexpr int iters = PackedSrcType::kElements / 4; #pragma unroll for (int ii = 0; ii < iters; ++ii, lut_idx >>=16, sign >>=16) { uint32_t final_prmt_idx = final_prmt_base | sign; @@ -5038,8 +5945,9 @@ struct NumericArrayConverter friend class detail::VectorizedConverter; public: - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE static result_type convert(source_type const &source) { + #if defined(__CUDA_ARCH__) result_type result; using ConverterType = NumericArrayConverter; detail::VectorizedConverter::convert result_type_packed_4, source_type_packed_4>(result, source); return result; + #else + result_type result; + ScalarConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif } - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE result_type operator()(source_type const &s) const { return convert(s); } @@ -5119,7 +6038,7 @@ struct NumericArrayConverter static constexpr uint32_t NEG_E5M2s_REG2 = 0xBCC0C2C4; - const int iters = PackedSrcType::kElements / 4; + constexpr int iters = PackedSrcType::kElements / 4; #pragma unroll for (int ii = 0; ii < iters; ++ii, lut_idx >>=16, sign >>=16) { uint32_t final_prmt_idx = final_prmt_base | sign; @@ -5145,8 +6064,9 @@ struct NumericArrayConverter friend class detail::VectorizedConverter; public: - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE static result_type convert(source_type const &source) { + #if defined(__CUDA_ARCH__) result_type result; using ConverterType = NumericArrayConverter; detail::VectorizedConverter::convert result_type_packed_4, source_type_packed_4>(result, source); return result; + #else + result_type result; + ScalarConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif } - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE result_type operator()(source_type const &s) const { return convert(s); } @@ -5226,7 +6157,7 @@ struct NumericArrayConverter static constexpr uint32_t E4M3s_REG4 = 0x57565554; - const int iters = PackedSrcType::kElements / 4; + constexpr int iters = PackedSrcType::kElements / 4; #pragma unroll for (int ii = 0; ii < iters; ++ii, lut_idx >>=16, sign >>=16) { uint32_t final_prmt_idx = final_prmt_base | sign; @@ -6175,7 +7106,7 @@ struct NumericArrayConverter { asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(r[ii]) : "r"(src_reg), "r"(prmt_indices[ii])); } - // In the absense of add.s16x2 instruction, use bit-wise operation to execute signed addition with magic numbers to achieve + // In the absence of add.s16x2 instruction, use bit-wise operation to execute signed addition with magic numbers to achieve // the same result as add.s16x2 instruction. // (See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-lop3) // For a logical operation F(a, b, c) the value of kImmLut can be computed by applying the same operation to diff --git a/3rd/cutlass/include/cutlass/numeric_size.h b/3rd/cutlass/include/cutlass/numeric_size.h index 4f267e5..0230c00 100644 --- a/3rd/cutlass/include/cutlass/numeric_size.h +++ b/3rd/cutlass/include/cutlass/numeric_size.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -50,7 +50,13 @@ struct sizeof_bits { }; template -struct sizeof_bits: sizeof_bits {}; +struct sizeof_bits : sizeof_bits {}; + +template +struct sizeof_bits : sizeof_bits {}; + +template +struct sizeof_bits : sizeof_bits {}; template <> struct sizeof_bits { diff --git a/3rd/cutlass/include/cutlass/numeric_types.h b/3rd/cutlass/include/cutlass/numeric_types.h index b79a3d2..7899020 100644 --- a/3rd/cutlass/include/cutlass/numeric_types.h +++ b/3rd/cutlass/include/cutlass/numeric_types.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -43,6 +43,7 @@ #include "cutlass/tfloat32.h" #include "cutlass/float8.h" #include "cutlass/uint128.h" +#include "cutlass/uint256.h" #include "cutlass/exmy_base.h" #include "cutlass/float_subbyte.h" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/pipeline/pipeline.hpp b/3rd/cutlass/include/cutlass/pipeline/pipeline.hpp index e9cf66a..b9c04df 100644 --- a/3rd/cutlass/include/cutlass/pipeline/pipeline.hpp +++ b/3rd/cutlass/include/cutlass/pipeline/pipeline.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/pipeline/sm100_pipeline.hpp b/3rd/cutlass/include/cutlass/pipeline/sm100_pipeline.hpp index 53bc919..f02ae23 100644 --- a/3rd/cutlass/include/cutlass/pipeline/sm100_pipeline.hpp +++ b/3rd/cutlass/include/cutlass/pipeline/sm100_pipeline.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -140,6 +140,8 @@ class PipelineUmmaAsync { int warp_idx = canonical_warp_idx_sync(); if (warp_idx == params.initializing_warp) { // Barrier FULL and EMPTY init + CUTLASS_ASSERT(params.producer_arv_count > 0 && "Producer arrival count must be non-zero"); + CUTLASS_ASSERT(params.consumer_arv_count > 0 && "Consumer arrival count must be non-zero"); cutlass::arch::detail::initialize_barrier_array_pair_aligned( storage.full_barrier_, storage.empty_barrier_, params.producer_arv_count, params.consumer_arv_count); } @@ -320,6 +322,24 @@ class PipelineTmaTransformAsync { } } + template + CUTLASS_DEVICE + PipelineTmaTransformAsync(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction, InitBarriers = {}, InitMasks = {}) + : impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{}) + , params_(params) + , empty_barrier_ptr_(&storage.empty_barrier_[0]) + , full_barrier_ptr_(&storage.full_barrier_[0]) { + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_barriers(storage, params_, cluster_shape, mcast_direction); + } + + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_masks(cluster_shape, mcast_direction); + } + } + // Helper function to initialize barriers template static @@ -334,9 +354,11 @@ class PipelineTmaTransformAsync { static constexpr bool IsDynamicCluster = not cute::is_static_v; static_assert(IsDynamicCluster or ((cute::size<0>(cluster_shape) % cute::size<0>(atom_thr_shape) == 0) && (cute::size<1>(cluster_shape) % cute::size<1>(atom_thr_shape) == 0))); - uint32_t const num_consumer_per_cluster = params.num_consumers / NumThreadsPerWarpGroup; + uint32_t const num_consumer_per_cluster = cute::ceil_div(params.num_consumers, static_cast(NumThreadsPerWarpGroup)); uint32_t const multicast_consumer_arrival_count = ((cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape)) + (cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape)) - 1) * num_consumer_per_cluster; + CUTLASS_ASSERT(multicast_consumer_arrival_count > 0 && "Multicast consumer arrival count must be non-zero"); + CUTLASS_ASSERT(producer_arv_cnt > 0 && "Producer arrival count must be non-zero"); cutlass::arch::detail::initialize_barrier_array_pair_aligned( storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); } @@ -344,8 +366,31 @@ class PipelineTmaTransformAsync { } template + static CUTLASS_DEVICE - void init_masks(ClusterShape cluster_shape, dim3 block_id_in_cluster = cute::block_id_in_cluster()) { + void + init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction) { + auto atom_thr_shape = AtomThrShape_MNK{}; + + int warp_idx = canonical_warp_idx_sync(); + if (warp_idx == params.initializing_warp) { + // Barrier FULL and EMPTY init + constexpr int producer_arv_cnt = 1; + uint32_t const num_consumer_per_cluster = params.num_consumers / NumThreadsPerWarpGroup; + uint32_t const multicast_consumer_arrival_count = (mcast_direction == McastDirection::kRow) ? + (cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape)) * num_consumer_per_cluster : // Mcast with row ctas + (cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape)) * num_consumer_per_cluster; // Mcast with col ctas + + cutlass::arch::detail::initialize_barrier_array_pair_aligned( + storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); + + } + cutlass::arch::fence_barrier_init(); + } + + template + CUTLASS_DEVICE + void init_masks(ClusterShape cluster_shape, dim3 block_id_in_cluster = cute::block_id_in_cluster(), McastDirection mcast_dir = McastDirection::kRowCol) { // Calculate consumer mask if (params_.role == ThreadCategory::Consumer) { // Logic to optimally schedule Empty Arrives @@ -374,10 +419,25 @@ class PipelineTmaTransformAsync { // STEP 2: Find if this dst block-id needs an arrival for this problem is_signaling_thread_ &= dst_blockid_ < cluster_size; - is_signaling_thread_ &= is_same_row_or_col(dst_blockid_, block_id_in_cluster, cluster_shape); + if(mcast_dir == McastDirection::kRowCol){ + is_signaling_thread_ &= is_same_row_or_col(dst_blockid_, block_id_in_cluster, cluster_shape); + } + if(mcast_dir == McastDirection::kRow){ + is_signaling_thread_ &= is_same_row(dst_blockid_, block_id_in_cluster, cluster_shape); + } } } + template + CUTLASS_DEVICE + bool is_same_row(int dst_block_id, dim3 block_id, ClusterShape cluster_shape) { + return (((dst_block_id % cute::size<0>(cluster_shape)) == block_id.x) + // If we are in the same cluster column and using 2CTA MMA, only odd or only even CTAs sync with each other + && ((dst_block_id % cute::size<0>(cluster_shape)) % cute::size<0>(AtomThrShape_MNK{}) == + block_id.x % cute::size<0>(AtomThrShape_MNK{})) + ); + } + template CUTLASS_DEVICE bool is_same_row_or_col(int dst_block_id, dim3 block_id, ClusterShape cluster_shape) { @@ -504,7 +564,8 @@ class PipelineTmaUmmaAsync { auto atom_thr_shape = AtomThrShape_MNK{}; uint32_t const multicast_consumer_arrival_count = (cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape)) + (cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape)) - 1; - + CUTLASS_ASSERT(multicast_consumer_arrival_count > 0 && "Multicast consumer arrival count must be non-zero"); + CUTLASS_ASSERT(producer_arv_cnt > 0 && "Producer arrival count must be non-zero"); cutlass::arch::detail::initialize_barrier_array_pair_aligned( storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); } @@ -525,6 +586,8 @@ class PipelineTmaUmmaAsync { cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape) : // Mcast with row ctas cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape); // Mcast with col ctas + CUTLASS_ASSERT(multicast_consumer_arrival_count > 0 && "Multicast consumer arrival count must be non-zero"); + CUTLASS_ASSERT(producer_arv_cnt > 0 && "Producer arrival count must be non-zero"); cutlass::arch::detail::initialize_barrier_array_pair_aligned( storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); } @@ -893,6 +956,8 @@ class PipelineCLCFetchAsync { int warp_idx = canonical_warp_idx_sync(); if (warp_idx == params.initializing_warp) { // Barrier FULL and EMPTY init + CUTLASS_ASSERT(params.producer_arv_count > 0 && "Producer arrival count must be non-zero"); + CUTLASS_ASSERT(params.consumer_arv_count > 0 && "Consumer arrival count must be non-zero"); cutlass::arch::detail::initialize_barrier_array_pair_aligned( full_barrier_ptr_, empty_barrier_ptr_, params_.producer_arv_count, params_.consumer_arv_count); } @@ -910,6 +975,8 @@ class PipelineCLCFetchAsync { int warp_idx = canonical_warp_idx_sync(); if (warp_idx == params.initializing_warp) { // Barrier FULL and EMPTY init + CUTLASS_ASSERT(params.producer_arv_count > 0 && "Producer arrival count must be non-zero"); + CUTLASS_ASSERT(params.consumer_arv_count > 0 && "Consumer arrival count must be non-zero"); cutlass::arch::detail::initialize_barrier_array_pair_aligned( full_barrier_ptr_, empty_barrier_ptr_, params_.producer_arv_count, params_.consumer_arv_count); } @@ -1102,7 +1169,7 @@ class PipelineEmpty { /////////////////////////////////////////////////////////////////////////////////////////////////// // // TMA (producer - consumer) Async Pipeline classes for Blackwell Sparse UMMA -// This is designed for the parttern that kernel has two different staged tensors. (AB and metadata) +// This is designed for the pattern that kernel has two different staged tensors. (AB and metadata) // /////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/pipeline/sm90_pipeline.hpp b/3rd/cutlass/include/cutlass/pipeline/sm90_pipeline.hpp index 0828c1e..bcf6c74 100644 --- a/3rd/cutlass/include/cutlass/pipeline/sm90_pipeline.hpp +++ b/3rd/cutlass/include/cutlass/pipeline/sm90_pipeline.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -309,13 +309,14 @@ class PipelineTmaAsync { if (is_initializing_warp) { // Barrier FULL and EMPTY init uint32_t const producer_arv_cnt = params.num_producers; - uint32_t const num_consumer_warpgroups_per_cluster = params.num_consumers / NumThreadsPerWarpGroup; + uint32_t const num_consumer_warpgroups_per_cluster = cute::ceil_div(params.num_consumers, static_cast(NumThreadsPerWarpGroup)); uint32_t multicast_consumer_arrival_count = params.num_consumers; // If cluster_size is 1 if (cute::size(cluster_shape) > 1) { multicast_consumer_arrival_count = (cute::size<0>(cluster_shape) + cute::size<1>(cluster_shape) - 1) * num_consumer_warpgroups_per_cluster; } - + CUTLASS_ASSERT(multicast_consumer_arrival_count > 0 && "Multicast consumer arrival count must be non-zero"); + CUTLASS_ASSERT(producer_arv_cnt > 0 && "Producer arrival count must be non-zero"); cutlass::arch::detail::initialize_barrier_array_pair_aligned( storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); } @@ -804,6 +805,8 @@ class PipelineTransactionAsync { if (is_initializing_warp) { // Barrier FULL and EMPTY init + CUTLASS_ASSERT(params.producer_arv_count > 0 && "Producer arrival count must be non-zero"); + CUTLASS_ASSERT(params.consumer_arv_count > 0 && "Consumer arrival count must be non-zero"); cutlass::arch::detail::initialize_barrier_array_pair_aligned( full_barrier_ptr, empty_barrier_ptr, params.producer_arv_count, params.consumer_arv_count); } @@ -1043,6 +1046,8 @@ class PipelineAsync { is_initializing_warp = (warp_idx == params.initializing_warp); if (is_initializing_warp) { // Barrier FULL and EMPTY init + CUTLASS_ASSERT(params.producer_arv_count > 0 && "Producer arrival count must be non-zero"); + CUTLASS_ASSERT(params.consumer_arv_count > 0 && "Consumer arrival count must be non-zero"); cutlass::arch::detail::initialize_barrier_array_pair_aligned( storage.full_barrier_, storage.empty_barrier_, params.producer_arv_count, params.consumer_arv_count); } @@ -1299,6 +1304,7 @@ class OrderedSequenceBarrier { // Barrier FULL, EMPTY init if (warp_idx == params.initializing_warp) { int arv_cnt = params.group_size; + CUTLASS_ASSERT(arv_cnt > 0 && "Arrive count must be non-zero"); constexpr int Stages = Depth * Length; cutlass::arch::detail::initialize_barrier_array_aligned( barrier_ptr_, arv_cnt); @@ -1307,6 +1313,7 @@ class OrderedSequenceBarrier { int warp_idx = canonical_warp_idx_sync(); int lane_predicate = cute::elect_one_sync(); + CUTLASS_ASSERT(params.group_size > 0 && "Group size must be non-zero"); // Barrier FULL, EMPTY init // Init is done only by the one elected thread of the block diff --git a/3rd/cutlass/include/cutlass/pitch_linear_coord.h b/3rd/cutlass/include/cutlass/pitch_linear_coord.h index 1b782ec..47896fd 100644 --- a/3rd/cutlass/include/cutlass/pitch_linear_coord.h +++ b/3rd/cutlass/include/cutlass/pitch_linear_coord.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/platform/platform.h b/3rd/cutlass/include/cutlass/platform/platform.h index 939451a..46e1adf 100644 --- a/3rd/cutlass/include/cutlass/platform/platform.h +++ b/3rd/cutlass/include/cutlass/platform/platform.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -31,6 +31,8 @@ #pragma once +#include "cutlass/tfloat32.h" + /** * \file * \brief C++ features that may be otherwise unimplemented for CUDA device functions. @@ -98,13 +100,13 @@ //----------------------------------------------------------------------------- // Dependencies //----------------------------------------------------------------------------- - +#include #if defined(__CUDACC_RTC__) -#include -#include -#include -#include -#include +#include CUDA_STD_HEADER(type_traits) +#include CUDA_STD_HEADER(utility) +#include CUDA_STD_HEADER(cstddef) +#include CUDA_STD_HEADER(cstdint) +#include CUDA_STD_HEADER(limits) #else #include #include @@ -128,7 +130,6 @@ #endif #include -#include #endif @@ -523,7 +524,7 @@ using std::is_trivially_copyable; #endif -#if (201703L <=__cplusplus) +#if (CUTLASS_CXX17_OR_LATER) /// std::is_unsigned_v using CUTLASS_STL_NAMESPACE::is_integral_v; @@ -582,6 +583,7 @@ struct alignment_of : std::alignment_of {}; #endif +#if CUDA_VERSION >= 11080 /* 16B specializations where 32-bit Win32 host compiler disagrees with device compiler */ template <> struct alignment_of { @@ -596,23 +598,68 @@ struct alignment_of { enum { value = 16 }; }; template <> -struct alignment_of { +struct alignment_of { enum { value = 16 }; }; template <> -struct alignment_of { +struct alignment_of { enum { value = 16 }; }; template <> -struct alignment_of { +struct alignment_of { enum { value = 16 }; }; + +#if CUDA_VERSION >= 13000 template <> -struct alignment_of { +struct alignment_of { enum { value = 16 }; }; template <> -struct alignment_of { +struct alignment_of { + enum { value = 16 }; +}; +template <> +struct alignment_of { + enum { value = 16 }; +}; +template <> +struct alignment_of { + enum { value = 16 }; +}; +template <> +struct alignment_of { + enum { value = 16 }; +}; +template <> +struct alignment_of { + enum { value = 32 }; +}; +template <> +struct alignment_of { + enum { value = 32 }; +}; +template <> +struct alignment_of { + enum { value = 32 }; +}; +template <> +struct alignment_of { + enum { value = 32 }; +}; +template <> +struct alignment_of { + enum { value = 32 }; +}; + +#else + +template <> +struct alignment_of { + enum { value = 16 }; +}; +template <> +struct alignment_of { enum { value = 16 }; }; template <> @@ -628,6 +675,9 @@ struct alignment_of { enum { value = 16 }; }; +#endif // CUDA_VERSION >= 13000 +#endif // CUDA_VERSION >= 11080 + // Specializations for volatile/const qualified types template struct alignment_of : alignment_of {}; @@ -865,6 +915,18 @@ struct numeric_limits { static constexpr bool has_infinity = true; }; +template <> +struct numeric_limits { + CUTLASS_HOST_DEVICE + static tfloat32_t infinity() noexcept { return tfloat32_t::bitcast(0x7f800000);} + CUTLASS_HOST_DEVICE + static tfloat32_t max() noexcept { return tfloat32_t::bitcast(0x7f7fffff);} + CUTLASS_HOST_DEVICE + static tfloat32_t lowest() noexcept { return tfloat32_t::bitcast(0xff7fffff);} + static constexpr bool is_integer = false; + static constexpr bool has_infinity = true; +}; + /// Returns a value that curries the `std::maximum()` function into the identity /// function. No value will compare < than this value. template diff --git a/3rd/cutlass/include/cutlass/predicate_vector.h b/3rd/cutlass/include/cutlass/predicate_vector.h index 0241a6f..76c02e2 100644 --- a/3rd/cutlass/include/cutlass/predicate_vector.h +++ b/3rd/cutlass/include/cutlass/predicate_vector.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -33,16 +33,17 @@ of boolean predicates. */ #pragma once - +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(cstdint) #else #include #endif -#include +#ifndef __QNX__ +#include CUDA_STD_HEADER(cassert) +#endif -#include "cutlass/cutlass.h" #include "cutlass/platform/platform.h" namespace cutlass { diff --git a/3rd/cutlass/include/cutlass/quaternion.h b/3rd/cutlass/include/cutlass/quaternion.h index 48ca362..90c9a52 100644 --- a/3rd/cutlass/include/cutlass/quaternion.h +++ b/3rd/cutlass/include/cutlass/quaternion.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/real.h b/3rd/cutlass/include/cutlass/real.h index cfca386..b2afe9a 100644 --- a/3rd/cutlass/include/cutlass/real.h +++ b/3rd/cutlass/include/cutlass/real.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/reduction/device/reduce_split_k.h b/3rd/cutlass/include/cutlass/reduction/device/reduce_split_k.h index 92b57aa..41eff8d 100644 --- a/3rd/cutlass/include/cutlass/reduction/device/reduce_split_k.h +++ b/3rd/cutlass/include/cutlass/reduction/device/reduce_split_k.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/reduction/device/tensor_reduce.h b/3rd/cutlass/include/cutlass/reduction/device/tensor_reduce.h index 26a0249..d4def07 100644 --- a/3rd/cutlass/include/cutlass/reduction/device/tensor_reduce.h +++ b/3rd/cutlass/include/cutlass/reduction/device/tensor_reduce.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_contiguous.h b/3rd/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_contiguous.h index c00c368..09e476a 100644 --- a/3rd/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_contiguous.h +++ b/3rd/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_contiguous.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_strided.h b/3rd/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_strided.h index c85d6dc..7bf125b 100644 --- a/3rd/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_strided.h +++ b/3rd/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_strided.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/reduction/kernel/reduce_softmax_final.h b/3rd/cutlass/include/cutlass/reduction/kernel/reduce_softmax_final.h index 3d39dc7..b462543 100644 --- a/3rd/cutlass/include/cutlass/reduction/kernel/reduce_softmax_final.h +++ b/3rd/cutlass/include/cutlass/reduction/kernel/reduce_softmax_final.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/reduction/kernel/reduce_split_k.h b/3rd/cutlass/include/cutlass/reduction/kernel/reduce_split_k.h index f6d2666..94338a5 100644 --- a/3rd/cutlass/include/cutlass/reduction/kernel/reduce_split_k.h +++ b/3rd/cutlass/include/cutlass/reduction/kernel/reduce_split_k.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h b/3rd/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h index 914bbdd..4fbe7ac 100644 --- a/3rd/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h +++ b/3rd/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_strided.h b/3rd/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_strided.h index 0538184..a3a336c 100644 --- a/3rd/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_strided.h +++ b/3rd/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_strided.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/reduction/thread/reduce.h b/3rd/cutlass/include/cutlass/reduction/thread/reduce.h index cc354df..77e6743 100644 --- a/3rd/cutlass/include/cutlass/reduction/thread/reduce.h +++ b/3rd/cutlass/include/cutlass/reduction/thread/reduce.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/reduction/thread/reduction_operators.h b/3rd/cutlass/include/cutlass/reduction/thread/reduction_operators.h index 3792d33..544e2c2 100644 --- a/3rd/cutlass/include/cutlass/reduction/thread/reduction_operators.h +++ b/3rd/cutlass/include/cutlass/reduction/thread/reduction_operators.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/reduction/threadblock_swizzle.h b/3rd/cutlass/include/cutlass/reduction/threadblock_swizzle.h index bbabaed..b111f45 100644 --- a/3rd/cutlass/include/cutlass/reduction/threadblock_swizzle.h +++ b/3rd/cutlass/include/cutlass/reduction/threadblock_swizzle.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/relatively_equal.h b/3rd/cutlass/include/cutlass/relatively_equal.h index 68bdb26..53ef622 100644 --- a/3rd/cutlass/include/cutlass/relatively_equal.h +++ b/3rd/cutlass/include/cutlass/relatively_equal.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -300,6 +300,7 @@ bool relatively_equal(float_ue4m3_t a, float_ue4m3_t b, float_ue4 return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); } + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass diff --git a/3rd/cutlass/include/cutlass/semaphore.h b/3rd/cutlass/include/cutlass/semaphore.h index 09a0a1a..4b45b5e 100644 --- a/3rd/cutlass/include/cutlass/semaphore.h +++ b/3rd/cutlass/include/cutlass/semaphore.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/subbyte_reference.h b/3rd/cutlass/include/cutlass/subbyte_reference.h index 6e98cdc..543089d 100644 --- a/3rd/cutlass/include/cutlass/subbyte_reference.h +++ b/3rd/cutlass/include/cutlass/subbyte_reference.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/tensor_coord.h b/3rd/cutlass/include/cutlass/tensor_coord.h index a124d39..66c233b 100644 --- a/3rd/cutlass/include/cutlass/tensor_coord.h +++ b/3rd/cutlass/include/cutlass/tensor_coord.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/tensor_ref.h b/3rd/cutlass/include/cutlass/tensor_ref.h index fc46749..030f890 100644 --- a/3rd/cutlass/include/cutlass/tensor_ref.h +++ b/3rd/cutlass/include/cutlass/tensor_ref.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/tensor_ref_planar_complex.h b/3rd/cutlass/include/cutlass/tensor_ref_planar_complex.h index 9ba3a23..fdef5b1 100644 --- a/3rd/cutlass/include/cutlass/tensor_ref_planar_complex.h +++ b/3rd/cutlass/include/cutlass/tensor_ref_planar_complex.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/tensor_view.h b/3rd/cutlass/include/cutlass/tensor_view.h index d669443..25d6833 100644 --- a/3rd/cutlass/include/cutlass/tensor_view.h +++ b/3rd/cutlass/include/cutlass/tensor_view.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/tensor_view_planar_complex.h b/3rd/cutlass/include/cutlass/tensor_view_planar_complex.h index 6b8f7b4..2081072 100644 --- a/3rd/cutlass/include/cutlass/tensor_view_planar_complex.h +++ b/3rd/cutlass/include/cutlass/tensor_view_planar_complex.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/tfloat32.h b/3rd/cutlass/include/cutlass/tfloat32.h index 7bc13e1..46cb8df 100644 --- a/3rd/cutlass/include/cutlass/tfloat32.h +++ b/3rd/cutlass/include/cutlass/tfloat32.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -467,12 +467,12 @@ tfloat32_t operator--(tfloat32_t & lhs, int) { // CUTLASS_HOST_DEVICE -cutlass::tfloat32_t operator "" _tf32(long double x) { +cutlass::tfloat32_t operator""_tf32(long double x) { return cutlass::tfloat32_t(float(x)); } CUTLASS_HOST_DEVICE -cutlass::tfloat32_t operator "" _tf32(unsigned long long int x) { +cutlass::tfloat32_t operator""_tf32(unsigned long long int x) { return cutlass::tfloat32_t(int(x)); } diff --git a/3rd/cutlass/include/cutlass/thread/matrix.h b/3rd/cutlass/include/cutlass/thread/matrix.h index c338306..0e68476 100644 --- a/3rd/cutlass/include/cutlass/thread/matrix.h +++ b/3rd/cutlass/include/cutlass/thread/matrix.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/trace.h b/3rd/cutlass/include/cutlass/trace.h index 803c72e..63f6575 100644 --- a/3rd/cutlass/include/cutlass/trace.h +++ b/3rd/cutlass/include/cutlass/trace.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/transform/collective/sm90_wgmma_transpose.hpp b/3rd/cutlass/include/cutlass/transform/collective/sm90_wgmma_transpose.hpp index 99c5bf7..4078d98 100644 --- a/3rd/cutlass/include/cutlass/transform/collective/sm90_wgmma_transpose.hpp +++ b/3rd/cutlass/include/cutlass/transform/collective/sm90_wgmma_transpose.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -289,7 +289,7 @@ class AsyncTranspositionOperandB { static constexpr int NumWarpTilePerWarpgroupTile = NumWarpsPerWarpGroup * (Steps == 8 ? 2 : 1); static constexpr int WarpTileSize = WarpgroupTileSize / NumWarpTilePerWarpgroupTile; - static_assert(WarpTileSize >= WarpThreadShapeN && WarpTileSize >= WarpThreadShapeK, "Invaild warp thread shape." ); + static_assert(WarpTileSize >= WarpThreadShapeN && WarpTileSize >= WarpThreadShapeK, "Invalid warp thread shape." ); static constexpr int TilesPerWarp = 2; // Each Warp would process 2 warp_tiles in one step. static constexpr int64_t WarpTileNCoordLUT = 06723763275316420; static constexpr int64_t WarpTileKCoordLUT = 05410541064206420; @@ -510,7 +510,7 @@ class AsyncTranspositionOperandB_1BElementB { static constexpr int NumWarpTilePerWarpgroupTile = NumWarpsPerWarpGroup * (Steps == 8 ? 2 : 1); static constexpr int WarpTileSize = WarpgroupTileSize / NumWarpTilePerWarpgroupTile; - static_assert(WarpTileSize >= WarpThreadShapeN && WarpTileSize >= WarpThreadShapeK, "Invaild warp thread shape." ); + static_assert(WarpTileSize >= WarpThreadShapeN && WarpTileSize >= WarpThreadShapeK, "Invalid warp thread shape." ); static constexpr int TilesPerWarp = 2; // Each Warp would process 2 warp_tiles in one step. static constexpr int64_t WarpTileNCoordLUT = 06723763275316420; static constexpr int64_t WarpTileKCoordLUT = 05410541064206420; diff --git a/3rd/cutlass/include/cutlass/transform/device/transform_universal_adapter.hpp b/3rd/cutlass/include/cutlass/transform/device/transform_universal_adapter.hpp index 265d2fe..58a264b 100644 --- a/3rd/cutlass/include/cutlass/transform/device/transform_universal_adapter.hpp +++ b/3rd/cutlass/include/cutlass/transform/device/transform_universal_adapter.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/transform/kernel/filter_format_transformer.hpp b/3rd/cutlass/include/cutlass/transform/kernel/filter_format_transformer.hpp index 9c9d758..f2978db 100644 --- a/3rd/cutlass/include/cutlass/transform/kernel/filter_format_transformer.hpp +++ b/3rd/cutlass/include/cutlass/transform/kernel/filter_format_transformer.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp b/3rd/cutlass/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp index 577c68c..e1b485c 100644 --- a/3rd/cutlass/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp +++ b/3rd/cutlass/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -378,13 +378,40 @@ class SM90StructuredSparseCompressor { } // Construct a sign bit mask for handling negative zeros - ElementAMmaRawUnit sign_mask = ElementAMmaRawUnit{ 0 }; - if constexpr (has_negative_zero_v) { - ElementAMmaRawUnit one_sign_mask = static_cast(~(ElementAMmaRawUnit{ 1 } << (cute::sizeof_bits_v - 1))); - for (int i = 0; i < sizeof(ElementAMmaRawUnit) / sizeof(ElementAUint); ++i) { - sign_mask = static_cast((int32_t)sign_mask | (int32_t)one_sign_mask << (i * cute::sizeof_bits_v)); + // Compute the mask value at compile time, then construct ElementAMmaRawUnit from it + // Case 1: float_e2m1_t (4-bit), uint4_t container, ElemsARawPerElementAMmaRaw 1 element + // - Result: negzero_mask_value = 0b0111 (stored in uint8_t as 0b0000_0111) + // + // Case 2: float_e2m1_t (4-bit), uint8_t container, ElemsARawPerElementAMmaRaw 2 elements + // - Result: negzero_mask_value = 0b0111_0111 + // + // Case 3: float_e4m3_t (8-bit), uint8_t container, ElemsARawPerElementAMmaRaw 1 element + // - Result: negzero_mask_value = 0b0111_1111 = 0x7F + // Note: Lambda returns uint32_t (constexpr-compatible) instead of ElementAMmaRawUnit (non-literal type) + constexpr uint32_t negzero_mask_value = []() constexpr -> uint32_t { + constexpr int ElementAMmaRawNumBits = cute::sizeof_bits_v; + constexpr int ElementANumBits = cute::sizeof_bits_v; + + if constexpr (has_negative_zero_v) { + // Create mask for one ElementA: all bits set except the sign bit (MSB) + // ElementANumBits = 4: (1 << 3) - 1 = 0b0111 + // ElementANumBits = 8: (1 << 7) - 1 = 0b0111_1111 + constexpr uint32_t ElementASignMask = (1u << (ElementANumBits - 1)) - 1; + + // Replicate the single-element mask across all packed elements + if constexpr (ElemsARawPerElementAMmaRaw == 1) { + return ElementASignMask; + } + else if constexpr (ElemsARawPerElementAMmaRaw == 2) { + return (ElementASignMask << ElementANumBits) | ElementASignMask; + } } - } + // No negative zero: return all bits set to 1 (no masking needed) + return (1u << ElementAMmaRawNumBits) - 1; + }(); + + // Construct ElementAMmaRawUnit from the compile-time computed mask value + const ElementAMmaRawUnit negzero_mask_out_sign_mask = ElementAMmaRawUnit{negzero_mask_value}; // * Compress // cACsAC is always row major order @@ -410,17 +437,21 @@ class SM90StructuredSparseCompressor { // * Find None Zero Element Idx within Chunk CUTE_UNROLL for (int elt_log_idx = 0; elt_log_idx < OneChunkSizeA{}; ++elt_log_idx) { - ElementAMmaRawUnit elem_A = tAsA[elt_log_idx]; + // Iterate through all ElementAMma within one logical chunk + ElementAMmaRawUnit tAsA_i = tAsA[elt_log_idx]; - // Handle negative 0 - ElementAMmaRawUnit masked_elem_A = elem_A; + // Mask off the signed bit s.t. negative zero is same as positive zero + ElementAMmaRawUnit tAsA_i_negzero_masked_out = tAsA_i; if constexpr (has_negative_zero_v) { - masked_elem_A = elem_A & sign_mask; + // For sub-bytes, LSB will contain valid bits. + // e.g. for float_e2m1_t, tAsA_i is stored in uint8_t with 0x0000yyyy where yyyy denote valid bits. + tAsA_i_negzero_masked_out = tAsA_i & negzero_mask_out_sign_mask; } - if (masked_elem_A != ElementAMmaRawUnit{0}) { + // Record this ElmentAMma if it's none zero + if (tAsA_i_negzero_masked_out != ElementAMmaRawUnit{0}) { non_zero_elt_log_idx[non_zero_cnt] = elt_log_idx; - tACsAC[non_zero_cnt] = elem_A; + tACsAC[non_zero_cnt] = tAsA_i; non_zero_cnt++; } } @@ -463,8 +494,7 @@ class SM90StructuredSparseCompressor { // * Output Cta Tensor S to G if (GemmM_within_Cta > 0 && GemmK_within_Cta > 0) { - constexpr int MaxVecBits = 128; // STG.128 - cute::cooperative_copy(threadIdx_X, cEsE, cEgE); + cute::cooperative_copy(threadIdx_X, cEsE, cEgE); } if (GemmMAlignedAC_within_Cta == TensorEAtomM{} && GemmKAlignedAC_within_Cta == TensorEAtomK{}) { @@ -546,9 +576,9 @@ class SM90StructuredSparseCompressor { // Row Major if constexpr (IsRowMajor) { CUTE_UNROLL - for (int iter_row_blk = 0; iter_row_blk < cutlass::ceil_div(shape<0>(dSrc), ThreadShapeRows * ValueShapeRows); ++iter_row_blk) { + for (int iter_row_blk = 0; iter_row_blk < cutlass::ceil_div(valid_rows, ThreadShapeRows * ValueShapeRows); ++iter_row_blk) { CUTE_UNROLL - for (int col_chunk_i = 0; col_chunk_i < cutlass::ceil_div(shape<1>(dSrc) , ThreadShapeCols * ValueShapeCols); ++col_chunk_i) { + for (int col_chunk_i = 0; col_chunk_i < cutlass::ceil_div(valid_cols, ThreadShapeCols * ValueShapeCols); ++col_chunk_i) { CUTE_UNROLL for (int iter_row_thr = 0; iter_row_thr < ValueShapeRows; ++iter_row_thr) { CUTE_UNROLL @@ -556,7 +586,9 @@ class SM90StructuredSparseCompressor { const int row_i = (iter_row_blk * ThreadShapeRows + threadIdx_X_row) * ValueShapeRows + iter_row_thr; const int col_i = (col_chunk_i * ThreadShapeCols + threadIdx_X_col) * ValueShapeCols + iter_col_thr; if constexpr ( (not pred) and (not IsQmmaF6) ) { - dDst(row_i, col_i) = dSrc(row_i, col_i); + if (row_i < valid_rows && col_i < valid_cols) { + dDst(row_i, col_i) = dSrc(row_i, col_i); + } } else { if (row_i < valid_rows && col_i < valid_cols) { @@ -571,9 +603,9 @@ class SM90StructuredSparseCompressor { // Col Major else { CUTE_UNROLL - for (int col_chunk_i = 0; col_chunk_i < cutlass::ceil_div(shape<1>(dSrc) , ThreadShapeCols * ValueShapeCols); ++col_chunk_i) { + for (int col_chunk_i = 0; col_chunk_i < cutlass::ceil_div(valid_cols, ThreadShapeCols * ValueShapeCols); ++col_chunk_i) { CUTE_UNROLL - for (int iter_row_blk = 0; iter_row_blk < cutlass::ceil_div(shape<0>(dSrc), ThreadShapeRows * ValueShapeRows); ++iter_row_blk) { + for (int iter_row_blk = 0; iter_row_blk < cutlass::ceil_div(valid_rows, ThreadShapeRows * ValueShapeRows); ++iter_row_blk) { CUTE_UNROLL for (int iter_col_thr = 0; iter_col_thr < ValueShapeCols; ++iter_col_thr) { CUTE_UNROLL @@ -581,7 +613,9 @@ class SM90StructuredSparseCompressor { const int row_i = (iter_row_blk * ThreadShapeRows + threadIdx_X_row) * ValueShapeRows + iter_row_thr; const int col_i = (col_chunk_i * ThreadShapeCols + threadIdx_X_col) * ValueShapeCols + iter_col_thr; if constexpr ( (not pred) and (not IsQmmaF6) ) { - dDst(row_i, col_i) = dSrc(row_i, col_i); + if (row_i < valid_rows && col_i < valid_cols) { + dDst(row_i, col_i) = dSrc(row_i, col_i); + } } else { if (row_i < valid_rows && col_i < valid_cols) { diff --git a/3rd/cutlass/include/cutlass/transform/kernel/sparse_gemm_compressor.hpp b/3rd/cutlass/include/cutlass/transform/kernel/sparse_gemm_compressor.hpp index 9f23535..85c0f7f 100644 --- a/3rd/cutlass/include/cutlass/transform/kernel/sparse_gemm_compressor.hpp +++ b/3rd/cutlass/include/cutlass/transform/kernel/sparse_gemm_compressor.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/transform/pitch_linear_thread_map.h b/3rd/cutlass/include/cutlass/transform/pitch_linear_thread_map.h index 6a8970e..aa05b62 100644 --- a/3rd/cutlass/include/cutlass/transform/pitch_linear_thread_map.h +++ b/3rd/cutlass/include/cutlass/transform/pitch_linear_thread_map.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -298,7 +298,7 @@ struct PitchLinearWarpRakedThreadMap { static_assert(Iterations::kCount, "Number of iterations must be non-zero"); - ///< Delta betweeen accesses (units of elements, concept: PitchLinearShape) + ///< Delta between accesses (units of elements, concept: PitchLinearShape) using Delta = layout::PitchLinearShape< Detail::WarpThreadArrangement::kContiguous * kElementsPerAccess, Detail::WarpThreadArrangement::kStrided @@ -427,7 +427,7 @@ struct PitchLinearStridedWarpRakedThreadMap { static_assert(Iterations::kCount, "Number of iterations must be non-zero"); - ///< Delta betweeen accesses (units of elements, concept: PitchLinearShape) + ///< Delta between accesses (units of elements, concept: PitchLinearShape) using Delta = typename BaseThreadMap::Delta; /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space @@ -531,7 +531,7 @@ struct TransposePitchLinearThreadMap { static_assert(Iterations::kCount, "Number of iterations must be non-zero"); - ///< Delta betweeen accesses (units of elements, concept: PitchLinearShape) + ///< Delta between accesses (units of elements, concept: PitchLinearShape) using Delta = layout::PitchLinearShape; @@ -716,7 +716,7 @@ struct PitchLinearWarpStripedThreadMap { static_assert(Iterations::kCount, "Number of iterations must be non-zero"); - ///< Delta betweeen accesses (units of elements, concept: PitchLinearShape) + ///< Delta between accesses (units of elements, concept: PitchLinearShape) using Delta = layout::PitchLinearShape< Detail::WarpThreadArrangement::kContiguous * kElementsPerAccess, Detail::WarpThreadArrangement::kStrided * Detail::WarpArrangement::kStrided @@ -897,7 +897,7 @@ struct TransposePitchLinearThreadMap2DThreadTile { /// Shape of access by each thread using ThreadAccessShape = typename ThreadMap::ThreadAccessShape; - ///< Delta betweeen accesses (units of elements, concept: PitchLinearShape) + ///< Delta between accesses (units of elements, concept: PitchLinearShape) using Delta = layout::PitchLinearShape; diff --git a/3rd/cutlass/include/cutlass/transform/thread/transpose.h b/3rd/cutlass/include/cutlass/transform/thread/transpose.h index 508cad8..7f4d788 100644 --- a/3rd/cutlass/include/cutlass/transform/thread/transpose.h +++ b/3rd/cutlass/include/cutlass/transform/thread/transpose.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/transform/thread/unary_op.h b/3rd/cutlass/include/cutlass/transform/thread/unary_op.h index 3977af5..a117839 100644 --- a/3rd/cutlass/include/cutlass/transform/thread/unary_op.h +++ b/3rd/cutlass/include/cutlass/transform/thread/unary_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/transform/threadblock/ell_iterator.h b/3rd/cutlass/include/cutlass/transform/threadblock/ell_iterator.h index bd717d6..c8cc50f 100644 --- a/3rd/cutlass/include/cutlass/transform/threadblock/ell_iterator.h +++ b/3rd/cutlass/include/cutlass/transform/threadblock/ell_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h b/3rd/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h index 3676c23..c7b5749 100644 --- a/3rd/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h +++ b/3rd/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_iterator.h b/3rd/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_iterator.h index 48fb983..27b7dc8 100644 --- a/3rd/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_iterator.h +++ b/3rd/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -76,7 +76,7 @@ namespace threadblock { /// To be efficient, this assumes the iterator will be dereferenced and advanced at least once /// outside any looping structure to minimize integer arithmetic. /// -/// Acceses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing +/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing /// the iterator. /// /// diff --git a/3rd/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h b/3rd/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h index dab597c..96e9809 100644 --- a/3rd/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h +++ b/3rd/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h b/3rd/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h index e5d9e70..39953b6 100644 --- a/3rd/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h +++ b/3rd/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h b/3rd/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h index e5c2a5f..422d9ef 100644 --- a/3rd/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h +++ b/3rd/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -419,7 +419,7 @@ class PredicatedTileAccessIterator, Element>; static bool const transpose = Transpose_; diff --git a/3rd/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h b/3rd/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h index 3acc31f..8411cd7 100644 --- a/3rd/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h +++ b/3rd/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -76,10 +76,10 @@ namespace threadblock { /// accesses may be performed without updating internal predicates and are efficient in terms of /// live register state and pointer arithmetic instructions. /// -/// To be efficient, this assumes the iteraor will be dereferenced and advanced at least once +/// To be efficient, this assumes the iterator will be dereferenced and advanced at least once /// outside any looping structure to minimize integer arithmetic. /// -/// Acceses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing +/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing /// the iterator. /// /// diff --git a/3rd/cutlass/include/cutlass/transform/threadblock/predicated_vector_access_iterator.h b/3rd/cutlass/include/cutlass/transform/threadblock/predicated_vector_access_iterator.h index df551c1..8133720 100644 --- a/3rd/cutlass/include/cutlass/transform/threadblock/predicated_vector_access_iterator.h +++ b/3rd/cutlass/include/cutlass/transform/threadblock/predicated_vector_access_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h b/3rd/cutlass/include/cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h index 1aae469..b66f66d 100644 --- a/3rd/cutlass/include/cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h +++ b/3rd/cutlass/include/cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator.h b/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator.h index cfb491b..3b42f31 100644 --- a/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator.h +++ b/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h b/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h index adda933..61f2bda 100644 --- a/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h +++ b/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h b/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h index 71c8968..6b24ba0 100644 --- a/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h +++ b/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h b/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h index e172447..4d16b01 100644 --- a/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h +++ b/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h b/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h index b55f841..a343ba5 100644 --- a/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h +++ b/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator.h b/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator.h index be07e43..43cf8e7 100644 --- a/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator.h +++ b/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h b/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h index 6c186ce..1485b15 100644 --- a/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h +++ b/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear_2dthreadtile.h b/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear_2dthreadtile.h index 5ed2e7f..907ed62 100644 --- a/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear_2dthreadtile.h +++ b/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear_2dthreadtile.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op.h b/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op.h index 723f328..9351f9c 100644 --- a/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op.h +++ b/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op_sm70.h b/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op_sm70.h index 53121c6..2695f62 100644 --- a/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op_sm70.h +++ b/3rd/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op_sm70.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/transform/threadblock/vector_iterator.h b/3rd/cutlass/include/cutlass/transform/threadblock/vector_iterator.h index 8e5d181..fe2e6d4 100644 --- a/3rd/cutlass/include/cutlass/transform/threadblock/vector_iterator.h +++ b/3rd/cutlass/include/cutlass/transform/threadblock/vector_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/transform/warp/vector_fragment_iterator.h b/3rd/cutlass/include/cutlass/transform/warp/vector_fragment_iterator.h index 707cbcc..61a763b 100644 --- a/3rd/cutlass/include/cutlass/transform/warp/vector_fragment_iterator.h +++ b/3rd/cutlass/include/cutlass/transform/warp/vector_fragment_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -34,7 +34,7 @@ \brief This defines a "fragment" iterator for visiting the fragments of a warp vector that participate in one warp-level mma operation. - Typically, this is used to access the scale/bias fragement of a warp-level mma operation. + Typically, this is used to access the scale/bias fragment of a warp-level mma operation. The scale/bias vector is then partitioned into smaller fragments that can be fed into next warp-level mma operation. diff --git a/3rd/cutlass/include/cutlass/uint128.h b/3rd/cutlass/include/cutlass/uint128.h index 295eaa6..6afe103 100644 --- a/3rd/cutlass/include/cutlass/uint128.h +++ b/3rd/cutlass/include/cutlass/uint128.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -33,9 +33,9 @@ \brief Defines an unsigned 128b integer with several operators to support 64-bit integer division. */ #pragma once - +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(cstdint) #else #include #include @@ -44,7 +44,6 @@ #include #endif -#include "cutlass/cutlass.h" /// Optionally enable GCC's built-in type #if (defined(__x86_64) || defined (__aarch64__)) && !(defined(__CUDA_ARCH__) && ((__CUDACC_VER_MAJOR__ <= 10) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ <= 4)))) && defined(__GNUC__) @@ -61,6 +60,64 @@ #endif #endif +CUTLASS_HOST_DEVICE +uint64_t umul128( + uint64_t multiplier, + uint64_t multiplicand, + uint64_t *high_product +) { + +#if defined(CUTLASS_INT128_ARITHMETIC) + return _umul128(multiplier, multiplicand, high_product); +#else + const uint64_t mask = 0xFFFFFFFF; + + uint64_t a_lo = multiplier & mask; + uint64_t a_hi = multiplier >> 32; + uint64_t b_lo = multiplicand & mask; + uint64_t b_hi = multiplicand >> 32; + + uint64_t p_ll = a_lo * b_lo; + uint64_t p_lh = a_lo * b_hi; + uint64_t p_hl = a_hi * b_lo; + uint64_t p_hh = a_hi * b_hi; + + uint64_t p_mid = (p_ll >> 32) + (p_lh & mask) + (p_hl & mask); + uint64_t r_lo = (p_ll & mask) + (p_mid << 32); + uint64_t r_hi = (p_lh & mask) + (p_hl & mask) + p_hh; + + *high_product = r_hi; + return r_lo; +#endif +} + + +CUTLASS_HOST_DEVICE +uint64_t udiv128(uint64_t high, uint64_t low, uint64_t divisor, uint64_t *remainder_ptr) { +#if defined(CUTLASS_INT128_ARITHMETIC_DIV) + return _udiv128(high, low, divisor, remainder_ptr); +#else + uint64_t quotient = 0, remainder = 0; + uint64_t const bit = 1; + for (int32_t i=127; i>=0; --i) { + uint64_t r = 0; + if (i >= 64) { + r = ((high >> (i - 64)) & bit); + } + else { + r = ((low >> i) & bit); + } + remainder = (remainder << 1) | r; + if (remainder >= divisor) { + remainder -= divisor; + quotient |= (bit << i); + } + } + *remainder_ptr = remainder; + return quotient; +#endif +} + namespace cutlass { ///! Unsigned 128b integer type @@ -158,16 +215,13 @@ struct alignas(16) uint128_t uint128_t y{}; #if defined(CUTLASS_UINT128_NATIVE) y.native = native * rhs; -#elif defined(CUTLASS_INT128_ARITHMETIC) +#else // Multiply by the low part - y.hilo_.lo = _umul128(hilo_.lo, rhs, &y.hilo_.hi); + y.hilo_.lo = umul128(hilo_.lo, rhs, &y.hilo_.hi); // Add the high part and ignore the overflow uint64_t overflow{0}; - y.hilo_.hi += _umul128(hilo_.hi, rhs, &overflow); -#else - CUTLASS_UNUSED(rhs); - exception(); + y.hilo_.hi += umul128(hilo_.hi, rhs, &overflow); #endif return y; } @@ -179,13 +233,10 @@ struct alignas(16) uint128_t uint64_t quotient{0}; #if defined(CUTLASS_UINT128_NATIVE) quotient = uint64_t(native / divisor); -#elif defined(CUTLASS_INT128_ARITHMETIC_DIV) +#else // implemented using MSVC's arithmetic intrinsics uint64_t remainder{0}; - quotient = _udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); -#else - CUTLASS_UNUSED(divisor); - exception(); + quotient = udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); #endif return quotient; } @@ -197,12 +248,9 @@ struct alignas(16) uint128_t uint64_t remainder{0}; #if defined(CUTLASS_UINT128_NATIVE) remainder = uint64_t(native % divisor); -#elif defined(CUTLASS_INT128_ARITHMETIC_DIV) - // implemented using MSVC's arithmetic intrinsics - (void)_udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); #else - CUTLASS_UNUSED(divisor); - exception(); + // implemented using MSVC's arithmetic intrinsics + (void)udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); #endif return remainder; } @@ -215,13 +263,9 @@ struct alignas(16) uint128_t #if defined(CUTLASS_UINT128_NATIVE) quotient = uint64_t(native / divisor); remainder = uint64_t(native % divisor); -#elif defined(CUTLASS_INT128_ARITHMETIC_DIV) - // implemented using MSVC's arithmetic intrinsics - quotient = _udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); #else - CUTLASS_UNUSED(remainder); - CUTLASS_UNUSED(divisor); - exception(); + // implemented using MSVC's arithmetic intrinsics + quotient = udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); #endif return quotient; } diff --git a/3rd/cutlass/include/cutlass/uint256.h b/3rd/cutlass/include/cutlass/uint256.h new file mode 100644 index 0000000..dadb636 --- /dev/null +++ b/3rd/cutlass/include/cutlass/uint256.h @@ -0,0 +1,93 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Defines an unsigned 256b integer. +*/ + +#pragma once +#include "cutlass/cutlass.h" +#if defined(__CUDACC_RTC__) +#include CUDA_STD_HEADER(cstdint) +#else +#include +#include +#include +#include +#include +#endif +#include "cutlass/uint128.h" + +namespace cutlass { + +///! Unsigned 256b integer type +struct alignas(32) uint256_t { + /// Size of one part of the uint's storage in bits + static constexpr int storage_bits_ = 128; + + struct hilo { + uint128_t lo; + uint128_t hi; + }; + + // Use a union to store either low and high parts. + union { + struct hilo hilo_; + }; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + uint256_t() : hilo_{uint128_t{}, uint128_t{}} {} + + /// Constructor from uint128 + CUTLASS_HOST_DEVICE + uint256_t(uint128_t lo_) : hilo_{lo_, uint128_t{}} {} + + /// Constructor from two 128b unsigned integers + CUTLASS_HOST_DEVICE + uint256_t(uint128_t lo_, uint128_t hi_) : hilo_{lo_, hi_} {} + + /// Lossily cast to uint128_t + CUTLASS_HOST_DEVICE + explicit operator uint128_t() const { + return hilo_.lo; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/3rd/cutlass/include/cutlass/version.h b/3rd/cutlass/include/cutlass/version.h index 41d7832..f388aa7 100644 --- a/3rd/cutlass/include/cutlass/version.h +++ b/3rd/cutlass/include/cutlass/version.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -35,8 +35,8 @@ #include #define CUTLASS_MAJOR 4 -#define CUTLASS_MINOR 0 -#define CUTLASS_PATCH 0 +#define CUTLASS_MINOR 4 +#define CUTLASS_PATCH 2 #ifdef CUTLASS_VERSIONS_GENERATED #include "cutlass/version_extended.h" diff --git a/3rd/cutlass/include/cutlass/wmma_array.h b/3rd/cutlass/include/cutlass/wmma_array.h index 77929f6..9547224 100644 --- a/3rd/cutlass/include/cutlass/wmma_array.h +++ b/3rd/cutlass/include/cutlass/wmma_array.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/3rd/cutlass/include/cutlass/workspace.h b/3rd/cutlass/include/cutlass/workspace.h index 485ebbe..610ca8c 100644 --- a/3rd/cutlass/include/cutlass/workspace.h +++ b/3rd/cutlass/include/cutlass/workspace.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -121,6 +121,8 @@ fill_workspace(void* workspace, T fill_value, size_t fill_count, cudaStream_t st #else CUdeviceptr d_workspace = reinterpret_cast(workspace); CUresult result = CUDA_SUCCESS; + +#ifndef __QNX__ if (sizeof(T) == 4) { result = cuMemsetD32Async(d_workspace, reinterpret_cast(fill_value), fill_count, stream); } @@ -130,6 +132,7 @@ fill_workspace(void* workspace, T fill_value, size_t fill_count, cudaStream_t st else if (sizeof(T) == 1) { result = cuMemsetD8Async(d_workspace, reinterpret_cast(fill_value), fill_count, stream); } +#endif if (CUDA_SUCCESS != result) { const char** error_string_ptr = nullptr; diff --git a/3rd/update-cutlass.sh b/3rd/update-cutlass.sh index adba038..6919429 100755 --- a/3rd/update-cutlass.sh +++ b/3rd/update-cutlass.sh @@ -4,7 +4,7 @@ rm -rf cutlass rm -rf cutlass.git git clone https://github.com/NVIDIA/cutlass.git cutlass.git -git -C ./cutlass.git checkout dc481792 +git -C ./cutlass.git checkout da5e086 mkdir -p cutlass mv cutlass.git/include cutlass/