From e312159ca0b9295cd3d49bb4a37229c7fbb371cf Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Wed, 17 Jun 2026 22:55:44 +0800 Subject: [PATCH 01/31] feat: first stage of vmi --- docs/designs/vmi-dialect-design.md | 2078 ++++++ docs/designs/vmi-implementation-manual.md | 4233 +++++++++++ include/PTO/IR/PTOAttrs.td | 2 + include/PTO/IR/PTOOps.td | 1 + include/PTO/IR/PTOTypeDefs.td | 1 + include/PTO/IR/VMIAttrs.td | 34 + include/PTO/IR/VMIOps.td | 562 ++ include/PTO/IR/VMITypeDefs.td | 67 + include/PTO/IR/VMIUtils.h | 53 + include/PTO/Transforms/Passes.h | 8 + include/PTO/Transforms/Passes.td | 71 + .../PTO/Transforms/VMITargetCapabilities.h | 318 + lib/PTO/IR/CMakeLists.txt | 1 + lib/PTO/IR/VMI.cpp | 1407 ++++ lib/PTO/Transforms/CMakeLists.txt | 3 + lib/PTO/Transforms/PTOValidateVMIIR.cpp | 445 ++ lib/PTO/Transforms/VMILayoutAssignment.cpp | 1330 ++++ lib/PTO/Transforms/VMIToVPTO.cpp | 6269 +++++++++++++++++ test/lit/CMakeLists.txt | 1 + test/lit/lit.cfg.py | 6 +- test/lit/vmi/vmi_absf_integer_invalid.pto | 19 + test/lit/vmi/vmi_absi_float_invalid.pto | 19 + ...ctive_prefix_index_result_type_invalid.pto | 21 + .../vmi/vmi_addf_lane_mismatch_invalid.pto | 21 + .../vmi/vmi_bitcast_total_bits_invalid.pto | 19 + test/lit/vmi/vmi_bitwise_float_invalid.pto | 64 + .../vmi_broadcast_type_mismatch_invalid.pto | 18 + ...i_channel_merge_input_mismatch_invalid.pto | 21 + ..._channel_merge_result_mismatch_invalid.pto | 21 + .../vmi_channel_split_lane_count_invalid.pto | 20 + ...vmi_channel_split_result_count_invalid.pto | 20 + .../vmi_compress_result_mismatch_invalid.pto | 23 + .../vmi/vmi_constant_attr_kind_invalid.pto | 20 + .../vmi_constant_element_count_invalid.pto | 20 + .../vmi/vmi_constant_element_type_invalid.pto | 20 + .../vmi_constant_mask_attr_kind_invalid.pto | 20 + ...mi_constant_mask_element_count_invalid.pto | 20 + ...vmi_constant_mask_element_type_invalid.pto | 20 + test/lit/vmi/vmi_divf_integer_invalid.pto | 22 + test/lit/vmi/vmi_elementwise_kind_invalid.pto | 63 + .../vmi/vmi_ensure_layout_surface_invalid.pto | 47 + test/lit/vmi/vmi_extf_direction_invalid.pto | 19 + .../vmi/vmi_extf_lane_mismatch_invalid.pto | 19 + test/lit/vmi/vmi_fma_integer_invalid.pto | 23 + test/lit/vmi/vmi_gather_indices_invalid.pto | 25 + .../lit/vmi/vmi_iota_element_type_invalid.pto | 19 + test/lit/vmi/vmi_iota_order_invalid.pto | 19 + ..._layout_assignment_active_prefix_index.pto | 26 + .../vmi_layout_assignment_broadcast_remat.pto | 52 + .../vmi_layout_assignment_call_boundary.pto | 45 + .../vmi/vmi_layout_assignment_cf_branch.pto | 56 + .../vmi/vmi_layout_assignment_cf_switch.pto | 51 + ...hannel_merge_count_unsupported_invalid.pto | 23 + ...hannel_split_count_unsupported_invalid.pto | 21 + .../vmi/vmi_layout_assignment_compress.pto | 30 + .../vmi_layout_assignment_compress_store.pto | 31 + .../vmi_layout_assignment_constant_remat.pto | 54 + .../vmi/vmi_layout_assignment_expand_load.pto | 35 + ...ayout_assignment_external_call_invalid.pto | 25 + ...ayout_assignment_external_decl_invalid.pto | 15 + ...yout_assignment_external_decl_preserve.pto | 23 + test/lit/vmi/vmi_layout_assignment_fma.pto | 29 + test/lit/vmi/vmi_layout_assignment_gather.pto | 35 + ...ayout_assignment_indirect_call_invalid.pto | 24 + .../vmi/vmi_layout_assignment_iota_remat.pto | 54 + .../vmi/vmi_layout_assignment_load_truncf.pto | 133 + ...ment_mask_granularity_conflict_invalid.pto | 33 + .../vmi/vmi_layout_assignment_mask_remat.pto | 73 + .../vmi_layout_assignment_mask_use_ensure.pto | 36 + .../vmi/vmi_layout_assignment_masked_load.pto | 32 + .../vmi_layout_assignment_multi_return.pto | 39 + ...signment_multi_return_conflict_invalid.pto | 30 + ...assignment_post_gate_type_attr_invalid.pto | 17 + .../vmi/vmi_layout_assignment_reduce_addf.pto | 30 + .../vmi/vmi_layout_assignment_reduce_addi.pto | 33 + .../vmi_layout_assignment_reduce_minmaxf.pto | 49 + .../lit/vmi/vmi_layout_assignment_scatter.pto | 32 + ...i_layout_assignment_scf_execute_region.pto | 38 + .../lit/vmi/vmi_layout_assignment_scf_for.pto | 43 + test/lit/vmi/vmi_layout_assignment_scf_if.pto | 50 + ...vmi_layout_assignment_scf_index_switch.pto | 48 + .../vmi/vmi_layout_assignment_scf_while.pto | 47 + .../vmi_layout_assignment_store_ensure.pto | 48 + .../vmi_layout_assignment_truncf_ensure.pto | 39 + test/lit/vmi/vmi_layout_assignment_widen.pto | 39 + test/lit/vmi/vmi_layout_factor_invalid.pto | 18 + .../vmi/vmi_layout_gate_surface_invalid.pto | 18 + .../vmi_layout_gate_surface_mask_invalid.pto | 20 + ...gate_type_attr_nested_physical_invalid.pto | 17 + ..._layout_gate_type_attr_surface_invalid.pto | 17 + test/lit/vmi/vmi_layout_gate_valid.pto | 23 + ...i_mask_concrete_without_layout_invalid.pto | 18 + test/lit/vmi/vmi_mask_granularity_invalid.pto | 18 + test/lit/vmi/vmi_mask_logic_invalid.pto | 67 + .../vmi/vmi_mask_pred_with_layout_invalid.pto | 18 + ..._masked_store_mask_granularity_invalid.pto | 25 + .../vmi/vmi_memory_element_type_invalid.pto | 57 + test/lit/vmi/vmi_min_max_integer_invalid.pto | 37 + test/lit/vmi/vmi_negf_integer_invalid.pto | 19 + test/lit/vmi/vmi_op_verifier_basic.pto | 106 + test/lit/vmi/vmi_pack_arity_invalid.pto | 20 + .../vmi_producer_boundary_helper_invalid.pto | 22 + .../vmi_producer_boundary_layout_invalid.pto | 19 + ..._producer_boundary_mask_layout_invalid.pto | 19 + ...i_producer_boundary_non_vmi_op_invalid.pto | 21 + ...vmi_producer_boundary_physical_invalid.pto | 30 + ...ucer_boundary_type_attr_layout_invalid.pto | 17 + ...ucer_boundary_type_attr_nested_invalid.pto | 17 + ...cer_boundary_type_attr_surface_invalid.pto | 17 + test/lit/vmi/vmi_producer_boundary_valid.pto | 27 + .../vmi_ptoas_backend_required_invalid.pto | 17 + test/lit/vmi/vmi_ptoas_cli_control_flow.pto | 43 + test/lit/vmi/vmi_ptoas_cli_pipeline.pto | 45 + test/lit/vmi/vmi_ptoas_public_abi_invalid.pto | 20 + .../vmi_ptoas_public_result_abi_invalid.pto | 22 + ...mi_reduce_addf_missing_reassoc_invalid.pto | 23 + test/lit/vmi/vmi_scatter_indices_invalid.pto | 24 + .../vmi_select_mask_granularity_invalid.pto | 25 + test/lit/vmi/vmi_shli_float_invalid.pto | 22 + test/lit/vmi/vmi_shrui_float_invalid.pto | 22 + test/lit/vmi/vmi_shrui_signed_invalid.pto | 22 + test/lit/vmi/vmi_to_vpto_abs.pto | 44 + .../vmi/vmi_to_vpto_active_prefix_index.pto | 33 + ...active_prefix_index_multichunk_invalid.pto | 26 + ..._vpto_active_prefix_index_tail_invalid.pto | 22 + test/lit/vmi/vmi_to_vpto_add.pto | 57 + test/lit/vmi/vmi_to_vpto_bf16_arith.pto | 50 + test/lit/vmi/vmi_to_vpto_bitcast.pto | 29 + test/lit/vmi/vmi_to_vpto_bitcast_partial.pto | 29 + test/lit/vmi/vmi_to_vpto_bitwise.pto | 53 + test/lit/vmi/vmi_to_vpto_broadcast.pto | 69 + test/lit/vmi/vmi_to_vpto_call_boundary.pto | 52 + test/lit/vmi/vmi_to_vpto_cf_branch.pto | 78 + .../vmi_to_vpto_channel_merge4_contiguous.pto | 40 + ...hannel_merge_count_unsupported_invalid.pto | 25 + ...i_to_vpto_channel_merge_layout_invalid.pto | 23 + ...to_channel_merge_partial_group_invalid.pto | 25 + ...hannel_split_count_unsupported_invalid.pto | 23 + ...i_to_vpto_channel_split_layout_invalid.pto | 22 + .../vmi/vmi_to_vpto_channel_split_merge.pto | 95 + .../vmi_to_vpto_channel_split_merge_tail.pto | 35 + ...to_channel_split_partial_group_invalid.pto | 24 + .../vmi_to_vpto_cmp_element_type_invalid.pto | 24 + ...vpto_cmp_predicate_unsupported_invalid.pto | 28 + test/lit/vmi/vmi_to_vpto_cmp_select.pto | 140 + ...unsigned_predicate_unsupported_invalid.pto | 28 + .../vmi_to_vpto_compaction_deint_invalid.pto | 58 + test/lit/vmi/vmi_to_vpto_compress.pto | 32 + ...mi_to_vpto_compress_multichunk_invalid.pto | 28 + test/lit/vmi/vmi_to_vpto_compress_store.pto | 33 + ...vpto_compress_store_multichunk_invalid.pto | 26 + .../vmi/vmi_to_vpto_compress_tail_invalid.pto | 28 + test/lit/vmi/vmi_to_vpto_constant.pto | 33 + test/lit/vmi/vmi_to_vpto_constant_mask.pto | 128 + .../vmi_to_vpto_constant_mask_nonprefix.pto | 34 + ...mi_to_vpto_constant_mask_rematerialize.pto | 42 + .../vmi_to_vpto_constant_nonsplat_invalid.pto | 23 + ...vmi_to_vpto_construction_width_invalid.pto | 34 + test/lit/vmi/vmi_to_vpto_create_mask.pto | 87 + .../vmi/vmi_to_vpto_create_mask_dynamic.pto | 132 + .../vmi_to_vpto_create_mask_plt_fallback.pto | 30 + .../vmi_to_vpto_create_mask_rematerialize.pto | 47 + test/lit/vmi/vmi_to_vpto_divf.pto | 33 + .../vmi/vmi_to_vpto_e2e_widen_add_store.pto | 74 + .../vmi_to_vpto_elementwise_width_invalid.pto | 41 + test/lit/vmi/vmi_to_vpto_ensure_identity.pto | 80 + .../vmi/vmi_to_vpto_ensure_layout_deint4.pto | 57 + ..._to_vpto_ensure_layout_partial_invalid.pto | 23 + .../vmi/vmi_to_vpto_ensure_layout_vdintlv.pto | 30 + .../vmi/vmi_to_vpto_ensure_layout_vintlv.pto | 49 + .../vmi_to_vpto_ensure_mask_granularity.pto | 40 + ...to_vpto_ensure_mask_granularity_direct.pto | 31 + ...vpto_ensure_mask_granularity_multistep.pto | 34 + .../vmi/vmi_to_vpto_ensure_mask_layout.pto | 114 + ...pto_ensure_mask_layout_partial_invalid.pto | 23 + .../vmi_to_vpto_ensure_mask_layout_widths.pto | 78 + .../vmi_to_vpto_expand_load_all_active.pto | 66 + ...oad_all_active_negative_offset_invalid.pto | 35 + ..._vpto_expand_load_partial_mask_invalid.pto | 33 + .../vmi_to_vpto_expand_load_runtime_mask.pto | 41 + test/lit/vmi/vmi_to_vpto_extf.pto | 74 + test/lit/vmi/vmi_to_vpto_extf_f8.pto | 59 + test/lit/vmi/vmi_to_vpto_extf_multichunk.pto | 35 + test/lit/vmi/vmi_to_vpto_fma.pto | 83 + .../vmi_to_vpto_fma_element_type_invalid.pto | 26 + ...vpto_function_type_layout_free_invalid.pto | 16 + test/lit/vmi/vmi_to_vpto_gather.pto | 37 + .../vmi/vmi_to_vpto_gather_f16_invalid.pto | 28 + ...i_to_vpto_gather_scatter_shape_invalid.pto | 91 + test/lit/vmi/vmi_to_vpto_iota.pto | 120 + test/lit/vmi/vmi_to_vpto_iota_tail.pto | 57 + test/lit/vmi/vmi_to_vpto_load_deint.pto | 53 + .../vmi/vmi_to_vpto_load_deint_multichunk.pto | 31 + .../vmi/vmi_to_vpto_load_nonfull_invalid.pto | 27 + .../vmi/vmi_to_vpto_load_safe_tail_memref.pto | 73 + ..._to_vpto_load_safe_tail_memref_invalid.pto | 25 + ...fe_tail_memref_negative_offset_invalid.pto | 25 + .../vmi/vmi_to_vpto_load_store_contiguous.pto | 33 + test/lit/vmi/vmi_to_vpto_mask_logic.pto | 126 + test/lit/vmi/vmi_to_vpto_masked_load.pto | 36 + ...mi_to_vpto_masked_load_nonfull_invalid.pto | 33 + ...i_to_vpto_masked_load_safe_tail_memref.pto | 69 + ...fe_tail_memref_negative_offset_invalid.pto | 31 + test/lit/vmi/vmi_to_vpto_masked_store.pto | 38 + .../vmi_to_vpto_masked_store_deint_tail.pto | 42 + ...i_to_vpto_masked_store_nonfull_invalid.pto | 26 + .../lit/vmi/vmi_to_vpto_masked_store_tail.pto | 40 + .../vmi_to_vpto_math_element_type_invalid.pto | 131 + .../vmi/vmi_to_vpto_memory_space_invalid.pto | 130 + test/lit/vmi/vmi_to_vpto_memory_x2_widths.pto | 44 + .../vmi/vmi_to_vpto_memref_layout_invalid.pto | 177 + test/lit/vmi/vmi_to_vpto_min_max.pto | 39 + test/lit/vmi/vmi_to_vpto_negf.pto | 29 + test/lit/vmi/vmi_to_vpto_pack_unpack.pto | 46 + test/lit/vmi/vmi_to_vpto_quant_dequant.pto | 310 + test/lit/vmi/vmi_to_vpto_quant_fp8.pto | 51 + test/lit/vmi/vmi_to_vpto_reduce_addf.pto | 36 + .../vmi_to_vpto_reduce_addf_f16_invalid.pto | 26 + .../vmi_to_vpto_reduce_addf_multichunk.pto | 38 + test/lit/vmi/vmi_to_vpto_reduce_addi.pto | 36 + .../vmi_to_vpto_reduce_addi_i16_invalid.pto | 26 + .../vmi_to_vpto_reduce_addi_multichunk.pto | 38 + .../vmi_to_vpto_reduce_maxf_multichunk.pto | 65 + .../vmi_to_vpto_reduce_maxf_tail_invalid.pto | 29 + test/lit/vmi/vmi_to_vpto_reduce_minf.pto | 36 + .../vmi/vmi_to_vpto_reduce_shape_invalid.pto | 85 + .../vmi_to_vpto_relu_element_type_invalid.pto | 22 + test/lit/vmi/vmi_to_vpto_scatter.pto | 31 + ...to_vpto_scatter_missing_unique_invalid.pto | 27 + test/lit/vmi/vmi_to_vpto_scf_for.pto | 44 + test/lit/vmi/vmi_to_vpto_scf_if.pto | 57 + test/lit/vmi/vmi_to_vpto_shli.pto | 33 + test/lit/vmi/vmi_to_vpto_shrui.pto | 33 + .../vmi/vmi_to_vpto_shuffle_forwarding.pto | 159 + .../vmi/vmi_to_vpto_shuffle_lane0_splat.pto | 44 + ...stable_gather_masked_load_todo_invalid.pto | 29 + test/lit/vmi/vmi_to_vpto_store_deint.pto | 64 + .../vmi/vmi_to_vpto_store_deint_invalid.pto | 22 + test/lit/vmi/vmi_to_vpto_store_deint_tail.pto | 35 + test/lit/vmi/vmi_to_vpto_store_tail.pto | 29 + .../vmi/vmi_to_vpto_store_width_invalid.pto | 38 + test/lit/vmi/vmi_to_vpto_sub_mul.pto | 60 + test/lit/vmi/vmi_to_vpto_tile_read_write.pto | 64 + .../vmi/vmi_to_vpto_tile_write_deint_tail.pto | 34 + test/lit/vmi/vmi_to_vpto_tile_write_tail.pto | 33 + ..._to_vpto_tile_write_tail_deint_invalid.pto | 22 + test/lit/vmi/vmi_to_vpto_truncf.pto | 56 + ...vpto_truncf_fp8_128_contiguous_invalid.pto | 25 + ..._vpto_truncf_unsupported_shape_invalid.pto | 23 + test/lit/vmi/vmi_to_vpto_type_arity.pto | 63 + ...vpto_type_attr_nested_residual_invalid.pto | 16 + ...vmi_to_vpto_type_attr_residual_invalid.pto | 16 + test/lit/vmi/vmi_to_vpto_type_only.pto | 27 + test/lit/vmi/vmi_to_vpto_unary_math.pto | 89 + ..._vpto_unrealized_cast_residual_invalid.pto | 20 + .../vmi_to_vpto_unsupported_op_invalid.pto | 25 + test/lit/vmi/vmi_truncf_direction_invalid.pto | 19 + .../vmi/vmi_truncf_lane_mismatch_invalid.pto | 19 + test/lit/vmi/vmi_type_attr_parse.pto | 40 + .../vmi/vmi_type_element_count_invalid.pto | 18 + .../vmi/vmi_unary_math_integer_invalid.pto | 55 + test/lit/vmi/vmi_unpack_arity_invalid.pto | 20 + .../vmi/dequant-f16-to-f32-tail/compare.py | 27 + .../vmi/dequant-f16-to-f32-tail/golden.py | 44 + .../vmi/dequant-f16-to-f32-tail/kernel.pto | 60 + .../vmi/dequant-f16-to-f32-tail/launch.cpp | 40 + .../vmi/dequant-f16-to-f32-tail/main.cpp | 78 + .../vmi/dequant-f16-to-f32-tail/ptoas.flags | 1 + .../vmi/dequant-f8-to-f32-tail/compare.py | 27 + .../vmi/dequant-f8-to-f32-tail/golden.py | 45 + .../vmi/dequant-f8-to-f32-tail/kernel.pto | 59 + .../vmi/dequant-f8-to-f32-tail/launch.cpp | 40 + .../cases/vmi/dequant-f8-to-f32-tail/main.cpp | 78 + .../vmi/dequant-f8-to-f32-tail/ptoas.flags | 1 + .../vmi/quant-f32-to-f16-tail/compare.py | 27 + .../cases/vmi/quant-f32-to-f16-tail/golden.py | 44 + .../vmi/quant-f32-to-f16-tail/kernel.pto | 60 + .../vmi/quant-f32-to-f16-tail/launch.cpp | 40 + .../cases/vmi/quant-f32-to-f16-tail/main.cpp | 78 + .../vmi/quant-f32-to-f16-tail/ptoas.flags | 1 + .../cases/vmi/quant-f32-to-f8-full/compare.py | 27 + .../cases/vmi/quant-f32-to-f8-full/golden.py | 40 + .../cases/vmi/quant-f32-to-f8-full/kernel.pto | 47 + .../cases/vmi/quant-f32-to-f8-full/launch.cpp | 40 + .../cases/vmi/quant-f32-to-f8-full/main.cpp | 79 + .../vmi/quant-f32-to-f8-full/ptoas.flags | 1 + .../cases/vmi/quant-f32-to-f8-tail/compare.py | 27 + .../cases/vmi/quant-f32-to-f8-tail/golden.py | 44 + .../cases/vmi/quant-f32-to-f8-tail/kernel.pto | 56 + .../cases/vmi/quant-f32-to-f8-tail/launch.cpp | 40 + .../cases/vmi/quant-f32-to-f8-tail/main.cpp | 78 + .../vmi/quant-f32-to-f8-tail/ptoas.flags | 1 + .../vmi/reduce-f16-f8-mul-store/compare.py | 27 + .../vmi/reduce-f16-f8-mul-store/golden.py | 46 + .../vmi/reduce-f16-f8-mul-store/kernel.pto | 66 + .../vmi/reduce-f16-f8-mul-store/launch.cpp | 43 + .../vmi/reduce-f16-f8-mul-store/main.cpp | 88 + .../vmi/reduce-f16-f8-mul-store/ptoas.flags | 1 + tools/CMakeLists.txt | 1 + tools/pto-test-opt/CMakeLists.txt | 35 + tools/pto-test-opt/pto-test-opt.cpp | 35 + tools/ptoas/ptoas.cpp | 62 + 302 files changed, 28543 insertions(+), 1 deletion(-) create mode 100644 docs/designs/vmi-dialect-design.md create mode 100644 docs/designs/vmi-implementation-manual.md create mode 100644 include/PTO/IR/VMIAttrs.td create mode 100644 include/PTO/IR/VMIOps.td create mode 100644 include/PTO/IR/VMITypeDefs.td create mode 100644 include/PTO/IR/VMIUtils.h create mode 100644 include/PTO/Transforms/VMITargetCapabilities.h create mode 100644 lib/PTO/IR/VMI.cpp create mode 100644 lib/PTO/Transforms/PTOValidateVMIIR.cpp create mode 100644 lib/PTO/Transforms/VMILayoutAssignment.cpp create mode 100644 lib/PTO/Transforms/VMIToVPTO.cpp create mode 100644 test/lit/vmi/vmi_absf_integer_invalid.pto create mode 100644 test/lit/vmi/vmi_absi_float_invalid.pto create mode 100644 test/lit/vmi/vmi_active_prefix_index_result_type_invalid.pto create mode 100644 test/lit/vmi/vmi_addf_lane_mismatch_invalid.pto create mode 100644 test/lit/vmi/vmi_bitcast_total_bits_invalid.pto create mode 100644 test/lit/vmi/vmi_bitwise_float_invalid.pto create mode 100644 test/lit/vmi/vmi_broadcast_type_mismatch_invalid.pto create mode 100644 test/lit/vmi/vmi_channel_merge_input_mismatch_invalid.pto create mode 100644 test/lit/vmi/vmi_channel_merge_result_mismatch_invalid.pto create mode 100644 test/lit/vmi/vmi_channel_split_lane_count_invalid.pto create mode 100644 test/lit/vmi/vmi_channel_split_result_count_invalid.pto create mode 100644 test/lit/vmi/vmi_compress_result_mismatch_invalid.pto create mode 100644 test/lit/vmi/vmi_constant_attr_kind_invalid.pto create mode 100644 test/lit/vmi/vmi_constant_element_count_invalid.pto create mode 100644 test/lit/vmi/vmi_constant_element_type_invalid.pto create mode 100644 test/lit/vmi/vmi_constant_mask_attr_kind_invalid.pto create mode 100644 test/lit/vmi/vmi_constant_mask_element_count_invalid.pto create mode 100644 test/lit/vmi/vmi_constant_mask_element_type_invalid.pto create mode 100644 test/lit/vmi/vmi_divf_integer_invalid.pto create mode 100644 test/lit/vmi/vmi_elementwise_kind_invalid.pto create mode 100644 test/lit/vmi/vmi_ensure_layout_surface_invalid.pto create mode 100644 test/lit/vmi/vmi_extf_direction_invalid.pto create mode 100644 test/lit/vmi/vmi_extf_lane_mismatch_invalid.pto create mode 100644 test/lit/vmi/vmi_fma_integer_invalid.pto create mode 100644 test/lit/vmi/vmi_gather_indices_invalid.pto create mode 100644 test/lit/vmi/vmi_iota_element_type_invalid.pto create mode 100644 test/lit/vmi/vmi_iota_order_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_active_prefix_index.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_broadcast_remat.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_call_boundary.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_cf_branch.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_cf_switch.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_channel_merge_count_unsupported_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_channel_split_count_unsupported_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_compress.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_compress_store.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_constant_remat.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_expand_load.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_external_call_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_external_decl_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_external_decl_preserve.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_fma.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_gather.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_indirect_call_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_iota_remat.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_load_truncf.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_mask_granularity_conflict_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_mask_remat.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_mask_use_ensure.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_masked_load.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_multi_return.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_multi_return_conflict_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_post_gate_type_attr_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_reduce_addf.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_reduce_addi.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_reduce_minmaxf.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_scatter.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_scf_execute_region.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_scf_for.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_scf_if.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_scf_index_switch.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_scf_while.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_store_ensure.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_truncf_ensure.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_widen.pto create mode 100644 test/lit/vmi/vmi_layout_factor_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_surface_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_surface_mask_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_type_attr_nested_physical_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_type_attr_surface_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_valid.pto create mode 100644 test/lit/vmi/vmi_mask_concrete_without_layout_invalid.pto create mode 100644 test/lit/vmi/vmi_mask_granularity_invalid.pto create mode 100644 test/lit/vmi/vmi_mask_logic_invalid.pto create mode 100644 test/lit/vmi/vmi_mask_pred_with_layout_invalid.pto create mode 100644 test/lit/vmi/vmi_masked_store_mask_granularity_invalid.pto create mode 100644 test/lit/vmi/vmi_memory_element_type_invalid.pto create mode 100644 test/lit/vmi/vmi_min_max_integer_invalid.pto create mode 100644 test/lit/vmi/vmi_negf_integer_invalid.pto create mode 100644 test/lit/vmi/vmi_op_verifier_basic.pto create mode 100644 test/lit/vmi/vmi_pack_arity_invalid.pto create mode 100644 test/lit/vmi/vmi_producer_boundary_helper_invalid.pto create mode 100644 test/lit/vmi/vmi_producer_boundary_layout_invalid.pto create mode 100644 test/lit/vmi/vmi_producer_boundary_mask_layout_invalid.pto create mode 100644 test/lit/vmi/vmi_producer_boundary_non_vmi_op_invalid.pto create mode 100644 test/lit/vmi/vmi_producer_boundary_physical_invalid.pto create mode 100644 test/lit/vmi/vmi_producer_boundary_type_attr_layout_invalid.pto create mode 100644 test/lit/vmi/vmi_producer_boundary_type_attr_nested_invalid.pto create mode 100644 test/lit/vmi/vmi_producer_boundary_type_attr_surface_invalid.pto create mode 100644 test/lit/vmi/vmi_producer_boundary_valid.pto create mode 100644 test/lit/vmi/vmi_ptoas_backend_required_invalid.pto create mode 100644 test/lit/vmi/vmi_ptoas_cli_control_flow.pto create mode 100644 test/lit/vmi/vmi_ptoas_cli_pipeline.pto create mode 100644 test/lit/vmi/vmi_ptoas_public_abi_invalid.pto create mode 100644 test/lit/vmi/vmi_ptoas_public_result_abi_invalid.pto create mode 100644 test/lit/vmi/vmi_reduce_addf_missing_reassoc_invalid.pto create mode 100644 test/lit/vmi/vmi_scatter_indices_invalid.pto create mode 100644 test/lit/vmi/vmi_select_mask_granularity_invalid.pto create mode 100644 test/lit/vmi/vmi_shli_float_invalid.pto create mode 100644 test/lit/vmi/vmi_shrui_float_invalid.pto create mode 100644 test/lit/vmi/vmi_shrui_signed_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_abs.pto create mode 100644 test/lit/vmi/vmi_to_vpto_active_prefix_index.pto create mode 100644 test/lit/vmi/vmi_to_vpto_active_prefix_index_multichunk_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_active_prefix_index_tail_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_add.pto create mode 100644 test/lit/vmi/vmi_to_vpto_bf16_arith.pto create mode 100644 test/lit/vmi/vmi_to_vpto_bitcast.pto create mode 100644 test/lit/vmi/vmi_to_vpto_bitcast_partial.pto create mode 100644 test/lit/vmi/vmi_to_vpto_bitwise.pto create mode 100644 test/lit/vmi/vmi_to_vpto_broadcast.pto create mode 100644 test/lit/vmi/vmi_to_vpto_call_boundary.pto create mode 100644 test/lit/vmi/vmi_to_vpto_cf_branch.pto create mode 100644 test/lit/vmi/vmi_to_vpto_channel_merge4_contiguous.pto create mode 100644 test/lit/vmi/vmi_to_vpto_channel_merge_count_unsupported_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_channel_merge_layout_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_channel_merge_partial_group_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_channel_split_count_unsupported_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_channel_split_layout_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_channel_split_merge.pto create mode 100644 test/lit/vmi/vmi_to_vpto_channel_split_merge_tail.pto create mode 100644 test/lit/vmi/vmi_to_vpto_channel_split_partial_group_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_cmp_element_type_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_cmp_predicate_unsupported_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_cmp_select.pto create mode 100644 test/lit/vmi/vmi_to_vpto_cmpi_unsigned_predicate_unsupported_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_compaction_deint_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_compress.pto create mode 100644 test/lit/vmi/vmi_to_vpto_compress_multichunk_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_compress_store.pto create mode 100644 test/lit/vmi/vmi_to_vpto_compress_store_multichunk_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_compress_tail_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_constant.pto create mode 100644 test/lit/vmi/vmi_to_vpto_constant_mask.pto create mode 100644 test/lit/vmi/vmi_to_vpto_constant_mask_nonprefix.pto create mode 100644 test/lit/vmi/vmi_to_vpto_constant_mask_rematerialize.pto create mode 100644 test/lit/vmi/vmi_to_vpto_constant_nonsplat_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_construction_width_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_create_mask.pto create mode 100644 test/lit/vmi/vmi_to_vpto_create_mask_dynamic.pto create mode 100644 test/lit/vmi/vmi_to_vpto_create_mask_plt_fallback.pto create mode 100644 test/lit/vmi/vmi_to_vpto_create_mask_rematerialize.pto create mode 100644 test/lit/vmi/vmi_to_vpto_divf.pto create mode 100644 test/lit/vmi/vmi_to_vpto_e2e_widen_add_store.pto create mode 100644 test/lit/vmi/vmi_to_vpto_elementwise_width_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_ensure_identity.pto create mode 100644 test/lit/vmi/vmi_to_vpto_ensure_layout_deint4.pto create mode 100644 test/lit/vmi/vmi_to_vpto_ensure_layout_partial_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_ensure_layout_vdintlv.pto create mode 100644 test/lit/vmi/vmi_to_vpto_ensure_layout_vintlv.pto create mode 100644 test/lit/vmi/vmi_to_vpto_ensure_mask_granularity.pto create mode 100644 test/lit/vmi/vmi_to_vpto_ensure_mask_granularity_direct.pto create mode 100644 test/lit/vmi/vmi_to_vpto_ensure_mask_granularity_multistep.pto create mode 100644 test/lit/vmi/vmi_to_vpto_ensure_mask_layout.pto create mode 100644 test/lit/vmi/vmi_to_vpto_ensure_mask_layout_partial_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_ensure_mask_layout_widths.pto create mode 100644 test/lit/vmi/vmi_to_vpto_expand_load_all_active.pto create mode 100644 test/lit/vmi/vmi_to_vpto_expand_load_all_active_negative_offset_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_expand_load_partial_mask_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_expand_load_runtime_mask.pto create mode 100644 test/lit/vmi/vmi_to_vpto_extf.pto create mode 100644 test/lit/vmi/vmi_to_vpto_extf_f8.pto create mode 100644 test/lit/vmi/vmi_to_vpto_extf_multichunk.pto create mode 100644 test/lit/vmi/vmi_to_vpto_fma.pto create mode 100644 test/lit/vmi/vmi_to_vpto_fma_element_type_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_function_type_layout_free_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_gather.pto create mode 100644 test/lit/vmi/vmi_to_vpto_gather_f16_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_gather_scatter_shape_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_iota.pto create mode 100644 test/lit/vmi/vmi_to_vpto_iota_tail.pto create mode 100644 test/lit/vmi/vmi_to_vpto_load_deint.pto create mode 100644 test/lit/vmi/vmi_to_vpto_load_deint_multichunk.pto create mode 100644 test/lit/vmi/vmi_to_vpto_load_nonfull_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_load_safe_tail_memref.pto create mode 100644 test/lit/vmi/vmi_to_vpto_load_safe_tail_memref_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_load_safe_tail_memref_negative_offset_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_load_store_contiguous.pto create mode 100644 test/lit/vmi/vmi_to_vpto_mask_logic.pto create mode 100644 test/lit/vmi/vmi_to_vpto_masked_load.pto create mode 100644 test/lit/vmi/vmi_to_vpto_masked_load_nonfull_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_masked_load_safe_tail_memref.pto create mode 100644 test/lit/vmi/vmi_to_vpto_masked_load_safe_tail_memref_negative_offset_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_masked_store.pto create mode 100644 test/lit/vmi/vmi_to_vpto_masked_store_deint_tail.pto create mode 100644 test/lit/vmi/vmi_to_vpto_masked_store_nonfull_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_masked_store_tail.pto create mode 100644 test/lit/vmi/vmi_to_vpto_math_element_type_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_memory_space_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_memory_x2_widths.pto create mode 100644 test/lit/vmi/vmi_to_vpto_memref_layout_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_min_max.pto create mode 100644 test/lit/vmi/vmi_to_vpto_negf.pto create mode 100644 test/lit/vmi/vmi_to_vpto_pack_unpack.pto create mode 100644 test/lit/vmi/vmi_to_vpto_quant_dequant.pto create mode 100644 test/lit/vmi/vmi_to_vpto_quant_fp8.pto create mode 100644 test/lit/vmi/vmi_to_vpto_reduce_addf.pto create mode 100644 test/lit/vmi/vmi_to_vpto_reduce_addf_f16_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_reduce_addf_multichunk.pto create mode 100644 test/lit/vmi/vmi_to_vpto_reduce_addi.pto create mode 100644 test/lit/vmi/vmi_to_vpto_reduce_addi_i16_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_reduce_addi_multichunk.pto create mode 100644 test/lit/vmi/vmi_to_vpto_reduce_maxf_multichunk.pto create mode 100644 test/lit/vmi/vmi_to_vpto_reduce_maxf_tail_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_reduce_minf.pto create mode 100644 test/lit/vmi/vmi_to_vpto_reduce_shape_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_relu_element_type_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_scatter.pto create mode 100644 test/lit/vmi/vmi_to_vpto_scatter_missing_unique_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_scf_for.pto create mode 100644 test/lit/vmi/vmi_to_vpto_scf_if.pto create mode 100644 test/lit/vmi/vmi_to_vpto_shli.pto create mode 100644 test/lit/vmi/vmi_to_vpto_shrui.pto create mode 100644 test/lit/vmi/vmi_to_vpto_shuffle_forwarding.pto create mode 100644 test/lit/vmi/vmi_to_vpto_shuffle_lane0_splat.pto create mode 100644 test/lit/vmi/vmi_to_vpto_stable_gather_masked_load_todo_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_store_deint.pto create mode 100644 test/lit/vmi/vmi_to_vpto_store_deint_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_store_deint_tail.pto create mode 100644 test/lit/vmi/vmi_to_vpto_store_tail.pto create mode 100644 test/lit/vmi/vmi_to_vpto_store_width_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_sub_mul.pto create mode 100644 test/lit/vmi/vmi_to_vpto_tile_read_write.pto create mode 100644 test/lit/vmi/vmi_to_vpto_tile_write_deint_tail.pto create mode 100644 test/lit/vmi/vmi_to_vpto_tile_write_tail.pto create mode 100644 test/lit/vmi/vmi_to_vpto_tile_write_tail_deint_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_truncf.pto create mode 100644 test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_truncf_unsupported_shape_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_type_arity.pto create mode 100644 test/lit/vmi/vmi_to_vpto_type_attr_nested_residual_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_type_attr_residual_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_type_only.pto create mode 100644 test/lit/vmi/vmi_to_vpto_unary_math.pto create mode 100644 test/lit/vmi/vmi_to_vpto_unrealized_cast_residual_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_unsupported_op_invalid.pto create mode 100644 test/lit/vmi/vmi_truncf_direction_invalid.pto create mode 100644 test/lit/vmi/vmi_truncf_lane_mismatch_invalid.pto create mode 100644 test/lit/vmi/vmi_type_attr_parse.pto create mode 100644 test/lit/vmi/vmi_type_element_count_invalid.pto create mode 100644 test/lit/vmi/vmi_unary_math_integer_invalid.pto create mode 100644 test/lit/vmi/vmi_unpack_arity_invalid.pto create mode 100644 test/vpto/cases/vmi/dequant-f16-to-f32-tail/compare.py create mode 100644 test/vpto/cases/vmi/dequant-f16-to-f32-tail/golden.py create mode 100644 test/vpto/cases/vmi/dequant-f16-to-f32-tail/kernel.pto create mode 100644 test/vpto/cases/vmi/dequant-f16-to-f32-tail/launch.cpp create mode 100644 test/vpto/cases/vmi/dequant-f16-to-f32-tail/main.cpp create mode 100644 test/vpto/cases/vmi/dequant-f16-to-f32-tail/ptoas.flags create mode 100644 test/vpto/cases/vmi/dequant-f8-to-f32-tail/compare.py create mode 100644 test/vpto/cases/vmi/dequant-f8-to-f32-tail/golden.py create mode 100644 test/vpto/cases/vmi/dequant-f8-to-f32-tail/kernel.pto create mode 100644 test/vpto/cases/vmi/dequant-f8-to-f32-tail/launch.cpp create mode 100644 test/vpto/cases/vmi/dequant-f8-to-f32-tail/main.cpp create mode 100644 test/vpto/cases/vmi/dequant-f8-to-f32-tail/ptoas.flags create mode 100644 test/vpto/cases/vmi/quant-f32-to-f16-tail/compare.py create mode 100644 test/vpto/cases/vmi/quant-f32-to-f16-tail/golden.py create mode 100644 test/vpto/cases/vmi/quant-f32-to-f16-tail/kernel.pto create mode 100644 test/vpto/cases/vmi/quant-f32-to-f16-tail/launch.cpp create mode 100644 test/vpto/cases/vmi/quant-f32-to-f16-tail/main.cpp create mode 100644 test/vpto/cases/vmi/quant-f32-to-f16-tail/ptoas.flags create mode 100644 test/vpto/cases/vmi/quant-f32-to-f8-full/compare.py create mode 100644 test/vpto/cases/vmi/quant-f32-to-f8-full/golden.py create mode 100644 test/vpto/cases/vmi/quant-f32-to-f8-full/kernel.pto create mode 100644 test/vpto/cases/vmi/quant-f32-to-f8-full/launch.cpp create mode 100644 test/vpto/cases/vmi/quant-f32-to-f8-full/main.cpp create mode 100644 test/vpto/cases/vmi/quant-f32-to-f8-full/ptoas.flags create mode 100644 test/vpto/cases/vmi/quant-f32-to-f8-tail/compare.py create mode 100644 test/vpto/cases/vmi/quant-f32-to-f8-tail/golden.py create mode 100644 test/vpto/cases/vmi/quant-f32-to-f8-tail/kernel.pto create mode 100644 test/vpto/cases/vmi/quant-f32-to-f8-tail/launch.cpp create mode 100644 test/vpto/cases/vmi/quant-f32-to-f8-tail/main.cpp create mode 100644 test/vpto/cases/vmi/quant-f32-to-f8-tail/ptoas.flags create mode 100644 test/vpto/cases/vmi/reduce-f16-f8-mul-store/compare.py create mode 100644 test/vpto/cases/vmi/reduce-f16-f8-mul-store/golden.py create mode 100644 test/vpto/cases/vmi/reduce-f16-f8-mul-store/kernel.pto create mode 100644 test/vpto/cases/vmi/reduce-f16-f8-mul-store/launch.cpp create mode 100644 test/vpto/cases/vmi/reduce-f16-f8-mul-store/main.cpp create mode 100644 test/vpto/cases/vmi/reduce-f16-f8-mul-store/ptoas.flags create mode 100644 tools/pto-test-opt/CMakeLists.txt create mode 100644 tools/pto-test-opt/pto-test-opt.cpp diff --git a/docs/designs/vmi-dialect-design.md b/docs/designs/vmi-dialect-design.md new file mode 100644 index 0000000000..5578ca93d1 --- /dev/null +++ b/docs/designs/vmi-dialect-design.md @@ -0,0 +1,2078 @@ +# VMI dialect 设计 + +## 背景 + +VPTO 的 `!pto.vreg` 是 256 bytes 物理向量寄存器抽象。很多 VPTO op 暴露的是 +physical placement:`vcvt` part、pack/unpack、interleave/deinterleave、load/store dist、 +predicate granularity 等。TileLang `T.parallel` 或其它前端想表达的是逻辑向量语义,不应该 +手写这些 physical placement。 + +VMI dialect 的目标是提供一层 PTO-friendly 的 semantic vector IR。它不是任何外部向量 dialect +的语法克隆,也不是 VPTO physical dialect。VMI 的设计来源是 PTO virtual vector ISA 需要承接的 +逻辑向量语义、layout、mask granularity、memory safety 和控制流 layout join;后续 lowering 只从 +VMI 决定 physical layout 和 VPTO op。 + +本设计采用 `vmi.vreg` 作为 layout carrier,不再引入单独的 `vbundle` type: + +```text +semantic VMI + -> layout-assigned VMI + -> physical VPTO +``` + +VMI 的 producer 在核心设计之外。TileLang/PTO lowering、手写 VMI 测试或其它 import 工具都可以 +产生 VMI,但它们不能定义 VMI 的 semantic surface。核心设计只要求 producer 在进入 VMI boundary +时生成合法 VMI IR。 + +## 和旧 VMI layout 设计的关系 + +旧文档中的核心形式是: + +```mlir +!pto.vmi.vreg +!pto.vmi.mask +``` + +这个方向是对的:`vmi.vreg` 本身是 virtual aggregate type,可以承载完整 logical vector, +layout 放在它上面比放在 physical `!pto.vreg` 上更合理。 + +旧设计需要补强的地方主要是 layout descriptor 和 lowering contract,而不是推翻 +`vmi.vreg`: + +1. 旧 layout descriptor 把 `logical_shape`、`phys_dtype`、`phys_lanes` 放进 attr,和 + `vreg` / target registry 存在重复信息。重复字段会产生 verifier 漂移。 +2. `axes=[#axis<...>]` 太开放,缺少每个 layout 的精确定义、part ordering 和 lane map。 +3. 旧设计要求 `N * bitwidth(T)` 是 256B 整数倍,无法覆盖 tail / 非整 tile。 +4. mask 只写成 `mask`,但没有定义 data layout、mask layout、mask granularity + conversion 在宽度转换中的同步规则。 +5. 控制流 join 没有定义:`scf.if` 两边 layout 不同、`scf.for` loop-carried layout 如何稳定。 +6. memory access map 和 register layout 没有切开,容易把 strided memory view 误当成 vreg + layout。 +7. hard vector semantics 缺失,例如 padding read、active prefix index、dynamic permute、 + compress/expand、scan/reduction/contract 的 VMI 表达和 lowering contract。 + +因此本设计保留 `vmi.vreg` 这个 carrier,但不沿用旧 layout descriptor 的 +开放式语义。旧文档没有定义 “logical behavior -> hardware mismatch -> physical +decomposition -> lane map -> propagation/sink” 这条 source contract;这是本文新增的核心约束。 + +换句话说,本文不是复述旧 `vmi.layout`,而是把旧的开放式 axis descriptor 收紧成一个很小的 +public layout 集合。本设计只接受 `contiguous`、`deinterleaved = 2`、`deinterleaved = 4`。 +source contract 是新增 layout kind 的准入规则,不是要求实现 generic axes 或任意 lane-map +descriptor。 + +## 目标 + +1. VMI surface 表达逻辑向量语义,不暴露 VPTO part/dist/interleave 细节。 +2. `vmi.vreg` 是 virtual aggregate type,可以表示大于 256B 的 logical vector。 +3. layout 放在 layout-assigned VMI type 上,不再另设 `vbundle`。 +4. VMI mask 是一等类型;surface mask 表达 logical predicate,layout-assigned mask 才携带 + concrete predicate granularity `b8/b16/b32`。 +5. VMI 支持 tail / 非整 tile;padding physical lane 不可观察。 +6. VMI lowering 支持控制流中的 layout join。 +7. VMI producer boundary 后的 IR 必须只依赖 VMI semantic op/type 表达逻辑向量语义。 + +## 非目标 + +1. 不改变 physical `!pto.vreg` 的含义。它仍然是 256 bytes physical register。 +2. 不把 VMI 做成任何外部向量 dialect 的逐 op 复制品;VMI 只表达 PTO lowering 需要的 logical + vector semantics。 +3. 不把 scalar lane extract 当作 VMI vector op。scalar lane extract 是 vector-to-scalar + boundary,必须在进入 VMI 前被 producer 消除,或以明确 diagnostic 退出 PTO 路线。 +4. 不把 VPTO load/store dist 暴露成 VMI surface op。dist 是 lowering 选择。 + +## VMI Producer Boundary Contract + +VMI 是 PTO 路线上的 virtual vector ISA。任何 producer 在进入 VMI boundary 后,必须满足下面之一: + +1. 逻辑向量语义已经表达为 native VMI semantic op。 +2. 逻辑向量语义已经表达为一组 VMI semantic op 的组合,并保持 producer 的 observable semantics。 +3. 该行为不是 VMI 负责的向量计算,而是 vector-to-scalar / tensor / debug / transform boundary, + 已经在进入 VMI 前由 producer 消除,或以明确 diagnostic 退出 PTO 路线。 + +不能把“当前阶段不支持”作为 VMI 设计结果。一个 PTO virtual vector semantic 如果属于 VMI 负责的 +逻辑向量语义,文档必须给出 VMI op、组合 lowering、layout contract、memory fallback 或 target +capability diagnostic。diagnostic 只允许表示语义边界或目标能力缺失,不能表示“VMI 没有设计这个能力”。 + +`pto.vmi -> pto` 的完成条件是: + +```text +at VMI producer boundary: + logical vector semantics are represented by VMI op/type + no physical VPTO op is introduced by the producer + no hidden layout/mask/type side table is required to interpret a VMI value + +after vmi-layout-assignment: + every vmi.vreg/vmi.mask has an explicit #pto.vmi.layout + every mask granularity matches its consumer + every control-flow yield/iter_arg/result has one stable layout + +after vmi-to-vpto: + no pto.vmi op/type remains + every logical VMI value has been lowered to ordered physical VPTO values +``` + +### Capability And Fallback Policy + +所有 direct lowering 和 fallback 选择必须来自显式配置,不能依赖 pass 内隐藏全局状态: + +```text +TargetCapabilityRegistry: + element-type storage/compute/convert support + layout source/sink/conversion support + memory access capability: OOB, masked, gather/scatter, block-strided + predicate capability: granularity conversion, prefix-popcount, rearrangement + reduction/scan/contract capability + scratch memory spaces, alignment, and lifetime rules + +VMIToPTOOptions: + enableScratchFallback + enableGuardedScalarFallback + enableIndexBufferFallback + allowDebugStrip + targetVScaleSpecialization + diagnosticVerbosity +``` + +fallback 被 option 禁用时,diagnostic 必须报告 `disabled_by_option`。target registry 缺能力时, +diagnostic 必须报告 `missing_capability`。debug-only op 只能由 debug pipeline 消费,或在 +`allowDebugStrip` 明确开启时剥离;否则报 `VMI-DEBUG-BOUNDARY`。 + +fallback resource 也必须显式建模: + +```text +scratch fallback: + memory space, alignment, element type, shape, lifetime, and deallocation point + must be explicit in the lowering plan + scratch initialization, such as padding fill, must dominate later scratch load + +guarded scalar/vector fallback: + guard must dominate every memory effect it protects + invalid lane must not compute a memory effect through an OOB memref address + +index-buffer fallback: + index element width, signedness, and address unit must match the consumer + buffer lifetime must dominate gather/scatter or compaction use +``` + +如果无法分配 scratch、无法放置 guard、或 index buffer 宽度不满足目标要求,diagnostic 使用 +`VMI-FALLBACK-RESOURCE`,并说明是 resource 缺失而不是语义不可表达。 + +## 类型模型 + +### Surface Type + +VMI surface type 不显式写 layout: + +```mlir +!pto.vmi.vreg<128xf32> +!pto.vmi.vreg<256xf8> +!pto.vmi.vreg<1xf32> + +!pto.vmi.mask<128xpred> +!pto.vmi.mask<256xpred> +``` + +`N` 是 logical lane count,`T` 是 logical element type。surface `mask` 表示 N 个 +logical predicate lane,不预先绑定 VPTO predicate granularity。layout assignment 根据 consumer +选择 concrete granularity: + +```text +f32/i32 consumer -> b32 +f16/bf16/i16 consumer -> b16 +f8/i8 consumer -> b8 +``` + +如果一个 logical mask 被不同 width consumer 使用,VMI lowering 必须按 use 插入 +`vmi.ensure_mask_granularity` 或重物化 mask producer,不能假设某个 concrete granularity 可直接 +给所有 consumer 使用。 + +VMI type 以 1-D logical vector 为核心。来自 multi-rank producer value 的语义在进入 VMI boundary 前按 row-major flatten 成: + +```mlir +!pto.vmi.vreg<64xf32> +!pto.vmi.mask<64xpred> +``` + +VMI value 本身只承载 flattened lane sequence,不携带隐式 rank side table。需要 rank 信息的 op +必须在自身 attr 中保存 logical shape / indexing map,例如 `logical_shape = [8, 8]`。这样保持 +与既有 `vmi.vreg` 设计一致,同时不丢失 transfer、transpose、reshape 等 op 的语义。 + +shape-sensitive op 的规则是: + +```text +elementwise / select: + operate on flattened lanes and preserve any surrounding op-provided shape context + +tile_read / tile_write: + carry logical_shape and permutation_map attrs + +shape_cast / reshape / transpose / contract: + carry source/result shapes, maps, and iterator metadata as op attrs + +block argument / function argument: + carries only flat vreg type; any later shaped use must provide its own shape attrs +``` + +因此 logical shape 信息不能保存在 C++ side table,也不能要求 consumer 从 defining op 反查。 + +Rank-0 logical vector 仍然是 VMI vector value,不是 scalar SSA value: + +```mlir +rank-0 logical vector -> !pto.vmi.vreg<1xT> +rank-0 logical predicate -> !pto.vmi.mask<1xpred> +``` + +只有产生 scalar result 的 extract 才是 vector-to-scalar boundary。rank-0 logical vector load、 +bitcast、mask 和 arithmetic 仍然走 VMI,不能因为只有一个 lane 就绕开 VMI verifier。 + +Scalable logical vector 不能直接进入 VMI type,因为 `vmi.vreg` 的 `N` 是 concrete logical lane +count。producer 必须先根据 target profile 和 tiling decision 把 scalable semantics specialize 成固定 +`N`;否则在 VMI boundary 报 `VMI-SCALABLE-VECTOR`。这不是 VMI 的临时缺口,而是 +固定 256B physical vreg lowering 的前置约束。 + +### Layout-Assigned Type + +`vmi-layout-assignment` 后,所有 VMI data/mask value 都必须带 layout: + +```mlir +!pto.vmi.vreg<128xf32, #pto.vmi.layout> +!pto.vmi.vreg<128xf32, #pto.vmi.layout> +!pto.vmi.vreg<256xf32, #pto.vmi.layout> + +!pto.vmi.mask<128xb32, #pto.vmi.layout> +!pto.vmi.mask<128xb32, #pto.vmi.layout> +``` + +这里的 `#pto.vmi.layout` 是唯一的 VMI register layout carrier。它不是 `#pto.vlayout` +的直接复用,也不是 `vbundle` 的 type 参数;但它必须采用同一套精确 lane-map 语义,保证后续 +lower 到 physical VPTO 时可验证。 + +### 非整 Tile + +VMI type 不要求 `N * bitwidth(T)` 是 256B 整数倍: + +```mlir +!pto.vmi.vreg<100xf32> +!pto.vmi.mask<100xpred> +``` + +physical lowering 时按 256B part 向上取整。超出 `N` 的 physical lane 是 padding lane: + +```text +padding lane: + may be poison/undef internally + must not be stored + must not affect compare/reduction/scan + must not become visible through layout conversion +``` + +任何 store、reduction、compress、mask-producing op 都必须用 logical lane count 或 explicit +mask 保护 padding lane。 + +## Layout 设计来源 + +VMI layout 的价值必须从逻辑 vector 行为推导,而不是从 layout 名字推导。判断流程是: + +```text +1. 前端想表达一个完整的 logical vector 行为。 +2. VPTO 底层指令不能把这个 logical vector 天然放进一个 contiguous physical sequence。 +3. 但 VPTO 可以把这个 logical vector 拆成一组有固定 lane-map 的 physical parts。 +4. 后续常见 op 可以在这些 parts 上逐 part 保持 logical semantics。 +5. 边界 consumer 能直接消费这种 parts,或存在可验证的 materialize path。 +6. 因此值得把这个 parts relation 提升为 VMI layout。 +``` + +layout 不是“某条指令的名字”,而是一个 representation relation: + +```text +Layout L defines: + logical vector value V[NxT] + <-> ordered physical parts P0, P1, ... + with exact map logical lane i -> (part, lane) +``` + +只有当这个 relation 能让 VMI 保持“用户看到的是一个连续 logical vector”,同时避免前端手写 +parts,layout 才有设计价值。 + +### Register Layout 集合 + +VMI register layout 不采用复杂通用 descriptor,而是定义为封闭集合: + +```text +#pto.vmi.layout +#pto.vmi.layout +#pto.vmi.layout +``` + +`deinterleaved = K` 表示一个 logical vector 被拆成 K 个 physical part,第 `p` 个 part 保存 +logical lane `p, p + K, p + 2K, ...`。这个名字直接描述元素摆放,不绑定到某条 VPTO op,也不 +引入旧 `axes` 的通用维度系统。 + +不加入 `channel`、`packed_bits`、`blocked`、`stride`、`permutation` 等 layout kind。 +这些能力先由 VMI semantic op、memory access plan 或 explicit layout conversion 表达。只有当 +一个新 representation 同时满足下面的 source contract,才允许扩展 layout 目录。 + +### Layout Source Contract + +每个 VMI layout kind 必须来自一条明确的 source contract: + +```text +logical behavior: + VMI 想表达的用户级 vector 行为 + +hardware mismatch: + 为什么 VPTO 不能用一个 contiguous physical sequence 天然承载 + +physical decomposition: + VPTO 实际能产生或消费的 physical parts + +lane map: + logical lane -> physical part/lane 的精确定义 + +propagation rule: + 哪些 VMI op 可以逐 part 保持语义 + +boundary rule: + 哪些 load/store/pack/convert consumer 可以直接消费,哪些必须 materialize + +mask rule: + 对应 mask 如何生成、转换和消费 +``` + +没有这份 source contract 的 lane movement 不能进入 `#pto.vmi.layout`。 + +### Source 1: Widen Cast To Larger Logical Vector + +逻辑行为: + +```mlir +%w = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> +``` + +用户语义是“128 个 f16 lane 加宽成 128 个连续 f32 lane”。但 128 个 f32 是 512B,超过单个 +256B physical vreg。VPTO 的可行 lowering 不是一个 contiguous 512B register,而是两条 part +conversion: + +```text +even part: + physical even[i] = extf(logical[2*i]) + +odd part: + physical odd[i] = extf(logical[2*i+1]) +``` + +因此需要一个 layout 表达“这个 VMI value 仍然是 logical `128xf32`,但 physical representation +是 even/odd 两个 parts”: + +```mlir +#pto.vmi.layout +``` + +lane map: + +```text +part = i % 2 +lane = floor(i / 2) +physical[part][lane] = logical[i] +``` + +这个 layout 的价值在于后续 elementwise op 不需要 materialize contiguous representation: + +```mlir +%s = pto.vmi.addf %w, %b + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +lowering 可以变成两路 add: + +```text +add even parts +add odd parts +``` + +最后如果 store consumer 能把 even/odd parts 交织写回 contiguous memory,就不需要中途 +`ensure_layout contiguous`。 + +同理: + +```mlir +%w = pto.vmi.extf %a + : !pto.vmi.vreg<256xf8> -> !pto.vmi.vreg<256xf32> +``` + +需要: + +```mlir +#pto.vmi.layout +``` + +这里不再使用抽象 stride 命名。`deinterleaved = 4` 的来源是 `f8 -> f32` 的 VPTO part +conversion contract,不是任意 stride 语义。 + +### Source 2: Narrow / Pack Consumer + +逻辑行为: + +```mlir +%n = pto.vmi.truncf %x + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> +``` + +如果 `%x` 已经是 `#pto.vmi.layout`,VPTO 可以用 pack/narrow 类 +consumer 把 even/odd f32 parts 合成 contiguous f16 result。这里 layout 的来源不是 producer,而是 +consumer 能直接接受这种 decomposition: + +```text +source layout: + logical f32 value represented as even/odd f32 parts + +consumer: + narrowing pack consumes those parts + +result: + contiguous f16 logical vector +``` + +因此 `deinterleaved` 必须同时登记 producer contract 和 inverse/sink contract。否则 layout 只能 +产生,不能被合法消耗。 + +### Source 3: Same-Width Layout Materialization + +逻辑行为: + +```mlir +%x = pto.vmi.ensure_layout %v + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +这里不新增 surface view op。目标不是产生两个独立 semantic vectors,而是让同一个 logical +vector 继续作为一个 VMI value 存活,只是 physical representation 变成 even/odd parts。IR 中由 +`vmi-layout-assignment` 插入 +`pto.vmi.ensure_layout`,并由 target registry 证明存在 preserving materialization path。VPTO 的 +`vdintlv/vintlv` 类 register rearrangement 可以产生或消费这种 representation。 + +这和 `vcvt` 产生的 even/odd representation 使用同一个 layout: + +```mlir +#pto.vmi.layout +``` + +区别只在 source contract: + +```text +logical behavior: + 同宽 logical vector 保持一个 VMI value,但 physical parts 分别保存 even/odd lanes + +hardware mismatch: + VPTO interleave/deinterleave 指令以两个 physical vreg parts 表达 + +layout: + deinterleaved=2 +``` + +如果 VMI op 的语义本来就是“返回两个独立 vectors”,例如 AoS -> SoA 后用户分别使用 `%x` +和 `%y`,那不需要 layout,直接产生两个 `vmi.vreg`。只有当“一个 logical vector value” +需要以 even/odd parts 长期存活时,才使用 `deinterleaved=2`。 + +### Channel Split / Merge 不是 Register Layout + +channel split/merge 的用户代码通常有两种形态。 + +第一种是把 interleaved data 当作普通 flat vector: + +```text +logical = [r0, g0, b0, a0, r1, g1, b1, a1, ...] +对每个 lane 做同一种逐元素操作 +``` + +这种情况下 `contiguous` representation 就能表达用户语义,不需要 channel layout。 + +第二种是用户按 channel 编程: + +```mlir +%r, %g, %b, %a = pto.vmi.channel_split %rgba + : !pto.vmi.vreg<128xi8> + -> !pto.vmi.vreg<32xi8>, !pto.vmi.vreg<32xi8>, + !pto.vmi.vreg<32xi8>, !pto.vmi.vreg<32xi8> + +%r2 = pto.vmi.addi %r, %bias_r : !pto.vmi.vreg<32xi8> +%g2 = pto.vmi.addi %g, %bias_g : !pto.vmi.vreg<32xi8> +%b2 = pto.vmi.addi %b, %bias_b : !pto.vmi.vreg<32xi8> +%a2 = pto.vmi.addi %a, %bias_a : !pto.vmi.vreg<32xi8> +%out = pto.vmi.channel_merge %r2, %g2, %b2, %a2 + : !pto.vmi.vreg<32xi8>, !pto.vmi.vreg<32xi8>, + !pto.vmi.vreg<32xi8>, !pto.vmi.vreg<32xi8> + -> !pto.vmi.vreg<128xi8> +``` + +这里自然的 IR 是多个 semantic VMI values,而不是“一个 VMI value 带 channel layout”。 +目标专用 split/merge 能力是 `channel_split/channel_merge` 的 lowering contract;load/store +memory boundary 的 dist/sink contract 也可以作为等价 lowering path。 + +`channel_split` / `channel_merge` 的语义必须能完全退化成 static shuffle,不能引入额外 +layout 规则。`C` 不需要单独 attr:`channel_split` 的 `C` 来自 result 个数, +`channel_merge` 的 `C` 来自 operand 个数。设 input 有 `N = C * M` 个 logical lanes: + +```text +channel_split(input, C): + out[c][i] = input[i * C + c] + for 0 <= c < C + for 0 <= i < M + +channel_merge(out[0], ..., out[C-1]): + result[i * C + c] = out[c][i] + for 0 <= i < M + for 0 <= c < C +``` + +如果 `N` 不能被 `C` 整除,或者 merge operands 的 logical lane count 不一致,op verifier +必须拒绝。需要 tail 的场景通过外层 mask / valid lane 语义表达,不能让 channel op 自己发明 +padding lane。 + +因此这两个 op 的价值只是 canonical interface:producer 可以直接表达 channel 语义, +外部 import 工具也可以把识别出的 static shuffle pattern canonicalize 成它们;如果没有 +识别或目标没有专用 lowering,保持或退回 `pto.vmi.shuffle` 仍然是等价路径。 +当前 direct VPTO lowering 只接受能形成完整 physical channel groups 的形状:flat contiguous +source/result 与 virtual deinterleaved=C channel layout 必须有相同 physical arity,或已经是 matching +deinterleaved=C layout 的 identity forwarding。arity-changing partial group 需要额外 packing/drop +padding plan,不能直接 lowering。 + +所以 VMI register layout 目录不为 channel-specific representation 引入 layout kind,也不预留 +半成品 layout 语义。本文覆盖的用户形态要么是 flat contiguous vector,要么是多个 channel +semantic value;都不需要“一个 VMI value 带 channel layout”。 + +### Pack / Unpack 不作为长期 Layout + +pack/unpack 的逻辑行为通常是 width conversion 或 memory encoding: + +```text +wide logical vector -> narrow logical vector +narrow memory payload -> wide logical vector +``` + +它们的结果可以是 `contiguous` logical vector;pack/unpack 是 producer/sink/conversion +contract,不是必须长期传播的 register layout。只有当目标 ISA 提供 packed-format arithmetic, +并且 VMI 真的要让 packed representation 跨 compute 存活时,才需要另立 +`packed_bits` layout。本设计没有 packed-format arithmetic source contract,因此 pack/unpack 不进入 +长期 register layout。 + +### 不应成为 Register Layout 的东西 + +以下能力虽然来自 VPTO/VISA,但不是 VMI register layout: + +| 能力 | 原因 | +|---|---| +| `vsldb/vsstb` block stride | 描述 memory address map;result register 可仍是 contiguous representation | +| gather/scatter index | runtime address map,不是 static logical lane 到 physical part 的关系 | +| dynamic `vselr` | runtime permutation,应是 `pto.vmi.permute` op | +| `vsqz/vusqz` compaction | runtime mask 决定 lane destination,应是 `compress/active_prefix_index` op | +| one-shot `vintlv/vdintlv` | 如果只是 boundary conversion,不应提升成长期 layout;若表示一个 VMI value 的 even/odd parts,则归入 `deinterleaved=2` | + +VMI layout 只解决“一个 logical vector value 在寄存器中长期以什么 parts representation 存活” +的问题。memory address、runtime permutation、dynamic compaction 都是其它语义。 + +### Lane Map + +设: + +```text +N = logical lane count +lanesPerDataPart(T) = 256B / sizeof(T) +lanesPerMaskPart(b8) = 256 +lanesPerMaskPart(b16) = 128 +lanesPerMaskPart(b32) = 64 +``` + +`contiguous`: + +```text +chunk = floor(i / lanesPerPart) +lane = i % lanesPerPart +physical[chunk][lane] = logical[i] +``` + +`deinterleaved = K`,其中 `K` 只能是 2 或 4: + +```text +p = i % K +q = floor(i / K) +chunk = floor(q / lanesPerPart) +lane = q % lanesPerPart +physical[p][chunk][lane] = logical[i] +``` + +`deinterleaved=2` 和 `deinterleaved=4` 的 physical value ordering 固定为 part-major: + +```text +p0_chunk0, p0_chunk1, ..., p1_chunk0, p1_chunk1, ..., p(K-1)_chunk0, ... +``` + +所有 verifier、type converter、physical lowering 和 control-flow conversion 必须使用同一套 +ordering。 + +### Physical Arity + +`vmi-to-vpto` 不能按示例猜 physical value 个数,必须由 type + layout 统一推导。 + +对 data vreg: + +```text +lanesPerPart = 256B / sizeof(T) + +contiguous: + chunks = ceil(N / lanesPerPart) + physical values = chunks + +deinterleaved = K: + lanesPerLogicalPart = ceil(N / K) + chunksPerPart = ceil(lanesPerLogicalPart / lanesPerPart) + physical values = K * chunksPerPart +``` + +对 mask: + +```text +lanesPerPart = lanesPerMaskPart(G) +same formula as data, replacing T with mask granularity G +``` + +每个 physical value 的有效 lane 由 lane map 反推: + +```text +contiguous valid: + logical = chunk * lanesPerPart + lane + valid = logical < N + +deinterleaved valid: + logical = K * (chunk * lanesPerPart + lane) + p + valid = logical < N +``` + +padding lane 可以是 poison/undef,但 store、mask-producing op、reduction、scan、compress 和 +layout conversion 都必须显式带着 `valid` 信息,不能只依赖 physical register 宽度。 + +### Broadcast 不作为 Register Layout + +VMI surface 使用 `broadcast` 表达前端语义: + +```mlir +%v = pto.vmi.broadcast %x : f32 -> !pto.vmi.vreg<128xf32> +``` + +也就是: + +```text +for i in 0 .. N: + v[i] = x +``` + +这不是 logical lane 到 physical part/lane 的 placement relation,而是一个 value producer +可以延迟 materialize 的事实。`vmi.broadcast` 应保持为 semantic op 或 layout-polymorphic +producer: + +```text +consumer wants contiguous: + materialize scalar into contiguous physical parts + +consumer wants deinterleaved=2: + materialize same scalar into even/odd parts + +consumer wants deinterleaved=4: + materialize same scalar into p0/p1/p2/p3 parts +``` + +因此 broadcast 不进入 `#pto.vmi.layout` 目录。它由 `vmi-layout-assignment` 按 consumer +layout 重物化或下沉到 consumer lowering,而不是作为 `vreg` 的 layout kind。 + +#### Broadcast Materialization + +MLIR SSA value 不能对不同 use 拥有不同 result type。因此 scalar broadcast 的多 layout +适配不是“一个 VMI value 同时带多个 layout”,而是在 layout assignment 中按 use 重物化。 + +semantic VMI: + +```mlir +%b = pto.vmi.broadcast %x : f32 -> !pto.vmi.vreg<128xf32> +%u = pto.vmi.addf %a_contiguous, %b + : !pto.vmi.vreg<128xf32> +%v = pto.vmi.addf %a_split, %b + : !pto.vmi.vreg<128xf32> +``` + +如果 `%u` 需要 `contiguous`,`%v` 需要 `deinterleaved=2`,layout assignment 重写为: + +```mlir +%b0 = pto.vmi.broadcast %x + : f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +%u = pto.vmi.addf %a_contiguous, %b0 + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%b1 = pto.vmi.broadcast %x + : f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +%v = pto.vmi.addf %a_split, %b1 + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +physical materialization: + +```text +contiguous: + each physical chunk is filled with scalar x + +deinterleaved=2: + even part is filled with scalar x + odd part is filled with scalar x + +deinterleaved=4: + p0/p1/p2/p3 parts are all filled with scalar x +``` + +这要求 `pto.vmi.broadcast` 标记为 rematerializable,并满足 dominance:clone 位置必须被 scalar +operand `%x` dominate。跨控制流时,如果 scalar operand 可在各 predecessor/body 内使用, +优先在 consumer 所在 block 重物化;否则必须在控制流 join 处选择一个具体 layout 并 materialize。 + +这个规则只对 scalar-to-vector broadcast 是零语义风险的。低 rank vector 到高 rank vector 的 +broadcast 可能需要真实 lane replication/shuffle,不能默认按任意 consumer layout 免费重物化; +这类 broadcast 必须携带 broadcast map,并按普通 VMI op 做 layout assignment。 + +VMI register layout 目录因此是: + +```text +contiguous +deinterleaved=2 +deinterleaved=4 +``` + +channel split/merge、pack/unpack、memory stride、dynamic permutation、dynamic compaction +不在目录内。它们分别由 VMI semantic op、conversion、memory access plan、`vmi.permute`、 +`vmi.compress/active_prefix_index` 承接。 + +## Pipeline + +### 1. VMI Producer Boundary + +VMI core pipeline 从合法 VMI semantic IR 开始。Producer 可以是 TileLang/PTO lowering、手写 VMI +测试或其它外部 import 工具,但 producer 不属于 VMI core pipeline。 + +进入 VMI boundary 时必须满足: + +```text +all logical vector semantics are represented by pto.vmi semantic ops +all VMI data/mask values use surface VMI type without layout +no physical VPTO op is introduced +no hidden layout/mask/type side table is required +scalar/tensor/debug/transform boundary has already been handled by producer +``` + +该 boundary 需要 verifier gate。它验证 VMI IR 自身完整,不验证某个外部 source dialect 的 +coverage。 + +### 2. `vmi-layout-assignment` + +该阶段把无 layout VMI type 转换成 layout-assigned VMI type,推荐实现为独立 pass: + +```mlir +!pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +!pto.vmi.mask<128xpred> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +``` + +layout assignment 做三件事: + +1. 为每个 producer 选择 natural layout。 +2. 为每个 consumer 协调 operand/result layout。 +3. 在必要处插入: + +```mlir +pto.vmi.ensure_layout +pto.vmi.ensure_mask_layout +pto.vmi.ensure_mask_granularity +``` + +layout assignment 不是局部 pattern 贪心插 conversion,而是约束求解: + +```text +nodes: + every VMI SSA value + block arguments and region/function results + rematerializable producers such as scalar broadcast/iota/constant + +allowed layouts: + contiguous + deinterleaved=2 + deinterleaved=4 + filtered by element type, mask granularity, op capability, and target registry + +hard constraints: + op verifier constraints, such as same-layout elementwise operands + data/mask layout alignment for predicated ops + control-flow block argument/yield/call signature equality + external ABI layout boundary + source/sink contracts for width conversion, load/store, pack/narrow + +soft costs: + natural producer layout preference + ensure_layout materialization cost from target registry + store/load sink cost + rematerialization cost for broadcast/iota/constant + scratch/guarded fallback resource cost +``` + +求解顺序: + +```text +1. Build constraints for the whole region/SCC, including control-flow and call edges. +2. Propagate impossible layouts and required mask granularities. +3. Choose a minimum-cost layout for each node. +4. Use deterministic tie-break: prefer existing natural layout, then contiguous. +5. Insert ensure_layout/ensure_mask_layout or rematerialize producers at chosen use sites. +6. Re-run verifier gates; no hidden side table may be needed to interpret the result. +``` + +如果 hard constraints 冲突,或所有 legal paths 都缺 target capability/resource,报 +`VMI-LAYOUT-CONTRACT` 或更具体 diagnostic。diagnostic payload 必须列出 conflict value、producer +natural layout、consumer required layouts、available conversion paths 和被禁用的 fallback。 + +#### Consumer Layout Demand + +“consumer 需要某个 layout”不是前端语义要求,而是 layout assignment 为了让 operands/results +的 lane-map 对齐并减少 layout conversion 选择的共同 representation。 + +典型例子: + +```mlir +%w = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + +%b = pto.vmi.broadcast %scalar + : f32 -> !pto.vmi.vreg<128xf32> + +%s = pto.vmi.addf %w, %b + : !pto.vmi.vreg<128xf32> +``` + +`%w` 的 logical 语义是 `128xf32`,但 VPTO `f16 -> f32` 的自然 lowering 产生 even/odd +两路 parts: + +```text +w_even[i] = extf(a[2*i]) +w_odd[i] = extf(a[2*i+1]) +``` + +因此 `%w` 的 natural layout 是: + +```mlir +#pto.vmi.layout +``` + +`addf` 是 layout-polymorphic elementwise op。它有两个合法选择: + +```text +choice A: + materialize %w to contiguous + materialize broadcast to contiguous + do one contiguous add sequence + +choice B: + materialize broadcast directly as deinterleaved=2 + do add on even parts and odd parts separately + keep result as deinterleaved=2 +``` + +choice B 通常更便宜,因为不需要把 `%w_even/%w_odd` 先 interleave 成 contiguous。broadcast +能直接适配 `deinterleaved=2`,是因为它的 logical lanes 全部等于同一个 scalar: + +```text +b_even = [scalar, scalar, ...] +b_odd = [scalar, scalar, ...] +``` + +所以这里说 `addf` consumer “需要” `deinterleaved=2`,准确含义是: + +```text +layout assignment 选择 deinterleaved=2 作为 addf 的共同 operand/result representation, +因为其中一个 operand 的 natural layout 已经是 deinterleaved=2,并且 broadcast 可零语义风险地重物化到该 layout。 +``` + +### 3. `vmi-to-vpto` + +该阶段把 layout-assigned VMI type 做 1:N physical type conversion,推荐实现为独立 pass: + +```text +!pto.vmi.vreg<128xf32, contiguous> + -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +!pto.vmi.vreg<128xf32, deinterleaved=2> + -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +!pto.vmi.vreg<256xf32, deinterleaved=4> + -> !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32> + +!pto.vmi.mask<128xb32, deinterleaved=2> + -> !pto.mask, !pto.mask +``` + +需要 internal projection/materialization op: + +```mlir +%p0, %p1 = pto.vmi.unpack %v + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%v = pto.vmi.pack %p0, %p1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +`pack/unpack` 不是新的 layout carrier,只是 layouted `vmi.vreg` 到 physical VPTO parts 的 +projection/materialization。 + +`unpack` 必须能作用在任意 SSA value 上,不能依赖 defining op。VMI value 可以来自 block +argument、`scf.if` result、loop iter_arg、function argument 或 call result;这些 value 没有 +可 look-through 的 layout materialization defining op。 + +`pack/unpack` 的 operand/result 个数必须使用 Physical Arity 公式推导。非整 tile 时,最后一个 +chunk 的 padding lane 仍属于 physical value,但不属于 logical value。 + +### Layout Conversion Materialization + +`pto.vmi.ensure_layout` / `pto.vmi.ensure_mask_layout` 是 logical-value-preserving conversion: + +```text +for every logical lane i: + dst.logical[i] = src.logical[i] +for every padding lane: + dst padding remains unobservable +``` + +source/result layout 完全相同时,`ensure_layout` / `ensure_mask_layout` 是 identity forwarding; +即使存在 partial/tail physical chunk,也不需要 target materialization path。source/result layout +不同时才需要 registry 证明 preserving conversion 及其 full-chunk/tail 处理策略。当前 direct path +允许 equal-arity partial/tail conversion:source/result 的 physical arity 必须相同,且两边都能组成完整 +contiguous/deinterleaved=2/4 `intlv` materialization group;arity-changing partial conversion 和 uneven +deinterleaved groups 继续报 unsupported。 + +合法 materialization path 必须来自 target registry: + +```text +same layout: + no-op + +contiguous <-> deinterleaved=2: + direct interleave/deinterleave register op, load/store dist sink/source, + or scratch/ordered fallback + +contiguous <-> deinterleaved=4: + direct 4-way layout sink/source, proven staged 2-way sequence, + or scratch/ordered fallback + +deinterleaved=2 <-> deinterleaved=4: + convert through contiguous only if both legs have preserving paths, + otherwise use scratch/ordered fallback or report VMI-LAYOUT-CONTRACT +``` + +`deinterleaved=4` 不能默认假设“两次二路 interleave”就是正确 materialization。只有当 staged +sequence 的 lane map 被 registry 证明等价于: + +```text +logical = 4 * lane + p +``` + +才允许使用。否则必须选择 store sink、scratch buffer 或 diagnostic。 + +### Verifier Gates + +每个 pipeline 边界都必须有 hard verifier,不能把残缺 IR 留给后续 pass 猜测: + +```text +at VMI producer boundary: + every logical vector value is represented by !pto.vmi.vreg / !pto.vmi.mask + every logical vector operation is represented by pto.vmi semantic op + no physical VPTO op has been introduced + no hidden layout/mask/type side table is required to interpret a value + +after vmi-layout-assignment: + every !pto.vmi.vreg / !pto.vmi.mask has #pto.vmi.layout + layout kind is one of contiguous/deinterleaved=2/deinterleaved=4 + mask granularity matches each consumer + branch operands, block arguments, function arguments/results, and yields agree on layout + no hidden layout/mask/type side table is required to interpret a value + +before vmi-to-vpto: + every pto.vmi.ensure_layout / ensure_mask_layout has a registered preserving materialization path + every fallback path has resource decision and dominance/lifetime proof + +after vmi-to-vpto: + no pto.vmi op or type remains + no UnrealizedConversionCastOp remains + no pto.vmi.pack/unpack/ensure_* helper remains + every physical value arity matches the Physical Arity helper +``` + +layout、mask、valid-lane 和 physical arity 信息必须存在于 IR type/attr/op operand 中,或可由它们 +纯函数推导;不能依赖 C++ side table。违反这些 gate 时使用 `VMI-PASS-INVARIANT` 或更具体的 +diagnostic,例如 `VMI-LAYOUT-CONTRACT`、`VMI-MEMORY-ACCESS`、`VMI-RESIDUAL-OP`。 + +## Layout Assignment 规则 + +### Elementwise + +same-layout operands: + +```text +vmi.addf/vmi.mulf/vmi.cmpi/vmi.select + fan out per physical part + result keeps operand layout +``` + +different-layout operands: + +```text +choose consumer-demanded layout +insert ensure_layout for other operands +vmi.broadcast can rematerialize in consumer-demanded layout +``` + +### Width Conversion + +典型 natural layout: + +```text +vmi.extf 128xf16 -> 128xf32: + source contiguous f16 + result deinterleaved=2 f32 + +vmi.extf 256xf8 -> 256xf32: + source contiguous f8 + result deinterleaved=4 f32 + +vmi.truncf 128xf32 -> 128xf16: + source may be deinterleaved=2 f32 + result contiguous f16 if pack/store sink requires contiguous + +vmi.truncf 256xf32 -> 256xf8: + source may be deinterleaved=4 f32 + result contiguous f8 if pack/store sink requires contiguous +``` + +Direct `vcvt` lowering 可以覆盖同一 contract 下的 partial/tail case:`extf` 的 logical lanes +必须仍然装进一个 contiguous narrow source physical chunk,并自然产生 deinterleaved=2/4 result; +`truncf` 的 deinterleaved=2/4 source parts 必须能 pack 成一个 contiguous narrow result chunk。 +这些路径允许 VPTO 对 padding lanes 执行 conversion,但 padding 只能流向 result padding lanes, +不能变成 logical result。 + +Mask granularity assignment 把 surface `mask` 转成 concrete +`mask`。consumer 决定所需 granularity: + +```text +f16 op consumes mask +f32 op consumes mask +f8 op consumes mask +``` + +如果 data 从 f16 扩到 f32,后续 f32 consumer 需要: + +```mlir +!pto.vmi.mask +``` + +不能继续复用 `mask`。 + +mask-producing op 的 granularity 不是 producer 固有属性: + +```text +vmi.create_mask / constant_mask: + logical predicate producer; granularity chosen by users + create_mask 的 logical prefix 语义不受目标 PAT_VL token 集合限制; + unsupported PAT_VL count 可以用 pto.plt_b* materialize + constant_mask 的 non-prefix chunk 用 prefix 差分和 predicate boolean ops materialize + +vmi.cmpf/cmpi: + result logical lane count follows compared data + concrete granularity chosen by mask consumers, not by compare element type alone + +multi-use mask: + choose one concrete granularity for the original SSA value + insert ensure_mask_granularity or rematerialize cheap mask producers per use +``` + +`ensure_mask_granularity` 必须 preserve logical predicate lane `mask[i]`。当前 direct lowering 对 +concrete `b8/b16/b32` granularity 使用 `pto.punpack` 做 widening,使用 `pto.ppack` 加 `pto.por` +做 narrowing,并按需要串联相邻级别完成 `b8 <-> b32`。如果目标缺少 predicate rearrangement 或 +granularity conversion,报 `VMI-LAYOUT-CONTRACT`,不能把 b16/b32 mask 当成同一 physical bit +pattern 直接复用。 + +### Predication + +Region-style mask 不作为长期 region op 保留到 VPTO lowering。producer 必须把 mask thread 到 +具体 VMI op: + +```text +masked load/store: + use pto.vmi.masked_load / pto.vmi.masked_store + +masked arithmetic with passthru: + compute candidate result + merge with passthru by pto.vmi.select(mask, candidate, passthru) + +masked reduction/scan: + inactive and padding lanes are excluded from the logical iteration +``` + +如果一个 masked op 的 inactive lane 语义要求“不读内存”或“不执行有副作用操作”,不能用 +full op + select 伪装;必须使用对应 masked VMI op、ordered fallback,或报 target capability +diagnostic。 + +### Memory Ops + +VMI memory op 表达 memory semantics,不表达 register layout。lowering 先构造 access plan: + +```text +base +logical lane count +logical_shape attr, if any +lane-to-address map +contiguity +block-strided row classification +read/write validity mask +padding plan +footprint safety proof +target OOB capability +``` + +memory access map 不是 register layout。比如 `tile_read` 的 memref stride 可以识别 +block-strided rows,并选择 `vsldb`,但 result `vmi.vreg` 的 register layout 仍由 +layout assignment 决定。 + +Producer-specific packed element view 不进入 VMI type。它们必须在 VMI memory op 之前规范化为 +element memref + access map: + +```text +memref> + -> base element type T + -> logical address = original index * K + vector_lane +``` + +normalization 必须保留 offset、stride、alignment、memory space 和 alias 信息。无法证明等价 +element view 时,报 `VMI-MEMORY-ACCESS`,不能把 packed element memref 伪装成 contiguous VMI +load/store。 + +direct path examples: + +```text +contiguous full-safe: + vlds/vsts + !pto.ptr source/destination must be UB-backed; memref source/destination + must either have unknown memory space at this stage or explicitly use + #pto.address_space + +32B block-strided rows with block-uniform mask: + vsldb/vsstb + +interleave/deinterleave boundary: + vldsx2/vstsx2 dist or explicit rearrangement + +indexed memory: + gather/scatter if inactive and duplicate-index semantics match +``` + +GM-backed VMI memory is semantic input, not a direct vector load/store target. +Current `vmi-to-vpto` direct memory lowering emits `pto.vlds`, `pto.vldsx2`, +`pto.vsts`, or `pto.vstsx2`; those VPTO ops operate on UB-backed vector memory. +If a `pto.vmi.load/store/tile_read/tile_write` still names GM at this stage, +the missing step is an explicit memory movement/materialization plan, scratch +plan, or UB view normalization. Otherwise the pass must report `VMI-UNSUPPORTED` +instead of silently producing illegal VPTO. + +### Control Flow + +VMI layouted type 可以跨 internal control flow,但 public ABI 不允许 layout leak。 + +MLIR conversion framework 可以做 region/block/signature 的 structural type conversion,但它不会 +自动决定 layout。`vmi-layout-assignment` 必须先把每个 block argument、region yield、branch +operand 和 call boundary 的 layout 固定下来,再交给 `vmi-to-vpto` 做 1:N type conversion。 + +`scf.if` join: + +```text +if all incoming layouts equal: + keep that layout +else: + choose consumer-demanded layout, otherwise contiguous + insert ensure_layout / ensure_mask_layout before yield +``` + +`scf.for` loop-carried value: + +```text +init layout == iter_arg layout == yield layout == loop result layout +``` + +如果 loop body repeatedly consumes deinterleaved=2/deinterleaved=4,优先保持该 natural layout;如果只有 loop +exit 需要 contiguous,则在 exit 后转换,不在 backedge 每轮转换。 + +`cf.br` / `cf.cond_br` block arguments: + +```text +target block argument has one chosen layout +each predecessor operand is converted to that layout before branch +``` + +function boundary: + +```text +internal VMI functions: + function argument/result layout is part of layout assignment + all callsites and returns must agree with the specialized signature layout + +external/public ABI: + must not expose #pto.vmi.layout + materialize to memory, scalar ABI, or final physical PTO ABI before crossing boundary +``` + +recursive or mutually recursive VMI functions require SCC fixed-point layout assignment. If a stable signature +layout cannot be found without inserting conversion on every cycle edge, choose `contiguous` at the function +boundary and keep deinterleaved layouts inside the function body. + +## VMI Op Families + +本节列出 VMI 必须拥有的 semantic op。assembly form 可在 ODS 中微调,但语义边界应保持。 +表中用 `/` 写在一起的名字表示多个独立 op,不表示一个 variadic opcode。去重后,正式 +semantic op 数量是 75 个。 +`ensure_layout`、`ensure_mask_layout`、`ensure_mask_granularity`、`pack`、`unpack` 是内部 +layout/materialization helper,不计入 semantic op;如果把 helper 也算作 VMI op,总数是 80 个。 + +该总表描述目标 semantic surface,不等价于当前第一批实现清单。当前 implementation slice +以 `docs/designs/vmi-implementation-manual.md` 的 Slice 1 为准;例如 `pto.vmi.from_elements` +虽然属于目标 construction family,但没有 scalar lane insert、vreg immediate 或 scratch +materialization plan 前不能宣称 direct lowering 已支持。 + +```text +construction: 6 +memory: 10 +arithmetic/conversion: 36 +permutation/mask/reduction/channel: 23 +semantic total: 75 +internal helpers: 5 +total including helpers: 80 +``` + +### Construction + +| Op | 语义 | +|---|---| +| `pto.vmi.constant` | logical constant vector,layout assignment 决定 materialization | +| `pto.vmi.broadcast` | scalar 或低 rank value broadcast 到 `vreg` | +| `pto.vmi.iota` | 从 scalar base 生成 logical lane index/value vector | +| `pto.vmi.from_elements` | 按 logical lane order 构造 | +| `pto.vmi.create_mask` | prefix 或 logical-shape mask | +| `pto.vmi.constant_mask` | static logical predicate mask, including non-prefix masks | +| `pto.vmi.mask_and/or/xor/not` | logical predicate elementwise operation | + +### Memory + +```mlir +%v = pto.vmi.load %base[%idx] + : memref -> !pto.vmi.vreg<128xf32> + +pto.vmi.store %v, %base[%idx] + : !pto.vmi.vreg<128xf32>, memref + +%v = pto.vmi.masked_load %base[%idx], %mask, %passthru + : memref, !pto.vmi.mask<128xpred>, + !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + +pto.vmi.masked_store %v, %base[%idx], %mask + : !pto.vmi.vreg<128xf32>, memref, !pto.vmi.mask<128xpred> + +%g = pto.vmi.gather %base[%indices], %mask, %passthru + : memref, !pto.vmi.vreg<128xindex>, !pto.vmi.mask<128xpred>, + !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + +pto.vmi.scatter %v, %base[%indices], %mask + : !pto.vmi.vreg<128xf32>, memref, + !pto.vmi.vreg<128xindex>, !pto.vmi.mask<128xpred> + +%e = pto.vmi.expand_load %base[%idx], %mask, %passthru + : memref, !pto.vmi.mask<128xpred>, + !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + +pto.vmi.compress_store %v, %base[%idx], %mask + : !pto.vmi.vreg<128xf32>, memref, !pto.vmi.mask<128xpred> +``` + +`masked_load` 的 inactive lane 不能产生 memory read。full load + select 只有在 inactive +lane 地址 safe-readable 时才合法。 +当前直接 lowering 只覆盖 contiguous result/passthru/mask:full physical chunks 直接 `vlds + vsel`; +partial/tail chunks 必须先证明完整 physical read footprint safe-readable,否则报 `VMI-UNSUPPORTED`。 +在第一阶段的矩阵 quant/dequant lowering 中,默认假设 UB 中的行数据按元素连续,tail load 可以安全读满 +当前物理 vreg;tail 的对外写入效果仍由 `pto.vmi.create_mask` + `pto.vmi.masked_store` +约束。严格 no-read tail 不是这个默认路径的语义,后续通过 stable gather 模式承接:该模式应把 +contiguous tail masked load 转为 `VGATHER2 + Pg` 风格的 per-lane non-faulting load。当前 +`vmi-to-vpto` 只预留 `enable-stable-gather-masked-load` 开关;开关打开且遇到 +`pto.vmi.masked_load` 时必须给 TODO diagnostic,不能退化成普通 `vlds + vsel`。 + +普通 `vmi.store` 和 `vmi.masked_store` 的 contiguous tail 可以用 true predicate store 承接: +full physical chunk 使用 all-true mask 或用户 mask,最后一个 partial chunk 使用 prefix valid-lane +mask;因此普通 `vmi.store` direct lowering 要求 value element width 能对应 +`pto.mask`。`masked_store` 先把用户 mask 与 valid-lane mask 做 logical AND。 +deinterleaved=2/4 tail store/masked_store 只有在每个 deinterleaved part 的 physical chunk 数相同、可先组成完整 +`vintlv/pintlv` group 并 materialize 成 contiguous chunks 时才直接支持;materialized 后 active +lane 为 0 的 padding-only chunk 不发 store。load padding 仍需要独立的 access plan,不能通过未受保护的 +full-footprint memory op 偷跑。 + +`gather/scatter` 使用 logical lane order 解释 `%indices`,index 单位和 memref element type +一致。`gather` inactive lane 返回 `%passthru[i]` 且不能读内存。`scatter` inactive lane 不能写 +内存;如果 active lanes 可能写同一地址,direct VPTO lowering 必须证明目标语义与 logical +lane order 等价,否则使用 ordered fallback 或报 `VMI-MEMORY-ACCESS`。 + +当前 `gather` direct lowering 覆盖一个保守子集: + +```text +source: + !pto.ptr + +layout: + result / indices / mask / passthru all contiguous + all physical chunks are full, so padding lanes cannot trigger memory reads + +type: + T is 32-bit element type + indices are signless or unsigned i32 + mask granularity is b32 + +lowering: + gathered = pto.vgather2_bc source, indices, mask + result = pto.vsel gathered, passthru, mask +``` + +`VGATHER2_BC` false predicate lanes do not read memory but produce zero result lanes. VMI `gather` requires false +lanes to preserve passthru, so the `vsel` is semantically required, not an optimization artifact. `f16/b16/f8/i8` +gather, tail gather, non-contiguous layout, memref/gm source, and fallback through guarded scalar load or scratch are +future target-capability paths. + +当前 `scatter` direct lowering 只在 VMI IR 携带显式 no-conflict proof 时启用: + +```mlir +pto.vmi.scatter %v, %base[%indices], %mask {indices_unique} + : !pto.vmi.vreg<64xf32>, !pto.ptr, + !pto.vmi.vreg<64xi32>, !pto.vmi.mask<64xpred> +``` + +`indices_unique` 的含义是:所有 active logical lanes 的 `%indices` 两两不同。这个 proof 可以来自 +producer 的静态分析、前端语义或上游 canonicalization;VMI lowering 不从 runtime 值猜测它。direct +path 的其它限制与 gather 对齐:UB pointer destination、contiguous full physical chunks、32-bit value +element、i32 indices 和 b32 mask。没有 `indices_unique` 时,`vmi-to-vpto` 必须诊断,而不能直接发 +`VSCATTER`,因为 `VSCATTER` 对重复 index 的 grant procedure 是目标相关/未定义的,不等价于 VMI +logical lane order。 + +`expand_load/compress_store` 表达 masked contiguous stream,不是 arbitrary indexed access: + +```text +expand_load: + k = 0 + for i in 0 .. N: + if mask[i]: + result[i] = base[idx + k] + k += 1 + else: + result[i] = passthru[i] + +compress_store: + k = 0 + for i in 0 .. N: + if mask[i]: + base[idx + k] = value[i] + k += 1 +``` + +Current direct `expand_load` lowering supports two paths. The first is the +degenerate all-active case: + +```text +mask == all_true => expand_load(base[idx], mask, passthru) == load(base[idx]) +``` + +The accepted mask must be statically proven all active through +`pto.vmi.create_mask` with constant `active_lanes >= N`, or a dense all-true +`pto.vmi.constant_mask`. The result, passthru, and mask layouts must be +contiguous. Partial/tail chunks still need the same safe full-read proof as +ordinary `vmi.load`; otherwise the direct path reports `VMI-UNSUPPORTED`. + +The second direct path covers one full 32-bit UB physical chunk with a runtime +mask: + +```text +base' = pto.addptr base, idx +indices = pto.vusqz(zero_i32_carrier, mask) +gathered = pto.vgather2_bc base', indices, mask +result = pto.vsel gathered, passthru, mask +``` + +It requires contiguous result/passthru/mask layout, 32-bit element type, b32 +mask granularity and one full physical chunk. Multi-chunk runtime masks need a +cross-chunk prefix-count carry; f16/b16/f8/i8 need a gather packing contract. +Unsupported cases still require guarded load, scratch fallback, or diagnostic, +and must not be lowered as a plain full load. + +Current direct `compress_store` lowering is intentionally narrower than the +surface semantics. It requires contiguous value/mask layout, exactly one full +physical chunk, and a UB `!pto.ptr` destination. The direct sequence is: + +```text +store_base = pto.addptr base, idx +sqz = pto.vsqz value, mask +align0 = pto.init_align +align1 = pto.vstur align0, sqz, store_base, "POST_UPDATE" +pto.vstar align1, store_base +``` + +The paired `vstur` consumer is what makes the later VPTO LLVM emitter select +`VSQZ #st=1`; emitting `vsqz` without that store consumer is only register +compress. Full physical chunk is required in this first path because padding +mask lanes must not be squeezed into memory. Multi-chunk `compress_store` +needs cross-chunk compaction and SQZN/store-state planning; deinterleaved +layouts need logical lane order reconstruction before the store chain. + +### Index And Address Contract + +`!pto.vmi.vreg` 是 logical index vector,不是 physical address vector。进入 VPTO 前, +index 必须按 target registry legalize 成目标支持的整数宽度: + +```text +index legalization: + choose target index bitwidth + prove every lane value fits, or insert preserving extend/trunc/check sequence + preserve signedness required by the consuming op +``` + +memory op 的 index 单位是 memref element,不是 byte。byte address 由 memref layout、element +size、base offset 和 lane index 共同计算: + +```text +logical element offset -> memref affine/strided map -> byte address +``` + +`gather/scatter` 的 `%indices`、`expand_load/compress_store` 的 active-prefix offset、`iota` 生成 +的 lane index 都必须在同一套 address unit 下解释。不能把 element index 直接当 byte offset,也 +不能在没有 range proof 时把 `index` 静默截断成较窄整数。 + +`active_prefix_index(mask)` 返回当前 lane 之前的 active lane 数: + +```text +idx[i] = popcount(mask[0 .. i)) +``` + +因此 `expand_load/compress_store` active lane 使用 `base + idx[i]`。如果目标缺少 prefix-popcount +或 index-vector lowering,必须选择 index-buffer/guarded fallback,或报 `VMI-FALLBACK-RESOURCE` +/ `VMI-LAYOUT-CONTRACT`。 + +`tile_read/tile_write` 承接 transfer-style padding 和 multi-dimensional access semantics: + +```mlir +%tile = pto.vmi.tile_read %view[%c0, %c0], %pad, %mask + {logical_shape = [8, 8], + permutation_map = affine_map<(d0, d1) -> (d0, d1)>} + : memref<8x8xf32, strided<[?, 1], offset: ?>>, f32, + !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<64xf32> + +pto.vmi.tile_write %tile, %view[%c0, %c0], %mask + {logical_shape = [8, 8], + permutation_map = affine_map<(d0, d1) -> (d0, d1)>} + : !pto.vmi.vreg<64xf32>, memref<8x8xf32, strided<[?, 1], offset: ?>>, + !pto.vmi.mask<64xpred> +``` + +`tile_read/tile_write` 只承接 memref memory semantics。producer 的 transfer-style read/write 如果作用在 +tensor source/destination 上,必须在进入 VMI 前 bufferize 成 memref access plan,或退出 PTO +路线。tensor write-back style 语义是产生新 tensor,不是对 memref 的 memory effect;不能把它 +伪装成 `pto.vmi.tile_write`。未处理的 tensor transfer 报 `VMI-TENSOR-BOUNDARY`。 + +`tile_read` invalid lane 的 result 必须等于 padding,不是后继 op 的 inactive lane。 + +`tile_read` lowering 必须先构造三个对象: + +```text +validMask(result lane): + logical lane is inside result shape + and explicit transfer mask maps to true + and source address is in bounds + +paddingValue(result lane): + scalar padding: same value for every invalid lane + vector-element padding: select element by suffix coordinate + broadcast/permuted padding: apply the same result-lane map as data + +safeReadProof: + proves the actual physical load footprint is safe-readable + independent from validMask +``` + +`validMask=false` 只说明 result lane 应等于 padding,不说明该 lane 的 source address 可以被读。 +因此 `tile_read` 的 preserving lowering 决策是: + +```text +safeReadProof == full and validMask all-true: + direct load + +safeReadProof == full and validMask not all-true: + loaded = full load + pad = materialize paddingValue in result layout + result = select(validMask, loaded, pad) + +target has true masked/non-faulting load: + loaded = masked load with inactive lanes not read + pad = materialize paddingValue in result layout + result = select(validMask, loaded, pad) unless inactive result is already padding + +safeReadProof != full: + split full-safe and partial paths, or + fill scratch with paddingValue, guarded-copy only valid lanes, then load scratch, or + use guarded scalar/vector fallback +``` + +First implementation stage note: + +```text +The padding-preserving branches above are semantic requirements for the full +design, but they are not part of the first-stage VMI implementation. The first +stage may lower only all-valid direct reads, or physical-tail reads whose extra +lanes are outside the logical VMI value and remain unobservable. If invalid +logical lanes require transfer_read paddingValue materialization, true +masked/non-faulting load, scratch, or guarded fallback, lowering must stop with +the implementation diagnostic code VMI-UNSUPPORTED instead of emitting an +approximate full load. +``` + +如果所有 preserving paths 都因 target capability 或 option 被禁用,报 `VMI-MEMORY-ACCESS`, +payload 必须指出缺的是 unsafe partial `tile_read` padding-preserving path。 + +`tile_write` 没有 padding value,但有 write-valid mask: + +```text +writeMask(source lane): + logical lane is inside source shape + and explicit transfer mask maps to true + and destination address is in bounds +``` + +`writeMask=false` 的 lane 不能产生 memory effect。只有 full physical footprint safe-writable 且 +writeMask all-true 时,才能使用 predicate-ignored store。partial write 必须使用 true masked +store、split/guarded fallback、scatter-like fallback,或报 `VMI-MEMORY-ACCESS`。 +当前 direct `vmi.tile_write` 只覆盖 flat contiguous tail:最后一个 partial chunk 使用 prefix +valid-lane predicate 发 `vsts`,同样要求 value element width 能对应 `pto.mask`。 +deinterleaved=2/4 tail 只有在能先完整 materialize 到 contiguous +chunks 时直接支持,padding-only materialized chunk 不发 store;带 transfer mask coordinate remap 的 +tile write 仍必须走独立 access plan。 + +explicit transfer mask 的坐标属于 transfer access space,不一定等于 flattened result/source lane +坐标。non-minor-identity transfer 必须先做 predicate coordinate remap;缺少 remap capability 时, +diagnostic 必须点名 transfer mask coordinate remap,而不是泛化成普通 memory failure。 + +### Arithmetic And Conversion + +VMI 不复用外部 elementwise arithmetic op。需要定义对应 VMI op: + +| Semantic | VMI op | +|---|---| +| float binary | `pto.vmi.addf/subf/mulf/divf/minf/maxf` | +| float unary | `pto.vmi.negf/sqrt/exp/ln/relu` | +| integer binary | `pto.vmi.addi/subi/muli` | +| bitwise/shift | `pto.vmi.andi/ori/xori/not/shli/shrui` | +| fused multiply-add | `pto.vmi.fma` | +| float casts | `pto.vmi.extf/truncf` | +| bitcast | `pto.vmi.bitcast` | +| compare/select | `pto.vmi.cmpf/cmpi/select` | + +Integer div/rem, arithmetic right shift, integer casts, int-float casts, and +index casts are intentionally not in the current VMI surface. They need +explicit signedness, rounding, saturation, overflow/remainder, and VPTO target +contracts before ODS ops are introduced. + +producer constant 转成 `pto.vmi.constant`,包括 dense、splat 和 rank-0 logical vector。 +constant 的 element type、shape、splatness 和 poison/undef 属性如果存在,必须保留到 VMI +constant attr;padding physical lane 仍按 VMI padding rule 处理,不能把 padding lane 当成用户 +constant lane。 + +当前 VPTO direct lowering 只把 scalar broadcast 和 splat constant materialize 成 +`pto.vdup`。这条路径与逐元素 op 一样要求 physical element width 能对应 +`pto.mask`;其它 element type 或非 splat constant 必须先有明确的 materialization +contract,否则报 `VMI-UNSUPPORTED`。 + +VMI arithmetic op 必须保留原 `arith` op 的 numeric contract: + +```text +floating point: + fastmath flags + rounding mode, if present + NaN / signed-zero / inf behavior implied by flags + +integer: + signedness of div/rem/compare/extend + overflow flags such as nsw/nuw when present + truncation and extension width rules + +compare/select: + cmpf/cmpi predicate + select condition mask granularity and layout +``` + +lowering 不能因为 VPTO 有更快指令就加强或放松这些属性。比如没有 fastmath 允许时,`fma` +不能拆成 `mulf + addf`,也不能把 `mulf + addf` 合成 `fma`;带 `nsw/nuw` 的 integer op +可以利用 flag 做优化,不带 flag 的 op 必须保持 wraparound/defined overflow 语义。 + +`pto.vmi.fma` 不能默认拆成 `mulf + addf`。`bitcast` 只有在当前 layout 下 bit grouping +physically adjacent、且每个对应 physical chunk 的 logical bit footprint 相同时才能 direct; +padding bits 只能流向 result padding bits。否则需要 layout conversion、scratch materialization +或 target capability diagnostic。 + +当前 VPTO direct lowering 对逐元素算术、逻辑、比较和 select 还有一条共同硬约束:物理 element +width 必须能对应到 `pto.mask`。因此 VMI 语义层可以承载 `index` 或 `f64` +这类类型,但在没有独立 lowering contract 前,`vmi-to-vpto` 必须报 `VMI-UNSUPPORTED`, +不能让 OneToN conversion 或 residual gate 隐式失败。 + +这条共同约束不是唯一约束。某些目标 VPTO/VISA op 还有自己的 element type contract, +必须在 `vmi-to-vpto` preflight 中单独检查。当前 direct lowering 明确承诺: + +```text +addf/subf/mulf: f16/bf16/f32 +divf: f16/f32 +minf/maxf: f16/bf16/f32 +negf/absf: f16/f32 +sqrt/exp/ln: f16/f32 +relu: f16/f32 +absi: signless/signed i8/i16/i32 +cmpf: f16/bf16/f32 +cmpi: signless/signed/unsigned i8/i16/i32 +``` + +因此 bf16/f8 虽然可能是合法 VMI float-like type 且能 materialize b16/b8 predicate mask, +但只要目标 direct op 不承诺该 element type,`vmi-to-vpto` 就必须先报 +`VMI-UNSUPPORTED`,直到定义对应 materialization 或 VPTO 目标能力。 + +当前 direct lowering 将 `pto.vmi.fma %lhs, %rhs, %acc` 映射为每个 physical part 上的 +`pto.vmula %acc_part, %lhs_part, %rhs_part, %all_true_mask`。该路径只承诺 f16/bf16/f32 +floating-point fused multiply-add;整数 multiply-accumulate、带 rounding/fastmath 变体或需要 +不同 accumulator 精度的形式必须单独建模,不能复用这个 op 偷换语义。 + +### Permutation, Mask, Reduction, Channel + +| Semantic | VMI op | +|---|---| +| static lane map | `pto.vmi.shuffle` | +| dynamic indexed lane map | `pto.vmi.permute` | +| logical interleave/deinterleave | `pto.vmi.interleave/deinterleave` | +| shape metadata change | `pto.vmi.shape_cast/reshape/transpose` | +| subvector update | `pto.vmi.slice/insert_slice/insert_element` | +| predicate logic | `pto.vmi.mask_and/or/xor/not` | +| prefix active index | `pto.vmi.active_prefix_index` | +| register compaction/expansion | `pto.vmi.compress/expand` | +| reduction/scan | `pto.vmi.reduction/scan` | +| contraction | `pto.vmi.contract/outerproduct` | +| channel split/merge | `pto.vmi.channel_split/channel_merge` | + +`pto.vmi.shuffle` 表达完整 static lane map。当前 VPTO direct lowering 先识别 physical chunk +forwarding:每个 result physical chunk 的所有非 padding lanes 必须来自同一个 source chunk, +且 source lane number 等于 result lane number;result padding lanes 不参与证明,forward 过来的 +物理 padding lanes 仍然不可观察。否则在每个 result physical chunk 都来自同一个 source chunk、 +result chunk 没有 padding lane、且 source lane index 是 ASC/DESC 连续序列时,用 `pto.vci` +生成 index vector 并发 `pto.vselr`。任意非 affine permutation、以及需要 tail lane 重排但无法安全 +materialize tail index vector 的场景,仍然需要通用 index-vector materialization、scratch fallback +或 target capability diagnostic。 + +`channel_split/channel_merge` 是 PTO-specific semantic op。它们表达用户按 channel 编程时的 +多个 logical VMI values,不能降格成 +`#pto.vmi.layout` kind。它们必须拥有 static shuffle 等价定义,canonicalization 可以双向进行: +识别出的 shuffle pattern 可以变成 channel op,channel op 也可以合法展开回 shuffle。 +Direct lowering 还必须证明 physical group 完整;否则即使 logical shuffle 语义成立,也要报 +target capability/materialization diagnostic,而不是让 OneToN pattern 在中途失败。 + +### Internal Layout Helpers + +这些 op 只允许存在于 VMI lowering 的中间阶段,不能作为 VMI semantic surface,也不能残留到 +physical VPTO 之后: + +| Op | 语义 | +|---|---| +| `pto.vmi.ensure_layout` | data vreg layout-preserving conversion | +| `pto.vmi.ensure_mask_layout` | mask layout-preserving conversion | +| `pto.vmi.ensure_mask_granularity` | logical predicate-preserving granularity conversion | +| `pto.vmi.unpack` | layouted VMI value projection to physical VPTO parts | +| `pto.vmi.pack` | physical VPTO parts materialized as one layouted VMI value | + +`active_prefix_index` 语义是: + +```text +idx[i] = popcount(mask[0 .. i)) +``` + +VMI surface 不暴露 VPTO `vusqz` 的无意义 source operand;需要 type/ABI carrier 时在 +`vmi-to-vpto` late materialize。 + +当前直接 lowering 只覆盖 contiguous 单物理 chunk。这个 case 可以用 `pto.vusqz` 精确承接: +`vmi-to-vpto` 先 materialize 一个 zero vreg 作为 VPTO `vusqz` 的 source carrier,再把 VMI mask +作为 governing predicate 传入。多物理 chunk 需要把前一 chunk 的 active count carry 到后一 chunk; +deinterleaved layout 还需要按逻辑 lane 顺序重建 prefix,因此不能逐物理 part 独立发 `vusqz`。 + +`vmi.compress(source, mask)` 语义是按 logical lane order 保留 active source lane 并压缩到结果前缀。 +当前直接 lowering 只覆盖 contiguous 单个 full physical chunk,可以用 `pto.vsqz(source, mask)` 承接。 +partial/tail chunk 不能直接走 `vsqz`,因为 padding mask lane 如果为 true,padding source lane 可能被 +压缩到可观察的 result 前缀。多物理 chunk 需要跨 chunk compaction;`compress_store` 还涉及 +`VSQZ #st=1` 与 `VSTUR`/`SQZN` 的配对约束,不能由 register `compress` 自动推出。 + +`vmi.compress_store(value, base[idx], mask)` 语义是按 logical lane order 把 active lane 写成连续 +memory stream。当前直接 lowering 只覆盖 contiguous、单个 full physical chunk 和 UB pointer +destination,并发出 `pto.vsqz -> pto.vstur POST_UPDATE -> pto.vstar` 的完整 store-state chain。非 full +chunk 暂不直接 lowering,因为 padding mask lane 可能被硬件 squeeze 成额外写出;multi-chunk 需要 +跨 chunk active count 和 SQZN FIFO/VSTUR 配对计划。 + +`shape_cast/reshape/transpose` 必须区分 metadata change 和 lane movement: + +```text +shape_cast / reshape: + preserve row-major flattened lane order + produce explicit result logical_shape attr + +transpose / flat_transpose: + changes logical lane order according to permutation + must lower through shuffle/permute/layout conversion/direct transpose capability +``` + +这些 op 的 source/result shape、permutation 和 broadcast map 都是 op attrs。VMI lowering 不能从 +producer defining op 或 side table 推断缺失 shape。 + +低 rank vector 到高 rank vector 的 broadcast 也不能当成 scalar broadcast 免费重物化。它必须 +保存 broadcast map: + +```text +result[indices] = source[broadcast_map(indices)] +``` + +只有 scalar-to-vector broadcast 可以按 consumer layout 任意重物化。 + +`iota` 是 lane index generation 的 VMI 表达: + +```text +iota(base, ASC): + result[i] = base + i + +iota(base, DESC): + result[i] = base - i +``` + +第一版 `iota` 的 `T` 跟随 VPTO `vci` 能承接的元素类型:integer 8/16/32 和 f16/f32。 +可变 step 不是 surface op 语义的一部分;如果 producer 需要 `base + i * step`,应表达为 +`iota(base=0) -> muli/vmi arithmetic -> addi/addf` 组合,或后续单独引入带 step 的 op。 +tail physical chunk 的 padding lane 可以承接 iota 的自然延续值,但这些 lane 不是 logical lane; +后续 memory/mask/reduction 等有外部效果的 consumer 必须继续按 valid logical lane 保护。 +deinterleaved layout 下的 physical part 需要 strided index materialization: + +```text +part p contains logical lanes p, p + factor, p + 2 * factor, ... +ASC value = base + p + factor * local_lane +DESC value = base - p - factor * local_lane +``` + +因此 direct `vci` 只覆盖 contiguous full-chunk path;deinterleaved path 必须额外物化 +`vci(0) * factor + base +/- p`,不能误降成每个 part 内连续的 `vci(base + p)`。当前 lowering +按 physical part 生成 `vci(0) + vmuls(factor) + vadds/vdup/vsub` 序列;padding/tail chunk +仍然需要独立的 padding-safe materialization plan。 + +`slice/insert_slice` 都按 logical lane order 定义,不读取或写入 padding lane: + +```text +slice(offset, size, stride): + result[j] = source[offset + j * stride] + +insert_slice(offset, stride): + result = dest + result[offset + j * stride] = update[j] + +insert_element(pos): + result = dest + result[pos] = scalar +``` + +`reduction/scan` 的 logical iteration 只覆盖 active logical lanes,padding lanes 不参与: + +```text +reduction(op, init, value, mask): + acc = init + for i in 0 .. N: + if mask is absent or mask[i]: + acc = op(acc, value[i]) + result = acc + +scan(op, init, value, mask): + acc = init + for i in 0 .. N: + if mask is absent or mask[i]: + acc = op(acc, value[i]) + result[i] = acc + else: + result[i] = passthru_or_identity +``` + +Current direct reduction support starts with integer add: + +```mlir +%r = pto.vmi.reduce_addi %value, %init, %mask + : !pto.vmi.vreg<64xi32>, !pto.vmi.vreg<1xi32>, + !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<1xi32> + +%rf = pto.vmi.reduce_addf %value, %init, %mask {reassoc} + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<1xf32>, + !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<1xf32> + +%rmax = pto.vmi.reduce_maxf %value, %init, %mask + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<1xf32>, + !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<1xf32> + +%rmin = pto.vmi.reduce_minf %value, %init, %mask + : !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<1xf16>, + !pto.vmi.mask<128xpred> -> !pto.vmi.vreg<1xf16> +``` + +`reduce_addi` preserves integer wraparound addition semantics. The direct +lowering requires contiguous layout, full 32-bit source physical chunks, +matching mask chunks, and one rank-0 init/result chunk. It emits `pto.vcadd` +for each masked source chunk, then serially accumulates each chunk result into +the rank-0 accumulator with `pto.vadd` under a `PAT_VL1` predicate. Padding +source lanes are rejected instead of being allowed to participate. + +`reduce_addf` is legal only with an explicit `{reassoc}` contract because the +ISA documents pair-wise FP reduction order. The direct lowering supports only +f32, contiguous layout, full source physical chunks, matching b32 mask chunks, +and one rank-0 init/result chunk. It uses the same per-chunk `vcadd` plus +serial `PAT_VL1 vadd` accumulation shape. Without `{reassoc}`, the verifier +rejects the op instead of silently changing ordered floating-point semantics. + +`reduce_maxf` and `reduce_minf` preserve VPTO-compatible floating-point min/max +reduction semantics. Direct lowering supports f16/f32, contiguous layout, full +source physical chunks, matching mask chunks, and one rank-0 init/result chunk. +For each physical source chunk, lowering emits `pto.vcmax` or `pto.vcmin`. +The chunk result's lowest lane is then accumulated into the rank-0 accumulator +with `pto.vmax` or `pto.vmin` under a `PAT_VL1` predicate. The index value that +`vcmax/vcmin` writes to the second lane is intentionally not part of the VMI op +result and is discarded by only observing lane 0. Inactive lane identities, +signed zero handling, and NaN behavior follow the underlying `vcmax/vcmin` and +`vmax/vmin` VPTO instructions. Padding source lanes are rejected, because the +logical reduction must not allow padding to become an inactive-lane identity or +a NaN-producing participant. + +lowering 可以选择 VPTO reduction/scan 指令、tree decomposition、scratch memory 或 scalarized +ordered fallback,但必须保持 numeric contract。没有目标能力时使用 `VMI-ELEMENT-TYPE` 或 +`VMI-LAYOUT-CONTRACT`,不能让未 lower 的逻辑向量 op 残留到 VPTO。 + +`contract/outerproduct` 在 VMI 中保留 indexing maps、iterator types、accumulator、mask 和 +element type,并且不允许绕过 VMI 直接回到其它向量 IR。如果目标有直接 matrix/vector contract +能力,lower 到直接 VPTO sequence;否则按 iterator space 分解成 VMI arithmetic + +reduction/scan,再走普通 VMI lowering。只有当 element type、accumulator 精度或 iterator +semantics 无法由目标表达时,才报 target capability diagnostic。 + +如果 producer 的 extract-like operation 结果仍是 logical vector,应表达成 `pto.vmi.slice`、 +`pto.vmi.shuffle` 或 `pto.vmi.shape_cast`。如果结果是 scalar,则属于 vector-to-scalar boundary, +不进入 VMI vector path,也不产生 `pto.vmi.extract`: + +```text +VMI-SCALAR-EXTRACT-BOUNDARY +``` + +## End-To-End Examples + +### f16 Widen Add Store + +Semantic VMI: + +```mlir +%a = pto.vmi.load %A[%i] + : memref -> !pto.vmi.vreg<128xf16> +%w = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> +%s = pto.vmi.addf %w, %bias + : !pto.vmi.vreg<128xf32> +pto.vmi.store %s, %C[%i] + : !pto.vmi.vreg<128xf32>, memref +``` + +Layout-assigned VMI: + +```mlir +%a = pto.vmi.load %A[%i] + : memref -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +%w = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +%s = pto.vmi.addf %w, %bias + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +pto.vmi.store %s, %C[%i] + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, memref +``` + +Physical lowering 可以生成 EVEN/ODD `vcvt`、两路 `vadd`,并在 store sink 使用 interleave +store 或显式 layout conversion。 + +### f8 To f32 + +```mlir +%a = pto.vmi.load %A[%i] + : memref -> !pto.vmi.vreg<256xf8> +%w = pto.vmi.extf %a + : !pto.vmi.vreg<256xf8> -> !pto.vmi.vreg<256xf32> +%s = pto.vmi.addf %w, %b + : !pto.vmi.vreg<256xf32> +pto.vmi.store %s, %C[%i] + : !pto.vmi.vreg<256xf32>, memref +``` + +layout assignment 可把 `%w/%s` 设为 `#pto.vmi.layout`。contiguous store 必须使用 +已验证的 layout sink 或先 materialize contiguous representation,不能把 p0/p1/p2/p3 part 当成连续内存写出。 + +### Block-Strided Tile Read + +```mlir +%tile = pto.vmi.tile_read %view[%c0, %c0], %pad, %mask + {logical_shape = [8, 8], + permutation_map = affine_map<(d0, d1) -> (d0, d1)>} + : memref<8x8xf32, strided<[?, 1], offset: ?>>, f32, + !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<64xf32> +``` + +如果 access plan 证明每 row 是 32B contiguous block,row 间 stride 可落到 ISA stride 字段, +且 mask block-uniform,lowering 可以选择 `vsldb`。如果 padding 非零,仍需在 load 后用 +valid mask 修正 invalid lane。 + +## Risk Closure Matrix + +| 风险 | 设计闭环 | 测试出口 | +|---|---|---| +| producer 直接绕过 VMI 生成 physical VPTO | VMI Producer Boundary Contract + Verifier Gates | `vmi_producer_boundary.mlir`, `vmi_pipeline_hard_gates.mlir` | +| arith numeric contract 被 VPTO 快速路径改写 | fastmath/rounding/overflow/cmp predicate preservation | `vmi_arith_numeric_contract.mlir` | +| layout 设计泛化失控 | closed `contiguous/deinterleaved=2/4` layout set + source contract | `vmi_f16_ext_add_store_deinterleaved2.mlir`, `vmi_f8_ext_add_store_deinterleaved4.mlir` | +| layout assignment 局部贪心导致控制流/多 use 错误 | region/SCC constraint solver + deterministic tie-break | `vmi_layout_assignment_constraint_solver.mlir`, `vmi_cf_and_call_layout_boundary.mlir` | +| 1:N physicalization arity 漂移 | Physical Arity helper + hard gate | `vmi_physical_arity_non_full_deinterleaved.mlir` | +| `deinterleaved=4` materialization 错 lane | registered preserving materialization path | `vmi_ensure_layout_materialization_contract.mlir` | +| mask granularity 过早固化 | surface `mask` + consumer-driven granularity assignment | `vmi_mask_granularity_width_change.mlir` | +| non-scalar broadcast / transpose 被当成 metadata | explicit broadcast map and lane-movement semantics | `vmi_shape_broadcast_semantics.mlir` | +| transfer padding / OOB read 写成 full load/store | `validMask` / `paddingValue` / `safeReadProof` / `writeMask` decision tree | `vmi_tile_read_padding_decision_tree.mlir`, `vmi_tile_write_oob_no_effect.mlir` | +| index/address 单位或宽度被误用 | index/address legalization contract | `vmi_index_address_legalization.mlir` | +| reduction/scan/contract 回退成 residual logical-vector op | VMI semantic op + direct/decompose/scratch lowering contract | `vmi_reduction_scan_contract_coverage.mlir` | +| shape 信息依赖 hidden side table | flat VMI value + shape-sensitive op attrs | `vmi_shape_broadcast_semantics.mlir`, `vmi_pipeline_hard_gates.mlir` | +| fallback 缺资源时退化成残缺 lowering | explicit fallback resource contract + `VMI-FALLBACK-RESOURCE` | `vmi_fallback_resource_diagnostics.mlir` | +| tensor/debug/scalar boundary 混入 VMI | explicit boundary diagnostics | `vmi_tensor_transfer_boundary.mlir`, `vmi_debug_boundary.mlir`, `vmi_extract_boundary.mlir` | + +## Diagnostics + +| Code | 场景 | +|---|---| +| `VMI-SCALAR-EXTRACT-BOUNDARY` | scalar lane extract 不是 VMI vector op,必须在进入 VMI 前消除或退出 PTO 路线 | +| `VMI-SCALABLE-VECTOR` | scalable vector 未在进入 VMI 前 specialize 成固定 logical lane count | +| `VMI-ELEMENT-TYPE` | target registry 缺 storage/compute/convert capability | +| `VMI-LAYOUT-CONTRACT` | VMI layout、mask granularity 或控制流/调用边界约束冲突 | +| `VMI-MEMORY-ACCESS` | access plan 无 direct/fallback path | +| `VMI-LAYOUT-CONTRACT` | layout conversion 或 sink 未被 target registry 支持 | +| `VMI-FALLBACK-RESOURCE` | scratch、guard、index buffer 或 fallback index width 资源不可用 | +| `VMI-TENSOR-BOUNDARY` | tensor transfer 必须在进入 VMI 前 bufferize 或退出 PTO 路线 | +| `VMI-DEBUG-BOUNDARY` | debug op 必须在进入 VMI 前消费、剥离或退出 PTO 路线 | +| `VMI-PASS-INVARIANT` | pipeline hard gate 被破坏,例如 hidden side table、残留 conversion cast 或 layout 缺失 | +| `VMI-RESIDUAL-OP` | physicalization 后仍有非法 VMI op/type 或 helper | + +diagnostic payload 至少包含 source op、semantic reason、failed contract、available paths、 +missing capability 或 disabled fallback option。 + +## Implementation Plan + +具体文件布局、Slice 切分、ODS/type/op/pass/test 落地步骤见 +`docs/designs/vmi-implementation-manual.md`。本节只保留高层任务顺序。 + +1. 定义 `!pto.vmi.vreg`、`!pto.vmi.vreg`、 + `!pto.vmi.mask`、`!pto.vmi.mask`。 +2. 定义 layout 目录:`#pto.vmi.layout`、 + `#pto.vmi.layout`、 + `#pto.vmi.layout`, + 并实现统一 lane-map / physical-arity helper。 +3. 定义 VMI semantic op families:construction、memory、arith、conversion、mask、 + permutation、active-prefix、compress/expand、channel split/merge、reduction/scan/contract。 +4. 实现 VMI producer boundary verifier,禁止 producer 直接生成 physical VPTO 或依赖 hidden state。 +5. 实现 `vmi-layout-assignment`,包含 op transfer function、cost model、mask granularity + conversion、control-flow join。 +6. 实现 VMI memory lowering:access plan、safe-read/write proof、tile padding materialization、 + transfer mask coordinate remap、masked/guarded/scratch fallback。 +7. 实现 `vmi-to-vpto` 1:N type conversion,包含 `pack/unpack` materialization 和 structural + conversion。 +8. 加 target element-type / layout-sink / ISA contract / fallback resource registry。 +9. 加 VMI hard gate verifier:覆盖 VMI producer boundary、`vmi-layout-assignment`、 + `vmi-to-vpto` 后的残留 op/type、layout、mask granularity、conversion cast 和 hidden-state + invariant。 +10. 加 VMI diagnostic code registry 和 lit tests。 + +## Test Checklist + +1. `vmi_f16_ext_add_store_deinterleaved2.mlir` + - `extf` 后 result 是 `vreg<128xf32, deinterleaved=2>`,store 保持 contiguous logical order。 +2. `vmi_f8_ext_add_store_deinterleaved4.mlir` + - `deinterleaved=4` p0/p1/p2/p3 不被误写成 contiguous memory。 +3. `vmi_non_full_tile_padding_lanes.mlir` + - `vreg<100xf32>` padding lane 不可观察。 +4. `vmi_mask_granularity_width_change.mlir` + - surface `mask` 被不同 width consumer 使用时,正确生成 `mask` / + `mask` 并保持 data layout。 +5. `vmi_control_flow_layout_join.mlir` + - `scf.if/scf.for` layouted VMI type join 稳定。 +6. `vmi_tile_read_padding_safe_footprint.mlir` + - full physical load unsafe 时不偷读 invalid lane。 +7. `vmi_block_strided_rows_vsldb.mlir` + - `tile_read/tile_write` 识别 32B block rows,并拒绝 per-lane mask direct path。 +8. `vmi_active_prefix_index_compress.mlir` + - arbitrary mask compaction 使用 logical prefix order。 +9. `vmi_extract_boundary.mlir` + - scalar extract 输出 `VMI-SCALAR-EXTRACT-BOUNDARY`。 +10. `vmi_channel_split_merge_semantic_op.mlir` + - interleaved channel data 按用户语义拆成多个 VMI values,再通过 merge 写回。 +11. `vmi_producer_boundary.mlir` + - producer boundary 后只有 VMI semantic op/type,不出现 physical VPTO 或 hidden-state 依赖。 +12. `vmi_mask_threading.mlir` + - region-style mask 被 thread 到 masked VMI op 或 `vmi.select` merge,不残留 region mask。 +13. `vmi_gather_scatter_memory_semantics.mlir` + - inactive gather/scatter lane 不读写内存,scatter duplicate-index case 不走非法 direct path。 +14. `vmi_reduction_scan_contract_coverage.mlir` + - reduction/scan/contract 不回退成 residual logical-vector op,按 VMI lowering contract 处理。 +15. `vmi_cf_and_call_layout_boundary.mlir` + - `cf.br/cond_br` block arguments 和 internal call signatures 选择稳定 layout,external ABI 不泄露 layout。 +16. `vmi_iota_bitcast_insert_extract_coverage.mlir` + - lane index、bitcast、vector-result extract-like 和 insert-like 语义都有 VMI 承接。 +17. `vmi_memory_view_normalization.mlir` + - producer-specific vector element view 先规范化为 element view 和 access plan。 +18. `vmi_debug_boundary.mlir` + - debug-only op 不进入 VMI;未被 producer 消费时输出 `VMI-DEBUG-BOUNDARY`。 +19. `vmi_arith_numeric_contract.mlir` + - VMI arithmetic constant、fastmath、cmp predicate、integer signedness/overflow flags 保真。 +20. `vmi_shape_broadcast_semantics.mlir` + - `shape_cast/reshape` 只改 explicit op shape attrs,`transpose/flat_transpose` 和非 scalar broadcast 保持 lane map 语义且不依赖 shape side table。 +21. `vmi_physical_arity_non_full_deinterleaved.mlir` + - 非整 tile 下 `contiguous/deinterleaved=2/4` 的 physical value 个数和 valid lane map 一致。 +22. `vmi_ensure_layout_materialization_contract.mlir` + - `ensure_layout` 保持 logical lane 值,`deinterleaved=4` 只使用 registry 证明过的 materialization path。 +23. `vmi_tile_read_padding_decision_tree.mlir` + - safe full-read + non-all-true valid mask 生成 padding materialization + select;unsafe path 不读 invalid address。 +24. `vmi_tile_write_oob_no_effect.mlir` + - `tile_write` 的 writeMask=false lane 没有 memory effect,不被 lower 成 predicate-ignored store。 +25. `vmi_transfer_mask_coordinate_remap.mlir` + - non-minor-identity `tile_read/tile_write` 的 explicit mask 先映射到 result/source logical lane。 +26. `vmi_tile_read_vector_element_padding.mlir` + - vector-element padding 按 suffix coordinate 展开,invalid lane 使用对应 padding element。 +27. `vmi_index_address_legalization.mlir` + - `vreg`、gather/scatter indices、active-prefix offset 使用 element units 且宽度合法。 +28. `vmi_fallback_resource_diagnostics.mlir` + - scratch、guarded fallback、index-buffer fallback 缺资源时输出 `VMI-FALLBACK-RESOURCE`。 +29. `vmi_tensor_transfer_boundary.mlir` + - tensor transfer-style producer op 不伪装成 VMI memory op,未 bufferize 时输出 `VMI-TENSOR-BOUNDARY`。 +30. `vmi_pipeline_hard_gates.mlir` + - 各 pass 边界拒绝残留 VMI helper/unrealized cast/hidden state,且 final lowering 不残留 VMI op/type。 +31. `vmi_layout_assignment_constraint_solver.mlir` + - 多 use、rematerializable producer、control-flow join、layout conversion cost 冲突时选择稳定 layout 或输出精确 diagnostic。 diff --git a/docs/designs/vmi-implementation-manual.md b/docs/designs/vmi-implementation-manual.md new file mode 100644 index 0000000000..772194f64d --- /dev/null +++ b/docs/designs/vmi-implementation-manual.md @@ -0,0 +1,4233 @@ +# VMI 实现手册 + +本文是 `docs/designs/vmi-dialect-design.md` 的落地手册。设计文档回答“为什么这样设计”,本文回答 +“按什么顺序改哪些文件、每一步做到什么程度才算完成”。 + +本文不替代最终 ODS / C++ verifier / lit 测试。实现时如果发现本文和 ODS 或 verifier 冲突,以 +更精确的 verifier 约束为准,并同步刷新本文。 + +## 0. 当前仓库约束 + +当前仓库只有一个 MLIR dialect: + +```text +dialect name: pto +cpp namespace: ::mlir::pto +``` + +VPTO 低层 op/type 也在同一个 `pto` dialect 里,通过 `VPTOOps.td`、`VPTOTypeDefs.td` 等文件组织。 +因此第一版 VMI 不新建独立 dialect,采用同一 dialect 下的嵌套 mnemonic: + +```text +types: + !pto.vmi.vreg<...> + !pto.vmi.mask<...> + +attrs: + #pto.vmi.layout<...> + +ops: + pto.vmi.addf + pto.vmi.subf + pto.vmi.mulf + pto.vmi.ensure_layout +``` + +落地方式是:`PTO_Dialect` 仍是唯一 dialect,VMI 只是 `pto` dialect 内的一组 type/attr/op。 +如果后续要拆成真正独立的 `pto.vmi` dialect,必须先保证所有 pass、type converter、parser 测试 +和公开文档同步迁移;第一版不要做这个拆分。 + +风险点:带点 mnemonic 例如 `vmi.vreg`、`vmi.addf` 必须在 Slice 0 先用 parser round-trip 测试 +证明。如果 TableGen 的默认 type/attr parser 不接受该 spelling,就在 VMI type/attr 上实现 +custom assembly format,而不是改公开 spelling。 + +## 1. 文件布局 + +新增文件: + +```text +include/PTO/IR/VMIAttrs.td +include/PTO/IR/VMITypeDefs.td +include/PTO/IR/VMIOps.td +lib/PTO/IR/VMI.cpp +lib/PTO/Transforms/VMILayoutAssignment.cpp +lib/PTO/Transforms/VMIToVPTO.cpp +lib/PTO/Transforms/PTOValidateVMIIR.cpp +test/lit/vmi/ +``` + +修改文件: + +```text +include/PTO/IR/PTOAttrs.td +include/PTO/IR/PTOTypeDefs.td +include/PTO/IR/PTOOps.td +include/PTO/IR/CMakeLists.txt +lib/PTO/IR/CMakeLists.txt +include/PTO/Transforms/Passes.td +lib/PTO/Transforms/CMakeLists.txt +``` + +推荐 include 关系: + +```text +PTOAttrs.td + include "PTO/IR/VMIAttrs.td" + +PTOTypeDefs.td + include "PTO/IR/VMITypeDefs.td" + +PTOOps.td + include "PTO/IR/VMIOps.td" +``` + +放置顺序: + +```text +VMIAttrs.td: + include PTODialect.td, AttrTypeBase.td, EnumAttr.td + must not include PTOAttrs.td + +VMITypeDefs.td: + include PTODialect.td and can rely on PTOAttrs.td having included VMIAttrs.td + +VMIOps.td: + include after PTO_Op is defined in PTOOps.td + do not include VPTOOps.td from VMIOps.td +``` + +这样现有 `LLVM_TARGET_DEFINITIONS PTOOps.td` 的 TableGen 生成路径可以继续覆盖 VMI type、attr +和 op。只有当 TableGen 生成目标不能正确收集新增 td 时,才单独新增 `mlir_tablegen` 目标。 + +`lib/PTO/IR/VMI.cpp` 放 VMI type/attr/op verifier、parse/print helper 和公共 lane-map helper。 +不要把 VMI verifier 塞进 `VPTO.cpp`。 + +Pass 注册要求: + +```text +include/PTO/Transforms/Passes.td: + add VMILayoutAssignment + add VMIToVPTO + add PTOValidateVMIIR + +include/PTO/Transforms/Passes.h: + add explicit create*Pass declarations if generated declarations are not enough + +lib/PTO/Transforms/CMakeLists.txt: + add the three new .cpp files to PTOTransforms + keep DEPENDS PTOPassesIncGen and PTOOpsIncGen + add missing MLIR dialect libraries only when a new source actually includes them +``` + +Driver wiring is explicit and opt-in. `ptoas --enable-vmi` runs the VMI semantic pipeline before the VPTO backend +pipeline: + +```text +pto-validate-vmi-ir +vmi-layout-assignment +pto-validate-vmi-layout-ir +vmi-to-vpto +``` + +`--enable-vmi` requires `--pto-backend=vpto` or `pto.backend = "vpto"` because the pipeline produces physical VPTO +values and ops. It is not part of the default PTOAS pipeline; existing PTO/VPTO inputs keep their previous behavior +unless the flag is set. + +The `ptoas --enable-vmi` user-facing entry also rejects public functions whose signature contains `!pto.vmi.*`. +Internal/private VMI-typed functions may still be specialized by `vmi-layout-assignment` and physicalized by +`vmi-to-vpto`, but a public VMI ABI requires an explicit materialization plan and must not be inferred from the +layout solver. + +CLI coverage: + +```text +vmi_ptoas_cli_pipeline.pto: + --pto-backend=vpto + --enable-vmi lowers the VMI pipeline + pto.backend = "vpto" also selects the VPTO-compatible path + explicit --pto-backend=emitc with --enable-vmi is rejected + +vmi_ptoas_backend_required_invalid.pto: + default emitc backend with --enable-vmi and no pto.backend = "vpto" is rejected + +vmi_ptoas_public_abi_invalid.pto / vmi_ptoas_public_result_abi_invalid.pto: + public VMI argument/result signatures are rejected before layout assignment +``` + +## MLIR Framework Usage + +三个核心 pass 不应该用同一种 MLIR 机制硬套。这里先定义实现框架选择,避免后续把 layout +求解、结构化控制流改写和 1:N physicalization 混在一个 pattern pass 里。 + +当前实现框架按下面的职责切开: + +```text +pto-validate-vmi-ir: + Operation::walk verifier。只看 IR 是否满足阶段不变量,不改 IR,不使用 conversion framework。 + +vmi-layout-assignment: + module-level per-SSA-value constraint solver。先收集等价类、producer natural layout 和 consumer request, + 再把结果写回 VMI type/helper op。它可以使用 IRRewriter 改 IR,但不以 TypeConverter 为主模型。 + +vmi-to-vpto: + MLIR OneToNTypeConversion。每个 layout-assigned VMI value 按统一 physical ordering 展开成多个 + VPTO value,并依靠 OneToN structural patterns 重写函数、return、region result、block argument 和 + branch operand。 +``` + +这三个 pass 的边界必须通过 IR 可见状态传递:layout 写在 `!pto.vmi.*` type 上,必要 materialization +写成 `pto.vmi.ensure_*`,physicalization 后不允许残留 `pto.vmi.*`、`!pto.vmi.*` 或 +`unrealized_conversion_cast`。不能把 layout 决策藏在 pass-private side table 里让后续 pass 猜。 + +源码级实现应该进一步拆成五个独立层次: + +```text +IR layer: + include/PTO/IR/VMIAttrs.td + include/PTO/IR/VMITypeDefs.td + include/PTO/IR/VMIOps.td + lib/PTO/IR/VMI.cpp + + 只定义语义、parse/print、type/op verifier 和公共 lane-map helper。 + 这一层不能知道 layout assignment 的全局选择,也不能直接依赖 VPTO lowering pass。 + +Semantic validation layer: + lib/PTO/Transforms/PTOValidateVMIIR.cpp + + 只检查阶段输入/输出是否满足 contract。它是 hard gate,不做 repair。 + +Layout solving layer: + lib/PTO/Transforms/VMILayoutAssignment.cpp + + 负责从 producer/consumer/control-flow/call 关系解出每个 logical value 的 layout, + 然后把结果写回 type 或 ensure_* helper。 + +Physicalization layer: + lib/PTO/Transforms/VMIToVPTO.cpp + + 负责把 layout-assigned VMI value 通过 OneToNTypeConversion 展成 VPTO physical values, + 并把每个 pto.vmi.* semantic op 改写成 VPTO op 序列。 + +Driver/test layer: + tools/ptoas/ptoas.cpp + tools/pto-test-opt/ + test/lit/vmi/ + + ptoas 只暴露 opt-in pipeline;pto-test-opt 保留单 pass 和中间 IR 的调试入口。 +``` + +每层的 MLIR 框架选择如下: + +```text +ODS/TableGen: + 定义 type/attr/op surface 和 verifier hook。 + +Operation::walk: + 用于 validation 和 layout constraint collection。 + +Union-find + DenseMap: + 用于 layout assignment 的 per-SSA-value 等价类求解。 + +IRRewriter/RewriterBase: + 用于 layout assignment 之后的 type rewrite、helper insertion、cheap producer rematerialization。 + +OneToNTypeConverter + OneToNOpConversionPattern: + 只用于 vmi-to-vpto,把一个 logical VMI value 展成多个 VPTO value。 + +Upstream OneToN structural helpers: + func.func / func.call / func.return / common SCF region-result conversion。 + +Project-local OneToN structural patterns: + cf.br / cf.cond_br / cf.switch / scf.execute_region / scf.index_switch。 +``` + +不要把这些层次合并成一个万能 pattern pass。特别是: + +```text +layout assignment 不能依赖 OneToNTypeConverter: + 因为 layout 不是 type-only 决策,同一个 !pto.vmi.vreg<128xf32> 的不同 SSA value + 可能因 producer/consumer/control-flow 约束得到不同 layout。 + +vmi-to-vpto 不能重新做 layout solving: + 它只消费已经写在 type/helper 上的 layout 决策。遇到未 assignment 的 VMI type 必须失败。 + +structural OneToN pattern 不能知道 VMI 语义: + 它们只负责 flatten/rebuild operands、results、successor operands 和 block arguments。 + 具体 lane 语义只属于 pto.vmi.* op lowering pattern。 + +verifier 不能偷偷修 IR: + 否则后续 pass 会依赖 verifier 的隐式 repair 行为,导致 pipeline 顺序不可推理。 +``` + +一个可以直接对照代码的 pass 边界表: + +```text +pass input output +--------------------------- ---------------------------- ---------------------------- +pto-validate-vmi-ir surface VMI IR same IR, or hard failure +vmi-layout-assignment surface/layout-partial VMI layout-assigned VMI IR +pto-validate-vmi-layout-ir layout-assigned VMI IR same IR, or hard failure +vmi-to-vpto layout-assigned VMI IR physical VPTO IR +final residual verifier physical VPTO candidate no pto.vmi.*, no !pto.vmi.* +``` + +### 代码级落点 + +当前实现应该能按文件直接审计。每个 pass 的核心类、MLIR 机制和失败边界如下: + +```text +lib/PTO/Transforms/PTOValidateVMIIR.cpp + pass: + PTOValidateVMIIRPass + PTOValidateVMILayoutIRPass + public helpers: + validateVMIProducerBoundaryIR + validateVMILayoutAssignedIR + MLIR API: + Operation::walk + func::FuncOp function type inspection + recursive TypeAttr / TypedAttr / ArrayAttr / DictionaryAttr scan + must not: + rewrite IR + create unrealized_conversion_cast + create ConversionTarget + repair illegal helper/type leakage + +lib/PTO/Transforms/VMILayoutAssignment.cpp + pass: + VMILayoutAssignmentPass + core object: + LayoutSolver + state: + DenseMap + SmallVector + SmallVector + SmallVector + SmallVector + MLIR API: + Operation::walk for fact collection + SymbolTable for direct internal calls + concrete cf/scf handlers for control-flow equivalence + IRRewriter/OpBuilder only after solving + must not: + use TypeConverter as the layout decision model + rewrite while collecting constraints + hide chosen layout in a pass-private side table + infer external VMI ABI + +lib/PTO/Transforms/VMIToVPTO.cpp + pass: + VMIToVPTOPass + converter: + VMIToVPTOTypeConverter : OneToNTypeConverter + pattern families: + OneToNOpConversionPattern for pto.vmi.* semantic ops + upstream func/scf OneToN structural patterns + project-local cf/scf structural OneToN patterns + MLIR API: + populateFuncTypeConversionPatterns + scf::populateSCFStructuralOneToNTypeConversions + applyPartialOneToNConversion + final residual walk + must not: + redo layout solving + inspect defining ops to recover physical parts + allow pto.vmi.pack/unpack/ensure_* to survive final output + allow unrealized_conversion_cast to survive final output +``` + +这里最重要的分界是:`vmi-layout-assignment` 解决的是 value-level layout,`vmi-to-vpto` +解决的是 type/value 1:N physicalization。前者的结果必须已经写回 `!pto.vmi.*` type 或显式 +`pto.vmi.ensure_*`;后者只能消费这些 IR-visible facts。 + +这也回答了“有没有充分利用 MLIR 自带能力”:结构化 1:N signature/control-flow conversion 必须用 +MLIR OneToN conversion;layout assignment 则不能强行塞进 converter,因为 converter 看不到 +producer natural layout、consumer request、CFG join 和 call-return slot 这些 value-level facts。 + +### Pass 级实现细则 + +这几个 pass 对 MLIR 自带能力的使用方式应该是“各用其长”,而不是都套成 converter pattern。 +实现时按下面的判断标准拆: + +```text +只检查阶段不变量: + 用 Operation::walk。不要创建 ConversionTarget,也不要 rewrite。 + +需要根据 SSA value、CFG join、call boundary 和 consumer request 决策 layout: + 用 module-level solver。MLIR conversion framework 没有 per-value layout 决策模型。 + +需要把一个 logical value 展成多个 physical value,并同步改 function/block/control-flow signature: + 用 OneToNTypeConversion。这里是 converter framework 最应该发挥作用的地方。 +``` + +#### Pass 框架细化 + +第一版实现按下面的源码和 MLIR infra 对齐。这个表是实现时的边界,不只是文档分层: + +```text +source file pass primary MLIR facility +----------------------------------------- --------------------------- --------------------------------------------- +lib/PTO/Transforms/PTOValidateVMIIR.cpp pto-validate-vmi-ir Operation::walk + recursive type/attr scan +lib/PTO/Transforms/PTOValidateVMIIR.cpp pto-validate-vmi-layout-ir Operation::walk + recursive type/attr scan +lib/PTO/Transforms/VMILayoutAssignment.cpp vmi-layout-assignment module-level union-find solver + IRRewriter +lib/PTO/Transforms/VMIToVPTO.cpp vmi-to-vpto OneToNTypeConverter + OneToNOpConversionPattern +``` + +这意味着每个 pass 的输入输出 contract 是固定的: + +```text +pto-validate-vmi-ir: + input: + surface VMI IR + legal: + pto.vmi semantic ops + !pto.vmi.vreg + !pto.vmi.mask + func/scf/cf structural ops carrying those types + illegal: + layout-assigned !pto.vmi.* type + physical !pto.vreg / !pto.mask / !pto.align type + pto.vmi.ensure_* / pack / unpack helper + VMI or physical type hidden in non-signature attribute + output: + exactly the same IR, or failure + +vmi-layout-assignment: + input: + verifier-clean surface VMI IR + legal work: + solve per-SSA layout/granularity constraints + rewrite VMI value/function/block types with explicit layout + insert pto.vmi.ensure_* only for use-site materialization + rematerialize cheap producers instead of inserting ensure_* when semantics are replay-safe + illegal work: + physicalize to !pto.vreg / !pto.mask + introduce pto.vmi.pack / pto.vmi.unpack + keep layout only in a pass-private side table + output: + layout-assigned VMI IR, or failure + +pto-validate-vmi-layout-ir: + input: + layout-assigned VMI IR + legal: + pto.vmi semantic ops + pto.vmi.ensure_layout / ensure_mask_layout / ensure_mask_granularity + !pto.vmi.vreg + !pto.vmi.mask + illegal: + surface !pto.vmi.vreg + surface !pto.vmi.mask + physical VPTO register types before vmi-to-vpto + pto.vmi.pack / pto.vmi.unpack + VMI or physical type hidden in non-signature attribute + output: + exactly the same IR, or failure + +vmi-to-vpto: + input: + layout-assigned VMI IR + legal work: + convert each VMI value to an ordered list of physical VPTO values + rewrite function signatures, block arguments, branch operands, region results and calls + lower pto.vmi semantic/helper ops to VPTO ops + illegal work: + infer missing layouts + change a chosen layout because one pattern finds a cheaper lowering + leave pto.vmi.* / !pto.vmi.* / unrealized_conversion_cast in final IR + output: + physical VPTO IR, or failure +``` + +`vmi-layout-assignment` 和 `vmi-to-vpto` 的关键差异是:前者解决“这个 SSA value 应该是什么 layout”, +后者解决“这个已经有 layout 的 SSA value 展开成哪些 physical value”。同一个 surface type 不能用 +`TypeConverter` 得到唯一答案: + +```mlir +%a = pto.vmi.broadcast %s : f32 -> !pto.vmi.vreg<128xf32> +%b = pto.vmi.extf %x : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> +%c = scf.if %cond -> !pto.vmi.vreg<128xf32> { + scf.yield %a : !pto.vmi.vreg<128xf32> +} else { + scf.yield %b : !pto.vmi.vreg<128xf32> +} +``` + +这里 `%a` 可以按 consumer 需要 rematerialize 成 contiguous 或 deinterleaved;`%b` 的 natural layout 是 +`deinterleaved=2`;`%c` 的 layout 必须由两个 yield 和后续 consumer 共同约束。这个选择依赖 Value、 +def-use、control-flow join 和 use-site request,不是 `!pto.vmi.vreg<128xf32> -> ...` 的 type-only 规则。 + +因此 layout pass 的代码形态应该固定为: + +```cpp +LogicalResult LayoutSolver::run() { + if (failed(collectAllVMIValues())) + return failure(); + if (failed(collectEquivalenceConstraints())) + return failure(); + if (failed(collectProducerNaturalLayouts())) + return failure(); + if (failed(collectConsumerRequests())) + return failure(); + if (failed(rewriteDataTypes())) + return failure(); + if (failed(insertDataUseMaterializations())) + return failure(); + if (failed(inferAndRewriteMaskTypes())) + return failure(); + if (failed(insertMaskUseMaterializations())) + return failure(); + rewriteFunctionTypesFromSolvedValues(); + return validateVMILayoutAssignedIR(module); +} +``` + +其中 `collect*` 阶段只能记录事实,不能边 walk 边改 IR。原因是控制流和 call boundary 会把后面才遇到的 +operand/result 合并到前面的 value class;边收集边改 type 会让后续约束看到混合状态,错误诊断也会依赖 +walk 顺序。 + +`vmi-to-vpto` 则必须是 converter pass。第一版使用的是 `OneToNTypeConversion`,因为它要同时处理 +value type 和结构签名: + +```text +!pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +func.func @f(%arg0: !pto.vmi.vreg<128xf32, layout>) -> !pto.vmi.vreg<128xf32, layout> + -> func.func @f(%arg0_0: !pto.vreg<64xf32>, %arg0_1: !pto.vreg<64xf32>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +``` + +这里不能用普通 1:1 `TypeConverter`,也不能靠每个 VMI op pattern 自己拆 operand。否则 `func.return`、 +`cf.br`、`scf.for` iter arg 这种没有 VMI defining op 的边界会漏转换。`OneToN` adaptor 才是 semantic +pattern 获取 physical parts 的唯一来源: + +```cpp +ValueRange lhsParts = adaptor.getLhs(); +ValueRange rhsParts = adaptor.getRhs(); +TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); +``` + +结构化转换的实现分工如下: + +```text +upstream helper: + populateFuncTypeConversionPatterns + covers func.func / func.return / direct func.call signature conversion + + scf::populateSCFStructuralOneToNTypeConversions + covers common SCF result/yield/block-argument structural conversions + +project-local OneToN patterns: + cf.br + cf.cond_br + cf.switch + scf.execute_region + scf.index_switch +``` + +项目内 structural pattern 只能做结构搬运: + +```text +1. read OneToNTypeMapping for each original operand/result +2. flatten successor operands or region result types +3. rebuild the same cf/scf op with converted types +4. inline/move original regions when required +``` + +它们不能做下面这些事: + +```text +infer layout from operand defining op +emit vadd/vcvt/vlds/vsts +decide contiguous vs deinterleaved +special-case pto.vmi semantic op +``` + +VMI 语义只能出现在 `OneToNOpConversionPattern` 里。这样才能保证 block argument、function +argument、loop-carried value 和 branch target argument 都按同一套 physical ordering 转换。 + +`vmi-to-vpto` 的 legality 由 preflight + conversion + final gate 三段组成,而不是单靠 +`ConversionTarget`: + +```text +preflight: + verifyVMIToVPTOInputIR + rejects layout-free VMI types + verifySupportedVMIToVPTOOps + rejects unsupported semantic/materialization cases before rewrite starts + +conversion: + applyPartialOneToNConversion + applies structural and semantic OneToN patterns + +final gate: + verifyNoResidualVMIIR + rejects pto.vmi.* + rejects !pto.vmi.* in operand/result/block/function/attribute type trees + rejects pto.vmi.pack/unpack materialization helpers + rejects unrealized_conversion_cast +``` + +这比只设置 `ConversionTarget` 更直接,因为当前 OneToN 工具链的重点是 type/value expansion 和 pattern +rewriter;最终合法性必须递归检查 attribute/type tree,防止 VMI type 被藏在 nested attr 里。 + +#### `pto-validate-vmi-ir` / `pto-validate-vmi-layout-ir` + +这两个 pass 是 hard gate,不是 legalization pass。 + +使用的 MLIR 能力: + +```text +Operation::walk: + 遍历 module 内所有 op、region、block argument、operand/result type 和 attribute。 + +TypeAttr / TypedAttr recursive scan: + 拒绝把 VMI/physical VPTO type 藏在 nested attribute 中。 + +func::FuncOp function type special case: + function_type attr 是签名本身,可以按当前阶段规则检查;其它 attr 不能携带 VMI/physical type。 +``` + +不使用 `ConversionTarget` 的原因: + +```text +ConversionTarget 适合表达“哪些 op/type legal,哪些 pattern 能改掉”。 +这里我们只想回答“当前 IR 是否已经处在某个阶段边界”,失败后必须停机,而不是尝试 repair。 +如果 verifier 顺手改 IR,pipeline 的阶段不变量会变成隐式行为,后续 pass 很难审计。 +``` + +这两个 pass 的输出只能是原 IR 或 failure: + +```cpp +void runOnOperation() override { + if (failed(verifyStageInvariant(getOperation()))) + signalPassFailure(); +} +``` + +#### `vmi-layout-assignment` + +这个 pass 使用 MLIR 的 IR 遍历和 rewrite 基础设施,但不使用 `TypeConverter` 作为主模型。 + +核心原因: + +```text +TypeConverter 的输入是 Type。 +layout assignment 的输入是 Value。 + +同一个 !pto.vmi.vreg<128xf32> 可以因为不同 producer/consumer 关系得到不同 layout: + f16->f32 widen result -> deinterleaved=2 + f8 ->f32 widen result -> deinterleaved=4 + only contiguous store value -> contiguous +``` + +实现应拆成两个阶段,不要边 walk 边 rewrite: + +```text +collect: + 1. 收集所有 VMI data/mask SSA value 和 block argument。 + 2. 用 union-find 合并必须同 layout 的 value。 + 3. 记录 producer natural layout。 + 4. 记录 consumer layout/granularity request。 + 5. 记录 function return slot、call operand/result、branch operand/block argument 关系。 + +rewrite: + 1. 为每个 equivalence class 选 layout。 + 2. 改写 value/function/block/result type。 + 3. 对 use-site mismatch 插入 ensure_* 或 rematerialize cheap producer。 + 4. 运行 pto-validate-vmi-layout-ir。 +``` + +建议的数据结构边界: + +```cpp +struct DataNode { + Value value; + VMIVRegType type; + unsigned parent; + VMILayoutAttr naturalLayout; +}; + +struct MaskNode { + Value value; + VMIMaskType type; + unsigned parent; + VMILayoutAttr requestedLayout; + std::string requestedGranularity; +}; + +struct DataUseRequest { + OpOperand *operand; + VMILayoutAttr layout; +}; + +struct MaskUseRequest { + OpOperand *operand; + VMILayoutAttr layout; + std::string granularity; +}; +``` + +这里可以充分使用 MLIR 的接口,但它们只是 constraint source: + +```text +BranchOpInterface / concrete cf.* handlers: + successor operand[i] == destination block argument[i] + +RegionBranchOpInterface / concrete scf.* handlers: + region yield operand[i] == parent result[i] + loop init/result/iter_arg/yield 同 slot 等价 + +CallOpInterface + SymbolTable: + direct internal call operand/result 和 callee argument/return slot 等价 + external/indirect VMI call 先拒绝,因为缺 ABI materialization + +IRRewriter: + 只在 solve 完成后统一改 type、插 ensure_*、clone cheap producer。 +``` + +`vmi-layout-assignment` 的 pass invariant 是:所有 layout 决策必须写回 IR。后续 `vmi-to-vpto` +只能读取 `!pto.vmi.*` type 和显式 `pto.vmi.ensure_*`,不能依赖 layout solver 的 side table。 + +#### `vmi-to-vpto` + +这个 pass 应该充分使用 MLIR converter framework,具体是 `OneToNTypeConversion`,不是普通 +`DialectConversion`。 + +普通 1:1 dialect conversion 不够的地方: + +```text +!pto.vmi.vreg<128xf32, deinterleaved=2> + -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +!pto.vmi.vreg<256xf8, deinterleaved=4> + -> !pto.vreg<256xf8>, !pto.vreg<256xf8>, !pto.vreg<256xf8>, !pto.vreg<256xf8> +``` + +函数参数、返回值、block argument、branch operand、region result 都必须做同样的 1:N 展开。 +这正是 `OneToNTypeConverter`、`OneToNOpConversionPattern` 和结构化 OneToN helper 的职责。 + +实现骨架: + +```cpp +void runOnOperation() override { + ModuleOp module = getOperation(); + + if (failed(verifyVMIToVPTOInputIR(module)) || + failed(verifySupportedVMIToVPTOOps(module))) + return signalPassFailure(); + + VMIToVPTOTypeConverter typeConverter; + RewritePatternSet patterns(&getContext()); + + populateFuncTypeConversionPatterns(typeConverter, patterns); + scf::populateSCFStructuralOneToNTypeConversions(typeConverter, patterns); + populateProjectLocalCFOneToNPatterns(typeConverter, patterns); + populateVMISemanticOneToNPatterns(typeConverter, patterns); + + if (failed(applyPartialOneToNConversion(module, typeConverter, + std::move(patterns))) || + failed(verifyNoResidualVMIIR(module))) + signalPassFailure(); +} +``` + +`VMIToVPTOTypeConverter` 只做一种事:把 layout-assigned VMI type 映射到 canonical physical value list。 +它不能重新推导 layout。 + +```text +contiguous: + chunk0, chunk1, ... in logical order + +deinterleaved=2: + part0 chunks for logical lanes 0,2,4,... + part1 chunks for logical lanes 1,3,5,... + +deinterleaved=4: + part0 chunks for lanes 0,4,8,... + part1 chunks for lanes 1,5,9,... + part2 chunks for lanes 2,6,10,... + part3 chunks for lanes 3,7,11,... +``` + +每个 semantic pattern 必须从 adaptor 拿 physical parts,不允许从 defining op 反推: + +```cpp +LogicalResult matchAndRewrite(VMIAddFOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange lhs = adaptor.getLhs(); + ValueRange rhs = adaptor.getRhs(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + + if (lhs.size() != rhs.size() || lhs.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, "physical arity mismatch"); + + SmallVector results; + for (auto [i, resultType] : llvm::enumerate(resultTypes)) { + results.push_back( + rewriter.create(op.getLoc(), resultType, lhs[i], rhs[i]) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); +} +``` + +这个约束对控制流是关键的:`scf.for` iter arg、branch target argument、function argument 都没有可用的 +defining op;它们的 physical parts 只能来自 OneToN signature/block argument conversion。 + +`vmi-to-vpto` 应有三层失败点,诊断不要混在一起: + +```text +preflight: + layout 未 assignment、unsupported semantic op、unsupported materialization path + +conversion: + pattern 缺失、arity mismatch、结构化控制流展开失败 + +final residual verifier: + 任何 pto.vmi.*、!pto.vmi.*、pto.vmi.pack/unpack/ensure_*、unrealized_conversion_cast 残留 +``` + +### `pto-validate-vmi-ir` + +`pto-validate-vmi-ir` 是边界 verifier,不使用 DialectConversion。 + +推荐使用: + +```text +Operation::walk +TypeSwitch / isa / dyn_cast +emitOpError / InFlightDiagnostic +SymbolTable, for function/call boundary checks +CallGraph or manual call graph collection, if recursive SCC needs diagnostics +DominanceInfo, if helper placement or resource dominance is checked +``` + +这个 pass 只检查 VMI producer boundary 和阶段不变量: + +```text +before layout assignment: + VMI data/mask values use surface type + no layout-assigned VMI type leaks in unless the test explicitly starts after assignment + no physical VPTO op appears in the semantic VMI region + no VMI helper op appears before the pass that is allowed to create it + no non-signature op/module TypeAttr or TypedAttr payload contains VMI or physical VPTO types + +after layout assignment: + pass: pto-validate-vmi-layout-ir + every VMI data value has a layout + every VMI mask has layout and concrete granularity + control-flow joins have stable type/layout + no non-signature op/module TypeAttr or TypedAttr payload contains VMI or physical VPTO types + +after VMI-to-VPTO: + no VMI op/type/helper remains + no unrealized_conversion_cast remains +``` + +不要把这个 pass 写成 rewrite pass。它可以收集 context 用于诊断,但不能通过局部修补让非法 IR +继续前进;否则后续 pass 会开始依赖 verifier 的隐式 repair 行为。 + +实现上要扫描的不只是 operand/result/block argument: + +```text +func.func function type: + 作为函数签名本身检查,允许出现当前阶段合法的 VMI type。 + +non-signature attributes: + module/op attribute 中只要递归包含 VMI type 或 physical VPTO type 都拒绝。这里包括 TypeAttr、 + TypedAttr,以及 ArrayAttr/DictionaryAttr 这类容器中的 nested attribute/type payload。 +``` + +这样可以堵住 hidden-state 形式的 side table,例如把 `!pto.vmi.vreg<...>` 偷存在 module attribute +里。`func.func` 的内建 `function_type` attr 是唯一例外,因为它只是函数签名的 MLIR 表达,不是额外 +隐藏状态。 + +### `vmi-layout-assignment` + +`vmi-layout-assignment` 不以 MLIR `TypeConverter` 作为主机制。 + +原因是 layout 选择不是单纯的 `Type -> TypeRange` 映射: + +```text +same surface type: + !pto.vmi.vreg<128xf32> + +possible per-value decisions: + value produced by f16->f32 widen: deinterleaved=2 + value loaded only for contiguous store: contiguous + value feeding fp8-like->f32 consumer path: deinterleaved=4 +``` + +两个 SSA value 可以有完全相同的 surface type,但因为 producer natural layout、consumer demand、 +控制流 join 和 target capability 不同,得到不同 layout。因此主模型应该是 per-SSA-value 的约束图, +而不是类型转换表。 + +推荐内部结构: + +```text +DenseMap +DenseMap +DenseMap +SmallVector +SmallVector +``` + +推荐使用的 MLIR 基础能力: + +```text +RegionBranchOpInterface: + collect scf.if/scf.for-like region entry, yield, result relations + +BranchOpInterface: + collect cf.br/cf.cond_br predecessor operand -> block argument relations + +CallOpInterface, CallableOpInterface, FunctionOpInterface: + collect call operand/result and function argument/result relations + +SymbolTable: + resolve direct calls and reject unresolved VMI signature assumptions + +DominanceInfo: + choose legal insertion points for ensure_layout, mask conversion, and rematerialization + +IRRewriter / RewriterBase: + rewrite types, insert helper ops, clone rematerializable producers +``` + +求解结果必须 materialize 回 IR,不能留在 side table: + +```text +1. Rewrite every VMI value type to a layout-assigned type. +2. Rewrite mask type to layout + b8/b16/b32 granularity. +3. Insert pto.vmi.ensure_layout where a consumer requires a different layout. +4. Insert pto.vmi.ensure_mask_layout / ensure_mask_granularity where predicate layout or granularity differs. +5. Clone rematerializable producers such as constant, broadcast, create_mask, iota-like producers when cheaper. +6. Re-run the VMI stage verifier. +``` + +这个 pass 可以用 `RewritePatternSet` 辅助局部 canonicalization,例如删除同 layout 的 +`ensure_layout`,但不能让 greedy pattern driver 决定全局 layout。全局约束必须先收敛,再做改写。 + +更具体地说,这里不用 `TypeConverter` 的原因不是 MLIR converter 不好用,而是此阶段的问题不是 +“一个旧 type 机械变成一个新 type”: + +```text +%a : !pto.vmi.vreg<128xf32> // 只被 contiguous store 消费 +%b : !pto.vmi.vreg<128xf32> // 来自 f16->f32 widen,后续继续 vadd +%c : !pto.vmi.vreg<128xf32> // 控制流 join,两个 predecessor 必须统一 layout +``` + +这三个 value 的 surface type 完全相同,但 layout 决策分别可能是 contiguous、deinterleaved=2、 +以及由 join 两侧约束共同决定。`TypeConverter` 看不到“这个 SSA value 的 producer/consumer/CFG +关系”,所以它只能作为后续 physicalization 的工具,不能作为 layout assignment 的主算法。 + +该 pass 对 MLIR 基础能力的使用边界是: + +```text +Operation::walk: + 收集所有 VMI SSA value、block argument、函数签名和 op transfer facts。 + +Union-find / DenseMap: + 表达必须同 layout 的 equivalence class。 + +SymbolTable: + 解析 direct internal func.call;带 VMI type 的 external/indirect call 先拒绝。 + +IRRewriter: + 改写 function/block/result type,插入 ensure_*,必要时 rematerialize cheap producer。 + +verifyLayoutAssignedVMIIR: + pass 末尾 hard gate,确认所有决策已经 materialize 到 IR。 +``` + +### `vmi-to-vpto` + +`vmi-to-vpto` 应该使用 MLIR 的 1:N conversion framework,而不是普通 `DialectConversion`。 +这个 pass 的核心问题正是一个 logical VMI value physicalize 成多个 VPTO value: + +```text +!pto.vmi.vreg -> !pto.vreg... +!pto.vmi.mask -> !pto.mask... +``` + +普通 `DialectConversion` 的 `OpConversionPattern` 对 1:N fixed operand/result 支持不够直接: +pattern adaptor 可能拿到 source materialization,也可能拿到 flat converted operands;`func.return` +这类“一个 logical operand 展开成多个 physical operands”的场景也容易出现不完整展开。因此这里采用 +MLIR `OneToNTypeConversion` 工具: + +推荐组件: + +```text +OneToNTypeConverter +OneToNOpConversionPattern +OneToNPatternRewriter +OneToNTypeMapping +populateFuncTypeConversionPatterns +scf::populateSCFStructuralOneToNTypeConversions +applyPartialOneToNConversion +final residual verifier +``` + +`OneToNTypeConverter` 负责 layout-assigned VMI type 到 ordered physical VPTO value list: + +```cpp +typeConverter.addConversion([](VMIVRegType type, SmallVectorImpl &results) { + // Use getVMIPhysicalArity(type) and the shared lane-map helper. + // Append one physical !pto.vreg per part/chunk. +}); + +typeConverter.addConversion([](VMIMaskType type, SmallVectorImpl &results) { + // Use mask granularity and physical arity helper. + // Append one physical !pto.mask per part/chunk. +}); +``` + +source/target materialization 可以用 VMI helper 承接中间状态: + +```text +VMI value -> physical values: + pto.vmi.unpack + +physical values -> VMI value: + pto.vmi.pack +``` + +但它们只是 conversion materialization,不是最终 IR 的合法残留。final gate 必须拒绝: + +```text +pto.vmi.pack +pto.vmi.unpack +pto.vmi.ensure_layout +pto.vmi.ensure_mask_layout +pto.vmi.ensure_mask_granularity +unrealized_conversion_cast +``` + +`applyPartialOneToNConversion` 本身不是 legality framework;它负责应用 1:N patterns 并替换内部 +`unrealized_conversion_cast`。因此 `vmi-to-vpto` 必须在 conversion 后运行 final residual verifier, +把下面这些全部作为 hard failure: + +```text +any pto.vmi.* op +any !pto.vmi.* type +any pto.vmi.pack/unpack materialization helper +any pto.vmi.ensure_* helper +any unrealized_conversion_cast +``` + +结构转换必须覆盖: + +```text +func arguments/results and return operands: + use populateFuncTypeConversionPatterns + +call operands/results: + convert callee signature and call sites together + +block arguments and branch operands: + convert target block arguments and predecessor operands in the same conversion + current implementation provides project-local OneToN patterns for cf.br, + cf.cond_br, and cf.switch because MLIR only provides the generic + BranchOpInterface helper for ordinary 1:1 dialect conversion, not for VMI + 1:N physicalization. + +scf.if/scf.for region yields and results: + use scf::populateSCFStructuralOneToNTypeConversions + otherwise write explicit OneToN patterns around RegionBranchOpInterface relations +``` + +如果当前 LLVM/MLIR 版本没有提供对应 OneToN helper,就补项目内 custom `OneToNConversionPattern`。 +选择标准不是“少写代码”,而是能否正确处理 1:N result、block argument、region yield 和 +recursive/function SCC。 + +当前实现的结构转换分工如下: + +```text +upstream OneToN helper: + func.func / func.return / func.call + scf.if / scf.for / scf.while and common SCF structural cases + +project-local OneToN structural patterns: + cf.br + cf.cond_br + cf.switch + scf.execute_region + scf.index_switch +``` + +项目内 structural pattern 只做一件事:按照 `OneToNTypeMapping` 展平/重建 operand、result、 +successor operand 和 block argument。它们不能内嵌 VMI layout 语义,也不能通过 defining op +重新推导物理寄存器列表。VMI 语义只出现在各个 `pto.vmi.*` 的 `OneToNOpConversionPattern` 中。 + +OneToN conversion 的执行顺序: + +```text +1. Populate structural conversion patterns. +2. Populate VMI semantic op lowering patterns. +3. Populate helper lowering/materialization patterns. +4. applyPartialOneToNConversion on the module. +5. Run final residual verifier as the hard legality gate. +``` + +如果 conversion 或 final gate 失败,诊断必须区分: + +```text +unsupported VMI semantic op +unsupported layout materialization path +unconverted function/control-flow boundary +unexpected VMI helper residual +unexpected unrealized_conversion_cast +``` + +这样 pass 边界就是清楚的: + +```text +pto-validate-vmi-ir: + verifier/walk, no conversion + +vmi-layout-assignment: + global per-value layout solver, then IR materialization + +vmi-to-vpto: + OneToNTypeConversion-based 1:N physicalization and final legality gate +``` + +### Concrete Pass Skeleton + +整个 pipeline 按下面的 hard contract 串起来: + +```text +raw VMI producer + -> pto-validate-vmi-ir + -> vmi-layout-assignment + -> pto-validate-vmi-layout-ir + -> vmi-to-vpto + -> final residual verifier +``` + +The `ptoas --enable-vmi` driver entry uses exactly this sequence before the existing VPTO backend pipeline. The +test-opt entry remains useful for isolated pass debugging, while the `ptoas` flag proves the same sequence is wired +through the user-facing compiler driver. + +各阶段之间只通过 IR 传递状态,不通过 pass-private side table 传递语义。也就是说: + +```text +layout assignment output: + VMI value type already contains layout + VMI mask type already contains layout + concrete b8/b16/b32 granularity + required layout conversion already appears as pto.vmi.ensure_* or rematerialized producer + +vmi-to-vpto input: + may contain pto.vmi.* semantic ops and helper ops + must not contain layout-free VMI type + function signatures and op/module TypeAttr or TypedAttr payloads are part of this invariant, + not just SSA operands/results + +vmi-to-vpto output: + must not contain pto.vmi.* op/type/helper + must not contain unrealized_conversion_cast + function type attributes and any other op/module TypeAttr or TypedAttr payloads must not contain !pto.vmi.* +``` + +This prevents a fragile design where `vmi-to-vpto` has to rediscover layout decisions from defining ops. A VMI value +may be a function argument, block argument, `scf.if` result, `scf.for` carried value, or branch target argument; none +of those has a useful defining op. + +#### Layout Assignment State + +`vmi-layout-assignment` should be implemented as one module-level solver object: + +```cpp +struct DataValueState { + Value value; + VMIVRegType surfaceType; + UnionFindNode eqClass; + VMILayoutAttr naturalLayout; // producer-preferred layout + SmallVector uses; // consumer requirements +}; + +struct MaskValueState { + Value value; + VMIMaskType surfaceType; + UnionFindNode eqClass; + VMILayoutAttr requestedLayout; + StringRef requestedGranularity; // b8/b16/b32 after inference + SmallVector uses; // consumer layout/granularity requests +}; + +struct LayoutUseRequest { + Operation *consumer; + VMILayoutAttr layout; + StringRef reason; // add/select/store/widen-source/etc. +}; +``` + +The solver runs in phases: + +```text +1. collect all VMI data/mask SSA values, including block arguments +2. add equivalence constraints +3. add producer natural-layout constraints +4. add consumer layout/granularity requests +5. solve each equivalence class +6. insert ensure_* or rematerialize producers for non-class-compatible uses +7. rewrite value types and function signatures +8. run pto-validate-vmi-layout-ir +``` + +Equivalence is only for cases where two logical values must have the same physical lane order: + +```text +add/sub/mul: + lhs == rhs == result + +cmpf/cmpi: + lhs == rhs + result mask requests lhs layout + element-width granularity + +select: + true_value == false_value == result + mask operand gets a use-site request for result layout + element-width granularity + +scf.if: + result[i] == then yield[i] == else yield[i] + +scf.for: + init_arg[i] == region_iter_arg[i] == yield[i] == result[i] + +cf.br/cf.cond_br: + successor operand[i] == successor block argument[i] + +direct internal func.call: + call operand[i] == callee argument[i] + call result[i] == all callee return operand[i] +``` + +Natural layout is not equivalence. For example: + +```text +extf f16 -> f32: + result natural layout = deinterleaved=2 + +extf f8 -> f32: + result natural layout = deinterleaved=4 + +truncf f32 -> f16: + result natural layout = contiguous + +truncf f32 -> fp8-like: + result natural layout = contiguous + +store/tile_write: + consumer requests contiguous externally visible order +``` + +If one equivalence class has incompatible natural layouts, the pass must diagnose `VMI-LAYOUT-CONTRACT` unless a +defined rematerialization path can split the value before the conflict. The first version should only rematerialize +trivially replayable producers: + +```text +constant +broadcast +constant_mask +create_mask +``` + +For non-rematerializable producers, insert `pto.vmi.ensure_layout` immediately before the consumer that requested the +different layout. This is the conservative first implementation rule. It works for ordinary SSA values, block +arguments, loop-carried values, branch arguments, and call results because the helper is dominated by the value at the +use site and does not need to be hoisted across control flow. `DominanceInfo` may be used later to hoist duplicated +helpers as an optimization, but it must not be required for correctness in the first implementation. + +That helper is a real IR marker: if `vmi-to-vpto` cannot lower its requested conversion, the program fails with an +explicit unsupported materialization diagnostic. + +#### Layout Assignment Implementation Frame + +This pass is a normal `OperationPass`. It deliberately does not use `DialectConversion`, because there is +no stable `Type -> Type` rule until the pass has solved producer preference, consumer demand, and control-flow joins. +The implementation should look like this: + +```cpp +struct LayoutSolver { + ModuleOp module; + MLIRContext *ctx; + + DenseMap dataIds; + SmallVector dataNodes; + DenseMap maskIds; + SmallVector maskNodes; + + SmallVector dataUseRequests; + SmallVector maskUseRequests; + DenseMap> firstReturnOperandsByFunc; + + LogicalResult collectConstraints(); + LogicalResult rewriteIR(); +}; +``` + +The concrete state objects should carry only facts that are materialized back into IR: + +```cpp +struct DataNode { + Value value; + VMIVRegType surfaceType; + unsigned parent; + VMILayoutAttr naturalLayout; // null means no producer preference yet +}; + +struct MaskNode { + Value value; + VMIMaskType surfaceType; + unsigned parent; + VMILayoutAttr requestedLayout; + std::string requestedGranularity; // empty until b8/b16/b32 is known +}; + +struct DataUseRequest { + OpOperand *operand; + VMILayoutAttr layout; +}; + +struct MaskUseRequest { + OpOperand *operand; + VMILayoutAttr layout; + std::string granularity; +}; +``` + +Do not store hidden layout state that `vmi-to-vpto` must rediscover. After this pass, a debugger should be able to read +the IR and know the chosen layout for every VMI value from its type alone. + +The pass body should stay simple: + +```cpp +void runOnOperation() override { + LayoutSolver solver(getOperation()); + if (failed(solver.collectConstraints()) || + failed(solver.rewriteIR()) || + failed(verifyLayoutAssignedVMIIR(getOperation()))) + signalPassFailure(); +} +``` + +The current implementation should map directly to this phase order: + +```cpp +LogicalResult LayoutSolver::run() { + if (failed(collect())) + return failure(); + if (failed(addConstraints())) + return failure(); + + rewriteDataTypes(); + if (failed(insertDataUseMaterializations())) + return failure(); + + if (failed(inferMaskRequests())) + return failure(); + rewriteMaskTypes(); + if (failed(insertMaskUseMaterializations())) + return failure(); + + rewriteFunctionType(); + return validateVMILayoutAssignedIR(module); +} +``` + +This order is intentional: + +```text +collect: + only discovers VMI values and block arguments. + +addConstraints: + only records equivalence, natural layout and consumer request facts. + It must not rewrite IR, because later CFG/call constraints may still merge + two values that were already seen. + +rewriteDataTypes: + commits solved data layouts to !pto.vmi.vreg type. + +insertDataUseMaterializations: + repairs use-site layout mismatch after the producer's committed type is known. + +inferMaskRequests: + uses already committed data layouts and element widths to infer concrete mask + layout/granularity requests. + +rewriteMaskTypes: + commits mask layout and b8/b16/b32 granularity. + +insertMaskUseMaterializations: + repairs mask layout/granularity mismatch. + +rewriteFunctionType: + updates function signatures last, after argument/result value types have been + rewritten. +``` + +Do not move `rewriteFunctionType` before use-site materialization. A function signature is the public shape of the +solved value class; changing it early makes call/return diagnostics depend on walk order and can hide an unresolved +use-site mismatch. + +Constraint collection is a module walk with explicit handlers. The important point is that each handler only records +facts; it must not rewrite while walking: + +```text +Data equivalence: + pto.vmi.addf/addi: lhs == rhs == result + pto.vmi.cmpf/cmpi: lhs == rhs + pto.vmi.select: true_value == false_value == result + pto.vmi.ensure_layout: source and result are not equivalent if layouts differ + +Data natural layout: + pto.vmi.extf f16->f32: result natural = deinterleaved=2 + pto.vmi.extf fp8-like->f32: result natural = deinterleaved=4 + pto.vmi.truncf: result natural = contiguous + pto.vmi.channel_merge with C inputs: result natural = deinterleaved=C + +Data use request: + pto.vmi.store/tile_write: value requested as contiguous + pto.vmi.channel_split with C results: source requested as deinterleaved=C + op requiring a common operand/result layout: request producer class layout + +Mask request: + cmp result: same data layout as operands, granularity from element width + select mask: same data layout as selected value, granularity from element width + store mask path: same data layout as stored value, granularity from element width +``` + +Control flow should be handled as equivalence, not as local op preference: + +```text +scf.if: + result[i] == then yield[i] == else yield[i] + +scf.for: + init_arg[i] == body iter_arg[i] == yield[i] == result[i] + +scf.while: + before argument[i] == condition forwarded operand[i] == after argument[i] + after yield[i] == result[i] + +scf.execute_region: + every nested scf.yield operand[i] == execute_region result[i] + +scf.index_switch: + every case/default yield operand[i] == index_switch result[i] + +cf.br: + operand[i] == destination block argument[i] + +cf.cond_br: + true operand[i] == true destination block argument[i] + false operand[i] == false destination block argument[i] + +cf.switch: + default operand[i] == default destination block argument[i] + case k operand[i] == case k destination block argument[i] + +func.call: + only direct internal callees are supported in the first implementation + call operand[i] == callee argument[i] + call result[i] == every corresponding callee return operand[i] +``` + +Function returns need one extra bookkeeping rule. A function result slot has one public layout in the function type, so +all `func.return` operands at the same index must be equivalent: + +```text +first return operand[i] == every later return operand[i] +function result type[i] is rewritten from the solved type of return operand[i] +call result[i] == every corresponding callee return operand[i] +``` + +If two return paths naturally produce incompatible layouts, the pass should report `VMI-LAYOUT-CONTRACT` instead of +silently choosing one path: + +```mlir +^a: + %x = pto.vmi.extf %f16 : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + return %x : !pto.vmi.vreg<128xf32> // natural deinterleaved=2 + +^b: + %y = pto.vmi.extf %f8 : !pto.vmi.vreg<256xf8E4M3FN> -> !pto.vmi.vreg<256xf32> + return %y : !pto.vmi.vreg<256xf32> // different result shape/layout, invalid by verifier/type first +``` + +For equal result shape but incompatible producer preferences, the same rule applies: + +```text +return slot 0 from f16->f32 path: natural deinterleaved=2 +return slot 0 from f8E4M3FN->f32 path with the same logical result shape: natural deinterleaved=4 +diagnostic: VMI-LAYOUT-CONTRACT: conflicting natural layouts ... +``` + +External declarations with VMI types are not a layout problem; they are ABI materialization. The first implementation +must reject them before rewriting: + +```text +VMI-LAYOUT-CONTRACT: VMI typed function declaration requires an explicit external ABI materialization plan +``` + +The rewrite phase has three ordered steps: + +```text +1. Rewrite all data SSA value types to !pto.vmi.vreg. +2. Rewrite all mask SSA value types to !pto.vmi.mask. +3. Repair use-site mismatches by either rematerializing a cheap producer or inserting an explicit helper. +``` + +Rematerialization is allowed only when replaying the producer cannot change memory, control flow, or execution count +semantics: + +```text +allowed: + pto.vmi.constant splat + pto.vmi.broadcast + pto.vmi.constant_mask + pto.vmi.create_mask + +not allowed in the first implementation: + load/tile_read + arithmetic result + conversion result + shuffle/channel_split/channel_merge result + value crossing a call boundary or block argument +``` + +If rematerialization is not legal, insert: + +```text +pto.vmi.ensure_layout +pto.vmi.ensure_mask_layout +pto.vmi.ensure_mask_granularity +``` + +These helpers make the unresolved materialization explicit. `vmi-layout-assignment` is allowed to create them; +`vmi-to-vpto` is responsible for proving and lowering them. If lowering cannot prove the physical transform, the final +diagnostic should be an unsupported layout/materialization diagnostic, not silent incorrect code. + +Layout assignment completion checks: + +```text +1. No surface !pto.vmi.vreg remains. +2. No surface !pto.vmi.mask remains. +3. Every VMI function argument, result, block argument, branch operand, call operand, and return operand has the + layout-assigned type selected by the solved equivalence class. +4. Every consumer-specific mismatch is represented either by a rematerialized cheap producer or by an explicit + pto.vmi.ensure_* op immediately before that consumer. +5. External declarations with VMI types are rejected; they are not rewritten into an implicit ABI. +``` + +#### OneToN Conversion Details + +`vmi-to-vpto` should use MLIR `OneToNTypeConversion` for all structural rewriting that involves VMI values: + +```text +OneToNTypeConverter: + !pto.vmi.vreg -> !pto.vreg... + !pto.vmi.mask -> !pto.mask... + +Patterns: + framework structural OneToN patterns for func/return/scf + explicit OneToNOpConversionPattern for each pto.vmi semantic op + explicit helper patterns for pack/unpack/ensure_* + +Final gate: + reject residual pto.vmi.*, !pto.vmi.*, function signatures containing !pto.vmi.*, and unrealized_conversion_cast +``` + +The implementation is an `OperationPass` with this shape: + +```cpp +struct VMIToVPTOTypeConverter final : OneToNTypeConverter { + VMIToVPTOTypeConverter() { + addConversion([](Type t) { return t; }); + addConversion(convertVMIVRegType); + addConversion(convertVMIMaskType); + + TypeConverter::addSourceMaterialization(materializeVPTOToVMI); + TypeConverter::addArgumentMaterialization(materializeVPTOToVMI); + OneToNTypeConverter::addTargetMaterialization(materializeVMIToVPTO); + } +}; + +void runOnOperation() override { + ModuleOp module = getOperation(); + if (failed(verifyVMIToVPTOInputIR(module)) || + failed(verifySupportedVMIToVPTOOps(module))) + return signalPassFailure(); + + VMIToVPTOTypeConverter typeConverter; + RewritePatternSet patterns(module.getContext()); + populateVMIOneToNConversionPatterns(typeConverter, patterns); + + if (failed(applyPartialOneToNConversion(module, typeConverter, + std::move(patterns))) || + failed(verifyNoResidualVMIIR(module))) + signalPassFailure(); +} +``` + +The type converter must define one canonical physical ordering and every pattern must use that ordering: + +```text +!pto.vmi.vreg + -> chunks in logical order: + chunk0 lanes [0..P-1], chunk1 lanes [P..2P-1], ... + +!pto.vmi.vreg + -> part-major chunks: + part0 chunk0 lanes [0,2,4,...] + part0 chunk1 next even lanes + part1 chunk0 lanes [1,3,5,...] + part1 chunk1 next odd lanes + +!pto.vmi.vreg + -> part-major chunks: + part0 lanes [0,4,8,...] + part1 lanes [1,5,9,...] + part2 lanes [2,6,10,...] + part3 lanes [3,7,11,...] + +!pto.vmi.mask + -> same part/chunk ordering as its data layout, one !pto.mask per physical part/chunk +``` + +`materializeVPTOToVMI` and `materializeVMIToVPTO` should use only `pto.vmi.pack` and `pto.vmi.unpack`. These ops are +conversion scaffolding; they are never valid final output. This makes accidental framework materialization visible in +the IR and easy to reject. + +Pattern population should be explicit: + +```cpp +void populateVMIOneToNConversionPatterns(VMIToVPTOTypeConverter &converter, + RewritePatternSet &patterns) { + populateFuncTypeConversionPatterns(converter, patterns); + scf::populateSCFStructuralOneToNTypeConversions(converter, patterns); + + patterns.add(converter, ctx); + + patterns.add(converter, ctx); + + patterns.add(converter, ctx); +} +``` + +Use upstream OneToN helpers where they exist: + +```text +func.func / func.return / func.call: + populateFuncTypeConversionPatterns + +scf.if / scf.for / scf.while and common structural SCF: + scf::populateSCFStructuralOneToNTypeConversions +``` + +Use project-local OneToN patterns where the current MLIR version does not provide a complete 1:N structural rewrite: + +```text +cf.br +cf.cond_br +cf.switch +scf.execute_region +scf.index_switch +``` + +These project-local structural patterns should not know VMI semantics. They only flatten operands/results according to +`OneToNTypeMapping`, convert successor block argument lists, and rebuild the same control-flow op. + +#### Pattern Authoring Checklist + +Every new `pto.vmi.*` lowering pattern should answer the same questions before it is added to +`populateVMIOneToNConversionPatterns`: + +```text +1. Does the op require all data operands/results to have identical physical arity? + If yes, check every ValueRange size against the result mapping before emitting VPTO ops. + +2. Does the op consume a mask? + If yes, the mask must already have concrete granularity and the same physical ordering expected by the data + operand. The pattern must not reinterpret a pred mask by lane count alone. + +3. Does the op observe contiguous logical order outside the register file? + If yes, require contiguous layout or explicitly lower the ensure_layout/materialization before using load/store + style VPTO ops. + +4. Does the op have padding lanes? + If yes, prove padding is unobservable. For load-like ops this requires a full-read safety proof or a fallback. + For store-like ops this requires a true predicate that disables padding writes. + +5. Does the op have target-specific side effects or ordering, such as squeeze/compact/store coupling? + If yes, put that check in verifySupportedVMIToVPTOOps before conversion starts, so the pass fails before partial + rewriting. + +6. Can it create pto.vmi.pack/unpack or unrealized_conversion_cast through framework materialization? + If yes, the semantic pattern still may be correct, but final residual verification must reject any leftover helper. +``` + +This gives a concrete division of labor: + +```text +verifySupportedVMIToVPTOOps: + shape/target/path support checks that should fail before any rewrite. + +OneToNOpConversionPattern: + mechanical lowering for a preflight-approved case. + +verifyNoResidualVMIIR: + final hard gate for missed patterns, illegal materializations and hidden VMI type payloads. +``` + +Do not put target capability probing in a structural pattern. For example, a `cf.br` pattern must never ask whether +`deinterleaved=4` can be materialized. It only converts successor operands. The semantic op that created or consumes +the value is responsible for proving the VPTO lowering path. + +#### Converter Use By Pass + +The implementation should be reviewable with the following rule: + +```text +pto-validate-vmi-ir: + no TypeConverter, no ConversionTarget, no rewrite. + +vmi-layout-assignment: + no TypeConverter for choosing layouts. + It may use RewriterBase after solving, but not DialectConversion as the solving model. + +vmi-to-vpto: + must use OneToNTypeConverter for VMI types. + must use OneToNOpConversionPattern for semantic VMI ops. + should use upstream func/scf OneToN helpers when available. + may add project-local structural OneToN patterns only for missing framework coverage. +``` + +The main reason is not style. It is correctness across values without defining ops: + +```mlir +^bb0(%x: !pto.vmi.vreg<128xf32, #pto.vmi.layout>): + cf.br ^bb1(%x : !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + +^bb1(%y: !pto.vmi.vreg<128xf32, #pto.vmi.layout>): + %z = pto.vmi.addf %y, %y + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + ... +``` + +`%y` has no defining VMI op. Its physical values are the converted block arguments produced by OneToN block signature +conversion. Any implementation that tries to recover physical parts from a defining op is therefore incomplete for +control flow, function arguments and loop-carried values. + +When writing semantic `OneToNOpConversionPattern`, do not infer physical parts from a defining op. Use the OneToN +adaptor's per-original-operand `ValueRange`: + +```cpp +LogicalResult matchAndRewrite(VMIAddFOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange lhsParts = adaptor.getLhs(); + ValueRange rhsParts = adaptor.getRhs(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + ... + rewriter.replaceOp(op, physicalResults, adaptor.getResultMapping()); +} +``` + +Every VMI semantic lowering then follows the same shape: + +```cpp +ValueRange lhsParts = adaptor.getLhs(); +ValueRange rhsParts = adaptor.getRhs(); +TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + +for each physical part index i: + emit physical VPTO op for lhsParts[i], rhsParts[i] -> resultTypes[i] + +replace op with all physical results using adaptor.getResultMapping() +``` + +This convention is mandatory for values crossing control flow. For example an `scf.for` iter arg has no defining op; +its physical parts are the converted block arguments created by OneToN signature conversion. + +The concrete pattern shape is: + +```cpp +LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange in0 = adaptor.getIn0(); + ValueRange in1 = adaptor.getIn1(); + TypeRange outTypes = adaptor.getResultMapping().getConvertedTypes(0); + + if (in0.size() != in1.size() || in0.size() != outTypes.size()) + return rewriter.notifyMatchFailure(op, "physical arity mismatch"); + + SmallVector results; + for (auto [i, outType] : llvm::enumerate(outTypes)) { + results.push_back(rewriter.create(op.getLoc(), outType, + in0[i], in1[i]).getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); +} +``` + +For non-VMI operands, use a helper like `getSingleValue(op, adaptor.getOffset(), "...")` and fail if the framework +unexpectedly expanded them. This catches malformed conversion rules early. + +#### Semantic Lowering Buckets + +The first implementation should split VMI op lowering into four buckets: + +```text +identity/helper: + pack, unpack, ensure_layout identity/materialization cases, ensure_mask_* identity case + +per-part elementwise: + addf, addi, subf, subi, mulf, muli, divf, minf, maxf, negf, absf, absi, sqrt, exp, ln, relu, andi, ori, xori, shli, shrui, not, cmpf, cmpi, select + +per-part predicate: + mask_and, mask_or, mask_xor, mask_not + +layout-producing conversion: + extf, truncf, bitcast + +externally ordered memory: + load, store, tile_read, tile_write +``` + +Per-part elementwise ops are straightforward only when all operands/results already share the same assigned layout: + +```text +logical deinterleaved=2 value: + part0 contains logical lanes 0, 2, 4, ... + part1 contains logical lanes 1, 3, 5, ... + +vmi.addf/subf/mulf on two such values: + emit the matching VPTO per-part op for part0_lhs, part0_rhs + emit the matching VPTO per-part op for part1_lhs, part1_rhs +``` + +This preserves logical lane semantics because each physical part contains the same logical lane subset for all +operands and the result. + +Memory ops are different because their observable semantics are contiguous logical order: + +```text +vmi.store of deinterleaved=2: + cannot blindly store part0 then part1 as the final memory order + must use a store plan that writes logical lane 0,1,2,3,... order + or materialize source to contiguous before physical store +``` + +Therefore `store/tile_write` lowering must either: + +```text +1. consume contiguous layout directly, or +2. lower ensure_layout(deinterleaved -> contiguous), then store, or +3. use target store instructions whose dist mode proves contiguous external order +``` + +The first implementation uses option 2 for full physical chunks: + +```text +vmi.load: + emit contiguous physical vlds chunks in memory order + materialize contiguous -> assigned result layout + +vmi.masked_load: + only when the full physical read footprint is proven safe + emit contiguous physical vlds chunks in memory order + select loaded lanes against passthru with the VMI mask + if enable-stable-gather-masked-load is set, reject pto.vmi.masked_load with + a stable TODO diagnostic until the VGATHER2-based strict no-read path is + implemented + +vmi.store: + materialize assigned source layout -> contiguous + emit physical vsts chunks in memory order + +vmi.tile_read / vmi.tile_write: + follow the same externally ordered rule +``` + +Current direct memory lowering may only emit VPTO vector memory ops for +UB-backed memory. Concretely, a `!pto.ptr<..., ub>` is legal, a +`!pto.ptr<..., gm>` is not; a memref with `#pto.address_space` is legal, +and a memref without a memory-space attribute is treated as unknown/local to +this stage to preserve existing local-view tests. A memref explicitly marked +GM or another non-VEC space is rejected by `vmi-to-vpto`. + +GM-backed VMI memory is still a valid semantic source/sink before this pass, +but direct lowering does not perform GM<->UB movement. That must be represented +by an earlier/lower memory access plan, scratch materialization, or UB view +normalization before `vmi-to-vpto`; otherwise the diagnostic is +`VMI-UNSUPPORTED` and names the GM-backed source/destination. + +For `deinterleaved=2`, `vldsx2 DINTLV_B*` and `vstsx2 INTLV_B*` are valid optimization candidates because the ISA has +an explicit two-stream de/interleave memory distribution mode. This should be implemented only as a peephole inside +`vmi-to-vpto` after the generic plan is correct: + +```text +vmi.load result layout deinterleaved=2: + vldsx2 DINTLV_B* can directly produce part0/part1 chunks + +vmi.store source layout deinterleaved=2: + vstsx2 INTLV_B* can directly store part0/part1 chunks in logical memory order +``` + +Do not generalize this to `deinterleaved=4` unless the two-level dist composition is proven against the ISA. The +fallback for `deinterleaved=4` remains generic layout materialization plus ordinary memory ops. + +Partial/tail load-style memory is legal only when the lowering can prove the full physical read footprint is safe. The +current direct path supports this limited proof: + +```text +source is a statically shaped memref +offset is a constant non-negative index, or tile_read implicit offset 0 +offset + physical_arity(result) * lanes_per_physical_part <= static memref element count +``` + +When this proof holds, `vmi.load` / `vmi.tile_read` may still issue full `pto.vlds` chunks. The extra padding lanes are +not logical VMI lanes and must remain unobservable through later VMI materialization rules. Pointer sources, dynamic +offsets, dynamic memrefs, and insufficient static footprints remain unsupported: + +```text +VMI-UNSUPPORTED: pto.vmi. requires full physical chunks without padding lanes or a statically safe full-read +footprint (...; safe-read proof failed: ...) +VMI-UNSUPPORTED: pto.vmi. ... (source is GM-backed, but current direct VMI-to-VPTO memory lowering emits +pto.vlds/pto.vsts and requires UB-backed memory) +``` + +Store-style ops are different because inactive lanes can be made write-free with true predicates. `vmi.store`, +`vmi.masked_store`, and `vmi.tile_write` therefore support the explicit contiguous/deinterleaved tail-store +materialization paths described below. + +## 2. Slice 0: Type / Attr Bootstrap + +第一步只实现 VMI type、layout attr 和纯 helper,不实现任何 conversion pass。 + +### 2.1 `#pto.vmi.layout` + +定义 `VMILayoutAttr`: + +```mlir +#pto.vmi.layout +#pto.vmi.layout +#pto.vmi.layout +``` + +建议内部参数: + +```text +kind: enum { contiguous, deinterleaved } +factor: int64_t +``` + +Verifier: + +```text +contiguous: + factor must be 1 + +deinterleaved: + factor must be 2 or 4 +``` + +禁止接受其它 spelling,例如 `stride2`、`stride4`、`parity`、`mod_split`、`blocked`。 + +### 2.2 `!pto.vmi.vreg` + +定义 `VMIVRegType`: + +```mlir +!pto.vmi.vreg<128xf32> +!pto.vmi.vreg<128xf32, #pto.vmi.layout> +!pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +建议参数: + +```text +elementCount: int64_t +elementType: Type +layout: Attribute // null means surface type before layout assignment +``` + +Verifier: + +```text +elementCount > 0 +elementType is scalar-like integer / float / index supported by VMI +layout is null or VMILayoutAttr +deinterleaved=4 only allowed when target registry later supports it; type verifier only checks shape +``` + +不要要求 `elementCount * bitwidth(elementType)` 是 256B 整数倍。 + +### 2.3 `!pto.vmi.mask` + +定义 `VMIMaskType`: + +```mlir +!pto.vmi.mask<128xpred> +!pto.vmi.mask<128xb32, #pto.vmi.layout> +!pto.vmi.mask<128xb32, #pto.vmi.layout> +``` + +建议参数: + +```text +elementCount: int64_t +granularity: enum/string { pred, b8, b16, b32 } +layout: Attribute +``` + +Verifier: + +```text +elementCount > 0 +surface mask may use pred and no layout +layout-assigned mask must use b8/b16/b32 and must have VMILayoutAttr +pred mask must not carry layout +``` + +### 2.4 Lane Map Helper + +在 C++ 中提供纯函数 helper,供 verifier、layout assignment、VMI-to-VPTO 和测试共用: + +```text +getDataLanesPerPart(elementType) +getMaskLanesPerPart(granularity) +getVMIPhysicalArity(type) +mapLogicalLaneToPhysical(type, logicalLane) +mapPhysicalLaneToLogical(type, part, chunk, lane) +isPaddingLane(type, part, chunk, lane) +``` + +这些 helper 是 hard dependency。任何 pass 不能重新手写一套 arity 公式。 + +Slice 0 完成条件: + +```text +1. VMI type/attr 能 parse/print round-trip。 + Covered by vmi_type_attr_parse.pto. +2. 非法 layout factor、非法 mask granularity、非法 element count 有 verifier diagnostic。 + Covered by vmi_layout_factor_invalid.pto, + vmi_mask_granularity_invalid.pto, vmi_type_element_count_invalid.pto, + and vmi_mask_concrete_without_layout_invalid.pto / + vmi_mask_pred_with_layout_invalid.pto. +3. helper 单测或 lit 测试覆盖 contiguous/deinterleaved=2/deinterleaved=4 和非整 tile。 + Covered by vmi_to_vpto_type_only.pto and + vmi_to_vpto_type_arity.pto. +``` + +## 3. Slice 1: Minimal VMI Op Set + +不要一次实现 75 个 semantic op。第一批只实现能跑通 widening + elementwise + store 的闭环。 + +### 3.1 必选 semantic op + +Construction: + +```text +pto.vmi.constant +pto.vmi.broadcast +pto.vmi.iota +pto.vmi.create_mask +pto.vmi.constant_mask +``` + +`pto.vmi.from_elements` belongs to the eventual construction surface, but it is +not part of Slice 1. Do not synthesize it from ad hoc scalar lane inserts until +there is an explicit vreg immediate, scalar-insert, or scratch materialization +contract. + +Mask: + +```text +pto.vmi.mask_and +pto.vmi.mask_or +pto.vmi.mask_xor +pto.vmi.mask_not +``` + +Arithmetic / conversion: + +```text +pto.vmi.addf +pto.vmi.addi +pto.vmi.subf +pto.vmi.subi +pto.vmi.mulf +pto.vmi.muli +pto.vmi.fma +pto.vmi.divf +pto.vmi.minf +pto.vmi.maxf +pto.vmi.negf +pto.vmi.absf +pto.vmi.absi +pto.vmi.sqrt +pto.vmi.exp +pto.vmi.ln +pto.vmi.relu +pto.vmi.andi +pto.vmi.ori +pto.vmi.xori +pto.vmi.shli +pto.vmi.shrui +pto.vmi.not +pto.vmi.cmpf +pto.vmi.cmpi +pto.vmi.select +pto.vmi.extf +pto.vmi.truncf +pto.vmi.bitcast +``` + +`pto.vmi.shrui` represents logical right shift and lowers to `pto.vshr`. +`pto.vmi.shrsi` is intentionally not defined until VPTO exposes or documents +an arithmetic right-shift contract distinct from logical right shift. +Integer div/rem, integer casts, int-float casts, and index casts are also +intentionally outside the current VMI surface until signedness, rounding, +saturation, overflow/remainder, and target lowering contracts are explicit. + +Memory: + +```text +pto.vmi.load +pto.vmi.masked_load +pto.vmi.gather +pto.vmi.expand_load +pto.vmi.store +pto.vmi.masked_store +pto.vmi.scatter +pto.vmi.compress_store +pto.vmi.tile_read +pto.vmi.tile_write +``` + +Current implementation scope note: + +```text +pto.vmi.gather / scatter +pto.vmi.active_prefix_index / compress / compress_store +future scan / contract style ops +``` + +These families are not first-stage completion blockers. The dialect surface may +define them, and the lowering may keep narrow direct paths when the target VPTO +contract is already explicit. Full semantic coverage for these families remains +out of scope until cross-chunk state, duplicate-index ordering, prefix carry, +compaction state, or contraction accumulation contracts are explicitly designed. +Unsupported shapes must fail before OneToN rewrite with `VMI-UNSUPPORTED`; they +must not fall through to residual-op diagnostics. + +Permutation: + +```text +pto.vmi.shuffle +pto.vmi.channel_split +pto.vmi.channel_merge +``` + +Internal helper: + +```text +pto.vmi.ensure_layout +pto.vmi.ensure_mask_layout +pto.vmi.ensure_mask_granularity +pto.vmi.unpack +pto.vmi.pack +``` + +### 3.2 Op Verifier Rules + +Construction op verifier: + +```text +constant value must be a dense elements attr, and its element type/count must match the result vreg +broadcast scalar type must match the result element type +constant_mask value must be a dense elements attr, must have i1 element type, and its element count must match the +result mask +create_mask may produce surface pred mask or concrete layout-assigned mask +mask_and/mask_or/mask_xor/mask_not require all mask operands/results to have the same logical lane count; if any +mask is layout-assigned, all masks must carry the same layout and granularity +``` + +Elementwise op verifier: + +```text +all data operands have same logical lane count +all data operands have same element type except documented conversion op +if any operand has layout, all layouted operands/results must agree +surface op may have no layout before vmi-layout-assignment +``` + +`select` verifier: + +```text +mask lane count == true/false/result lane count +mask layout must match data layout after layout assignment +mask granularity must match selected element width after layout assignment +``` + +`extf/truncf` verifier: + +```text +source/result lane count equal +source/result element types are float +bitwidth changes in the expected direction +``` + +Memory op verifier: + +```text +load/tile_read memory element type must match result VMI data element type when the source is PtrType or MemRefType +store/tile_write memory element type must match stored VMI data element type when the destination is PtrType or MemRefType +``` + +`shuffle` verifier: + +```text +static mask length == result lane count +each mask index selects an existing source logical lane +result element type == source element type +no padding lane may be selected +``` + +`channel_split` verifier: + +```text +result count C >= 2 +input lane count N == C * M +each result is vreg +channel c result semantics: out[c][i] = input[i * C + c] +if any source/result carries layout, all must carry layout +for C=2/4, layout-assigned source must be contiguous or deinterleaved=C +layout-assigned results must be contiguous +``` + +`channel_merge` verifier: + +```text +operand count C >= 2 +all operands have same M and element type T +result is vreg +result semantics: result[i * C + c] = input[c][i] +if any input/result carries layout, all must carry layout +layout-assigned inputs must be contiguous +for C=2/4, layout-assigned result must be contiguous or deinterleaved=C +``` + +`ensure_layout` verifier: + +```text +source/result are both VMIVRegType +same elementCount and elementType +source/result both layout-assigned +source layout may equal result layout; that is a canonical no-op +``` + +`ensure_mask_layout` verifier is identical except it uses `VMIMaskType` and preserves granularity. + +`ensure_mask_granularity` verifier: + +```text +source/result are both VMIMaskType +same elementCount +same layout +source/result granularity are b8/b16/b32 +logical predicate value must be preserved +``` + +`pack/unpack` verifier: + +```text +VMI side must be layout-assigned +physical operand/result count == getVMIPhysicalArity(VMI type) +physical data types are !pto.vreg +physical mask types are !pto.mask +ordering is the shared Physical Arity helper order +``` + +Slice 1 完成条件: + +```text +1. Every Slice 1 op parses, prints, and has negative verifier tests. + Arithmetic/mask/helper verifier coverage includes vmi_elementwise_kind_invalid.pto, + vmi_mask_logic_invalid.pto, vmi_ensure_layout_surface_invalid.pto, + vmi_unpack_arity_invalid.pto, and vmi_pack_arity_invalid.pto. +2. Helper ops are marked internal in docs and rejected by final VMI-to-VPTO gate if residual. +3. `channel_split/channel_merge` have tests proving shuffle-equivalent lane order. +``` + +## 4. Slice 2: VMI Producer Boundary Verifier + +VMI core implementation starts from VMI IR. Producer-specific import is outside this manual's core path. + +实现 `PTOValidateVMIIR.cpp` 中的 VMI boundary verifier: + +```text +recommended pass name: pto-validate-vmi-ir +anchor: func::FuncOp or ModuleOp +source file: lib/PTO/Transforms/PTOValidateVMIIR.cpp +``` + +Boundary verifier checks: + +```text +all logical vector values use !pto.vmi.vreg / !pto.vmi.mask +all logical vector behavior is represented by pto.vmi semantic ops +surface VMI values before layout assignment do not carry layout +no physical VPTO op appears before vmi-to-vpto +no hidden side table is required to interpret VMI values +scalar/tensor/debug/transform boundary has already been resolved by producer +``` + +Slice 2 完成条件: + +```text +1. VMI-native positive tests pass boundary verification. + Covered by vmi_producer_boundary_valid.pto. +2. Physical VPTO op before VMI-to-VPTO is rejected. + Covered by vmi_producer_boundary_physical_invalid.pto, including both + physical function types and physical VPTO ops. +3. Layout-assigned type before layout assignment is rejected unless the test explicitly starts after layout assignment. + Covered by vmi_producer_boundary_layout_invalid.pto and + vmi_producer_boundary_mask_layout_invalid.pto. +4. Missing VMI type/op invariants produce `VMI-PASS-INVARIANT` or a more specific diagnostic. + Covered by vmi_producer_boundary_non_vmi_op_invalid.pto, + vmi_producer_boundary_helper_invalid.pto, and the producer-boundary + TypeAttr nested/surface/layout invalid tests. +``` + +## 5. Slice 3: `vmi-layout-assignment` + +推荐实现为 pass: + +```text +recommended pass name: vmi-layout-assignment +anchor: ModuleOp +source file: lib/PTO/Transforms/VMILayoutAssignment.cpp +``` + +`vmi-layout-assignment` 必须是 module 级 pass。函数参数、`func.return` operand、 +`func.call` operand/result 和 callee signature 需要在同一个约束图里求解;函数级 pass +只能看到局部 body,无法安全地同步 callsite 和 callee。 + +### 5.1 Internal Data Model + +Build one layout node per VMI SSA value: + +```text +Operation result +BlockArgument +Region yield operand +Function argument/result +Call operand/result +``` + +Each node records: + +```text +logical type: VMIVRegType or VMIMaskType +allowed layouts: bitset {contiguous, deinterleaved2, deinterleaved4} +required mask granularity: pred/b8/b16/b32 or unknown +natural layout preference +hard constraints +soft costs +``` + +No information required by later passes may live only in this data structure. After the pass, type/attr/op +operands must fully describe the result. + +### 5.2 Transfer Functions + +Minimum Slice 3 transfer functions: + +```text +constant/broadcast/create_mask/constant_mask: + rematerializable in any legal consumer layout + +mask_and/mask_or/mask_xor/mask_not: + all mask operands/results same layout and granularity + +addf/addi/subf/subi/mulf/muli/divf/minf/maxf/negf/absf/absi/sqrt/exp/ln/relu/andi/ori/xori/shli/shrui/not/cmpf/cmpi/select: + all data operands/results same layout + mask layout follows data layout + +extf f16 -> f32: + result natural layout = deinterleaved=2 + source requires contiguous layout for the direct vcvt part=EVEN/ODD path + partial/tail source chunks are supported when they still fit in one physical + source chunk and produce the natural two-part result; source padding lanes map + only to result padding lanes + +extf f8 -> f32: + result natural layout = deinterleaved=4 + source requires contiguous layout for the direct vcvt part=P0/P1/P2/P3 path + partial/tail source chunks are supported under the same one-source-chunk + contract; source padding lanes map only to result padding lanes + +truncf f32 -> f16: + can consume deinterleaved=2 and produce contiguous + current implementation records a deinterleaved=2 source use-site request and + inserts pto.vmi.ensure_layout when the source value solved to contiguous. + partial/tail source pairs are supported when the two deinterleaved source + parts pack into one contiguous result chunk; source padding lanes map only to + result padding lanes + +truncf f32 -> fp8-like: + can consume deinterleaved=4 and produce contiguous + current implementation records a deinterleaved=4 source use-site request and + inserts pto.vmi.ensure_layout when the source value solved to contiguous. + The lowering emits four pto.vcvt operations with part=P0/P1/P2/P3, then ORs + the mutually exclusive partial destination registers into one contiguous fp8 + result. This mirrors the hardware packed-4 contract: each source part owns + one quarter of the destination byte lanes, so the final externally visible + vector remains logical lane order 0..N-1 after the merge. + +bitcast: + source and result layouts must match + source/result total logical bits must match + current implementation supports identical physical arity when every source/result + physical chunk carries the same number of logical bits. This covers full chunks + and partial/tail chunks such as 65xf32 -> 130xi16, where the second physical + chunk carries 32 logical bits on both sides. Partial/tail bitcast remains + unsupported if source padding bits would become result logical bits. + +load/tile_read: + result layout chosen by consumers unless memory plan has a cheaper registered sink/source + +store/tile_write: + can consume any layout only if target registry has preserving store path + current implementation records a contiguous use-site request for vmi.store and + inserts pto.vmi.ensure_layout when the stored value class solved to a + non-contiguous layout. This makes externally visible memory order explicit in + IR before vmi-to-vpto. If explicit IR reaches vmi-to-vpto with a + deinterleaved=2/4 tail value, the direct lowering may still materialize it to + contiguous physical chunks first, but only when every deinterleaved part has + the same physical chunk count and therefore forms complete intlv groups. + +shuffle/channel_split/channel_merge: + default result layout contiguous unless target registry provides direct layout-preserving path + current implementation supports pto.vmi.shuffle when every result physical + chunk forwards one source physical chunk with identical lane positions for + all non-padding result lanes. Result padding lanes are ignored by the + forwarding proof and remain unobservable after physicalization. This allows + whole-chunk projection/reordering under contiguous or explicit deinterleaved + layouts, including tail-prefix projections such as `[0, 1, 2, 3] -> + !pto.vmi.vreg<4xf32>`. Arbitrary lane permutation remains unsupported unless + the vselr index-vector path below can materialize it. + current implementation supports channel_split/channel_merge for 2 or 4 + channels. channel_split consumes a natural deinterleaved=C source and produces + contiguous per-channel results; channel_merge consumes contiguous per-channel + inputs and produces a natural deinterleaved=C result. The direct path also + accepts partial/tail channel groups when the virtual deinterleaved=C channel + layout has the same physical arity as the source/result representation, so + every physical group can be materialized with complete intlv/dintlv pairs. + Arity-changing partial groups such as splitting 4xf32 into two 2xf32 channels + remain unsupported. If a producer/consumer + requires dense contiguous layout, pto.vmi.ensure_layout materializes the + pto.vdintlv/pto.vintlv tree explicitly. Non-matching layouts and other channel + counts remain unsupported. +``` + +### 5.3 Solver Order + +Implement deterministic solving: + +```text +1. Collect region/SCC constraints, including scf/cf/function/call boundaries. +2. Propagate impossible layouts and required mask granularities. +3. Pick a layout per node using minimum cost. +4. Tie-break: explicit layout already present on the VMI type, then natural layout, then contiguous. +5. Rewrite result/block/function types to layout-assigned VMI types. +6. Insert ensure_layout / ensure_mask_layout / ensure_mask_granularity at uses that need conversion. +7. Clone rematerializable producers per use when cheaper than conversion. +8. Run verifier gate. +``` + +Current implementation status: + +```text +implemented: + extf source -> contiguous use-site request for supported f16/fp8-like to f32 paths + truncf f32->f16 source -> deinterleaved=2 use-site request + truncf f32->fp8-like source -> deinterleaved=4 use-site request + single-use pto.vmi.load / tile_read results can adopt a consumer-requested + layout before type rewrite; this covers direct memory producers such as + load -> truncf without inserting a redundant ensure_layout + vmi.store data operand -> contiguous use-site request + explicit VMI vreg layout is preserved as an initial solver constraint + explicit concrete VMI mask layout/granularity is preserved as an initial solver constraint + channel_split source -> deinterleaved=C use-site request + channel_split results -> contiguous natural layout + channel_merge inputs -> contiguous use-site request + channel_merge result -> deinterleaved=C natural layout + shuffle without explicit layouts -> contiguous source use-site request and contiguous result natural layout + shuffle with explicit source/result layouts -> preserve explicit layouts and let vmi-to-vpto prove chunk forwarding + pto.vmi.ensure_layout insertion for non-contiguous store operands + pto.vmi.ensure_layout insertion for truncf source materialization + pto.vmi.ensure_mask_layout / ensure_mask_granularity insertion for select mask operands + pto.vmi.create_mask / constant_mask rematerialization for select mask operands when the consumer needs a + different mask layout/granularity + splat pto.vmi.constant rematerialization for data operands when the consumer needs + a different layout + pto.vmi.broadcast rematerialization for data operands when the consumer needs + a different layout + scf.execute_region result/yield layout equivalence + scf.index_switch result/yield layout equivalence + scf.while state layout equivalence + +not yet implemented: + generic per-consumer layout request table for every VMI op + producer rematerialization for non-splat data constants and other cheap producers + cost model / target capability registry +``` + +Do not implement a local greedy pattern pass that ignores block arguments or function signatures. + +### 5.4 CFG Rules + +CFG 处理分两层。第一层是必须做的 layout equivalence:同一个控制流值在 +result、yield、region/block argument 之间必须形成同一个 layout/mask 约束组。第二层才是 +layout conflict resolution:当同一个 producer 的不同 consumers 希望不同 layout 时,插入 +`ensure_layout`、`ensure_mask_layout` 或 rematerialize producer。 + +当前可落地的最小实现先做第一层。它不尝试在 branch 边界自动插入 conversion,因此下面这些 +关系一旦因为 natural layout 或 mask granularity 冲突无法合并,必须报 `VMI-LAYOUT-CONTRACT`, +不能默默选择某一边。 + +`scf.if` equivalence: + +```text +for each result index i: + scf.if result[i] + == then scf.yield operand[i] + == else scf.yield operand[i] +``` + +如果 value 是 `!pto.vmi.vreg`,合并 data layout 约束;如果 value 是 +`!pto.vmi.mask`,合并 mask layout 和 granularity 请求。这样 `%m = scf.if ... -> +!pto.vmi.mask` 后被 `vmi.select` 消费时,select 对 `%m` 推出的 `b8/b16/b32 + layout` +会传播回两边 yield 的 mask producer。 + +`scf.for` equivalence: + +```text +for each iter_arg index i: + init_arg[i] + == region_iter_arg[i] + == scf.yield operand[i] + == scf.for result[i] +``` + +这条规则避免 loop-carried value 每次迭代改变 layout。对于 `extf f16->f32` 作为 init、 +loop body 内部 `addf` 并 yield 的 case,`extf` 的 natural layout `deinterleaved=2` +必须稳定传递到 `%acc` region arg、`scf.yield` 和 loop result。 + +`cf.br` / `cf.cond_br` equivalence: + +```text +for each successor operand index i: + branch successor operand[i] + == successor block argument[i] +``` + +当前实现覆盖标准 `cf.br`、`cf.cond_br` 和 `cf.switch`。其中 `cf.switch` 的 default operands +与 default destination block arguments 按 index 建 layout 等价关系;每个 case operand segment +与对应 case destination block arguments 按 index 建 layout 等价关系。更泛化的 +`BranchOpInterface` op 如果携带 VMI type,后续要么补对应 mapping,要么在 layout assignment +阶段明确 diagnostic,不能让 hidden default layout 穿过去。 + +当前实现支持携带 VMI value 的 `scf.execute_region`:execute_region result 与直属 region terminator +`scf.yield` operands 按 result index 合并到同一个 layout 等价类。嵌套 region 内属于其他 op 的 +`scf.yield` 不参与 execute_region 的等价关系。 + +当前实现支持携带 VMI value 的 `scf.index_switch`:default/case region `scf.yield` operands 与 +index_switch results 按 result index 合并到同一个 layout 等价类。 + +当前实现支持携带 VMI value 的 `scf.while`:init operand、before region argument、`scf.condition` +forwarded operand、after region argument、after region `scf.yield` operand 和 while result 按状态 +index 合并到同一个 layout 等价类。`scf.condition` 的 i1 condition 本身不参与 VMI layout 约束。 + +Function boundary: + +```text +internal functions may get specialized layouted signatures +external ABI must not expose VMI layout +recursive SCC requires fixed-point signature layout +``` + +当前实现支持 direct `func.call` 到同一 module 内带 body 的 `func.func`: + +```text +call operand[i] == callee argument[i] +call result[i] == every callee return operand[i] +same-result-index return operands inside one callee are equivalent +``` + +如果携带 VMI type 的 call 无法解析到带 body 的 direct callee,layout assignment 必须报 +`VMI-LAYOUT-CONTRACT`。后续如需支持 public/external ABI,必须先定义 VMI 值如何在 ABI +边界 materialize,不能把 layouted VMI type 暴露出去。 +当前实现明确拒绝携带 VMI type 的 `func.call_indirect`,因为它没有可解析的 direct internal +callee signature/body 可参与 layout constraint solving。 + +当前实现对携带 VMI type 的 external function declaration 报 `VMI-LAYOUT-CONTRACT`,因为还没有 +定义 VMI value 的外部 ABI materialization plan。没有 VMI type 的 external declaration 必须在 +`rewriteFunctionType` 中保持原签名,不能因为没有 entry block arguments 被改写成空签名。 + +`ptoas --enable-vmi` 额外拒绝 public `func.func` 的 VMI-typed signature: + +```text +VMI-LAYOUT-CONTRACT: public VMI typed function requires an explicit external ABI materialization plan +``` + +这样 test-opt 仍可覆盖 internal/private function signature physicalization,用户入口则不会把 +layout-assigned VMI 值隐式暴露成 public ABI。 + +Slice 3 完成条件: + +```text +1. All VMI values have layout-assigned types after the pass. +2. All masks have b8/b16/b32 granularity after the pass. +3. CFG and call tests prove branch/yield/signature layout equality. +4. Multi-use rematerializable producer tests prove broadcast, constant, iota, + create_mask, and constant_mask rematerialization vs ensure_layout / + ensure_mask_* is deterministic. +5. The pass runs the layout-assigned VMI hard gate before returning, including + recursive TypeAttr/TypedAttr rejection; covered by + vmi_layout_assignment_post_gate_type_attr_invalid.pto. +``` + +## 6. Slice 4: `vmi-to-vpto` + +推荐实现为 pass: + +```text +recommended pass name: vmi-to-vpto +anchor: ModuleOp +source file: lib/PTO/Transforms/VMIToVPTO.cpp +``` + +第一步实现必须先落地 MLIR OneToN conversion 框架: + +```text +VMIToVPTOTypeConverter : OneToNTypeConverter: + !pto.vmi.vreg -> ordered !pto.vreg list + !pto.vmi.mask -> ordered !pto.mask list + +Structural patterns: + populateFuncTypeConversionPatterns + scf::populateSCFStructuralOneToNTypeConversions + project-local OneToN patterns for cf.br/cf.cond_br/cf.switch + project-local OneToN patterns for scf.execute_region/scf.index_switch + +VMI patterns: + OneToNOpConversionPattern for pack/unpack/ensure_*/semantic ops + +Final residual gate: + reject pto.vmi.*, !pto.vmi.*, unrealized_conversion_cast + scan SSA types, block argument types, function signatures, and op/module TypeAttr or TypedAttr payloads +``` + +这一步可以先支持 type-only physicalization 和 `pack/unpack` helper physicalization,但不能让未实现的 VMI semantic op 静默通过。 +如果还有 `pto.vmi.*` 或 VMI type 残留,必须报 `VMI-RESIDUAL-OP`。 + +当前 slice 支持 VMI function/input/block argument 展开成 physical arguments,并支持: + +```text +pto.vmi.unpack(layouted VMI aggregate) -> physical parts: + replace with OneToN adaptor source parts + +pto.vmi.pack(physical parts) -> layouted VMI aggregate: + replace with the physical parts through resultMapping + +pto.vmi.ensure_layout / ensure_mask_layout / ensure_mask_granularity: + ensure_layout must compare the original VMI source/result layout attrs, not only the converted physical type list. + If source/result layouts are identical, replace with source parts. This identity case supports partial/tail physical + chunks because no lane reordering or packing is performed. + If deinterleaved=2 -> contiguous, emit one pto.vintlv. + If contiguous -> deinterleaved=2, emit one pto.vdintlv. + If deinterleaved=4 -> contiguous, emit the two-level pto.vintlv tree. + If contiguous -> deinterleaved=4, emit the reverse two-level pto.vdintlv tree. + ensure_mask_layout supports the same contiguous <-> deinterleaved=2/4 layout conversions with predicate + rearrange ops: + deinterleaved=2 -> contiguous: pto.pintlv_b8/b16/b32 + contiguous -> deinterleaved=2: pto.pdintlv_b8/b16/b32 + deinterleaved=4 -> contiguous: two-level pto.pintlv_b8/b16/b32 tree + contiguous -> deinterleaved=4: two-level pto.pdintlv_b8/b16/b32 tree + ensure_mask_granularity supports concrete b8/b16/b32 logical predicate-preserving conversion: + widening b8 -> b16 -> b32: split each physical chunk with pto.punpack LOWER/HIGHER + narrowing b32 -> b16 -> b8: pack physical chunk pairs with pto.ppack LOWER/HIGHER and merge halves with pto.por + b8 <-> b32 conversions are lowered as two adjacent steps through b16. + +pto.vmi.broadcast: + current direct lowering requires the physical result element width to be 8, + 16, or 32 bits, because the vdup is predicated by pto.mask. + Other semantic element types need a dedicated materialization contract before + vmi-to-vpto may lower them. + for each physical result part: + materialize pto.pset_b8/b16/b32 "PAT_ALL" from the physical result element width + emit pto.vdup(scalar, all_true_mask) + This is layout-independent because every logical lane has the same scalar value. A deinterleaved layout simply + receives one identical vdup per partition/chunk; no vintlv/vdintlv is needed. + +pto.vmi.iota: + semantics: + ASC: result[lane] = base + lane + DESC: result[lane] = base - lane + supported element types follow pto.vci: + integer 8/16/32 and f16/f32 + contiguous full-chunk direct path: + for each physical chunk c: + chunk_base = base +/- c * lanes_per_part + emit pto.vci chunk_base {order = ASC|DESC} + deinterleaved layout requires strided index materialization because physical part p contains logical lanes: + p, p + factor, p + 2 * factor, ... + The required formula is: + ASC: base + p + factor * local_lane + DESC: base - p - factor * local_lane + The current lowering materializes this per physical chunk: + local = pto.vci 0 + scaled = pto.vmuls local, factor + ASC: result = pto.vadds scaled, base + part_offset + DESC: result = pto.vsub pto.vdup(base - part_offset), scaled + Partial/tail chunks are allowed. The physical padding lanes receive the natural continuation of the generated iota + sequence and remain padding/undef at the VMI semantic level; memory writes, masks, reductions, and other + externally-visible consumers must still obey the VMI padding rules. + +pto.vmi.constant_mask: + support dense bool constants for concrete b8/b16/b32 masks. For each physical chunk: + if the active lanes form a prefix: + emit pto.pset_b8/b16/b32 PAT_ALL, PAT_ALLF, or supported PAT_VL* + if a prefix count has no supported PAT_VL token, fall back to pto.plt_b8/b16/b32 with a constant i32 count + otherwise decompose the static bitset into active runs: + run [lo, hi) = prefix(hi) & ~prefix(lo) + combine runs with pto.por under an all-true predicate + pred-only masks remain unsupported until they have a concrete b8/b16/b32 consumer granularity. + +pto.vmi.mask_and / mask_or / mask_xor / mask_not: + for each physical predicate part: + materialize pto.pset_b8/b16/b32 "PAT_ALL" from the physical mask granularity + mask_and emits pto.pand(lhs_part, rhs_part, all_true_mask) + mask_or emits pto.por(lhs_part, rhs_part, all_true_mask) + mask_xor emits pto.pxor(lhs_part, rhs_part, all_true_mask) + mask_not emits pto.pnot(source_part, all_true_mask) + +pto.vmi.addf / addi / subf / subi / mulf / muli / divf / minf / maxf / negf / absf / absi / sqrt / exp / ln / relu / andi / ori / xori / shli / shrui / not: + current direct lowering requires the physical element width to be 8, 16, or + 32 bits, because every emitted VPTO op is predicated by a materialized + pto.mask. VMI types such as index or f64 remain valid semantic + surface types only after a dedicated lowering contract exists; until then + vmi-to-vpto must report VMI-UNSUPPORTED before OneToN conversion. + This common predicate-maskability rule is necessary but not sufficient for + every target op. Direct lowering must also preflight the concrete VPTO/VISA + element contract before OneToN rewriting: + addf/subf/mulf -> pto.vadd/vsub/vmul support f16/bf16/f32 floating types + divf -> pto.vdiv supports f16/f32 floating types + minf/maxf -> pto.vmin/vmax support f16/bf16/f32 floating types + negf/absf/sqrt/exp/ln/relu -> pto.vneg/vabs/vsqrt/vexp/vln/vrelu support f16/f32 floating types + absi -> pto.vabs supports signless/signed i8/i16/i32 integer types + bf16/f8 remain legal VMI float-like semantic types for the ops whose VMI + semantics allow them, but vmi-to-vpto must report VMI-UNSUPPORTED until a + materialization plan or wider target contract exists. + for each physical part: + materialize pto.pset_b8/b16/b32 "PAT_ALL" from the physical element width + addf/addi emit pto.vadd(lhs_part, rhs_part, all_true_mask) + subf/subi emit pto.vsub(lhs_part, rhs_part, all_true_mask) + mulf/muli emit pto.vmul(lhs_part, rhs_part, all_true_mask) + divf emits pto.vdiv(lhs_part, rhs_part, all_true_mask) + minf emits pto.vmin(lhs_part, rhs_part, all_true_mask) + maxf emits pto.vmax(lhs_part, rhs_part, all_true_mask) + negf emits pto.vneg(source_part, all_true_mask) + absf/absi emit pto.vabs(source_part, all_true_mask) + sqrt emits pto.vsqrt(source_part, all_true_mask) + exp emits pto.vexp(source_part, all_true_mask) + ln emits pto.vln(source_part, all_true_mask) + relu emits pto.vrelu(source_part, all_true_mask) + andi emits pto.vand(lhs_part, rhs_part, all_true_mask) + ori emits pto.vor(lhs_part, rhs_part, all_true_mask) + xori emits pto.vxor(lhs_part, rhs_part, all_true_mask) + shli emits pto.vshl(lhs_part, rhs_part, all_true_mask) + shrui emits pto.vshr(lhs_part, rhs_part, all_true_mask) + not emits pto.vnot(source_part, all_true_mask) + +pto.vmi.fma: + semantic: + result = fused_multiply_add(lhs, rhs, acc) + It must not be decomposed to pto.vmi.mulf + pto.vmi.addf because VPTO VMULA + may produce different floating-point results from separate multiply and add. + layout assignment: + lhs, rhs, acc, and result belong to one data layout equivalence class. + current direct lowering: + source/result element type must be f16, bf16, or f32 + for each physical part: + materialize pto.pset_b16/b32 "PAT_ALL" from the physical element width + emit pto.vmula(acc_part, lhs_part, rhs_part, all_true_mask) + The VMI operand order is lhs, rhs, acc; the VPTO operand order is acc, lhs, rhs. + +pto.vmi.cmpf / cmpi: + current direct lowering has the same 8/16/32-bit physical element-width + precondition as elementwise arithmetic, so the result predicate can be + materialized as b8/b16/b32. + target element contract: + cmpf: f16/bf16/f32, matching VISA VCMP floating-point element types + cmpi: signless/signed/unsigned i8/i16/i32, matching VISA VCMP integer element types + for each physical part: + materialize pto.pset_b8/b16/b32 "PAT_ALL" as the seed predicate + canonicalize predicate to VPTO cmp_mode eq/ne/lt/le/gt/ge + emit pto.vcmp(lhs_part, rhs_part, seed_mask, cmp_mode) + supported cmpf ordered aliases: + oeq -> eq + one -> ne + olt -> lt + ole -> le + ogt -> gt + oge -> ge + supported cmpi signed aliases: + slt -> lt + sle -> le + sgt -> gt + sge -> ge + unsupported floating-point predicates such as ord/uno/ult/ule/ugt/uge must emit VMI-UNSUPPORTED until NaN-aware + predicate construction is designed. + unsupported unsigned integer predicates ult/ule/ugt/uge must emit VMI-UNSUPPORTED until VPTO integer signedness + materialization is explicit. + +pto.vmi.active_prefix_index: + semantic: + idx[i] = popcount(mask[0 .. i)) + result element type must be signless i8/i16/i32, and concrete mask granularity must match the result element width. + current direct lowering: + only contiguous layout + only one physical result/mask chunk + result and mask chunks must be full, with no padding logical lanes + materialize a zero vreg carrier with pto.vdup + emit pto.vusqz(carrier, mask) + unsupported cases: + partial/tail chunks because padding mask lanes could affect the observable prefix + multi-chunk contiguous values need cross-chunk prefix carry + deinterleaved layouts need logical-lane-order prefix reconstruction + both must report VMI-UNSUPPORTED before OneToN conversion + +pto.vmi.compress: + semantic: + keep source lanes whose mask lane is true and compact them in logical lane order; inactive tail lanes are zero/undef + at the VMI semantic level unless consumed by an operation that defines them. + current direct lowering: + source/result/mask must be contiguous + source/result/mask must each materialize to one physical chunk + source chunk must be full, with no padding logical lanes + emit pto.vsqz(source, mask) + unsupported cases: + partial/tail chunks because padding mask lanes could be squeezed into the observable result prefix + multi-chunk values need cross-chunk compaction and SQZN/carry planning + deinterleaved layouts need logical-lane-order compaction before physical part placement + compress_store is not implied by register compress; store-coupled VSQZ #st=1 and VSTUR require a separate + producer/consumer pairing plan + +pto.vmi.compress_store: + semantic: + store source lanes whose mask lane is true as a dense logical memory stream: + k = 0 + for lane in logical order: + if mask[lane]: + base[offset + k] = value[lane] + k += 1 + layout assignment: + value use is requested as contiguous + mask use is requested as contiguous with granularity derived from value element width + current direct lowering: + value and mask must be contiguous + value and mask must each materialize to one physical chunk + the value chunk must be full, with no padding logical lanes + destination must be a UB !pto.ptr because pto.vstur is pointer-only and UB-only + lower as: + store_base = pto.addptr destination, offset + squeezed = pto.vsqz(value, mask) + align0 = pto.init_align + align1 = pto.vstur align0, squeezed, store_base, "POST_UPDATE" + pto.vstar align1, store_base + The pto.vstur user is the required consumer that lets the VPTO LLVM emitter + set VSQZ #st=1. A plain register pto.vsqz must not be assumed to enqueue + SQZN for store. + unsupported cases: + memref or GM destination until an explicit pointer/materialization plan exists + partial/tail physical chunks, because padding mask lanes could be squeezed into memory + multi-chunk values, because they need cross-chunk active-count compaction and SQZN/VSTUR state planning + deinterleaved layouts, because compaction must be in logical lane order + +pto.vmi.reduce_addi: + semantic: + acc = init[0] + for lane in logical order: + if mask[lane]: + acc = acc + source[lane] // integer wraparound addition + result[0] = acc + layout assignment: + source use is requested as contiguous + init use is requested as contiguous + result natural layout is contiguous + mask use is requested as contiguous with granularity derived from source element width + current direct lowering: + source element width must be 32 bits; narrower vcadd widens its result and needs a separate result type plan + source must materialize to one or more full physical chunks with no padding logical lanes + init/result must be rank-0 VMI vectors and each materialize to one physical chunk + mask must materialize to the same number of physical chunks as source + lower as: + first_lane = pto.pge_b32 "PAT_VL1" + acc = init + for each source_chunk, mask_chunk in physical order: + reduced = pto.vcadd(source_chunk, mask_chunk) + acc = pto.vadd(reduced, acc, first_lane) + result = acc + unsupported cases: + i8/i16 until widening result and init conversion are designed + partial/tail source chunks because padding lanes must not participate + floating-point add reduction without pto.vmi.reduce_addf {reassoc} + +pto.vmi.reduce_addf: + semantic: + requires {reassoc}; without it the verifier rejects the op + acc = init[0] + for lane in any reassociated tree over active logical lanes: + acc = acc + source[lane] + result[0] = acc + layout assignment: + source use is requested as contiguous + init use is requested as contiguous + result natural layout is contiguous + mask use is requested as contiguous with granularity derived from source element width + current direct lowering: + source element type must be f32 + source must materialize to one or more full physical chunks with no padding logical lanes + init/result must be rank-0 VMI vectors and each materialize to one physical chunk + mask must materialize to the same number of b32 physical chunks as source + lower as: + first_lane = pto.pge_b32 "PAT_VL1" + acc = init + for each source_chunk, mask_chunk in physical order: + reduced = pto.vcadd(source_chunk, mask_chunk) + acc = pto.vadd(reduced, acc, first_lane) + result = acc + unsupported cases: + missing reassoc attr + f16 until accumulator precision and rounding contract are designed + partial/tail source chunks because padding lanes must not participate + +pto.vmi.reduce_maxf / pto.vmi.reduce_minf: + semantic: + acc = init[0] + for each active logical lane in logical lane order: + reduce_maxf: acc = max(acc, source[lane]) + reduce_minf: acc = min(acc, source[lane]) + result[0] = acc + inactive lanes inside each physical chunk follow VPTO identities: + reduce_maxf uses pto.vcmax, where inactive FP lanes behave as -INF + reduce_minf uses pto.vcmin, where inactive FP lanes behave as +INF + NaN and signed-zero behavior follows pto.vcmax/pto.vcmin for the chunk + reduction and pto.vmax/pto.vmin for serial chunk accumulation. The index + lane produced by pto.vcmax/pto.vcmin is ignored because VMI exposes only the + rank-0 value result. + layout assignment: + source use is requested as contiguous + init use is requested as contiguous + result natural layout is contiguous + mask use is requested as contiguous with granularity derived from source element width + current direct lowering: + source element type must be f16 or f32 + source must materialize to one or more full physical chunks with no padding logical lanes + init/result must be rank-0 VMI vectors and each materialize to one physical chunk + mask must materialize to the same number of physical chunks as source + lower reduce_maxf as: + first_lane = pto.pge_b16/b32 "PAT_VL1" + acc = init + for each source_chunk, mask_chunk in physical order: + reduced = pto.vcmax(source_chunk, mask_chunk) + acc = pto.vmax(reduced, acc, first_lane) + result = acc + lower reduce_minf as: + first_lane = pto.pge_b16/b32 "PAT_VL1" + acc = init + for each source_chunk, mask_chunk in physical order: + reduced = pto.vcmin(source_chunk, mask_chunk) + acc = pto.vmin(reduced, acc, first_lane) + result = acc + unsupported cases: + bf16/fp8/f64 until VPTO reduction and combine semantics are designed + partial/tail source chunks because padding lanes must not participate + integer min/max until signed/unsigned and inactive identity contracts are explicit + +pto.vmi.select: + current direct lowering is a storage-width select rather than a semantic + arithmetic op: source/result physical elements must be b8/b16/b32-maskable, + but signedness and float-vs-integer interpretation are not inspected. + for each physical part: + consume the corresponding physical predicate part + emit pto.vsel(true_part, false_part, predicate_part) + +pto.vmi.extf, direct path: + support 16-bit float-like contiguous source part -> f32 deinterleaved=2 result parts + materialize pto.pset_b16 "PAT_ALL" + emit pto.vcvt(source_part, mask, part=EVEN/ODD) + partial/tail is valid when the logical lanes fit in the one physical source + part; PAT_ALL may convert padding lanes, but those lanes remain padding in + the deinterleaved result + support 8-bit contiguous source part -> f32 deinterleaved=4 result parts + materialize pto.pset_b8 "PAT_ALL" + emit pto.vcvt(source_part, mask, part=P0/P1/P2/P3) + the same padding rule applies + reject other extf width/layout shapes until their exact part plan is implemented + +pto.vmi.truncf, direct path: + support f32 deinterleaved=2 source parts -> 16-bit contiguous result part + materialize pto.pset_b32 "PAT_ALL" for the source conversion + emit pto.vcvt(even_f32_part, mask, rnd=R, sat=SAT, part=EVEN) + emit pto.vcvt(odd_f32_part, mask, rnd=R, sat=SAT, part=ODD) + materialize pto.pset_b16 "PAT_ALL" + merge mutually exclusive part results with pto.vor + partial/tail is valid when the two source parts pack into one physical + result part; converted padding lanes remain result padding + support f32 deinterleaved=4 source parts -> 8-bit contiguous result part + materialize pto.pset_b32 "PAT_ALL" for the source conversion + emit pto.vcvt(p0_f32_part, mask, rnd=R, sat=SAT, part=P0) + emit pto.vcvt(p1_f32_part, mask, rnd=R, sat=SAT, part=P1) + emit pto.vcvt(p2_f32_part, mask, rnd=R, sat=SAT, part=P2) + emit pto.vcvt(p3_f32_part, mask, rnd=R, sat=SAT, part=P3) + materialize pto.pset_b8 "PAT_ALL" + merge mutually exclusive part results with pto.vor + partial/tail is valid when the four source parts pack into one physical + result part; converted padding lanes remain result padding + reject other truncf width/layout shapes until their exact pack plan is implemented + +pto.vmi.bitcast: + for each physical part: + emit pto.vbitcast(source_part) -> result_part_type + source/result layouts must match, physical arity must match, and every + corresponding physical chunk must carry the same number of logical bits. + Padding bits may map only to result padding bits; any shape where source + padding would become result logical data remains unsupported. + +pto.vmi.channel_split / pto.vmi.channel_merge: + support 2-way and 4-way channel transforms for contiguous per-channel values + and matching deinterleaved=C merged values. + + channel_split C=2: + if the source layout is already deinterleaved=2, forward physical chunks + directly to the two contiguous channel results. + if the source layout is contiguous, source logical vector must physicalize + as 2*N contiguous chunks. For each pair of dense chunks: + %ch0_i, %ch1_i = pto.vdintlv %dense_2i, %dense_2i_plus_1 + Results are returned in per-channel order: + channel0 chunks..., channel1 chunks... + + channel_split C=4: + if the source layout is already deinterleaved=4, forward physical chunks + directly to the four contiguous channel results. + if the source layout is contiguous, source logical vector must physicalize + as 4*N contiguous chunks. The lowering is the same two-level pto.vdintlv + tree used by contiguous -> deinterleaved=4 materialization, but the + partition-major output is interpreted as four separate contiguous channel + results. + + channel_merge C=2/C=4: + inputs are consumed as per-channel contiguous chunks. + If the result layout is deinterleaved=C, the physical chunks are forwarded + directly in partition-major order. + If the result layout is contiguous, the lowering uses the reverse + pto.vintlv tree and returns dense contiguous chunks for the merged result. + + Unsupported: + channel counts other than 2 or 4 + non-matching channel input/result layouts + arity-changing or uneven partial physical channel groups that cannot form + complete intlv/dintlv groups + +pto.vmi.shuffle: + first try whole physical chunk forwarding cases: + source/result layouts are assigned + every non-padding lane in a result physical chunk maps to the same source physical chunk + source lane number equals result lane number inside the physical chunk + result padding lanes are ignored and remain semantically unobservable + + If forwarding fails, try vci-materializable vselr per physical chunk: + every result physical chunk has no padding lane + every lane in a result physical chunk maps to the same source physical chunk + source lane indices inside the chunk form one ASC or DESC consecutive sequence + materialize the index vector with pto.vci(base_lane, ASC|DESC) + emit pto.vselr(source_chunk, index_vector) + + Examples: + identity 128xf32 -> 128xf32: + indices = [0, 1, ..., 127] + forward dense chunks 0 and 1 + + second physical chunk 128xf32 -> 64xf32: + indices = [64, 65, ..., 127] + forward dense chunk 1 + + tail prefix 128xf32 -> 4xf32: + indices = [0, 1, 2, 3] + forward dense chunk 0 + lanes 4..63 of the physical result are padding lanes and are not part of + the logical vmi value + + chunk swap 128xf32 -> 128xf32: + indices = [64, 65, ..., 127, 0, 1, ..., 63] + forward dense chunks in order 1, 0 + + reverse one 64xf32 chunk: + indices = [63, 62, ..., 0] + index = pto.vci 63 {order = DESC} : i32 -> !pto.vreg<64xi32> + result = pto.vselr source_chunk, index + + Unsupported: + partial physical chunk projection whose observable result lanes are not + padding-safe forwarding, e.g. [1, 2, 3, 4] -> 4xf32 when it would require + shifting lanes rather than forwarding a whole physical chunk + broadcast, duplicate lanes, arbitrary non-affine permutation + current implementation emits VMI-UNSUPPORTED for these cases before + OneToN conversion, instead of leaving a generic residual VMI op. +``` + +`func.return` 携带 VMI operand 时必须通过 OneToN func/return structural pattern 展开成 physical +return operands。不能只取第一个 physical part;这种错误会导致函数类型已经返回两个 physical value, +但 `func.return` 只返回一个 value。 + +### 6.1 Type Conversion + +Use one shared physicalization helper: + +```text +VMIVRegType -> N physical !pto.vreg +VMIMaskType -> N physical !pto.mask +``` + +Physical result ordering must be: + +```text +contiguous: + chunk0, chunk1, ... + +deinterleaved=K: + p0_chunk0, p0_chunk1, ..., p1_chunk0, ..., p(K-1)_chunkN +``` + +### 6.2 Structural Conversion + +The pass must convert: + +```text +operation results +block arguments +branch operands +cf.br / cf.cond_br successor block signatures +scf.if results and yields +scf.for iter_args and yields +func arguments/results +call operands/results +return operands +cf.br / cf.cond_br / cf.switch block arguments and successor operands +scf.execute_region results and yields: + current implementation uses a project-local OneToN structural pattern. +scf.index_switch results and yields: + current implementation uses a project-local OneToN structural pattern. +``` + +Do not rely on a defining op to recover parts. Any VMI value may come from a block argument or function +argument, so `unpack` must be valid on arbitrary layout-assigned VMI SSA values before final lowering. + +### 6.3 Op Lowering + +Internal helper lowering: + +```text +unpack: + replace with physical values in helper ordering + +pack: + materialize one logical VMI aggregate before it is immediately consumed by another VMI helper + must not remain after final gate + +ensure_layout: + preflight: + source/result must have computable physical arity + source/result physical arity must match + identity source/result layouts do not require full chunks + if source/result layouts differ, either: + every source/result physical chunk is full, with no padding lanes; or + source/result both have complete contiguous/deinterleaved=2/4 materialization groups and their materialized + physical arity still equals the original VMI physical arity + arity-changing partial/tail layout conversion remains unsupported because it would need an explicit padding + packing/drop plan + otherwise report VMI-UNSUPPORTED before OneToN conversion + + compare the original VMI source/result layout attrs: + same layout: + forward the converted source parts + deinterleaved=2 -> contiguous: + %d0, %d1 = pto.vintlv %p0, %p1 + contiguous -> deinterleaved=2: + %p0, %p1 = pto.vdintlv %d0, %d1 + deinterleaved=4 -> contiguous: + %a0, %a1 = pto.vintlv %p0, %p2 + %b0, %b1 = pto.vintlv %p1, %p3 + %d0, %d1 = pto.vintlv %a0, %b0 + %d2, %d3 = pto.vintlv %a1, %b1 + contiguous -> deinterleaved=4: + %a0, %b0 = pto.vdintlv %d0, %d1 + %a1, %b1 = pto.vdintlv %d2, %d3 + %p0, %p2 = pto.vdintlv %a0, %a1 + %p1, %p3 = pto.vdintlv %b0, %b1 + + It is a bug to treat layout conversion as identity merely because both sides convert to the same + number of physical !pto.vreg values with the same type. For example: + !pto.vmi.vreg<128xf32, deinterleaved=2> + !pto.vmi.vreg<128xf32, contiguous> + both physicalize to two !pto.vreg<64xf32> values, but their logical lane order differs. + +ensure_mask_layout: + preflight: + source/result must have computable physical arity + source/result physical arity must match + if source/result layouts differ, every source/result physical predicate chunk must be full, with no padding lanes + identity source/result layouts do not require full chunks + otherwise report VMI-UNSUPPORTED before OneToN conversion + + same-layout: + forward source parts + deinterleaved=2 -> contiguous: + use pto.pintlv_b8/b16/b32 on each partition pair + contiguous -> deinterleaved=2: + use pto.pdintlv_b8/b16/b32 on each dense pair + deinterleaved=4 -> contiguous: + use the same two-level tree as data layout conversion, replacing pto.vintlv with pto.pintlv_b8/b16/b32 + contiguous -> deinterleaved=4: + use the reverse two-level tree, replacing pto.vdintlv with pto.pdintlv_b8/b16/b32 + source/result granularity must be identical; granularity conversion belongs to ensure_mask_granularity. + +ensure_mask_granularity: + source/result layout and logical lane count must match. + source/result granularity must be concrete b8/b16/b32. + identity conversion forwards physical parts. + widening conversion: + b8 -> b16 or b16 -> b32 uses pto.punpack LOWER/HIGHER for each source physical chunk. + each source physical mask chunk can produce up to two result chunks in logical order. + narrowing conversion: + b32 -> b16 or b16 -> b8 uses pto.ppack LOWER for the low source chunk. + if a high source chunk exists, use pto.ppack HIGHER and merge the two partial masks with pto.por under PAT_ALL. + this handles odd tail groups because the missing high half is padding and remains zero. + multi-step conversion: + b8 -> b32 is b8 -> b16 -> b32. + b32 -> b8 is b32 -> b16 -> b8. +``` + +Elementwise lowering: + +```text +for each physical part: + lower add/cmp/select to corresponding VPTO op sequence + preserve source/result physical ordering + cmp predicates must be canonicalized before creating pto.vcmp: + eq/ne/lt/le/gt/ge pass through + ordered FP aliases oeq/one/olt/ole/ogt/oge map to eq/ne/lt/le/gt/ge + signed integer aliases slt/sle/sgt/sge map to lt/le/gt/ge + unordered/NaN-sensitive FP predicates are unsupported until represented explicitly + unsigned integer predicates are unsupported until signedness is represented explicitly +``` + +Producer lowering: + +```text +broadcast: + TypeConverter gives the ordered result physical types. + For each result physical vreg: + create all-true mask with the vreg element width + emit pto.vdup scalar -> that physical vreg + + This is valid for contiguous and deinterleaved layouts because splat has no lane-order dependence. + +constant: + Splat dense constants use the same path as broadcast: + create scalar arith.constant from the splat attribute + emit pto.vdup per physical result part + require the same 8/16/32-bit physical result element-width precondition as + broadcast + Non-splat dense constants need an explicit constant materialization strategy or must remain unsupported with a + precise diagnostic; do not synthesize an arbitrary lane sequence by scalar inserts unless that path is designed. + +create_mask / constant_mask: + constant active_lanes create_mask lowers per physical mask part: + clamp active_lanes to [0, logical lane count] + compute active prefix count for each physical mask chunk with the VMI lane-map helper + emit pto.pge_b8/b16/b32 PAT_ALL, PAT_ALLF, or supported PAT_VL* + if a chunk prefix count has no supported PAT_VL token, fall back to pto.plt_b8/b16/b32 with a constant i32 count + Dynamic active_lanes with contiguous layout lowers by chaining pto.plt_b8/b16/b32 over the physical chunks: + active_i32 = arith.index_cast active_lanes : index to i32 + active_i32 = minui(maxsi(active_i32, 0), logical_lane_count) + mask0, remaining0 = pto.plt_b* active_i32 + mask1, remaining1 = pto.plt_b* remaining0 + ... + Dynamic active_lanes with deinterleaved layout remaps one logical prefix into per-part dynamic lane counts before + chaining pto.plt_b*: + active_i32 = minui(maxsi(index_cast(active_lanes), 0), logical_lane_count) + part_count(part) = (active_i32 + factor - 1 - part) / factor + then chain pto.plt_b* independently for each partition in VMI physical order: + p0 chunks..., p1 chunks..., ... + dense constant_mask lowers per physical mask part: + first map logical lanes to physical predicate lanes using the assigned VMI layout + prefix chunks emit pto.pset_b8/b16/b32 PAT_ALL, PAT_ALLF, or supported PAT_VL* + if a prefix count has no supported PAT_VL token, emit pto.plt_b8/b16/b32 with a constant i32 count + non-prefix chunks are decomposed into static active runs: + prefix(hi) = pto.pge/plt for the run end + prefix(lo) = pto.pge/plt for the run begin + run = prefix(hi) & ~prefix(lo) using pto.pnot + pto.pand + chunk = run0 | run1 | ... using pto.por + +Unsupported diagnostics: + unexpected residual dynamic pto.vmi.create_mask after OneToN conversion: + VMI-UNSUPPORTED: dynamic pto.vmi.create_mask active_lanes could not be lowered by the current runtime predicate + generation plan + This is a final-gate diagnostic for malformed or newly unsupported dynamic shapes. The supported dynamic + contiguous/deinterleaved=2/deinterleaved=4 paths above must lower before this residual gate. + + non-splat pto.vmi.constant: + VMI-UNSUPPORTED: non-splat pto.vmi.constant requires a vreg immediate or scratch materialization plan + + partial/tail pto.vmi.load/tile_read: + VMI-UNSUPPORTED: pto.vmi. requires full physical chunks without padding lanes or a statically safe + full-read footprint (...; safe-read proof failed: ...) + GM-backed direct pto.vmi.load/masked_load/expand_load/tile_read: + VMI-UNSUPPORTED: pto.vmi. ... (source is GM-backed, but current direct VMI-to-VPTO memory lowering + emits pto.vlds/pto.vsts and requires UB-backed memory) + unsupported partial/tail pto.vmi.store/masked_store/tile_write: + VMI-UNSUPPORTED: pto.vmi. requires an 8/16/32-bit predicate-maskable element type and either full + physical chunks or contiguous/deinterleaved tail-store materialization, with UB-backed destination; unsupported + cases include values such as f64/index that have no b64 predicate representation, GM-backed destinations that + still need a memory movement/materialization plan, and uneven deinterleaved physical groups that cannot form + complete intlv groups + + unsupported non-identity partial/tail pto.vmi.ensure_layout: + VMI-UNSUPPORTED: pto.vmi.ensure_layout cannot materialize the requested data layout conversion; unsupported cases + include arity-changing partial/tail conversion and uneven deinterleaved groups that cannot form complete intlv + groups + If the helper has a single consumer, the main diagnostic is emitted on the + consumer op and operand, including both the actual operand VMI type and the + required VMI type. For example, pto.vmi.truncf operand #0 can report + `!pto.vmi.vreg<128xf32, contiguous>` vs. + `!pto.vmi.vreg<128xf32, deinterleaved=4>` for f32->fp8. The failed + pto.vmi.ensure_layout conversion is attached as a note. + + unsupported non-identity partial/tail pto.vmi.ensure_mask_layout: + VMI-UNSUPPORTED: pto.vmi.ensure_mask_layout cannot materialize the requested mask layout conversion; unsupported + cases include arity-changing partial/tail conversion and uneven deinterleaved groups that cannot form complete + predicate intlv groups + + unsupported pto.vmi.ensure_mask_granularity: + VMI-UNSUPPORTED: non-identity mask granularity materialization requires concrete b8/b16/b32 masks with matching + lane count and layout (...) + + unsupported pto.vmi.extf direct path shape: + VMI-UNSUPPORTED: pto.vmi.extf supports only one contiguous 16-bit float-like or fp8-like physical source chunk to f32 + deinterleaved=2/4 results; partial/tail is allowed only when source padding maps to result padding + + unsupported pto.vmi.truncf direct path shape: + VMI-UNSUPPORTED: pto.vmi.truncf supports only f32 deinterleaved=2 source parts to one contiguous f16 result chunk + or f32 deinterleaved=4 source parts to one contiguous fp8-like result chunk + + unsupported pto.vmi.bitcast shape: + VMI-UNSUPPORTED: pto.vmi.bitcast requires matching source/result layouts with identical physical arity and matching + per-chunk logical bit footprints (...) + + unsupported pto.vmi.channel_split / pto.vmi.channel_merge channel count: + VMI-UNSUPPORTED: pto.vmi.channel_split supports only 2 or 4 channels + VMI-UNSUPPORTED: pto.vmi.channel_merge supports only 2 or 4 channels + unsupported pto.vmi.channel_split / pto.vmi.channel_merge layout: + VMI-UNSUPPORTED: pto.vmi.channel_split requires source layout to be contiguous or matching deinterleaved channel + layout, and every result layout to be contiguous + VMI-UNSUPPORTED: pto.vmi.channel_merge requires every input layout to be contiguous and result layout to be + contiguous or matching deinterleaved channel layout +``` + +Width conversion lowering: + +```text +f16 -> f32: + supported direct path when source is contiguous and result is deinterleaved=2: + pto.vcvt part=EVEN produces logical lanes 0,2,4,... + pto.vcvt part=ODD produces logical lanes 1,3,5,... + source/result physical arity must be 1 -> 2 + +f8 -> f32: + supported direct path when source is contiguous and result is deinterleaved=4: + pto.vcvt part=P0/P1/P2/P3 produces the four modulo-4 lane partitions + source/result physical arity must be 1 -> 4 + +f32 -> f16: + supported direct path when source is deinterleaved=2 and result is contiguous: + pto.vcvt part=EVEN consumes even/source part 0 + pto.vcvt part=ODD consumes odd/source part 1 + pto.vor merges mutually exclusive f16 part results into one contiguous vreg + source/result physical arity must be 2 -> 1 + current default conversion attrs are rnd=R, sat=SAT +``` + +Memory lowering: + +```text +vmi.load: + current direct memory path first reads contiguous physical chunks. The logical lane count must be an exact multiple + of the physical vreg lane count. + For each contiguous physical chunk i: + offset_i = base_offset + i * lanesPerPart + dense_i = pto.vlds base[offset_i] + + If the requested VMI result layout is contiguous, return the dense chunks directly. + If the requested VMI result layout is deinterleaved=2: + prefer pto.vldsx2 "DINTLV_B8/B16/B32" per physical chunk group: + %p0_i, %p1_i = pto.vldsx2 base[offset_i], "DINTLV_B*" + return results in VMI partition-major order: + p0_chunk0, p0_chunk1, ..., p1_chunk0, p1_chunk1, ... + If the requested VMI result layout is deinterleaved=4 with exactly four physical parts: + use dense pto.vlds chunks followed by the reverse two-level pto.vdintlv tree. + + For larger multi-chunk deinterleaved=4 loads, apply the same conversion per contiguous chunk group and return + physical parts in VMI partition-major order: + deinterleaved=4: p0_chunks..., p1_chunks..., p2_chunks..., p3_chunks... + +vmi.store: + direct lowering requires value element width to be 8, 16, or 32 bits so the + emitted pto.vsts/pto.vstsx2 predicate can be materialized as b8/b16/b32. + contiguous layout with full physical chunks: + offset_i = base_offset + i * lanesPerPart + mask_i = pto.pset_b8/b16/b32 "PAT_ALL" + pto.vsts value_i, base[offset_i], mask_i + contiguous layout with a final partial physical chunk: + full chunks still use PAT_ALL + the final chunk computes valid_lanes = logical_lane_count - chunk_i * lanesPerPart + tail_mask_i = pto.plt_b8/b16/b32(valid_lanes) + pto.vsts tail_value_i, base[offset_i], tail_mask_i + padding lanes therefore have no externally visible store effect. + +deinterleaved store: + deinterleaved=2 with full physical chunks: + prefer pto.vstsx2 "INTLV_B8/B16/B32" per physical chunk group: + pto.vstsx2 p0_i, p1_i, base[offset_i], "INTLV_B*", all_true_mask + offset_i = base_offset + i * 2 * lanesPerPart + the vstsx2 dist mode writes logical lane 0,1,2,3,... order externally. + + current safe path lowers through proven register materialization before store: + deinterleaved=4 with exactly four physical parts: + use the two-level pto.vintlv tree, then store %d0/%d1/%d2/%d3 as contiguous chunks + + Larger multi-chunk deinterleaved=4 values use the same conversion per chunk group. The final store order is dense + chunk order, so external memory observes logical lane 0,1,2,... order. + +vmi.masked_load: + semantics: + if mask[lane] is true, result[lane] = memory[base + lane] + if mask[lane] is false, result[lane] = passthru[lane] + inactive mask lanes do not by themselves permit unsafe memory reads + current direct path: + result, passthru, and mask are requested as contiguous + full physical chunks can always use pto.vlds because every loaded lane is logical + partial/tail chunks require the same statically safe full-read proof as vmi.load + for each contiguous physical chunk i: + loaded_i = pto.vlds base[offset_i] + result_i = pto.vsel loaded_i, passthru_i, mask_i + unsupported cases: + non-contiguous layouts + unsafe partial/tail read footprints + target true masked/non-faulting load and guarded/scratch fallback + +vmi.gather: + semantics: + if mask[lane] is true, result[lane] = memory[base + indices[lane]] + if mask[lane] is false, result[lane] = passthru[lane] and no memory read occurs for that lane + indices are interpreted in element units, not bytes + layout assignment: + result natural layout is contiguous + indices and passthru uses are requested as contiguous + mask use is requested as contiguous with granularity derived from result element width + current direct path: + source must be !pto.ptr + T must be a 32-bit element type + indices must be signless or unsigned i32 + result / indices / passthru / mask must be contiguous full physical chunks + mask granularity must be b32 + for each physical chunk i: + gathered_i = pto.vgather2_bc source, indices_i, mask_i + result_i = pto.vsel gathered_i, passthru_i, mask_i + reason for vsel: + VGATHER2_BC false predicate lanes do not read memory but produce zero; VMI false lanes preserve passthru. + unsupported cases: + f16/b16/f8/i8 result element types + partial/tail chunks + non-contiguous layouts + memref/gm source + guarded/scratch fallback + +vmi.scatter: + semantics: + if mask[lane] is true, memory[base + indices[lane]] = value[lane] + if mask[lane] is false, no memory write occurs for that lane + indices are interpreted in element units, not bytes + if two active lanes have the same index, VMI logical semantics require an ordered conflict policy or an explicit + no-conflict proof before direct target lowering + layout assignment: + value and indices uses are requested as contiguous + mask use is requested as contiguous with granularity derived from value element width + current direct path: + op must carry {indices_unique} + destination must be !pto.ptr + T must be a 32-bit element type + indices must be signless or unsigned i32 + value / indices / mask must be contiguous full physical chunks + mask granularity must be b32 + for each physical chunk i: + pto.vscatter value_i, destination, indices_i, mask_i + reason for indices_unique: + VSCATTER false predicate lanes do not write, but duplicate active indices have target-defined/undefined grant + behavior. VMI cannot lower duplicate-index logical order semantics to VSCATTER without a proof or fallback. + unsupported cases: + missing indices_unique proof + f16/b16/f8/i8 value element types + partial/tail chunks + non-contiguous layouts + memref/gm destination + ordered duplicate-index fallback + +vmi.expand_load: + semantics: + k = 0 + for lane in logical order: + if mask[lane]: + result[lane] = memory[base + k] + k += 1 + else: + result[lane] = passthru[lane] + layout assignment: + result natural layout is contiguous + passthru use is requested as contiguous + mask use is requested as contiguous with granularity derived from result element width + current direct path: + static all-active path: + pto.vmi.create_mask with constant active_lanes >= logical lane count + dense all-true pto.vmi.constant_mask + in that case expand_load degenerates to ordinary vmi.load: + for each contiguous physical chunk i: + loaded_i = pto.vlds base[offset_i] + result_i = loaded_i + partial/tail chunks still require the same statically safe full-read proof as vmi.load. + runtime-mask path: + source must be !pto.ptr + T must be a 32-bit element type + result / passthru / mask must be contiguous one full physical chunk + mask granularity must be b32 + base_i = pto.addptr source, offset + indices_i = pto.vusqz(zero_i32_carrier, mask_i) + loaded_i = pto.vgather2_bc base_i, indices_i, mask_i + result_i = pto.vsel loaded_i, passthru_i, mask_i + unsupported cases: + runtime masks across multiple physical chunks + runtime masks on non-32-bit element types + non-contiguous layouts + unsafe partial/tail read footprints + guarded load or scratch fallback + +vmi.masked_store: + semantics: + if mask[lane] is true, store value[lane] + if mask[lane] is false, no memory write occurs for that logical lane + current full-footprint path: + value and mask are requested as contiguous at the use site + mask granularity is derived from value element width + for each contiguous physical chunk i: + offset_i = base_offset + i * lanesPerPart + pto.vsts value_i, base[offset_i], mask_i + contiguous layout with a final partial physical chunk: + full chunks store with the user mask directly + the final chunk computes tail_valid_i with pto.plt_b8/b16/b32(valid_lanes) + store_mask_i = pto.pand user_mask_i, tail_valid_i, all_true_mask_i + pto.vsts tail_value_i, base[offset_i], store_mask_i + padding lanes and user-inactive lanes therefore both have no write effect. + If the incoming value/mask are deinterleaved, layout assignment inserts + ensure_layout/ensure_mask_layout or the vmi-to-vpto pattern materializes the same contiguous representation before + emitting stores. This preserves logical memory order and keeps inactive lanes write-free. + +non-full chunks: + vmi.store, vmi.masked_store, and vmi.tile_write support contiguous tail chunks by predicating the final pto.vsts with + a prefix valid mask. masked_store additionally ANDs the user mask with the tail-valid mask. + deinterleaved=2/4 tail store/masked_store/tile_write is supported only through explicit layout materialization to + contiguous chunks first. This requires every deinterleaved part to have the same physical chunk count, so the + materializer can build complete vintlv/pintlv groups. After materialization, each contiguous chunk is predicated by + the logical tail-valid mask; chunks whose active logical lane count is zero are not emitted as stores. Uneven + deinterleaved groups, such as 129xf32 with deinterleaved=2, remain unsupported until a padding/scratch plan can + assemble only the observable contiguous chunks. + vmi.load and tile_read support partial/tail chunks only when the direct full physical read is statically safe: + statically shaped memref source, constant non-negative offset (or tile_read offset 0), and enough elements for the + whole physical read footprint. Padding lanes must never become observable. Other partial/tail load cases still need + scratch/guarded/true-masked load planning. + +vmi.tile_read / vmi.tile_write, current direct full-footprint path: + This is not transfer_read padding lowering. It is only the tile/memref equivalent of the full-chunk direct memory + path above. + + tile_read: + source must lower to one VPTO buffer-like value. + logical lane count must be an exact multiple of the physical lanes per part. + use offset 0 as the tile base offset. + contiguous result layout reads physical chunks with pto.vlds. + deinterleaved=2 result layout prefers pto.vldsx2 "DINTLV_B8/B16/B32" with offset 0. + other supported layouts materialize the requested result layout after contiguous reads. + + tile_write: + destination must lower to one VPTO buffer-like value. + use offset 0 as the tile base offset. + value element width must be 8, 16, or 32 bits so pto.vsts/pto.vstsx2 can receive a materialized predicate. + contiguous source layout stores every physical chunk with pto.vsts and an all-true mask. + if the final contiguous chunk is partial, store it with a prefix valid-lane mask. + deinterleaved=2 source layout prefers pto.vstsx2 "INTLV_B8/B16/B32" with offset 0. + other supported layouts materialize the source value to contiguous layout first. + deinterleaved=2/4 tail source layouts are supported through this materialization path only when every + deinterleaved part has the same physical chunk count; zero-active materialized chunks are skipped. + + Unsupported: + padding value semantics + partial/tail tile footprints + transfer_read-style out-of-bounds reads + write masks + non-identity tile indexing/permutation + any path that would expose padding lanes or reorder externally visible memory +``` + +Final hard gate: + +```text +no pto.vmi op remains +no !pto.vmi.* type remains, including in function signatures +no UnrealizedConversionCastOp remains +physical arity matches helper for every lowered value +``` + +Slice 4 完成条件: + +```text +1. `f16 -> f32 -> add -> store` lowers with deinterleaved=2 and stores contiguous logical order. + Covered by vmi_to_vpto_e2e_widen_add_store.pto. +2. `f8 -> f32 -> add -> store` lowers with deinterleaved=4 and stores contiguous logical order. + Covered by vmi_to_vpto_e2e_widen_add_store.pto. +3. Non-full memory physical arity and valid lane map are tested. + Covered by vmi_to_vpto_load_nonfull_invalid.pto, vmi_to_vpto_store_deint_invalid.pto, + vmi_to_vpto_load_safe_tail_memref.pto, + vmi_to_vpto_load_safe_tail_memref_negative_offset_invalid.pto, + vmi_to_vpto_masked_load_safe_tail_memref.pto, + vmi_to_vpto_masked_load_safe_tail_memref_negative_offset_invalid.pto, + vmi_to_vpto_expand_load_all_active.pto, + vmi_to_vpto_expand_load_all_active_negative_offset_invalid.pto, and multi-chunk load/store layout tests. +4. Full-footprint tile_read/tile_write direct path lowers through pto.vlds/pto.vsts or deinterleaved=2 x2 dist + instructions with offset 0. + Covered by vmi_to_vpto_tile_read_write.pto. +5. Internal func.call boundaries expand callee signatures, call operands/results, and returned VMI values together. + Covered by vmi_layout_assignment_call_boundary.pto, vmi_layout_assignment_indirect_call_invalid.pto, + and vmi_to_vpto_call_boundary.pto. +6. Structured control-flow carrying VMI values expands iter args, yields, results, masks, and returns together. + Covered by vmi_layout_assignment_cf_switch.pto, + vmi_layout_assignment_scf_execute_region.pto, + vmi_layout_assignment_scf_index_switch.pto, + vmi_layout_assignment_scf_while.pto, vmi_to_vpto_cf_branch.pto, + vmi_to_vpto_scf_for.pto, vmi_to_vpto_scf_if.pto, and the user-facing + vmi_ptoas_cli_control_flow.pto. +7. Final gate rejects residual VMI helper and unrealized casts. + Covered by vmi_to_vpto_ensure_identity.pto, + vmi_to_vpto_ensure_layout_partial_invalid.pto, + vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto, + vmi_to_vpto_ensure_mask_layout_partial_invalid.pto, + vmi_to_vpto_unsupported_op_invalid.pto, + vmi_to_vpto_unrealized_cast_residual_invalid.pto, + vmi_to_vpto_type_attr_residual_invalid.pto, and per-feature unsupported + tests. +8. Same-family indirect memory ops reject unsupported direct-lowering shapes consistently. + Covered by vmi_to_vpto_gather_scatter_shape_invalid.pto together with the existing gather/scatter positive and + per-feature negative tests. +9. Same-family reduction ops reject unsupported direct-lowering shapes consistently. + Covered by vmi_to_vpto_reduce_shape_invalid.pto together with the existing reduce add/min/max positive and + per-feature negative tests, including vmi_to_vpto_reduce_addi_i16_invalid.pto and + vmi_to_vpto_reduce_addf_f16_invalid.pto. +10. Target-specific element contracts are checked before OneToN rewriting for direct VPTO ops. + Covered by vmi_to_vpto_bf16_arith.pto, vmi_to_vpto_math_element_type_invalid.pto, + vmi_to_vpto_cmp_select.pto, vmi_to_vpto_cmp_element_type_invalid.pto, + vmi_to_vpto_fma.pto, vmi_to_vpto_fma_element_type_invalid.pto, and + vmi_to_vpto_unary_math.pto for negf/absf/absi/sqrt/exp/ln/relu, plus + vmi_to_vpto_relu_element_type_invalid.pto. +11. Same-family mask logic ops lower through the physical mask granularity instead of assuming b32 masks. + Covered by vmi_to_vpto_mask_logic.pto for mask_and/mask_or/mask_xor/mask_not on b32 masks produced by + cmpf and on direct b8/b16 mask operands. +``` + +## 7. Slice 5: Tile Memory And Padding + +The Slice 4 direct path may lower full-footprint `tile_read/tile_write` with offset 0. For partial `tile_read`, it may +also lower to plain `pto.vlds` only when the static safe-read proof above succeeds. Do not lower any other partial or +padded `tile_read` as a plain load until a richer access plan proves it is safe. + +Implement an internal `VMIMemoryAccessPlan`: + +```text +base +logical lane count +logical_shape +permutation_map +lane-to-address map in element units +validMask +paddingValue +safeReadProof +writeMask +target capability decision +fallback resource decision +``` + +Current implementation status: + +```text +lib/PTO/Transforms/VMIToVPTO.cpp + VMIMemoryAccessPlan + VMIMemorySafeReadProof + VMIMemoryLogicalShape + VMIMemoryLaneAddressMap + VMIMemoryFallbackDecision + +currently routed through the plan: + contiguous identity logical_shape/permutation/lane-to-address map in element units + explicit rejection of non-identity memref layouts until subview/affine lane maps are represented + covered by vmi_to_vpto_memref_layout_invalid.pto, including a memref.subview-produced strided view + subview diagnostics name the missing normalized base/offset/stride lane-to-address plan + target true masked/non-faulting load capability query + current result is missing capability because pto.vlds has no mask operand + covered by vmi_to_vpto_masked_load_nonfull_invalid.pto + stable gather masked-load option + covered by vmi_to_vpto_stable_gather_masked_load_todo_invalid.pto + currently emits a TODO diagnostic instead of lowering through VGATHER2 + direct pto.vmi.load partial/tail safe full-read proof + pto.vmi.masked_load partial/tail safe full-read proof + pto.vmi.expand_load static all-active safe full-read proof + VMI-to-VPTO rewrite match guard for load/tile_read full-or-safe reads + pto.vmi.store/tile_write direct write target decision with all-true writeMask kind + pto.vmi.masked_store direct write target decision with explicit writeMask kind + unsafe partial/tail read fallback decision as RequiredUnavailable diagnostic + covered by vmi_to_vpto_load_nonfull_invalid.pto, + vmi_to_vpto_masked_load_nonfull_invalid.pto, and + vmi_to_vpto_expand_load_all_active_negative_offset_invalid.pto + +currently not implemented by the plan: + paddingValue materialization (intentionally unsupported in the first implementation stage) + non-all-true validMask direct masked/non-faulting load lowering + scratch/guarded fallback lowering or allocation + lowering for non-identity logical_shape/permutation_map/lane-to-address maps, including subview or affine lane maps + writeMask fallback planning beyond the existing contiguous tail-store predicate path +``` + +Important first-stage contract: + +```text +VMI physical tail lanes and transfer paddingValue are different concepts. + +Physical tail lanes: + arise because pto.vreg is fixed at 256 bytes + are outside the logical VMI lane count + may be read/computed only when the extra lanes remain unobservable + +transfer_read-style paddingValue: + is an observable logical result for invalid/OOB transfer lanes + cannot be dropped or replaced by arbitrary physical tail contents + is not materialized by the first-stage VMI implementation + +Therefore any frontend path that still needs transfer_read paddingValue +semantics must stop before direct VMI-to-VPTO lowering with VMI-UNSUPPORTED, +unless it has already canonicalized to an all-valid load/masked_load subset +whose invalid lanes are proven absent. +``` + +`tile_read` decision tree: + +```text +safeReadProof full && validMask all true: + direct load + +safeReadProof full && validMask not all true: + first-stage: VMI-UNSUPPORTED because paddingValue materialization is not implemented + future: full load + padding materialization + select + +target true masked/non-faulting load: + first-stage: VMI-UNSUPPORTED because true masked/non-faulting load and paddingValue materialization are not implemented + future: masked load + padding materialization + +otherwise: + first-stage: VMI-UNSUPPORTED with the missing fallback reason + future: split safe regions, scratch fill/copy/load, guarded fallback, or diagnostic +``` + +`tile_write` decision tree: + +```text +writeMask all true && full footprint safe-writable: + direct store + +target true masked store: + masked store + +otherwise: + split/guarded/scatter-like fallback or diagnostic +``` + +Slice 5 完成条件: + +```text +1. Unsafe partial/tail read-like ops never lower to a potentially invalid full + read unless the physical footprint is statically proven safe. +2. PaddingValue materialization is not required in the first implementation + stage. Any path that would require paddingValue, true masked/non-faulting + load, scratch fill/copy/load, or guarded fallback must report + `VMI-UNSUPPORTED` with the missing fallback reason. +3. Non-identity logical_shape/permutation_map/lane-to-address maps, including + subview or affine lane maps, are explicitly rejected before lowering. +4. Store-like partial/tail writes are supported only by the existing + full-chunk or contiguous/deinterleaved tail-store predicate paths. Other + writeMask fallback paths must report `VMI-UNSUPPORTED`. +``` + +## 8. Target Capability Registry + +Add one explicit registry object, passed into layout assignment and VMI-to-VPTO: + +```text +supportsElementType(type, purpose) +getNaturalLayout(op) +supportsLayoutConversion(srcLayout, dstLayout, elementType) +getLayoutMaterializationPlan(srcLayout, dstLayout, elementType) +supportsMaskGranularityConversion(srcG, dstG) +supportsMemoryAccessPlan(plan) +supportsPrefixPopcount(maskType) +supportsReductionScanContract(op) +getScratchResource(plan) +``` + +The registry returns structured results: + +```text +supported +unsupported_missing_capability +unsupported_disabled_by_option +unsupported_resource +``` + +Diagnostics must expose that reason. A pass must not silently choose scalar fallback when fallback is disabled. + +Current implementation status: + +```text +include/PTO/Transforms/VMITargetCapabilities.h + VMITargetCapabilityRegistry + VMICapabilityResult { status, reason } + +currently routed through the registry: + element-type purpose checks for predicate-maskable vregs and direct elementwise/cmp/fma/relu VPTO lowering + reduction-family element-type contracts for reduce_addi/reduce_addf/reduce_maxf/reduce_minf + direct pto.vlds/vsts memory source/destination support + missing target true masked/non-faulting load capability for the current pto.vlds surface + pointer-only UB memory support for pto.vgather2_bc/pto.vscatter/pto.vstur based VMI paths + supported source/result layout conversion pairs + supported b8/b16/b32 mask granularity conversion pairs + pto.vmi.channel_split/channel_merge supported channel count + +still legacy helper-based and should migrate into the registry as follow-up: + full layout materialization plans and padding-safety checks + adjacent ppack/punpack mask granularity materialization plans + prefix popcount and full reduction/scan/contract shape capability checks +``` + +## 9. Diagnostics + +Centralize diagnostic codes in one header or utility file: + +```text +VMI-UNSUPPORTED +VMI-LAYOUT-CONTRACT +VMI-PASS-INVARIANT +VMI-RESIDUAL-OP +``` + +Current implementation defines these codes and their `": "` prefixes in `include/PTO/IR/VMIUtils.h`. Transform and +CLI code must reference those constants instead of spelling the diagnostic code strings locally; a source grep for the +four code strings should find only the central definitions. + +Every diagnostic should include: + +```text +source op +logical VMI type +producer natural layout, if any +consumer required layout, if any +missing capability or disabled option +available materialization paths, if known +``` + +## 10. Lit Test Layout + +Use a dedicated directory: + +```text +test/lit/vmi/ +``` + +Minimum test files: + +```text +vmi_type_attr_parse.mlir +vmi_type_attr_invalid.mlir +vmi_op_verifier_basic.mlir +vmi_producer_boundary.mlir +vmi_layout_assignment_widen.mlir +vmi_layout_assignment_cfg.mlir +vmi_layout_assignment_broadcast_remat.mlir +vmi_layout_assignment_iota_remat.mlir +vmi_layout_assignment_mask_remat.mlir +vmi_to_vpto_deinterleaved2.mlir +vmi_to_vpto_deinterleaved4.mlir +vmi_to_vpto_compaction_deint_invalid.mlir +vmi_to_vpto_non_full_tile.mlir +vmi_tile_read_padding.mlir +vmi_tile_write_mask.mlir +vmi_pipeline_hard_gates.mlir +``` + +Each pass test must use `FileCheck` to prove both positive output and negative absence: + +```text +CHECK: pto.vmi.addf +CHECK-NOT: pto.vadd +CHECK-NOT: unrealized_conversion_cast +``` + +Final lowering tests must check: + +```text +CHECK-NOT: pto.vmi. +CHECK-NOT: unrealized_conversion_cast +``` + +## 11. Implementation Order + +Recommended merge order: + +```text +1. VMI type/attr + helper + parse/verify tests. +2. Slice 1 op shells + verifier tests. +3. VMI producer boundary verifier. +4. layout assignment for straight-line code. +5. layout assignment for scf/cf/function boundaries. +6. vmi-to-vpto type conversion + pack/unpack/unpackable block args. +7. deinterleaved=2 f16 widen end-to-end. +8. deinterleaved=4 f8 widen end-to-end. +9. tile_read/tile_write padding-safe lowering. +10. remaining semantic op families. +``` + +Do not merge a pass that leaves hidden side tables as a required interpretation mechanism. Temporary internal +analysis structures are fine only if the pass materializes the final state into IR before returning. + +## 12. Review Checklist Before Coding Each Slice + +Before implementation: + +```text +1. Is the op/type syntax written in ODS and tested by parser round-trip? +2. Does every verifier rule have a negative test? +3. Does every pass have a post-pass hard gate? +4. Are CFG block arguments and function signatures covered? +5. Does any lowering rely on a defining op that block arguments do not have? +6. Does memory lowering prove safe footprint separately from valid lane mask? +7. Does mask granularity follow consumer element width? +8. Does final VPTO lowering leave zero VMI op/type/helper or unrealized-cast residuals? +``` + +If any answer is no, the slice is not ready to be treated as complete. + +## 13. Adding One VMI Op End To End + +新增一个 `pto.vmi.*` op 时,不要只补 ODS 和 lowering pattern。它必须穿过固定的六个落点, +否则很容易出现 verifier 能过、layout pass 不知道怎么约束、或控制流 physicalization 后残留 VMI type。 + +```text +1. ODS surface: + include/PTO/IR/VMIOps.td + +2. semantic verifier: + lib/PTO/IR/VMI.cpp + +3. layout facts: + lib/PTO/Transforms/VMILayoutAssignment.cpp + +4. vmi-to-vpto preflight: + lib/PTO/Transforms/VMIToVPTO.cpp::verifySupportedVMIToVPTOOps + +5. OneToN lowering pattern: + lib/PTO/Transforms/VMIToVPTO.cpp::populateVMIOneToNConversionPatterns + +6. focused lit tests: + test/lit/vmi/ +``` + +这六个落点的职责不同: + +```text +ODS: + 只定义 op 形状、operand/result type 类别、assembly format、interface 和 verifier hook。 + +VMI.cpp verifier: + 检查局部语义,例如元素类型、rank、lane count、predicate 字符串、source/result bit 数关系。 + 不能依赖 def-use 图,不能决定 layout。 + +LayoutAssignment: + 只收集 value-level layout/granularity 事实: + - producer natural layout + - operands that must share layout with result + - consumer required layout + - mask consumer required granularity + 不能在 collect 阶段改 IR。 + +VMIToVPTO preflight: + 在 rewrite 前拒绝当前 lowering 不支持但语义合法的 case。 + 典型例子是 partial physical chunk、non-prefix mask constant、dynamic create_mask、unsupported shuffle。 + +OneToN pattern: + 从 adaptor 读取 physical parts,按已经确定的 layout 发 VPTO op。 + 不能重新推断 layout,也不能通过 defining op 找 physical parts。 + +lit: + 至少覆盖 parser/verify、layout assignment、positive lowering、negative unsupported diagnostic。 +``` + +### Layout Fact Template + +新增 op 时先给它归类,再写 layout 约束。不要从 VPTO 指令形态反推 VMI layout;layout 的来源必须是 +logical vector 语义和当前物理指令的天然限制。 + +```text +elementwise same-shape op: + examples: + addf/addi/subf/mulf/andi/shli/shrui/absf/absi/sqrt + layout rule: + all data operands and result are in one equivalence class + lowering rule: + emit one VPTO op per physical part + +compare op: + examples: + cmpf/cmpi + layout rule: + lhs/rhs data layout unified + result mask requested to the same data layout + result mask granularity comes from lhs/rhs element width + lowering rule: + emit one vcmp per data part, producing corresponding mask part + +mask logical op: + examples: + mask_and/mask_or/mask_xor/mask_not + layout rule: + all mask operands/results share layout and granularity + lowering rule: + emit one predicate op per physical mask part + +layout-changing producer: + examples: + extf f16->f32, extf f8->f32, truncf f32->f16, truncf f32->fp8-like + layout rule: + source/request side follows instruction input contract + result natural layout follows instruction output contract + lowering rule: + emit the instruction sequence that preserves logical lane order under that layout + +memory consumer/producer: + examples: + load/store/tile_read/tile_write + layout rule: + load/tile_read result natural layout is chosen by memory dist capability + store/tile_write value operand requests the layout that memory dist can consume + lowering rule: + direct path only when every physical chunk has no padding lane and footprint is safe + +structural boundary: + examples: + scf.if result/yield, scf.for iter args, cf.br successor operands, func.call + layout rule: + semantically identical incoming/outgoing values are unified + lowering rule: + handled by OneToN structural patterns, not by op semantic lowering +``` + +代码里 `LayoutSolver::addConstraints()` 应该只表达上面的事实。例如一个普通 elementwise binary op +只需要: + +```cpp +if (auto addf = dyn_cast(op)) { + if (failed(unite(addf.getLhs(), addf.getRhs(), op)) || + failed(unite(addf.getLhs(), addf.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); +} +``` + +一个 layout-changing op 不应该把 source/result 直接 `unite`,而是明确写 producer/consumer 合同: + +```cpp +if (auto extf = dyn_cast(op)) { + requestDataUse(extf.getSourceMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(extf.getResult(), + VMILayoutAttr::getDeinterleaved(ctx, factor), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); +} +``` + +### OneToN Pattern Template + +`vmi-to-vpto` pattern 的输入不再是 logical VMI value,而是 adaptor 里已经 flatten 好的 physical parts。 +pattern 只做三件事: + +```text +1. 从 adaptor 取每个 logical operand 的 physical part list。 +2. 从 resultMapping 取每个 logical result 对应的 physical result type list。 +3. 按 part 顺序创建 VPTO op,并用 resultMapping replace 原 op。 +``` + +普通 elementwise binary op 的代码形态应该接近: + +```cpp +LogicalResult matchAndRewrite(VMIAddFOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange lhsParts = adaptor.getLhs(); + ValueRange rhsParts = adaptor.getRhs(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + + if (lhsParts.size() != rhsParts.size() || lhsParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, "physical arity mismatch"); + + SmallVector results; + for (auto [lhs, rhs, resultType] : llvm::zip_equal(lhsParts, rhsParts, resultTypes)) + results.push_back(rewriter.create(op.getLoc(), resultType, lhs, rhs)); + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); +} +``` + +这里不能调用 `op.getLhs().getDefiningOp()` 去找物理寄存器。原因是 VMI value 可以来自: + +```text +function argument +block argument +scf.for iter arg +scf.if result +cf.br successor argument +func.call result +``` + +这些 value 很多没有 VMI defining op。physical parts 的唯一合法来源是 OneToN adaptor 和 +OneToNTypeMapping。 + +### Control-Flow Checklist + +每新增一个 op,不一定要写新的控制流 pattern;但必须检查它的结果或 operand 是否可能跨边界。 +如果只是普通 VMI value,那么已有 structural OneToN pattern 应该负责边界 physicalization: + +```text +func.func / func.call / func.return: + upstream func OneToN conversion + +scf.if / scf.for / scf.while / scf.yield: + upstream SCF OneToN structural conversion plus layout solver equivalence constraints + +cf.br / cf.cond_br / cf.switch: + project-local OneToN patterns flatten successor operands and rewrite destination block signatures + +scf.execute_region / scf.index_switch: + project-local OneToN patterns flatten region results +``` + +新增 op 的测试要至少放一个跨边界用例,证明 op 的 result 不是只在 straight-line IR 中工作: + +```mlir +%r = scf.if %cond -> !pto.vmi.vreg<128xf32> { + %x = pto.vmi.addf %a, %b : ... -> !pto.vmi.vreg<128xf32> + scf.yield %x : !pto.vmi.vreg<128xf32> +} else { + scf.yield %c : !pto.vmi.vreg<128xf32> +} +pto.vmi.store %r, %ptr, %off : ... +``` + +对应 lowering test 必须检查: + +```text +CHECK-NOT: pto.vmi. +CHECK-NOT: !pto.vmi. +CHECK-NOT: unrealized_conversion_cast +``` + +如果这个测试失败,通常不是该 op 的 VPTO pattern 本身错,而是 layout assignment 没有把 yield/result/consumer +约束统一,或者 OneToN structural pattern 漏了某种 region/control-flow op。 + +### Preflight Versus Pattern Failure + +语义合法但当前还没有物理实现的 case,应该在 `verifySupportedVMIToVPTOOps()` 里给稳定 diagnostic, +不要让 pattern 随机 `notifyMatchFailure()` 后落成 generic conversion failure。 + +```text +use verifier failure: + op 本身语义非法,任何 target 都不应该接受。 + examples: + absf on integer element + shrui on signed integer element + bitcast total bits mismatch + +use VMI-LAYOUT-CONTRACT: + 多个 producer/consumer/control-flow 约束互相冲突。 + examples: + one value simultaneously required as contiguous and deinterleaved=2 + one mask simultaneously required as b16 and b32 + +use VMI-UNSUPPORTED in preflight: + VMI semantics are valid, but current VPTO materialization is not implemented. + examples: + partial/tail memory access + pred-only constant mask without concrete b8/b16/b32 granularity + shuffle that requires vselr index-vector materialization + bitcast across partial physical chunks + +use VMI-RESIDUAL-OP: + conversion framework finished but VMI op/type/helper/cast remains. + This is a pass bug or missing pattern, not a user semantic error. +``` + +Pattern-local `notifyMatchFailure()` is still useful for debugging competing patterns, but it must not be the only +user-visible explanation for a known unsupported VMI semantic case. diff --git a/include/PTO/IR/PTOAttrs.td b/include/PTO/IR/PTOAttrs.td index 429b5e232e..ed8166ad12 100644 --- a/include/PTO/IR/PTOAttrs.td +++ b/include/PTO/IR/PTOAttrs.td @@ -38,6 +38,8 @@ class PTO_Attr traits = []> let mnemonic = attrMnemonic; } +include "PTO/IR/VMIAttrs.td" + //===----------------------------------------------------------------------===// // Address Space //===----------------------------------------------------------------------===// diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index 97d6ff6a9b..ce2a780edd 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -76,6 +76,7 @@ class PTO_DpsOp traits = []> class PTO_Op traits = []> : Op; +include "PTO/IR/VMIOps.td" include "PTO/IR/VPTOOps.td" //===----------------------------------------------------------------------===// diff --git a/include/PTO/IR/PTOTypeDefs.td b/include/PTO/IR/PTOTypeDefs.td index 5fbe9d8d45..a6ac0ad106 100644 --- a/include/PTO/IR/PTOTypeDefs.td +++ b/include/PTO/IR/PTOTypeDefs.td @@ -338,4 +338,5 @@ def F4E2M1x2Type : TypeDef { + let summary = "VMI logical vector register layout"; + let parameters = (ins + StringRefParameter<"layout kind">:$kind, + "int64_t":$factor + ); + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; + + let extraClassDeclaration = [{ + static VMILayoutAttr getContiguous(::mlir::MLIRContext *context); + static VMILayoutAttr getDeinterleaved(::mlir::MLIRContext *context, + int64_t factor); + + bool isContiguous() const { return getKind() == "contiguous"; } + bool isDeinterleaved() const { return getKind() == "deinterleaved"; } + }]; +} + +#endif // MLIR_DIALECT_PTO_IR_VMIATTRS diff --git a/include/PTO/IR/VMIOps.td b/include/PTO/IR/VMIOps.td new file mode 100644 index 0000000000..6f567bb8a5 --- /dev/null +++ b/include/PTO/IR/VMIOps.td @@ -0,0 +1,562 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- VMIOps.td - PTO VMI semantic operations -------------*- tablegen -*-===// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_PTO_IR_VMIOPS +#define MLIR_DIALECT_PTO_IR_VMIOPS + +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +def VMI_VRegTypeConstraint : Type< + CPred<"::llvm::isa<::mlir::pto::VMIVRegType>($_self)">, + "VMI logical vector register type">; + +def VMI_MaskTypeConstraint : Type< + CPred<"::llvm::isa<::mlir::pto::VMIMaskType>($_self)">, + "VMI logical mask type">; + +def VMI_ValueTypeConstraint : Type< + CPred<"::llvm::isa<::mlir::pto::VMIVRegType, ::mlir::pto::VMIMaskType>($_self)">, + "VMI logical vector or mask type">; + +def PTO_PhysicalVRegTypeConstraint : Type< + CPred<"::llvm::isa<::mlir::pto::VRegType>($_self)">, + "PTO physical vector register type">; + +def PTO_PhysicalMaskTypeConstraint : Type< + CPred<"::llvm::isa<::mlir::pto::MaskType>($_self)">, + "PTO physical mask type">; + +def PTO_PhysicalVMIPartTypeConstraint : AnyTypeOf< + [PTO_PhysicalVRegTypeConstraint, PTO_PhysicalMaskTypeConstraint], + "PTO physical vector register or mask type">; + +class VMI_Op traits = []> + : PTO_Op<"vmi." # mnemonic, traits>; + +def VMIConstantOp : VMI_Op<"constant"> { + let summary = "VMI logical vector constant"; + let arguments = (ins AnyAttr:$value); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; +} + +def VMIBroadcastOp : VMI_Op<"broadcast"> { + let summary = "Broadcast one scalar or rank-0 VMI vector to a VMI logical vector"; + let arguments = (ins AnyType:$value); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$value attr-dict `:` type($value) `->` type($result)"; +} + +def VMIIotaOp : VMI_Op<"iota"> { + let summary = "Create a VMI logical index vector from a scalar base"; + let arguments = (ins + AnyTypeOf<[AnyInteger, AnyFloat], "integer/float scalar">:$base, + OptionalAttr:$order + ); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$base attr-dict `:` type($base) `->` type($result)"; +} + +def VMICreateMaskOp : VMI_Op<"create_mask"> { + let summary = "Create a VMI logical prefix predicate mask"; + let arguments = (ins Index:$active_lanes); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$active_lanes attr-dict `:` type($active_lanes) `->` type($result)"; +} + +def VMIConstantMaskOp : VMI_Op<"constant_mask"> { + let summary = "VMI logical predicate mask constant"; + let arguments = (ins AnyAttr:$value); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; +} + +def VMIMaskAndOp : VMI_Op<"mask_and"> { + let summary = "VMI logical predicate mask and"; + let arguments = (ins VMI_MaskTypeConstraint:$lhs, VMI_MaskTypeConstraint:$rhs); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIMaskOrOp : VMI_Op<"mask_or"> { + let summary = "VMI logical predicate mask or"; + let arguments = (ins VMI_MaskTypeConstraint:$lhs, VMI_MaskTypeConstraint:$rhs); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIMaskXOrOp : VMI_Op<"mask_xor"> { + let summary = "VMI logical predicate mask xor"; + let arguments = (ins VMI_MaskTypeConstraint:$lhs, VMI_MaskTypeConstraint:$rhs); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIMaskNotOp : VMI_Op<"mask_not"> { + let summary = "VMI logical predicate mask not"; + let arguments = (ins VMI_MaskTypeConstraint:$source); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIAddFOp : VMI_Op<"addf"> { + let summary = "VMI floating-point elementwise add"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIAddIOp : VMI_Op<"addi"> { + let summary = "VMI integer elementwise add"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMISubFOp : VMI_Op<"subf"> { + let summary = "VMI floating-point elementwise subtract"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMISubIOp : VMI_Op<"subi"> { + let summary = "VMI integer elementwise subtract"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIMulFOp : VMI_Op<"mulf"> { + let summary = "VMI floating-point elementwise multiply"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIMulIOp : VMI_Op<"muli"> { + let summary = "VMI integer elementwise multiply"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIFmaOp : VMI_Op<"fma"> { + let summary = "VMI fused floating-point multiply-add"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs, + VMI_VRegTypeConstraint:$acc); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` type($lhs) `,` type($rhs) `,` type($acc) `->` type($result)"; +} + +def VMIDivFOp : VMI_Op<"divf"> { + let summary = "VMI floating-point elementwise divide"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIMinFOp : VMI_Op<"minf"> { + let summary = "VMI floating-point elementwise minimum"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIMaxFOp : VMI_Op<"maxf"> { + let summary = "VMI floating-point elementwise maximum"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMINegFOp : VMI_Op<"negf"> { + let summary = "VMI floating-point elementwise negate"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIAbsFOp : VMI_Op<"absf"> { + let summary = "VMI floating-point elementwise absolute value"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIAbsIOp : VMI_Op<"absi"> { + let summary = "VMI integer elementwise absolute value"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMISqrtOp : VMI_Op<"sqrt"> { + let summary = "VMI floating-point elementwise square root"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIExpOp : VMI_Op<"exp"> { + let summary = "VMI floating-point elementwise exponential"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMILnOp : VMI_Op<"ln"> { + let summary = "VMI floating-point elementwise natural logarithm"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIReluOp : VMI_Op<"relu"> { + let summary = "VMI floating-point elementwise ReLU"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIAndIOp : VMI_Op<"andi"> { + let summary = "VMI integer elementwise bitwise and"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIOrIOp : VMI_Op<"ori"> { + let summary = "VMI integer elementwise bitwise or"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIXOrIOp : VMI_Op<"xori"> { + let summary = "VMI integer elementwise bitwise xor"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIShLIOp : VMI_Op<"shli"> { + let summary = "VMI integer elementwise left shift"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIShRUIOp : VMI_Op<"shrui"> { + let summary = "VMI unsigned integer elementwise right shift"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMINotOp : VMI_Op<"not"> { + let summary = "VMI integer elementwise bitwise not"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMICmpFOp : VMI_Op<"cmpf"> { + let summary = "VMI floating-point elementwise compare"; + let arguments = (ins StrAttr:$predicate, VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMICmpIOp : VMI_Op<"cmpi"> { + let summary = "VMI integer elementwise compare"; + let arguments = (ins StrAttr:$predicate, VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMISelectOp : VMI_Op<"select"> { + let summary = "VMI elementwise select"; + let arguments = (ins VMI_MaskTypeConstraint:$mask, VMI_VRegTypeConstraint:$true_value, + VMI_VRegTypeConstraint:$false_value); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$mask `,` $true_value `,` $false_value attr-dict `:` type($mask) `,` type($true_value) `,` type($false_value) `->` type($result)"; +} + +def VMIActivePrefixIndexOp : VMI_Op<"active_prefix_index"> { + let summary = "VMI per-lane active-prefix index from a predicate mask"; + let arguments = (ins VMI_MaskTypeConstraint:$mask); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$mask attr-dict `:` type($mask) `->` type($result)"; +} + +def VMICompressOp : VMI_Op<"compress"> { + let summary = "VMI compact active source lanes according to a predicate mask"; + let arguments = (ins VMI_VRegTypeConstraint:$source, VMI_MaskTypeConstraint:$mask); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `,` $mask attr-dict `:` type($source) `,` type($mask) `->` type($result)"; +} + +def VMICompressStoreOp : VMI_Op<"compress_store", [DeclareOpInterfaceMethods]> { + let summary = "VMI store active source lanes contiguously according to a predicate mask"; + let arguments = (ins VMI_VRegTypeConstraint:$value, PtrOrMemRef:$destination, + Index:$offset, VMI_MaskTypeConstraint:$mask); + let results = (outs); + let hasVerifier = 1; + let assemblyFormat = "$value `,` $destination `[` $offset `]` `,` $mask attr-dict `:` type($value) `,` type($destination) `,` type($mask)"; +} + +def VMIReduceAddIOp : VMI_Op<"reduce_addi"> { + let summary = "VMI masked integer add reduction with a rank-0 vector init"; + let arguments = (ins VMI_VRegTypeConstraint:$source, + VMI_VRegTypeConstraint:$init, + VMI_MaskTypeConstraint:$mask); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `,` $init `,` $mask attr-dict `:` type($source) `,` type($init) `,` type($mask) `->` type($result)"; +} + +def VMIReduceAddFOp : VMI_Op<"reduce_addf"> { + let summary = "VMI masked floating-point add reduction with explicit reassociation permission"; + let arguments = (ins VMI_VRegTypeConstraint:$source, + VMI_VRegTypeConstraint:$init, + VMI_MaskTypeConstraint:$mask, + OptionalAttr:$reassoc); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `,` $init `,` $mask attr-dict `:` type($source) `,` type($init) `,` type($mask) `->` type($result)"; +} + +def VMIReduceMaxFOp : VMI_Op<"reduce_maxf"> { + let summary = "VMI masked floating-point maximum reduction with a rank-0 vector init"; + let arguments = (ins VMI_VRegTypeConstraint:$source, + VMI_VRegTypeConstraint:$init, + VMI_MaskTypeConstraint:$mask); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `,` $init `,` $mask attr-dict `:` type($source) `,` type($init) `,` type($mask) `->` type($result)"; +} + +def VMIReduceMinFOp : VMI_Op<"reduce_minf"> { + let summary = "VMI masked floating-point minimum reduction with a rank-0 vector init"; + let arguments = (ins VMI_VRegTypeConstraint:$source, + VMI_VRegTypeConstraint:$init, + VMI_MaskTypeConstraint:$mask); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `,` $init `,` $mask attr-dict `:` type($source) `,` type($init) `,` type($mask) `->` type($result)"; +} + +def VMIExtFOp : VMI_Op<"extf"> { + let summary = "VMI floating-point elementwise extension"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMITruncFOp : VMI_Op<"truncf"> { + let summary = "VMI floating-point elementwise truncation"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIBitcastOp : VMI_Op<"bitcast"> { + let summary = "VMI bitwise vector reinterpretation"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMILoadOp : VMI_Op<"load", [DeclareOpInterfaceMethods]> { + let summary = "VMI logical vector load"; + let arguments = (ins PtrOrMemRef:$source, Index:$offset); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `[` $offset `]` attr-dict `:` type($source) `->` type($result)"; +} + +def VMIMaskedLoadOp : VMI_Op<"masked_load", [DeclareOpInterfaceMethods]> { + let summary = "VMI logical masked vector load with passthrough lanes"; + let arguments = (ins PtrOrMemRef:$source, Index:$offset, + VMI_MaskTypeConstraint:$mask, + VMI_VRegTypeConstraint:$passthru); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `[` $offset `]` `,` $mask `,` $passthru attr-dict `:` type($source) `,` type($mask) `,` type($passthru) `->` type($result)"; +} + +def VMIGatherOp : VMI_Op<"gather", [DeclareOpInterfaceMethods]> { + let summary = "VMI logical masked indexed gather with passthrough lanes"; + let arguments = (ins PtrOrMemRef:$source, + VMI_VRegTypeConstraint:$indices, + VMI_MaskTypeConstraint:$mask, + VMI_VRegTypeConstraint:$passthru); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `[` $indices `]` `,` $mask `,` $passthru attr-dict `:` type($source) `,` type($indices) `,` type($mask) `,` type($passthru) `->` type($result)"; +} + +def VMIExpandLoadOp : VMI_Op<"expand_load", [DeclareOpInterfaceMethods]> { + let summary = "VMI load a dense active-lane stream into masked logical lanes"; + let arguments = (ins PtrOrMemRef:$source, Index:$offset, + VMI_MaskTypeConstraint:$mask, + VMI_VRegTypeConstraint:$passthru); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `[` $offset `]` `,` $mask `,` $passthru attr-dict `:` type($source) `,` type($mask) `,` type($passthru) `->` type($result)"; +} + +def VMIStoreOp : VMI_Op<"store", [DeclareOpInterfaceMethods]> { + let summary = "VMI logical vector store"; + let arguments = (ins VMI_VRegTypeConstraint:$value, PtrOrMemRef:$destination, Index:$offset); + let results = (outs); + let hasVerifier = 1; + let assemblyFormat = "$value `,` $destination `[` $offset `]` attr-dict `:` type($value) `,` type($destination)"; +} + +def VMIMaskedStoreOp : VMI_Op<"masked_store", [DeclareOpInterfaceMethods]> { + let summary = "VMI logical masked vector store"; + let arguments = (ins VMI_VRegTypeConstraint:$value, PtrOrMemRef:$destination, + Index:$offset, VMI_MaskTypeConstraint:$mask); + let results = (outs); + let hasVerifier = 1; + let assemblyFormat = "$value `,` $destination `[` $offset `]` `,` $mask attr-dict `:` type($value) `,` type($destination) `,` type($mask)"; +} + +def VMIScatterOp : VMI_Op<"scatter", [DeclareOpInterfaceMethods]> { + let summary = "VMI logical masked indexed scatter"; + let arguments = (ins VMI_VRegTypeConstraint:$value, + PtrOrMemRef:$destination, + VMI_VRegTypeConstraint:$indices, + VMI_MaskTypeConstraint:$mask, + OptionalAttr:$indices_unique); + let results = (outs); + let hasVerifier = 1; + let assemblyFormat = "$value `,` $destination `[` $indices `]` `,` $mask attr-dict `:` type($value) `,` type($destination) `,` type($indices) `,` type($mask)"; +} + +def VMITileReadOp : VMI_Op<"tile_read", [DeclareOpInterfaceMethods]> { + let summary = "VMI logical tile read"; + let arguments = (ins AnyType:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMITileWriteOp : VMI_Op<"tile_write", [DeclareOpInterfaceMethods]> { + let summary = "VMI logical tile write"; + let arguments = (ins VMI_VRegTypeConstraint:$value, AnyType:$destination); + let results = (outs); + let hasVerifier = 1; + let assemblyFormat = "$value `,` $destination attr-dict `:` type($value) `,` type($destination)"; +} + +def VMIShuffleOp : VMI_Op<"shuffle"> { + let summary = "VMI static lane shuffle"; + let arguments = (ins VMI_VRegTypeConstraint:$source, DenseI64ArrayAttr:$indices); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `[` $indices `]` attr-dict `:` type($source) `->` type($result)"; +} + +def VMIChannelSplitOp : VMI_Op<"channel_split"> { + let summary = "VMI split interleaved logical channels"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs Variadic:$results); + let hasVerifier = 1; +} + +def VMIChannelMergeOp : VMI_Op<"channel_merge"> { + let summary = "VMI merge logical channels by interleaving"; + let arguments = (ins Variadic:$inputs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; +} + +def VMIEnsureLayoutOp : VMI_Op<"ensure_layout"> { + let summary = "Internal VMI data layout materialization helper"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIEnsureMaskLayoutOp : VMI_Op<"ensure_mask_layout"> { + let summary = "Internal VMI mask layout materialization helper"; + let arguments = (ins VMI_MaskTypeConstraint:$source); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIEnsureMaskGranularityOp : VMI_Op<"ensure_mask_granularity"> { + let summary = "Internal VMI mask granularity materialization helper"; + let arguments = (ins VMI_MaskTypeConstraint:$source); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIUnpackOp : VMI_Op<"unpack"> { + let summary = "Internal VMI value projection to physical parts"; + let arguments = (ins VMI_ValueTypeConstraint:$source); + let results = (outs Variadic:$parts); + let hasVerifier = 1; +} + +def VMIPackOp : VMI_Op<"pack"> { + let summary = "Internal physical parts materialized as one VMI value"; + let arguments = (ins Variadic:$parts); + let results = (outs VMI_ValueTypeConstraint:$result); + let hasVerifier = 1; +} + +#endif // MLIR_DIALECT_PTO_IR_VMIOPS diff --git a/include/PTO/IR/VMITypeDefs.td b/include/PTO/IR/VMITypeDefs.td new file mode 100644 index 0000000000..4ec6bb5009 --- /dev/null +++ b/include/PTO/IR/VMITypeDefs.td @@ -0,0 +1,67 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- VMITypeDefs.td - PTO VMI type definitions -----------*- tablegen -*-===// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_PTO_IR_VMITYPEDEFS +#define MLIR_DIALECT_PTO_IR_VMITYPEDEFS + +include "PTO/IR/PTODialect.td" +include "PTO/IR/PTOAttrs.td" + +def VMIVRegType : TypeDef { + let mnemonic = "vmi.vreg"; + let summary = "A VMI logical vector register value"; + + let parameters = (ins + "int64_t":$elementCount, + "Type":$elementType, + "mlir::Attribute":$layout + ); + + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; + + let extraClassDeclaration = [{ + bool hasLayout() const { return static_cast(getLayout()); } + VMILayoutAttr getLayoutAttr() const { + return ::llvm::dyn_cast_or_null(getLayout()); + } + }]; +} + +def VMIMaskType : TypeDef { + let mnemonic = "vmi.mask"; + let summary = "A VMI logical predicate mask value"; + + let parameters = (ins + "int64_t":$elementCount, + StringRefParameter<"mask granularity view">:$granularity, + "mlir::Attribute":$layout + ); + + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; + + let extraClassDeclaration = [{ + static bool isSupportedGranularity(::llvm::StringRef granularity); + static bool isConcreteGranularity(::llvm::StringRef granularity); + + bool hasLayout() const { return static_cast(getLayout()); } + bool isPred() const { return getGranularity() == "pred"; } + bool isB8() const { return getGranularity() == "b8"; } + bool isB16() const { return getGranularity() == "b16"; } + bool isB32() const { return getGranularity() == "b32"; } + VMILayoutAttr getLayoutAttr() const { + return ::llvm::dyn_cast_or_null(getLayout()); + } + }]; +} + +#endif // MLIR_DIALECT_PTO_IR_VMITYPEDEFS diff --git a/include/PTO/IR/VMIUtils.h b/include/PTO/IR/VMIUtils.h new file mode 100644 index 0000000000..e55e558034 --- /dev/null +++ b/include/PTO/IR/VMIUtils.h @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- VMIUtils.h - PTO VMI shared helpers ----------------------*- C++ -*-===// +//===----------------------------------------------------------------------===// + +#ifndef PTO_IR_VMIUTILS_H +#define PTO_IR_VMIUTILS_H + +#include "PTO/IR/PTO.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir::pto { + +inline constexpr StringLiteral kVMIDiagUnsupported = "VMI-UNSUPPORTED"; +inline constexpr StringLiteral kVMIDiagLayoutContract = + "VMI-LAYOUT-CONTRACT"; +inline constexpr StringLiteral kVMIDiagPassInvariant = "VMI-PASS-INVARIANT"; +inline constexpr StringLiteral kVMIDiagResidualOp = "VMI-RESIDUAL-OP"; + +inline constexpr StringLiteral kVMIDiagUnsupportedPrefix = + "VMI-UNSUPPORTED: "; +inline constexpr StringLiteral kVMIDiagLayoutContractPrefix = + "VMI-LAYOUT-CONTRACT: "; +inline constexpr StringLiteral kVMIDiagPassInvariantPrefix = + "VMI-PASS-INVARIANT: "; +inline constexpr StringLiteral kVMIDiagResidualOpPrefix = "VMI-RESIDUAL-OP: "; + +struct VMIPhysicalLane { + int64_t part = 0; + int64_t chunk = 0; + int64_t lane = 0; +}; + +FailureOr getDataLanesPerPart(Type elementType); +FailureOr getMaskLanesPerPart(StringRef granularity); +FailureOr getVMIPhysicalArity(Type type); +FailureOr mapLogicalLaneToPhysical(Type type, + int64_t logicalLane); +FailureOr mapPhysicalLaneToLogical(Type type, int64_t part, + int64_t chunk, int64_t lane); +FailureOr isPaddingLane(Type type, int64_t part, int64_t chunk, + int64_t lane); + +} // namespace mlir::pto + +#endif // PTO_IR_VMIUTILS_H diff --git a/include/PTO/Transforms/Passes.h b/include/PTO/Transforms/Passes.h index 59ad36c932..15a247594b 100644 --- a/include/PTO/Transforms/Passes.h +++ b/include/PTO/Transforms/Passes.h @@ -100,6 +100,14 @@ LogicalResult validateVPTOEmissionIR(ModuleOp module, llvm::raw_ostream *diagOS = nullptr); std::unique_ptr createPTOValidateVPTOIRPass(); std::unique_ptr createPTOValidateVPTOEmissionIRPass(); +LogicalResult validateVMIProducerBoundaryIR(ModuleOp module, + llvm::raw_ostream *diagOS = nullptr); +LogicalResult validateVMILayoutAssignedIR(ModuleOp module, + llvm::raw_ostream *diagOS = nullptr); +std::unique_ptr createPTOValidateVMIIRPass(); +std::unique_ptr createPTOValidateVMILayoutIRPass(); +std::unique_ptr createVMILayoutAssignmentPass(); +std::unique_ptr createVMIToVPTOPass(); std::unique_ptr createExpandTileOpPass(); std::unique_ptr createExpandTileOpPass(const ExpandTileOpOptions &options); std::unique_ptr createFoldTileBufIntrinsicsPass(); diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index a897034d15..25ec3324b9 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -629,6 +629,77 @@ def PTOValidateVPTOIR : Pass<"pto-validate-vpto-ir", "ModuleOp"> { "mlir::scf::SCFDialect"]; } +def PTOValidateVMIIR : Pass<"pto-validate-vmi-ir", "ModuleOp"> { + let summary = "Validate VMI producer-boundary semantic IR"; + let description = [{ + Checks that VMI producer-boundary IR uses only surface VMI data/mask types, + native pto.vmi semantic ops, and structural control-flow/function ops. This + pass runs before layout assignment, so layout-assigned VMI types, VMI helper + ops, and physical VPTO register types are rejected. + }]; + let constructor = "mlir::pto::createPTOValidateVMIIRPass()"; + let dependentDialects = ["mlir::cf::ControlFlowDialect", + "mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect"]; +} + +def PTOValidateVMILayoutIR + : Pass<"pto-validate-vmi-layout-ir", "ModuleOp"> { + let summary = "Validate layout-assigned VMI IR"; + let description = [{ + Checks the post-layout-assignment VMI stage: every VMI data value must have + a concrete VMI layout, every VMI mask must have concrete b8/b16/b32 + granularity and layout, physical VPTO register values must not appear yet, + and VMI typed values must stay inside VMI semantic/helper or structural ops. + }]; + let constructor = "mlir::pto::createPTOValidateVMILayoutIRPass()"; + let dependentDialects = ["mlir::cf::ControlFlowDialect", + "mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect"]; +} + +def VMILayoutAssignment : Pass<"vmi-layout-assignment", "ModuleOp"> { + let summary = "Assign concrete VMI layouts and mask granularities"; + let description = [{ + Solves VMI layout constraints and materializes the chosen layout and mask + granularity into VMI types. This pass is the boundary between surface VMI + semantic IR and layout-assigned VMI IR. + }]; + let constructor = "mlir::pto::createVMILayoutAssignmentPass()"; + let dependentDialects = ["mlir::cf::ControlFlowDialect", + "mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect"]; +} + +def VMIToVPTO : Pass<"vmi-to-vpto", "ModuleOp"> { + let summary = "Convert layout-assigned VMI IR to physical VPTO IR"; + let description = [{ + Converts layout-assigned VMI aggregate data/mask types to ordered physical + VPTO register and mask value lists using MLIR OneToNTypeConversion. This + pass is responsible for VMI 1:N type conversion, structural control-flow + and function/call signature conversion, and VMI semantic op physicalization. + }]; + let constructor = "mlir::pto::createVMIToVPTOPass()"; + let options = [ + Option<"enableStableGatherMaskedLoad", + "enable-stable-gather-masked-load", "bool", + /*default=*/"false", + "Reserve the stable VGATHER-based lowering path for VMI masked " + "loads; currently emits a TODO diagnostic when used."> + ]; + let dependentDialects = ["mlir::cf::ControlFlowDialect", + "mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect"]; +} + def PTOValidateVPTOEmissionIR : Pass<"pto-validate-vpto-emission-ir", "ModuleOp"> { let summary = diff --git a/include/PTO/Transforms/VMITargetCapabilities.h b/include/PTO/Transforms/VMITargetCapabilities.h new file mode 100644 index 0000000000..15b4f19f1d --- /dev/null +++ b/include/PTO/Transforms/VMITargetCapabilities.h @@ -0,0 +1,318 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- VMITargetCapabilities.h - VMI target capability registry -*- C++ -*-===// +//===----------------------------------------------------------------------===// + +#ifndef PTO_TRANSFORMS_VMITARGETCAPABILITIES_H +#define PTO_TRANSFORMS_VMITARGETCAPABILITIES_H + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/Twine.h" + +#include + +namespace mlir::pto { + +enum class VMICapabilityStatus { + supported, + unsupported_missing_capability, + unsupported_disabled_by_option, + unsupported_resource, +}; + +enum class VMIElementPurpose { + PredicateMask, + F16F32, + F16BF16F32, + SignlessOrSignedI8I16I32, + AnyI8I16I32, + VMula, + VRelu, +}; + +enum class VMIReductionKind { + AddI, + AddF, + MaxF, + MinF, +}; + +enum class VMIFallbackResourceKind { + ScratchMemory, + GuardedControlFlow, +}; + +struct VMICapabilityResult { + VMICapabilityStatus status = VMICapabilityStatus::supported; + std::string reason; + + static VMICapabilityResult supported() { return {}; } + + static VMICapabilityResult missingCapability(const Twine &reason) { + VMICapabilityResult result; + result.status = VMICapabilityStatus::unsupported_missing_capability; + result.reason = reason.str(); + return result; + } + + bool isSupported() const { + return status == VMICapabilityStatus::supported; + } + + LogicalResult toLogicalResult(std::string *outReason = nullptr) const { + if (isSupported()) + return success(); + if (outReason) + *outReason = reason; + return failure(); + } +}; + +class VMITargetCapabilityRegistry { +public: + VMICapabilityResult supportsElementType(Type type, + VMIElementPurpose purpose) const { + switch (purpose) { + case VMIElementPurpose::PredicateMask: { + unsigned elementBits = pto::getPTOStorageElemBitWidth(type); + if (elementBits == 8 || elementBits == 16 || elementBits == 32) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "requires an 8/16/32-bit element type so VPTO b8/b16/b32 " + "predicate masks can be materialized"); + } + case VMIElementPurpose::F16F32: + if (type.isF16() || type.isF32()) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "requires f16/f32 element type for direct VPTO lowering"); + case VMIElementPurpose::F16BF16F32: + if (type.isF16() || type.isBF16() || type.isF32()) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "requires f16/bf16/f32 element type for direct VPTO lowering"); + case VMIElementPurpose::SignlessOrSignedI8I16I32: + if (isSignlessOrSignedI8I16I32(type)) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "requires signless/signed i8/i16/i32 element type for direct VPTO " + "lowering"); + case VMIElementPurpose::AnyI8I16I32: + if (isAnyI8I16I32(type)) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "requires signless/signed/unsigned i8/i16/i32 element type for " + "direct VPTO lowering"); + case VMIElementPurpose::VMula: + if (type.isF16() || type.isBF16() || type.isF32()) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "requires f16, bf16, or f32 element type for pto.vmula"); + case VMIElementPurpose::VRelu: + if (type.isF16() || type.isF32()) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "pto.vrelu direct lowering supports only f16/f32 VMI " + "floating-point element types"); + } + llvm_unreachable("unhandled VMI element purpose"); + } + + VMICapabilityResult supportsDirectMemory(Type type, StringRef role) const { + switch (classifyDirectMemoryRole(type)) { + case DirectMemoryRole::UB: + case DirectMemoryRole::Unknown: + return VMICapabilityResult::supported(); + case DirectMemoryRole::GM: + return VMICapabilityResult::missingCapability( + Twine(role) + + " is GM-backed, but current direct VMI-to-VPTO memory lowering " + "emits pto.vlds/pto.vsts and requires UB-backed memory"); + case DirectMemoryRole::Other: + return VMICapabilityResult::missingCapability( + Twine(role) + + " is not UB-backed memory supported by pto.vlds/pto.vsts"); + } + llvm_unreachable("unhandled direct memory role"); + } + + VMICapabilityResult supportsUBPointerMemory(Type type, StringRef role, + StringRef physicalOp, + StringRef ubReason) const { + auto ptrType = dyn_cast(type); + if (!ptrType) + return VMICapabilityResult::missingCapability( + Twine("requires a !pto.ptr ") + role + " because " + physicalOp + + " is pointer-only"); + if (ptrType.getMemorySpace().getAddressSpace() != AddressSpace::VEC) + return VMICapabilityResult::missingCapability( + Twine("requires a UB ") + role + " because " + ubReason); + return VMICapabilityResult::supported(); + } + + VMICapabilityResult supportsChannelCount(StringRef opName, + int64_t channels) const { + if (channels == 2 || channels == 4) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + Twine(opName) + " supports only 2 or 4 channels"); + } + + VMICapabilityResult supportsLayoutConversion(VMILayoutAttr sourceLayout, + VMILayoutAttr resultLayout, + Type elementType) const { + (void)elementType; + if (!sourceLayout || !resultLayout) + return VMICapabilityResult::missingCapability( + "requires assigned source/result layouts"); + if (sourceLayout == resultLayout) + return VMICapabilityResult::supported(); + if (sourceLayout.isContiguous() && resultLayout.isDeinterleaved() && + (resultLayout.getFactor() == 2 || resultLayout.getFactor() == 4)) + return VMICapabilityResult::supported(); + if (sourceLayout.isDeinterleaved() && resultLayout.isContiguous() && + (sourceLayout.getFactor() == 2 || sourceLayout.getFactor() == 4)) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "unsupported source/result layout pair"); + } + + VMICapabilityResult supportsMaskGranularityConversion( + StringRef sourceGranularity, StringRef resultGranularity) const { + if (!VMIMaskType::isConcreteGranularity(sourceGranularity) || + !VMIMaskType::isConcreteGranularity(resultGranularity)) + return VMICapabilityResult::missingCapability( + "requires concrete b8/b16/b32 source and result granularities"); + return VMICapabilityResult::supported(); + } + + VMICapabilityResult supportsTrueMaskedLoad(Type sourceType, Type resultType, + Type maskType) const { + (void)sourceType; + (void)resultType; + (void)maskType; + return VMICapabilityResult::missingCapability( + "target true masked/non-faulting load is unavailable because the " + "current VPTO pto.vlds surface has no mask operand"); + } + + VMICapabilityResult supportsFallbackResource( + VMIFallbackResourceKind kind) const { + switch (kind) { + case VMIFallbackResourceKind::ScratchMemory: + return VMICapabilityResult::missingCapability( + "scratch memory fallback resource allocation is not implemented"); + case VMIFallbackResourceKind::GuardedControlFlow: + return VMICapabilityResult::missingCapability( + "guarded memory fallback control-flow lowering is not implemented"); + } + llvm_unreachable("unhandled VMI fallback resource kind"); + } + + VMICapabilityResult supportsReductionElementType( + VMIReductionKind kind, Type elementType) const { + switch (kind) { + case VMIReductionKind::AddI: + if (pto::getPTOStorageElemBitWidth(elementType) == 32 && + isa(elementType)) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "currently supports only 32-bit integer elements because narrow " + "vcadd widens its result"); + case VMIReductionKind::AddF: + if (elementType.isF32()) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "currently supports only f32 elements; f16 requires an explicit " + "accumulator precision and rounding contract"); + case VMIReductionKind::MaxF: + case VMIReductionKind::MinF: + if (elementType.isF16() || elementType.isF32()) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "currently supports only f16/f32 elements because pto.vcmax/" + "pto.vcmin support only those floating-point element types"); + } + llvm_unreachable("unhandled VMI reduction kind"); + } + +private: + enum class DirectMemoryRole { Unknown, UB, GM, Other }; + + DirectMemoryRole classifyDirectMemoryRole(Type type) const { + if (auto ptrType = dyn_cast(type)) { + switch (ptrType.getMemorySpace().getAddressSpace()) { + case AddressSpace::GM: + case AddressSpace::Zero: + return DirectMemoryRole::GM; + case AddressSpace::VEC: + return DirectMemoryRole::UB; + default: + return DirectMemoryRole::Other; + } + } + + auto memrefType = dyn_cast(type); + if (!memrefType) + return DirectMemoryRole::Other; + + Attribute memorySpace = memrefType.getMemorySpace(); + if (!memorySpace) + return DirectMemoryRole::Unknown; + + if (auto addressSpace = dyn_cast(memorySpace)) { + switch (addressSpace.getAddressSpace()) { + case AddressSpace::GM: + case AddressSpace::Zero: + return DirectMemoryRole::GM; + case AddressSpace::VEC: + return DirectMemoryRole::UB; + default: + return DirectMemoryRole::Other; + } + } + + if (auto integerSpace = dyn_cast(memorySpace)) { + switch (integerSpace.getInt()) { + case static_cast(AddressSpace::GM): + case static_cast(AddressSpace::Zero): + return DirectMemoryRole::GM; + case static_cast(AddressSpace::VEC): + return DirectMemoryRole::UB; + default: + return DirectMemoryRole::Other; + } + } + + return DirectMemoryRole::Other; + } + + static bool isSignlessOrSignedI8I16I32(Type type) { + auto intType = dyn_cast(type); + if (!intType || intType.isUnsigned()) + return false; + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + } + + static bool isAnyI8I16I32(Type type) { + auto intType = dyn_cast(type); + if (!intType) + return false; + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + } +}; + +} // namespace mlir::pto + +#endif // PTO_TRANSFORMS_VMITARGETCAPABILITIES_H diff --git a/lib/PTO/IR/CMakeLists.txt b/lib/PTO/IR/CMakeLists.txt index 74b9e0bd68..4f8d995796 100644 --- a/lib/PTO/IR/CMakeLists.txt +++ b/lib/PTO/IR/CMakeLists.txt @@ -15,6 +15,7 @@ add_mlir_dialect_library(PTOIR PTO.cpp VPTO.cpp + VMI.cpp PTOAttrs.cpp PTOSyncUtils.cpp PTOTypeDefs.cpp diff --git a/lib/PTO/IR/VMI.cpp b/lib/PTO/IR/VMI.cpp new file mode 100644 index 0000000000..1f9a43f51a --- /dev/null +++ b/lib/PTO/IR/VMI.cpp @@ -0,0 +1,1407 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- VMI.cpp - PTO VMI type and attribute support -----------------------===// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" +#include "PTO/IR/VMIUtils.h" + +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Types.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +static std::string formatVMIVRegType(int64_t elementCount, Type elementType, + Attribute layout) { + std::string result; + llvm::raw_string_ostream os(result); + os << "!pto.vmi.vreg<" << elementCount << "x" << elementType; + if (layout) + os << ", " << layout; + os << ">"; + return result; +} + +static std::string formatVMIMaskType(int64_t elementCount, + StringRef granularity, + Attribute layout) { + std::string result; + llvm::raw_string_ostream os(result); + os << "!pto.vmi.mask<" << elementCount << "x" << granularity; + if (layout) + os << ", " << layout; + os << ">"; + return result; +} + +static bool isSupportedVMIElementType(Type type) { + return isa(type) || + pto::isPTOLowPrecisionType(type); +} + +static bool isVMIFloatLikeType(Type type) { + return isa(type) || pto::isPTOLowPrecisionType(type); +} + +static bool isVMIIntegerLikeType(Type type) { + return isa(type); +} + +static bool isVMIIotaElementType(Type type) { + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + return type.isF16() || type.isF32(); +} + +static bool isCompatibleScalarForSemanticType(Type semanticType, + Type scalarType) { + if (semanticType == scalarType) + return true; + + auto semanticInt = dyn_cast(semanticType); + auto scalarInt = dyn_cast(scalarType); + if (!semanticInt || !scalarInt || + semanticInt.getWidth() != scalarInt.getWidth()) + return false; + + if (semanticInt.isSigned()) + return scalarInt.isSigned() || scalarInt.isSignless(); + if (semanticInt.isUnsigned()) + return scalarInt.isUnsigned() || scalarInt.isSignless(); + return scalarInt.isSignless(); +} + +static unsigned getVMIElementBitWidth(Type type) { + if (isa(type)) + return 64; + return pto::getPTOStorageElemBitWidth(type); +} + +static std::optional getVMIIntegerOrFloatBitWidth(Type type) { + if (auto intType = dyn_cast(type)) + return intType.getWidth(); + if (auto floatType = dyn_cast(type)) + return floatType.getWidth(); + return std::nullopt; +} + +static int64_t divideCeilNonNegative(int64_t value, int64_t divisor) { + return value == 0 ? 0 : (value + divisor - 1) / divisor; +} + +static LogicalResult parseOptionalVMILayout(AsmParser &parser, + Attribute &layout) { + if (failed(parser.parseOptionalComma())) + return success(); + + if (failed(parser.parseAttribute(layout))) + return failure(); + if (!mlir::isa(layout)) + return parser.emitError(parser.getCurrentLocation(), + "expected #pto.vmi.layout attribute"); + return success(); +} + +static FailureOr getVMIElementCount(Type type) { + if (auto vregType = dyn_cast(type)) + return vregType.getElementCount(); + if (auto maskType = dyn_cast(type)) + return maskType.getElementCount(); + return failure(); +} + +static FailureOr getAssignedVMILayout(Type type) { + Attribute layout; + if (auto vregType = dyn_cast(type)) + layout = vregType.getLayout(); + else if (auto maskType = dyn_cast(type)) + layout = maskType.getLayout(); + else + return failure(); + + auto layoutAttr = dyn_cast_or_null(layout); + if (!layoutAttr) + return failure(); + return layoutAttr; +} + +static FailureOr getLayoutFactor(Type type) { + FailureOr layout = getAssignedVMILayout(type); + if (failed(layout)) + return failure(); + return (*layout).isContiguous() ? 1 : (*layout).getFactor(); +} + +static FailureOr getPhysicalLanesPerPart(Type type) { + if (auto vregType = dyn_cast(type)) + return getDataLanesPerPart(vregType.getElementType()); + if (auto maskType = dyn_cast(type)) + return getMaskLanesPerPart(maskType.getGranularity()); + return failure(); +} + +static int64_t getMaskGranularityBitWidth(StringRef granularity) { + if (granularity == "b8") + return 8; + if (granularity == "b16") + return 16; + if (granularity == "b32") + return 32; + return 0; +} + +static bool isLayoutAssigned(VMIVRegType type) { + return static_cast(type.getLayoutAttr()); +} + +static bool isLayoutAssigned(VMIMaskType type) { + return static_cast(type.getLayoutAttr()); +} + +static LogicalResult verifyAllSameVRegShapeAndLayout(Operation *op, + ArrayRef types, + bool requireSameElement) { + if (types.empty()) + return success(); + + VMIVRegType first = types.front(); + bool anyLayout = llvm::any_of(types, [](VMIVRegType type) { + return isLayoutAssigned(type); + }); + + for (VMIVRegType type : types) { + if (type.getElementCount() != first.getElementCount()) + return op->emitOpError("requires all VMI data values to have the same logical lane count"); + if (requireSameElement && type.getElementType() != first.getElementType()) + return op->emitOpError("requires all VMI data values to have the same element type"); + if (anyLayout && !isLayoutAssigned(type)) + return op->emitOpError("requires either all or no VMI data values to carry layout"); + if (anyLayout && type.getLayout() != first.getLayout()) + return op->emitOpError("requires all layout-assigned VMI data values to have the same layout"); + } + return success(); +} + +static LogicalResult verifyElementwiseVRegOp(Operation *op, VMIVRegType lhs, + VMIVRegType rhs, + VMIVRegType result) { + return verifyAllSameVRegShapeAndLayout(op, {lhs, rhs, result}, + /*requireSameElement=*/true); +} + +static LogicalResult verifyFloatUnaryVRegOp(Operation *op, + VMIVRegType source, + VMIVRegType result) { + if (!isVMIFloatLikeType(source.getElementType())) + return op->emitOpError("requires floating-point-like VMI element type"); + return verifyAllSameVRegShapeAndLayout(op, {source, result}, + /*requireSameElement=*/true); +} + +static LogicalResult verifyFloatTernaryVRegOp(Operation *op, VMIVRegType lhs, + VMIVRegType rhs, VMIVRegType acc, + VMIVRegType result) { + if (!isVMIFloatLikeType(lhs.getElementType())) + return op->emitOpError("requires floating-point-like VMI element type"); + return verifyAllSameVRegShapeAndLayout(op, {lhs, rhs, acc, result}, + /*requireSameElement=*/true); +} + +static LogicalResult verifyAllSameMaskShapeLayoutAndGranularity( + Operation *op, ArrayRef types) { + if (types.empty()) + return success(); + + VMIMaskType first = types.front(); + bool anyLayout = llvm::any_of(types, [](VMIMaskType type) { + return isLayoutAssigned(type); + }); + + for (VMIMaskType type : types) { + if (type.getElementCount() != first.getElementCount()) + return op->emitOpError( + "requires all VMI mask values to have the same logical lane count"); + if (type.getGranularity() != first.getGranularity()) + return op->emitOpError( + "requires all VMI mask values to have the same granularity"); + if (anyLayout && !isLayoutAssigned(type)) + return op->emitOpError( + "requires either all or no VMI mask values to carry layout"); + if (anyLayout && type.getLayout() != first.getLayout()) + return op->emitOpError( + "requires all layout-assigned VMI mask values to have the same " + "layout"); + } + return success(); +} + +static LogicalResult verifyMaskMatchesData(Operation *op, VMIMaskType maskType, + VMIVRegType dataType) { + if (maskType.getElementCount() != dataType.getElementCount()) + return op->emitOpError("requires mask logical lane count to match data lane count"); + + if (isLayoutAssigned(maskType) || isLayoutAssigned(dataType)) { + if (!isLayoutAssigned(maskType) || !isLayoutAssigned(dataType)) + return op->emitOpError("requires either both mask and data to carry layout or neither to carry layout"); + if (maskType.getLayout() != dataType.getLayout()) + return op->emitOpError("requires mask layout to match data layout"); + } + + if (maskType.isPred()) + return success(); + + unsigned elementBitWidth = getVMIElementBitWidth(dataType.getElementType()); + int64_t maskBitWidth = getMaskGranularityBitWidth(maskType.getGranularity()); + if (elementBitWidth != 0 && maskBitWidth != 0 && + elementBitWidth != static_cast(maskBitWidth)) + return op->emitOpError("requires mask granularity to match data element width"); + + return success(); +} + +static Type getMemoryElementType(Type type) { + if (auto ptrType = dyn_cast(type)) + return ptrType.getElementType(); + if (auto memrefType = dyn_cast(type)) + return memrefType.getElementType(); + return {}; +} + +static LogicalResult verifyMemoryElementMatches(Operation *op, Type memoryType, + VMIVRegType dataType, + StringRef role) { + Type memoryElementType = getMemoryElementType(memoryType); + if (!memoryElementType) + return success(); + if (memoryElementType != dataType.getElementType()) + return op->emitOpError() + << "requires memory " << role + << " element type to match VMI data element type"; + return success(); +} + +static LogicalResult verifyPhysicalParts(Operation *op, Type vmiType, + TypeRange physicalTypes) { + FailureOr expectedArity = getVMIPhysicalArity(vmiType); + if (failed(expectedArity)) + return op->emitOpError("requires a layout-assigned VMI type with computable physical arity"); + if (static_cast(physicalTypes.size()) != *expectedArity) + return op->emitOpError() + << "requires " << *expectedArity << " physical parts, got " + << physicalTypes.size(); + + if (auto vregType = dyn_cast(vmiType)) { + FailureOr lanesPerPart = + getDataLanesPerPart(vregType.getElementType()); + if (failed(lanesPerPart)) + return op->emitOpError("requires data element type with known physical lane count"); + for (Type physicalType : physicalTypes) { + auto partType = dyn_cast(physicalType); + if (!partType) + return op->emitOpError("requires physical data parts to be !pto.vreg"); + if (partType.getElementCount() != *lanesPerPart || + partType.getElementType() != vregType.getElementType()) + return op->emitOpError("requires physical data part type to match VMI lane-map helper"); + } + return success(); + } + + auto maskType = dyn_cast(vmiType); + if (!maskType) + return op->emitOpError("requires VMI data or mask type"); + if (maskType.isPred()) + return op->emitOpError("requires layout-assigned mask with concrete granularity"); + + for (Type physicalType : physicalTypes) { + auto partType = dyn_cast(physicalType); + if (!partType) + return op->emitOpError("requires physical mask parts to be !pto.mask"); + if (partType.getGranularity() != maskType.getGranularity()) + return op->emitOpError("requires physical mask part granularity to match VMI mask"); + } + return success(); +} + +static int64_t getLogicalLanesInPart(int64_t elementCount, int64_t factor, + int64_t part) { + if (part < 0 || part >= factor || part >= elementCount) + return 0; + return ((elementCount - 1 - part) / factor) + 1; +} + +} // namespace + +VMILayoutAttr VMILayoutAttr::getContiguous(MLIRContext *context) { + return VMILayoutAttr::get(context, "contiguous", 1); +} + +VMILayoutAttr VMILayoutAttr::getDeinterleaved(MLIRContext *context, + int64_t factor) { + return VMILayoutAttr::get(context, "deinterleaved", factor); +} + +Attribute VMILayoutAttr::parse(AsmParser &parser, Type) { + SMLoc loc = parser.getCurrentLocation(); + StringRef kind; + int64_t factor = 1; + + if (failed(parser.parseLess()) || failed(parser.parseKeyword(&kind))) + return {}; + + if (kind == "contiguous") { + factor = 1; + } else if (kind == "deinterleaved") { + if (failed(parser.parseEqual()) || failed(parser.parseInteger(factor))) + return {}; + } else { + parser.emitError(parser.getCurrentLocation(), + "expected VMI layout kind 'contiguous' or " + "'deinterleaved'"); + return {}; + } + + if (failed(parser.parseGreater())) + return {}; + + return parser.getChecked(loc, parser.getContext(), kind, + factor); +} + +void VMILayoutAttr::print(AsmPrinter &printer) const { + printer << "<" << getKind(); + if (isDeinterleaved()) + printer << " = " << getFactor(); + printer << ">"; +} + +LogicalResult +VMILayoutAttr::verify(function_ref emitError, + StringRef kind, int64_t factor) { + if (kind == "contiguous") { + if (factor != 1) + return emitError() + << "#pto.vmi.layout requires factor to be 1"; + return success(); + } + + if (kind == "deinterleaved") { + if (factor != 2 && factor != 4) + return emitError() + << "#pto.vmi.layout expected factor to be 2 or 4"; + return success(); + } + + return emitError() << "expected VMI layout kind to be 'contiguous' or " + "'deinterleaved'"; +} + +Type VMIVRegType::parse(AsmParser &parser) { + SmallVector shape; + Type elementType; + Attribute layout; + SMLoc loc = parser.getCurrentLocation(); + + if (failed(parser.parseLess()) || + failed(parser.parseDimensionList(shape, /*allowDynamic=*/false, + /*withTrailingX=*/true)) || + shape.size() != 1 || failed(parser.parseType(elementType)) || + failed(parseOptionalVMILayout(parser, layout)) || + failed(parser.parseGreater())) + return {}; + + return parser.getChecked(loc, parser.getContext(), + shape.front(), elementType, layout); +} + +void VMIVRegType::print(AsmPrinter &printer) const { + printer << "<" << getElementCount() << "x"; + printer.printType(getElementType()); + if (getLayout()) + printer << ", " << getLayout(); + printer << ">"; +} + +LogicalResult VMIVRegType::verify(function_ref emitError, + int64_t elementCount, Type elementType, + Attribute layout) { + if (elementCount <= 0) + return emitError() << "'" << formatVMIVRegType(elementCount, elementType, + layout) + << "' expected a positive element count"; + + if (!isSupportedVMIElementType(elementType)) + return emitError() << "'" << formatVMIVRegType(elementCount, elementType, + layout) + << "' expected an integer, index, floating-point, or " + "PTO low-precision element type"; + + if (layout && !mlir::isa(layout)) + return emitError() << "'" << formatVMIVRegType(elementCount, elementType, + layout) + << "' expected layout to be #pto.vmi.layout"; + + return success(); +} + +bool VMIMaskType::isSupportedGranularity(StringRef granularity) { + return granularity == "pred" || isConcreteGranularity(granularity); +} + +bool VMIMaskType::isConcreteGranularity(StringRef granularity) { + return granularity == "b8" || granularity == "b16" || granularity == "b32"; +} + +Type VMIMaskType::parse(AsmParser &parser) { + SmallVector shape; + StringRef granularity; + Attribute layout; + SMLoc loc = parser.getCurrentLocation(); + + if (failed(parser.parseLess()) || + failed(parser.parseDimensionList(shape, /*allowDynamic=*/false, + /*withTrailingX=*/true)) || + shape.size() != 1 || failed(parser.parseKeyword(&granularity)) || + failed(parseOptionalVMILayout(parser, layout)) || + failed(parser.parseGreater())) + return {}; + + return parser.getChecked(loc, parser.getContext(), + shape.front(), granularity, layout); +} + +void VMIMaskType::print(AsmPrinter &printer) const { + printer << "<" << getElementCount() << "x" << getGranularity(); + if (getLayout()) + printer << ", " << getLayout(); + printer << ">"; +} + +LogicalResult VMIMaskType::verify(function_ref emitError, + int64_t elementCount, StringRef granularity, + Attribute layout) { + if (elementCount <= 0) + return emitError() << "'" << formatVMIMaskType(elementCount, granularity, + layout) + << "' expected a positive element count"; + + if (!isSupportedGranularity(granularity)) + return emitError() << "'" << formatVMIMaskType(elementCount, granularity, + layout) + << "' expected granularity to be one of pred, b8, b16, " + "b32"; + + if (layout && !mlir::isa(layout)) + return emitError() << "'" << formatVMIMaskType(elementCount, granularity, + layout) + << "' expected layout to be #pto.vmi.layout"; + + if (granularity == "pred" && layout) + return emitError() << "'" << formatVMIMaskType(elementCount, granularity, + layout) + << "' pred mask must not carry layout"; + + if (granularity != "pred" && !layout) + return emitError() << "'" << formatVMIMaskType(elementCount, granularity, + layout) + << "' concrete mask granularity requires layout"; + + return success(); +} + +LogicalResult VMIConstantOp::verify() { + auto resultType = cast(getResult().getType()); + auto denseAttr = dyn_cast(getValue()); + if (!denseAttr) + return emitOpError("requires dense elements constant attribute"); + if (denseAttr.getElementType() != resultType.getElementType()) + return emitOpError("requires dense constant element type to match result element type"); + if (denseAttr.getNumElements() != resultType.getElementCount()) + return emitOpError("requires dense constant element count to match result logical lane count"); + return success(); +} + +LogicalResult VMIBroadcastOp::verify() { + auto resultType = cast(getResult().getType()); + Type valueType = getValue().getType(); + if (valueType == resultType.getElementType()) + return success(); + if (auto vregType = dyn_cast(valueType)) { + if (vregType.getElementCount() != 1) + return emitOpError("requires VMI vector input to have one logical lane"); + if (vregType.getElementType() != resultType.getElementType()) + return emitOpError("requires VMI vector input element type to match " + "result element type"); + return success(); + } + return emitOpError("requires scalar or VMI vector input element type to " + "match result element type"); +} + +LogicalResult VMIIotaOp::verify() { + auto resultType = cast(getResult().getType()); + Type elementType = resultType.getElementType(); + if (!isVMIIotaElementType(elementType)) + return emitOpError("requires result element type to be integer 8/16/32 " + "or f16/f32"); + if (!isCompatibleScalarForSemanticType(elementType, getBase().getType())) + return emitOpError("requires base type to match result element type"); + + if (std::optional order = getOrder()) { + if (*order != "ASC" && *order != "DESC") + return emitOpError("requires order to be ASC or DESC"); + } + return success(); +} + +LogicalResult VMICreateMaskOp::verify() { + auto resultType = cast(getResult().getType()); + if (!resultType.isPred() && !isLayoutAssigned(resultType)) + return emitOpError("requires concrete mask result to carry layout"); + return success(); +} + +LogicalResult VMIConstantMaskOp::verify() { + auto resultType = cast(getResult().getType()); + auto denseAttr = dyn_cast(getValue()); + if (!denseAttr) + return emitOpError("requires dense elements mask constant attribute"); + if (!denseAttr.getElementType().isInteger(1)) + return emitOpError("requires dense mask constant element type to be i1"); + if (denseAttr.getNumElements() != resultType.getElementCount()) + return emitOpError("requires dense mask constant element count to match result logical lane count"); + return success(); +} + +LogicalResult VMIMaskAndOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + return verifyAllSameMaskShapeLayoutAndGranularity( + getOperation(), {lhsType, rhsType, resultType}); +} + +LogicalResult VMIMaskOrOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + return verifyAllSameMaskShapeLayoutAndGranularity( + getOperation(), {lhsType, rhsType, resultType}); +} + +LogicalResult VMIMaskXOrOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + return verifyAllSameMaskShapeLayoutAndGranularity( + getOperation(), {lhsType, rhsType, resultType}); +} + +LogicalResult VMIMaskNotOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + return verifyAllSameMaskShapeLayoutAndGranularity( + getOperation(), {sourceType, resultType}); +} + +LogicalResult VMIAddFOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIFloatLikeType(lhsType.getElementType())) + return emitOpError("requires floating-point-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMIAddIOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(lhsType.getElementType())) + return emitOpError("requires integer-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMISubFOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIFloatLikeType(lhsType.getElementType())) + return emitOpError("requires floating-point-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMISubIOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(lhsType.getElementType())) + return emitOpError("requires integer-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMIMulFOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIFloatLikeType(lhsType.getElementType())) + return emitOpError("requires floating-point-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMIMulIOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(lhsType.getElementType())) + return emitOpError("requires integer-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMIFmaOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto accType = cast(getAcc().getType()); + auto resultType = cast(getResult().getType()); + return verifyFloatTernaryVRegOp(getOperation(), lhsType, rhsType, accType, + resultType); +} + +LogicalResult VMIDivFOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIFloatLikeType(lhsType.getElementType())) + return emitOpError("requires floating-point-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMIMinFOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIFloatLikeType(lhsType.getElementType())) + return emitOpError("requires floating-point-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMIMaxFOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIFloatLikeType(lhsType.getElementType())) + return emitOpError("requires floating-point-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMINegFOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + return verifyFloatUnaryVRegOp(getOperation(), sourceType, resultType); +} + +LogicalResult VMIAbsFOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + return verifyFloatUnaryVRegOp(getOperation(), sourceType, resultType); +} + +LogicalResult VMIAbsIOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(sourceType.getElementType())) + return emitOpError("requires integer-like VMI element type"); + return verifyAllSameVRegShapeAndLayout(getOperation(), + {sourceType, resultType}, + /*requireSameElement=*/true); +} + +LogicalResult VMISqrtOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + return verifyFloatUnaryVRegOp(getOperation(), sourceType, resultType); +} + +LogicalResult VMIExpOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + return verifyFloatUnaryVRegOp(getOperation(), sourceType, resultType); +} + +LogicalResult VMILnOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + return verifyFloatUnaryVRegOp(getOperation(), sourceType, resultType); +} + +LogicalResult VMIReluOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + return verifyFloatUnaryVRegOp(getOperation(), sourceType, resultType); +} + +LogicalResult VMIAndIOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(lhsType.getElementType())) + return emitOpError("requires integer-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMIOrIOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(lhsType.getElementType())) + return emitOpError("requires integer-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMIXOrIOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(lhsType.getElementType())) + return emitOpError("requires integer-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMIShLIOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(lhsType.getElementType())) + return emitOpError("requires integer-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMIShRUIOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + auto integerType = dyn_cast(lhsType.getElementType()); + if (!integerType || integerType.isSigned()) + return emitOpError( + "requires signless or unsigned integer VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMINotOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(sourceType.getElementType())) + return emitOpError("requires integer-like VMI element type"); + return verifyAllSameVRegShapeAndLayout(getOperation(), {sourceType, resultType}, + /*requireSameElement=*/true); +} + +LogicalResult VMICmpFOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIFloatLikeType(lhsType.getElementType())) + return emitOpError("requires floating-point-like VMI element type"); + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), {lhsType, rhsType}, + /*requireSameElement=*/true))) + return failure(); + return verifyMaskMatchesData(getOperation(), resultType, lhsType); +} + +LogicalResult VMICmpIOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(lhsType.getElementType())) + return emitOpError("requires integer-like VMI element type"); + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), {lhsType, rhsType}, + /*requireSameElement=*/true))) + return failure(); + return verifyMaskMatchesData(getOperation(), resultType, lhsType); +} + +LogicalResult VMISelectOp::verify() { + auto maskType = cast(getMask().getType()); + auto trueType = cast(getTrueValue().getType()); + auto falseType = cast(getFalseValue().getType()); + auto resultType = cast(getResult().getType()); + if (failed(verifyAllSameVRegShapeAndLayout( + getOperation(), {trueType, falseType, resultType}, + /*requireSameElement=*/true))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, resultType); +} + +LogicalResult VMIActivePrefixIndexOp::verify() { + auto maskType = cast(getMask().getType()); + auto resultType = cast(getResult().getType()); + auto resultIntType = dyn_cast(resultType.getElementType()); + if (!resultIntType || !resultIntType.isSignless()) + return emitOpError("requires signless integer result element type"); + unsigned resultWidth = resultIntType.getWidth(); + if (resultWidth != 8 && resultWidth != 16 && resultWidth != 32) + return emitOpError("requires i8, i16, or i32 result element type"); + return verifyMaskMatchesData(getOperation(), maskType, resultType); +} + +LogicalResult VMICompressOp::verify() { + auto sourceType = cast(getSource().getType()); + auto maskType = cast(getMask().getType()); + auto resultType = cast(getResult().getType()); + if (failed(verifyAllSameVRegShapeAndLayout( + getOperation(), {sourceType, resultType}, + /*requireSameElement=*/true))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, sourceType); +} + +LogicalResult VMICompressStoreOp::verify() { + auto valueType = cast(getValue().getType()); + auto maskType = cast(getMask().getType()); + if (failed(verifyMemoryElementMatches(getOperation(), + getDestination().getType(), valueType, + "destination"))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, valueType); +} + +void VMICompressStoreOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VMIReduceAddIOp::verify() { + auto sourceType = cast(getSource().getType()); + auto initType = cast(getInit().getType()); + auto maskType = cast(getMask().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(sourceType.getElementType())) + return emitOpError("requires integer-like VMI source element type"); + if (sourceType.getElementType() != initType.getElementType() || + sourceType.getElementType() != resultType.getElementType()) + return emitOpError( + "requires source, init, and result element types to match"); + if (initType.getElementCount() != 1 || resultType.getElementCount() != 1) + return emitOpError("requires init and result to be rank-0 VMI vectors"); + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), + {initType, resultType}, + /*requireSameElement=*/true))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, sourceType); +} + +LogicalResult VMIReduceAddFOp::verify() { + auto sourceType = cast(getSource().getType()); + auto initType = cast(getInit().getType()); + auto maskType = cast(getMask().getType()); + auto resultType = cast(getResult().getType()); + if (!getOperation()->hasAttr("reassoc")) + return emitOpError( + "requires reassoc attr because VPTO vcadd performs pair-wise " + "floating-point reduction"); + if (!isVMIFloatLikeType(sourceType.getElementType())) + return emitOpError("requires floating-point-like VMI source element type"); + if (sourceType.getElementType() != initType.getElementType() || + sourceType.getElementType() != resultType.getElementType()) + return emitOpError( + "requires source, init, and result element types to match"); + if (initType.getElementCount() != 1 || resultType.getElementCount() != 1) + return emitOpError("requires init and result to be rank-0 VMI vectors"); + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), + {initType, resultType}, + /*requireSameElement=*/true))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, sourceType); +} + +template +LogicalResult verifyReduceMinMaxFOp(OpTy op) { + auto sourceType = cast(op.getSource().getType()); + auto initType = cast(op.getInit().getType()); + auto maskType = cast(op.getMask().getType()); + auto resultType = cast(op.getResult().getType()); + if (!isVMIFloatLikeType(sourceType.getElementType())) + return op.emitOpError("requires floating-point-like VMI source element type"); + if (sourceType.getElementType() != initType.getElementType() || + sourceType.getElementType() != resultType.getElementType()) + return op.emitOpError( + "requires source, init, and result element types to match"); + if (initType.getElementCount() != 1 || resultType.getElementCount() != 1) + return op.emitOpError("requires init and result to be rank-0 VMI vectors"); + if (failed(verifyAllSameVRegShapeAndLayout(op.getOperation(), + {initType, resultType}, + /*requireSameElement=*/true))) + return failure(); + return verifyMaskMatchesData(op.getOperation(), maskType, sourceType); +} + +LogicalResult VMIReduceMaxFOp::verify() { + return verifyReduceMinMaxFOp(*this); +} + +LogicalResult VMIReduceMinFOp::verify() { + return verifyReduceMinMaxFOp(*this); +} + +LogicalResult VMIExtFOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementCount() != resultType.getElementCount()) + return emitOpError("requires source and result logical lane counts to match"); + if (!isVMIFloatLikeType(sourceType.getElementType()) || + !isVMIFloatLikeType(resultType.getElementType())) + return emitOpError("requires floating-point-like source and result element types"); + if (getVMIElementBitWidth(sourceType.getElementType()) >= + getVMIElementBitWidth(resultType.getElementType())) + return emitOpError("requires result element type to be wider than source element type"); + return success(); +} + +LogicalResult VMITruncFOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementCount() != resultType.getElementCount()) + return emitOpError("requires source and result logical lane counts to match"); + if (!isVMIFloatLikeType(sourceType.getElementType()) || + !isVMIFloatLikeType(resultType.getElementType())) + return emitOpError("requires floating-point-like source and result element types"); + if (getVMIElementBitWidth(sourceType.getElementType()) <= + getVMIElementBitWidth(resultType.getElementType())) + return emitOpError("requires result element type to be narrower than source element type"); + return success(); +} + +LogicalResult VMIBitcastOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + std::optional sourceBits = + getVMIIntegerOrFloatBitWidth(sourceType.getElementType()); + std::optional resultBits = + getVMIIntegerOrFloatBitWidth(resultType.getElementType()); + if (!sourceBits || !resultBits) + return emitOpError( + "requires integer or floating-point source and result element types"); + if (sourceType.getElementCount() * static_cast(*sourceBits) != + resultType.getElementCount() * static_cast(*resultBits)) + return emitOpError( + "requires source and result to carry the same total number of bits"); + + if (isLayoutAssigned(sourceType) || isLayoutAssigned(resultType)) { + if (!isLayoutAssigned(sourceType) || !isLayoutAssigned(resultType)) + return emitOpError( + "requires either both source and result to carry layout or neither " + "to carry layout"); + if (sourceType.getLayout() != resultType.getLayout()) + return emitOpError("requires source and result layouts to match"); + } + + return success(); +} + +LogicalResult VMILoadOp::verify() { + return verifyMemoryElementMatches(getOperation(), getSource().getType(), + cast(getResult().getType()), + "source"); +} + +void VMILoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VMIMaskedLoadOp::verify() { + auto maskType = cast(getMask().getType()); + auto passthruType = cast(getPassthru().getType()); + auto resultType = cast(getResult().getType()); + if (failed(verifyMemoryElementMatches(getOperation(), getSource().getType(), + resultType, "source"))) + return failure(); + if (failed(verifyAllSameVRegShapeAndLayout( + getOperation(), {passthruType, resultType}, + /*requireSameElement=*/true))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, resultType); +} + +void VMIMaskedLoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VMIGatherOp::verify() { + auto indicesType = cast(getIndices().getType()); + auto maskType = cast(getMask().getType()); + auto passthruType = cast(getPassthru().getType()); + auto resultType = cast(getResult().getType()); + if (failed(verifyMemoryElementMatches(getOperation(), getSource().getType(), + resultType, "source"))) + return failure(); + + auto indexElementType = dyn_cast(indicesType.getElementType()); + if (!indexElementType || indexElementType.getWidth() != 32 || + indexElementType.isSigned()) + return emitOpError("requires signless or unsigned 32-bit integer indices"); + + if (failed(verifyAllSameVRegShapeAndLayout( + getOperation(), {indicesType, passthruType, resultType}, + /*requireSameElement=*/false))) + return failure(); + if (failed(verifyAllSameVRegShapeAndLayout( + getOperation(), {passthruType, resultType}, + /*requireSameElement=*/true))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, resultType); +} + +void VMIGatherOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VMIExpandLoadOp::verify() { + auto maskType = cast(getMask().getType()); + auto passthruType = cast(getPassthru().getType()); + auto resultType = cast(getResult().getType()); + if (failed(verifyMemoryElementMatches(getOperation(), getSource().getType(), + resultType, "source"))) + return failure(); + if (failed(verifyAllSameVRegShapeAndLayout( + getOperation(), {passthruType, resultType}, + /*requireSameElement=*/true))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, resultType); +} + +void VMIExpandLoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VMIStoreOp::verify() { + return verifyMemoryElementMatches(getOperation(), + getDestination().getType(), + cast(getValue().getType()), + "destination"); +} + +void VMIStoreOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VMIMaskedStoreOp::verify() { + auto valueType = cast(getValue().getType()); + auto maskType = cast(getMask().getType()); + if (failed(verifyMemoryElementMatches(getOperation(), + getDestination().getType(), + valueType, "destination"))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, valueType); +} + +void VMIMaskedStoreOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VMIScatterOp::verify() { + auto valueType = cast(getValue().getType()); + auto indicesType = cast(getIndices().getType()); + auto maskType = cast(getMask().getType()); + if (failed(verifyMemoryElementMatches(getOperation(), + getDestination().getType(), + valueType, "destination"))) + return failure(); + + auto indexElementType = dyn_cast(indicesType.getElementType()); + if (!indexElementType || indexElementType.getWidth() != 32 || + indexElementType.isSigned()) + return emitOpError("requires signless or unsigned 32-bit integer indices"); + + if (failed(verifyAllSameVRegShapeAndLayout( + getOperation(), {valueType, indicesType}, + /*requireSameElement=*/false))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, valueType); +} + +void VMIScatterOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VMITileReadOp::verify() { + return verifyMemoryElementMatches(getOperation(), getSource().getType(), + cast(getResult().getType()), + "source"); +} + +void VMITileReadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VMITileWriteOp::verify() { + return verifyMemoryElementMatches(getOperation(), + getDestination().getType(), + cast(getValue().getType()), + "destination"); +} + +void VMITileWriteOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VMIShuffleOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementType() != resultType.getElementType()) + return emitOpError("requires result element type to match source element type"); + if (static_cast(getIndices().size()) != resultType.getElementCount()) + return emitOpError("requires shuffle index count to match result logical lane count"); + for (int64_t index : getIndices()) { + if (index < 0 || index >= sourceType.getElementCount()) + return emitOpError("requires every shuffle index to select an existing source logical lane"); + } + if (isLayoutAssigned(sourceType) || isLayoutAssigned(resultType)) { + if (!isLayoutAssigned(sourceType) || !isLayoutAssigned(resultType)) + return emitOpError("requires either both source and result to carry layout or neither to carry layout"); + } + return success(); +} + +LogicalResult VMIChannelSplitOp::verify() { + auto sourceType = cast(getSource().getType()); + if (getResults().size() < 2) + return emitOpError("requires at least two channel results"); + auto firstResultType = cast(getResults().front().getType()); + if (sourceType.getElementCount() != + static_cast(getResults().size()) * firstResultType.getElementCount()) + return emitOpError("requires source lane count to equal result count times per-channel lane count"); + for (Value result : getResults()) { + auto resultType = cast(result.getType()); + if (resultType.getElementCount() != firstResultType.getElementCount() || + resultType.getElementType() != sourceType.getElementType()) + return emitOpError("requires every channel result to have equal lane count and source element type"); + } + bool anyLayout = isLayoutAssigned(sourceType); + for (Value result : getResults()) + anyLayout |= isLayoutAssigned(cast(result.getType())); + if (anyLayout) { + if (!isLayoutAssigned(sourceType)) + return emitOpError("requires layout-assigned channel_split source when any channel result has layout"); + for (Value result : getResults()) { + auto resultType = cast(result.getType()); + if (!isLayoutAssigned(resultType)) + return emitOpError("requires every channel_split result to carry layout when source has layout"); + if (!cast(resultType.getLayout()).isContiguous()) + return emitOpError("requires layout-assigned channel_split results to be contiguous"); + } + int64_t channels = getResults().size(); + if (channels == 2 || channels == 4) { + auto sourceLayout = cast(sourceType.getLayout()); + auto expectedLayout = + VMILayoutAttr::getDeinterleaved(getContext(), channels); + if (!sourceLayout.isContiguous() && sourceLayout != expectedLayout) + return emitOpError("requires layout-assigned channel_split source to be contiguous or deinterleaved by result count"); + } + } + return success(); +} + +LogicalResult VMIChannelMergeOp::verify() { + if (getInputs().size() < 2) + return emitOpError("requires at least two channel inputs"); + auto firstInputType = cast(getInputs().front().getType()); + auto resultType = cast(getResult().getType()); + for (Value input : getInputs()) { + auto inputType = cast(input.getType()); + if (inputType.getElementCount() != firstInputType.getElementCount() || + inputType.getElementType() != firstInputType.getElementType()) + return emitOpError("requires all channel inputs to have the same lane count and element type"); + } + if (resultType.getElementCount() != + static_cast(getInputs().size()) * firstInputType.getElementCount() || + resultType.getElementType() != firstInputType.getElementType()) + return emitOpError("requires result lane count and element type to match merged channels"); + bool anyLayout = isLayoutAssigned(resultType); + for (Value input : getInputs()) + anyLayout |= isLayoutAssigned(cast(input.getType())); + if (anyLayout) { + if (!isLayoutAssigned(resultType)) + return emitOpError("requires layout-assigned channel_merge result when any channel input has layout"); + for (Value input : getInputs()) { + auto inputType = cast(input.getType()); + if (!isLayoutAssigned(inputType)) + return emitOpError("requires every channel_merge input to carry layout when result has layout"); + if (!cast(inputType.getLayout()).isContiguous()) + return emitOpError("requires layout-assigned channel_merge inputs to be contiguous"); + } + int64_t channels = getInputs().size(); + if (channels == 2 || channels == 4) { + auto resultLayout = cast(resultType.getLayout()); + auto expectedLayout = + VMILayoutAttr::getDeinterleaved(getContext(), channels); + if (!resultLayout.isContiguous() && resultLayout != expectedLayout) + return emitOpError("requires layout-assigned channel_merge result to be contiguous or deinterleaved by input count"); + } + } + return success(); +} + +LogicalResult VMIEnsureLayoutOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementCount() != resultType.getElementCount() || + sourceType.getElementType() != resultType.getElementType()) + return emitOpError("requires source and result to preserve VMI data shape and element type"); + if (!isLayoutAssigned(sourceType) || !isLayoutAssigned(resultType)) + return emitOpError("requires source and result to be layout-assigned"); + return success(); +} + +LogicalResult VMIEnsureMaskLayoutOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementCount() != resultType.getElementCount() || + sourceType.getGranularity() != resultType.getGranularity()) + return emitOpError("requires source and result to preserve VMI mask shape and granularity"); + if (!isLayoutAssigned(sourceType) || !isLayoutAssigned(resultType)) + return emitOpError("requires source and result to be layout-assigned"); + return success(); +} + +LogicalResult VMIEnsureMaskGranularityOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementCount() != resultType.getElementCount()) + return emitOpError("requires source and result to preserve VMI mask lane count"); + if (!isLayoutAssigned(sourceType) || !isLayoutAssigned(resultType)) + return emitOpError("requires source and result to be layout-assigned"); + if (sourceType.getLayout() != resultType.getLayout()) + return emitOpError("requires source and result mask layouts to match"); + if (sourceType.isPred() || resultType.isPred()) + return emitOpError("requires concrete source and result mask granularities"); + return success(); +} + +LogicalResult VMIUnpackOp::verify() { + return verifyPhysicalParts(getOperation(), getSource().getType(), + getParts().getTypes()); +} + +LogicalResult VMIPackOp::verify() { + return verifyPhysicalParts(getOperation(), getResult().getType(), + getParts().getTypes()); +} + +FailureOr mlir::pto::getDataLanesPerPart(Type elementType) { + unsigned elementBitWidth = pto::getPTOStorageElemBitWidth(elementType); + if (elementBitWidth == 0) + return failure(); + constexpr int64_t kPhysicalVRegBits = 256 * 8; + if (kPhysicalVRegBits % elementBitWidth != 0) + return failure(); + return kPhysicalVRegBits / elementBitWidth; +} + +FailureOr mlir::pto::getMaskLanesPerPart(StringRef granularity) { + if (granularity == "b8") + return 256; + if (granularity == "b16") + return 128; + if (granularity == "b32") + return 64; + return failure(); +} + +FailureOr mlir::pto::getVMIPhysicalArity(Type type) { + FailureOr elementCount = getVMIElementCount(type); + FailureOr factor = getLayoutFactor(type); + FailureOr lanesPerPart = getPhysicalLanesPerPart(type); + if (failed(elementCount) || failed(factor) || failed(lanesPerPart)) + return failure(); + + int64_t arity = 0; + for (int64_t part = 0; part < *factor; ++part) { + int64_t lanesInPart = getLogicalLanesInPart(*elementCount, *factor, part); + arity += divideCeilNonNegative(lanesInPart, *lanesPerPart); + } + return arity; +} + +FailureOr +mlir::pto::mapLogicalLaneToPhysical(Type type, int64_t logicalLane) { + FailureOr elementCount = getVMIElementCount(type); + FailureOr factor = getLayoutFactor(type); + FailureOr lanesPerPart = getPhysicalLanesPerPart(type); + if (failed(elementCount) || failed(factor) || failed(lanesPerPart)) + return failure(); + if (logicalLane < 0 || logicalLane >= *elementCount) + return failure(); + + int64_t part = logicalLane % *factor; + int64_t indexInPart = logicalLane / *factor; + return VMIPhysicalLane{part, indexInPart / *lanesPerPart, + indexInPart % *lanesPerPart}; +} + +FailureOr mlir::pto::mapPhysicalLaneToLogical(Type type, int64_t part, + int64_t chunk, + int64_t lane) { + FailureOr elementCount = getVMIElementCount(type); + FailureOr factor = getLayoutFactor(type); + FailureOr lanesPerPart = getPhysicalLanesPerPart(type); + if (failed(elementCount) || failed(factor) || failed(lanesPerPart)) + return failure(); + if (part < 0 || part >= *factor || chunk < 0 || lane < 0 || + lane >= *lanesPerPart) + return failure(); + + int64_t indexInPart = chunk * *lanesPerPart + lane; + int64_t logicalLane = indexInPart * *factor + part; + if (logicalLane >= *elementCount) + return failure(); + return logicalLane; +} + +FailureOr mlir::pto::isPaddingLane(Type type, int64_t part, + int64_t chunk, int64_t lane) { + FailureOr elementCount = getVMIElementCount(type); + FailureOr factor = getLayoutFactor(type); + FailureOr lanesPerPart = getPhysicalLanesPerPart(type); + if (failed(elementCount) || failed(factor) || failed(lanesPerPart)) + return failure(); + if (part < 0 || part >= *factor || chunk < 0 || lane < 0 || + lane >= *lanesPerPart) + return failure(); + + int64_t lanesInPart = getLogicalLanesInPart(*elementCount, *factor, part); + int64_t indexInPart = chunk * *lanesPerPart + lane; + return indexInPart >= lanesInPart; +} diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index e372c3d711..fef96ec2c8 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -34,6 +34,9 @@ add_mlir_dialect_library(PTOTransforms PTOVPTOPtrBoundary.cpp VPTOBufferMaterialization.cpp PTOValidateVPTOIR.cpp + PTOValidateVMIIR.cpp + VMILayoutAssignment.cpp + VMIToVPTO.cpp PTOInferVPTOVecScope.cpp InsertSync/PTOInsertSync.cpp diff --git a/lib/PTO/Transforms/PTOValidateVMIIR.cpp b/lib/PTO/Transforms/PTOValidateVMIIR.cpp new file mode 100644 index 0000000000..6ce3e8eecd --- /dev/null +++ b/lib/PTO/Transforms/PTOValidateVMIIR.cpp @@ -0,0 +1,445 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOValidateVMIIR.cpp - VMI boundary verifier ----------------------===// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/IR/VMIUtils.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_PTOVALIDATEVMIIR +#define GEN_PASS_DEF_PTOVALIDATEVMILAYOUTIR +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +bool isVMIType(Type type) { return isa(type); } + +bool isPhysicalVPTOType(Type type) { + return isa(type); +} + +bool containsVMIOrPhysicalType(Type type) { + if (isVMIType(type) || isPhysicalVPTOType(type)) + return true; + + if (auto functionType = dyn_cast(type)) { + return llvm::any_of(functionType.getInputs(), [](Type input) { + return containsVMIOrPhysicalType(input); + }) || + llvm::any_of(functionType.getResults(), [](Type result) { + return containsVMIOrPhysicalType(result); + }); + } + + if (auto shapedType = dyn_cast(type)) + return containsVMIOrPhysicalType(shapedType.getElementType()); + + return false; +} + +bool containsVMIOrPhysicalType(Attribute attr) { + if (!attr) + return false; + + if (auto typeAttr = dyn_cast(attr)) + if (containsVMIOrPhysicalType(typeAttr.getValue())) + return true; + + if (auto typedAttr = dyn_cast(attr)) + if (containsVMIOrPhysicalType(typedAttr.getType())) + return true; + + if (auto arrayAttr = dyn_cast(attr)) + return llvm::any_of(arrayAttr, [](Attribute element) { + return containsVMIOrPhysicalType(element); + }); + + if (auto dictAttr = dyn_cast(attr)) + return llvm::any_of(dictAttr, [](NamedAttribute namedAttr) { + return containsVMIOrPhysicalType(namedAttr.getValue()); + }); + + return false; +} + +bool isSurfaceVMIType(Type type) { + if (auto vregType = dyn_cast(type)) + return !vregType.getLayout(); + if (auto maskType = dyn_cast(type)) + return maskType.isPred() && !maskType.getLayout(); + return false; +} + +bool isLayoutAssignedVMIType(Type type) { + if (auto vregType = dyn_cast(type)) + return static_cast(vregType.getLayoutAttr()); + if (auto maskType = dyn_cast(type)) + return maskType.getLayoutAttr() && + VMIMaskType::isConcreteGranularity(maskType.getGranularity()); + return false; +} + +bool isVMIHelperOp(Operation *op) { + StringRef name = op->getName().getStringRef(); + return name == "pto.vmi.ensure_layout" || + name == "pto.vmi.ensure_mask_layout" || + name == "pto.vmi.ensure_mask_granularity" || + name == "pto.vmi.pack" || name == "pto.vmi.unpack"; +} + +bool isVMILayoutHelperOp(Operation *op) { + StringRef name = op->getName().getStringRef(); + return name == "pto.vmi.ensure_layout" || + name == "pto.vmi.ensure_mask_layout" || + name == "pto.vmi.ensure_mask_granularity"; +} + +bool isVMISemanticOp(Operation *op) { + StringRef name = op->getName().getStringRef(); + return name.starts_with("pto.vmi.") && !isVMIHelperOp(op); +} + +bool isStructuralOp(Operation *op) { + StringRef name = op->getName().getStringRef(); + return name == "builtin.module" || name.starts_with("func.") || + name.starts_with("scf.") || name.starts_with("cf."); +} + +bool hasVMIOrPhysicalType(Operation *op) { + auto hasInterestingType = [](Type type) { + return isVMIType(type) || isPhysicalVPTOType(type); + }; + if (llvm::any_of(op->getOperandTypes(), hasInterestingType) || + llvm::any_of(op->getResultTypes(), hasInterestingType)) + return true; + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + if (llvm::any_of(block.getArgumentTypes(), hasInterestingType)) + return true; + } + } + return false; +} + +void mirrorDiagnostic(llvm::raw_ostream *diagOS, Twine message) { + if (diagOS) + *diagOS << message << "\n"; +} + +LogicalResult emitInvariant(Operation *op, llvm::raw_ostream *diagOS, + Twine message) { + InFlightDiagnostic diag = + op->emitError() << kVMIDiagPassInvariantPrefix << message; + (void)diag; + mirrorDiagnostic(diagOS, Twine(kVMIDiagPassInvariantPrefix) + message); + return failure(); +} + +LogicalResult verifyBoundaryType(Operation *owner, Type type, + llvm::raw_ostream *diagOS) { + if (isPhysicalVPTOType(type)) + return emitInvariant( + owner, diagOS, + "physical VPTO register type appears before VMI-to-VPTO"); + + if (isVMIType(type) && !isSurfaceVMIType(type)) + return emitInvariant( + owner, diagOS, + "VMI producer boundary requires surface !pto.vmi.vreg or " + "!pto.vmi.mask type"); + + return success(); +} + +LogicalResult verifyBoundaryTypeTree(Operation *owner, Type type, + llvm::raw_ostream *diagOS) { + if (failed(verifyBoundaryType(owner, type, diagOS))) + return failure(); + + if (auto functionType = dyn_cast(type)) { + for (Type input : functionType.getInputs()) + if (failed(verifyBoundaryTypeTree(owner, input, diagOS))) + return failure(); + for (Type result : functionType.getResults()) + if (failed(verifyBoundaryTypeTree(owner, result, diagOS))) + return failure(); + } + + if (auto shapedType = dyn_cast(type)) + return verifyBoundaryTypeTree(owner, shapedType.getElementType(), diagOS); + + return success(); +} + +LogicalResult verifyLayoutAssignedType(Operation *owner, Type type, + llvm::raw_ostream *diagOS) { + if (isPhysicalVPTOType(type)) + return emitInvariant( + owner, diagOS, + "physical VPTO register type appears before VMI-to-VPTO"); + + if (isVMIType(type) && !isLayoutAssignedVMIType(type)) + return emitInvariant( + owner, diagOS, + "layout-assigned VMI IR requires !pto.vmi.vreg with layout and " + "!pto.vmi.mask with b8/b16/b32 granularity plus layout"); + + return success(); +} + +LogicalResult verifyLayoutAssignedTypeTree(Operation *owner, Type type, + llvm::raw_ostream *diagOS) { + if (failed(verifyLayoutAssignedType(owner, type, diagOS))) + return failure(); + + if (auto functionType = dyn_cast(type)) { + for (Type input : functionType.getInputs()) + if (failed(verifyLayoutAssignedTypeTree(owner, input, diagOS))) + return failure(); + for (Type result : functionType.getResults()) + if (failed(verifyLayoutAssignedTypeTree(owner, result, diagOS))) + return failure(); + } + + if (auto shapedType = dyn_cast(type)) + return verifyLayoutAssignedTypeTree(owner, shapedType.getElementType(), + diagOS); + + return success(); +} + +template +LogicalResult verifyAttributeTypes(Operation *owner, Attribute attr, + llvm::raw_ostream *diagOS, + TypeVerifier verifyType) { + if (!attr) + return success(); + + if (auto typeAttr = dyn_cast(attr)) + if (failed(verifyType(owner, typeAttr.getValue(), diagOS))) + return failure(); + + if (auto typedAttr = dyn_cast(attr)) + if (failed(verifyType(owner, typedAttr.getType(), diagOS))) + return failure(); + + if (auto arrayAttr = dyn_cast(attr)) { + for (Attribute element : arrayAttr) + if (failed(verifyAttributeTypes(owner, element, diagOS, verifyType))) + return failure(); + } + + if (auto dictAttr = dyn_cast(attr)) { + for (NamedAttribute namedAttr : dictAttr) + if (failed(verifyAttributeTypes(owner, namedAttr.getValue(), diagOS, + verifyType))) + return failure(); + } + + return success(); +} + +bool isFunctionTypeAttr(Operation *op, NamedAttribute attr) { + return isa(op) && attr.getName() == "function_type"; +} + +LogicalResult verifyNoHiddenVMIAttributeType(Operation *op, + NamedAttribute attr, + llvm::raw_ostream *diagOS) { + if (isFunctionTypeAttr(op, attr)) + return success(); + if (containsVMIOrPhysicalType(attr.getValue())) + return emitInvariant( + op, diagOS, + "VMI or physical VPTO type appears in a non-signature attribute"); + return success(); +} + +LogicalResult verifyOperationTypes(Operation *op, llvm::raw_ostream *diagOS) { + if (auto funcOp = dyn_cast(op)) { + FunctionType functionType = funcOp.getFunctionType(); + for (Type type : functionType.getInputs()) + if (failed(verifyBoundaryTypeTree(op, type, diagOS))) + return failure(); + for (Type type : functionType.getResults()) + if (failed(verifyBoundaryTypeTree(op, type, diagOS))) + return failure(); + } + + for (Type type : op->getOperandTypes()) + if (failed(verifyBoundaryTypeTree(op, type, diagOS))) + return failure(); + for (Type type : op->getResultTypes()) + if (failed(verifyBoundaryTypeTree(op, type, diagOS))) + return failure(); + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + for (Type type : block.getArgumentTypes()) { + if (failed(verifyBoundaryTypeTree(op, type, diagOS))) + return failure(); + } + } + } + for (NamedAttribute attr : op->getAttrs()) { + if (failed(verifyNoHiddenVMIAttributeType(op, attr, diagOS))) + return failure(); + if (failed(verifyAttributeTypes(op, attr.getValue(), diagOS, + verifyBoundaryTypeTree))) + return failure(); + } + return success(); +} + +LogicalResult verifyLayoutAssignedOperationTypes(Operation *op, + llvm::raw_ostream *diagOS) { + if (auto funcOp = dyn_cast(op)) { + FunctionType functionType = funcOp.getFunctionType(); + for (Type type : functionType.getInputs()) + if (failed(verifyLayoutAssignedTypeTree(op, type, diagOS))) + return failure(); + for (Type type : functionType.getResults()) + if (failed(verifyLayoutAssignedTypeTree(op, type, diagOS))) + return failure(); + } + + for (Type type : op->getOperandTypes()) + if (failed(verifyLayoutAssignedTypeTree(op, type, diagOS))) + return failure(); + for (Type type : op->getResultTypes()) + if (failed(verifyLayoutAssignedTypeTree(op, type, diagOS))) + return failure(); + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + for (Type type : block.getArgumentTypes()) { + if (failed(verifyLayoutAssignedTypeTree(op, type, diagOS))) + return failure(); + } + } + } + for (NamedAttribute attr : op->getAttrs()) { + if (failed(verifyNoHiddenVMIAttributeType(op, attr, diagOS))) + return failure(); + if (failed(verifyAttributeTypes(op, attr.getValue(), diagOS, + verifyLayoutAssignedTypeTree))) + return failure(); + } + return success(); +} + +LogicalResult verifyOperationBoundary(Operation *op, + llvm::raw_ostream *diagOS) { + if (failed(verifyOperationTypes(op, diagOS))) + return failure(); + + if (!hasVMIOrPhysicalType(op)) + return success(); + + if (isVMIHelperOp(op)) + return emitInvariant( + op, diagOS, + "VMI helper op appears before layout assignment or VMI-to-VPTO"); + + if (isVMISemanticOp(op) || isStructuralOp(op)) + return success(); + + return emitInvariant(op, diagOS, + "VMI typed value is used by a non-VMI semantic op"); +} + +LogicalResult verifyLayoutAssignedOperation(Operation *op, + llvm::raw_ostream *diagOS) { + if (failed(verifyLayoutAssignedOperationTypes(op, diagOS))) + return failure(); + + if (!hasVMIOrPhysicalType(op)) + return success(); + + if (isVMIHelperOp(op)) { + if (isVMILayoutHelperOp(op)) + return success(); + return emitInvariant( + op, diagOS, + "VMI pack/unpack helper appears before VMI-to-VPTO physicalization"); + } + + if (isVMISemanticOp(op) || isStructuralOp(op)) + return success(); + + return emitInvariant(op, diagOS, + "VMI typed value is used by a non-VMI semantic op"); +} + +struct PTOValidateVMIIRPass + : public mlir::pto::impl::PTOValidateVMIIRBase { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PTOValidateVMIIRPass) + + void runOnOperation() override { + if (failed(validateVMIProducerBoundaryIR(getOperation(), &llvm::errs()))) + signalPassFailure(); + } +}; + +struct PTOValidateVMILayoutIRPass + : public mlir::pto::impl::PTOValidateVMILayoutIRBase< + PTOValidateVMILayoutIRPass> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PTOValidateVMILayoutIRPass) + + void runOnOperation() override { + if (failed(validateVMILayoutAssignedIR(getOperation(), &llvm::errs()))) + signalPassFailure(); + } +}; + +} // namespace + +LogicalResult mlir::pto::validateVMIProducerBoundaryIR( + ModuleOp module, llvm::raw_ostream *diagOS) { + WalkResult result = module.walk([&](Operation *op) { + if (failed(verifyOperationBoundary(op, diagOS))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); +} + +LogicalResult mlir::pto::validateVMILayoutAssignedIR( + ModuleOp module, llvm::raw_ostream *diagOS) { + WalkResult result = module.walk([&](Operation *op) { + if (failed(verifyLayoutAssignedOperation(op, diagOS))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); +} + +std::unique_ptr mlir::pto::createPTOValidateVMIIRPass() { + return std::make_unique(); +} + +std::unique_ptr mlir::pto::createPTOValidateVMILayoutIRPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/VMILayoutAssignment.cpp b/lib/PTO/Transforms/VMILayoutAssignment.cpp new file mode 100644 index 0000000000..e4d201d45c --- /dev/null +++ b/lib/PTO/Transforms/VMILayoutAssignment.cpp @@ -0,0 +1,1330 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- VMILayoutAssignment.cpp - Assign VMI layouts ----------------------===// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" +#include "PTO/IR/VMIUtils.h" +#include "PTO/Transforms/Passes.h" +#include "PTO/Transforms/VMITargetCapabilities.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_VMILAYOUTASSIGNMENT +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +struct DataNode { + Value value; + VMIVRegType type; + unsigned parent = 0; + VMILayoutAttr naturalLayout; +}; + +struct MaskNode { + Value value; + VMIMaskType type; + unsigned parent = 0; + VMILayoutAttr requestedLayout; + std::string requestedGranularity; +}; + +struct DataUseRequest { + OpOperand *operand; + VMILayoutAttr layout; +}; + +struct MaskUseRequest { + OpOperand *operand; + VMILayoutAttr layout; + std::string granularity; +}; + +static unsigned getElementBitWidth(Type type) { + if (isa(type)) + return 64; + return pto::getPTOStorageElemBitWidth(type); +} + +static StringRef getMaskGranularityForElement(Type elementType) { + switch (getElementBitWidth(elementType)) { + case 8: + return "b8"; + case 16: + return "b16"; + case 32: + return "b32"; + default: + return ""; + } +} + +static bool isLane0SplatShuffle(VMIShuffleOp op) { + auto sourceType = cast(op.getSource().getType()); + ArrayRef indices = op.getIndices(); + return sourceType.getElementCount() == 1 && !indices.empty() && + llvm::all_of(indices, [](int64_t index) { return index == 0; }); +} + +bool containsVMIType(Type type) { + if (isa(type)) + return true; + if (auto functionType = dyn_cast(type)) { + return llvm::any_of(functionType.getInputs(), [](Type input) { + return containsVMIType(input); + }) || + llvm::any_of(functionType.getResults(), [](Type result) { + return containsVMIType(result); + }); + } + if (auto shapedType = dyn_cast(type)) + return containsVMIType(shapedType.getElementType()); + return false; +} + +struct LayoutSolver { + explicit LayoutSolver(ModuleOp module, + const VMITargetCapabilityRegistry &capabilities) + : module(module), ctx(module.getContext()), capabilities(capabilities) {} + + unsigned addDataValue(Value value) { + auto type = dyn_cast(value.getType()); + if (!type) + return ~0u; + auto [it, inserted] = dataIds.try_emplace(value, dataNodes.size()); + if (inserted) + dataNodes.push_back( + DataNode{value, type, it->second, type.getLayoutAttr()}); + return it->second; + } + + unsigned addMaskValue(Value value) { + auto type = dyn_cast(value.getType()); + if (!type) + return ~0u; + auto [it, inserted] = maskIds.try_emplace(value, maskNodes.size()); + if (inserted) { + std::string granularity; + if (VMIMaskType::isConcreteGranularity(type.getGranularity())) + granularity = type.getGranularity().str(); + maskNodes.push_back( + MaskNode{value, type, it->second, type.getLayoutAttr(), granularity}); + } + return it->second; + } + + unsigned find(unsigned id) { + if (dataNodes[id].parent == id) + return id; + dataNodes[id].parent = find(dataNodes[id].parent); + return dataNodes[id].parent; + } + + unsigned findMask(unsigned id) { + if (maskNodes[id].parent == id) + return id; + maskNodes[id].parent = findMask(maskNodes[id].parent); + return maskNodes[id].parent; + } + + LogicalResult unite(Value lhs, Value rhs, Operation *op) { + unsigned lhsId = addDataValue(lhs); + unsigned rhsId = addDataValue(rhs); + if (lhsId == ~0u || rhsId == ~0u) + return success(); + unsigned lhsRoot = find(lhsId); + unsigned rhsRoot = find(rhsId); + if (lhsRoot == rhsRoot) + return success(); + dataNodes[rhsRoot].parent = lhsRoot; + VMILayoutAttr lhsNatural = dataNodes[lhsRoot].naturalLayout; + VMILayoutAttr rhsNatural = dataNodes[rhsRoot].naturalLayout; + if (lhsNatural && rhsNatural && lhsNatural != rhsNatural) + return op->emitError() + << kVMIDiagLayoutContractPrefix << "conflicting natural layouts " + << lhsNatural << " and " << rhsNatural; + if (!lhsNatural) + dataNodes[lhsRoot].naturalLayout = rhsNatural; + return success(); + } + + LogicalResult uniteMask(Value lhs, Value rhs, Operation *op) { + unsigned lhsId = addMaskValue(lhs); + unsigned rhsId = addMaskValue(rhs); + if (lhsId == ~0u || rhsId == ~0u) + return success(); + unsigned lhsRoot = findMask(lhsId); + unsigned rhsRoot = findMask(rhsId); + if (lhsRoot == rhsRoot) + return success(); + + MaskNode &lhsNode = maskNodes[lhsRoot]; + MaskNode &rhsNode = maskNodes[rhsRoot]; + if (lhsNode.requestedLayout && rhsNode.requestedLayout && + lhsNode.requestedLayout != rhsNode.requestedLayout) + return op->emitError() + << kVMIDiagLayoutContractPrefix << "conflicting mask layouts " + << lhsNode.requestedLayout << " and " << rhsNode.requestedLayout; + if (!lhsNode.requestedGranularity.empty() && + !rhsNode.requestedGranularity.empty() && + lhsNode.requestedGranularity != rhsNode.requestedGranularity) + return op->emitError() + << kVMIDiagLayoutContractPrefix + << "conflicting mask granularities " + << lhsNode.requestedGranularity << " and " + << rhsNode.requestedGranularity; + + rhsNode.parent = lhsRoot; + if (!lhsNode.requestedLayout) + lhsNode.requestedLayout = rhsNode.requestedLayout; + if (lhsNode.requestedGranularity.empty()) + lhsNode.requestedGranularity = rhsNode.requestedGranularity; + return success(); + } + + LogicalResult setNaturalLayout(Value value, VMILayoutAttr layout, + Operation *op) { + unsigned id = addDataValue(value); + if (id == ~0u || !layout) + return success(); + unsigned root = find(id); + VMILayoutAttr existing = dataNodes[root].naturalLayout; + if (existing && existing != layout) + return op->emitError() + << kVMIDiagLayoutContractPrefix << "conflicting natural layouts " + << existing << " and " << layout; + dataNodes[root].naturalLayout = layout; + return success(); + } + + VMILayoutAttr getContiguousLayout() { + return VMILayoutAttr::getContiguous(ctx); + } + + VMILayoutAttr getDataLayout(Value value) { + unsigned id = addDataValue(value); + if (id == ~0u) + return {}; + unsigned root = find(id); + if (dataNodes[root].naturalLayout) + return dataNodes[root].naturalLayout; + return getContiguousLayout(); + } + + LogicalResult requestMask(Value mask, VMILayoutAttr layout, + StringRef granularity, Operation *op) { + unsigned id = addMaskValue(mask); + if (id == ~0u) + return success(); + if (!layout || granularity.empty()) + return op->emitError() + << kVMIDiagLayoutContractPrefix + << "cannot infer concrete mask layout or granularity"; + MaskNode &node = maskNodes[findMask(id)]; + if (node.requestedLayout && node.requestedLayout != layout) + return op->emitError() + << kVMIDiagLayoutContractPrefix << "conflicting mask layouts " + << node.requestedLayout << " and " << layout; + if (!node.requestedGranularity.empty() && + node.requestedGranularity != granularity) + return op->emitError() + << kVMIDiagLayoutContractPrefix + << "conflicting mask granularities " + << node.requestedGranularity << " and " << granularity; + node.requestedLayout = layout; + node.requestedGranularity = granularity.str(); + return success(); + } + + void requestDataUse(OpOperand &operand, VMILayoutAttr layout) { + if (isa(operand.get().getType())) + dataUseRequests.push_back(DataUseRequest{&operand, layout}); + } + + bool canAdoptConsumerRequestedLayout(Value value) { + if (!value.hasOneUse()) + return false; + Operation *definingOp = value.getDefiningOp(); + return definingOp && isa(definingOp); + } + + LogicalResult applyConsumerDrivenDataLayouts() { + for (DataUseRequest request : dataUseRequests) { + Value value = request.operand->get(); + if (!canAdoptConsumerRequestedLayout(value)) + continue; + unsigned id = addDataValue(value); + if (id == ~0u) + continue; + unsigned root = find(id); + VMILayoutAttr existing = dataNodes[root].naturalLayout; + if (existing && existing != request.layout) + return request.operand->getOwner()->emitError() + << kVMIDiagLayoutContractPrefix + << "conflicting natural layouts " + << existing << " and " << request.layout; + dataNodes[root].naturalLayout = request.layout; + } + return success(); + } + + LogicalResult requestMaskUse(OpOperand &operand, VMILayoutAttr layout, + StringRef granularity, Operation *op) { + if (!isa(operand.get().getType())) + return success(); + if (!layout || granularity.empty()) + return op->emitError() + << kVMIDiagLayoutContractPrefix + << "cannot infer concrete mask use layout or granularity"; + maskUseRequests.push_back( + MaskUseRequest{&operand, layout, granularity.str()}); + return success(); + } + + LogicalResult collect() { + module.walk([&](Operation *op) { + for (Value result : op->getResults()) { + addDataValue(result); + addMaskValue(result); + } + for (Region ®ion : op->getRegions()) + for (Block &block : region) + for (BlockArgument arg : block.getArguments()) { + addDataValue(arg); + addMaskValue(arg); + } + }); + return success(); + } + + LogicalResult addConstraints() { + WalkResult result = module.walk([&](Operation *op) -> WalkResult { + if (auto maskAnd = dyn_cast(op)) { + if (failed(uniteMask(maskAnd.getLhs(), maskAnd.getRhs(), op)) || + failed(uniteMask(maskAnd.getLhs(), maskAnd.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto maskOr = dyn_cast(op)) { + if (failed(uniteMask(maskOr.getLhs(), maskOr.getRhs(), op)) || + failed(uniteMask(maskOr.getLhs(), maskOr.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto maskXor = dyn_cast(op)) { + if (failed(uniteMask(maskXor.getLhs(), maskXor.getRhs(), op)) || + failed(uniteMask(maskXor.getLhs(), maskXor.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto maskNot = dyn_cast(op)) { + if (failed(uniteMask(maskNot.getSource(), maskNot.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto addf = dyn_cast(op)) { + if (failed(unite(addf.getLhs(), addf.getRhs(), op)) || + failed(unite(addf.getLhs(), addf.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto addi = dyn_cast(op)) { + if (failed(unite(addi.getLhs(), addi.getRhs(), op)) || + failed(unite(addi.getLhs(), addi.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto subf = dyn_cast(op)) { + if (failed(unite(subf.getLhs(), subf.getRhs(), op)) || + failed(unite(subf.getLhs(), subf.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto subi = dyn_cast(op)) { + if (failed(unite(subi.getLhs(), subi.getRhs(), op)) || + failed(unite(subi.getLhs(), subi.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto mulf = dyn_cast(op)) { + if (failed(unite(mulf.getLhs(), mulf.getRhs(), op)) || + failed(unite(mulf.getLhs(), mulf.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto muli = dyn_cast(op)) { + if (failed(unite(muli.getLhs(), muli.getRhs(), op)) || + failed(unite(muli.getLhs(), muli.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto fma = dyn_cast(op)) { + if (failed(unite(fma.getLhs(), fma.getRhs(), op)) || + failed(unite(fma.getLhs(), fma.getAcc(), op)) || + failed(unite(fma.getLhs(), fma.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto divf = dyn_cast(op)) { + if (failed(unite(divf.getLhs(), divf.getRhs(), op)) || + failed(unite(divf.getLhs(), divf.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto minf = dyn_cast(op)) { + if (failed(unite(minf.getLhs(), minf.getRhs(), op)) || + failed(unite(minf.getLhs(), minf.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto maxf = dyn_cast(op)) { + if (failed(unite(maxf.getLhs(), maxf.getRhs(), op)) || + failed(unite(maxf.getLhs(), maxf.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto negf = dyn_cast(op)) { + if (failed(unite(negf.getSource(), negf.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto absf = dyn_cast(op)) { + if (failed(unite(absf.getSource(), absf.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto absi = dyn_cast(op)) { + if (failed(unite(absi.getSource(), absi.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto sqrt = dyn_cast(op)) { + if (failed(unite(sqrt.getSource(), sqrt.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto exp = dyn_cast(op)) { + if (failed(unite(exp.getSource(), exp.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto ln = dyn_cast(op)) { + if (failed(unite(ln.getSource(), ln.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto relu = dyn_cast(op)) { + if (failed(unite(relu.getSource(), relu.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto andi = dyn_cast(op)) { + if (failed(unite(andi.getLhs(), andi.getRhs(), op)) || + failed(unite(andi.getLhs(), andi.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto ori = dyn_cast(op)) { + if (failed(unite(ori.getLhs(), ori.getRhs(), op)) || + failed(unite(ori.getLhs(), ori.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto xori = dyn_cast(op)) { + if (failed(unite(xori.getLhs(), xori.getRhs(), op)) || + failed(unite(xori.getLhs(), xori.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto shli = dyn_cast(op)) { + if (failed(unite(shli.getLhs(), shli.getRhs(), op)) || + failed(unite(shli.getLhs(), shli.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto shrui = dyn_cast(op)) { + if (failed(unite(shrui.getLhs(), shrui.getRhs(), op)) || + failed(unite(shrui.getLhs(), shrui.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto notOp = dyn_cast(op)) { + if (failed(unite(notOp.getSource(), notOp.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto cmpf = dyn_cast(op)) { + if (failed(unite(cmpf.getLhs(), cmpf.getRhs(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto cmpi = dyn_cast(op)) { + if (failed(unite(cmpi.getLhs(), cmpi.getRhs(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto select = dyn_cast(op)) { + if (failed(unite(select.getTrueValue(), select.getFalseValue(), op)) || + failed(unite(select.getTrueValue(), select.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto activePrefix = dyn_cast(op)) { + if (failed(setNaturalLayout(activePrefix.getResult(), + getContiguousLayout(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto compress = dyn_cast(op)) { + requestDataUse(compress.getSourceMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(compress.getResult(), + getContiguousLayout(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + requestDataUse(reduce.getSourceMutable(), getContiguousLayout()); + requestDataUse(reduce.getInitMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(reduce.getResult(), + getContiguousLayout(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + requestDataUse(reduce.getSourceMutable(), getContiguousLayout()); + requestDataUse(reduce.getInitMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(reduce.getResult(), + getContiguousLayout(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + requestDataUse(reduce.getSourceMutable(), getContiguousLayout()); + requestDataUse(reduce.getInitMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(reduce.getResult(), + getContiguousLayout(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + requestDataUse(reduce.getSourceMutable(), getContiguousLayout()); + requestDataUse(reduce.getInitMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(reduce.getResult(), + getContiguousLayout(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto extf = dyn_cast(op)) { + auto sourceType = cast(extf.getSource().getType()); + auto resultType = cast(extf.getResult().getType()); + unsigned sourceBits = getElementBitWidth(sourceType.getElementType()); + unsigned resultBits = getElementBitWidth(resultType.getElementType()); + if (sourceBits == 16 && resultBits == 32) { + requestDataUse(extf.getSourceMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(extf.getResult(), + VMILayoutAttr::getDeinterleaved(ctx, 2), + op))) + return WalkResult::interrupt(); + } else if (sourceBits == 8 && resultBits == 32) { + requestDataUse(extf.getSourceMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(extf.getResult(), + VMILayoutAttr::getDeinterleaved(ctx, 4), + op))) + return WalkResult::interrupt(); + } + return WalkResult::advance(); + } + if (auto truncf = dyn_cast(op)) { + auto sourceType = cast(truncf.getSource().getType()); + auto resultType = cast(truncf.getResult().getType()); + unsigned sourceBits = getElementBitWidth(sourceType.getElementType()); + unsigned resultBits = getElementBitWidth(resultType.getElementType()); + if (sourceBits == 32 && resultBits == 16) + requestDataUse(truncf.getSourceMutable(), + VMILayoutAttr::getDeinterleaved(ctx, 2)); + else if (sourceBits == 32 && resultBits == 8) + requestDataUse(truncf.getSourceMutable(), + VMILayoutAttr::getDeinterleaved(ctx, 4)); + if (failed(setNaturalLayout(truncf.getResult(), getContiguousLayout(), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto bitcast = dyn_cast(op)) { + if (failed(unite(bitcast.getSource(), bitcast.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto load = dyn_cast(op)) { + requestDataUse(load.getPassthruMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(load.getResult(), getContiguousLayout(), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto gather = dyn_cast(op)) { + auto resultType = cast(gather.getResult().getType()); + requestDataUse(gather.getIndicesMutable(), getContiguousLayout()); + requestDataUse(gather.getPassthruMutable(), getContiguousLayout()); + if (failed(requestMaskUse(gather.getMaskMutable(), + getContiguousLayout(), + getMaskGranularityForElement( + resultType.getElementType()), + op))) + return WalkResult::interrupt(); + if (failed(setNaturalLayout(gather.getResult(), + getContiguousLayout(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto load = dyn_cast(op)) { + requestDataUse(load.getPassthruMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(load.getResult(), getContiguousLayout(), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto store = dyn_cast(op)) { + requestDataUse(store.getValueMutable(), getContiguousLayout()); + return WalkResult::advance(); + } + if (auto store = dyn_cast(op)) { + auto valueType = cast(store.getValue().getType()); + requestDataUse(store.getValueMutable(), getContiguousLayout()); + if (failed(requestMaskUse(store.getMaskMutable(), + getContiguousLayout(), + getMaskGranularityForElement( + valueType.getElementType()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto scatter = dyn_cast(op)) { + auto valueType = cast(scatter.getValue().getType()); + requestDataUse(scatter.getValueMutable(), getContiguousLayout()); + requestDataUse(scatter.getIndicesMutable(), getContiguousLayout()); + if (failed(requestMaskUse(scatter.getMaskMutable(), + getContiguousLayout(), + getMaskGranularityForElement( + valueType.getElementType()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto store = dyn_cast(op)) { + auto valueType = cast(store.getValue().getType()); + requestDataUse(store.getValueMutable(), getContiguousLayout()); + if (failed(requestMaskUse(store.getMaskMutable(), + getContiguousLayout(), + getMaskGranularityForElement( + valueType.getElementType()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto tileWrite = dyn_cast(op)) { + requestDataUse(tileWrite.getValueMutable(), getContiguousLayout()); + return WalkResult::advance(); + } + if (auto split = dyn_cast(op)) { + int64_t channels = split.getNumResults(); + VMICapabilityResult capability = + capabilities.supportsChannelCount("pto.vmi.channel_split", + channels); + if (!capability.isSupported()) { + split.emitError() << kVMIDiagUnsupportedPrefix << capability.reason; + return WalkResult::interrupt(); + } + requestDataUse( + split.getSourceMutable(), + VMILayoutAttr::getDeinterleaved(ctx, channels)); + for (Value result : split.getResults()) + if (failed(setNaturalLayout(result, getContiguousLayout(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto merge = dyn_cast(op)) { + int64_t channels = merge.getInputs().size(); + VMICapabilityResult capability = + capabilities.supportsChannelCount("pto.vmi.channel_merge", + channels); + if (!capability.isSupported()) { + merge.emitError() << kVMIDiagUnsupportedPrefix << capability.reason; + return WalkResult::interrupt(); + } + for (OpOperand &input : merge.getInputsMutable()) + requestDataUse(input, getContiguousLayout()); + if (failed(setNaturalLayout( + merge.getResult(), + VMILayoutAttr::getDeinterleaved(ctx, channels), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto shuffle = dyn_cast(op)) { + auto sourceType = cast(shuffle.getSource().getType()); + auto resultType = cast(shuffle.getResult().getType()); + if (sourceType.hasLayout() || resultType.hasLayout()) + return WalkResult::advance(); + + requestDataUse(shuffle.getSourceMutable(), getContiguousLayout()); + if (isLane0SplatShuffle(shuffle)) + return WalkResult::advance(); + if (failed(setNaturalLayout(shuffle.getResult(), getContiguousLayout(), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto ifOp = dyn_cast(op)) { + if (failed(addIfConstraints(ifOp))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto executeRegionOp = dyn_cast(op)) { + if (failed(addExecuteRegionConstraints(executeRegionOp))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto indexSwitchOp = dyn_cast(op)) { + if (failed(addIndexSwitchConstraints(indexSwitchOp))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto whileOp = dyn_cast(op)) { + if (failed(addWhileConstraints(whileOp))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto forOp = dyn_cast(op)) { + if (failed(addForConstraints(forOp))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto branchOp = dyn_cast(op)) { + if (failed(addBranchConstraints(branchOp.getDest(), + branchOp.getDestOperands(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto condBranchOp = dyn_cast(op)) { + if (failed(addBranchConstraints(condBranchOp.getTrueDest(), + condBranchOp.getTrueDestOperands(), + op)) || + failed(addBranchConstraints(condBranchOp.getFalseDest(), + condBranchOp.getFalseDestOperands(), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto switchOp = dyn_cast(op)) { + if (failed(addBranchConstraints(switchOp.getDefaultDestination(), + switchOp.getDefaultOperands(), op))) + return WalkResult::interrupt(); + for (auto [dest, operands] : + llvm::zip(switchOp.getCaseDestinations(), + switchOp.getCaseOperands())) { + if (failed(addBranchConstraints(dest, operands, op))) + return WalkResult::interrupt(); + } + return WalkResult::advance(); + } + if (auto returnOp = dyn_cast(op)) { + if (failed(addReturnConstraints(returnOp))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto callOp = dyn_cast(op)) { + if (failed(addCallConstraints(callOp))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (op->getName().getStringRef() == "func.call_indirect") { + if (hasVMIValueTypes(op)) { + op->emitError() + << kVMIDiagLayoutContractPrefix + << "VMI typed call requires a direct internal callee with a body"; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + } + if (auto funcOp = dyn_cast(op)) { + if (funcOp.empty() && hasVMIFunctionType(funcOp)) { + funcOp.emitError() + << kVMIDiagLayoutContractPrefix + << "VMI typed function declaration requires an explicit " + "external ABI materialization plan"; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + } + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); + } + + LogicalResult uniteEquivalentValues(Value lhs, Value rhs, Operation *op) { + if (failed(unite(lhs, rhs, op))) + return failure(); + return uniteMask(lhs, rhs, op); + } + + LogicalResult addIfConstraints(scf::IfOp ifOp) { + for (OpResult result : ifOp->getResults()) { + unsigned resultNo = result.getResultNumber(); + for (Region *region : {&ifOp.getThenRegion(), &ifOp.getElseRegion()}) { + if (region->empty()) + continue; + auto yieldOp = + dyn_cast(region->front().getTerminator()); + if (!yieldOp || resultNo >= yieldOp.getNumOperands()) + continue; + if (failed(uniteEquivalentValues(result, yieldOp.getOperand(resultNo), + ifOp))) + return failure(); + } + } + return success(); + } + + LogicalResult addYieldConstraints(ResultRange results, scf::YieldOp yieldOp, + Operation *op) { + for (auto [index, result] : llvm::enumerate(results)) { + if (index >= yieldOp.getNumOperands()) + break; + if (failed(uniteEquivalentValues(result, yieldOp.getOperand(index), op))) + return failure(); + } + return success(); + } + + LogicalResult addExecuteRegionConstraints(scf::ExecuteRegionOp executeOp) { + WalkResult result = executeOp.getRegion().walk([&](scf::YieldOp yieldOp) { + if (yieldOp->getParentOp() != executeOp.getOperation()) + return WalkResult::advance(); + if (failed(addYieldConstraints(executeOp->getResults(), yieldOp, + executeOp))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); + } + + LogicalResult addIndexSwitchConstraints(scf::IndexSwitchOp indexSwitchOp) { + auto addBlockTerminator = [&](Block &block) -> LogicalResult { + auto yieldOp = dyn_cast(block.getTerminator()); + if (!yieldOp) + return success(); + return addYieldConstraints(indexSwitchOp->getResults(), yieldOp, + indexSwitchOp); + }; + + if (failed(addBlockTerminator(indexSwitchOp.getDefaultBlock()))) + return failure(); + for (unsigned idx = 0, e = indexSwitchOp.getNumCases(); idx < e; ++idx) + if (failed(addBlockTerminator(indexSwitchOp.getCaseBlock(idx)))) + return failure(); + return success(); + } + + LogicalResult addWhileConstraints(scf::WhileOp whileOp) { + auto inits = whileOp.getInits(); + auto beforeArgs = whileOp.getBeforeArguments(); + Block &afterBlock = whileOp.getAfter().front(); + auto conditionOp = + dyn_cast(whileOp.getBefore().front().getTerminator()); + auto yieldOp = dyn_cast(afterBlock.getTerminator()); + + for (auto [index, init] : llvm::enumerate(inits)) { + Value anchor = init; + if (index < beforeArgs.size() && + failed(uniteEquivalentValues(anchor, beforeArgs[index], whileOp))) + return failure(); + if (conditionOp && index < conditionOp.getArgs().size() && + failed(uniteEquivalentValues(anchor, conditionOp.getArgs()[index], + whileOp))) + return failure(); + if (index < afterBlock.getNumArguments() && + failed(uniteEquivalentValues(anchor, afterBlock.getArgument(index), + whileOp))) + return failure(); + if (yieldOp && index < yieldOp.getNumOperands() && + failed(uniteEquivalentValues(anchor, yieldOp.getOperand(index), + whileOp))) + return failure(); + if (index < whileOp.getNumResults() && + failed(uniteEquivalentValues(anchor, whileOp.getResult(index), + whileOp))) + return failure(); + } + return success(); + } + + LogicalResult addForConstraints(scf::ForOp forOp) { + auto initArgs = forOp.getInitArgs(); + auto regionIterArgs = forOp.getRegionIterArgs(); + auto results = forOp.getResults(); + scf::YieldOp yieldOp = nullptr; + if (Block *body = forOp.getBody()) + yieldOp = dyn_cast(body->getTerminator()); + + for (auto [index, initArg] : llvm::enumerate(initArgs)) { + Value anchor = initArg; + if (index < regionIterArgs.size() && + failed(uniteEquivalentValues(anchor, regionIterArgs[index], forOp))) + return failure(); + if (index < results.size() && + failed(uniteEquivalentValues(anchor, results[index], forOp))) + return failure(); + if (yieldOp && index < yieldOp.getNumOperands() && + failed(uniteEquivalentValues(anchor, yieldOp.getOperand(index), + forOp))) + return failure(); + } + return success(); + } + + LogicalResult addBranchConstraints(Block *dest, OperandRange operands, + Operation *op) { + if (!dest) + return success(); + for (auto [index, operand] : llvm::enumerate(operands)) { + if (index >= dest->getNumArguments()) + break; + if (failed(uniteEquivalentValues(operand, dest->getArgument(index), op))) + return failure(); + } + return success(); + } + + LogicalResult addReturnConstraints(func::ReturnOp returnOp) { + auto func = returnOp->getParentOfType(); + if (!func) + return success(); + + auto it = firstReturnOperandsByFunc.find(func); + if (it == firstReturnOperandsByFunc.end()) { + SmallVector operands(returnOp.getOperands()); + firstReturnOperandsByFunc.try_emplace(func, std::move(operands)); + return success(); + } + + ArrayRef firstOperands = it->second; + for (auto [index, operand] : llvm::enumerate(returnOp.getOperands())) { + if (index >= firstOperands.size()) + break; + if (failed(uniteEquivalentValues(firstOperands[index], operand, returnOp))) + return failure(); + } + return success(); + } + + bool hasVMIValueTypes(Operation *op) { + return llvm::any_of(op->getOperandTypes(), containsVMIType) || + llvm::any_of(op->getResultTypes(), containsVMIType); + } + + bool hasVMIFunctionType(func::FuncOp func) { + FunctionType type = func.getFunctionType(); + return llvm::any_of(type.getInputs(), containsVMIType) || + llvm::any_of(type.getResults(), containsVMIType); + } + + LogicalResult addCallConstraints(func::CallOp callOp) { + if (!hasVMIValueTypes(callOp)) + return success(); + + auto callee = SymbolTable::lookupNearestSymbolFrom( + callOp, callOp.getCalleeAttr()); + if (!callee || callee.empty()) + return callOp.emitError() + << kVMIDiagLayoutContractPrefix + << "VMI typed call requires a direct internal callee with a body"; + + for (auto [operand, argument] : + llvm::zip(callOp.getOperands(), callee.getArguments())) { + if (failed(uniteEquivalentValues(operand, argument, callOp))) + return failure(); + } + + SmallVector returns; + callee.walk([&](func::ReturnOp returnOp) { returns.push_back(returnOp); }); + for (func::ReturnOp returnOp : returns) { + for (auto [index, result] : llvm::enumerate(callOp.getResults())) { + if (index >= returnOp.getNumOperands()) + break; + if (failed(uniteEquivalentValues(result, returnOp.getOperand(index), + callOp))) + return failure(); + } + } + return success(); + } + + void rewriteDataTypes() { + for (DataNode &node : dataNodes) { + VMILayoutAttr layout = getDataLayout(node.value); + node.value.setType(VMIVRegType::get(ctx, node.type.getElementCount(), + node.type.getElementType(), layout)); + } + } + + std::optional rematerializeDataUse(Value value, VMIVRegType resultType, + Location loc, + OpBuilder &builder) { + if (auto constant = value.getDefiningOp()) { + auto denseAttr = dyn_cast(constant.getValue()); + if (denseAttr && denseAttr.isSplat()) + return builder.create(loc, resultType, + constant.getValue()) + .getResult(); + } + if (auto broadcast = value.getDefiningOp()) + return builder + .create(loc, resultType, broadcast.getValue()) + .getResult(); + if (auto iota = value.getDefiningOp()) + return builder.create(loc, resultType, iota.getBase(), + iota.getOrderAttr()) + .getResult(); + return std::nullopt; + } + + LogicalResult insertDataUseMaterializations() { + OpBuilder builder(ctx); + for (DataUseRequest request : dataUseRequests) { + Value value = request.operand->get(); + auto sourceType = dyn_cast(value.getType()); + if (!sourceType) + continue; + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + if (!sourceLayout) + return request.operand->getOwner()->emitError() + << kVMIDiagLayoutContractPrefix + << "data use materialization requires layout-assigned source " + "type"; + if (sourceLayout == request.layout) + continue; + + auto resultType = + VMIVRegType::get(ctx, sourceType.getElementCount(), + sourceType.getElementType(), request.layout); + builder.setInsertionPoint(request.operand->getOwner()); + std::optional rematerialized = + rematerializeDataUse(value, resultType, + request.operand->getOwner()->getLoc(), builder); + if (rematerialized) { + request.operand->set(*rematerialized); + continue; + } + auto ensure = builder.create( + request.operand->getOwner()->getLoc(), resultType, value); + request.operand->set(ensure.getResult()); + } + return success(); + } + + LogicalResult inferMaskRequests() { + WalkResult result = module.walk([&](Operation *op) -> WalkResult { + if (auto cmpf = dyn_cast(op)) { + auto lhsType = cast(cmpf.getLhs().getType()); + if (failed(requestMask(cmpf.getResult(), lhsType.getLayoutAttr(), + getMaskGranularityForElement( + lhsType.getElementType()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto cmpi = dyn_cast(op)) { + auto lhsType = cast(cmpi.getLhs().getType()); + if (failed(requestMask(cmpi.getResult(), lhsType.getLayoutAttr(), + getMaskGranularityForElement( + lhsType.getElementType()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto select = dyn_cast(op)) { + auto resultType = cast(select.getResult().getType()); + if (failed(requestMaskUse(select.getMaskMutable(), + resultType.getLayoutAttr(), + getMaskGranularityForElement( + resultType.getElementType()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto activePrefix = dyn_cast(op)) { + auto resultType = + cast(activePrefix.getResult().getType()); + if (failed(requestMaskUse(activePrefix.getMaskMutable(), + resultType.getLayoutAttr(), + getMaskGranularityForElement( + resultType.getElementType()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto compress = dyn_cast(op)) { + auto resultType = cast(compress.getResult().getType()); + if (failed(requestMaskUse(compress.getMaskMutable(), + resultType.getLayoutAttr(), + getMaskGranularityForElement( + resultType.getElementType()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + auto sourceType = cast(reduce.getSource().getType()); + if (failed(requestMaskUse(reduce.getMaskMutable(), + sourceType.getLayoutAttr(), + getMaskGranularityForElement( + sourceType.getElementType()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + auto sourceType = cast(reduce.getSource().getType()); + if (failed(requestMaskUse(reduce.getMaskMutable(), + sourceType.getLayoutAttr(), + getMaskGranularityForElement( + sourceType.getElementType()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + auto sourceType = cast(reduce.getSource().getType()); + if (failed(requestMaskUse(reduce.getMaskMutable(), + sourceType.getLayoutAttr(), + getMaskGranularityForElement( + sourceType.getElementType()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + auto sourceType = cast(reduce.getSource().getType()); + if (failed(requestMaskUse(reduce.getMaskMutable(), + sourceType.getLayoutAttr(), + getMaskGranularityForElement( + sourceType.getElementType()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto load = dyn_cast(op)) { + auto resultType = cast(load.getResult().getType()); + if (failed(requestMaskUse(load.getMaskMutable(), + resultType.getLayoutAttr(), + getMaskGranularityForElement( + resultType.getElementType()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto load = dyn_cast(op)) { + auto resultType = cast(load.getResult().getType()); + if (failed(requestMaskUse(load.getMaskMutable(), + resultType.getLayoutAttr(), + getMaskGranularityForElement( + resultType.getElementType()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); + } + + void rewriteMaskTypes() { + for (MaskNode &node : maskNodes) { + MaskNode &root = maskNodes[findMask(maskIds.lookup(node.value))]; + VMILayoutAttr layout = root.requestedLayout ? root.requestedLayout + : getContiguousLayout(); + StringRef granularity = root.requestedGranularity.empty() + ? StringRef("b32") + : StringRef(root.requestedGranularity); + node.value.setType(VMIMaskType::get(ctx, node.type.getElementCount(), + granularity, layout)); + } + } + + std::optional rematerializeMaskUse(Value value, VMIMaskType resultType, + Location loc, + OpBuilder &builder) { + if (auto createMask = value.getDefiningOp()) + return builder.create(loc, resultType, + createMask.getActiveLanes()) + .getResult(); + if (auto constantMask = value.getDefiningOp()) + return builder + .create(loc, resultType, + constantMask.getValueAttr()) + .getResult(); + return std::nullopt; + } + + LogicalResult insertMaskUseMaterializations() { + OpBuilder builder(ctx); + for (MaskUseRequest request : maskUseRequests) { + Value value = request.operand->get(); + auto sourceType = dyn_cast(value.getType()); + if (!sourceType) + continue; + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + if (!sourceLayout) + return request.operand->getOwner()->emitError() + << kVMIDiagLayoutContractPrefix + << "mask use materialization requires layout-assigned source " + "type"; + + builder.setInsertionPoint(request.operand->getOwner()); + Value current = value; + VMIMaskType currentType = sourceType; + auto requestedType = VMIMaskType::get(ctx, sourceType.getElementCount(), + request.granularity, + request.layout); + if (sourceType != requestedType) { + std::optional rematerialized = rematerializeMaskUse( + value, requestedType, request.operand->getOwner()->getLoc(), + builder); + if (rematerialized) { + request.operand->set(*rematerialized); + continue; + } + } + + if (sourceLayout != request.layout) { + auto layoutType = VMIMaskType::get(ctx, currentType.getElementCount(), + currentType.getGranularity(), + request.layout); + auto ensureLayout = builder.create( + request.operand->getOwner()->getLoc(), layoutType, current); + current = ensureLayout.getResult(); + currentType = layoutType; + } + + if (currentType.getGranularity() != request.granularity) { + auto granularityType = + VMIMaskType::get(ctx, currentType.getElementCount(), + request.granularity, request.layout); + auto ensureGranularity = + builder.create( + request.operand->getOwner()->getLoc(), granularityType, + current); + current = ensureGranularity.getResult(); + } + + if (current != value) + request.operand->set(current); + } + return success(); + } + + void rewriteFunctionType() { + module.walk([&](func::FuncOp func) { + if (func.empty()) + return; + + SmallVector inputs; + inputs.reserve(func.getNumArguments()); + for (BlockArgument arg : func.getArguments()) + inputs.push_back(arg.getType()); + + SmallVector results; + auto it = firstReturnOperandsByFunc.find(func); + if (it != firstReturnOperandsByFunc.end()) { + for (Value operand : it->second) + results.push_back(operand.getType()); + } else { + for (Type type : func.getFunctionType().getResults()) { + if (auto vregType = dyn_cast(type)) { + results.push_back(VMIVRegType::get(ctx, vregType.getElementCount(), + vregType.getElementType(), + getContiguousLayout())); + } else if (auto maskType = dyn_cast(type)) { + results.push_back(VMIMaskType::get(ctx, maskType.getElementCount(), + "b32", getContiguousLayout())); + } else { + results.push_back(type); + } + } + } + + func.setFunctionType(FunctionType::get(ctx, inputs, results)); + }); + } + + LogicalResult run() { + if (failed(collect())) + return failure(); + if (failed(addConstraints())) + return failure(); + if (failed(applyConsumerDrivenDataLayouts())) + return failure(); + rewriteDataTypes(); + if (failed(insertDataUseMaterializations())) + return failure(); + if (failed(inferMaskRequests())) + return failure(); + rewriteMaskTypes(); + if (failed(insertMaskUseMaterializations())) + return failure(); + rewriteFunctionType(); + return validateVMILayoutAssignedIR(module); + } + + ModuleOp module; + MLIRContext *ctx; + const VMITargetCapabilityRegistry &capabilities; + DenseMap dataIds; + DenseMap maskIds; + DenseMap> firstReturnOperandsByFunc; + SmallVector dataNodes; + SmallVector maskNodes; + SmallVector dataUseRequests; + SmallVector maskUseRequests; +}; + +struct VMILayoutAssignmentPass + : public mlir::pto::impl::VMILayoutAssignmentBase< + VMILayoutAssignmentPass> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VMILayoutAssignmentPass) + + void runOnOperation() override { + VMITargetCapabilityRegistry capabilities; + if (failed(LayoutSolver(getOperation(), capabilities).run())) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createVMILayoutAssignmentPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp new file mode 100644 index 0000000000..db19c2846b --- /dev/null +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -0,0 +1,6269 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- VMIToVPTO.cpp - Convert VMI to physical VPTO IR -------------------===// +//===----------------------------------------------------------------------===// + +// https://discourse.llvm.org/t/matchandrewrite-hiding-virtual-functions/84933/8 +#pragma GCC diagnostic ignored "-Woverloaded-virtual" + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" +#include "PTO/IR/VMIUtils.h" +#include "PTO/Transforms/Passes.h" +#include "PTO/Transforms/VMITargetCapabilities.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/OneToNTypeConversion.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/raw_ostream.h" +#include + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_VMITOVPTO +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +bool isVMIType(Type type) { return isa(type); } + +bool containsVMIType(Type type) { + if (isVMIType(type)) + return true; + + if (auto functionType = dyn_cast(type)) + return llvm::any_of(functionType.getInputs(), + [](Type input) { return containsVMIType(input); }) || + llvm::any_of(functionType.getResults(), + [](Type result) { return containsVMIType(result); }); + + if (auto shapedType = dyn_cast(type)) + return containsVMIType(shapedType.getElementType()); + + return false; +} + +bool hasVMIType(TypeRange types) { + return llvm::any_of(types, [](Type type) { return containsVMIType(type); }); +} + +bool hasVMIType(FunctionType type) { + return hasVMIType(type.getInputs()) || hasVMIType(type.getResults()); +} + +bool hasVMIType(Attribute attr) { + if (!attr) + return false; + + if (auto typeAttr = dyn_cast(attr)) + if (containsVMIType(typeAttr.getValue())) + return true; + + if (auto typedAttr = dyn_cast(attr)) + if (containsVMIType(typedAttr.getType())) + return true; + + if (auto arrayAttr = dyn_cast(attr)) + return llvm::any_of(arrayAttr, [](Attribute element) { + return hasVMIType(element); + }); + + if (auto dictAttr = dyn_cast(attr)) + return llvm::any_of(dictAttr, [](NamedAttribute namedAttr) { + return hasVMIType(namedAttr.getValue()); + }); + + return false; +} + +bool hasVMIType(Operation *op) { + if (auto func = dyn_cast(op)) + if (hasVMIType(func.getFunctionType())) + return true; + if (hasVMIType(op->getOperandTypes()) || hasVMIType(op->getResultTypes())) + return true; + for (Region ®ion : op->getRegions()) + for (Block &block : region) + if (hasVMIType(block.getArgumentTypes())) + return true; + for (NamedAttribute attr : op->getAttrs()) + if (hasVMIType(attr.getValue())) + return true; + return false; +} + +bool isVMIOp(Operation *op) { + return op->getName().getStringRef().starts_with("pto.vmi."); +} + +bool isLayoutAssignedVMIType(Type type) { + if (auto vregType = dyn_cast(type)) + return static_cast(vregType.getLayoutAttr()); + if (auto maskType = dyn_cast(type)) + return maskType.getLayoutAttr() && + VMIMaskType::isConcreteGranularity(maskType.getGranularity()); + return true; +} + +LogicalResult verifyLayoutAssignedVMITypeTree(Operation *op, Type type) { + if (!isLayoutAssignedVMIType(type)) + return op->emitError() + << kVMIDiagPassInvariantPrefix + << "vmi-to-vpto requires layout-assigned VMI types"; + + if (auto functionType = dyn_cast(type)) { + for (Type input : functionType.getInputs()) + if (failed(verifyLayoutAssignedVMITypeTree(op, input))) + return failure(); + for (Type result : functionType.getResults()) + if (failed(verifyLayoutAssignedVMITypeTree(op, result))) + return failure(); + } + + if (auto shapedType = dyn_cast(type)) + return verifyLayoutAssignedVMITypeTree(op, shapedType.getElementType()); + + return success(); +} + +LogicalResult verifyVMIToVPTOInputAttribute(Operation *op, Attribute attr) { + if (!attr) + return success(); + + if (auto typeAttr = dyn_cast(attr)) + if (failed(verifyLayoutAssignedVMITypeTree(op, typeAttr.getValue()))) + return failure(); + + if (auto typedAttr = dyn_cast(attr)) + if (failed(verifyLayoutAssignedVMITypeTree(op, typedAttr.getType()))) + return failure(); + + if (auto arrayAttr = dyn_cast(attr)) { + for (Attribute element : arrayAttr) + if (failed(verifyVMIToVPTOInputAttribute(op, element))) + return failure(); + } + + if (auto dictAttr = dyn_cast(attr)) { + for (NamedAttribute namedAttr : dictAttr) + if (failed(verifyVMIToVPTOInputAttribute(op, namedAttr.getValue()))) + return failure(); + } + + return success(); +} + +LogicalResult verifyVMIToVPTOInputTypes(Operation *op) { + for (Type type : op->getOperandTypes()) + if (failed(verifyLayoutAssignedVMITypeTree(op, type))) + return failure(); + for (Type type : op->getResultTypes()) + if (failed(verifyLayoutAssignedVMITypeTree(op, type))) + return failure(); + if (auto func = dyn_cast(op)) { + for (Type type : func.getFunctionType().getInputs()) + if (failed(verifyLayoutAssignedVMITypeTree(op, type))) + return failure(); + for (Type type : func.getFunctionType().getResults()) + if (failed(verifyLayoutAssignedVMITypeTree(op, type))) + return failure(); + } + for (Region ®ion : op->getRegions()) + for (Block &block : region) + for (Type type : block.getArgumentTypes()) + if (failed(verifyLayoutAssignedVMITypeTree(op, type))) + return failure(); + for (NamedAttribute attr : op->getAttrs()) + if (failed(verifyVMIToVPTOInputAttribute(op, attr.getValue()))) + return failure(); + return success(); +} + +LogicalResult verifyVMIToVPTOInputIR(ModuleOp module) { + WalkResult result = module.walk([&](Operation *op) { + if (failed(verifyVMIToVPTOInputTypes(op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); +} + +static std::optional materializeVPTOToVMI(OpBuilder &builder, + Type resultType, + ValueRange inputs, + Location loc) { + if (!isVMIType(resultType)) + return std::nullopt; + return builder.create(loc, resultType, inputs).getResult(); +} + +static std::optional> +materializeVMIToVPTO(OpBuilder &builder, TypeRange resultTypes, Value input, + Location loc) { + if (!isVMIType(input.getType())) + return std::nullopt; + auto unpackOp = builder.create(loc, resultTypes, input); + return SmallVector(unpackOp->getResults()); +} + +class VMIToVPTOTypeConverter final : public OneToNTypeConverter { +public: + VMIToVPTOTypeConverter() { + addConversion([](Type type) { return type; }); + addConversion([](VMIVRegType type, SmallVectorImpl &results) + -> LogicalResult { + FailureOr arity = getVMIPhysicalArity(type); + FailureOr lanesPerPart = + getDataLanesPerPart(type.getElementType()); + if (failed(arity) || failed(lanesPerPart)) + return failure(); + for (int64_t i = 0; i < *arity; ++i) + results.push_back(VRegType::get(type.getContext(), *lanesPerPart, + type.getElementType())); + return success(); + }); + addConversion([](VMIMaskType type, SmallVectorImpl &results) + -> LogicalResult { + FailureOr arity = getVMIPhysicalArity(type); + if (failed(arity)) + return failure(); + for (int64_t i = 0; i < *arity; ++i) + results.push_back(MaskType::get(type.getContext(), + type.getGranularity())); + return success(); + }); + TypeConverter::addSourceMaterialization(materializeVPTOToVMI); + TypeConverter::addArgumentMaterialization(materializeVPTOToVMI); + OneToNTypeConverter::addTargetMaterialization(materializeVMIToVPTO); + } +}; + +FailureOr createAllTrueMaskForVReg(Location loc, VRegType vregType, + PatternRewriter &rewriter) { + MLIRContext *ctx = rewriter.getContext(); + unsigned elementBits = + pto::getPTOStorageElemBitWidth(vregType.getElementType()); + if (elementBits == 8) + return rewriter + .create(loc, MaskType::get(ctx, "b8"), + rewriter.getStringAttr("PAT_ALL")) + .getResult(); + if (elementBits == 16) + return rewriter + .create(loc, MaskType::get(ctx, "b16"), + rewriter.getStringAttr("PAT_ALL")) + .getResult(); + if (elementBits == 32) + return rewriter + .create(loc, MaskType::get(ctx, "b32"), + rewriter.getStringAttr("PAT_ALL")) + .getResult(); + return failure(); +} + +FailureOr getMaskTypeForVReg(VRegType vregType, + MLIRContext *ctx) { + unsigned elementBits = + pto::getPTOStorageElemBitWidth(vregType.getElementType()); + if (elementBits == 8) + return MaskType::get(ctx, "b8"); + if (elementBits == 16) + return MaskType::get(ctx, "b16"); + if (elementBits == 32) + return MaskType::get(ctx, "b32"); + return failure(); +} + +FailureOr createAllTrueMask(Location loc, MaskType maskType, + PatternRewriter &rewriter) { + StringAttr pattern = rewriter.getStringAttr("PAT_ALL"); + MLIRContext *ctx = rewriter.getContext(); + if (maskType.isB8()) + return rewriter.create(loc, MaskType::get(ctx, "b8"), pattern) + .getResult(); + if (maskType.isB16()) + return rewriter.create(loc, MaskType::get(ctx, "b16"), pattern) + .getResult(); + if (maskType.isB32()) + return rewriter.create(loc, MaskType::get(ctx, "b32"), pattern) + .getResult(); + return failure(); +} + +FailureOr createPatternMask(Location loc, MaskType maskType, + StringRef pattern, + PatternRewriter &rewriter) { + StringAttr patternAttr = rewriter.getStringAttr(pattern); + MLIRContext *ctx = rewriter.getContext(); + if (maskType.isB8()) + return rewriter.create(loc, MaskType::get(ctx, "b8"), patternAttr) + .getResult(); + if (maskType.isB16()) + return rewriter + .create(loc, MaskType::get(ctx, "b16"), patternAttr) + .getResult(); + if (maskType.isB32()) + return rewriter + .create(loc, MaskType::get(ctx, "b32"), patternAttr) + .getResult(); + return failure(); +} + +FailureOr createPrefixMask(Location loc, MaskType maskType, + StringRef pattern, + PatternRewriter &rewriter) { + StringAttr patternAttr = rewriter.getStringAttr(pattern); + MLIRContext *ctx = rewriter.getContext(); + if (maskType.isB8()) + return rewriter.create(loc, MaskType::get(ctx, "b8"), patternAttr) + .getResult(); + if (maskType.isB16()) + return rewriter.create(loc, MaskType::get(ctx, "b16"), patternAttr) + .getResult(); + if (maskType.isB32()) + return rewriter.create(loc, MaskType::get(ctx, "b32"), patternAttr) + .getResult(); + return failure(); +} + +FailureOr> +createRuntimePrefixMask(Location loc, MaskType maskType, Value activeLanes, + PatternRewriter &rewriter) { + MLIRContext *ctx = rewriter.getContext(); + Type scalarType = activeLanes.getType(); + if (maskType.isB8()) { + auto op = rewriter.create(loc, MaskType::get(ctx, "b8"), + scalarType, activeLanes); + return std::make_pair(Value(op.getMask()), Value(op.getScalarOut())); + } + if (maskType.isB16()) { + auto op = rewriter.create(loc, MaskType::get(ctx, "b16"), + scalarType, activeLanes); + return std::make_pair(Value(op.getMask()), Value(op.getScalarOut())); + } + if (maskType.isB32()) { + auto op = rewriter.create(loc, MaskType::get(ctx, "b32"), + scalarType, activeLanes); + return std::make_pair(Value(op.getMask()), Value(op.getScalarOut())); + } + return failure(); +} + +LogicalResult checkSupportedMaskableVReg( + const VMITargetCapabilityRegistry &capabilities, VMIVRegType type, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + VMICapabilityResult elementCapability = + capabilities.supportsElementType(type.getElementType(), + VMIElementPurpose::PredicateMask); + if (!elementCapability.isSupported()) + return fail(elementCapability.reason); + + FailureOr lanesPerPart = getDataLanesPerPart(type.getElementType()); + FailureOr arity = getVMIPhysicalArity(type); + if (failed(lanesPerPart) || failed(arity) || *arity < 1) + return fail("requires computable non-empty physical vreg parts"); + + return success(); +} + +LogicalResult checkSupportedTargetElementVReg( + const VMITargetCapabilityRegistry &capabilities, VMIVRegType type, + VMIElementPurpose purpose, StringRef elementContract, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (failed(checkSupportedMaskableVReg(capabilities, type, reason))) + return failure(); + + VMICapabilityResult elementCapability = + capabilities.supportsElementType(type.getElementType(), purpose); + if (!elementCapability.isSupported()) + return fail(elementCapability.reason); + + return success(); +} + +Value createI32Constant(Location loc, int64_t value, PatternRewriter &rewriter) { + return rewriter.create(loc, value, 32); +} + +Value clampDynamicActiveLanes(Location loc, Value activeLanes, + int64_t maxActiveLanes, + PatternRewriter &rewriter) { + Value activeI32 = + rewriter.create(loc, rewriter.getI32Type(), + activeLanes); + Value zeroI32 = createI32Constant(loc, 0, rewriter); + Value nonNegative = + rewriter.create(loc, activeI32, zeroI32); + Value maxI32 = createI32Constant(loc, maxActiveLanes, rewriter); + return rewriter.create(loc, nonNegative, maxI32); +} + +Value createPartitionActiveLanes(Location loc, Value activeLanesI32, + int64_t factor, int64_t part, + PatternRewriter &rewriter) { + if (factor == 1) + return activeLanesI32; + int64_t bias = factor - 1 - part; + Value biased = activeLanesI32; + if (bias != 0) + biased = + rewriter.create(loc, biased, + createI32Constant(loc, bias, rewriter)); + return rewriter.create( + loc, biased, createI32Constant(loc, factor, rewriter)); +} + +std::optional getPrefixPattern(int64_t activeLanes, + int64_t lanesPerPart) { + if (activeLanes <= 0) + return std::string("PAT_ALLF"); + if (activeLanes >= lanesPerPart) + return std::string("PAT_ALL"); + switch (activeLanes) { + case 1: + case 2: + case 3: + case 4: + case 8: + case 16: + case 32: + case 64: + case 128: + return std::string("PAT_VL") + std::to_string(activeLanes); + default: + return std::nullopt; + } +} + +FailureOr getSingleValue(Operation *op, ValueRange values, + StringRef description, + PatternRewriter &rewriter) { + if (values.size() != 1) { + (void)rewriter.notifyMatchFailure(op, description); + return failure(); + } + return values.front(); +} + +static int64_t ceilDivNonNegative(int64_t lhs, int64_t rhs) { + return (lhs + rhs - 1) / rhs; +} + +FailureOr getDataLayoutFactor(VMIVRegType type) { + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout) + return failure(); + return layout.isContiguous() ? 1 : layout.getFactor(); +} + +FailureOr getDataChunksInPart(VMIVRegType type, int64_t part) { + FailureOr factor = getDataLayoutFactor(type); + FailureOr lanesPerPart = getDataLanesPerPart(type.getElementType()); + if (failed(factor) || failed(lanesPerPart) || part < 0 || part >= *factor) + return failure(); + + int64_t logicalLanesInPart = + (type.getElementCount() + *factor - 1 - part) / *factor; + return ceilDivNonNegative(logicalLanesInPart, *lanesPerPart); +} + +FailureOr getDataFlatPartIndex(VMIVRegType type, int64_t part, + int64_t chunk) { + FailureOr factor = getDataLayoutFactor(type); + if (failed(factor) || part < 0 || part >= *factor || chunk < 0) + return failure(); + + int64_t flatIndex = 0; + for (int64_t currentPart = 0; currentPart < part; ++currentPart) { + FailureOr chunks = getDataChunksInPart(type, currentPart); + if (failed(chunks)) + return failure(); + flatIndex += *chunks; + } + + FailureOr chunks = getDataChunksInPart(type, part); + if (failed(chunks) || chunk >= *chunks) + return failure(); + return flatIndex + chunk; +} + +FailureOr checkFullDataPhysicalChunks(VMIVRegType type, + std::string *reason) { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + FailureOr lanesPerPart = getDataLanesPerPart(type.getElementType()); + if (failed(lanesPerPart)) + return fail("requires known physical lanes per part"); + + FailureOr factor = getDataLayoutFactor(type); + if (failed(factor)) + return fail("requires assigned layout"); + + for (int64_t part = 0; part < *factor; ++part) { + FailureOr chunks = getDataChunksInPart(type, part); + if (failed(chunks)) + return fail("requires known physical chunks"); + for (int64_t chunk = 0; chunk < *chunks; ++chunk) { + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = isPaddingLane(type, part, chunk, lane); + if (failed(padding)) + return fail("failed to map physical padding lane"); + if (*padding) + return fail("found padding lane in physical chunk"); + } + } + } + + return *lanesPerPart; +} + +FailureOr getVMITypeLayoutFactor(Type type) { + Attribute layout; + if (auto vregType = dyn_cast(type)) + layout = vregType.getLayout(); + else if (auto maskType = dyn_cast(type)) + layout = maskType.getLayout(); + else + return failure(); + + auto layoutAttr = dyn_cast_or_null(layout); + if (!layoutAttr) + return failure(); + return layoutAttr.isContiguous() ? 1 : layoutAttr.getFactor(); +} + +FailureOr getVMITypeElementCount(Type type) { + if (auto vregType = dyn_cast(type)) + return vregType.getElementCount(); + if (auto maskType = dyn_cast(type)) + return maskType.getElementCount(); + return failure(); +} + +FailureOr getVMITypeLanesPerPart(Type type) { + if (auto vregType = dyn_cast(type)) + return getDataLanesPerPart(vregType.getElementType()); + if (auto maskType = dyn_cast(type)) + return getMaskLanesPerPart(maskType.getGranularity()); + return failure(); +} + +FailureOr getVMITypeChunksInPart(Type type, int64_t part) { + FailureOr elementCount = getVMITypeElementCount(type); + FailureOr factor = getVMITypeLayoutFactor(type); + FailureOr lanesPerPart = getVMITypeLanesPerPart(type); + if (failed(elementCount) || failed(factor) || failed(lanesPerPart) || + part < 0 || part >= *factor) + return failure(); + + int64_t logicalLanesInPart = (*elementCount + *factor - 1 - part) / *factor; + return ceilDivNonNegative(logicalLanesInPart, *lanesPerPart); +} + +LogicalResult checkFullVMIPhysicalChunks(Type type, std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + FailureOr factor = getVMITypeLayoutFactor(type); + FailureOr lanesPerPart = getVMITypeLanesPerPart(type); + if (failed(factor) || failed(lanesPerPart)) + return fail("requires assigned layout with known physical lanes per part"); + + for (int64_t part = 0; part < *factor; ++part) { + FailureOr chunks = getVMITypeChunksInPart(type, part); + if (failed(chunks)) + return fail("requires known physical chunks"); + for (int64_t chunk = 0; chunk < *chunks; ++chunk) { + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = isPaddingLane(type, part, chunk, lane); + if (failed(padding)) + return fail("failed to map physical padding lane"); + if (*padding) + return fail("found padding lane in physical chunk"); + } + } + } + + return success(); +} + +FailureOr getContiguousMaterializationPartCount(Type type, + std::string *reason); + +LogicalResult checkSupportedLayoutMaterialization( + const VMITargetCapabilityRegistry &capabilities, Type sourceType, + Type resultType, VMILayoutAttr sourceLayout, VMILayoutAttr resultLayout, + std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + VMICapabilityResult layoutCapability = + capabilities.supportsLayoutConversion(sourceLayout, resultLayout, + Type{}); + if (!layoutCapability.isSupported()) + return fail(layoutCapability.reason); + + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(resultArity)) + return fail("requires computable source/result physical arity"); + if (*sourceArity != *resultArity) + return fail("requires source and result to have the same physical arity"); + + if (sourceLayout == resultLayout) + return success(); + + std::string sourceReason; + std::string resultReason; + LogicalResult sourceFull = + checkFullVMIPhysicalChunks(sourceType, &sourceReason); + LogicalResult resultFull = + checkFullVMIPhysicalChunks(resultType, &resultReason); + if (succeeded(sourceFull) && succeeded(resultFull)) + return success(); + + std::string sourceMaterializationReason; + FailureOr sourceMaterializedParts = + getContiguousMaterializationPartCount(sourceType, + &sourceMaterializationReason); + std::string resultMaterializationReason; + FailureOr resultMaterializedParts = + getContiguousMaterializationPartCount(resultType, + &resultMaterializationReason); + if (succeeded(sourceMaterializedParts) && + succeeded(resultMaterializedParts) && + *sourceMaterializedParts == *sourceArity && + *resultMaterializedParts == *resultArity) + return success(); + + if (failed(sourceFull)) + return fail(Twine("source ") + sourceReason + + "; source materialization " + sourceMaterializationReason); + return fail(Twine("result ") + resultReason + + "; result materialization " + resultMaterializationReason); +} + +FailureOr getContiguousMaterializationPartCount(Type type, + std::string *reason) { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + FailureOr arity = getVMIPhysicalArity(type); + FailureOr factor = getVMITypeLayoutFactor(type); + if (failed(arity) || failed(factor)) + return fail("requires computable physical arity and assigned layout"); + + Attribute layoutAttr; + if (auto vregType = dyn_cast(type)) + layoutAttr = vregType.getLayout(); + else if (auto maskType = dyn_cast(type)) + layoutAttr = maskType.getLayout(); + else + return fail("requires VMI data or mask type"); + + auto layout = dyn_cast_or_null(layoutAttr); + if (!layout) + return fail("requires assigned layout"); + if (layout.isContiguous()) + return *arity; + if (!layout.isDeinterleaved() || + (layout.getFactor() != 2 && layout.getFactor() != 4)) + return fail("requires contiguous or deinterleaved=2/4 layout"); + + FailureOr chunksPerGroup = getVMITypeChunksInPart(type, 0); + if (failed(chunksPerGroup)) + return fail("requires known physical chunks per part"); + if (*chunksPerGroup == 0) + return fail("requires at least one physical chunk per part"); + + for (int64_t part = 1; part < *factor; ++part) { + FailureOr chunks = getVMITypeChunksInPart(type, part); + if (failed(chunks)) + return fail("requires known physical chunks per part"); + if (*chunks != *chunksPerGroup) + return fail("requires every deinterleaved part to have the same " + "physical chunk count"); + } + + return *arity; +} + +LogicalResult checkCanMaterializeToContiguous(Type type, std::string *reason) { + return succeeded(getContiguousMaterializationPartCount(type, reason)) + ? success() + : failure(); +} + +std::optional getConstantIndexValue(Value value) { + if (auto constant = value.getDefiningOp()) + return constant.value(); + if (auto constant = value.getDefiningOp()) { + if (auto integerAttr = dyn_cast(constant.getValue())) + return integerAttr.getInt(); + } + return std::nullopt; +} + +FailureOr getStaticMemRefElementCount(Type type) { + auto memrefType = dyn_cast(type); + if (!memrefType || !memrefType.hasStaticShape()) + return failure(); + + int64_t elements = 1; + for (int64_t dim : memrefType.getShape()) + elements *= dim; + return elements; +} + +enum class VMIMemoryValidMaskKind { + AllTrue, + ExplicitMask, +}; + +enum class VMIMemoryWriteMaskKind { + AllTrue, + ExplicitMask, +}; + +enum class VMIMemoryPermutationKind { + Identity, +}; + +enum class VMIMemoryFallbackDecisionKind { + NotRequired, + RequiredUnavailable, +}; + +struct VMIMemoryLogicalShape { + int64_t elementCount = 0; +}; + +struct VMIMemoryLaneAddressMap { + VMIMemoryPermutationKind permutation = VMIMemoryPermutationKind::Identity; + int64_t baseElementOffset = 0; + int64_t elementStride = 1; + int64_t physicalLaneFootprint = 0; + + int64_t getExclusiveEndElement() const { + return baseElementOffset + physicalLaneFootprint * elementStride; + } +}; + +struct VMIMemoryFallbackDecision { + VMIMemoryFallbackDecisionKind kind = + VMIMemoryFallbackDecisionKind::NotRequired; + std::string reason = "not required"; + + static VMIMemoryFallbackDecision notRequired() { return {}; } + + static VMIMemoryFallbackDecision requiredUnavailable(const Twine &reason) { + VMIMemoryFallbackDecision decision; + decision.kind = VMIMemoryFallbackDecisionKind::RequiredUnavailable; + decision.reason = reason.str(); + return decision; + } +}; + +struct VMIMemorySafeReadProof { + bool proven = false; + std::string reason; + std::optional constantOffset; + std::optional staticElementCount; + std::optional laneAddressMap; + int64_t physicalFootprint = 0; +}; + +struct VMIMemoryAccessPlan { + Type baseType; + VMIVRegType valueType; + std::optional constantOffset; + VMIMemoryLogicalShape logicalShape; + VMIMemoryValidMaskKind validMask = VMIMemoryValidMaskKind::AllTrue; + VMIMemoryPermutationKind permutation = VMIMemoryPermutationKind::Identity; + std::optional laneAddressMap; + Attribute paddingValue; + VMIMemoryWriteMaskKind writeMask = VMIMemoryWriteMaskKind::AllTrue; + VMIMemorySafeReadProof safeReadProof; + VMICapabilityResult targetCapability; + VMICapabilityResult trueMaskedLoadCapability; + VMICapabilityResult scratchFallbackCapability; + VMICapabilityResult guardedFallbackCapability; + VMIMemoryFallbackDecision fallbackDecision; +}; + +FailureOr +buildContiguousIdentityLaneAddressMap(int64_t constantOffset, + VMIVRegType resultType, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + FailureOr lanesPerPart = + getDataLanesPerPart(resultType.getElementType()); + FailureOr arity = getVMIPhysicalArity(resultType); + if (failed(lanesPerPart) || failed(arity)) + return fail("requires computable physical read footprint"); + + VMIMemoryLaneAddressMap map; + map.baseElementOffset = constantOffset; + map.physicalLaneFootprint = *arity * *lanesPerPart; + return map; +} + +VMICapabilityResult requireIdentityMemRefLayout(Type memoryType, + StringRef role, + Value memoryValue = {}) { + auto memrefType = dyn_cast(memoryType); + if (!memrefType || memrefType.getLayout().isIdentity()) + return VMICapabilityResult::supported(); + std::string reason = + (Twine(role) + + " memref layout is non-identity; current VMI memory access plan " + "supports only contiguous identity lane-to-address maps") + .str(); + if (memoryValue && memoryValue.getDefiningOp()) + reason += "; memref.subview requires normalized base/offset/stride " + "lane-to-address planning"; + return VMICapabilityResult::missingCapability(reason); +} + +VMIMemorySafeReadProof +computeSafeFullReadProof(Type sourceType, std::optional constantOffset, + VMIVRegType resultType) { + VMIMemorySafeReadProof proof; + proof.constantOffset = constantOffset; + + auto fail = [&](const Twine &message) { + proof.proven = false; + proof.reason = message.str(); + return proof; + }; + + if (!constantOffset) + return fail("requires constant index offset"); + + FailureOr elements = getStaticMemRefElementCount(sourceType); + if (failed(elements)) + return fail("requires statically shaped memref source"); + proof.staticElementCount = *elements; + + if (*constantOffset < 0) + return fail("requires non-negative offset"); + + std::string addressMapReason; + FailureOr addressMap = + buildContiguousIdentityLaneAddressMap(*constantOffset, resultType, + &addressMapReason); + if (failed(addressMap)) + return fail(addressMapReason); + proof.laneAddressMap = *addressMap; + + proof.physicalFootprint = addressMap->physicalLaneFootprint; + if (addressMap->getExclusiveEndElement() > *elements) + return fail(Twine("full physical read footprint [") + + Twine(addressMap->baseElementOffset) + ", " + + Twine(addressMap->getExclusiveEndElement()) + + ") exceeds static memref element count " + Twine(*elements)); + + proof.proven = true; + return proof; +} + +VMIMemoryAccessPlan +buildReadAccessPlan(const VMITargetCapabilityRegistry &capabilities, + Value source, Type sourceType, VMIVRegType resultType, + std::optional constantOffset, + VMIMemoryValidMaskKind validMask) { + VMIMemoryAccessPlan plan; + plan.baseType = sourceType; + plan.valueType = resultType; + plan.constantOffset = constantOffset; + plan.logicalShape.elementCount = resultType.getElementCount(); + plan.validMask = validMask; + plan.permutation = VMIMemoryPermutationKind::Identity; + plan.writeMask = VMIMemoryWriteMaskKind::AllTrue; + plan.safeReadProof = + computeSafeFullReadProof(sourceType, constantOffset, resultType); + plan.laneAddressMap = plan.safeReadProof.laneAddressMap; + plan.targetCapability = capabilities.supportsDirectMemory(sourceType, + "source"); + if (plan.targetCapability.isSupported()) + plan.targetCapability = + requireIdentityMemRefLayout(sourceType, "source", source); + if (validMask == VMIMemoryValidMaskKind::ExplicitMask) + plan.trueMaskedLoadCapability = + capabilities.supportsTrueMaskedLoad(sourceType, resultType, Type{}); + plan.scratchFallbackCapability = + capabilities.supportsFallbackResource(VMIFallbackResourceKind::ScratchMemory); + plan.guardedFallbackCapability = capabilities.supportsFallbackResource( + VMIFallbackResourceKind::GuardedControlFlow); + return plan; +} + +VMIMemoryAccessPlan +buildWriteAccessPlan(const VMITargetCapabilityRegistry &capabilities, + Value destination, Type destinationType, + VMIVRegType valueType, + VMIMemoryWriteMaskKind writeMask) { + VMIMemoryAccessPlan plan; + plan.baseType = destinationType; + plan.valueType = valueType; + plan.logicalShape.elementCount = valueType.getElementCount(); + plan.validMask = VMIMemoryValidMaskKind::AllTrue; + plan.permutation = VMIMemoryPermutationKind::Identity; + plan.writeMask = writeMask; + plan.targetCapability = + capabilities.supportsDirectMemory(destinationType, "destination"); + if (plan.targetCapability.isSupported()) + plan.targetCapability = + requireIdentityMemRefLayout(destinationType, "destination", + destination); + return plan; +} + +void requireUnavailableReadFallback(VMIMemoryAccessPlan &plan) { + std::string maskedLoadReason; + if (plan.validMask == VMIMemoryValidMaskKind::ExplicitMask && + !plan.trueMaskedLoadCapability.isSupported()) + maskedLoadReason = + (Twine("; ") + plan.trueMaskedLoadCapability.reason).str(); + std::string scratchReason; + if (!plan.scratchFallbackCapability.isSupported()) + scratchReason = (Twine("; ") + plan.scratchFallbackCapability.reason).str(); + std::string guardedReason; + if (!plan.guardedFallbackCapability.isSupported()) + guardedReason = (Twine("; ") + plan.guardedFallbackCapability.reason).str(); + plan.fallbackDecision = VMIMemoryFallbackDecision::requiredUnavailable( + Twine("partial/tail read needs a scratch, guarded, or true " + "masked/non-faulting load fallback, but no such fallback resource " + "plan is implemented") + + maskedLoadReason + scratchReason + guardedReason); +} + +FailureOr +verifyFullOrSafeReadVRegChunks(Operation *op, VMIVRegType type, + Type sourceType, Value offset, + PatternRewriter &rewriter) { + std::string fullChunkReason; + FailureOr lanesPerPart = + checkFullDataPhysicalChunks(type, &fullChunkReason); + if (succeeded(lanesPerPart)) + return *lanesPerPart; + + VMIMemorySafeReadProof safeReadProof = + computeSafeFullReadProof(sourceType, getConstantIndexValue(offset), type); + if (safeReadProof.proven) { + lanesPerPart = getDataLanesPerPart(type.getElementType()); + if (succeeded(lanesPerPart)) + return *lanesPerPart; + } + + (void)rewriter.notifyMatchFailure( + op, Twine("memory lowering ") + fullChunkReason + + "; safe full-read proof failed: " + safeReadProof.reason); + return failure(); +} + +LogicalResult checkSupportedLoadShape( + const VMITargetCapabilityRegistry &capabilities, VMIVRegType type, + Value source, Type sourceType, std::optional constantOffset, + std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + VMIMemoryAccessPlan accessPlan = + buildReadAccessPlan(capabilities, source, sourceType, type, + constantOffset, VMIMemoryValidMaskKind::AllTrue); + if (!accessPlan.targetCapability.isSupported()) + return fail(accessPlan.targetCapability.reason); + + std::string fullChunkReason; + if (succeeded(checkFullDataPhysicalChunks(type, &fullChunkReason))) + return success(); + + if (accessPlan.safeReadProof.proven) + return success(); + requireUnavailableReadFallback(accessPlan); + return fail(Twine(fullChunkReason) + + "; safe-read proof failed: " + + accessPlan.safeReadProof.reason + + "; fallback decision: " + accessPlan.fallbackDecision.reason); +} + +LogicalResult checkSupportedStoreShape( + const VMITargetCapabilityRegistry &capabilities, VMIVRegType type, + Value destination, Type destinationType, std::string *reason) { + VMIMemoryAccessPlan accessPlan = + buildWriteAccessPlan(capabilities, destination, destinationType, type, + VMIMemoryWriteMaskKind::AllTrue); + if (!accessPlan.targetCapability.isSupported()) { + if (reason) + *reason = accessPlan.targetCapability.reason; + return failure(); + } + + if (failed(checkSupportedMaskableVReg(capabilities, type, reason))) + return failure(); + + std::string fullChunkReason; + if (succeeded(checkFullDataPhysicalChunks(type, &fullChunkReason))) + return success(); + + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout) + return fail("requires assigned layout"); + if (failed(getDataLanesPerPart(type.getElementType()))) + return fail("requires known physical lanes per part"); + if (layout.isContiguous()) + return success(); + + std::string materializationReason; + if (succeeded(checkCanMaterializeToContiguous(type, &materializationReason))) + return success(); + return fail(Twine("partial/tail store requires contiguous layout or " + "deinterleaved layout that can materialize to contiguous; " + "value ") + + fullChunkReason + ", materialization " + + materializationReason); +} + +LogicalResult +checkSupportedMaskedLoadShape(const VMITargetCapabilityRegistry &capabilities, + VMIMaskedLoadOp op, std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto resultType = cast(op.getResult().getType()); + auto passthruType = cast(op.getPassthru().getType()); + auto maskType = cast(op.getMask().getType()); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + VMILayoutAttr passthruLayout = passthruType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + VMIMemoryAccessPlan accessPlan = buildReadAccessPlan( + capabilities, op.getSource(), op.getSource().getType(), resultType, + getConstantIndexValue(op.getOffset()), + VMIMemoryValidMaskKind::ExplicitMask); + if (!accessPlan.targetCapability.isSupported()) + return fail(accessPlan.targetCapability.reason); + if (!resultLayout || !passthruLayout || !maskLayout) + return fail("requires assigned result, passthru, and mask layouts"); + if (!resultLayout.isContiguous() || !passthruLayout.isContiguous() || + !maskLayout.isContiguous()) + return fail("requires contiguous result, passthru, and mask layouts"); + + std::string fullChunkReason; + if (succeeded(checkFullDataPhysicalChunks(resultType, &fullChunkReason))) + return success(); + + if (accessPlan.safeReadProof.proven) + return success(); + requireUnavailableReadFallback(accessPlan); + return fail(Twine("partial/tail masked_load requires statically safe " + "full-read footprint; value ") + + fullChunkReason + ", safe-read proof " + + accessPlan.safeReadProof.reason + + "; fallback decision: " + accessPlan.fallbackDecision.reason); +} + +LogicalResult checkSupportedGatherShape( + const VMITargetCapabilityRegistry &capabilities, VMIGatherOp op, + std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto resultType = cast(op.getResult().getType()); + auto indicesType = cast(op.getIndices().getType()); + auto passthruType = cast(op.getPassthru().getType()); + auto maskType = cast(op.getMask().getType()); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + VMILayoutAttr indicesLayout = indicesType.getLayoutAttr(); + VMILayoutAttr passthruLayout = passthruType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + if (!resultLayout || !indicesLayout || !passthruLayout || !maskLayout) + return fail("requires assigned result, indices, passthru, and mask " + "layouts"); + if (!resultLayout.isContiguous() || !indicesLayout.isContiguous() || + !passthruLayout.isContiguous() || !maskLayout.isContiguous()) + return fail("requires contiguous result, indices, passthru, and mask " + "layouts"); + + VMICapabilityResult sourceCapability = capabilities.supportsUBPointerMemory( + op.getSource().getType(), "source", "pto.vgather2_bc", + "pto.vgather2_bc reads only UB"); + if (!sourceCapability.isSupported()) + return fail(sourceCapability.reason); + + if (pto::getPTOStorageElemBitWidth(resultType.getElementType()) != 32) + return fail("currently requires 32-bit result element type so physical " + "offset and result lane counts match pto.vgather2_bc"); + auto indexElementType = dyn_cast(indicesType.getElementType()); + if (!indexElementType || indexElementType.getWidth() != 32 || + indexElementType.isSigned()) + return fail("requires signless or unsigned 32-bit indices"); + if (maskType.getGranularity() != "b32") + return fail("requires b32 mask granularity"); + + FailureOr resultArity = getVMIPhysicalArity(resultType); + FailureOr indicesArity = getVMIPhysicalArity(indicesType); + FailureOr passthruArity = getVMIPhysicalArity(passthruType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + if (failed(resultArity) || failed(indicesArity) || failed(passthruArity) || + failed(maskArity)) + return fail("requires computable physical arity"); + if (*resultArity != *indicesArity || *resultArity != *passthruArity || + *resultArity != *maskArity) + return fail("requires result, indices, passthru, and mask to have the " + "same physical arity"); + + std::string resultReason; + std::string indicesReason; + std::string passthruReason; + std::string maskReason; + if (failed(checkFullDataPhysicalChunks(resultType, &resultReason))) + return fail(Twine("result requires full physical chunks; ") + + resultReason); + if (failed(checkFullDataPhysicalChunks(indicesType, &indicesReason))) + return fail(Twine("indices require full physical chunks; ") + + indicesReason); + if (failed(checkFullDataPhysicalChunks(passthruType, &passthruReason))) + return fail(Twine("passthru requires full physical chunks; ") + + passthruReason); + if (failed(checkFullVMIPhysicalChunks(maskType, &maskReason))) + return fail(Twine("mask requires full physical chunks; ") + maskReason); + + return success(); +} + +LogicalResult checkSupportedScatterShape( + const VMITargetCapabilityRegistry &capabilities, VMIScatterOp op, + std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (!op->hasAttr("indices_unique")) + return fail("requires indices_unique proof because pto.vscatter does not " + "define logical-lane-order duplicate-index semantics"); + + auto valueType = cast(op.getValue().getType()); + auto indicesType = cast(op.getIndices().getType()); + auto maskType = cast(op.getMask().getType()); + VMILayoutAttr valueLayout = valueType.getLayoutAttr(); + VMILayoutAttr indicesLayout = indicesType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + if (!valueLayout || !indicesLayout || !maskLayout) + return fail("requires assigned value, indices, and mask layouts"); + if (!valueLayout.isContiguous() || !indicesLayout.isContiguous() || + !maskLayout.isContiguous()) + return fail("requires contiguous value, indices, and mask layouts"); + + VMICapabilityResult destinationCapability = + capabilities.supportsUBPointerMemory( + op.getDestination().getType(), "destination", "pto.vscatter", + "pto.vscatter writes only UB"); + if (!destinationCapability.isSupported()) + return fail(destinationCapability.reason); + + if (pto::getPTOStorageElemBitWidth(valueType.getElementType()) != 32) + return fail("currently requires 32-bit value element type so physical " + "index and value lane counts match pto.vscatter"); + auto indexElementType = dyn_cast(indicesType.getElementType()); + if (!indexElementType || indexElementType.getWidth() != 32 || + indexElementType.isSigned()) + return fail("requires signless or unsigned 32-bit indices"); + if (maskType.getGranularity() != "b32") + return fail("requires b32 mask granularity"); + + FailureOr valueArity = getVMIPhysicalArity(valueType); + FailureOr indicesArity = getVMIPhysicalArity(indicesType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + if (failed(valueArity) || failed(indicesArity) || failed(maskArity)) + return fail("requires computable physical arity"); + if (*valueArity != *indicesArity || *valueArity != *maskArity) + return fail("requires value, indices, and mask to have the same physical " + "arity"); + + std::string valueReason; + std::string indicesReason; + std::string maskReason; + if (failed(checkFullDataPhysicalChunks(valueType, &valueReason))) + return fail(Twine("value requires full physical chunks; ") + valueReason); + if (failed(checkFullDataPhysicalChunks(indicesType, &indicesReason))) + return fail(Twine("indices require full physical chunks; ") + + indicesReason); + if (failed(checkFullVMIPhysicalChunks(maskType, &maskReason))) + return fail(Twine("mask requires full physical chunks; ") + maskReason); + + return success(); +} + +Value stripMaskMaterialization(Value value) { + while (true) { + if (auto ensure = value.getDefiningOp()) { + value = ensure.getSource(); + continue; + } + if (auto ensure = value.getDefiningOp()) { + value = ensure.getSource(); + continue; + } + return value; + } +} + +bool isStaticAllActiveMask(Value mask, int64_t expectedLanes, + std::string *reason = nullptr) { + mask = stripMaskMaterialization(mask); + auto fail = [&](const Twine &message) { + if (reason) + *reason = message.str(); + return false; + }; + + if (auto createMask = mask.getDefiningOp()) { + auto activeConstant = + createMask.getActiveLanes().getDefiningOp(); + if (!activeConstant) + return fail("create_mask active_lanes is dynamic"); + auto activeAttr = dyn_cast(activeConstant.getValue()); + if (!activeAttr) + return fail("create_mask active_lanes is not an integer constant"); + return activeAttr.getInt() >= expectedLanes + ? true + : fail("create_mask active_lanes is smaller than the logical " + "lane count"); + } + + if (auto constantMask = mask.getDefiningOp()) { + auto denseAttr = dyn_cast(constantMask.getValue()); + if (!denseAttr) + return fail("constant_mask is not a dense integer mask"); + if (denseAttr.getNumElements() != expectedLanes) + return fail("constant_mask element count does not match the logical " + "lane count"); + auto values = denseAttr.getValues(); + for (bool value : values) + if (!value) + return fail("constant_mask contains an inactive lane"); + return true; + } + + return fail("mask is not a static all-active create_mask or constant_mask"); +} + +LogicalResult +checkSupportedExpandLoadShape(const VMITargetCapabilityRegistry &capabilities, + VMIExpandLoadOp op, std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto resultType = cast(op.getResult().getType()); + auto passthruType = cast(op.getPassthru().getType()); + auto maskType = cast(op.getMask().getType()); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + VMILayoutAttr passthruLayout = passthruType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + VMIMemoryAccessPlan accessPlan = buildReadAccessPlan( + capabilities, op.getSource(), op.getSource().getType(), resultType, + getConstantIndexValue(op.getOffset()), + VMIMemoryValidMaskKind::ExplicitMask); + if (!accessPlan.targetCapability.isSupported()) + return fail(accessPlan.targetCapability.reason); + if (!resultLayout || !passthruLayout || !maskLayout) + return fail("requires assigned result, passthru, and mask layouts"); + if (!resultLayout.isContiguous() || !passthruLayout.isContiguous() || + !maskLayout.isContiguous()) + return fail("requires contiguous result, passthru, and mask layouts"); + + std::string maskReason; + bool staticAllActive = + isStaticAllActiveMask(op.getMask(), resultType.getElementCount(), + &maskReason); + + std::string fullChunkReason; + if (staticAllActive && + succeeded(checkFullDataPhysicalChunks(resultType, &fullChunkReason))) + return success(); + + if (staticAllActive && accessPlan.safeReadProof.proven) + return success(); + + std::string allActivePathReason; + if (!staticAllActive) { + allActivePathReason = maskReason.empty() ? "requires static all-active mask" + : maskReason; + } else { + requireUnavailableReadFallback(accessPlan); + allActivePathReason = + (Twine("requires full physical chunks or statically safe full-read " + "footprint; value ") + + fullChunkReason + ", safe-read proof " + + accessPlan.safeReadProof.reason + "; fallback decision: " + + accessPlan.fallbackDecision.reason) + .str(); + } + + VMICapabilityResult sourceCapability = capabilities.supportsUBPointerMemory( + op.getSource().getType(), "source", "pto.vgather2_bc", + "pto.vgather2_bc reads only UB"); + if (!sourceCapability.isSupported()) { + if (!isa(op.getSource().getType())) + return fail(Twine("runtime-mask path ") + sourceCapability.reason + + "; all-active path " + allActivePathReason); + return fail(Twine("runtime-mask path ") + sourceCapability.reason); + } + if (pto::getPTOStorageElemBitWidth(resultType.getElementType()) != 32) + return fail("runtime-mask path currently requires 32-bit result element " + "type so prefix indices and gather result lane counts match"); + if (maskType.getGranularity() != "b32") + return fail("runtime-mask path requires b32 mask granularity"); + + FailureOr resultArity = getVMIPhysicalArity(resultType); + FailureOr passthruArity = getVMIPhysicalArity(passthruType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + if (failed(resultArity) || failed(passthruArity) || failed(maskArity)) + return fail("runtime-mask path requires computable physical arity"); + if (*resultArity != 1 || *passthruArity != 1 || *maskArity != 1) + return fail("runtime-mask path currently supports only one physical " + "chunk because prefix indices must not reset across chunks"); + + std::string passthruReason; + std::string maskFullReason; + if (failed(checkFullDataPhysicalChunks(resultType, &fullChunkReason))) + return fail(Twine("runtime-mask result requires full physical chunks; ") + + fullChunkReason); + if (failed(checkFullDataPhysicalChunks(passthruType, &passthruReason))) + return fail(Twine("runtime-mask passthru requires full physical chunks; ") + + passthruReason); + if (failed(checkFullVMIPhysicalChunks(maskType, &maskFullReason))) + return fail(Twine("runtime-mask mask requires full physical chunks; ") + + maskFullReason); + + return success(); +} + +LogicalResult checkSupportedMaskedStoreShape( + const VMITargetCapabilityRegistry &capabilities, VMIVRegType valueType, + VMIMaskType maskType, Value destination, Type destinationType, + std::string *reason) { + VMIMemoryAccessPlan accessPlan = + buildWriteAccessPlan(capabilities, destination, destinationType, valueType, + VMIMemoryWriteMaskKind::ExplicitMask); + if (!accessPlan.targetCapability.isSupported()) { + if (reason) + *reason = accessPlan.targetCapability.reason; + return failure(); + } + + std::string valueReason; + std::string maskReason; + if (succeeded(checkFullDataPhysicalChunks(valueType, &valueReason)) && + succeeded(checkFullVMIPhysicalChunks(maskType, &maskReason))) + return success(); + + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + VMILayoutAttr valueLayout = valueType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + if (!valueLayout || !maskLayout) + return fail("requires assigned value and mask layouts"); + + FailureOr valueArity = getVMIPhysicalArity(valueType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + if (failed(valueArity) || failed(maskArity) || *valueArity != *maskArity) + return fail("requires matching value/mask physical arity"); + + std::string valueMaterializationReason; + FailureOr valueParts = getContiguousMaterializationPartCount( + valueType, &valueMaterializationReason); + if (failed(valueParts)) + return fail(Twine("value cannot materialize to contiguous; value ") + + valueReason + ", materialization " + + valueMaterializationReason); + + std::string maskMaterializationReason; + FailureOr maskParts = getContiguousMaterializationPartCount( + maskType, &maskMaterializationReason); + if (failed(maskParts)) + return fail(Twine("mask cannot materialize to contiguous; mask ") + + maskReason + ", materialization " + + maskMaterializationReason); + if (*valueParts != *maskParts) + return fail("requires value/mask contiguous materialization arity to match"); + return success(); +} + +FailureOr getContiguousActiveDataLanes(VMIVRegType vmiType, + int64_t chunk) { + FailureOr lanesPerPart = + getDataLanesPerPart(vmiType.getElementType()); + if (failed(lanesPerPart)) + return failure(); + + int64_t remaining = vmiType.getElementCount() - chunk * *lanesPerPart; + return std::clamp(remaining, 0, *lanesPerPart); +} + +FailureOr createContiguousStoreMask(Location loc, VMIVRegType vmiType, + int64_t chunk, VRegType vregType, + PatternRewriter &rewriter) { + FailureOr lanesPerPart = + getDataLanesPerPart(vmiType.getElementType()); + if (failed(lanesPerPart)) + return failure(); + + FailureOr activeLanes = + getContiguousActiveDataLanes(vmiType, chunk); + if (failed(activeLanes)) + return failure(); + if (*activeLanes == *lanesPerPart) + return createAllTrueMaskForVReg(loc, vregType, rewriter); + + FailureOr maskType = + getMaskTypeForVReg(vregType, rewriter.getContext()); + if (failed(maskType)) + return failure(); + FailureOr> maskAndRemaining = + createRuntimePrefixMask(loc, *maskType, + createI32Constant(loc, *activeLanes, rewriter), + rewriter); + if (failed(maskAndRemaining)) + return failure(); + return maskAndRemaining->first; +} + +FailureOr createMaskedStorePredicate(Location loc, VMIVRegType vmiType, + int64_t chunk, Value userMask, + VRegType vregType, + PatternRewriter &rewriter) { + FailureOr lanesPerPart = + getDataLanesPerPart(vmiType.getElementType()); + if (failed(lanesPerPart)) + return failure(); + + FailureOr activeLanes = + getContiguousActiveDataLanes(vmiType, chunk); + if (failed(activeLanes)) + return failure(); + if (*activeLanes == *lanesPerPart) + return userMask; + + auto maskType = dyn_cast(userMask.getType()); + if (!maskType) + return failure(); + FailureOr tailMask = + createContiguousStoreMask(loc, vmiType, chunk, vregType, rewriter); + FailureOr allTrue = createAllTrueMask(loc, maskType, rewriter); + if (failed(tailMask) || failed(allTrue)) + return failure(); + return rewriter.create(loc, maskType, userMask, *tailMask, *allTrue) + .getResult(); +} + +FailureOr> +computeShuffleForwardingSourceParts(VMIShuffleOp op, std::string *reason) { + auto fail = [&](const Twine &message) -> FailureOr> { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + if (failed(lanesPerPart)) + return fail("requires known lanes per physical part"); + + ArrayRef indices = op.getIndices(); + if (indices.empty()) + return fail("requires non-empty indices"); + + FailureOr resultFactor = getDataLayoutFactor(resultType); + if (failed(resultFactor)) + return fail("requires assigned result layout"); + + SmallVector sourceFlatIndices; + for (int64_t resultPart = 0; resultPart < *resultFactor; ++resultPart) { + FailureOr resultChunks = + getDataChunksInPart(resultType, resultPart); + if (failed(resultChunks)) + return fail("requires known result physical chunks"); + + for (int64_t resultChunk = 0; resultChunk < *resultChunks; ++resultChunk) { + std::optional sourcePart; + std::optional sourceChunk; + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = + isPaddingLane(resultType, resultPart, resultChunk, lane); + if (failed(padding)) + return fail("failed to classify result padding lanes"); + if (*padding) + continue; + + FailureOr resultLogicalLane = + mapPhysicalLaneToLogical(resultType, resultPart, resultChunk, + lane); + if (failed(resultLogicalLane) || + *resultLogicalLane >= static_cast(indices.size())) + return fail("failed to map result lane"); + + FailureOr sourcePhysical = + mapLogicalLaneToPhysical(sourceType, indices[*resultLogicalLane]); + if (failed(sourcePhysical)) + return fail("failed to map source lane"); + if (sourcePhysical->lane != lane) + return fail("requires same-lane physical chunks"); + + if (!sourcePart) { + sourcePart = sourcePhysical->part; + sourceChunk = sourcePhysical->chunk; + continue; + } + if (*sourcePart != sourcePhysical->part || + *sourceChunk != sourcePhysical->chunk) + return fail("requires one source chunk per result chunk"); + } + + if (!sourcePart || !sourceChunk) + return fail("requires at least one logical lane per result chunk"); + FailureOr sourceFlatIndex = + getDataFlatPartIndex(sourceType, *sourcePart, *sourceChunk); + if (failed(sourceFlatIndex)) + return fail("source part range is out of bounds"); + sourceFlatIndices.push_back(*sourceFlatIndex); + } + } + + return sourceFlatIndices; +} + +struct ShuffleVselrPlan { + int64_t sourceFlatIndex = 0; + int64_t baseLane = 0; + bool descending = false; +}; + +FailureOr computeShuffleLane0SplatSourcePart(VMIShuffleOp op, + std::string *reason) { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + ArrayRef indices = op.getIndices(); + if (indices.empty()) + return fail("requires non-empty indices"); + if (!llvm::all_of(indices, [](int64_t index) { return index == 0; })) + return fail("requires every result lane to select source lane 0"); + + auto sourceType = cast(op.getSource().getType()); + FailureOr sourceLane = + mapLogicalLaneToPhysical(sourceType, 0); + if (failed(sourceLane)) + return fail("failed to map source lane 0"); + FailureOr sourceFlatIndex = + getDataFlatPartIndex(sourceType, sourceLane->part, sourceLane->chunk); + if (failed(sourceFlatIndex)) + return fail("source lane 0 part range is out of bounds"); + return *sourceFlatIndex; +} + +FailureOr> +computeShuffleVselrPlans(VMIShuffleOp op, std::string *reason) { + auto fail = [&](const Twine &message) -> FailureOr> { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + if (failed(lanesPerPart)) + return fail("requires known lanes per physical part"); + + ArrayRef indices = op.getIndices(); + if (indices.empty()) + return fail("requires non-empty indices"); + + FailureOr resultFactor = getDataLayoutFactor(resultType); + if (failed(resultFactor)) + return fail("requires assigned result layout"); + + SmallVector plans; + for (int64_t resultPart = 0; resultPart < *resultFactor; ++resultPart) { + FailureOr resultChunks = + getDataChunksInPart(resultType, resultPart); + if (failed(resultChunks)) + return fail("requires known result physical chunks"); + + for (int64_t resultChunk = 0; resultChunk < *resultChunks; ++resultChunk) { + std::optional sourcePart; + std::optional sourceChunk; + std::optional baseLane; + std::optional descending; + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = + isPaddingLane(resultType, resultPart, resultChunk, lane); + if (failed(padding) || *padding) + return fail("requires full physical result chunks"); + + FailureOr resultLogicalLane = + mapPhysicalLaneToLogical(resultType, resultPart, resultChunk, + lane); + if (failed(resultLogicalLane) || + *resultLogicalLane >= static_cast(indices.size())) + return fail("failed to map result lane"); + + FailureOr sourcePhysical = + mapLogicalLaneToPhysical(sourceType, indices[*resultLogicalLane]); + if (failed(sourcePhysical)) + return fail("failed to map source lane"); + + if (!sourcePart) { + sourcePart = sourcePhysical->part; + sourceChunk = sourcePhysical->chunk; + baseLane = sourcePhysical->lane; + continue; + } + + if (*sourcePart != sourcePhysical->part || + *sourceChunk != sourcePhysical->chunk) + return fail("requires one source chunk per result chunk"); + + int64_t ascExpected = *baseLane + lane; + int64_t descExpected = *baseLane - lane; + bool asc = sourcePhysical->lane == ascExpected; + bool desc = sourcePhysical->lane == descExpected; + if (!asc && !desc) + return fail("requires ASC or DESC affine source lane indices"); + + bool laneDescending = desc && !asc; + if (!descending) { + descending = laneDescending; + continue; + } + if (*descending != laneDescending) + return fail("requires one index order per result chunk"); + } + + FailureOr sourceFlatIndex = + getDataFlatPartIndex(sourceType, *sourcePart, *sourceChunk); + if (failed(sourceFlatIndex)) + return fail("source part range is out of bounds"); + plans.push_back(ShuffleVselrPlan{*sourceFlatIndex, *baseLane, + descending.value_or(false)}); + } + } + + return plans; +} + +struct ConstantMaskChunkMaterialization { + SmallVector activeLanes; +}; + +FailureOr> +computeConstantMaskMaterialization(VMIConstantMaskOp op, std::string *reason) { + auto fail = [&](const Twine &message) + -> FailureOr> { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto denseAttr = dyn_cast(op.getValue()); + if (!denseAttr) + return fail("only dense integer mask constants are supported"); + + auto resultVMIType = cast(op.getResult().getType()); + VMILayoutAttr layout = resultVMIType.getLayoutAttr(); + if (!layout || + !VMIMaskType::isConcreteGranularity(resultVMIType.getGranularity())) + return fail("requires concrete layout and granularity"); + + FailureOr lanesPerPart = + getMaskLanesPerPart(resultVMIType.getGranularity()); + if (failed(lanesPerPart)) + return fail("requires known physical mask lanes per part"); + + auto boolValues = denseAttr.getValues(); + int64_t factor = layout.isContiguous() ? 1 : layout.getFactor(); + SmallVector materializations; + for (int64_t part = 0; part < factor; ++part) { + for (int64_t chunk = 0;; ++chunk) { + bool anyLane = false; + ConstantMaskChunkMaterialization materialization; + materialization.activeLanes.reserve(*lanesPerPart); + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = + isPaddingLane(resultVMIType, part, chunk, lane); + if (failed(padding)) + return fail("failed to map physical padding lane"); + if (*padding) { + materialization.activeLanes.push_back(0); + continue; + } + anyLane = true; + + FailureOr logicalLane = + mapPhysicalLaneToLogical(resultVMIType, part, chunk, lane); + if (failed(logicalLane)) + return fail("failed to map physical lane"); + materialization.activeLanes.push_back(boolValues[*logicalLane] ? 1 : 0); + } + if (!anyLane) + break; + materializations.push_back(std::move(materialization)); + } + } + + return materializations; +} + +std::optional getPrefixActiveLaneCount(ArrayRef activeLanes) { + bool seenInactive = false; + int64_t activeCount = 0; + for (int8_t active : activeLanes) { + if (active) { + if (seenInactive) + return std::nullopt; + ++activeCount; + continue; + } + seenInactive = true; + } + return activeCount; +} + +FailureOr materializePrefixMask(Location loc, MaskType maskType, + int64_t activeLanes, + int64_t lanesPerPart, + PatternRewriter &rewriter) { + std::optional pattern = + getPrefixPattern(activeLanes, lanesPerPart); + if (pattern) + return createPatternMask(loc, maskType, *pattern, rewriter); + + FailureOr> maskAndRemaining = + createRuntimePrefixMask(loc, maskType, + createI32Constant(loc, activeLanes, rewriter), + rewriter); + if (failed(maskAndRemaining)) + return failure(); + return maskAndRemaining->first; +} + +FailureOr +materializeConstantMaskChunk(Location loc, MaskType maskType, + ArrayRef activeLanes, + PatternRewriter &rewriter) { + FailureOr lanesPerPart = + getMaskLanesPerPart(maskType.getGranularity()); + if (failed(lanesPerPart) || + static_cast(activeLanes.size()) != *lanesPerPart) + return failure(); + + if (std::optional prefixCount = + getPrefixActiveLaneCount(activeLanes)) + return materializePrefixMask(loc, maskType, *prefixCount, *lanesPerPart, + rewriter); + + FailureOr allTrue = createAllTrueMask(loc, maskType, rewriter); + if (failed(allTrue)) + return failure(); + + Value result; + int64_t lane = 0; + while (lane < *lanesPerPart) { + while (lane < *lanesPerPart && !activeLanes[lane]) + ++lane; + if (lane >= *lanesPerPart) + break; + + int64_t runBegin = lane; + while (lane < *lanesPerPart && activeLanes[lane]) + ++lane; + int64_t runEnd = lane; + + FailureOr prefixEnd = + materializePrefixMask(loc, maskType, runEnd, *lanesPerPart, rewriter); + if (failed(prefixEnd)) + return failure(); + + Value runMask = *prefixEnd; + if (runBegin != 0) { + FailureOr prefixBegin = materializePrefixMask( + loc, maskType, runBegin, *lanesPerPart, rewriter); + if (failed(prefixBegin)) + return failure(); + Value notPrefixBegin = + rewriter.create(loc, maskType, *prefixBegin, *allTrue) + .getResult(); + runMask = + rewriter.create(loc, maskType, *prefixEnd, notPrefixBegin, + *allTrue) + .getResult(); + } + + if (!result) { + result = runMask; + continue; + } + result = rewriter.create(loc, maskType, result, runMask, *allTrue) + .getResult(); + } + + if (result) + return result; + return materializePrefixMask(loc, maskType, 0, *lanesPerPart, rewriter); +} + +Value createChunkOffset(Location loc, Value baseOffset, int64_t laneOffset, + PatternRewriter &rewriter) { + if (laneOffset == 0) + return baseOffset; + Value delta = rewriter.create(loc, laneOffset); + return rewriter.create(loc, baseOffset, delta).getResult(); +} + +std::optional getX2MemoryDistToken(Type elementType, + StringRef prefix) { + unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); + if (elementBits != 8 && elementBits != 16 && elementBits != 32) + return std::nullopt; + return (Twine(prefix) + "_B" + Twine(elementBits)).str(); +} + +std::optional getVPTOCmpMode(StringRef predicate) { + if (predicate == "eq" || predicate == "ne" || predicate == "lt" || + predicate == "le" || predicate == "gt" || predicate == "ge") + return predicate; + if (predicate == "oeq") + return StringRef("eq"); + if (predicate == "one") + return StringRef("ne"); + if (predicate == "olt") + return StringRef("lt"); + if (predicate == "ole") + return StringRef("le"); + if (predicate == "ogt") + return StringRef("gt"); + if (predicate == "oge") + return StringRef("ge"); + if (predicate == "slt") + return StringRef("lt"); + if (predicate == "sle") + return StringRef("le"); + if (predicate == "sgt") + return StringRef("gt"); + if (predicate == "sge") + return StringRef("ge"); + return std::nullopt; +} + +LogicalResult checkSupportedComparePredicate(Operation *op, + StringRef predicate) { + if (getVPTOCmpMode(predicate)) + return success(); + return op->emitError() + << kVMIDiagUnsupportedPrefix << "compare predicate " << predicate + << " cannot be lowered to pto.vcmp; supported predicates are " + "eq/ne/lt/le/gt/ge, ordered FP forms oeq/one/olt/ole/ogt/oge, " + "and signed integer forms slt/sle/sgt/sge"; +} + +struct OneToNVMIUnpackOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIUnpackOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + if (sourceParts.size() != op->getNumResults()) + return rewriter.notifyMatchFailure( + op, "converted source part count must match unpack results"); + rewriter.replaceOp(op, sourceParts, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIPackOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIPackOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + FailureOr arity = getVMIPhysicalArity(op.getResult().getType()); + if (failed(arity) || + static_cast(adaptor.getFlatOperands().size()) != *arity) + return rewriter.notifyMatchFailure( + op, "pack part count must match converted VMI result arity"); + rewriter.replaceOp(op, adaptor.getFlatOperands(), + adaptor.getResultMapping()); + return success(); + } +}; + +LogicalResult verifyIdentityPartForwarding(Operation *op, ValueRange sourceParts, + TypeRange resultTypes, + PatternRewriter &rewriter) { + if (sourceParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "source and result physical arity mismatch"); + for (auto [part, resultType] : llvm::zip_equal(sourceParts, resultTypes)) { + if (part.getType() != resultType) + return rewriter.notifyMatchFailure( + op, "helper requires non-identity physical materialization"); + } + return success(); +} + +FailureOr> +materializeDataLayoutConversion(Operation *op, ValueRange sourceParts, + TypeRange resultTypes, + VMILayoutAttr sourceLayout, + VMILayoutAttr resultLayout, + PatternRewriter &rewriter) { + if (!sourceLayout || !resultLayout) { + (void)rewriter.notifyMatchFailure( + op, "layout materialization requires assigned source/result layouts"); + return failure(); + } + + if (sourceLayout == resultLayout) { + if (failed(verifyIdentityPartForwarding(op, sourceParts, resultTypes, + rewriter))) + return failure(); + return SmallVector(sourceParts.begin(), sourceParts.end()); + } + + bool deint2ToContiguous = sourceLayout.isDeinterleaved() && + sourceLayout.getFactor() == 2 && + resultLayout.isContiguous(); + bool contiguousToDeint2 = sourceLayout.isContiguous() && + resultLayout.isDeinterleaved() && + resultLayout.getFactor() == 2; + if (deint2ToContiguous || contiguousToDeint2) { + if (sourceParts.size() != resultTypes.size() || sourceParts.empty() || + sourceParts.size() % 2 != 0) { + (void)rewriter.notifyMatchFailure( + op, + "deinterleaved=2 layout materialization requires 2*N parts"); + return failure(); + } + if (failed(verifyIdentityPartForwarding(op, sourceParts, resultTypes, + rewriter))) + return failure(); + + int64_t groups = sourceParts.size() / 2; + SmallVector results; + results.reserve(sourceParts.size()); + if (deint2ToContiguous) { + for (int64_t i = 0; i < groups; ++i) { + auto materialize = rewriter.create( + op->getLoc(), resultTypes[2 * i], resultTypes[2 * i + 1], + sourceParts[i], sourceParts[groups + i]); + results.append({materialize.getLow(), materialize.getHigh()}); + } + } else { + SmallVector part0; + SmallVector part1; + part0.reserve(groups); + part1.reserve(groups); + for (int64_t i = 0; i < groups; ++i) { + auto materialize = rewriter.create( + op->getLoc(), resultTypes[i], resultTypes[groups + i], + sourceParts[2 * i], sourceParts[2 * i + 1]); + part0.push_back(materialize.getLow()); + part1.push_back(materialize.getHigh()); + } + results.append(part0); + results.append(part1); + } + return results; + } + + bool deint4ToContiguous = sourceLayout.isDeinterleaved() && + sourceLayout.getFactor() == 4 && + resultLayout.isContiguous(); + bool contiguousToDeint4 = sourceLayout.isContiguous() && + resultLayout.isDeinterleaved() && + resultLayout.getFactor() == 4; + if (deint4ToContiguous || contiguousToDeint4) { + if (sourceParts.size() != resultTypes.size() || sourceParts.empty() || + sourceParts.size() % 4 != 0) { + (void)rewriter.notifyMatchFailure( + op, + "deinterleaved=4 layout materialization requires 4*N parts"); + return failure(); + } + if (failed(verifyIdentityPartForwarding(op, sourceParts, resultTypes, + rewriter))) + return failure(); + + SmallVector results; + results.reserve(sourceParts.size()); + int64_t groups = sourceParts.size() / 4; + if (deint4ToContiguous) { + for (int64_t i = 0; i < groups; ++i) { + Value p0 = sourceParts[i]; + Value p1 = sourceParts[groups + i]; + Value p2 = sourceParts[2 * groups + i]; + Value p3 = sourceParts[3 * groups + i]; + auto even = + rewriter.create(op->getLoc(), resultTypes[4 * i], + resultTypes[4 * i + 1], p0, p2); + auto odd = + rewriter.create(op->getLoc(), resultTypes[4 * i], + resultTypes[4 * i + 1], p1, p3); + auto low = + rewriter.create(op->getLoc(), resultTypes[4 * i], + resultTypes[4 * i + 1], even.getLow(), + odd.getLow()); + auto high = + rewriter.create(op->getLoc(), resultTypes[4 * i + 2], + resultTypes[4 * i + 3], even.getHigh(), + odd.getHigh()); + results.append( + {low.getLow(), low.getHigh(), high.getLow(), high.getHigh()}); + } + } else { + SmallVector part0; + SmallVector part1; + SmallVector part2; + SmallVector part3; + part0.reserve(groups); + part1.reserve(groups); + part2.reserve(groups); + part3.reserve(groups); + for (int64_t i = 0; i < groups; ++i) { + auto low = + rewriter.create(op->getLoc(), resultTypes[i], + resultTypes[groups + i], + sourceParts[4 * i], + sourceParts[4 * i + 1]); + auto high = rewriter.create( + op->getLoc(), resultTypes[2 * groups + i], + resultTypes[3 * groups + i], sourceParts[4 * i + 2], + sourceParts[4 * i + 3]); + auto even = rewriter.create( + op->getLoc(), resultTypes[i], resultTypes[2 * groups + i], + low.getLow(), high.getLow()); + auto odd = rewriter.create( + op->getLoc(), resultTypes[groups + i], + resultTypes[3 * groups + i], low.getHigh(), high.getHigh()); + part0.push_back(even.getLow()); + part1.push_back(odd.getLow()); + part2.push_back(even.getHigh()); + part3.push_back(odd.getHigh()); + } + results.append(part0); + results.append(part1); + results.append(part2); + results.append(part3); + } + return results; + } + + (void)rewriter.notifyMatchFailure( + op, "unsupported VMI data layout materialization"); + return failure(); +} + +FailureOr> +createPredicateDintlv(Location loc, Type lowType, Type highType, Value lhs, + Value rhs, PatternRewriter &rewriter) { + auto maskType = dyn_cast(lowType); + if (!maskType || highType != lowType) + return failure(); + if (maskType.isB8()) { + auto op = rewriter.create(loc, lowType, highType, lhs, rhs); + return std::make_pair(op.getLow(), op.getHigh()); + } + if (maskType.isB16()) { + auto op = rewriter.create(loc, lowType, highType, lhs, rhs); + return std::make_pair(op.getLow(), op.getHigh()); + } + if (maskType.isB32()) { + auto op = rewriter.create(loc, lowType, highType, lhs, rhs); + return std::make_pair(op.getLow(), op.getHigh()); + } + return failure(); +} + +FailureOr> +createPredicateIntlv(Location loc, Type lowType, Type highType, Value lhs, + Value rhs, PatternRewriter &rewriter) { + auto maskType = dyn_cast(lowType); + if (!maskType || highType != lowType) + return failure(); + if (maskType.isB8()) { + auto op = rewriter.create(loc, lowType, highType, lhs, rhs); + return std::make_pair(op.getLow(), op.getHigh()); + } + if (maskType.isB16()) { + auto op = rewriter.create(loc, lowType, highType, lhs, rhs); + return std::make_pair(op.getLow(), op.getHigh()); + } + if (maskType.isB32()) { + auto op = rewriter.create(loc, lowType, highType, lhs, rhs); + return std::make_pair(op.getLow(), op.getHigh()); + } + return failure(); +} + +FailureOr> +materializeMaskLayoutConversion(Operation *op, ValueRange sourceParts, + TypeRange resultTypes, + VMILayoutAttr sourceLayout, + VMILayoutAttr resultLayout, + PatternRewriter &rewriter) { + if (!sourceLayout || !resultLayout) { + (void)rewriter.notifyMatchFailure( + op, "mask layout materialization requires assigned source/result " + "layouts"); + return failure(); + } + + if (sourceLayout == resultLayout) { + if (failed(verifyIdentityPartForwarding(op, sourceParts, resultTypes, + rewriter))) + return failure(); + return SmallVector(sourceParts.begin(), sourceParts.end()); + } + + bool deint2ToContiguous = sourceLayout.isDeinterleaved() && + sourceLayout.getFactor() == 2 && + resultLayout.isContiguous(); + bool contiguousToDeint2 = sourceLayout.isContiguous() && + resultLayout.isDeinterleaved() && + resultLayout.getFactor() == 2; + if (deint2ToContiguous || contiguousToDeint2) { + if (sourceParts.size() != resultTypes.size() || sourceParts.empty() || + sourceParts.size() % 2 != 0) { + (void)rewriter.notifyMatchFailure( + op, "deinterleaved=2 mask layout materialization requires 2*N " + "parts"); + return failure(); + } + if (failed(verifyIdentityPartForwarding(op, sourceParts, resultTypes, + rewriter))) + return failure(); + + int64_t groups = sourceParts.size() / 2; + SmallVector results; + results.reserve(sourceParts.size()); + if (deint2ToContiguous) { + for (int64_t i = 0; i < groups; ++i) { + FailureOr> materialize = + createPredicateIntlv(op->getLoc(), resultTypes[2 * i], + resultTypes[2 * i + 1], sourceParts[i], + sourceParts[groups + i], rewriter); + if (failed(materialize)) + return rewriter.notifyMatchFailure( + op, "unsupported predicate intlv mask type"); + results.append({materialize->first, materialize->second}); + } + } else { + SmallVector part0; + SmallVector part1; + part0.reserve(groups); + part1.reserve(groups); + for (int64_t i = 0; i < groups; ++i) { + FailureOr> materialize = + createPredicateDintlv(op->getLoc(), resultTypes[i], + resultTypes[groups + i], sourceParts[2 * i], + sourceParts[2 * i + 1], rewriter); + if (failed(materialize)) + return rewriter.notifyMatchFailure( + op, "unsupported predicate dintlv mask type"); + part0.push_back(materialize->first); + part1.push_back(materialize->second); + } + results.append(part0); + results.append(part1); + } + return results; + } + + bool deint4ToContiguous = sourceLayout.isDeinterleaved() && + sourceLayout.getFactor() == 4 && + resultLayout.isContiguous(); + bool contiguousToDeint4 = sourceLayout.isContiguous() && + resultLayout.isDeinterleaved() && + resultLayout.getFactor() == 4; + if (deint4ToContiguous || contiguousToDeint4) { + if (sourceParts.size() != resultTypes.size() || sourceParts.empty() || + sourceParts.size() % 4 != 0) { + (void)rewriter.notifyMatchFailure( + op, "deinterleaved=4 mask layout materialization requires 4*N " + "parts"); + return failure(); + } + if (failed(verifyIdentityPartForwarding(op, sourceParts, resultTypes, + rewriter))) + return failure(); + + SmallVector results; + results.reserve(sourceParts.size()); + int64_t groups = sourceParts.size() / 4; + if (deint4ToContiguous) { + for (int64_t i = 0; i < groups; ++i) { + Value p0 = sourceParts[i]; + Value p1 = sourceParts[groups + i]; + Value p2 = sourceParts[2 * groups + i]; + Value p3 = sourceParts[3 * groups + i]; + FailureOr> even = + createPredicateIntlv(op->getLoc(), resultTypes[4 * i], + resultTypes[4 * i + 1], p0, p2, rewriter); + FailureOr> odd = + createPredicateIntlv(op->getLoc(), resultTypes[4 * i], + resultTypes[4 * i + 1], p1, p3, rewriter); + if (failed(even) || failed(odd)) + return rewriter.notifyMatchFailure( + op, "unsupported predicate intlv mask type"); + FailureOr> low = + createPredicateIntlv(op->getLoc(), resultTypes[4 * i], + resultTypes[4 * i + 1], even->first, + odd->first, rewriter); + FailureOr> high = + createPredicateIntlv(op->getLoc(), resultTypes[4 * i + 2], + resultTypes[4 * i + 3], even->second, + odd->second, rewriter); + if (failed(low) || failed(high)) + return rewriter.notifyMatchFailure( + op, "unsupported predicate intlv mask type"); + results.append({low->first, low->second, high->first, high->second}); + } + } else { + SmallVector part0; + SmallVector part1; + SmallVector part2; + SmallVector part3; + part0.reserve(groups); + part1.reserve(groups); + part2.reserve(groups); + part3.reserve(groups); + for (int64_t i = 0; i < groups; ++i) { + FailureOr> low = + createPredicateDintlv(op->getLoc(), resultTypes[i], + resultTypes[groups + i], + sourceParts[4 * i], sourceParts[4 * i + 1], + rewriter); + FailureOr> high = createPredicateDintlv( + op->getLoc(), resultTypes[2 * groups + i], + resultTypes[3 * groups + i], sourceParts[4 * i + 2], + sourceParts[4 * i + 3], rewriter); + if (failed(low) || failed(high)) + return rewriter.notifyMatchFailure( + op, "unsupported predicate dintlv mask type"); + FailureOr> even = + createPredicateDintlv(op->getLoc(), resultTypes[i], + resultTypes[2 * groups + i], low->first, + high->first, rewriter); + FailureOr> odd = + createPredicateDintlv(op->getLoc(), resultTypes[groups + i], + resultTypes[3 * groups + i], low->second, + high->second, rewriter); + if (failed(even) || failed(odd)) + return rewriter.notifyMatchFailure( + op, "unsupported predicate dintlv mask type"); + part0.push_back(even->first); + part1.push_back(odd->first); + part2.push_back(even->second); + part3.push_back(odd->second); + } + results.append(part0); + results.append(part1); + results.append(part2); + results.append(part3); + } + return results; + } + + (void)rewriter.notifyMatchFailure( + op, "unsupported VMI mask layout materialization"); + return failure(); +} + +int getMaskGranularityRank(StringRef granularity) { + if (granularity == "b8") + return 0; + if (granularity == "b16") + return 1; + if (granularity == "b32") + return 2; + return -1; +} + +StringRef getMaskGranularityForRank(int rank) { + switch (rank) { + case 0: + return "b8"; + case 1: + return "b16"; + case 2: + return "b32"; + default: + return ""; + } +} + +LogicalResult checkSupportedMaskGranularityMaterialization( + const VMITargetCapabilityRegistry &capabilities, VMIMaskType sourceType, + VMIMaskType resultType, std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (sourceType.getElementCount() != resultType.getElementCount()) + return fail("requires source and result mask lane counts to match"); + if (sourceType.getLayoutAttr() != resultType.getLayoutAttr()) + return fail("requires source and result mask layouts to match"); + + VMICapabilityResult granularityCapability = + capabilities.supportsMaskGranularityConversion( + sourceType.getGranularity(), resultType.getGranularity()); + if (!granularityCapability.isSupported()) + return fail(granularityCapability.reason); + + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(resultArity)) + return fail("requires computable source/result physical arity"); + if (*sourceArity < 1 || *resultArity < 1) + return fail("requires non-empty source/result physical arity"); + + return success(); +} + +FailureOr> materializeAdjacentMaskGranularityConversion( + Operation *op, VMIMaskType sourceType, VMIMaskType resultType, + ValueRange sourceParts, PatternRewriter &rewriter) { + auto fail = [&](const Twine &message) -> FailureOr> { + (void)rewriter.notifyMatchFailure(op, message); + return failure(); + }; + + int sourceRank = getMaskGranularityRank(sourceType.getGranularity()); + int resultRank = getMaskGranularityRank(resultType.getGranularity()); + if (std::abs(sourceRank - resultRank) != 1) + return fail("mask granularity conversion must be adjacent"); + + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr factor = getVMITypeLayoutFactor(sourceType); + if (failed(sourceArity) || failed(factor) || + static_cast(sourceParts.size()) != *sourceArity) + return fail("source mask part count does not match source VMI type"); + + MLIRContext *ctx = op->getContext(); + auto partAttr = [&](StringRef part) { return StringAttr::get(ctx, part); }; + auto resultMaskType = MaskType::get(ctx, resultType.getGranularity()); + SmallVector results; + + int64_t sourceOffset = 0; + for (int64_t part = 0; part < *factor; ++part) { + FailureOr sourceChunks = getVMITypeChunksInPart(sourceType, part); + FailureOr resultChunks = getVMITypeChunksInPart(resultType, part); + if (failed(sourceChunks) || failed(resultChunks)) + return fail("requires computable source/result chunks per layout part"); + + if (resultRank > sourceRank) { + int64_t produced = 0; + for (int64_t chunk = 0; chunk < *sourceChunks && produced < *resultChunks; + ++chunk) { + Value source = sourceParts[sourceOffset + chunk]; + results.push_back( + rewriter + .create(op->getLoc(), resultMaskType, source, + partAttr("LOWER")) + .getResult()); + ++produced; + if (produced >= *resultChunks) + break; + results.push_back( + rewriter + .create(op->getLoc(), resultMaskType, source, + partAttr("HIGHER")) + .getResult()); + ++produced; + } + if (produced != *resultChunks) + return fail("widening mask granularity conversion produced the wrong " + "number of result chunks"); + } else { + Value allTrue; + int64_t consumed = 0; + for (int64_t chunk = 0; chunk < *resultChunks; ++chunk) { + if (consumed >= *sourceChunks) + return fail("narrowing mask granularity conversion ran out of " + "source chunks"); + Value lowerSource = sourceParts[sourceOffset + consumed++]; + Value packed = + rewriter + .create(op->getLoc(), resultMaskType, lowerSource, + partAttr("LOWER")) + .getResult(); + if (consumed < *sourceChunks) { + Value higherSource = sourceParts[sourceOffset + consumed++]; + Value higher = + rewriter + .create(op->getLoc(), resultMaskType, higherSource, + partAttr("HIGHER")) + .getResult(); + if (!allTrue) { + FailureOr mask = + createAllTrueMask(op->getLoc(), resultMaskType, rewriter); + if (failed(mask)) + return fail("failed to create all-true mask for ppack merge"); + allTrue = *mask; + } + packed = rewriter + .create(op->getLoc(), resultMaskType, packed, + higher, allTrue) + .getResult(); + } + results.push_back(packed); + } + if (consumed != *sourceChunks) + return fail("narrowing mask granularity conversion left unused source " + "chunks"); + } + + sourceOffset += *sourceChunks; + } + + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(resultArity) || + static_cast(results.size()) != *resultArity) + return fail("mask granularity conversion result count mismatch"); + return results; +} + +FailureOr> materializeMaskGranularityConversion( + Operation *op, const VMITargetCapabilityRegistry &capabilities, + VMIMaskType sourceType, VMIMaskType resultType, ValueRange sourceParts, + PatternRewriter &rewriter) { + std::string reason; + if (failed(checkSupportedMaskGranularityMaterialization(capabilities, + sourceType, + resultType, &reason))) { + (void)rewriter.notifyMatchFailure(op, reason); + return failure(); + } + + int currentRank = getMaskGranularityRank(sourceType.getGranularity()); + int resultRank = getMaskGranularityRank(resultType.getGranularity()); + VMIMaskType currentType = sourceType; + SmallVector currentParts(sourceParts.begin(), sourceParts.end()); + + while (currentRank != resultRank) { + currentRank += currentRank < resultRank ? 1 : -1; + StringRef nextGranularity = getMaskGranularityForRank(currentRank); + if (nextGranularity.empty()) { + (void)rewriter.notifyMatchFailure(op, + "invalid target mask granularity rank"); + return failure(); + } + VMIMaskType nextType = + VMIMaskType::get(op->getContext(), currentType.getElementCount(), + nextGranularity, currentType.getLayoutAttr()); + FailureOr> nextParts = + materializeAdjacentMaskGranularityConversion( + op, currentType, nextType, currentParts, rewriter); + if (failed(nextParts)) + return failure(); + currentType = nextType; + currentParts = std::move(*nextParts); + } + + return currentParts; +} + +struct OneToNVMIEnsureLayoutOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIEnsureLayoutOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!sourceLayout || !resultLayout) + return rewriter.notifyMatchFailure( + op, "ensure_layout requires assigned source/result layouts"); + + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + FailureOr> results = materializeDataLayoutConversion( + op, sourceParts, resultTypes, sourceLayout, resultLayout, rewriter); + if (failed(results)) + return failure(); + rewriter.replaceOp(op, *results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIEnsureMaskLayoutOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIEnsureMaskLayoutOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIEnsureMaskLayoutOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + if (sourceType.getGranularity() != resultType.getGranularity()) + return rewriter.notifyMatchFailure( + op, "mask layout helper cannot also change granularity"); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + FailureOr> results = materializeMaskLayoutConversion( + op, sourceParts, resultTypes, sourceLayout, resultLayout, rewriter); + if (failed(results)) + return failure(); + rewriter.replaceOp(op, *results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIEnsureMaskGranularityOpPattern + : OneToNOpConversionPattern { + OneToNVMIEnsureMaskGranularityOpPattern( + TypeConverter &typeConverter, MLIRContext *context, + const VMITargetCapabilityRegistry &capabilities) + : OneToNOpConversionPattern(typeConverter, + context), + capabilities(capabilities) {} + + LogicalResult + matchAndRewrite(VMIEnsureMaskGranularityOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + if (sourceType.getLayout() != resultType.getLayout()) + return rewriter.notifyMatchFailure( + op, "mask granularity helper cannot also change layout"); + + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceType.getGranularity() != resultType.getGranularity()) { + FailureOr> results = + materializeMaskGranularityConversion(op, capabilities, sourceType, + resultType, sourceParts, + rewriter); + if (failed(results)) + return failure(); + if (results->size() != resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "mask granularity result arity mismatch"); + for (auto [result, type] : llvm::zip_equal(*results, resultTypes)) + if (result.getType() != type) + return rewriter.notifyMatchFailure( + op, "mask granularity result type mismatch"); + rewriter.replaceOp(op, *results, adaptor.getResultMapping()); + return success(); + } + + if (failed(verifyIdentityPartForwarding(op, sourceParts, resultTypes, + rewriter))) + return failure(); + rewriter.replaceOp(op, sourceParts, adaptor.getResultMapping()); + return success(); + } + +private: + const VMITargetCapabilityRegistry &capabilities; +}; + +struct OneToNVMIBroadcastOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIBroadcastOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange inputParts = adaptor.getValue(); + if (inputParts.size() != 1) + return rewriter.notifyMatchFailure( + op, "broadcast input must convert to one value"); + bool inputIsVReg = isa(op.getValue().getType()); + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + SmallVector results; + results.reserve(resultTypes.size()); + for (Type resultType : resultTypes) { + auto vregType = dyn_cast(resultType); + if (!vregType) + return rewriter.notifyMatchFailure(op, + "broadcast result must be vreg"); + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for broadcast mask"); + StringAttr position = + inputIsVReg ? rewriter.getStringAttr("LOWEST") : StringAttr{}; + results.push_back( + rewriter + .create(op.getLoc(), resultType, inputParts.front(), + *mask, position) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +FailureOr createScalarOffsetConstant(Location loc, Type type, + int64_t value, + PatternRewriter &rewriter) { + if (auto intType = dyn_cast(type)) { + return rewriter + .create(loc, IntegerAttr::get(intType, value)) + .getResult(); + } + if (auto floatType = dyn_cast(type)) { + return rewriter + .create( + loc, FloatAttr::get(floatType, + llvm::APFloat(static_cast(value)))) + .getResult(); + } + return failure(); +} + +FailureOr createIotaChunkBase(Location loc, Value base, + int64_t laneOffset, + StringRef order, + PatternRewriter &rewriter) { + if (laneOffset == 0) + return base; + + FailureOr offset = + createScalarOffsetConstant(loc, base.getType(), laneOffset, rewriter); + if (failed(offset)) + return failure(); + + if (isa(base.getType())) { + if (order == "DESC") + return rewriter.create(loc, base, *offset).getResult(); + return rewriter.create(loc, base, *offset).getResult(); + } + if (isa(base.getType())) { + if (order == "DESC") + return rewriter.create(loc, base, *offset).getResult(); + return rewriter.create(loc, base, *offset).getResult(); + } + + return failure(); +} + +FailureOr createIotaContiguousChunk(Location loc, Type resultType, + Value base, int64_t laneOffset, + StringAttr orderAttr, + PatternRewriter &rewriter) { + StringRef order = orderAttr ? orderAttr.getValue() : StringRef("ASC"); + FailureOr chunkBase = + createIotaChunkBase(loc, base, laneOffset, order, rewriter); + if (failed(chunkBase)) + return failure(); + return rewriter.create(loc, resultType, *chunkBase, orderAttr) + .getResult(); +} + +FailureOr createIotaDeinterleavedChunk(Location loc, Type resultType, + Value base, int64_t factor, + int64_t part, int64_t chunk, + int64_t lanesPerPart, + StringAttr orderAttr, + PatternRewriter &rewriter) { + auto vregType = dyn_cast(resultType); + if (!vregType) + return failure(); + + FailureOr mask = createAllTrueMaskForVReg(loc, vregType, rewriter); + FailureOr zero = createScalarOffsetConstant(loc, base.getType(), 0, + rewriter); + FailureOr factorScalar = + createScalarOffsetConstant(loc, base.getType(), factor, rewriter); + if (failed(mask) || failed(zero) || failed(factorScalar)) + return failure(); + + Value local = + rewriter.create(loc, resultType, *zero, StringAttr{}).getResult(); + Value scaled = + rewriter.create(loc, resultType, local, *factorScalar, *mask) + .getResult(); + + StringRef order = orderAttr ? orderAttr.getValue() : StringRef("ASC"); + int64_t partOffset = part + factor * chunk * lanesPerPart; + FailureOr biasedBase = + createIotaChunkBase(loc, base, partOffset, order, rewriter); + if (failed(biasedBase)) + return failure(); + + if (order == "DESC") { + Value baseVector = + rewriter + .create(loc, resultType, *biasedBase, *mask, + /*position=*/nullptr) + .getResult(); + return rewriter.create(loc, resultType, baseVector, scaled, *mask) + .getResult(); + } + + return rewriter.create(loc, resultType, scaled, *biasedBase, *mask) + .getResult(); +} + +struct OneToNVMIIotaOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIIotaOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto resultVMIType = cast(op.getResult().getType()); + VMILayoutAttr layout = resultVMIType.getLayoutAttr(); + if (!layout) + return rewriter.notifyMatchFailure(op, + "iota requires assigned layout"); + + FailureOr lanesPerPart = + getDataLanesPerPart(resultVMIType.getElementType()); + if (failed(lanesPerPart)) + return rewriter.notifyMatchFailure( + op, "iota requires known physical lanes per part"); + + FailureOr base = + getSingleValue(op, adaptor.getBase(), + "iota base must convert to one value", rewriter); + if (failed(base)) + return failure(); + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + SmallVector results; + results.reserve(resultTypes.size()); + + if (layout.isContiguous()) { + for (auto [index, resultType] : llvm::enumerate(resultTypes)) { + if (!isa(resultType)) + return rewriter.notifyMatchFailure(op, "iota result must be vreg"); + FailureOr result = createIotaContiguousChunk( + op.getLoc(), resultType, *base, + static_cast(index) * *lanesPerPart, op.getOrderAttr(), + rewriter); + if (failed(result)) + return rewriter.notifyMatchFailure( + op, "failed to materialize contiguous iota chunk"); + results.push_back(*result); + } + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + int64_t factor = layout.getFactor(); + if (resultTypes.size() % factor != 0) + return rewriter.notifyMatchFailure( + op, "deinterleaved iota physical result count does not match " + "layout factor"); + int64_t chunksPerPart = resultTypes.size() / factor; + for (int64_t part = 0; part < factor; ++part) { + for (int64_t chunk = 0; chunk < chunksPerPart; ++chunk) { + Type resultType = resultTypes[part * chunksPerPart + chunk]; + FailureOr result = createIotaDeinterleavedChunk( + op.getLoc(), resultType, *base, factor, part, chunk, + *lanesPerPart, op.getOrderAttr(), rewriter); + if (failed(result)) + return rewriter.notifyMatchFailure( + op, "failed to materialize deinterleaved iota chunk"); + results.push_back(*result); + } + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIConstantOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIConstantOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto denseAttr = dyn_cast(op.getValue()); + if (!denseAttr || !denseAttr.isSplat()) + return rewriter.notifyMatchFailure( + op, "only splat dense data constants are supported"); + auto splatAttr = dyn_cast(denseAttr.getSplatValue()); + if (!splatAttr) + return rewriter.notifyMatchFailure(op, + "splat constant must be typed"); + + Value scalar = + rewriter.create(op.getLoc(), splatAttr).getResult(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + SmallVector results; + results.reserve(resultTypes.size()); + for (Type resultType : resultTypes) { + auto vregType = dyn_cast(resultType); + if (!vregType) + return rewriter.notifyMatchFailure(op, "constant result must be vreg"); + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for constant mask"); + results.push_back( + rewriter + .create(op.getLoc(), resultType, scalar, *mask, + /*position=*/nullptr) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIConstantMaskOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIConstantMaskOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIConstantMaskOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + std::string reason; + FailureOr> materializations = + computeConstantMaskMaterialization(op, &reason); + if (failed(materializations)) + return rewriter.notifyMatchFailure( + op, Twine("constant_mask ") + reason); + + SmallVector results; + results.reserve(resultTypes.size()); + for (const ConstantMaskChunkMaterialization &materialization : + *materializations) { + if (results.size() >= resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "constant_mask produced too many physical masks"); + auto maskType = dyn_cast(resultTypes[results.size()]); + if (!maskType) + return rewriter.notifyMatchFailure(op, + "constant_mask result must be mask"); + FailureOr mask = materializeConstantMaskChunk( + op.getLoc(), maskType, materialization.activeLanes, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "failed to materialize constant_mask physical chunk"); + results.push_back(*mask); + } + + if (results.size() != resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "constant_mask physical result count mismatch"); + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMICreateMaskOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMICreateMaskOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto activeConstant = + op.getActiveLanes().getDefiningOp(); + auto resultVMIType = cast(op.getResult().getType()); + VMILayoutAttr layout = resultVMIType.getLayoutAttr(); + if (!layout || !VMIMaskType::isConcreteGranularity( + resultVMIType.getGranularity())) + return rewriter.notifyMatchFailure( + op, "create_mask requires concrete layout and granularity"); + FailureOr lanesPerPart = + getMaskLanesPerPart(resultVMIType.getGranularity()); + if (failed(lanesPerPart)) + return rewriter.notifyMatchFailure( + op, "create_mask requires known physical mask lanes per part"); + + if (!activeConstant) { + FailureOr active = getSingleValue( + op, adaptor.getActiveLanes(), + "create_mask active_lanes must convert to one value", rewriter); + if (failed(active)) + return failure(); + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + int64_t factor = layout.isContiguous() ? 1 : layout.getFactor(); + if (resultTypes.size() % factor != 0) + return rewriter.notifyMatchFailure( + op, "dynamic create_mask physical result count does not match " + "layout factor"); + int64_t chunksPerPart = resultTypes.size() / factor; + Value activeI32 = clampDynamicActiveLanes( + op.getLoc(), *active, resultVMIType.getElementCount(), rewriter); + + SmallVector results; + results.reserve(resultTypes.size()); + for (int64_t part = 0; part < factor; ++part) { + Value remaining = createPartitionActiveLanes( + op.getLoc(), activeI32, factor, part, rewriter); + for (int64_t chunk = 0; chunk < chunksPerPart; ++chunk) { + Type resultType = resultTypes[part * chunksPerPart + chunk]; + auto maskType = dyn_cast(resultType); + if (!maskType) + return rewriter.notifyMatchFailure( + op, "create_mask result must be mask"); + FailureOr> maskAndRemaining = + createRuntimePrefixMask(op.getLoc(), maskType, remaining, + rewriter); + if (failed(maskAndRemaining)) + return rewriter.notifyMatchFailure( + op, "unsupported mask type for dynamic create_mask"); + results.push_back(maskAndRemaining->first); + remaining = maskAndRemaining->second; + } + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + auto activeAttr = dyn_cast(activeConstant.getValue()); + if (!activeAttr) + return rewriter.notifyMatchFailure( + op, "create_mask active_lanes must be an integer constant"); + + int64_t activeLanes = activeAttr.getInt(); + if (activeLanes < 0) + activeLanes = 0; + if (activeLanes > resultVMIType.getElementCount()) + activeLanes = resultVMIType.getElementCount(); + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + int64_t factor = layout.isContiguous() ? 1 : layout.getFactor(); + SmallVector results; + results.reserve(resultTypes.size()); + + for (int64_t part = 0; part < factor; ++part) { + for (int64_t chunk = 0;; ++chunk) { + bool anyLane = false; + int64_t activeInChunk = 0; + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = + isPaddingLane(resultVMIType, part, chunk, lane); + if (failed(padding)) + return rewriter.notifyMatchFailure( + op, "failed to map create_mask physical padding lane"); + if (*padding) + continue; + anyLane = true; + FailureOr logicalLane = + mapPhysicalLaneToLogical(resultVMIType, part, chunk, lane); + if (failed(logicalLane)) + return rewriter.notifyMatchFailure( + op, "failed to map create_mask physical lane"); + if (*logicalLane < activeLanes) + ++activeInChunk; + } + if (!anyLane) + break; + + if (results.size() >= resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "create_mask produced too many physical masks"); + auto maskType = dyn_cast(resultTypes[results.size()]); + if (!maskType) + return rewriter.notifyMatchFailure(op, + "create_mask result must be mask"); + std::optional pattern = + getPrefixPattern(activeInChunk, *lanesPerPart); + if (pattern) { + FailureOr mask = + createPrefixMask(op.getLoc(), maskType, *pattern, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "unsupported mask type for create_mask"); + results.push_back(*mask); + continue; + } + + FailureOr> maskAndRemaining = + createRuntimePrefixMask( + op.getLoc(), maskType, + createI32Constant(op.getLoc(), activeInChunk, rewriter), + rewriter); + if (failed(maskAndRemaining)) + return rewriter.notifyMatchFailure( + op, "unsupported mask type for create_mask plt fallback"); + results.push_back(maskAndRemaining->first); + } + } + + if (results.size() != resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "create_mask physical result count mismatch"); + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMILoadOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMILoadOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto resultVMIType = cast(op.getResult().getType()); + FailureOr source = + getSingleValue(op, adaptor.getSource(), + "load source must convert to one value", rewriter); + FailureOr offset = + getSingleValue(op, adaptor.getOffset(), + "load offset must convert to one value", rewriter); + if (failed(source) || failed(offset)) + return failure(); + FailureOr lanesPerPart = verifyFullOrSafeReadVRegChunks( + op, resultVMIType, (*source).getType(), *offset, rewriter); + if (failed(lanesPerPart)) + return failure(); + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); + if (resultLayout && resultLayout.isDeinterleaved() && + resultLayout.getFactor() == 2) { + std::optional dist = + getX2MemoryDistToken(resultVMIType.getElementType(), "DINTLV"); + if (dist && !resultTypes.empty() && resultTypes.size() % 2 == 0) { + int64_t groups = resultTypes.size() / 2; + SmallVector lows; + SmallVector highs; + lows.reserve(groups); + highs.reserve(groups); + for (int64_t group = 0; group < groups; ++group) { + Type lowType = resultTypes[group]; + Type highType = resultTypes[groups + group]; + if (lowType != highType) + return rewriter.notifyMatchFailure( + op, "vldsx2 requires matching low/high result types"); + Value chunkOffset = createChunkOffset( + op.getLoc(), *offset, group * 2 * *lanesPerPart, rewriter); + auto load = rewriter.create( + op.getLoc(), lowType, highType, *source, chunkOffset, + rewriter.getStringAttr(*dist)); + lows.push_back(load.getLow()); + highs.push_back(load.getHigh()); + } + SmallVector results; + results.reserve(resultTypes.size()); + results.append(lows); + results.append(highs); + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + } + + SmallVector contiguousParts; + contiguousParts.reserve(resultTypes.size()); + for (auto [index, resultType] : llvm::enumerate(resultTypes)) { + auto vregType = dyn_cast(resultType); + if (!vregType) + return rewriter.notifyMatchFailure(op, "load result must be vreg"); + Value chunkOffset = createChunkOffset( + op.getLoc(), *offset, index * *lanesPerPart, rewriter); + contiguousParts.push_back( + rewriter + .create(op.getLoc(), resultType, + /*updated_base=*/Type{}, *source, chunkOffset, + /*dist=*/nullptr) + .getResult()); + } + + FailureOr> results = materializeDataLayoutConversion( + op, contiguousParts, resultTypes, + VMILayoutAttr::getContiguous(rewriter.getContext()), + resultVMIType.getLayoutAttr(), rewriter); + if (failed(results)) + return failure(); + + rewriter.replaceOp(op, *results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIMaskedLoadOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIMaskedLoadOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIMaskedLoadOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto resultVMIType = cast(op.getResult().getType()); + FailureOr source = + getSingleValue(op, adaptor.getSource(), + "masked_load source must convert to one value", + rewriter); + FailureOr offset = + getSingleValue(op, adaptor.getOffset(), + "masked_load offset must convert to one value", + rewriter); + if (failed(source) || failed(offset)) + return failure(); + + FailureOr lanesPerPart = verifyFullOrSafeReadVRegChunks( + op, resultVMIType, (*source).getType(), *offset, rewriter); + if (failed(lanesPerPart)) + return failure(); + + ValueRange maskParts = adaptor.getMask(); + ValueRange passthruParts = adaptor.getPassthru(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (maskParts.size() != passthruParts.size() || + passthruParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "masked_load physical arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [index, maskPassthruAndType] : + llvm::enumerate(llvm::zip_equal(maskParts, passthruParts, + resultTypes))) { + auto [mask, passthru, resultType] = maskPassthruAndType; + if (!isa(mask.getType()) || passthru.getType() != resultType || + !isa(resultType)) + return rewriter.notifyMatchFailure( + op, "masked_load physical part type mismatch"); + + Value chunkOffset = createChunkOffset( + op.getLoc(), *offset, index * *lanesPerPart, rewriter); + Value loaded = + rewriter + .create(op.getLoc(), resultType, + /*updated_base=*/Type{}, *source, chunkOffset, + /*dist=*/nullptr) + .getResult(); + results.push_back( + rewriter + .create(op.getLoc(), resultType, loaded, passthru, mask) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIGatherOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIGatherOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + FailureOr source = + getSingleValue(op, adaptor.getSource(), + "gather source must convert to one value", rewriter); + if (failed(source)) + return failure(); + + ValueRange indicesParts = adaptor.getIndices(); + ValueRange maskParts = adaptor.getMask(); + ValueRange passthruParts = adaptor.getPassthru(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (indicesParts.size() != maskParts.size() || + indicesParts.size() != passthruParts.size() || + indicesParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, "gather physical arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [indices, mask, passthru, resultType] : + llvm::zip_equal(indicesParts, maskParts, passthruParts, + resultTypes)) { + if (!isa(indices.getType()) || !isa(mask.getType()) || + passthru.getType() != resultType || !isa(resultType)) + return rewriter.notifyMatchFailure(op, + "gather physical part type mismatch"); + + Value gathered = + rewriter + .create(op.getLoc(), resultType, *source, indices, + mask) + .getResult(); + results.push_back( + rewriter + .create(op.getLoc(), resultType, gathered, passthru, + mask) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIExpandLoadOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIExpandLoadOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIExpandLoadOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto resultVMIType = cast(op.getResult().getType()); + FailureOr source = + getSingleValue(op, adaptor.getSource(), + "expand_load source must convert to one value", + rewriter); + FailureOr offset = + getSingleValue(op, adaptor.getOffset(), + "expand_load offset must convert to one value", + rewriter); + if (failed(source) || failed(offset)) + return failure(); + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (isStaticAllActiveMask(op.getMask(), resultVMIType.getElementCount())) { + FailureOr lanesPerPart = verifyFullOrSafeReadVRegChunks( + op, resultVMIType, (*source).getType(), *offset, rewriter); + if (failed(lanesPerPart)) + return failure(); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [index, resultType] : llvm::enumerate(resultTypes)) { + if (!isa(resultType)) + return rewriter.notifyMatchFailure( + op, "expand_load result must be vreg"); + Value chunkOffset = createChunkOffset( + op.getLoc(), *offset, index * *lanesPerPart, rewriter); + results.push_back( + rewriter + .create(op.getLoc(), resultType, + /*updated_base=*/Type{}, *source, chunkOffset, + /*dist=*/nullptr) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + ValueRange maskParts = adaptor.getMask(); + ValueRange passthruParts = adaptor.getPassthru(); + if (resultTypes.size() != 1 || maskParts.size() != 1 || + passthruParts.size() != 1) + return rewriter.notifyMatchFailure( + op, "runtime expand_load supports only one physical chunk"); + + auto resultType = dyn_cast(resultTypes.front()); + auto maskType = dyn_cast(maskParts.front().getType()); + if (!resultType || !maskType || passthruParts.front().getType() != resultType) + return rewriter.notifyMatchFailure( + op, "runtime expand_load requires physical result/passthru/mask"); + + auto baseType = dyn_cast((*source).getType()); + if (!baseType) + return rewriter.notifyMatchFailure(op, + "runtime expand_load requires ptr"); + Value gatherBase = + rewriter + .create(op.getLoc(), (*source).getType(), *source, + *offset) + .getResult(); + auto indexType = + VRegType::get(rewriter.getContext(), resultType.getElementCount(), + rewriter.getI32Type()); + FailureOr indexSeedMask = + createAllTrueMaskForVReg(op.getLoc(), indexType, rewriter); + if (failed(indexSeedMask)) + return rewriter.notifyMatchFailure( + op, "failed to create runtime expand_load index seed mask"); + Value zero = rewriter.create(op.getLoc(), 0, 32); + Value carrier = + rewriter + .create(op.getLoc(), indexType, zero, *indexSeedMask, + /*position=*/nullptr) + .getResult(); + Value indices = + rewriter + .create(op.getLoc(), indexType, carrier, + maskParts.front()) + .getResult(); + Value gathered = + rewriter + .create(op.getLoc(), resultType, gatherBase, indices, + maskParts.front()) + .getResult(); + Value result = + rewriter + .create(op.getLoc(), resultType, gathered, + passthruParts.front(), maskParts.front()) + .getResult(); + rewriter.replaceOp(op, SmallVector{result}, + adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIStoreOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIStoreOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto valueVMIType = cast(op.getValue().getType()); + FailureOr lanesPerPart = + getDataLanesPerPart(valueVMIType.getElementType()); + if (failed(lanesPerPart)) + return rewriter.notifyMatchFailure( + op, "store requires known physical lanes per part"); + bool fullPhysicalChunks = + succeeded(checkFullDataPhysicalChunks(valueVMIType, nullptr)); + + FailureOr destination = + getSingleValue(op, adaptor.getDestination(), + "store destination must convert to one value", rewriter); + FailureOr offset = + getSingleValue(op, adaptor.getOffset(), + "store offset must convert to one value", rewriter); + if (failed(destination) || failed(offset)) + return failure(); + + ValueRange valueParts = adaptor.getValue(); + VMILayoutAttr valueLayout = valueVMIType.getLayoutAttr(); + if (fullPhysicalChunks && valueLayout && valueLayout.isDeinterleaved() && + valueLayout.getFactor() == 2) { + std::optional dist = + getX2MemoryDistToken(valueVMIType.getElementType(), "INTLV"); + if (dist && !valueParts.empty() && valueParts.size() % 2 == 0) { + int64_t groups = valueParts.size() / 2; + for (int64_t group = 0; group < groups; ++group) { + Value low = valueParts[group]; + Value high = valueParts[groups + group]; + if (low.getType() != high.getType()) + return rewriter.notifyMatchFailure( + op, "vstsx2 requires matching low/high value types"); + auto vregType = dyn_cast(low.getType()); + if (!vregType) + return rewriter.notifyMatchFailure(op, "store value must be vreg"); + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for store mask"); + Value chunkOffset = createChunkOffset( + op.getLoc(), *offset, group * 2 * *lanesPerPart, rewriter); + rewriter.create(op.getLoc(), low, high, *destination, + chunkOffset, rewriter.getStringAttr(*dist), + *mask); + } + rewriter.eraseOp(op); + return success(); + } + } + + SmallVector contiguousTypes; + contiguousTypes.reserve(valueParts.size()); + for (Value value : valueParts) + contiguousTypes.push_back(value.getType()); + + FailureOr> storeParts = materializeDataLayoutConversion( + op, valueParts, contiguousTypes, valueVMIType.getLayoutAttr(), + VMILayoutAttr::getContiguous(rewriter.getContext()), rewriter); + if (failed(storeParts)) + return failure(); + + for (auto [index, value] : llvm::enumerate(*storeParts)) { + auto vregType = dyn_cast(value.getType()); + if (!vregType) + return rewriter.notifyMatchFailure(op, "store value must be vreg"); + if (!fullPhysicalChunks) { + FailureOr activeLanes = + getContiguousActiveDataLanes(valueVMIType, index); + if (failed(activeLanes)) + return rewriter.notifyMatchFailure( + op, "failed to compute store active lanes"); + if (*activeLanes == 0) + continue; + } + FailureOr mask = fullPhysicalChunks + ? createAllTrueMaskForVReg(op.getLoc(), + vregType, rewriter) + : createContiguousStoreMask(op.getLoc(), + valueVMIType, + index, vregType, + rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for store mask"); + Value chunkOffset = createChunkOffset( + op.getLoc(), *offset, index * *lanesPerPart, rewriter); + rewriter.create(op.getLoc(), + /*updated_base=*/Type{}, value, *destination, + chunkOffset, /*dist=*/nullptr, *mask); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +struct OneToNVMIMaskedStoreOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIMaskedStoreOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIMaskedStoreOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto valueVMIType = cast(op.getValue().getType()); + FailureOr lanesPerPart = + getDataLanesPerPart(valueVMIType.getElementType()); + if (failed(lanesPerPart)) + return rewriter.notifyMatchFailure( + op, "masked_store requires known physical lanes per part"); + + FailureOr destination = + getSingleValue(op, adaptor.getDestination(), + "masked_store destination must convert to one value", + rewriter); + FailureOr offset = + getSingleValue(op, adaptor.getOffset(), + "masked_store offset must convert to one value", + rewriter); + if (failed(destination) || failed(offset)) + return failure(); + + ValueRange valueParts = adaptor.getValue(); + ValueRange maskParts = adaptor.getMask(); + if (valueParts.size() != maskParts.size()) + return rewriter.notifyMatchFailure( + op, "masked_store value/mask physical arity mismatch"); + + SmallVector contiguousValueTypes; + contiguousValueTypes.reserve(valueParts.size()); + for (Value value : valueParts) + contiguousValueTypes.push_back(value.getType()); + FailureOr> storeParts = materializeDataLayoutConversion( + op, valueParts, contiguousValueTypes, valueVMIType.getLayoutAttr(), + VMILayoutAttr::getContiguous(rewriter.getContext()), rewriter); + if (failed(storeParts)) + return failure(); + + auto maskVMIType = cast(op.getMask().getType()); + SmallVector contiguousMaskTypes; + contiguousMaskTypes.reserve(maskParts.size()); + for (Value mask : maskParts) + contiguousMaskTypes.push_back(mask.getType()); + FailureOr> storeMasks = materializeMaskLayoutConversion( + op, maskParts, contiguousMaskTypes, maskVMIType.getLayoutAttr(), + VMILayoutAttr::getContiguous(rewriter.getContext()), rewriter); + if (failed(storeMasks)) + return failure(); + + if (storeParts->size() != storeMasks->size()) + return rewriter.notifyMatchFailure( + op, "masked_store converted value/mask arity mismatch"); + + for (auto [index, valueAndMask] : + llvm::enumerate(llvm::zip_equal(*storeParts, *storeMasks))) { + auto [value, mask] = valueAndMask; + auto vregType = dyn_cast(value.getType()); + if (!vregType || !isa(mask.getType())) + return rewriter.notifyMatchFailure( + op, "masked_store converted parts must be vreg/mask"); + FailureOr activeLanes = + getContiguousActiveDataLanes(valueVMIType, index); + if (failed(activeLanes)) + return rewriter.notifyMatchFailure( + op, "failed to compute masked_store active lanes"); + if (*activeLanes == 0) + continue; + FailureOr storeMask = createMaskedStorePredicate( + op.getLoc(), valueVMIType, index, mask, vregType, rewriter); + if (failed(storeMask)) + return rewriter.notifyMatchFailure( + op, "failed to materialize masked_store predicate"); + Value chunkOffset = createChunkOffset( + op.getLoc(), *offset, index * *lanesPerPart, rewriter); + rewriter.create(op.getLoc(), + /*updated_base=*/Type{}, value, *destination, + chunkOffset, /*dist=*/nullptr, *storeMask); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +struct OneToNVMIScatterOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIScatterOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + FailureOr destination = + getSingleValue(op, adaptor.getDestination(), + "scatter destination must convert to one value", + rewriter); + if (failed(destination)) + return failure(); + + ValueRange valueParts = adaptor.getValue(); + ValueRange indicesParts = adaptor.getIndices(); + ValueRange maskParts = adaptor.getMask(); + if (valueParts.size() != indicesParts.size() || + valueParts.size() != maskParts.size()) + return rewriter.notifyMatchFailure(op, + "scatter physical arity mismatch"); + + for (auto [value, indices, mask] : + llvm::zip_equal(valueParts, indicesParts, maskParts)) { + if (!isa(value.getType()) || + !isa(indices.getType()) || !isa(mask.getType())) + return rewriter.notifyMatchFailure( + op, "scatter physical part type mismatch"); + rewriter.create(op.getLoc(), value, *destination, indices, + mask); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +struct OneToNVMITileReadOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMITileReadOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto resultVMIType = cast(op.getResult().getType()); + FailureOr source = + getSingleValue(op, adaptor.getSource(), + "tile_read source must convert to one value", rewriter); + if (failed(source)) + return failure(); + + Value zero = rewriter.create(op.getLoc(), 0); + FailureOr lanesPerPart = verifyFullOrSafeReadVRegChunks( + op, resultVMIType, (*source).getType(), zero, rewriter); + if (failed(lanesPerPart)) + return failure(); + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); + if (resultLayout && resultLayout.isDeinterleaved() && + resultLayout.getFactor() == 2) { + std::optional dist = + getX2MemoryDistToken(resultVMIType.getElementType(), "DINTLV"); + if (dist && !resultTypes.empty() && resultTypes.size() % 2 == 0) { + int64_t groups = resultTypes.size() / 2; + SmallVector lows; + SmallVector highs; + lows.reserve(groups); + highs.reserve(groups); + for (int64_t group = 0; group < groups; ++group) { + Type lowType = resultTypes[group]; + Type highType = resultTypes[groups + group]; + if (lowType != highType) + return rewriter.notifyMatchFailure( + op, "vldsx2 requires matching low/high result types"); + Value chunkOffset = createChunkOffset( + op.getLoc(), zero, group * 2 * *lanesPerPart, rewriter); + auto load = rewriter.create( + op.getLoc(), lowType, highType, *source, chunkOffset, + rewriter.getStringAttr(*dist)); + lows.push_back(load.getLow()); + highs.push_back(load.getHigh()); + } + SmallVector results; + results.reserve(resultTypes.size()); + results.append(lows); + results.append(highs); + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + } + + SmallVector contiguousParts; + contiguousParts.reserve(resultTypes.size()); + for (auto [index, resultType] : llvm::enumerate(resultTypes)) { + auto vregType = dyn_cast(resultType); + if (!vregType) + return rewriter.notifyMatchFailure(op, "tile_read result must be vreg"); + Value chunkOffset = createChunkOffset( + op.getLoc(), zero, index * *lanesPerPart, rewriter); + contiguousParts.push_back( + rewriter + .create(op.getLoc(), resultType, + /*updated_base=*/Type{}, *source, chunkOffset, + /*dist=*/nullptr) + .getResult()); + } + + FailureOr> results = materializeDataLayoutConversion( + op, contiguousParts, resultTypes, + VMILayoutAttr::getContiguous(rewriter.getContext()), + resultVMIType.getLayoutAttr(), rewriter); + if (failed(results)) + return failure(); + + rewriter.replaceOp(op, *results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMITileWriteOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMITileWriteOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto valueVMIType = cast(op.getValue().getType()); + FailureOr lanesPerPart = + getDataLanesPerPart(valueVMIType.getElementType()); + if (failed(lanesPerPart)) + return rewriter.notifyMatchFailure( + op, "tile_write requires known physical lanes per part"); + bool fullPhysicalChunks = + succeeded(checkFullDataPhysicalChunks(valueVMIType, nullptr)); + + FailureOr destination = getSingleValue( + op, adaptor.getDestination(), + "tile_write destination must convert to one value", rewriter); + if (failed(destination)) + return failure(); + + ValueRange valueParts = adaptor.getValue(); + Value zero = rewriter.create(op.getLoc(), 0); + VMILayoutAttr valueLayout = valueVMIType.getLayoutAttr(); + if (fullPhysicalChunks && valueLayout && valueLayout.isDeinterleaved() && + valueLayout.getFactor() == 2) { + std::optional dist = + getX2MemoryDistToken(valueVMIType.getElementType(), "INTLV"); + if (dist && !valueParts.empty() && valueParts.size() % 2 == 0) { + int64_t groups = valueParts.size() / 2; + for (int64_t group = 0; group < groups; ++group) { + Value low = valueParts[group]; + Value high = valueParts[groups + group]; + if (low.getType() != high.getType()) + return rewriter.notifyMatchFailure( + op, "vstsx2 requires matching low/high value types"); + auto vregType = dyn_cast(low.getType()); + if (!vregType) + return rewriter.notifyMatchFailure(op, + "tile_write value must be vreg"); + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for tile_write mask"); + Value chunkOffset = createChunkOffset( + op.getLoc(), zero, group * 2 * *lanesPerPart, rewriter); + rewriter.create(op.getLoc(), low, high, *destination, + chunkOffset, rewriter.getStringAttr(*dist), + *mask); + } + rewriter.eraseOp(op); + return success(); + } + } + + SmallVector contiguousTypes; + contiguousTypes.reserve(valueParts.size()); + for (Value value : valueParts) + contiguousTypes.push_back(value.getType()); + + FailureOr> storeParts = materializeDataLayoutConversion( + op, valueParts, contiguousTypes, valueVMIType.getLayoutAttr(), + VMILayoutAttr::getContiguous(rewriter.getContext()), rewriter); + if (failed(storeParts)) + return failure(); + + for (auto [index, value] : llvm::enumerate(*storeParts)) { + auto vregType = dyn_cast(value.getType()); + if (!vregType) + return rewriter.notifyMatchFailure(op, "tile_write value must be vreg"); + if (!fullPhysicalChunks) { + FailureOr activeLanes = + getContiguousActiveDataLanes(valueVMIType, index); + if (failed(activeLanes)) + return rewriter.notifyMatchFailure( + op, "failed to compute tile_write active lanes"); + if (*activeLanes == 0) + continue; + } + FailureOr mask = fullPhysicalChunks + ? createAllTrueMaskForVReg(op.getLoc(), + vregType, rewriter) + : createContiguousStoreMask(op.getLoc(), + valueVMIType, + index, vregType, + rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for tile_write mask"); + Value chunkOffset = createChunkOffset( + op.getLoc(), zero, index * *lanesPerPart, rewriter); + rewriter.create(op.getLoc(), + /*updated_base=*/Type{}, value, *destination, + chunkOffset, /*dist=*/nullptr, *mask); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +template +struct OneToNVMIBinaryOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(SourceOp op, typename OneToNOpConversionPattern< + SourceOp>::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange lhsParts = adaptor.getLhs(); + ValueRange rhsParts = adaptor.getRhs(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (lhsParts.size() != rhsParts.size() || + lhsParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, "physical binary arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [lhs, rhs, resultType] : + llvm::zip_equal(lhsParts, rhsParts, resultTypes)) { + auto vregType = dyn_cast(resultType); + if (!vregType || lhs.getType() != resultType || + rhs.getType() != resultType) + return rewriter.notifyMatchFailure(op, + "physical binary part type mismatch"); + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for all-true binary mask"); + results.push_back( + rewriter.create(op.getLoc(), resultType, lhs, rhs, *mask) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIFmaOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIFmaOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange lhsParts = adaptor.getLhs(); + ValueRange rhsParts = adaptor.getRhs(); + ValueRange accParts = adaptor.getAcc(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (lhsParts.size() != rhsParts.size() || + lhsParts.size() != accParts.size() || + lhsParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, "fma physical arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [lhs, rhs, acc, resultType] : + llvm::zip_equal(lhsParts, rhsParts, accParts, resultTypes)) { + auto vregType = dyn_cast(resultType); + if (!vregType || lhs.getType() != resultType || + rhs.getType() != resultType || acc.getType() != resultType) + return rewriter.notifyMatchFailure( + op, "fma requires matching physical vreg parts"); + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure(op, + "unsupported element type for fma"); + results.push_back( + rewriter.create(op.getLoc(), resultType, acc, lhs, rhs, + *mask) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +template +struct OneToNVMIUnaryOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(SourceOp op, typename OneToNOpConversionPattern< + SourceOp>::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, "physical unary arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [source, resultType] : llvm::zip_equal(sourceParts, resultTypes)) { + auto vregType = dyn_cast(resultType); + if (!vregType || source.getType() != resultType) + return rewriter.notifyMatchFailure(op, + "physical unary part type mismatch"); + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for all-true unary mask"); + results.push_back( + rewriter.create(op.getLoc(), resultType, source, *mask) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +template +struct OneToNVMIMaskBinaryOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(SourceOp op, typename OneToNOpConversionPattern< + SourceOp>::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange lhsParts = adaptor.getLhs(); + ValueRange rhsParts = adaptor.getRhs(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (lhsParts.size() != rhsParts.size() || + lhsParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, + "physical mask binary arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [lhs, rhs, resultType] : + llvm::zip_equal(lhsParts, rhsParts, resultTypes)) { + auto maskType = dyn_cast(resultType); + if (!maskType || lhs.getType() != resultType || + rhs.getType() != resultType) + return rewriter.notifyMatchFailure( + op, "physical mask binary part type mismatch"); + FailureOr seedMask = + createAllTrueMask(op.getLoc(), maskType, rewriter); + if (failed(seedMask)) + return rewriter.notifyMatchFailure( + op, "unsupported mask type for all-true mask binary seed"); + results.push_back( + rewriter.create(op.getLoc(), resultType, lhs, rhs, + *seedMask) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +template +struct OneToNVMIMaskUnaryOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(SourceOp op, typename OneToNOpConversionPattern< + SourceOp>::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, + "physical mask unary arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [source, resultType] : + llvm::zip_equal(sourceParts, resultTypes)) { + auto maskType = dyn_cast(resultType); + if (!maskType || source.getType() != resultType) + return rewriter.notifyMatchFailure( + op, "physical mask unary part type mismatch"); + FailureOr seedMask = + createAllTrueMask(op.getLoc(), maskType, rewriter); + if (failed(seedMask)) + return rewriter.notifyMatchFailure( + op, "unsupported mask type for all-true mask unary seed"); + results.push_back( + rewriter.create(op.getLoc(), resultType, source, *seedMask) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +template +struct OneToNVMICmpOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(SourceOp op, typename OneToNOpConversionPattern< + SourceOp>::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + std::optional cmpMode = getVPTOCmpMode(op.getPredicate()); + if (!cmpMode) + return op.emitOpError() + << kVMIDiagUnsupportedPrefix << "compare predicate " + << op.getPredicate() + << " cannot be lowered to pto.vcmp; supported predicates are " + "eq/ne/lt/le/gt/ge, ordered FP forms " + "oeq/one/olt/ole/ogt/oge, and signed integer forms " + "slt/sle/sgt/sge"; + + ValueRange lhsParts = adaptor.getLhs(); + ValueRange rhsParts = adaptor.getRhs(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (lhsParts.size() != rhsParts.size() || + lhsParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, "physical cmp arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [lhs, rhs, resultType] : + llvm::zip_equal(lhsParts, rhsParts, resultTypes)) { + auto maskType = dyn_cast(resultType); + if (!maskType || lhs.getType() != rhs.getType() || + !isa(lhs.getType())) + return rewriter.notifyMatchFailure(op, + "physical cmp part type mismatch"); + FailureOr seedMask = + createAllTrueMask(op.getLoc(), maskType, rewriter); + if (failed(seedMask)) + return rewriter.notifyMatchFailure( + op, "unsupported mask type for all-true cmp seed"); + results.push_back( + rewriter + .create(op.getLoc(), resultType, lhs, rhs, *seedMask, + rewriter.getStringAttr(*cmpMode)) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMISelectOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMISelectOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange maskParts = adaptor.getMask(); + ValueRange trueParts = adaptor.getTrueValue(); + ValueRange falseParts = adaptor.getFalseValue(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (maskParts.size() != trueParts.size() || + trueParts.size() != falseParts.size() || + trueParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, "physical select arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [mask, trueValue, falseValue, resultType] : + llvm::zip_equal(maskParts, trueParts, falseParts, resultTypes)) { + if (!isa(mask.getType()) || trueValue.getType() != resultType || + falseValue.getType() != resultType || !isa(resultType)) + return rewriter.notifyMatchFailure( + op, "physical select part type mismatch"); + results.push_back( + rewriter + .create(op.getLoc(), resultType, trueValue, falseValue, + mask) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIActivePrefixIndexOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIActivePrefixIndexOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIActivePrefixIndexOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange maskParts = adaptor.getMask(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (maskParts.size() != 1 || resultTypes.size() != 1) + return rewriter.notifyMatchFailure( + op, "active_prefix_index supports only one physical part"); + + auto resultType = dyn_cast(resultTypes.front()); + auto maskType = dyn_cast(maskParts.front().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure( + op, "active_prefix_index requires physical vreg/mask parts"); + + auto intType = dyn_cast(resultType.getElementType()); + if (!intType || !intType.isSignless()) + return rewriter.notifyMatchFailure( + op, "active_prefix_index requires signless integer result part"); + + FailureOr seedMask = + createAllTrueMaskForVReg(op.getLoc(), resultType, rewriter); + if (failed(seedMask)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for active_prefix_index seed mask"); + + Value zero = rewriter.create( + op.getLoc(), 0, intType.getWidth()); + Value carrier = + rewriter + .create(op.getLoc(), resultType, zero, *seedMask, + /*position=*/nullptr) + .getResult(); + Value result = + rewriter + .create(op.getLoc(), resultType, carrier, + maskParts.front()) + .getResult(); + rewriter.replaceOp(op, SmallVector{result}, + adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMICompressOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMICompressOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + ValueRange maskParts = adaptor.getMask(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.size() != 1 || maskParts.size() != 1 || + resultTypes.size() != 1) + return rewriter.notifyMatchFailure( + op, "compress supports only one physical part"); + + auto resultType = dyn_cast(resultTypes.front()); + if (!resultType || sourceParts.front().getType() != resultType || + !isa(maskParts.front().getType())) + return rewriter.notifyMatchFailure( + op, "compress requires physical source/mask/result parts"); + + Value result = + rewriter + .create(op.getLoc(), resultType, sourceParts.front(), + maskParts.front()) + .getResult(); + rewriter.replaceOp(op, SmallVector{result}, + adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMICompressStoreOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMICompressStoreOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMICompressStoreOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + FailureOr destination = + getSingleValue(op, adaptor.getDestination(), + "compress_store destination must convert to one value", + rewriter); + FailureOr offset = + getSingleValue(op, adaptor.getOffset(), + "compress_store offset must convert to one value", + rewriter); + if (failed(destination) || failed(offset)) + return failure(); + + ValueRange valueParts = adaptor.getValue(); + ValueRange maskParts = adaptor.getMask(); + if (valueParts.size() != 1 || maskParts.size() != 1) + return rewriter.notifyMatchFailure( + op, "compress_store supports only one physical part"); + + auto valueType = dyn_cast(valueParts.front().getType()); + if (!valueType || !isa(maskParts.front().getType()) || + !isa((*destination).getType())) + return rewriter.notifyMatchFailure( + op, "compress_store requires physical value/mask and ptr " + "destination"); + + Value storeBase = + rewriter + .create(op.getLoc(), (*destination).getType(), + *destination, *offset) + .getResult(); + Value squeezed = + rewriter + .create(op.getLoc(), valueType, valueParts.front(), + maskParts.front()) + .getResult(); + auto align = + rewriter.create(op.getLoc(), + AlignType::get(rewriter.getContext())); + auto store = rewriter.create( + op.getLoc(), align.getResult().getType(), align.getResult(), squeezed, + storeBase, rewriter.getStringAttr("POST_UPDATE")); + rewriter.create(op.getLoc(), store.getAlignOut(), storeBase); + rewriter.eraseOp(op); + return success(); + } +}; + +struct OneToNVMIReduceAddIOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIReduceAddIOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIReduceAddIOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + ValueRange initParts = adaptor.getInit(); + ValueRange maskParts = adaptor.getMask(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.empty() || sourceParts.size() != maskParts.size() || + initParts.size() != 1 || resultTypes.size() != 1) + return rewriter.notifyMatchFailure( + op, "reduce_addi requires matching source/mask chunks and one " + "init/result chunk"); + + auto resultType = dyn_cast(resultTypes.front()); + auto maskType = dyn_cast(maskParts.front().getType()); + if (!resultType || !maskType || initParts.front().getType() != resultType) + return rewriter.notifyMatchFailure( + op, "reduce_addi requires matching physical source/init/result " + "vregs and one mask"); + + for (Value sourcePart : sourceParts) + if (sourcePart.getType() != resultType) + return rewriter.notifyMatchFailure( + op, "reduce_addi requires every source chunk to match result " + "vreg type"); + for (Value maskPart : maskParts) + if (maskPart.getType() != maskType) + return rewriter.notifyMatchFailure( + op, "reduce_addi requires every mask chunk to have the same " + "predicate type"); + + FailureOr firstLaneMask = + createPrefixMask(op.getLoc(), maskType, "PAT_VL1", rewriter); + if (failed(firstLaneMask)) + return rewriter.notifyMatchFailure( + op, "failed to create reduce_addi first-lane mask"); + + Value accumulator = initParts.front(); + for (auto [sourcePart, maskPart] : llvm::zip_equal(sourceParts, maskParts)) { + Value reduced = + rewriter.create(op.getLoc(), resultType, sourcePart, + maskPart) + .getResult(); + accumulator = + rewriter + .create(op.getLoc(), resultType, reduced, accumulator, + *firstLaneMask) + .getResult(); + } + + rewriter.replaceOp(op, SmallVector{accumulator}, + adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIReduceAddFOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIReduceAddFOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIReduceAddFOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + ValueRange initParts = adaptor.getInit(); + ValueRange maskParts = adaptor.getMask(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.empty() || sourceParts.size() != maskParts.size() || + initParts.size() != 1 || resultTypes.size() != 1) + return rewriter.notifyMatchFailure( + op, "reduce_addf requires matching source/mask chunks and one " + "init/result chunk"); + + auto resultType = dyn_cast(resultTypes.front()); + auto maskType = dyn_cast(maskParts.front().getType()); + if (!resultType || !maskType || initParts.front().getType() != resultType) + return rewriter.notifyMatchFailure( + op, "reduce_addf requires matching physical source/init/result " + "vregs and one mask"); + + for (Value sourcePart : sourceParts) + if (sourcePart.getType() != resultType) + return rewriter.notifyMatchFailure( + op, "reduce_addf requires every source chunk to match result " + "vreg type"); + for (Value maskPart : maskParts) + if (maskPart.getType() != maskType) + return rewriter.notifyMatchFailure( + op, "reduce_addf requires every mask chunk to have the same " + "predicate type"); + + FailureOr firstLaneMask = + createPrefixMask(op.getLoc(), maskType, "PAT_VL1", rewriter); + if (failed(firstLaneMask)) + return rewriter.notifyMatchFailure( + op, "failed to create reduce_addf first-lane mask"); + + Value accumulator = initParts.front(); + for (auto [sourcePart, maskPart] : llvm::zip_equal(sourceParts, maskParts)) { + Value reduced = + rewriter.create(op.getLoc(), resultType, sourcePart, + maskPart) + .getResult(); + accumulator = + rewriter + .create(op.getLoc(), resultType, reduced, accumulator, + *firstLaneMask) + .getResult(); + } + + rewriter.replaceOp(op, SmallVector{accumulator}, + adaptor.getResultMapping()); + return success(); + } +}; + +template +struct OneToNVMIReduceMinMaxFOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite( + SourceOp op, + typename OneToNOpConversionPattern::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + ValueRange initParts = adaptor.getInit(); + ValueRange maskParts = adaptor.getMask(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.empty() || sourceParts.size() != maskParts.size() || + initParts.size() != 1 || resultTypes.size() != 1) + return rewriter.notifyMatchFailure( + op, "floating min/max reduction requires matching source/mask chunks " + "and one init/result chunk"); + + auto resultType = dyn_cast(resultTypes.front()); + auto maskType = dyn_cast(maskParts.front().getType()); + if (!resultType || !maskType || initParts.front().getType() != resultType) + return rewriter.notifyMatchFailure( + op, "floating min/max reduction requires matching physical source/" + "init/result vregs and one mask"); + + for (Value sourcePart : sourceParts) + if (sourcePart.getType() != resultType) + return rewriter.notifyMatchFailure( + op, "floating min/max reduction requires every source chunk to " + "match result vreg type"); + for (Value maskPart : maskParts) + if (maskPart.getType() != maskType) + return rewriter.notifyMatchFailure( + op, "floating min/max reduction requires every mask chunk to have " + "the same predicate type"); + + FailureOr firstLaneMask = + createPrefixMask(op.getLoc(), maskType, "PAT_VL1", rewriter); + if (failed(firstLaneMask)) + return rewriter.notifyMatchFailure( + op, "failed to create floating min/max reduction first-lane mask"); + + Value accumulator = initParts.front(); + for (auto [sourcePart, maskPart] : + llvm::zip_equal(sourceParts, maskParts)) { + Value reduced = + rewriter.create(op.getLoc(), resultType, sourcePart, + maskPart) + .getResult(); + accumulator = + rewriter + .create(op.getLoc(), resultType, reduced, accumulator, + *firstLaneMask) + .getResult(); + } + + rewriter.replaceOp(op, SmallVector{accumulator}, + adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIExtFOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIExtFOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.empty()) + return rewriter.notifyMatchFailure( + op, "extf requires at least one physical source chunk"); + + auto sourceType = dyn_cast(sourceParts.front().getType()); + if (!sourceType) + return rewriter.notifyMatchFailure(op, "expected physical extf source"); + for (Value sourcePart : sourceParts) { + auto currentSourceType = dyn_cast(sourcePart.getType()); + if (!currentSourceType || currentSourceType != sourceType) + return rewriter.notifyMatchFailure( + op, "extf source physical parts must have matching type"); + } + + SmallVector resultVRegTypes; + resultVRegTypes.reserve(resultTypes.size()); + for (Type resultType : resultTypes) { + auto resultVRegType = dyn_cast(resultType); + if (!resultVRegType || + (resultVRegTypes.empty() + ? !resultVRegType.getElementType().isF32() + : resultVRegType != resultVRegTypes.front())) + return rewriter.notifyMatchFailure( + op, "unsupported physical extf result type"); + resultVRegTypes.push_back(resultVRegType); + } + + unsigned sourceBits = + pto::getPTOStorageElemBitWidth(sourceType.getElementType()); + ArrayRef parts; + int64_t factor = 0; + if (sourceBits == 16 && resultTypes.size() == 2 * sourceParts.size()) { + static constexpr StringRef kEvenOddParts[] = {"EVEN", "ODD"}; + parts = kEvenOddParts; + factor = 2; + } else if (sourceBits == 8 && + resultTypes.size() == 4 * sourceParts.size()) { + static constexpr StringRef kPacked4Parts[] = {"P0", "P1", "P2", "P3"}; + parts = kPacked4Parts; + factor = 4; + } else { + return rewriter.notifyMatchFailure( + op, "unsupported physical extf source/result width relation"); + } + + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), sourceType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure(op, + "failed to build extf seed mask"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (int64_t partIndex = 0; partIndex < factor; ++partIndex) { + for (auto [chunkIndex, sourcePart] : llvm::enumerate(sourceParts)) { + VRegType resultType = + resultVRegTypes[partIndex * sourceParts.size() + chunkIndex]; + results.push_back( + rewriter + .create(op.getLoc(), resultType, sourcePart, *mask, + /*rnd=*/nullptr, /*sat=*/nullptr, + rewriter.getStringAttr(parts[partIndex])) + .getResult()); + } + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMITruncFOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMITruncFOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if ((sourceParts.size() != 2 && sourceParts.size() != 4) || + resultTypes.size() != 1) + return rewriter.notifyMatchFailure( + op, "only f32 deinterleaved=2/4 to 16/8-bit contiguous truncf is supported"); + + auto sourceType0 = dyn_cast(sourceParts.front().getType()); + auto resultType = dyn_cast(resultTypes.front()); + if (!sourceType0 || !sourceType0.getElementType().isF32() || !resultType) + return rewriter.notifyMatchFailure( + op, "unsupported physical truncf source/result type"); + for (Value sourcePart : sourceParts) { + auto sourceType = dyn_cast(sourcePart.getType()); + if (!sourceType || sourceType != sourceType0) + return rewriter.notifyMatchFailure( + op, "truncf source physical parts must have matching f32 type"); + } + + unsigned resultBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); + ArrayRef parts; + if (sourceParts.size() == 2 && resultBits == 16) { + static constexpr StringRef kEvenOddParts[] = {"EVEN", "ODD"}; + parts = kEvenOddParts; + } else if (sourceParts.size() == 4 && resultBits == 8) { + static constexpr StringRef kPacked4Parts[] = {"P0", "P1", "P2", "P3"}; + parts = kPacked4Parts; + } else { + return rewriter.notifyMatchFailure( + op, "unsupported physical truncf source/result width relation"); + } + + FailureOr sourceMask = + createAllTrueMaskForVReg(op.getLoc(), sourceType0, rewriter); + FailureOr resultMask = + createAllTrueMaskForVReg(op.getLoc(), resultType, rewriter); + if (failed(sourceMask) || failed(resultMask)) + return rewriter.notifyMatchFailure(op, + "failed to build truncf masks"); + + StringAttr rnd = rewriter.getStringAttr("R"); + StringAttr sat = rewriter.getStringAttr("SAT"); + SmallVector partials; + partials.reserve(parts.size()); + for (auto [sourcePart, part] : llvm::zip_equal(sourceParts, parts)) { + partials.push_back( + rewriter + .create(op.getLoc(), resultType, sourcePart, *sourceMask, + rnd, sat, rewriter.getStringAttr(part)) + .getResult()); + } + + Value merged = partials.front(); + for (Value partial : llvm::drop_begin(partials)) + merged = + rewriter + .create(op.getLoc(), resultType, merged, partial, + *resultMask) + .getResult(); + + rewriter.replaceOp(op, merged, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIBitcastOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIBitcastOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, + "physical bitcast arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [sourcePart, resultType] : + llvm::zip_equal(sourceParts, resultTypes)) { + if (!isa(sourcePart.getType()) || !isa(resultType)) + return rewriter.notifyMatchFailure( + op, "physical bitcast part type mismatch"); + results.push_back( + rewriter.create(op.getLoc(), resultType, sourcePart) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIChannelSplitOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIChannelSplitOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIChannelSplitOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + int64_t channels = op.getNumResults(); + if (channels != 2 && channels != 4) + return rewriter.notifyMatchFailure( + op, "channel_split only supports 2 or 4 channels"); + + auto sourceType = cast(op.getSource().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + auto channelLayout = + VMILayoutAttr::getDeinterleaved(rewriter.getContext(), channels); + if (!sourceLayout || + (!sourceLayout.isContiguous() && sourceLayout != channelLayout)) + return rewriter.notifyMatchFailure( + op, + "channel_split requires contiguous or matching deinterleaved source " + "layout"); + for (Value result : op.getResults()) { + auto resultType = cast(result.getType()); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!resultLayout || !resultLayout.isContiguous()) + return rewriter.notifyMatchFailure( + op, "channel_split requires contiguous result layouts"); + } + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(); + FailureOr> results = materializeDataLayoutConversion( + op, adaptor.getSource(), resultTypes, sourceLayout, channelLayout, + rewriter); + if (failed(results)) + return failure(); + + rewriter.replaceOp(op, *results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIChannelMergeOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIChannelMergeOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIChannelMergeOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + int64_t channels = op.getInputs().size(); + if (channels != 2 && channels != 4) + return rewriter.notifyMatchFailure( + op, "channel_merge only supports 2 or 4 channels"); + + for (Value input : op.getInputs()) { + auto inputType = cast(input.getType()); + VMILayoutAttr inputLayout = inputType.getLayoutAttr(); + if (!inputLayout || !inputLayout.isContiguous()) + return rewriter.notifyMatchFailure( + op, "channel_merge requires contiguous input layouts"); + } + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + auto channelLayout = + VMILayoutAttr::getDeinterleaved(rewriter.getContext(), channels); + if (!resultLayout || + (!resultLayout.isContiguous() && resultLayout != channelLayout)) + return rewriter.notifyMatchFailure( + op, + "channel_merge requires contiguous or matching deinterleaved result " + "layout"); + + FailureOr> results = materializeDataLayoutConversion( + op, adaptor.getFlatOperands(), + adaptor.getResultMapping().getConvertedTypes(0), channelLayout, + resultLayout, rewriter); + if (failed(results)) + return failure(); + + rewriter.replaceOp(op, *results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIShuffleOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIShuffleOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + std::string reason; + FailureOr> sourceFlatIndices = + computeShuffleForwardingSourceParts(op, &reason); + if (succeeded(sourceFlatIndices)) { + SmallVector results; + results.reserve(resultTypes.size()); + for (int64_t sourceFlatIndex : *sourceFlatIndices) { + if (sourceFlatIndex >= static_cast(sourceParts.size())) + return rewriter.notifyMatchFailure( + op, "shuffle forwarding source part range is out of bounds"); + results.push_back(sourceParts[sourceFlatIndex]); + } + + if (failed(verifyIdentityPartForwarding(op, results, resultTypes, + rewriter))) + return failure(); + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + std::string splatReason; + FailureOr splatSource = + computeShuffleLane0SplatSourcePart(op, &splatReason); + if (succeeded(splatSource)) { + if (*splatSource >= static_cast(sourceParts.size())) + return rewriter.notifyMatchFailure( + op, "shuffle lane0 splat source part range is out of bounds"); + + SmallVector results; + results.reserve(resultTypes.size()); + Value sourcePart = sourceParts[*splatSource]; + for (Type resultType : resultTypes) { + auto sourceVRegType = dyn_cast(sourcePart.getType()); + auto resultVRegType = dyn_cast(resultType); + if (!sourceVRegType || !resultVRegType || + sourceVRegType != resultVRegType) + return rewriter.notifyMatchFailure( + op, "shuffle lane0 splat requires matching physical vreg type"); + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), resultVRegType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "failed to create shuffle lane0 splat mask"); + results.push_back( + rewriter + .create(op.getLoc(), resultType, sourcePart, *mask, + rewriter.getStringAttr("LOWEST")) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + std::string vselrReason; + FailureOr> vselrPlans = + computeShuffleVselrPlans(op, &vselrReason); + if (failed(vselrPlans)) + return rewriter.notifyMatchFailure( + op, Twine("shuffle vselr ") + vselrReason); + + if (vselrPlans->size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, + "shuffle vselr arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [plan, resultType] : llvm::zip_equal(*vselrPlans, resultTypes)) { + if (plan.sourceFlatIndex >= static_cast(sourceParts.size())) + return rewriter.notifyMatchFailure( + op, "shuffle vselr source part range is out of bounds"); + + auto sourceVRegType = + dyn_cast(sourceParts[plan.sourceFlatIndex].getType()); + auto resultVRegType = dyn_cast(resultType); + if (!sourceVRegType || !resultVRegType || + sourceVRegType.getElementCount() != + resultVRegType.getElementCount() || + sourceVRegType.getElementType() != resultVRegType.getElementType()) + return rewriter.notifyMatchFailure( + op, "shuffle vselr source/result type mismatch"); + + unsigned indexBits = + pto::getPTOStorageElemBitWidth(sourceVRegType.getElementType()); + if (indexBits != 8 && indexBits != 16 && indexBits != 32) + return rewriter.notifyMatchFailure( + op, "shuffle vselr requires 8/16/32-bit index elements"); + + auto indexElementType = + IntegerType::get(rewriter.getContext(), indexBits); + Type indexType = + VRegType::get(rewriter.getContext(), + sourceVRegType.getElementCount(), indexElementType); + FailureOr base = createScalarOffsetConstant( + op.getLoc(), indexElementType, plan.baseLane, rewriter); + if (failed(base)) + return rewriter.notifyMatchFailure( + op, "failed to materialize shuffle vselr index base"); + StringAttr orderAttr = + plan.descending ? rewriter.getStringAttr("DESC") : StringAttr{}; + Value indexVector = + rewriter.create(op.getLoc(), indexType, *base, orderAttr) + .getResult(); + results.push_back( + rewriter + .create(op.getLoc(), resultType, + sourceParts[plan.sourceFlatIndex], indexVector) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +Block *convertBranchDestBlock(Block *block, OneToNPatternRewriter &rewriter, + OneToNTypeConverter &typeConverter, + llvm::DenseMap &converted) { + auto [it, inserted] = converted.try_emplace(block, nullptr); + if (!inserted) + return it->second; + + OneToNTypeMapping argMapping(block->getArgumentTypes()); + if (failed(typeConverter.computeTypeMapping(block->getArgumentTypes(), + argMapping)) || + !argMapping.hasNonIdentityConversion()) { + it->second = block; + return block; + } + + Block *newBlock = rewriter.applySignatureConversion(block, argMapping); + it->second = newBlock; + return newBlock; +} + +struct OneToNCFBranchOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(cf::BranchOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto *converter = getTypeConverter(); + llvm::DenseMap convertedBlocks; + Block *dest = convertBranchDestBlock(op.getDest(), rewriter, *converter, + convertedBlocks); + + if (!adaptor.getOperandMapping().hasNonIdentityConversion() && + dest == op.getDest()) + return failure(); + + rewriter.replaceOpWithNewOp(op, dest, + adaptor.getFlatOperands()); + return success(); + } +}; + +struct OneToNCFCondBranchOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(cf::CondBranchOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto *converter = getTypeConverter(); + llvm::DenseMap convertedBlocks; + Block *trueDest = convertBranchDestBlock(op.getTrueDest(), rewriter, + *converter, convertedBlocks); + Block *falseDest = convertBranchDestBlock(op.getFalseDest(), rewriter, + *converter, convertedBlocks); + + if (!adaptor.getOperandMapping().hasNonIdentityConversion() && + trueDest == op.getTrueDest() && falseDest == op.getFalseDest()) + return failure(); + + ValueRange condition = adaptor.getCondition(); + if (condition.size() != 1) + return rewriter.notifyMatchFailure( + op, "condition converted to multiple values"); + + SmallVector trueOperands; + SmallVector falseOperands; + ValueRange flatOperands = adaptor.getFlatOperands(); + const OneToNTypeMapping &operandMapping = adaptor.getOperandMapping(); + unsigned operandIndex = 1; + for (unsigned i = 0, e = op.getNumTrueOperands(); i < e; ++i) + llvm::append_range( + trueOperands, + operandMapping.getConvertedValues(flatOperands, operandIndex++)); + for (unsigned i = 0, e = op.getNumFalseOperands(); i < e; ++i) + llvm::append_range( + falseOperands, + operandMapping.getConvertedValues(flatOperands, operandIndex++)); + + rewriter.replaceOpWithNewOp( + op, condition.front(), trueDest, trueOperands, falseDest, + falseOperands); + return success(); + } +}; + +struct OneToNCFSwitchOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(cf::SwitchOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto *converter = getTypeConverter(); + llvm::DenseMap convertedBlocks; + Block *defaultDest = + convertBranchDestBlock(op.getDefaultDestination(), rewriter, + *converter, convertedBlocks); + + SmallVector caseDests; + caseDests.reserve(op.getCaseDestinations().size()); + for (Block *dest : op.getCaseDestinations()) + caseDests.push_back( + convertBranchDestBlock(dest, rewriter, *converter, convertedBlocks)); + + bool changed = defaultDest != op.getDefaultDestination(); + for (auto [oldDest, newDest] : + llvm::zip(op.getCaseDestinations(), caseDests)) + changed |= oldDest != newDest; + changed |= adaptor.getOperandMapping().hasNonIdentityConversion(); + if (!changed) + return failure(); + + ValueRange flag = adaptor.getFlag(); + if (flag.size() != 1) + return rewriter.notifyMatchFailure(op, "flag converted to multiple values"); + + SmallVector defaultOperands; + SmallVector> caseOperandStorage; + SmallVector caseOperands; + ValueRange flatOperands = adaptor.getFlatOperands(); + const OneToNTypeMapping &operandMapping = adaptor.getOperandMapping(); + + unsigned operandIndex = 1; + for (unsigned i = 0, e = op.getDefaultOperands().size(); i < e; ++i) + llvm::append_range( + defaultOperands, + operandMapping.getConvertedValues(flatOperands, operandIndex++)); + + caseOperandStorage.reserve(op.getCaseOperandSegments().size()); + caseOperands.reserve(op.getCaseOperandSegments().size()); + for (int32_t segmentSize : op.getCaseOperandSegments()) { + SmallVector operands; + for (int32_t i = 0; i < segmentSize; ++i) + llvm::append_range( + operands, + operandMapping.getConvertedValues(flatOperands, operandIndex++)); + caseOperandStorage.push_back(std::move(operands)); + } + for (SmallVector &operands : caseOperandStorage) + caseOperands.push_back(operands); + + rewriter.replaceOpWithNewOp( + op, flag.front(), defaultDest, defaultOperands, op.getCaseValuesAttr(), + caseDests, caseOperands); + return success(); + } +}; + +struct OneToNSCFExecuteRegionOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + scf::ExecuteRegionOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(scf::ExecuteRegionOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + SmallVector resultTypes; + const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); + for (unsigned i = 0, e = op->getNumResults(); i < e; ++i) + llvm::append_range(resultTypes, resultMapping.getConvertedTypes(i)); + if (resultTypes == op->getResultTypes()) + return failure(); + + auto newOp = + rewriter.create(op.getLoc(), resultTypes); + newOp->setAttrs(op->getAttrs()); + rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), + newOp.getRegion().end()); + rewriter.replaceOp(op, newOp->getResults(), resultMapping); + return success(); + } +}; + +struct OneToNSCFIndexSwitchOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(scf::IndexSwitchOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange arg = adaptor.getArg(); + if (arg.size() != 1) + return rewriter.notifyMatchFailure( + op, "index_switch selector converted to multiple values"); + + SmallVector resultTypes; + const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); + for (unsigned i = 0, e = op->getNumResults(); i < e; ++i) + llvm::append_range(resultTypes, resultMapping.getConvertedTypes(i)); + if (resultTypes == op->getResultTypes()) + return failure(); + + auto newOp = rewriter.create( + op.getLoc(), resultTypes, arg.front(), op.getCases(), + op.getNumCases()); + newOp->setAttrs(op->getAttrs()); + rewriter.inlineRegionBefore(op.getDefaultRegion(), + newOp.getDefaultRegion(), + newOp.getDefaultRegion().end()); + for (auto [srcRegion, dstRegion] : + llvm::zip(op.getCaseRegions(), newOp.getCaseRegions())) + rewriter.inlineRegionBefore(srcRegion, dstRegion, dstRegion.end()); + rewriter.replaceOp(op, newOp->getResults(), resultMapping); + return success(); + } +}; + +void populateVMIOneToNConversionPatterns( + VMIToVPTOTypeConverter &typeConverter, RewritePatternSet &patterns, + const VMITargetCapabilityRegistry &capabilities) { + populateFuncTypeConversionPatterns(typeConverter, patterns); + scf::populateSCFStructuralOneToNTypeConversions(typeConverter, patterns); + patterns + .add(typeConverter, patterns.getContext()); + patterns.add(typeConverter, + patterns.getContext()); + patterns.add( + typeConverter, patterns.getContext()); + patterns.add, + OneToNVMIMaskBinaryOpPattern, + OneToNVMIMaskBinaryOpPattern, + OneToNVMIMaskUnaryOpPattern, + OneToNVMILoadOpPattern, + OneToNVMIMaskedLoadOpPattern, + OneToNVMIGatherOpPattern, + OneToNVMIExpandLoadOpPattern, + OneToNVMIStoreOpPattern, + OneToNVMIMaskedStoreOpPattern, + OneToNVMIScatterOpPattern, + OneToNVMITileReadOpPattern, + OneToNVMITileWriteOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIFmaOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMICmpOpPattern, + OneToNVMICmpOpPattern, + OneToNVMISelectOpPattern, + OneToNVMIActivePrefixIndexOpPattern, + OneToNVMICompressOpPattern, + OneToNVMICompressStoreOpPattern, + OneToNVMIReduceAddIOpPattern, + OneToNVMIReduceAddFOpPattern, + OneToNVMIReduceMinMaxFOpPattern, + OneToNVMIReduceMinMaxFOpPattern, + OneToNVMIExtFOpPattern, + OneToNVMITruncFOpPattern, + OneToNVMIBitcastOpPattern, + OneToNVMIChannelSplitOpPattern, + OneToNVMIChannelMergeOpPattern, + OneToNVMIShuffleOpPattern>(typeConverter, + patterns.getContext()); + patterns.add( + typeConverter, patterns.getContext(), capabilities); +} + +LogicalResult verifyNoResidualVMIIR(ModuleOp module) { + WalkResult result = module.walk([&](Operation *op) { + if (isa(op)) { + op->emitError() + << kVMIDiagResidualOpPrefix + << "unrealized conversion cast remains after vmi-to-vpto"; + return WalkResult::interrupt(); + } + if (auto createMask = dyn_cast(op)) { + if (!createMask.getActiveLanes().getDefiningOp()) { + createMask.emitError() + << kVMIDiagUnsupportedPrefix + << "dynamic pto.vmi.create_mask active_lanes could not be lowered " + "by the current runtime predicate generation plan"; + return WalkResult::interrupt(); + } + } + if (auto constant = dyn_cast(op)) { + auto denseAttr = dyn_cast(constant.getValue()); + if (denseAttr && !denseAttr.isSplat()) { + constant.emitError() + << kVMIDiagUnsupportedPrefix + << "non-splat pto.vmi.constant requires a vreg immediate or " + "scratch materialization plan"; + return WalkResult::interrupt(); + } + } + if (isVMIOp(op) || hasVMIType(op)) { + op->emitError() + << kVMIDiagResidualOpPrefix + << "failed to convert all VMI ops/types to VPTO"; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); +} + +LogicalResult checkSupportedExtFShape(VMIExtFOp op) { + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (!sourceLayout || !resultLayout || failed(sourceArity) || + failed(resultArity) || !sourceLayout.isContiguous() || + !resultLayout.isDeinterleaved() || + !resultType.getElementType().isF32()) + return failure(); + + unsigned sourceBits = + pto::getPTOStorageElemBitWidth(sourceType.getElementType()); + if (sourceBits == 16 && resultLayout.getFactor() == 2 && + *resultArity == 2 * *sourceArity) + return success(); + if (sourceBits == 8 && resultLayout.getFactor() == 4 && + *resultArity == 4 * *sourceArity) + return success(); + return failure(); +} + +LogicalResult checkSupportedTruncFShape(VMITruncFOp op) { + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (!sourceLayout || !resultLayout || failed(sourceArity) || + failed(resultArity) || !sourceLayout.isDeinterleaved() || + !resultLayout.isContiguous() || !sourceType.getElementType().isF32() || + *resultArity != 1) + return failure(); + + unsigned resultBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); + if (sourceLayout.getFactor() == 2 && *sourceArity == 2 && resultBits == 16) + return success(); + if (sourceLayout.getFactor() == 4 && *sourceArity == 4 && resultBits == 8) + return success(); + return failure(); +} + +FailureOr> +getPhysicalLogicalBitFootprint(VMIVRegType type) { + unsigned elementBits = pto::getPTOStorageElemBitWidth(type.getElementType()); + if (elementBits == 0) + return failure(); + + FailureOr factor = getDataLayoutFactor(type); + FailureOr lanesPerPart = + getDataLanesPerPart(type.getElementType()); + if (failed(factor) || failed(lanesPerPart)) + return failure(); + + SmallVector bits; + for (int64_t part = 0; part < *factor; ++part) { + FailureOr chunks = getDataChunksInPart(type, part); + if (failed(chunks)) + return failure(); + for (int64_t chunk = 0; chunk < *chunks; ++chunk) { + int64_t activeLanes = 0; + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = isPaddingLane(type, part, chunk, lane); + if (failed(padding)) + return failure(); + if (!*padding) + ++activeLanes; + } + bits.push_back(activeLanes * static_cast(elementBits)); + } + } + return bits; +} + +LogicalResult checkSupportedBitcastShape(VMIBitcastOp op, + std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!sourceLayout || !resultLayout) + return fail("requires assigned source and result layouts"); + if (sourceLayout != resultLayout) + return fail("requires matching source and result layouts"); + + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(resultArity)) + return fail("requires computable source and result physical arity"); + if (*sourceArity != *resultArity) + return fail("requires source and result to have the same physical arity"); + + FailureOr> sourceBits = + getPhysicalLogicalBitFootprint(sourceType); + FailureOr> resultBits = + getPhysicalLogicalBitFootprint(resultType); + if (failed(sourceBits) || failed(resultBits)) + return fail("requires computable physical logical bit footprints"); + if (sourceBits->size() != resultBits->size()) + return fail("requires source and result physical footprint counts to " + "match"); + for (auto [source, result] : llvm::zip_equal(*sourceBits, *resultBits)) { + if (source != result) + return fail("requires matching logical bit footprint in every physical " + "chunk"); + } + + return success(); +} + +LogicalResult checkSupportedChannelSplitShape( + const VMITargetCapabilityRegistry &capabilities, VMIChannelSplitOp op, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + int64_t channels = op.getNumResults(); + VMICapabilityResult channelCapability = + capabilities.supportsChannelCount("pto.vmi.channel_split", channels); + if (!channelCapability.isSupported()) + return fail(channelCapability.reason); + + auto sourceType = cast(op.getSource().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + if (!sourceLayout) + return fail("requires assigned source layout"); + auto expectedLayout = + VMILayoutAttr::getDeinterleaved(op.getContext(), channels); + if (!sourceLayout.isContiguous() && sourceLayout != expectedLayout) + return fail("requires source layout to be contiguous or matching " + "deinterleaved channel layout"); + + for (Value result : op.getResults()) { + VMILayoutAttr resultLayout = + cast(result.getType()).getLayoutAttr(); + if (!resultLayout || !resultLayout.isContiguous()) + return fail("requires every result layout to be contiguous"); + } + + auto channelType = + VMIVRegType::get(op.getContext(), sourceType.getElementCount(), + sourceType.getElementType(), expectedLayout); + std::string materializationReason; + if (failed(checkSupportedLayoutMaterialization( + capabilities, sourceType, channelType, sourceLayout, expectedLayout, + &materializationReason))) + return fail(Twine("cannot materialize source to channel layout; ") + + materializationReason); + + FailureOr channelArity = getVMIPhysicalArity(channelType); + int64_t resultArity = 0; + for (Value result : op.getResults()) { + FailureOr arity = + getVMIPhysicalArity(cast(result.getType())); + if (failed(arity)) + return fail("requires computable result physical arity"); + resultArity += *arity; + } + if (failed(channelArity) || *channelArity != resultArity) + return fail("requires channel physical arity to match all result parts"); + + return success(); +} + +LogicalResult checkSupportedChannelMergeShape( + const VMITargetCapabilityRegistry &capabilities, VMIChannelMergeOp op, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + int64_t channels = op.getInputs().size(); + VMICapabilityResult channelCapability = + capabilities.supportsChannelCount("pto.vmi.channel_merge", channels); + if (!channelCapability.isSupported()) + return fail(channelCapability.reason); + + int64_t inputArity = 0; + for (Value input : op.getInputs()) { + auto inputType = cast(input.getType()); + VMILayoutAttr inputLayout = inputType.getLayoutAttr(); + if (!inputLayout || !inputLayout.isContiguous()) + return fail("requires every input layout to be contiguous"); + FailureOr arity = getVMIPhysicalArity(inputType); + if (failed(arity)) + return fail("requires computable input physical arity"); + inputArity += *arity; + } + + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!resultLayout) + return fail("requires assigned result layout"); + auto expectedLayout = + VMILayoutAttr::getDeinterleaved(op.getContext(), channels); + if (!resultLayout.isContiguous() && resultLayout != expectedLayout) + return fail("requires result layout to be contiguous or matching " + "deinterleaved channel layout"); + + auto channelType = + VMIVRegType::get(op.getContext(), resultType.getElementCount(), + resultType.getElementType(), expectedLayout); + FailureOr channelArity = getVMIPhysicalArity(channelType); + if (failed(channelArity) || *channelArity != inputArity) + return fail("requires channel physical arity to match all input parts"); + + std::string materializationReason; + if (failed(checkSupportedLayoutMaterialization( + capabilities, channelType, resultType, expectedLayout, resultLayout, + &materializationReason))) + return fail(Twine("cannot materialize channel layout to result; ") + + materializationReason); + + return success(); +} + +LogicalResult +checkSupportedActivePrefixIndexShape(VMIActivePrefixIndexOp op, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto maskType = cast(op.getMask().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!maskLayout || !resultLayout) + return fail("requires assigned mask and result layouts"); + if (!maskLayout.isContiguous() || !resultLayout.isContiguous()) + return fail("requires contiguous mask and result layouts"); + + std::string resultFullReason; + if (failed(checkFullDataPhysicalChunks(resultType, &resultFullReason))) + return fail(Twine("requires full result physical chunks so padding mask " + "lanes cannot affect the observable prefix; ") + + resultFullReason); + + std::string maskFullReason; + if (failed(checkFullVMIPhysicalChunks(maskType, &maskFullReason))) + return fail(Twine("requires full mask physical chunks so padding mask " + "lanes cannot affect the observable prefix; ") + + maskFullReason); + + FailureOr maskArity = getVMIPhysicalArity(maskType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(maskArity) || failed(resultArity)) + return fail("requires computable mask and result physical arity"); + if (*maskArity != 1 || *resultArity != 1) + return fail("requires a single physical chunk; multi-chunk prefix needs " + "cross-chunk carry"); + + return success(); +} + +LogicalResult checkSupportedCompressShape(VMICompressOp op, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto maskType = cast(op.getMask().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!sourceLayout || !maskLayout || !resultLayout) + return fail("requires assigned source, mask, and result layouts"); + if (!sourceLayout.isContiguous() || !maskLayout.isContiguous() || + !resultLayout.isContiguous()) + return fail("requires contiguous source, mask, and result layouts"); + + std::string fullChunkReason; + if (failed(checkFullDataPhysicalChunks(sourceType, &fullChunkReason))) + return fail(Twine("requires full source physical chunks so padding mask " + "lanes cannot be squeezed into the result; ") + + fullChunkReason); + + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(maskArity) || failed(resultArity)) + return fail("requires computable source, mask, and result physical arity"); + if (*sourceArity != 1 || *maskArity != 1 || *resultArity != 1) + return fail("requires a single physical chunk; multi-chunk compress needs " + "cross-chunk compaction"); + + return success(); +} + +LogicalResult checkSupportedCompressStoreShape( + const VMITargetCapabilityRegistry &capabilities, VMICompressStoreOp op, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto valueType = cast(op.getValue().getType()); + auto maskType = cast(op.getMask().getType()); + VMILayoutAttr valueLayout = valueType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + if (!valueLayout || !maskLayout) + return fail("requires assigned value and mask layouts"); + if (!valueLayout.isContiguous() || !maskLayout.isContiguous()) + return fail("requires contiguous value and mask layouts"); + + VMICapabilityResult destinationCapability = + capabilities.supportsUBPointerMemory( + op.getDestination().getType(), "destination", "pto.vstur", + "pto.vstur stores only to UB"); + if (!destinationCapability.isSupported()) + return fail(destinationCapability.reason); + + std::string fullChunkReason; + if (failed(checkFullDataPhysicalChunks(valueType, &fullChunkReason))) + return fail(Twine("requires full physical chunks so padding mask lanes " + "cannot be squeezed into memory; ") + + fullChunkReason); + + FailureOr valueArity = getVMIPhysicalArity(valueType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + if (failed(valueArity) || failed(maskArity)) + return fail("requires computable value and mask physical arity"); + if (*valueArity != 1 || *maskArity != 1) + return fail("requires a single physical chunk; multi-chunk " + "compress_store needs cross-chunk compaction and SQZN " + "state planning"); + + return success(); +} + +template +LogicalResult +checkSupportedReduceShape(const VMITargetCapabilityRegistry &capabilities, + OpTy op, VMIReductionKind kind, bool requiresReassoc, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (requiresReassoc && !op->hasAttr("reassoc")) + return fail("requires reassoc attr for pair-wise floating-point vcadd"); + + auto sourceType = cast(op.getSource().getType()); + auto initType = cast(op.getInit().getType()); + auto maskType = cast(op.getMask().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr initLayout = initType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!sourceLayout || !initLayout || !maskLayout || !resultLayout) + return fail("requires assigned source, init, mask, and result layouts"); + if (!sourceLayout.isContiguous() || !initLayout.isContiguous() || + !maskLayout.isContiguous() || !resultLayout.isContiguous()) + return fail("requires contiguous source, init, mask, and result layouts"); + + VMICapabilityResult elementCapability = + capabilities.supportsReductionElementType(kind, + sourceType.getElementType()); + if (!elementCapability.isSupported()) + return fail(elementCapability.reason); + + std::string fullChunkReason; + if (failed(checkFullDataPhysicalChunks(sourceType, &fullChunkReason))) + return fail(Twine("requires full source physical chunks so padding lanes " + "do not participate in the reduction; ") + + fullChunkReason); + + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr initArity = getVMIPhysicalArity(initType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(initArity) || failed(maskArity) || + failed(resultArity)) + return fail("requires computable physical arity"); + if (*sourceArity < 1 || *maskArity != *sourceArity) + return fail("requires source and mask physical arity to match and be " + "non-empty"); + if (*initArity != 1 || *resultArity != 1) + return fail("requires one init and result physical chunk"); + + return success(); +} + +LogicalResult +checkSupportedFmaShape(const VMITargetCapabilityRegistry &capabilities, + VMIFmaOp op, std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto lhsType = cast(op.getLhs().getType()); + VMICapabilityResult elementCapability = + capabilities.supportsElementType(lhsType.getElementType(), + VMIElementPurpose::VMula); + if (!elementCapability.isSupported()) + return fail(elementCapability.reason); + + FailureOr arity = getVMIPhysicalArity(lhsType); + if (failed(arity) || *arity < 1) + return fail("requires computable non-empty physical arity"); + + return success(); +} + +LogicalResult +checkSupportedReluShape(const VMITargetCapabilityRegistry &capabilities, + VMIReluOp op, std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto resultType = cast(op.getResult().getType()); + if (failed(checkSupportedMaskableVReg(capabilities, resultType, reason))) + return failure(); + + VMICapabilityResult elementCapability = + capabilities.supportsElementType(resultType.getElementType(), + VMIElementPurpose::VRelu); + if (!elementCapability.isSupported()) + return fail(elementCapability.reason); + + return success(); +} + +void emitEnsureLayoutMaterializationError(VMIEnsureLayoutOp ensure, + VMIVRegType sourceType, + VMIVRegType resultType, + StringRef reason) { + if (ensure.getResult().hasOneUse()) { + OpOperand &use = *ensure.getResult().use_begin(); + Operation *requester = use.getOwner(); + InFlightDiagnostic diag = + requester->emitError() + << kVMIDiagUnsupportedPrefix << requester->getName() << " operand #" + << use.getOperandNumber() << " has type " << sourceType + << " but requires " << resultType + << "; pto.vmi.ensure_layout cannot materialize this conversion"; + diag.attachNote(ensure.getLoc()) + << "failed helper conversion " << sourceType << " -> " << resultType + << " (" << reason + << "); partial/tail layout materialization requires an explicit " + "packing plan"; + return; + } + + ensure.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.ensure_layout cannot materialize the requested data " + "layout conversion (" + << reason + << "); partial/tail layout materialization requires an explicit " + "packing plan"; +} + +LogicalResult verifySupportedVMIToVPTOOps( + ModuleOp module, const VMITargetCapabilityRegistry &capabilities, + bool enableStableGatherMaskedLoad) { + auto emitMemoryUnsupported = [&](Operation *op, StringRef opName, + VMIVRegType type, Value source, + std::optional constantOffset) + -> WalkResult { + std::string reason; + if (succeeded(checkSupportedLoadShape(capabilities, type, source, + source.getType(), constantOffset, + &reason))) + return WalkResult::advance(); + + op->emitError() + << kVMIDiagUnsupportedPrefix << opName + << " requires full physical chunks without padding lanes or a " + "statically safe full-read footprint (" + << reason << ")"; + return WalkResult::interrupt(); + }; + + auto emitMaskableUnsupported = [&](Operation *op, StringRef opName, + VMIVRegType type) -> WalkResult { + std::string reason; + if (succeeded(checkSupportedMaskableVReg(capabilities, type, &reason))) + return WalkResult::advance(); + + op->emitError() + << kVMIDiagUnsupportedPrefix << opName + << " direct lowering requires physical vreg parts with b8/b16/b32 " + "predicate masks (" + << reason << ")"; + return WalkResult::interrupt(); + }; + + auto emitTargetElementUnsupported = + [&](Operation *op, StringRef opName, VMIVRegType type, + VMIElementPurpose purpose, StringRef elementContract) -> WalkResult { + std::string reason; + if (succeeded(checkSupportedTargetElementVReg( + capabilities, type, purpose, elementContract, &reason))) + return WalkResult::advance(); + + op->emitError() + << kVMIDiagUnsupportedPrefix << opName + << " direct lowering requires " << elementContract + << " and physical vreg parts with b8/b16/b32 predicate masks (" + << reason << ")"; + return WalkResult::interrupt(); + }; + + WalkResult result = module.walk([&](Operation *op) { + if (auto constant = dyn_cast(op)) { + auto denseAttr = dyn_cast(constant.getValue()); + if (!denseAttr || !denseAttr.isSplat()) { + constant.emitError() + << kVMIDiagUnsupportedPrefix + << "non-splat pto.vmi.constant requires a vreg immediate or " + "scratch materialization plan"; + return WalkResult::interrupt(); + } + return emitMaskableUnsupported( + op, "pto.vmi.constant", + cast(constant.getResult().getType())); + } + + if (auto broadcast = dyn_cast(op)) + return emitMaskableUnsupported( + op, "pto.vmi.broadcast", + cast(broadcast.getResult().getType())); + + if (auto load = dyn_cast(op)) + return emitMemoryUnsupported( + op, "pto.vmi.load", cast(load.getResult().getType()), + load.getSource(), getConstantIndexValue(load.getOffset())); + if (auto load = dyn_cast(op)) { + if (enableStableGatherMaskedLoad) { + load.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.masked_load stable VGATHER-based lowering is reserved " + "for strict masked/tail loads but is not implemented yet"; + return WalkResult::interrupt(); + } + std::string reason; + if (succeeded(checkSupportedMaskedLoadShape(capabilities, load, + &reason))) + return WalkResult::advance(); + load.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.masked_load direct lowering requires a supported memory " + "source, contiguous result/passthru/mask layouts, and either " + "full physical chunks or a statically safe full-read footprint (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto gather = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedGatherShape(capabilities, gather, &reason))) + return WalkResult::advance(); + gather.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.gather lowers through pto.vgather2_bc + pto.vsel only " + "for UB pointer sources, contiguous full physical chunks, " + "32-bit result elements, i32 indices, and b32 masks (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto load = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedExpandLoadShape(capabilities, load, + &reason))) + return WalkResult::advance(); + load.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.expand_load direct lowering is currently supported for " + "either a static all-active mask lowered as pto.vlds, or a " + "one-full-chunk 32-bit UB runtime mask lowered through pto.vusqz " + "+ pto.vgather2_bc + pto.vsel (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto store = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedStoreShape( + capabilities, cast(store.getValue().getType()), + store.getDestination(), store.getDestination().getType(), + &reason))) + return WalkResult::advance(); + store.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.store requires an 8/16/32-bit predicate-maskable " + "element type and either full physical chunks or contiguous " + "tail-store layout, with UB-backed destination (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto store = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedMaskedStoreShape( + capabilities, cast(store.getValue().getType()), + cast(store.getMask().getType()), + store.getDestination(), store.getDestination().getType(), + &reason))) + return WalkResult::advance(); + store.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.masked_store requires either full physical chunks or " + "contiguous tail-store value/mask layout, with UB-backed " + "destination (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto scatter = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedScatterShape(capabilities, scatter, + &reason))) + return WalkResult::advance(); + scatter.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.scatter lowers through pto.vscatter only with an " + "indices_unique proof, UB pointer destination, contiguous full " + "physical chunks, 32-bit value elements, i32 indices, and b32 " + "masks (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto tileRead = dyn_cast(op)) + return emitMemoryUnsupported( + op, "pto.vmi.tile_read", + cast(tileRead.getResult().getType()), + tileRead.getSource(), 0); + if (auto tileWrite = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedStoreShape( + capabilities, + cast(tileWrite.getValue().getType()), + tileWrite.getDestination(), tileWrite.getDestination().getType(), + &reason))) + return WalkResult::advance(); + tileWrite.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.tile_write requires an 8/16/32-bit predicate-maskable " + "element type and either full physical chunks or contiguous " + "tail-store layout, with UB-backed destination (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto ensure = dyn_cast(op)) { + auto sourceType = cast(ensure.getSource().getType()); + auto resultType = cast(ensure.getResult().getType()); + std::string reason; + if (succeeded(checkSupportedLayoutMaterialization( + capabilities, sourceType, resultType, sourceType.getLayoutAttr(), + resultType.getLayoutAttr(), &reason))) + return WalkResult::advance(); + + emitEnsureLayoutMaterializationError(ensure, sourceType, resultType, + reason); + return WalkResult::interrupt(); + } + + if (auto ensure = dyn_cast(op)) { + auto sourceType = cast(ensure.getSource().getType()); + auto resultType = cast(ensure.getResult().getType()); + std::string reason; + if (succeeded(checkSupportedLayoutMaterialization( + capabilities, sourceType, resultType, sourceType.getLayoutAttr(), + resultType.getLayoutAttr(), &reason))) + return WalkResult::advance(); + + ensure.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.ensure_mask_layout cannot materialize the requested " + "mask layout conversion (" + << reason + << "); partial/tail predicate layout materialization requires an " + "explicit packing plan"; + return WalkResult::interrupt(); + } + + if (auto ensure = dyn_cast(op)) { + auto sourceType = cast(ensure.getSource().getType()); + auto resultType = cast(ensure.getResult().getType()); + if (sourceType.getGranularity() == resultType.getGranularity()) + return WalkResult::advance(); + + std::string reason; + if (succeeded(checkSupportedMaskGranularityMaterialization( + capabilities, sourceType, resultType, &reason))) + return WalkResult::advance(); + + ensure.emitError() + << kVMIDiagUnsupportedPrefix + << "non-identity mask granularity materialization requires concrete " + "b8/b16/b32 masks with matching lane count and layout (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto addf = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.addf", + cast(addf.getResult().getType()), + VMIElementPurpose::F16BF16F32, + "f16/bf16/f32 element type"); + if (auto addi = dyn_cast(op)) + return emitMaskableUnsupported(op, "pto.vmi.addi", + cast( + addi.getResult().getType())); + if (auto subf = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.subf", + cast(subf.getResult().getType()), + VMIElementPurpose::F16BF16F32, + "f16/bf16/f32 element type"); + if (auto subi = dyn_cast(op)) + return emitMaskableUnsupported(op, "pto.vmi.subi", + cast( + subi.getResult().getType())); + if (auto mulf = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.mulf", + cast(mulf.getResult().getType()), + VMIElementPurpose::F16BF16F32, + "f16/bf16/f32 element type"); + if (auto muli = dyn_cast(op)) + return emitMaskableUnsupported(op, "pto.vmi.muli", + cast( + muli.getResult().getType())); + if (auto divf = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.divf", + cast(divf.getResult().getType()), + VMIElementPurpose::F16F32, + "f16/f32 element type"); + if (auto minf = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.minf", + cast(minf.getResult().getType()), + VMIElementPurpose::F16BF16F32, + "f16/bf16/f32 element type"); + if (auto maxf = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.maxf", + cast(maxf.getResult().getType()), + VMIElementPurpose::F16BF16F32, + "f16/bf16/f32 element type"); + if (auto negf = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.negf", + cast(negf.getResult().getType()), + VMIElementPurpose::F16F32, + "f16/f32 element type"); + if (auto absf = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.absf", + cast(absf.getResult().getType()), + VMIElementPurpose::F16F32, + "f16/f32 element type"); + if (auto absi = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.absi", + cast(absi.getResult().getType()), + VMIElementPurpose::SignlessOrSignedI8I16I32, + "signless/signed i8/i16/i32 element type"); + if (auto sqrt = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.sqrt", + cast(sqrt.getResult().getType()), + VMIElementPurpose::F16F32, + "f16/f32 element type"); + if (auto exp = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.exp", + cast(exp.getResult().getType()), + VMIElementPurpose::F16F32, + "f16/f32 element type"); + if (auto ln = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.ln", + cast(ln.getResult().getType()), + VMIElementPurpose::F16F32, + "f16/f32 element type"); + if (auto relu = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedReluShape(capabilities, relu, &reason))) + return WalkResult::advance(); + relu.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.relu direct lowering requires physical vreg parts with " + "b8/b16/b32 predicate masks and f16/f32 element type (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto andi = dyn_cast(op)) + return emitMaskableUnsupported(op, "pto.vmi.andi", + cast( + andi.getResult().getType())); + if (auto ori = dyn_cast(op)) + return emitMaskableUnsupported(op, "pto.vmi.ori", + cast(ori.getResult().getType())); + if (auto xori = dyn_cast(op)) + return emitMaskableUnsupported(op, "pto.vmi.xori", + cast( + xori.getResult().getType())); + if (auto shli = dyn_cast(op)) + return emitMaskableUnsupported(op, "pto.vmi.shli", + cast( + shli.getResult().getType())); + if (auto shrui = dyn_cast(op)) + return emitMaskableUnsupported(op, "pto.vmi.shrui", + cast( + shrui.getResult().getType())); + if (auto notOp = dyn_cast(op)) + return emitMaskableUnsupported(op, "pto.vmi.not", + cast( + notOp.getResult().getType())); + if (auto select = dyn_cast(op)) + return emitMaskableUnsupported(op, "pto.vmi.select", + cast( + select.getResult().getType())); + + if (auto cmpf = dyn_cast(op)) { + WalkResult target = emitTargetElementUnsupported( + op, "pto.vmi.cmpf", cast(cmpf.getLhs().getType()), + VMIElementPurpose::F16BF16F32, "f16/bf16/f32 element type"); + if (target.wasInterrupted()) + return target; + if (succeeded(checkSupportedComparePredicate(op, cmpf.getPredicate()))) + return WalkResult::advance(); + return WalkResult::interrupt(); + } + + if (auto cmpi = dyn_cast(op)) { + WalkResult target = emitTargetElementUnsupported( + op, "pto.vmi.cmpi", cast(cmpi.getLhs().getType()), + VMIElementPurpose::AnyI8I16I32, + "signless/signed/unsigned i8/i16/i32 element type"); + if (target.wasInterrupted()) + return target; + if (succeeded(checkSupportedComparePredicate(op, cmpi.getPredicate()))) + return WalkResult::advance(); + return WalkResult::interrupt(); + } + + if (auto activePrefix = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedActivePrefixIndexShape(activePrefix, + &reason))) + return WalkResult::advance(); + activePrefix.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.active_prefix_index lowers through pto.vusqz only for " + "one contiguous physical chunk (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto compress = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedCompressShape(compress, &reason))) + return WalkResult::advance(); + compress.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.compress lowers through pto.vsqz only for one " + "contiguous full physical chunk (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto compressStore = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedCompressStoreShape(capabilities, + compressStore, &reason))) + return WalkResult::advance(); + compressStore.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.compress_store lowers through pto.vsqz + pto.vstur " + "only for one contiguous full physical chunk with a UB pointer " + "destination (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto reduce = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedReduceShape( + capabilities, reduce, VMIReductionKind::AddI, + /*requiresReassoc=*/false, &reason))) + return WalkResult::advance(); + reduce.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.reduce_addi lowers through pto.vcadd only for " + "contiguous full 32-bit integer source chunks with matching " + "mask chunks and one init/result chunk (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto reduce = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedReduceShape( + capabilities, reduce, VMIReductionKind::AddF, + /*requiresReassoc=*/true, &reason))) + return WalkResult::advance(); + reduce.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.reduce_addf lowers through pto.vcadd only with " + "reassoc, f32 contiguous full source chunks, matching mask " + "chunks, and one init/result chunk (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto reduce = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedReduceShape( + capabilities, reduce, VMIReductionKind::MaxF, + /*requiresReassoc=*/false, &reason))) + return WalkResult::advance(); + reduce.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.reduce_maxf lowers through pto.vcmax only for f16/f32 " + "contiguous full source chunks with matching mask chunks and one " + "init/result chunk (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto reduce = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedReduceShape( + capabilities, reduce, VMIReductionKind::MinF, + /*requiresReassoc=*/false, &reason))) + return WalkResult::advance(); + reduce.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.reduce_minf lowers through pto.vcmin only for f16/f32 " + "contiguous full source chunks with matching mask chunks and one " + "init/result chunk (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto fma = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedFmaShape(capabilities, fma, &reason))) + return WalkResult::advance(); + fma.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.fma lowers through pto.vmula only for f16/bf16/f32 " + "element types (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto extf = dyn_cast(op)) { + if (succeeded(checkSupportedExtFShape(extf))) + return WalkResult::advance(); + + extf.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.extf supports contiguous 16-bit float-like or fp8-like " + "physical source chunks to f32 deinterleaved=2/4 results; " + "partial/tail is allowed only when source padding maps to result " + "padding"; + return WalkResult::interrupt(); + } + + if (auto truncf = dyn_cast(op)) { + if (succeeded(checkSupportedTruncFShape(truncf))) + return WalkResult::advance(); + + truncf.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.truncf supports only f32 deinterleaved=2 source parts " + "to one contiguous f16 result chunk or f32 deinterleaved=4 " + "source parts to one contiguous fp8-like result chunk"; + return WalkResult::interrupt(); + } + + if (auto bitcast = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedBitcastShape(bitcast, &reason))) + return WalkResult::advance(); + + bitcast.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.bitcast requires matching source/result layouts with " + "identical physical arity and matching per-chunk logical bit " + "footprints (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto split = dyn_cast(op)) { + int64_t channels = split.getNumResults(); + std::string reason; + if (succeeded(checkSupportedChannelSplitShape(capabilities, split, + &reason))) + return WalkResult::advance(); + + if (channels != 2 && channels != 4) + split.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.channel_split supports only 2 or 4 channels"; + else + split.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.channel_split requires source layout to be contiguous " + "or matching deinterleaved channel layout, every result layout " + "to be contiguous, and complete physical channel groups (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto merge = dyn_cast(op)) { + int64_t channels = merge.getInputs().size(); + std::string reason; + if (succeeded(checkSupportedChannelMergeShape(capabilities, merge, + &reason))) + return WalkResult::advance(); + + if (channels != 2 && channels != 4) + merge.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.channel_merge supports only 2 or 4 channels"; + else + merge.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.channel_merge requires every input layout to be " + "contiguous and result layout to be contiguous or matching " + "deinterleaved channel layout, with complete physical channel " + "groups (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto shuffle = dyn_cast(op)) { + std::string reason; + if (succeeded(computeShuffleForwardingSourceParts(shuffle, &reason))) + return WalkResult::advance(); + std::string splatReason; + if (succeeded(computeShuffleLane0SplatSourcePart(shuffle, &splatReason))) + return WalkResult::advance(); + std::string vselrReason; + if (succeeded(computeShuffleVselrPlans(shuffle, &vselrReason))) + return WalkResult::advance(); + + shuffle.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.shuffle requires physical chunk forwarding or " + "lane0 splat or vci-materializable vselr indices (forwarding: " + << reason << "; lane0 splat: " << splatReason + << "; vselr: " << vselrReason << ")"; + return WalkResult::interrupt(); + } + + if (auto constantMask = dyn_cast(op)) { + std::string reason; + if (succeeded(computeConstantMaskMaterialization(constantMask, &reason))) + return WalkResult::advance(); + + constantMask.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.constant_mask requires a dense bool constant with " + "concrete layout and b8/b16/b32 granularity (" + << reason << ")"; + return WalkResult::interrupt(); + } + + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); +} + +struct VMIToVPTOPass + : public mlir::pto::impl::VMIToVPTOBase { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VMIToVPTOPass) + + void runOnOperation() override { + ModuleOp module = getOperation(); + if (failed(verifyVMIToVPTOInputIR(module))) { + signalPassFailure(); + return; + } + VMITargetCapabilityRegistry capabilities; + if (failed(verifySupportedVMIToVPTOOps( + module, capabilities, enableStableGatherMaskedLoad))) { + signalPassFailure(); + return; + } + + MLIRContext *context = module.getContext(); + VMIToVPTOTypeConverter typeConverter; + RewritePatternSet patterns(context); + + populateVMIOneToNConversionPatterns(typeConverter, patterns, + capabilities); + if (failed(applyPartialOneToNConversion(module, typeConverter, + std::move(patterns)))) { + module.emitError() + << kVMIDiagResidualOpPrefix + << "failed to convert all VMI ops/types to VPTO"; + signalPassFailure(); + return; + } + if (failed(verifyNoResidualVMIIR(module))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createVMIToVPTOPass() { + return std::make_unique(); +} diff --git a/test/lit/CMakeLists.txt b/test/lit/CMakeLists.txt index 684ae0bf50..9fb0e63c5b 100644 --- a/test/lit/CMakeLists.txt +++ b/test/lit/CMakeLists.txt @@ -27,6 +27,7 @@ configure_lit_site_cfg( set(PTOIR_TEST_DEPENDS FileCheck count not pto-opt + pto-test-opt ) add_lit_testsuite(check-pto "Running the pto regression tests" diff --git a/test/lit/lit.cfg.py b/test/lit/lit.cfg.py index 9a81959f47..43cb6724e0 100644 --- a/test/lit/lit.cfg.py +++ b/test/lit/lit.cfg.py @@ -40,6 +40,8 @@ # test_exec_root: The root path where tests should be run. config.test_exec_root = os.path.join(config.ptoir_obj_root, 'test/lit') config.ptoir_tools_dir = os.path.join(config.ptoir_obj_root, 'tools/ptoas') +config.ptoir_test_tools_dir = os.path.join(config.ptoir_obj_root, + 'tools/pto-test-opt') config.substitutions.append(('%PATH%', config.environment['PATH'])) config.substitutions.append(('%shlibext', config.llvm_shlib_ext)) @@ -57,9 +59,11 @@ # Tweak the PATH to include the tools dir. llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True) -tool_dirs = [config.ptoir_tools_dir, config.llvm_tools_dir] +tool_dirs = [config.ptoir_tools_dir, config.ptoir_test_tools_dir, + config.llvm_tools_dir] tools = [ 'ptoas', + 'pto-test-opt', ] llvm_config.add_tool_substitutions(tools, tool_dirs) diff --git a/test/lit/vmi/vmi_absf_integer_invalid.pto b/test/lit/vmi/vmi_absf_integer_invalid.pto new file mode 100644 index 0000000000..2a3900e4e5 --- /dev/null +++ b/test/lit/vmi/vmi_absf_integer_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_absf_integer_invalid(%value: !pto.vmi.vreg<128xi16>) { + %abs = pto.vmi.absf %value + : !pto.vmi.vreg<128xi16> -> !pto.vmi.vreg<128xi16> + return + } +} + +// CHECK: 'pto.vmi.absf' op requires floating-point-like VMI element type diff --git a/test/lit/vmi/vmi_absi_float_invalid.pto b/test/lit/vmi/vmi_absi_float_invalid.pto new file mode 100644 index 0000000000..0f2d556c1a --- /dev/null +++ b/test/lit/vmi/vmi_absi_float_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_absi_float_invalid(%value: !pto.vmi.vreg<64xf32>) { + %abs = pto.vmi.absi %value + : !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf32> + return + } +} + +// CHECK: 'pto.vmi.absi' op requires integer-like VMI element type diff --git a/test/lit/vmi/vmi_active_prefix_index_result_type_invalid.pto b/test/lit/vmi/vmi_active_prefix_index_result_type_invalid.pto new file mode 100644 index 0000000000..c675b2e6e9 --- /dev/null +++ b/test/lit/vmi/vmi_active_prefix_index_result_type_invalid.pto @@ -0,0 +1,21 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_active_prefix_index_result_type_invalid( + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + %idx = pto.vmi.active_prefix_index %mask + : !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: 'pto.vmi.active_prefix_index' op requires signless integer result element type diff --git a/test/lit/vmi/vmi_addf_lane_mismatch_invalid.pto b/test/lit/vmi/vmi_addf_lane_mismatch_invalid.pto new file mode 100644 index 0000000000..bd6ed94bac --- /dev/null +++ b/test/lit/vmi/vmi_addf_lane_mismatch_invalid.pto @@ -0,0 +1,21 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_addf_lane_mismatch_invalid( + %a: !pto.vmi.vreg<128xf32>, + %b: !pto.vmi.vreg<64xf32>) { + %r = pto.vmi.addf %a, %b + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: requires all VMI data values to have the same logical lane count diff --git a/test/lit/vmi/vmi_bitcast_total_bits_invalid.pto b/test/lit/vmi/vmi_bitcast_total_bits_invalid.pto new file mode 100644 index 0000000000..937889d014 --- /dev/null +++ b/test/lit/vmi/vmi_bitcast_total_bits_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_bitcast_total_bits_invalid(%value: !pto.vmi.vreg<128xf32>) { + %cast = pto.vmi.bitcast %value + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xi16> + return + } +} + +// CHECK: 'pto.vmi.bitcast' op requires source and result to carry the same total number of bits diff --git a/test/lit/vmi/vmi_bitwise_float_invalid.pto b/test/lit/vmi/vmi_bitwise_float_invalid.pto new file mode 100644 index 0000000000..60f260d444 --- /dev/null +++ b/test/lit/vmi/vmi_bitwise_float_invalid.pto @@ -0,0 +1,64 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file 2>&1 | FileCheck %s + +module { + func.func @vmi_andi_float_invalid( + %lhs: !pto.vmi.vreg<128xf32>, + %rhs: !pto.vmi.vreg<128xf32>) { + %out = pto.vmi.andi %lhs, %rhs + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: 'pto.vmi.andi' op requires integer-like VMI element type + +// ----- + +module { + func.func @vmi_ori_float_invalid( + %lhs: !pto.vmi.vreg<128xf32>, + %rhs: !pto.vmi.vreg<128xf32>) { + %out = pto.vmi.ori %lhs, %rhs + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: 'pto.vmi.ori' op requires integer-like VMI element type + +// ----- + +module { + func.func @vmi_xori_float_invalid( + %lhs: !pto.vmi.vreg<128xf32>, + %rhs: !pto.vmi.vreg<128xf32>) { + %out = pto.vmi.xori %lhs, %rhs + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: 'pto.vmi.xori' op requires integer-like VMI element type + +// ----- + +module { + func.func @vmi_not_float_invalid(%source: !pto.vmi.vreg<128xf32>) { + %out = pto.vmi.not %source + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: 'pto.vmi.not' op requires integer-like VMI element type diff --git a/test/lit/vmi/vmi_broadcast_type_mismatch_invalid.pto b/test/lit/vmi/vmi_broadcast_type_mismatch_invalid.pto new file mode 100644 index 0000000000..9ecdc9469f --- /dev/null +++ b/test/lit/vmi/vmi_broadcast_type_mismatch_invalid.pto @@ -0,0 +1,18 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_broadcast_type_mismatch_invalid(%value: f16) { + %result = pto.vmi.broadcast %value : f16 -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: requires scalar or VMI vector input element type to match result element type diff --git a/test/lit/vmi/vmi_channel_merge_input_mismatch_invalid.pto b/test/lit/vmi/vmi_channel_merge_input_mismatch_invalid.pto new file mode 100644 index 0000000000..1dbc569c4c --- /dev/null +++ b/test/lit/vmi/vmi_channel_merge_input_mismatch_invalid.pto @@ -0,0 +1,21 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_channel_merge_input_mismatch_invalid( + %ch0: !pto.vmi.vreg<2xf32>, + %ch1: !pto.vmi.vreg<3xf32>) { + %merged = "pto.vmi.channel_merge"(%ch0, %ch1) + : (!pto.vmi.vreg<2xf32>, !pto.vmi.vreg<3xf32>) -> !pto.vmi.vreg<5xf32> + return + } +} + +// CHECK: requires all channel inputs to have the same lane count and element type diff --git a/test/lit/vmi/vmi_channel_merge_result_mismatch_invalid.pto b/test/lit/vmi/vmi_channel_merge_result_mismatch_invalid.pto new file mode 100644 index 0000000000..f5c7ad94b9 --- /dev/null +++ b/test/lit/vmi/vmi_channel_merge_result_mismatch_invalid.pto @@ -0,0 +1,21 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_channel_merge_result_mismatch_invalid( + %ch0: !pto.vmi.vreg<2xf32>, + %ch1: !pto.vmi.vreg<2xf32>) { + %merged = "pto.vmi.channel_merge"(%ch0, %ch1) + : (!pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32>) -> !pto.vmi.vreg<5xf32> + return + } +} + +// CHECK: requires result lane count and element type to match merged channels diff --git a/test/lit/vmi/vmi_channel_split_lane_count_invalid.pto b/test/lit/vmi/vmi_channel_split_lane_count_invalid.pto new file mode 100644 index 0000000000..bbf923b079 --- /dev/null +++ b/test/lit/vmi/vmi_channel_split_lane_count_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_channel_split_lane_count_invalid( + %src: !pto.vmi.vreg<5xf32>) { + %ch0, %ch1 = "pto.vmi.channel_split"(%src) + : (!pto.vmi.vreg<5xf32>) -> (!pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32>) + return + } +} + +// CHECK: requires source lane count to equal result count times per-channel lane count diff --git a/test/lit/vmi/vmi_channel_split_result_count_invalid.pto b/test/lit/vmi/vmi_channel_split_result_count_invalid.pto new file mode 100644 index 0000000000..bbe2b434d6 --- /dev/null +++ b/test/lit/vmi/vmi_channel_split_result_count_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_channel_split_result_count_invalid( + %src: !pto.vmi.vreg<4xf32>) { + %ch0 = "pto.vmi.channel_split"(%src) + : (!pto.vmi.vreg<4xf32>) -> !pto.vmi.vreg<4xf32> + return + } +} + +// CHECK: requires at least two channel results diff --git a/test/lit/vmi/vmi_compress_result_mismatch_invalid.pto b/test/lit/vmi/vmi_compress_result_mismatch_invalid.pto new file mode 100644 index 0000000000..7e7e6bb66f --- /dev/null +++ b/test/lit/vmi/vmi_compress_result_mismatch_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_compress_result_mismatch_invalid( + %src: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + %out = pto.vmi.compress %src, %mask + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xi32, #pto.vmi.layout> + return + } +} + +// CHECK: 'pto.vmi.compress' op requires all VMI data values to have the same element type diff --git a/test/lit/vmi/vmi_constant_attr_kind_invalid.pto b/test/lit/vmi/vmi_constant_attr_kind_invalid.pto new file mode 100644 index 0000000000..c1ff60fe3b --- /dev/null +++ b/test/lit/vmi/vmi_constant_attr_kind_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s 2>&1 | FileCheck %s + +module { + func.func @vmi_constant_attr_kind_invalid() { + %value = "pto.vmi.constant"() { + value = 1.000000e+00 : f32 + } : () -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: requires dense elements constant attribute diff --git a/test/lit/vmi/vmi_constant_element_count_invalid.pto b/test/lit/vmi/vmi_constant_element_count_invalid.pto new file mode 100644 index 0000000000..b5e80ce364 --- /dev/null +++ b/test/lit/vmi/vmi_constant_element_count_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s 2>&1 | FileCheck %s + +module { + func.func @vmi_constant_element_count_invalid() { + %value = "pto.vmi.constant"() { + value = dense<1.000000e+00> : tensor<64xf32> + } : () -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: requires dense constant element count to match result logical lane count diff --git a/test/lit/vmi/vmi_constant_element_type_invalid.pto b/test/lit/vmi/vmi_constant_element_type_invalid.pto new file mode 100644 index 0000000000..29a5f2d22a --- /dev/null +++ b/test/lit/vmi/vmi_constant_element_type_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s 2>&1 | FileCheck %s + +module { + func.func @vmi_constant_element_type_invalid() { + %value = "pto.vmi.constant"() { + value = dense<1> : tensor<128xi32> + } : () -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: requires dense constant element type to match result element type diff --git a/test/lit/vmi/vmi_constant_mask_attr_kind_invalid.pto b/test/lit/vmi/vmi_constant_mask_attr_kind_invalid.pto new file mode 100644 index 0000000000..537d007f03 --- /dev/null +++ b/test/lit/vmi/vmi_constant_mask_attr_kind_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s 2>&1 | FileCheck %s + +module { + func.func @vmi_constant_mask_attr_kind_invalid() { + %mask = "pto.vmi.constant_mask"() { + value = true + } : () -> !pto.vmi.mask<128xpred> + return + } +} + +// CHECK: requires dense elements mask constant attribute diff --git a/test/lit/vmi/vmi_constant_mask_element_count_invalid.pto b/test/lit/vmi/vmi_constant_mask_element_count_invalid.pto new file mode 100644 index 0000000000..f39f4ab00a --- /dev/null +++ b/test/lit/vmi/vmi_constant_mask_element_count_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s 2>&1 | FileCheck %s + +module { + func.func @vmi_constant_mask_element_count_invalid() { + %mask = "pto.vmi.constant_mask"() { + value = dense : tensor<64xi1> + } : () -> !pto.vmi.mask<128xpred> + return + } +} + +// CHECK: requires dense mask constant element count to match result logical lane count diff --git a/test/lit/vmi/vmi_constant_mask_element_type_invalid.pto b/test/lit/vmi/vmi_constant_mask_element_type_invalid.pto new file mode 100644 index 0000000000..7f97a4afd6 --- /dev/null +++ b/test/lit/vmi/vmi_constant_mask_element_type_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s 2>&1 | FileCheck %s + +module { + func.func @vmi_constant_mask_element_type_invalid() { + %mask = "pto.vmi.constant_mask"() { + value = dense<1> : tensor<128xi32> + } : () -> !pto.vmi.mask<128xpred> + return + } +} + +// CHECK: requires dense mask constant element type to be i1 diff --git a/test/lit/vmi/vmi_divf_integer_invalid.pto b/test/lit/vmi/vmi_divf_integer_invalid.pto new file mode 100644 index 0000000000..0c26d668b3 --- /dev/null +++ b/test/lit/vmi/vmi_divf_integer_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_divf_integer_invalid( + %lhs: !pto.vmi.vreg<128xi32>, + %rhs: !pto.vmi.vreg<128xi32>) { + %quotient = pto.vmi.divf %lhs, %rhs + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + return + } +} + +// CHECK: 'pto.vmi.divf' op requires floating-point-like VMI element type diff --git a/test/lit/vmi/vmi_elementwise_kind_invalid.pto b/test/lit/vmi/vmi_elementwise_kind_invalid.pto new file mode 100644 index 0000000000..46e8255de8 --- /dev/null +++ b/test/lit/vmi/vmi_elementwise_kind_invalid.pto @@ -0,0 +1,63 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file 2>&1 | FileCheck %s + +module { + func.func @vmi_subf_integer_invalid( + %lhs: !pto.vmi.vreg<128xi32>, %rhs: !pto.vmi.vreg<128xi32>) { + %out = pto.vmi.subf %lhs, %rhs + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + return + } +} + +// CHECK: 'pto.vmi.subf' op requires floating-point-like VMI element type + +// ----- + +module { + func.func @vmi_subi_float_invalid( + %lhs: !pto.vmi.vreg<128xf32>, %rhs: !pto.vmi.vreg<128xf32>) { + %out = pto.vmi.subi %lhs, %rhs + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: 'pto.vmi.subi' op requires integer-like VMI element type + +// ----- + +module { + func.func @vmi_mulf_integer_invalid( + %lhs: !pto.vmi.vreg<128xi32>, %rhs: !pto.vmi.vreg<128xi32>) { + %out = pto.vmi.mulf %lhs, %rhs + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + return + } +} + +// CHECK: 'pto.vmi.mulf' op requires floating-point-like VMI element type + +// ----- + +module { + func.func @vmi_muli_float_invalid( + %lhs: !pto.vmi.vreg<128xf32>, %rhs: !pto.vmi.vreg<128xf32>) { + %out = pto.vmi.muli %lhs, %rhs + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: 'pto.vmi.muli' op requires integer-like VMI element type diff --git a/test/lit/vmi/vmi_ensure_layout_surface_invalid.pto b/test/lit/vmi/vmi_ensure_layout_surface_invalid.pto new file mode 100644 index 0000000000..09a92692de --- /dev/null +++ b/test/lit/vmi/vmi_ensure_layout_surface_invalid.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file 2>&1 | FileCheck %s + +module { + func.func @vmi_ensure_layout_surface_invalid(%a: !pto.vmi.vreg<128xf32>) { + %r = pto.vmi.ensure_layout %a + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} + +// CHECK: requires source and result to be layout-assigned + +// ----- + +module { + func.func @vmi_ensure_mask_granularity_surface_invalid( + %a: !pto.vmi.mask<128xpred>) { + %r = pto.vmi.ensure_mask_granularity %a + : !pto.vmi.mask<128xpred> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + return + } +} + +// CHECK: requires source and result to be layout-assigned + +// ----- + +module { + func.func @vmi_ensure_mask_granularity_layout_mismatch_invalid( + %a: !pto.vmi.mask<128xb16, #pto.vmi.layout>) { + %r = pto.vmi.ensure_mask_granularity %a + : !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + return + } +} + +// CHECK: requires source and result mask layouts to match diff --git a/test/lit/vmi/vmi_extf_direction_invalid.pto b/test/lit/vmi/vmi_extf_direction_invalid.pto new file mode 100644 index 0000000000..e00280a69d --- /dev/null +++ b/test/lit/vmi/vmi_extf_direction_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_extf_direction_invalid(%source: !pto.vmi.vreg<128xf32>) { + %result = pto.vmi.extf %source + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + return + } +} + +// CHECK: requires result element type to be wider than source element type diff --git a/test/lit/vmi/vmi_extf_lane_mismatch_invalid.pto b/test/lit/vmi/vmi_extf_lane_mismatch_invalid.pto new file mode 100644 index 0000000000..d1b64fc15d --- /dev/null +++ b/test/lit/vmi/vmi_extf_lane_mismatch_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_extf_lane_mismatch_invalid(%source: !pto.vmi.vreg<64xf16>) { + %result = pto.vmi.extf %source + : !pto.vmi.vreg<64xf16> -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: requires source and result logical lane counts to match diff --git a/test/lit/vmi/vmi_fma_integer_invalid.pto b/test/lit/vmi/vmi_fma_integer_invalid.pto new file mode 100644 index 0000000000..e44d8879b3 --- /dev/null +++ b/test/lit/vmi/vmi_fma_integer_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_fma_integer_invalid( + %lhs: !pto.vmi.vreg<64xi32>, + %rhs: !pto.vmi.vreg<64xi32>, + %acc: !pto.vmi.vreg<64xi32>) -> !pto.vmi.vreg<64xi32> { + %out = pto.vmi.fma %lhs, %rhs, %acc + : !pto.vmi.vreg<64xi32>, !pto.vmi.vreg<64xi32>, + !pto.vmi.vreg<64xi32> -> !pto.vmi.vreg<64xi32> + return %out : !pto.vmi.vreg<64xi32> + } +} + +// CHECK: 'pto.vmi.fma' op requires floating-point-like VMI element type diff --git a/test/lit/vmi/vmi_gather_indices_invalid.pto b/test/lit/vmi/vmi_gather_indices_invalid.pto new file mode 100644 index 0000000000..057e3d1244 --- /dev/null +++ b/test/lit/vmi/vmi_gather_indices_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_gather_indices_invalid( + %src: !pto.ptr, + %indices: !pto.vmi.vreg<64xf32>, + %mask: !pto.vmi.mask<64xpred>, + %passthru: !pto.vmi.vreg<64xf32>) { + %out = pto.vmi.gather %src[%indices], %mask, %passthru + : !pto.ptr, !pto.vmi.vreg<64xf32>, + !pto.vmi.mask<64xpred>, !pto.vmi.vreg<64xf32> + -> !pto.vmi.vreg<64xf32> + return + } +} + +// CHECK: 'pto.vmi.gather' op requires signless or unsigned 32-bit integer indices diff --git a/test/lit/vmi/vmi_iota_element_type_invalid.pto b/test/lit/vmi/vmi_iota_element_type_invalid.pto new file mode 100644 index 0000000000..448fba485f --- /dev/null +++ b/test/lit/vmi/vmi_iota_element_type_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_iota_element_type_invalid(%base: i64) { + %value = pto.vmi.iota %base + : i64 -> !pto.vmi.vreg<64xi64> + return + } +} + +// CHECK: 'pto.vmi.iota' op requires result element type to be integer 8/16/32 or f16/f32 diff --git a/test/lit/vmi/vmi_iota_order_invalid.pto b/test/lit/vmi/vmi_iota_order_invalid.pto new file mode 100644 index 0000000000..93df56591c --- /dev/null +++ b/test/lit/vmi/vmi_iota_order_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_iota_order_invalid(%base: i32) { + %value = pto.vmi.iota %base {order = "DOWN"} + : i32 -> !pto.vmi.vreg<64xi32> + return + } +} + +// CHECK: 'pto.vmi.iota' op requires order to be ASC or DESC diff --git a/test/lit/vmi/vmi_layout_assignment_active_prefix_index.pto b/test/lit/vmi/vmi_layout_assignment_active_prefix_index.pto new file mode 100644 index 0000000000..5dabf59203 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_active_prefix_index.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_active_prefix_index(%mask: !pto.vmi.mask<64xpred>) + -> !pto.vmi.vreg<64xi32> { + %idx = pto.vmi.active_prefix_index %mask + : !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<64xi32> + return %idx : !pto.vmi.vreg<64xi32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_active_prefix_index( +// CHECK-SAME: %[[MASK:.*]]: !pto.vmi.mask<64xb32, #pto.vmi.layout>) +// CHECK-SAME: -> !pto.vmi.vreg<64xi32, #pto.vmi.layout> +// CHECK: %[[IDX:.*]] = pto.vmi.active_prefix_index %[[MASK]] +// CHECK-SAME: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<64xi32, #pto.vmi.layout> +// CHECK: return %[[IDX]] diff --git a/test/lit/vmi/vmi_layout_assignment_broadcast_remat.pto b/test/lit/vmi/vmi_layout_assignment_broadcast_remat.pto new file mode 100644 index 0000000000..6e165de8a0 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_broadcast_remat.pto @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_broadcast_remat( + %scalar: f32, + %src: !pto.vmi.vreg<128xf16>, + %dst: !pto.ptr, + %offset: index) -> !pto.vmi.vreg<128xf32> { + %broadcast = pto.vmi.broadcast %scalar + : f32 -> !pto.vmi.vreg<128xf32> + %wide = pto.vmi.extf %src + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.addf %broadcast, %wide + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.store %broadcast, %dst[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return %sum : !pto.vmi.vreg<128xf32> + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_broadcast_remat( +// ASSIGN-SAME: %[[SCALAR:.*]]: f32 +// ASSIGN-SAME: %[[SRC:.*]]: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[BCAST_DEINT:.*]] = pto.vmi.broadcast %[[SCALAR]] +// ASSIGN-SAME: f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[WIDE:.*]] = pto.vmi.extf %[[SRC]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.addf %[[BCAST_DEINT]], %[[WIDE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-NOT: pto.vmi.ensure_layout %[[BCAST_DEINT]] +// ASSIGN: %[[BCAST_CONTIG:.*]] = pto.vmi.broadcast %[[SCALAR]] +// ASSIGN-SAME: f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[BCAST_CONTIG]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_broadcast_remat( +// LOWER-COUNT-4: pto.vdup %arg0 +// LOWER-NOT: pto.vintlv +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_call_boundary.pto b/test/lit/vmi/vmi_layout_assignment_call_boundary.pto new file mode 100644 index 0000000000..b7245ad00b --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_call_boundary.pto @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func private @callee(%x: !pto.vmi.vreg<128xf32>) + -> !pto.vmi.vreg<128xf32> { + %sum = pto.vmi.addf %x, %x + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %sum : !pto.vmi.vreg<128xf32> + } + + func.func @caller(%a: !pto.vmi.vreg<128xf16>) { + %ea = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %r = call @callee(%ea) + : (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.addf %r, %r + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK-LABEL: func.func private @callee( +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: pto.vmi.addf +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: return +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-LABEL: func.func @caller( +// CHECK: %[[EA:.*]] = pto.vmi.extf +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: %[[R:.*]] = call @callee(%[[EA]]) +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: pto.vmi.addf %[[R]], %[[R]] +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_cf_branch.pto b/test/lit/vmi/vmi_layout_assignment_cf_branch.pto new file mode 100644 index 0000000000..f96962a580 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_cf_branch.pto @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_cf_branch( + %cond: i1, + %a: !pto.vmi.vreg<128xf16>, + %b: !pto.vmi.vreg<128xf16>) { + cf.cond_br %cond, ^then(%a : !pto.vmi.vreg<128xf16>), + ^else(%b : !pto.vmi.vreg<128xf16>) + + ^then(%then_arg: !pto.vmi.vreg<128xf16>): + %then_value = pto.vmi.extf %then_arg + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %then_mask = pto.vmi.cmpf "olt", %then_value, %then_value + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.mask<128xpred> + cf.br ^join(%then_value, %then_mask + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred>) + + ^else(%else_arg: !pto.vmi.vreg<128xf16>): + %else_value = pto.vmi.extf %else_arg + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %else_mask = pto.vmi.cmpf "olt", %else_value, %else_value + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.mask<128xpred> + cf.br ^join(%else_value, %else_mask + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred>) + + ^join(%value: !pto.vmi.vreg<128xf32>, %mask: !pto.vmi.mask<128xpred>): + %selected = pto.vmi.select %mask, %value, %value + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, + !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_cf_branch( +// CHECK: cf.cond_br +// CHECK-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// CHECK: ^{{.*}}(%{{.*}}: !pto.vmi.vreg<128xf16, #pto.vmi.layout>): +// CHECK: pto.vmi.extf +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: cf.br +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: ^{{.*}}(%[[VALUE:.*]]: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, %[[MASK:.*]]: !pto.vmi.mask<128xb32, #pto.vmi.layout>): +// CHECK: pto.vmi.select %[[MASK]], %[[VALUE]], %[[VALUE]] +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_cf_switch.pto b/test/lit/vmi/vmi_layout_assignment_cf_switch.pto new file mode 100644 index 0000000000..6376a5502c --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_cf_switch.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_cf_switch( + %flag: i32, + %a: !pto.vmi.vreg<128xf32>, + %b: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> { + cf.switch %flag : i32, [ + default: ^join(%a : !pto.vmi.vreg<128xf32>), + 0: ^join(%b : !pto.vmi.vreg<128xf32>) + ] + + ^join(%value: !pto.vmi.vreg<128xf32>): + return %value : !pto.vmi.vreg<128xf32> + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_cf_switch( +// ASSIGN-SAME: %[[FLAG:[^:]+]]: i32 +// ASSIGN-SAME: %[[A:[^:]+]]: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: %[[B:[^:]+]]: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: cf.switch %[[FLAG]] : i32, [ +// ASSIGN: default: ^bb1(%[[A]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout>), +// ASSIGN: 0: ^bb1(%[[B]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout>) +// ASSIGN: ^bb1(%[[VALUE:.*]]: !pto.vmi.vreg<128xf32, #pto.vmi.layout>): +// ASSIGN: return %[[VALUE]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_cf_switch( +// LOWER-SAME: %[[FLAG:[^:]+]]: i32 +// LOWER-SAME: %[[A0:[^:]+]]: !pto.vreg<64xf32> +// LOWER-SAME: %[[A1:[^:]+]]: !pto.vreg<64xf32> +// LOWER-SAME: %[[B0:[^:]+]]: !pto.vreg<64xf32> +// LOWER-SAME: %[[B1:[^:]+]]: !pto.vreg<64xf32> +// LOWER-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// LOWER: cf.switch %[[FLAG]] : i32, [ +// LOWER: default: ^bb1(%[[A0]], %[[A1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>), +// LOWER: 0: ^bb1(%[[B0]], %[[B1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>) +// LOWER: ^bb1(%[[VALUE0:.*]]: !pto.vreg<64xf32>, %[[VALUE1:.*]]: !pto.vreg<64xf32>): +// LOWER: return %[[VALUE0]], %[[VALUE1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_channel_merge_count_unsupported_invalid.pto b/test/lit/vmi/vmi_layout_assignment_channel_merge_count_unsupported_invalid.pto new file mode 100644 index 0000000000..351b3f62f8 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_channel_merge_count_unsupported_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_channel_merge_count_unsupported_invalid( + %ch0: !pto.vmi.vreg<64xf32>, + %ch1: !pto.vmi.vreg<64xf32>, + %ch2: !pto.vmi.vreg<64xf32>) { + %merged = "pto.vmi.channel_merge"(%ch0, %ch1, %ch2) + : (!pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32>) + -> !pto.vmi.vreg<192xf32> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.channel_merge supports only 2 or 4 channels diff --git a/test/lit/vmi/vmi_layout_assignment_channel_split_count_unsupported_invalid.pto b/test/lit/vmi/vmi_layout_assignment_channel_split_count_unsupported_invalid.pto new file mode 100644 index 0000000000..572845c1a4 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_channel_split_count_unsupported_invalid.pto @@ -0,0 +1,21 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_channel_split_count_unsupported_invalid( + %src: !pto.vmi.vreg<192xf32>) { + %ch0, %ch1, %ch2 = "pto.vmi.channel_split"(%src) + : (!pto.vmi.vreg<192xf32>) + -> (!pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32>) + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.channel_split supports only 2 or 4 channels diff --git a/test/lit/vmi/vmi_layout_assignment_compress.pto b/test/lit/vmi/vmi_layout_assignment_compress.pto new file mode 100644 index 0000000000..dee109ce28 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_compress.pto @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_compress( + %src: !pto.vmi.vreg<64xf32>, + %mask: !pto.vmi.mask<64xpred>) -> !pto.vmi.vreg<64xf32> { + %out = pto.vmi.compress %src, %mask + : !pto.vmi.vreg<64xf32>, !pto.vmi.mask<64xpred> + -> !pto.vmi.vreg<64xf32> + return %out : !pto.vmi.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_compress( +// CHECK-SAME: %[[SRC:.*]]: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: %[[MASK:.*]]: !pto.vmi.mask<64xb32, #pto.vmi.layout>) +// CHECK-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.compress %[[SRC]], %[[MASK]] +// CHECK-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_compress_store.pto b/test/lit/vmi/vmi_layout_assignment_compress_store.pto new file mode 100644 index 0000000000..93266bdf42 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_compress_store.pto @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_compress_store( + %value: !pto.vmi.vreg<64xf32>, + %dst: !pto.ptr, + %offset: index, + %mask: !pto.vmi.mask<64xpred>) { + pto.vmi.compress_store %value, %dst[%offset], %mask + : !pto.vmi.vreg<64xf32>, !pto.ptr, !pto.vmi.mask<64xpred> + return + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_compress_store( +// CHECK-SAME: %[[VALUE:.*]]: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: %[[DST:.*]]: !pto.ptr +// CHECK-SAME: %[[OFFSET:.*]]: index +// CHECK-SAME: %[[MASK:.*]]: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK: pto.vmi.compress_store %[[VALUE]], %[[DST]][%[[OFFSET]]], %[[MASK]] +// CHECK-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: !pto.ptr +// CHECK-SAME: !pto.vmi.mask<64xb32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_constant_remat.pto b/test/lit/vmi/vmi_layout_assignment_constant_remat.pto new file mode 100644 index 0000000000..e387aa077d --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_constant_remat.pto @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_constant_remat( + %src: !pto.vmi.vreg<128xf16>, + %dst: !pto.ptr, + %offset: index) -> !pto.vmi.vreg<128xf32> { + %constant = "pto.vmi.constant"() { + value = dense<1.000000e+00> : tensor<128xf32> + } : () -> !pto.vmi.vreg<128xf32> + %wide = pto.vmi.extf %src + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.addf %constant, %wide + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.store %constant, %dst[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return %sum : !pto.vmi.vreg<128xf32> + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_constant_remat( +// ASSIGN-SAME: %[[SRC:.*]]: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[CONST_DEINT:.*]] = "pto.vmi.constant"() +// ASSIGN-SAME: dense<1.000000e+00> : tensor<128xf32> +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[WIDE:.*]] = pto.vmi.extf %[[SRC]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.addf %[[CONST_DEINT]], %[[WIDE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-NOT: pto.vmi.ensure_layout %[[CONST_DEINT]] +// ASSIGN: %[[CONST_CONTIG:.*]] = "pto.vmi.constant"() +// ASSIGN-SAME: dense<1.000000e+00> : tensor<128xf32> +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[CONST_CONTIG]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_constant_remat( +// LOWER: arith.constant 1.000000e+00 : f32 +// LOWER-COUNT-4: pto.vdup +// LOWER-NOT: pto.vintlv +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_expand_load.pto b/test/lit/vmi/vmi_layout_assignment_expand_load.pto new file mode 100644 index 0000000000..501b26b369 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_expand_load.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_expand_load( + %src: !pto.ptr, + %offset: index, + %mask: !pto.vmi.mask<64xpred>, + %passthru: !pto.vmi.vreg<64xf32>) -> !pto.vmi.vreg<64xf32> { + %out = pto.vmi.expand_load %src[%offset], %mask, %passthru + : !pto.ptr, !pto.vmi.mask<64xpred>, + !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf32> + return %out : !pto.vmi.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_expand_load( +// CHECK-SAME: %[[SRC:.*]]: !pto.ptr +// CHECK-SAME: %[[OFFSET:.*]]: index +// CHECK-SAME: %[[MASK:.*]]: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: %[[PASSTHRU:.*]]: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.expand_load %[[SRC]][%[[OFFSET]]], %[[MASK]], %[[PASSTHRU]] +// CHECK-SAME: !pto.ptr +// CHECK-SAME: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_external_call_invalid.pto b/test/lit/vmi/vmi_layout_assignment_external_call_invalid.pto new file mode 100644 index 0000000000..101f0f9254 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_external_call_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func private @external(!pto.vmi.vreg<128xf32>) + -> !pto.vmi.vreg<128xf32> + + func.func @caller(%x: !pto.vmi.vreg<128xf32>) { + %r = call @external(%x) + : (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.addf %r, %r + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: VMI-LAYOUT-CONTRACT: VMI typed function declaration requires an explicit external ABI materialization plan diff --git a/test/lit/vmi/vmi_layout_assignment_external_decl_invalid.pto b/test/lit/vmi/vmi_layout_assignment_external_decl_invalid.pto new file mode 100644 index 0000000000..ffb994287a --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_external_decl_invalid.pto @@ -0,0 +1,15 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func private @external_vmi(!pto.vmi.vreg<128xf32>) +} + +// CHECK: VMI-LAYOUT-CONTRACT: VMI typed function declaration requires an explicit external ABI materialization plan diff --git a/test/lit/vmi/vmi_layout_assignment_external_decl_preserve.pto b/test/lit/vmi/vmi_layout_assignment_external_decl_preserve.pto new file mode 100644 index 0000000000..384d0d1171 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_external_decl_preserve.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func private @external_i32(i32) -> i32 + + func.func @vmi_layout_assignment_external_decl_preserve( + %input: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> { + return %input : !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: module +// CHECK: func.func private @external_i32(i32) -> i32 +// CHECK-LABEL: func.func @vmi_layout_assignment_external_decl_preserve( +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_fma.pto b/test/lit/vmi/vmi_layout_assignment_fma.pto new file mode 100644 index 0000000000..c40b09c471 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_fma.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_fma( + %lhs: !pto.vmi.vreg<64xf32>, + %rhs: !pto.vmi.vreg<64xf32>, + %acc: !pto.vmi.vreg<64xf32>) -> !pto.vmi.vreg<64xf32> { + %out = pto.vmi.fma %lhs, %rhs, %acc + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32>, + !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf32> + return %out : !pto.vmi.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_fma( +// CHECK-SAME: %arg0: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: %arg1: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: %arg2: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.fma %arg0, %arg1, %arg2 +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_gather.pto b/test/lit/vmi/vmi_layout_assignment_gather.pto new file mode 100644 index 0000000000..a63919bf6f --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_gather.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_gather( + %src: !pto.ptr, + %indices: !pto.vmi.vreg<64xi32>, + %mask: !pto.vmi.mask<64xpred>, + %passthru: !pto.vmi.vreg<64xf32>) -> !pto.vmi.vreg<64xf32> { + %out = pto.vmi.gather %src[%indices], %mask, %passthru + : !pto.ptr, !pto.vmi.vreg<64xi32>, + !pto.vmi.mask<64xpred>, !pto.vmi.vreg<64xf32> + -> !pto.vmi.vreg<64xf32> + return %out : !pto.vmi.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_gather( +// CHECK-SAME: %arg1: !pto.vmi.vreg<64xi32, #pto.vmi.layout> +// CHECK-SAME: %arg2: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: %arg3: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.gather %arg0[%arg1], %arg2, %arg3 +// CHECK-SAME: !pto.vmi.vreg<64xi32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_indirect_call_invalid.pto b/test/lit/vmi/vmi_layout_assignment_indirect_call_invalid.pto new file mode 100644 index 0000000000..4186b78dfa --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_indirect_call_invalid.pto @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @caller( + %fn: (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32>, + %x: !pto.vmi.vreg<128xf32>) { + %r = func.call_indirect %fn(%x) + : (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.addf %r, %r + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: VMI-LAYOUT-CONTRACT: VMI typed call requires a direct internal callee with a body diff --git a/test/lit/vmi/vmi_layout_assignment_iota_remat.pto b/test/lit/vmi/vmi_layout_assignment_iota_remat.pto new file mode 100644 index 0000000000..773fd4187c --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_iota_remat.pto @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_iota_remat( + %base: f32, + %src: !pto.vmi.vreg<128xf16>, + %dst: !pto.ptr, + %offset: index) -> !pto.vmi.vreg<128xf32> { + %iota = pto.vmi.iota %base + : f32 -> !pto.vmi.vreg<128xf32> + %wide = pto.vmi.extf %src + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.addf %iota, %wide + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.store %iota, %dst[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return %sum : !pto.vmi.vreg<128xf32> + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_iota_remat( +// ASSIGN-SAME: %[[BASE:.*]]: f32 +// ASSIGN-SAME: %[[SRC:.*]]: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[IOTA_DEINT:.*]] = pto.vmi.iota %[[BASE]] +// ASSIGN-SAME: f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[WIDE:.*]] = pto.vmi.extf %[[SRC]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.addf %[[IOTA_DEINT]], %[[WIDE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-NOT: pto.vmi.ensure_layout %[[IOTA_DEINT]] +// ASSIGN: %[[IOTA_CONTIG:.*]] = pto.vmi.iota %[[BASE]] +// ASSIGN-SAME: f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[IOTA_CONTIG]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_iota_remat( +// LOWER: pto.vci +// LOWER: pto.vcvt +// LOWER: pto.vadd +// LOWER-NOT: pto.vintlv +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_load_truncf.pto b/test/lit/vmi/vmi_layout_assignment_load_truncf.pto new file mode 100644 index 0000000000..6b2d588e04 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_load_truncf.pto @@ -0,0 +1,133 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_load_truncf( + %src: !pto.ptr, + %offset: index) -> !pto.vmi.vreg<128xf16> { + %wide = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %narrow = pto.vmi.truncf %wide + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + return %narrow : !pto.vmi.vreg<128xf16> + } + + func.func @vmi_layout_assignment_tile_read_truncf( + %src: memref<128xf32>) -> !pto.vmi.vreg<128xf16> { + %wide = pto.vmi.tile_read %src + : memref<128xf32> -> !pto.vmi.vreg<128xf32> + %narrow = pto.vmi.truncf %wide + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + return %narrow : !pto.vmi.vreg<128xf16> + } + + func.func @vmi_layout_assignment_load_truncf_multi_use( + %src: !pto.ptr, + %dst: !pto.ptr, + %offset: index) -> !pto.vmi.vreg<128xf16> { + %wide = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<128xf32> + pto.vmi.store %wide, %dst[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + %narrow = pto.vmi.truncf %wide + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + return %narrow : !pto.vmi.vreg<128xf16> + } + + func.func @vmi_layout_assignment_tile_read_truncf_multi_use( + %src: memref<128xf32>, + %dst: memref<128xf32>) -> !pto.vmi.vreg<128xf16> { + %wide = pto.vmi.tile_read %src + : memref<128xf32> -> !pto.vmi.vreg<128xf32> + pto.vmi.tile_write %wide, %dst + : !pto.vmi.vreg<128xf32>, memref<128xf32> + %narrow = pto.vmi.truncf %wide + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + return %narrow : !pto.vmi.vreg<128xf16> + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_load_truncf( +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[WIDE:.*]] = pto.vmi.load +// ASSIGN-SAME: !pto.ptr -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-NOT: pto.vmi.ensure_layout +// ASSIGN: %[[NARROW:.*]] = pto.vmi.truncf %[[WIDE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: return %[[NARROW]] : !pto.vmi.vreg<128xf16, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_load_truncf( +// LOWER: %[[P0:.*]], %[[P1:.*]] = pto.vldsx2 %arg0[%arg1], "DINTLV_B32" +// LOWER: %[[EVEN:.*]] = pto.vcvt %[[P0]], {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} +// LOWER: %[[ODD:.*]] = pto.vcvt %[[P1]], {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} +// LOWER: %[[NARROW:.*]] = pto.vor %[[EVEN]], %[[ODD]] +// LOWER: return %[[NARROW]] : !pto.vreg<128xf16> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_tile_read_truncf( +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[WIDE:.*]] = pto.vmi.tile_read +// ASSIGN-SAME: memref<128xf32> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-NOT: pto.vmi.ensure_layout +// ASSIGN: %[[NARROW:.*]] = pto.vmi.truncf %[[WIDE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: return %[[NARROW]] : !pto.vmi.vreg<128xf16, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_tile_read_truncf( +// LOWER: %[[ZERO:.*]] = arith.constant 0 : index +// LOWER: %[[P0:.*]], %[[P1:.*]] = pto.vldsx2 %arg0[%[[ZERO]]], "DINTLV_B32" +// LOWER: %[[EVEN:.*]] = pto.vcvt %[[P0]], {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} +// LOWER: %[[ODD:.*]] = pto.vcvt %[[P1]], {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} +// LOWER: %[[NARROW:.*]] = pto.vor %[[EVEN]], %[[ODD]] +// LOWER: return %[[NARROW]] : !pto.vreg<128xf16> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_load_truncf_multi_use( +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[WIDE:.*]] = pto.vmi.load +// ASSIGN-SAME: !pto.ptr -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[WIDE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[SPLIT:.*]] = pto.vmi.ensure_layout %[[WIDE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[NARROW:.*]] = pto.vmi.truncf %[[SPLIT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: return %[[NARROW]] : !pto.vmi.vreg<128xf16, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_load_truncf_multi_use( +// LOWER: pto.vsts +// LOWER: pto.vdintlv +// LOWER: pto.vcvt +// LOWER: return {{.*}} : !pto.vreg<128xf16> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_tile_read_truncf_multi_use( +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[WIDE:.*]] = pto.vmi.tile_read +// ASSIGN-SAME: memref<128xf32> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.tile_write %[[WIDE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[SPLIT:.*]] = pto.vmi.ensure_layout %[[WIDE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[NARROW:.*]] = pto.vmi.truncf %[[SPLIT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: return %[[NARROW]] : !pto.vmi.vreg<128xf16, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_tile_read_truncf_multi_use( +// LOWER: pto.vsts +// LOWER: pto.vdintlv +// LOWER: pto.vcvt +// LOWER: return {{.*}} : !pto.vreg<128xf16> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_mask_granularity_conflict_invalid.pto b/test/lit/vmi/vmi_layout_assignment_mask_granularity_conflict_invalid.pto new file mode 100644 index 0000000000..f3942119a7 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_mask_granularity_conflict_invalid.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_mask_granularity_conflict_invalid( + %cond: i1, + %a16: !pto.vmi.vreg<128xf16>, + %b16: !pto.vmi.vreg<128xf16>, + %a32: !pto.vmi.vreg<128xf32>, + %b32: !pto.vmi.vreg<128xf32>) { + %mask = scf.if %cond -> !pto.vmi.mask<128xpred> { + %m16 = pto.vmi.cmpf "olt", %a16, %b16 + : !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf16> + -> !pto.vmi.mask<128xpred> + scf.yield %m16 : !pto.vmi.mask<128xpred> + } else { + %m32 = pto.vmi.cmpf "olt", %a32, %b32 + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.mask<128xpred> + scf.yield %m32 : !pto.vmi.mask<128xpred> + } + return + } +} + +// CHECK: VMI-LAYOUT-CONTRACT: conflicting mask granularities diff --git a/test/lit/vmi/vmi_layout_assignment_mask_remat.pto b/test/lit/vmi/vmi_layout_assignment_mask_remat.pto new file mode 100644 index 0000000000..b114643836 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_mask_remat.pto @@ -0,0 +1,73 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_create_mask_remat( + %active: index, + %a16: !pto.vmi.vreg<128xf16>, + %b16: !pto.vmi.vreg<128xf16>, + %a32: !pto.vmi.vreg<128xf32>, + %b32: !pto.vmi.vreg<128xf32>) + -> (!pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf32>) { + %mask = pto.vmi.create_mask %active : index -> !pto.vmi.mask<128xpred> + %sel16 = pto.vmi.select %mask, %a16, %b16 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf16> + -> !pto.vmi.vreg<128xf16> + %sel32 = pto.vmi.select %mask, %a32, %b32 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %sel16, %sel32 + : !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf32> + } + + func.func @vmi_layout_assignment_constant_mask_remat( + %a16: !pto.vmi.vreg<128xf16>, + %b16: !pto.vmi.vreg<128xf16>, + %a32: !pto.vmi.vreg<128xf32>, + %b32: !pto.vmi.vreg<128xf32>) + -> (!pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf32>) { + %mask = "pto.vmi.constant_mask"() { + value = dense : tensor<128xi1> + } : () -> !pto.vmi.mask<128xpred> + %sel16 = pto.vmi.select %mask, %a16, %b16 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf16> + -> !pto.vmi.vreg<128xf16> + %sel32 = pto.vmi.select %mask, %a32, %b32 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %sel16, %sel32 + : !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_create_mask_remat( +// CHECK-SAME: %[[ACTIVE:.*]]: index +// CHECK: %[[M32:.*]] = pto.vmi.create_mask %[[ACTIVE]] +// CHECK-SAME: index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: %[[M16:.*]] = pto.vmi.create_mask %[[ACTIVE]] +// CHECK-SAME: index -> !pto.vmi.mask<128xb16, #pto.vmi.layout> +// CHECK: pto.vmi.select %[[M16]] +// CHECK-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> +// CHECK: pto.vmi.select %[[M32]] +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK-NOT: pto.vmi.ensure_mask_layout +// CHECK-NOT: pto.vmi.ensure_mask_granularity + +// CHECK-LABEL: func.func @vmi_layout_assignment_constant_mask_remat( +// CHECK: %[[CM32:.*]] = "pto.vmi.constant_mask"() +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: %[[CM16:.*]] = "pto.vmi.constant_mask"() +// CHECK-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> +// CHECK: pto.vmi.select %[[CM16]] +// CHECK-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> +// CHECK: pto.vmi.select %[[CM32]] +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK-NOT: pto.vmi.ensure_mask_layout +// CHECK-NOT: pto.vmi.ensure_mask_granularity diff --git a/test/lit/vmi/vmi_layout_assignment_mask_use_ensure.pto b/test/lit/vmi/vmi_layout_assignment_mask_use_ensure.pto new file mode 100644 index 0000000000..fd487d017a --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_mask_use_ensure.pto @@ -0,0 +1,36 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_mask_use_ensure( + %m: !pto.vmi.mask<128xpred>, + %a16: !pto.vmi.vreg<128xf16>, + %b16: !pto.vmi.vreg<128xf16>, + %a32: !pto.vmi.vreg<128xf32>, + %b32: !pto.vmi.vreg<128xf32>) { + %sel16 = pto.vmi.select %m, %a16, %b16 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf16> + -> !pto.vmi.vreg<128xf16> + %sel32 = pto.vmi.select %m, %a32, %b32 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_mask_use_ensure( +// CHECK-SAME: %[[M:.*]]: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: %[[M16:.*]] = pto.vmi.ensure_mask_granularity %[[M]] +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> +// CHECK: pto.vmi.select %[[M16]] +// CHECK-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> +// CHECK: pto.vmi.select %[[M]] +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_masked_load.pto b/test/lit/vmi/vmi_layout_assignment_masked_load.pto new file mode 100644 index 0000000000..286c92f6da --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_masked_load.pto @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_masked_load( + %src: !pto.ptr, + %offset: index, + %mask: !pto.vmi.mask<64xpred>, + %passthru: !pto.vmi.vreg<64xf32>) -> !pto.vmi.vreg<64xf32> { + %out = pto.vmi.masked_load %src[%offset], %mask, %passthru + : !pto.ptr, !pto.vmi.mask<64xpred>, !pto.vmi.vreg<64xf32> + -> !pto.vmi.vreg<64xf32> + return %out : !pto.vmi.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_masked_load( +// CHECK-SAME: %arg2: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: %arg3: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.masked_load %arg0[%arg1], %arg2, %arg3 +// CHECK-SAME: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_multi_return.pto b/test/lit/vmi/vmi_layout_assignment_multi_return.pto new file mode 100644 index 0000000000..380b0d0ef9 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_multi_return.pto @@ -0,0 +1,39 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @multi_return( + %cond: i1, + %a: !pto.vmi.vreg<128xf16>, + %b: !pto.vmi.vreg<128xf16>) -> !pto.vmi.vreg<128xf32> { + cf.cond_br %cond, ^then, ^else + + ^then: + %ea = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + return %ea : !pto.vmi.vreg<128xf32> + + ^else: + %eb = pto.vmi.extf %b + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + return %eb : !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @multi_return( +// CHECK-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: pto.vmi.extf +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: return +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: pto.vmi.extf +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: return +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_multi_return_conflict_invalid.pto b/test/lit/vmi/vmi_layout_assignment_multi_return_conflict_invalid.pto new file mode 100644 index 0000000000..4e9b2885fd --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_multi_return_conflict_invalid.pto @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @multi_return_conflict( + %cond: i1, + %a: !pto.vmi.vreg<128xf16>, + %b: !pto.vmi.vreg<128xf8E4M3FN>) -> !pto.vmi.vreg<128xf32> { + cf.cond_br %cond, ^then, ^else + + ^then: + %ea = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + return %ea : !pto.vmi.vreg<128xf32> + + ^else: + %eb = pto.vmi.extf %b + : !pto.vmi.vreg<128xf8E4M3FN> -> !pto.vmi.vreg<128xf32> + return %eb : !pto.vmi.vreg<128xf32> + } +} + +// CHECK: VMI-LAYOUT-CONTRACT: conflicting natural layouts #pto.vmi.layout and #pto.vmi.layout diff --git a/test/lit/vmi/vmi_layout_assignment_post_gate_type_attr_invalid.pto b/test/lit/vmi/vmi_layout_assignment_post_gate_type_attr_invalid.pto new file mode 100644 index 0000000000..968aeb1e05 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_post_gate_type_attr_invalid.pto @@ -0,0 +1,17 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module attributes { + pto.hidden_vmi_type = !pto.vmi.vreg<128xf32> +} { +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: VMI or physical VPTO type appears in a non-signature attribute diff --git a/test/lit/vmi/vmi_layout_assignment_reduce_addf.pto b/test/lit/vmi/vmi_layout_assignment_reduce_addf.pto new file mode 100644 index 0000000000..de71e01d6a --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_reduce_addf.pto @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_reduce_addf( + %source: !pto.vmi.vreg<64xf32>, + %init: !pto.vmi.vreg<1xf32>, + %mask: !pto.vmi.mask<64xpred>) -> !pto.vmi.vreg<1xf32> { + %out = pto.vmi.reduce_addf %source, %init, %mask {reassoc} + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<1xf32>, + !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<1xf32> + return %out : !pto.vmi.vreg<1xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_reduce_addf( +// CHECK-SAME: %arg0: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: %arg1: !pto.vmi.vreg<1xf32, #pto.vmi.layout> +// CHECK-SAME: %arg2: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.reduce_addf %arg0, %arg1, %arg2 +// CHECK-SAME: reassoc +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_reduce_addi.pto b/test/lit/vmi/vmi_layout_assignment_reduce_addi.pto new file mode 100644 index 0000000000..82a516b114 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_reduce_addi.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_reduce_addi( + %source: !pto.vmi.vreg<64xi32>, + %init: !pto.vmi.vreg<1xi32>, + %mask: !pto.vmi.mask<64xpred>) -> !pto.vmi.vreg<1xi32> { + %out = pto.vmi.reduce_addi %source, %init, %mask + : !pto.vmi.vreg<64xi32>, !pto.vmi.vreg<1xi32>, + !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<1xi32> + return %out : !pto.vmi.vreg<1xi32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_reduce_addi( +// CHECK-SAME: %[[SOURCE:.*]]: !pto.vmi.vreg<64xi32, #pto.vmi.layout> +// CHECK-SAME: %[[INIT:.*]]: !pto.vmi.vreg<1xi32, #pto.vmi.layout> +// CHECK-SAME: %[[MASK:.*]]: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<1xi32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.reduce_addi %[[SOURCE]], %[[INIT]], %[[MASK]] +// CHECK-SAME: !pto.vmi.vreg<64xi32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<1xi32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<1xi32, #pto.vmi.layout> +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_reduce_minmaxf.pto b/test/lit/vmi/vmi_layout_assignment_reduce_minmaxf.pto new file mode 100644 index 0000000000..51f8180ef0 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_reduce_minmaxf.pto @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_reduce_maxf( + %source: !pto.vmi.vreg<64xf32>, + %init: !pto.vmi.vreg<1xf32>, + %mask: !pto.vmi.mask<64xpred>) -> !pto.vmi.vreg<1xf32> { + %out = pto.vmi.reduce_maxf %source, %init, %mask + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<1xf32>, + !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<1xf32> + return %out : !pto.vmi.vreg<1xf32> + } + + func.func @vmi_layout_assignment_reduce_minf( + %source: !pto.vmi.vreg<128xf16>, + %init: !pto.vmi.vreg<1xf16>, + %mask: !pto.vmi.mask<128xpred>) -> !pto.vmi.vreg<1xf16> { + %out = pto.vmi.reduce_minf %source, %init, %mask + : !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<1xf16>, + !pto.vmi.mask<128xpred> -> !pto.vmi.vreg<1xf16> + return %out : !pto.vmi.vreg<1xf16> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_reduce_maxf( +// CHECK-SAME: %arg0: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: %arg1: !pto.vmi.vreg<1xf32, #pto.vmi.layout> +// CHECK-SAME: %arg2: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> +// CHECK: %[[MAX:.*]] = pto.vmi.reduce_maxf %arg0, %arg1, %arg2 +// CHECK: return %[[MAX]] + +// CHECK-LABEL: func.func @vmi_layout_assignment_reduce_minf( +// CHECK-SAME: %arg0: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// CHECK-SAME: %arg1: !pto.vmi.vreg<1xf16, #pto.vmi.layout> +// CHECK-SAME: %arg2: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<1xf16, #pto.vmi.layout> +// CHECK: %[[MASK:.*]] = pto.vmi.ensure_mask_granularity %arg2 +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> -> !pto.vmi.mask<128xb16, #pto.vmi.layout> +// CHECK: %[[MIN:.*]] = pto.vmi.reduce_minf %arg0, %arg1, %[[MASK]] +// CHECK: return %[[MIN]] diff --git a/test/lit/vmi/vmi_layout_assignment_scatter.pto b/test/lit/vmi/vmi_layout_assignment_scatter.pto new file mode 100644 index 0000000000..9560cfa981 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_scatter.pto @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_scatter( + %value: !pto.vmi.vreg<64xf32>, + %dst: !pto.ptr, + %indices: !pto.vmi.vreg<64xi32>, + %mask: !pto.vmi.mask<64xpred>) { + pto.vmi.scatter %value, %dst[%indices], %mask {indices_unique} + : !pto.vmi.vreg<64xf32>, !pto.ptr, + !pto.vmi.vreg<64xi32>, !pto.vmi.mask<64xpred> + return + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_scatter( +// CHECK-SAME: %arg0: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: %arg2: !pto.vmi.vreg<64xi32, #pto.vmi.layout> +// CHECK-SAME: %arg3: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK: pto.vmi.scatter %arg0, %arg1[%arg2], %arg3 +// CHECK-SAME: indices_unique +// CHECK-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<64xi32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.mask<64xb32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_scf_execute_region.pto b/test/lit/vmi/vmi_layout_assignment_scf_execute_region.pto new file mode 100644 index 0000000000..3bd81dca8c --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_scf_execute_region.pto @@ -0,0 +1,38 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_scf_execute_region( + %input: !pto.vmi.vreg<128xf16>) -> !pto.vmi.vreg<128xf32> { + %wide = pto.vmi.extf %input + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %result = scf.execute_region -> !pto.vmi.vreg<128xf32> { + scf.yield %wide : !pto.vmi.vreg<128xf32> + } + return %result : !pto.vmi.vreg<128xf32> + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_scf_execute_region( +// ASSIGN-SAME: %[[INPUT:.*]]: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[WIDE:.*]] = pto.vmi.extf %[[INPUT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[RESULT:.*]] = scf.execute_region -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: scf.yield %[[WIDE]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: return %[[RESULT]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_scf_execute_region( +// LOWER: %[[RESULT:.*]]:2 = scf.execute_region -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// LOWER: scf.yield {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// LOWER: return %[[RESULT]]#0, %[[RESULT]]#1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_scf_for.pto b/test/lit/vmi/vmi_layout_assignment_scf_for.pto new file mode 100644 index 0000000000..b63563216b --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_scf_for.pto @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_scf_for(%a: !pto.vmi.vreg<128xf16>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %init = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %result = scf.for %i = %c0 to %c2 step %c1 + iter_args(%acc = %init) -> (!pto.vmi.vreg<128xf32>) { + %next = pto.vmi.addf %acc, %acc + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + scf.yield %next : !pto.vmi.vreg<128xf32> + } + %sum = pto.vmi.addf %result, %result + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_scf_for( +// CHECK: %[[INIT:.*]] = pto.vmi.extf +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: %[[RESULT:.*]] = scf.for +// CHECK-SAME: iter_args(%[[ACC:.*]] = %[[INIT]]) +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: pto.vmi.addf %[[ACC]], %[[ACC]] +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: scf.yield +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: pto.vmi.addf %[[RESULT]], %[[RESULT]] +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_scf_if.pto b/test/lit/vmi/vmi_layout_assignment_scf_if.pto new file mode 100644 index 0000000000..f86107920a --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_scf_if.pto @@ -0,0 +1,50 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_scf_if( + %cond: i1, + %a: !pto.vmi.vreg<128xf16>, + %b: !pto.vmi.vreg<128xf16>) { + %value, %mask = scf.if %cond + -> (!pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred>) { + %ea = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %cmpa = pto.vmi.cmpf "olt", %ea, %ea + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.mask<128xpred> + scf.yield %ea, %cmpa : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + } else { + %eb = pto.vmi.extf %b + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %cmpb = pto.vmi.cmpf "olt", %eb, %eb + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.mask<128xpred> + scf.yield %eb, %cmpb : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + } + %selected = pto.vmi.select %mask, %value, %value + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, + !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_scf_if( +// CHECK: scf.if +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: pto.vmi.cmpf +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: scf.yield +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: pto.vmi.select +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_scf_index_switch.pto b/test/lit/vmi/vmi_layout_assignment_scf_index_switch.pto new file mode 100644 index 0000000000..24ea65503e --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_scf_index_switch.pto @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_scf_index_switch( + %selector: index, + %input: !pto.vmi.vreg<128xf16>) -> !pto.vmi.vreg<128xf32> { + %wide = pto.vmi.extf %input + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %result = scf.index_switch %selector -> !pto.vmi.vreg<128xf32> + case 0 { + scf.yield %wide : !pto.vmi.vreg<128xf32> + } + default { + scf.yield %wide : !pto.vmi.vreg<128xf32> + } + return %result : !pto.vmi.vreg<128xf32> + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_scf_index_switch( +// ASSIGN-SAME: %[[SELECTOR:.*]]: index +// ASSIGN-SAME: %[[INPUT:.*]]: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[WIDE:.*]] = pto.vmi.extf %[[INPUT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[RESULT:.*]] = scf.index_switch %[[SELECTOR]] -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: scf.yield %[[WIDE]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: default +// ASSIGN: scf.yield %[[WIDE]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: return %[[RESULT]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_scf_index_switch( +// LOWER: %[[RESULT:.*]]:2 = scf.index_switch {{.*}} -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// LOWER: scf.yield {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// LOWER: default +// LOWER: scf.yield {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// LOWER: return %[[RESULT]]#0, %[[RESULT]]#1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_scf_while.pto b/test/lit/vmi/vmi_layout_assignment_scf_while.pto new file mode 100644 index 0000000000..917bf1762f --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_scf_while.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_scf_while( + %input: !pto.vmi.vreg<128xf16>, + %keep_going: i1) -> !pto.vmi.vreg<128xf32> { + %wide = pto.vmi.extf %input + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %result = scf.while (%value = %wide) + : (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> { + scf.condition(%keep_going) %value : !pto.vmi.vreg<128xf32> + } do { + ^bb0(%value: !pto.vmi.vreg<128xf32>): + scf.yield %value : !pto.vmi.vreg<128xf32> + } + return %result : !pto.vmi.vreg<128xf32> + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_scf_while( +// ASSIGN-SAME: %[[INPUT:.*]]: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[WIDE:.*]] = pto.vmi.extf %[[INPUT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[RESULT:.*]] = scf.while (%[[VALUE:.*]] = %[[WIDE]]) : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: scf.condition(%arg1) %[[VALUE]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: ^bb0(%[[AFTER:.*]]: !pto.vmi.vreg<128xf32, #pto.vmi.layout>): +// ASSIGN: scf.yield %[[AFTER]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: return %[[RESULT]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_scf_while( +// LOWER: %[[RESULT:.*]]:2 = scf.while +// LOWER-SAME: (!pto.vreg<64xf32>, !pto.vreg<64xf32>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// LOWER: scf.condition(%arg1) {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// LOWER: scf.yield {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// LOWER: return %[[RESULT]]#0, %[[RESULT]]#1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_store_ensure.pto b/test/lit/vmi/vmi_layout_assignment_store_ensure.pto new file mode 100644 index 0000000000..430fff7fda --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_store_ensure.pto @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_store_ensure( + %src: !pto.vmi.vreg<128xf16>, + %dst: !pto.ptr, + %offset: index) { + %wide = pto.vmi.extf %src + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.addf %wide, %wide + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.store %sum, %dst[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_store_ensure( +// ASSIGN-SAME: %[[SRC:.*]]: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[WIDE:.*]] = pto.vmi.extf %[[SRC]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.addf %[[WIDE]], %[[WIDE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[DENSE:.*]] = pto.vmi.ensure_layout %[[SUM]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[DENSE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_store_ensure( +// LOWER: %[[EVEN:.*]] = pto.vcvt +// LOWER: %[[ODD:.*]] = pto.vcvt +// LOWER: %[[SUM0:.*]] = pto.vadd %[[EVEN]], %[[EVEN]] +// LOWER: %[[SUM1:.*]] = pto.vadd %[[ODD]], %[[ODD]] +// LOWER: %[[D0:.*]], %[[D1:.*]] = pto.vintlv %[[SUM0]], %[[SUM1]] +// LOWER: pto.vsts %[[D0]] +// LOWER: pto.vsts %[[D1]] +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_truncf_ensure.pto b/test/lit/vmi/vmi_layout_assignment_truncf_ensure.pto new file mode 100644 index 0000000000..141e85772b --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_truncf_ensure.pto @@ -0,0 +1,39 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_truncf_ensure( + %wide: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf16> { + %narrow = pto.vmi.truncf %wide + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + return %narrow : !pto.vmi.vreg<128xf16> + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_truncf_ensure( +// ASSIGN-SAME: %[[WIDE:.*]]: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[SPLIT:.*]] = pto.vmi.ensure_layout %[[WIDE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[NARROW:.*]] = pto.vmi.truncf %[[SPLIT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: return %[[NARROW]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_truncf_ensure( +// LOWER-SAME: %[[D0:[^,]+]]: !pto.vreg<64xf32> +// LOWER-SAME: %[[D1:[^)]+]]: !pto.vreg<64xf32> +// LOWER: %[[P0:.*]], %[[P1:.*]] = pto.vdintlv %[[D0]], %[[D1]] +// LOWER: %[[EVEN:.*]] = pto.vcvt %[[P0]]{{.*}}part = "EVEN" +// LOWER: %[[ODD:.*]] = pto.vcvt %[[P1]]{{.*}}part = "ODD" +// LOWER: %[[NARROW:.*]] = pto.vor %[[EVEN]], %[[ODD]] +// LOWER: return %[[NARROW]] : !pto.vreg<128xf16> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_widen.pto b/test/lit/vmi/vmi_layout_assignment_widen.pto new file mode 100644 index 0000000000..eceedcb711 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_widen.pto @@ -0,0 +1,39 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_widen( + %a: !pto.vmi.vreg<128xf16>, + %b: !pto.vmi.vreg<128xf16>) { + %ea = pto.vmi.extf %a : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %eb = pto.vmi.extf %b : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.addf %ea, %eb + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %cmp = pto.vmi.cmpf "olt", %ea, %eb + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> -> !pto.vmi.mask<128xpred> + %sel = pto.vmi.select %cmp, %sum, %ea + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_widen( +// CHECK-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// CHECK: pto.vmi.extf +// CHECK-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: pto.vmi.addf +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: pto.vmi.cmpf +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: pto.vmi.select +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_factor_invalid.pto b/test/lit/vmi/vmi_layout_factor_invalid.pto new file mode 100644 index 0000000000..b908700333 --- /dev/null +++ b/test/lit/vmi/vmi_layout_factor_invalid.pto @@ -0,0 +1,18 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_factor_invalid( + %arg0: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + return + } +} + +// CHECK: #pto.vmi.layout expected factor to be 2 or 4 diff --git a/test/lit/vmi/vmi_layout_gate_surface_invalid.pto b/test/lit/vmi/vmi_layout_gate_surface_invalid.pto new file mode 100644 index 0000000000..1b1bfdfb52 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_surface_invalid.pto @@ -0,0 +1,18 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_surface_invalid(%a: !pto.vmi.vreg<128xf32>) { + return + } +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: layout-assigned VMI IR requires !pto.vmi.vreg with layout diff --git a/test/lit/vmi/vmi_layout_gate_surface_mask_invalid.pto b/test/lit/vmi/vmi_layout_gate_surface_mask_invalid.pto new file mode 100644 index 0000000000..79425740d8 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_surface_mask_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_surface_mask_invalid( + %m: !pto.vmi.mask<128xpred>) { + return + } +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: layout-assigned VMI IR requires !pto.vmi.vreg with layout +// CHECK-SAME: !pto.vmi.mask with b8/b16/b32 granularity plus layout diff --git a/test/lit/vmi/vmi_layout_gate_type_attr_nested_physical_invalid.pto b/test/lit/vmi/vmi_layout_gate_type_attr_nested_physical_invalid.pto new file mode 100644 index 0000000000..7494367606 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_type_attr_nested_physical_invalid.pto @@ -0,0 +1,17 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module attributes { + pto.hidden_physical_state = [{nested = !pto.vreg<64xf32>}] +} { +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: VMI or physical VPTO type appears in a non-signature attribute diff --git a/test/lit/vmi/vmi_layout_gate_type_attr_surface_invalid.pto b/test/lit/vmi/vmi_layout_gate_type_attr_surface_invalid.pto new file mode 100644 index 0000000000..78549ed3e6 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_type_attr_surface_invalid.pto @@ -0,0 +1,17 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module attributes { + pto.hidden_vmi_type = !pto.vmi.mask<128xpred> +} { +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: VMI or physical VPTO type appears in a non-signature attribute diff --git a/test/lit/vmi/vmi_layout_gate_valid.pto b/test/lit/vmi/vmi_layout_gate_valid.pto new file mode 100644 index 0000000000..ebc5778f34 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_valid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -pto-validate-vmi-layout-ir + +module { + func.func @vmi_layout_gate_valid( + %m: !pto.vmi.mask<128xb32, #pto.vmi.layout>, + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %sel = pto.vmi.select %m, %a, %b + : !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_mask_concrete_without_layout_invalid.pto b/test/lit/vmi/vmi_mask_concrete_without_layout_invalid.pto new file mode 100644 index 0000000000..43aca3fd30 --- /dev/null +++ b/test/lit/vmi/vmi_mask_concrete_without_layout_invalid.pto @@ -0,0 +1,18 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_mask_concrete_without_layout_invalid( + %arg0: !pto.vmi.mask<128xb32>) { + return + } +} + +// CHECK: concrete mask granularity requires layout diff --git a/test/lit/vmi/vmi_mask_granularity_invalid.pto b/test/lit/vmi/vmi_mask_granularity_invalid.pto new file mode 100644 index 0000000000..4d85cc9aa0 --- /dev/null +++ b/test/lit/vmi/vmi_mask_granularity_invalid.pto @@ -0,0 +1,18 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_mask_granularity_invalid( + %arg0: !pto.vmi.mask<128xb64, #pto.vmi.layout>) { + return + } +} + +// CHECK: expected granularity to be one of pred, b8, b16, b32 diff --git a/test/lit/vmi/vmi_mask_logic_invalid.pto b/test/lit/vmi/vmi_mask_logic_invalid.pto new file mode 100644 index 0000000000..49798b742b --- /dev/null +++ b/test/lit/vmi/vmi_mask_logic_invalid.pto @@ -0,0 +1,67 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file 2>&1 | FileCheck %s + +module { + func.func @vmi_mask_and_lane_mismatch( + %lhs: !pto.vmi.mask<128xpred>, + %rhs: !pto.vmi.mask<64xpred>) { + %and = pto.vmi.mask_and %lhs, %rhs + : !pto.vmi.mask<128xpred>, !pto.vmi.mask<64xpred> + -> !pto.vmi.mask<128xpred> + return + } +} + +// CHECK: 'pto.vmi.mask_and' op requires all VMI mask values to have the same logical lane count + +// ----- + +module { + func.func @vmi_mask_or_granularity_mismatch( + %lhs: !pto.vmi.mask<128xb16, #pto.vmi.layout>, + %rhs: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { + %or = pto.vmi.mask_or %lhs, %rhs + : !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + return + } +} + +// CHECK: 'pto.vmi.mask_or' op requires all VMI mask values to have the same granularity + +// ----- + +module { + func.func @vmi_mask_xor_lane_mismatch( + %lhs: !pto.vmi.mask<128xpred>, + %rhs: !pto.vmi.mask<64xpred>) { + %xor = pto.vmi.mask_xor %lhs, %rhs + : !pto.vmi.mask<128xpred>, !pto.vmi.mask<64xpred> + -> !pto.vmi.mask<128xpred> + return + } +} + +// CHECK: 'pto.vmi.mask_xor' op requires all VMI mask values to have the same logical lane count + +// ----- + +module { + func.func @vmi_mask_not_granularity_mismatch( + %src: !pto.vmi.mask<128xb16, #pto.vmi.layout>) { + %not = pto.vmi.mask_not %src + : !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + return + } +} + +// CHECK: 'pto.vmi.mask_not' op requires all VMI mask values to have the same granularity diff --git a/test/lit/vmi/vmi_mask_pred_with_layout_invalid.pto b/test/lit/vmi/vmi_mask_pred_with_layout_invalid.pto new file mode 100644 index 0000000000..e7d949242e --- /dev/null +++ b/test/lit/vmi/vmi_mask_pred_with_layout_invalid.pto @@ -0,0 +1,18 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_mask_pred_with_layout_invalid( + %arg0: !pto.vmi.mask<128xpred, #pto.vmi.layout>) { + return + } +} + +// CHECK: pred mask must not carry layout diff --git a/test/lit/vmi/vmi_masked_store_mask_granularity_invalid.pto b/test/lit/vmi/vmi_masked_store_mask_granularity_invalid.pto new file mode 100644 index 0000000000..4b3a672049 --- /dev/null +++ b/test/lit/vmi/vmi_masked_store_mask_granularity_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_masked_store_mask_granularity_invalid( + %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb16, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + pto.vmi.masked_store %value, %dst[%offset], %mask + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + return + } +} + +// CHECK: 'pto.vmi.masked_store' op requires mask granularity to match data element width diff --git a/test/lit/vmi/vmi_memory_element_type_invalid.pto b/test/lit/vmi/vmi_memory_element_type_invalid.pto new file mode 100644 index 0000000000..4d6a199e11 --- /dev/null +++ b/test/lit/vmi/vmi_memory_element_type_invalid.pto @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_load_element_type_invalid(%src: !pto.ptr, %offset: index) { + %value = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<128xf16> + return + } +} + +// CHECK: 'pto.vmi.load' op requires memory source element type to match VMI data element type + +// ----- + +module { + func.func @vmi_store_element_type_invalid( + %value: !pto.vmi.vreg<128xf16>, %dst: !pto.ptr, %offset: index) { + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<128xf16>, !pto.ptr + return + } +} + +// CHECK: 'pto.vmi.store' op requires memory destination element type to match VMI data element type + +// ----- + +module { + func.func @vmi_tile_read_element_type_invalid(%src: memref<128xf32>) { + %value = pto.vmi.tile_read %src + : memref<128xf32> -> !pto.vmi.vreg<128xf16> + return + } +} + +// CHECK: 'pto.vmi.tile_read' op requires memory source element type to match VMI data element type + +// ----- + +module { + func.func @vmi_tile_write_element_type_invalid( + %value: !pto.vmi.vreg<128xf16>, %dst: memref<128xf32>) { + pto.vmi.tile_write %value, %dst + : !pto.vmi.vreg<128xf16>, memref<128xf32> + return + } +} + +// CHECK: 'pto.vmi.tile_write' op requires memory destination element type to match VMI data element type diff --git a/test/lit/vmi/vmi_min_max_integer_invalid.pto b/test/lit/vmi/vmi_min_max_integer_invalid.pto new file mode 100644 index 0000000000..71d0861e82 --- /dev/null +++ b/test/lit/vmi/vmi_min_max_integer_invalid.pto @@ -0,0 +1,37 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file 2>&1 | FileCheck %s + +module { + func.func @vmi_minf_integer_invalid( + %lhs: !pto.vmi.vreg<128xi32>, + %rhs: !pto.vmi.vreg<128xi32>) { + %min = pto.vmi.minf %lhs, %rhs + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + return + } +} + +// CHECK: 'pto.vmi.minf' op requires floating-point-like VMI element type + +// ----- + +module { + func.func @vmi_maxf_integer_invalid( + %lhs: !pto.vmi.vreg<128xi32>, + %rhs: !pto.vmi.vreg<128xi32>) { + %max = pto.vmi.maxf %lhs, %rhs + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + return + } +} + +// CHECK: 'pto.vmi.maxf' op requires floating-point-like VMI element type diff --git a/test/lit/vmi/vmi_negf_integer_invalid.pto b/test/lit/vmi/vmi_negf_integer_invalid.pto new file mode 100644 index 0000000000..6b28584b64 --- /dev/null +++ b/test/lit/vmi/vmi_negf_integer_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_negf_integer_invalid(%value: !pto.vmi.vreg<128xi32>) { + %neg = pto.vmi.negf %value + : !pto.vmi.vreg<128xi32> -> !pto.vmi.vreg<128xi32> + return + } +} + +// CHECK: 'pto.vmi.negf' op requires floating-point-like VMI element type diff --git a/test/lit/vmi/vmi_op_verifier_basic.pto b/test/lit/vmi/vmi_op_verifier_basic.pto new file mode 100644 index 0000000000..bff24c6e07 --- /dev/null +++ b/test/lit/vmi/vmi_op_verifier_basic.pto @@ -0,0 +1,106 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_op_verifier_basic( + %ptr: !pto.ptr, + %tile: memref<128xf32>, + %layouted: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %mask_b16: !pto.vmi.mask<128xb16, #pto.vmi.layout>, + %mask_b32: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { + %c0 = arith.constant 0 : index + %f32 = arith.constant 1.000000e+00 : f32 + %f16 = arith.constant 1.000000e+00 : f16 + %active = arith.constant 64 : index + + %const = "pto.vmi.constant"() { + value = dense<1.000000e+00> : tensor<128xf32> + } : () -> !pto.vmi.vreg<128xf32> + %broadcast = pto.vmi.broadcast %f32 : f32 -> !pto.vmi.vreg<128xf32> + %broadcast16 = pto.vmi.broadcast %f16 : f16 -> !pto.vmi.vreg<128xf16> + %mask = pto.vmi.create_mask %active : index -> !pto.vmi.mask<128xpred> + %mask_const = "pto.vmi.constant_mask"() { + value = dense : tensor<128xi1> + } : () -> !pto.vmi.mask<128xpred> + + %add = pto.vmi.addf %broadcast, %const + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %cmp = pto.vmi.cmpf "olt", %broadcast, %const + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> -> !pto.vmi.mask<128xpred> + %sel = pto.vmi.select %mask, %broadcast, %const + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %ext = pto.vmi.extf %broadcast16 : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %trunc = pto.vmi.truncf %ext : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + + %loaded = pto.vmi.load %ptr[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf32> + pto.vmi.store %loaded, %ptr[%c0] : !pto.vmi.vreg<128xf32>, !pto.ptr + %tile_read = pto.vmi.tile_read %tile : memref<128xf32> -> !pto.vmi.vreg<128xf32> + pto.vmi.tile_write %tile_read, %tile : !pto.vmi.vreg<128xf32>, memref<128xf32> + + %small = "pto.vmi.shuffle"(%broadcast) { + indices = array + } : (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<4xf32> + %split0, %split1 = "pto.vmi.channel_split"(%small) + : (!pto.vmi.vreg<4xf32>) -> (!pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32>) + %merged = "pto.vmi.channel_merge"(%split0, %split1) + : (!pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32>) -> !pto.vmi.vreg<4xf32> + + %ensure = pto.vmi.ensure_layout %layouted + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %layouted_ext = pto.vmi.extf %ensure + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf64, #pto.vmi.layout> + %layouted_trunc = pto.vmi.truncf %layouted_ext + : !pto.vmi.vreg<128xf64, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %mask_layout = pto.vmi.ensure_mask_layout %mask_b32 + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %mask_granularity = pto.vmi.ensure_mask_granularity %mask_b16 + : !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + + %part0, %part1 = "pto.vmi.unpack"(%layouted) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + %packed = "pto.vmi.pack"(%part0, %part1) + : (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + + %i0 = arith.constant 0 : i32 + %iv0 = pto.vmi.broadcast %i0 : i32 -> !pto.vmi.vreg<128xi32> + %iadd = pto.vmi.addi %iv0, %iv0 + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> -> !pto.vmi.vreg<128xi32> + %icmp = pto.vmi.cmpi "slt", %iv0, %iv0 + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> -> !pto.vmi.mask<128xpred> + + return + } +} + +// CHECK-LABEL: func.func @vmi_op_verifier_basic +// CHECK: pto.vmi.broadcast +// CHECK: pto.vmi.addf +// CHECK: pto.vmi.cmpf +// CHECK: pto.vmi.select +// CHECK: pto.vmi.extf +// CHECK: pto.vmi.truncf +// CHECK: pto.vmi.load +// CHECK: pto.vmi.store +// CHECK: pto.vmi.tile_read +// CHECK: pto.vmi.tile_write +// CHECK: pto.vmi.ensure_layout +// CHECK: pto.vmi.ensure_mask_layout +// CHECK: pto.vmi.ensure_mask_granularity +// CHECK: "pto.vmi.unpack" +// CHECK: "pto.vmi.pack" +// CHECK: pto.vmi.addi +// CHECK: pto.vmi.cmpi diff --git a/test/lit/vmi/vmi_pack_arity_invalid.pto b/test/lit/vmi/vmi_pack_arity_invalid.pto new file mode 100644 index 0000000000..4ba4eaa180 --- /dev/null +++ b/test/lit/vmi/vmi_pack_arity_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_pack_arity_invalid(%p0: !pto.vreg<64xf32>) { + %a = "pto.vmi.pack"(%p0) + : (!pto.vreg<64xf32>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} + +// CHECK: requires 2 physical parts, got 1 diff --git a/test/lit/vmi/vmi_producer_boundary_helper_invalid.pto b/test/lit/vmi/vmi_producer_boundary_helper_invalid.pto new file mode 100644 index 0000000000..81805f2a28 --- /dev/null +++ b/test/lit/vmi/vmi_producer_boundary_helper_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_producer_boundary_helper_invalid( + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %r = pto.vmi.ensure_layout %a + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: VMI producer boundary requires surface diff --git a/test/lit/vmi/vmi_producer_boundary_layout_invalid.pto b/test/lit/vmi/vmi_producer_boundary_layout_invalid.pto new file mode 100644 index 0000000000..be6a6414f9 --- /dev/null +++ b/test/lit/vmi/vmi_producer_boundary_layout_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_producer_boundary_layout_invalid( + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + return + } +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: VMI producer boundary requires surface diff --git a/test/lit/vmi/vmi_producer_boundary_mask_layout_invalid.pto b/test/lit/vmi/vmi_producer_boundary_mask_layout_invalid.pto new file mode 100644 index 0000000000..3d3727bdaa --- /dev/null +++ b/test/lit/vmi/vmi_producer_boundary_mask_layout_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_producer_boundary_mask_layout_invalid( + %m: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { + return + } +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: VMI producer boundary requires surface !pto.vmi.vreg or !pto.vmi.mask type diff --git a/test/lit/vmi/vmi_producer_boundary_non_vmi_op_invalid.pto b/test/lit/vmi/vmi_producer_boundary_non_vmi_op_invalid.pto new file mode 100644 index 0000000000..c5aa0676f0 --- /dev/null +++ b/test/lit/vmi/vmi_producer_boundary_non_vmi_op_invalid.pto @@ -0,0 +1,21 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_producer_boundary_non_vmi_op_invalid( + %a: !pto.vmi.vreg<128xf32>) { + %0 = builtin.unrealized_conversion_cast %a + : !pto.vmi.vreg<128xf32> to !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: VMI typed value is used by a non-VMI semantic op diff --git a/test/lit/vmi/vmi_producer_boundary_physical_invalid.pto b/test/lit/vmi/vmi_producer_boundary_physical_invalid.pto new file mode 100644 index 0000000000..c2a3996eb9 --- /dev/null +++ b/test/lit/vmi/vmi_producer_boundary_physical_invalid.pto @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -pto-validate-vmi-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_producer_boundary_physical_invalid(%a: !pto.vreg<64xf32>) { + return + } +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: physical VPTO register type appears before VMI-to-VPTO + +// ----- + +module { + func.func @vmi_producer_boundary_physical_op_invalid() { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + return + } +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: physical VPTO register type appears before VMI-to-VPTO diff --git a/test/lit/vmi/vmi_producer_boundary_type_attr_layout_invalid.pto b/test/lit/vmi/vmi_producer_boundary_type_attr_layout_invalid.pto new file mode 100644 index 0000000000..8deed1cecb --- /dev/null +++ b/test/lit/vmi/vmi_producer_boundary_type_attr_layout_invalid.pto @@ -0,0 +1,17 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-ir 2>&1 | FileCheck %s + +module attributes { + pto.hidden_vmi_type = !pto.vmi.vreg<128xf32, #pto.vmi.layout> +} { +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: VMI or physical VPTO type appears in a non-signature attribute diff --git a/test/lit/vmi/vmi_producer_boundary_type_attr_nested_invalid.pto b/test/lit/vmi/vmi_producer_boundary_type_attr_nested_invalid.pto new file mode 100644 index 0000000000..4163dcfb16 --- /dev/null +++ b/test/lit/vmi/vmi_producer_boundary_type_attr_nested_invalid.pto @@ -0,0 +1,17 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-ir 2>&1 | FileCheck %s + +module attributes { + pto.hidden_vmi_state = {nested = [!pto.vmi.vreg<128xf32>]} +} { +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: VMI or physical VPTO type appears in a non-signature attribute diff --git a/test/lit/vmi/vmi_producer_boundary_type_attr_surface_invalid.pto b/test/lit/vmi/vmi_producer_boundary_type_attr_surface_invalid.pto new file mode 100644 index 0000000000..8cd353ca13 --- /dev/null +++ b/test/lit/vmi/vmi_producer_boundary_type_attr_surface_invalid.pto @@ -0,0 +1,17 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-ir 2>&1 | FileCheck %s + +module attributes { + pto.hidden_vmi_type = !pto.vmi.vreg<128xf32> +} { +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: VMI or physical VPTO type appears in a non-signature attribute diff --git a/test/lit/vmi/vmi_producer_boundary_valid.pto b/test/lit/vmi/vmi_producer_boundary_valid.pto new file mode 100644 index 0000000000..dee731bd1f --- /dev/null +++ b/test/lit/vmi/vmi_producer_boundary_valid.pto @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -pto-validate-vmi-ir | FileCheck %s + +module { + func.func @vmi_producer_boundary_valid( + %a: !pto.vmi.vreg<128xf32>, + %b: !pto.vmi.vreg<128xf32>, + %m: !pto.vmi.mask<128xpred>) -> !pto.vmi.vreg<128xf32> { + %r = pto.vmi.addf %a, %b + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %s = pto.vmi.select %m, %r, %a + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %s : !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_producer_boundary_valid +// CHECK: pto.vmi.addf +// CHECK: pto.vmi.select diff --git a/test/lit/vmi/vmi_ptoas_backend_required_invalid.pto b/test/lit/vmi/vmi_ptoas_backend_required_invalid.pto new file mode 100644 index 0000000000..7379984b50 --- /dev/null +++ b/test/lit/vmi/vmi_ptoas_backend_required_invalid.pto @@ -0,0 +1,17 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 --enable-vmi %s -o - 2>&1 | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + func.func @vmi_ptoas_backend_required_invalid() { + return + } +} + +// CHECK: Error: --enable-vmi requires --pto-backend=vpto or pto.backend = "vpto". diff --git a/test/lit/vmi/vmi_ptoas_cli_control_flow.pto b/test/lit/vmi/vmi_ptoas_cli_control_flow.pto new file mode 100644 index 0000000000..cd29782d10 --- /dev/null +++ b/test/lit/vmi/vmi_ptoas_cli_control_flow.pto @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + module attributes {pto.backend = "vpto", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_ptoas_cli_control_flow( + %cond: i1, + %lhs: f32, + %rhs: f32, + %dst: !pto.ptr, + %offset: index) { + %lhs_v = pto.vmi.broadcast %lhs + : f32 -> !pto.vmi.vreg<128xf32> + %rhs_v = pto.vmi.broadcast %rhs + : f32 -> !pto.vmi.vreg<128xf32> + %chosen = scf.if %cond -> !pto.vmi.vreg<128xf32> { + scf.yield %lhs_v : !pto.vmi.vreg<128xf32> + } else { + scf.yield %rhs_v : !pto.vmi.vreg<128xf32> + } + pto.vmi.store %chosen, %dst[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } + } +} + +// CHECK-LABEL: func.func @vmi_ptoas_cli_control_flow +// CHECK: %[[LHS:.*]] = pto.vdup +// CHECK: %[[RHS:.*]] = pto.vdup +// CHECK: %[[CHOSEN:.*]] = arith.select {{.*}}, %[[LHS]], %[[RHS]] : !pto.vreg<64xf32> +// CHECK: pto.vsts %[[CHOSEN]] +// CHECK: pto.vsts %[[CHOSEN]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_ptoas_cli_pipeline.pto b/test/lit/vmi/vmi_ptoas_cli_pipeline.pto new file mode 100644 index 0000000000..8957bb1f40 --- /dev/null +++ b/test/lit/vmi/vmi_ptoas_cli_pipeline.pto @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - | FileCheck %s +// RUN: ptoas --pto-arch=a5 --enable-vmi --emit-vpto %s -o - | FileCheck %s --check-prefix=ATTR +// RUN: not ptoas --pto-backend=emitc --enable-vmi %s -o - 2>&1 | FileCheck %s --check-prefix=EMITC + +module attributes {pto.target_arch = "a5"} { + module attributes {pto.backend = "vpto", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_ptoas_cli_pipeline( + %scalar: f32, + %dst: !pto.ptr, + %offset: index) { + %value = pto.vmi.broadcast %scalar + : f32 -> !pto.vmi.vreg<128xf32> + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } + } +} + +// CHECK-LABEL: func.func @vmi_ptoas_cli_pipeline +// CHECK: pto.vecscope +// CHECK: pto.vdup +// CHECK: pto.vsts +// CHECK: pto.vsts +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// ATTR-LABEL: func.func @vmi_ptoas_cli_pipeline +// ATTR: pto.vecscope +// ATTR: pto.vdup +// ATTR: pto.vsts +// ATTR-NOT: pto.vmi. +// ATTR-NOT: !pto.vmi. +// ATTR-NOT: unrealized_conversion_cast + +// EMITC: Error: --enable-vmi requires --pto-backend=vpto or pto.backend = "vpto". diff --git a/test/lit/vmi/vmi_ptoas_public_abi_invalid.pto b/test/lit/vmi/vmi_ptoas_public_abi_invalid.pto new file mode 100644 index 0000000000..79b146acd8 --- /dev/null +++ b/test/lit/vmi/vmi_ptoas_public_abi_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - 2>&1 | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + module attributes {pto.backend = "vpto", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_ptoas_public_abi_invalid( + %value: !pto.vmi.vreg<128xf32>) { + return + } + } +} + +// CHECK: VMI-LAYOUT-CONTRACT: public VMI typed function requires an explicit external ABI materialization plan diff --git a/test/lit/vmi/vmi_ptoas_public_result_abi_invalid.pto b/test/lit/vmi/vmi_ptoas_public_result_abi_invalid.pto new file mode 100644 index 0000000000..a27067e62c --- /dev/null +++ b/test/lit/vmi/vmi_ptoas_public_result_abi_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - 2>&1 | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + module attributes {pto.backend = "vpto", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_ptoas_public_result_abi_invalid( + %scalar: f32) -> !pto.vmi.vreg<128xf32> { + %value = pto.vmi.broadcast %scalar + : f32 -> !pto.vmi.vreg<128xf32> + return %value : !pto.vmi.vreg<128xf32> + } + } +} + +// CHECK: VMI-LAYOUT-CONTRACT: public VMI typed function requires an explicit external ABI materialization plan diff --git a/test/lit/vmi/vmi_reduce_addf_missing_reassoc_invalid.pto b/test/lit/vmi/vmi_reduce_addf_missing_reassoc_invalid.pto new file mode 100644 index 0000000000..47dc112c04 --- /dev/null +++ b/test/lit/vmi/vmi_reduce_addf_missing_reassoc_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_reduce_addf_missing_reassoc_invalid( + %source: !pto.vmi.vreg<64xf32>, + %init: !pto.vmi.vreg<1xf32>, + %mask: !pto.vmi.mask<64xpred>) { + %out = pto.vmi.reduce_addf %source, %init, %mask + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<1xf32>, + !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<1xf32> + return + } +} + +// CHECK: 'pto.vmi.reduce_addf' op requires reassoc attr because VPTO vcadd performs pair-wise floating-point reduction diff --git a/test/lit/vmi/vmi_scatter_indices_invalid.pto b/test/lit/vmi/vmi_scatter_indices_invalid.pto new file mode 100644 index 0000000000..bd59b81b04 --- /dev/null +++ b/test/lit/vmi/vmi_scatter_indices_invalid.pto @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_scatter_indices_invalid( + %value: !pto.vmi.vreg<64xf32>, + %dst: !pto.ptr, + %indices: !pto.vmi.vreg<64xf32>, + %mask: !pto.vmi.mask<64xpred>) { + pto.vmi.scatter %value, %dst[%indices], %mask {indices_unique} + : !pto.vmi.vreg<64xf32>, !pto.ptr, + !pto.vmi.vreg<64xf32>, !pto.vmi.mask<64xpred> + return + } +} + +// CHECK: 'pto.vmi.scatter' op requires signless or unsigned 32-bit integer indices diff --git a/test/lit/vmi/vmi_select_mask_granularity_invalid.pto b/test/lit/vmi/vmi_select_mask_granularity_invalid.pto new file mode 100644 index 0000000000..2e6b9d10f9 --- /dev/null +++ b/test/lit/vmi/vmi_select_mask_granularity_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_select_mask_granularity_invalid( + %m: !pto.vmi.mask<128xb16, #pto.vmi.layout>, + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %r = pto.vmi.select %m, %a, %b + : !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} + +// CHECK: requires mask granularity to match data element width diff --git a/test/lit/vmi/vmi_shli_float_invalid.pto b/test/lit/vmi/vmi_shli_float_invalid.pto new file mode 100644 index 0000000000..e73ee9c232 --- /dev/null +++ b/test/lit/vmi/vmi_shli_float_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_shli_float_invalid( + %lhs: !pto.vmi.vreg<128xf32>, + %rhs: !pto.vmi.vreg<128xf32>) { + %shifted = pto.vmi.shli %lhs, %rhs + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: 'pto.vmi.shli' op requires integer-like VMI element type diff --git a/test/lit/vmi/vmi_shrui_float_invalid.pto b/test/lit/vmi/vmi_shrui_float_invalid.pto new file mode 100644 index 0000000000..5de50dfff1 --- /dev/null +++ b/test/lit/vmi/vmi_shrui_float_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_shrui_float_invalid( + %lhs: !pto.vmi.vreg<128xf32>, + %rhs: !pto.vmi.vreg<128xf32>) { + %shifted = pto.vmi.shrui %lhs, %rhs + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: 'pto.vmi.shrui' op requires signless or unsigned integer VMI element type diff --git a/test/lit/vmi/vmi_shrui_signed_invalid.pto b/test/lit/vmi/vmi_shrui_signed_invalid.pto new file mode 100644 index 0000000000..c3c57a52e9 --- /dev/null +++ b/test/lit/vmi/vmi_shrui_signed_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_shrui_signed_invalid( + %lhs: !pto.vmi.vreg<128xsi16>, + %rhs: !pto.vmi.vreg<128xsi16>) { + %shifted = pto.vmi.shrui %lhs, %rhs + : !pto.vmi.vreg<128xsi16>, !pto.vmi.vreg<128xsi16> + -> !pto.vmi.vreg<128xsi16> + return + } +} + +// CHECK: 'pto.vmi.shrui' op requires signless or unsigned integer VMI element type diff --git a/test/lit/vmi/vmi_to_vpto_abs.pto b/test/lit/vmi/vmi_to_vpto_abs.pto new file mode 100644 index 0000000000..247a239f66 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_abs.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_absf( + %value: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> { + %abs = pto.vmi.absf %value + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + return %abs : !pto.vmi.vreg<128xf32> + } + + func.func @vmi_to_vpto_absi( + %value: !pto.vmi.vreg<256xi16>) -> !pto.vmi.vreg<256xi16> { + %abs = pto.vmi.absi %value + : !pto.vmi.vreg<256xi16> -> !pto.vmi.vreg<256xi16> + return %abs : !pto.vmi.vreg<256xi16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_absf( +// CHECK-SAME: %[[F0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[F1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK-DAG: %[[AF0:.*]] = pto.vabs %[[F0]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[AF1:.*]] = pto.vabs %[[F1]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[AF0]], %[[AF1]] + +// CHECK-LABEL: func.func @vmi_to_vpto_absi( +// CHECK-SAME: %[[I0:[^,]+]]: !pto.vreg<128xi16> +// CHECK-SAME: %[[I1:[^)]+]]: !pto.vreg<128xi16> +// CHECK-SAME: -> (!pto.vreg<128xi16>, !pto.vreg<128xi16>) +// CHECK-DAG: %[[AI0:.*]] = pto.vabs %[[I0]], {{.*}} : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK-DAG: %[[AI1:.*]] = pto.vabs %[[I1]], {{.*}} : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK: return %[[AI0]], %[[AI1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_active_prefix_index.pto b/test/lit/vmi/vmi_to_vpto_active_prefix_index.pto new file mode 100644 index 0000000000..7d64e0ec0f --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_active_prefix_index.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_active_prefix_index( + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) + -> !pto.vreg<64xi32> { + %idx = pto.vmi.active_prefix_index %mask + : !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xi32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%idx) + : (!pto.vmi.vreg<64xi32, #pto.vmi.layout>) + -> !pto.vreg<64xi32> + return %part : !pto.vreg<64xi32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_active_prefix_index( +// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[M:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[CARRIER:.*]] = pto.vdup %[[ZERO]], %[[M]] : i32, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[IDX:.*]] = pto.vusqz %[[CARRIER]], %arg0 : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: return %[[IDX]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_active_prefix_index_multichunk_invalid.pto b/test/lit/vmi/vmi_to_vpto_active_prefix_index_multichunk_invalid.pto new file mode 100644 index 0000000000..cb655b0e4f --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_active_prefix_index_multichunk_invalid.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_active_prefix_index_multichunk_invalid( + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) { + %idx = pto.vmi.active_prefix_index %mask + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%idx) + : (!pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) + return %p0, %p1 : !pto.vreg<64xi32>, !pto.vreg<64xi32> + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.active_prefix_index lowers through pto.vusqz only for one contiguous physical chunk +// CHECK-SAME: multi-chunk prefix needs cross-chunk carry diff --git a/test/lit/vmi/vmi_to_vpto_active_prefix_index_tail_invalid.pto b/test/lit/vmi/vmi_to_vpto_active_prefix_index_tail_invalid.pto new file mode 100644 index 0000000000..07fd5307e0 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_active_prefix_index_tail_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_active_prefix_index_tail_invalid( + %mask: !pto.vmi.mask<32xb32, #pto.vmi.layout>) { + %idx = pto.vmi.active_prefix_index %mask + : !pto.vmi.mask<32xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<32xi32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.active_prefix_index lowers through pto.vusqz only for one contiguous physical chunk +// CHECK-SAME: padding mask lanes cannot affect the observable prefix diff --git a/test/lit/vmi/vmi_to_vpto_add.pto b/test/lit/vmi/vmi_to_vpto_add.pto new file mode 100644 index 0000000000..49b5fdeca3 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_add.pto @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_addf( + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %sum = pto.vmi.addf %a, %b + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%sum) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_addi( + %a: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) { + %sum = pto.vmi.addi %a, %b + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%sum) + : (!pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) + return %p0, %p1 : !pto.vreg<64xi32>, !pto.vreg<64xi32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_addf( +// CHECK: %[[M0:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: pto.vadd {{.*}}, {{.*}}, %[[M0]] +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[M1:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: pto.vadd {{.*}}, {{.*}}, %[[M1]] +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-LABEL: func.func @vmi_to_vpto_addi( +// CHECK: %[[IM0:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: pto.vadd {{.*}}, {{.*}}, %[[IM0]] +// CHECK-SAME: !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[IM1:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: pto.vadd {{.*}}, {{.*}}, %[[IM1]] +// CHECK-SAME: !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK-NOT: pto.vmi.add +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_bf16_arith.pto b/test/lit/vmi/vmi_to_vpto_bf16_arith.pto new file mode 100644 index 0000000000..c7357b5abd --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_bf16_arith.pto @@ -0,0 +1,50 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_bf16_arith( + %lhs: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>) + -> (!pto.vreg<128xbf16>, !pto.vreg<128xbf16>, !pto.vreg<128xbf16>) { + %sum = pto.vmi.addf %lhs, %rhs + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + %min = pto.vmi.minf %lhs, %rhs + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + %max = pto.vmi.maxf %lhs, %rhs + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + %sum_part = "pto.vmi.unpack"(%sum) + : (!pto.vmi.vreg<128xbf16, #pto.vmi.layout>) + -> !pto.vreg<128xbf16> + %min_part = "pto.vmi.unpack"(%min) + : (!pto.vmi.vreg<128xbf16, #pto.vmi.layout>) + -> !pto.vreg<128xbf16> + %max_part = "pto.vmi.unpack"(%max) + : (!pto.vmi.vreg<128xbf16, #pto.vmi.layout>) + -> !pto.vreg<128xbf16> + return %sum_part, %min_part, %max_part + : !pto.vreg<128xbf16>, !pto.vreg<128xbf16>, !pto.vreg<128xbf16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_bf16_arith( +// CHECK: %[[MASK:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: %[[ADD:.*]] = pto.vadd %arg0, %arg1, %[[MASK]] : !pto.vreg<128xbf16>, !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<128xbf16> +// CHECK: %[[MIN:.*]] = pto.vmin %arg0, %arg1, %{{.*}} : !pto.vreg<128xbf16>, !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<128xbf16> +// CHECK: %[[MAX:.*]] = pto.vmax %arg0, %arg1, %{{.*}} : !pto.vreg<128xbf16>, !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<128xbf16> +// CHECK: return %[[ADD]], %[[MIN]], %[[MAX]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_bitcast.pto b/test/lit/vmi/vmi_to_vpto_bitcast.pto new file mode 100644 index 0000000000..f73ffbe68a --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_bitcast.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_bitcast_f32_to_i16( + %value: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<256xi16> { + %cast = pto.vmi.bitcast %value + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<256xi16> + return %cast : !pto.vmi.vreg<256xi16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_bitcast_f32_to_i16( +// CHECK-SAME: %[[V0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[V1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<128xi16>, !pto.vreg<128xi16>) +// CHECK-DAG: %[[B0:.*]] = pto.vbitcast %[[V0]] : !pto.vreg<64xf32> -> !pto.vreg<128xi16> +// CHECK-DAG: %[[B1:.*]] = pto.vbitcast %[[V1]] : !pto.vreg<64xf32> -> !pto.vreg<128xi16> +// CHECK: return %[[B0]], %[[B1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_bitcast_partial.pto b/test/lit/vmi/vmi_to_vpto_bitcast_partial.pto new file mode 100644 index 0000000000..e2a1b3c789 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_bitcast_partial.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_bitcast_partial( + %value: !pto.vmi.vreg<65xf32>) -> !pto.vmi.vreg<130xi16> { + %cast = pto.vmi.bitcast %value + : !pto.vmi.vreg<65xf32> -> !pto.vmi.vreg<130xi16> + return %cast : !pto.vmi.vreg<130xi16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_bitcast_partial( +// CHECK-SAME: %[[S0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[S1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<128xi16>, !pto.vreg<128xi16>) +// CHECK-DAG: %[[B0:.*]] = pto.vbitcast %[[S0]] : !pto.vreg<64xf32> -> !pto.vreg<128xi16> +// CHECK-DAG: %[[B1:.*]] = pto.vbitcast %[[S1]] : !pto.vreg<64xf32> -> !pto.vreg<128xi16> +// CHECK: return %[[B0]], %[[B1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_bitwise.pto b/test/lit/vmi/vmi_to_vpto_bitwise.pto new file mode 100644 index 0000000000..80a665ccd9 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_bitwise.pto @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_bitwise( + %a: !pto.vmi.vreg<256xi16>, + %b: !pto.vmi.vreg<256xi16>) + -> (!pto.vmi.vreg<256xi16>, !pto.vmi.vreg<256xi16>, + !pto.vmi.vreg<256xi16>, + !pto.vmi.vreg<256xi16>) { + %and = pto.vmi.andi %a, %b + : !pto.vmi.vreg<256xi16>, !pto.vmi.vreg<256xi16> + -> !pto.vmi.vreg<256xi16> + %or = pto.vmi.ori %a, %b + : !pto.vmi.vreg<256xi16>, !pto.vmi.vreg<256xi16> + -> !pto.vmi.vreg<256xi16> + %xor = pto.vmi.xori %a, %b + : !pto.vmi.vreg<256xi16>, !pto.vmi.vreg<256xi16> + -> !pto.vmi.vreg<256xi16> + %not = pto.vmi.not %a + : !pto.vmi.vreg<256xi16> -> !pto.vmi.vreg<256xi16> + return %and, %or, %xor, %not + : !pto.vmi.vreg<256xi16>, !pto.vmi.vreg<256xi16>, + !pto.vmi.vreg<256xi16>, + !pto.vmi.vreg<256xi16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_bitwise( +// CHECK-SAME: %[[A0:[^,]+]]: !pto.vreg<128xi16> +// CHECK-SAME: %[[A1:[^,]+]]: !pto.vreg<128xi16> +// CHECK-SAME: %[[B0:[^,]+]]: !pto.vreg<128xi16> +// CHECK-SAME: %[[B1:[^)]+]]: !pto.vreg<128xi16> +// CHECK-SAME: -> (!pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.vreg<128xi16>) +// CHECK-DAG: %[[AND0:.*]] = pto.vand %[[A0]], %[[B0]], {{.*}} : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK-DAG: %[[AND1:.*]] = pto.vand %[[A1]], %[[B1]], {{.*}} : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK-DAG: %[[OR0:.*]] = pto.vor %[[A0]], %[[B0]], {{.*}} : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK-DAG: %[[OR1:.*]] = pto.vor %[[A1]], %[[B1]], {{.*}} : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK-DAG: %[[XOR0:.*]] = pto.vxor %[[A0]], %[[B0]], {{.*}} : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK-DAG: %[[XOR1:.*]] = pto.vxor %[[A1]], %[[B1]], {{.*}} : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK-DAG: %[[NOT0:.*]] = pto.vnot %[[A0]], {{.*}} : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK-DAG: %[[NOT1:.*]] = pto.vnot %[[A1]], {{.*}} : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK: return %[[AND0]], %[[AND1]], %[[OR0]], %[[OR1]], %[[XOR0]], %[[XOR1]], %[[NOT0]], %[[NOT1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_broadcast.pto b/test/lit/vmi/vmi_to_vpto_broadcast.pto new file mode 100644 index 0000000000..9cdbf92e1e --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_broadcast.pto @@ -0,0 +1,69 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_broadcast_contiguous(%scalar: f32) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %value = pto.vmi.broadcast %scalar + : f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_broadcast_deint4(%scalar: f32) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %value = pto.vmi.broadcast %scalar + : f32 -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_broadcast_rank0( + %scalar: !pto.vmi.vreg<1xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %value = pto.vmi.broadcast %scalar + : !pto.vmi.vreg<1xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_broadcast_contiguous( +// CHECK: %[[M0:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[P0:.*]] = pto.vdup %arg0, %[[M0]] : f32, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[M1:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[P1:.*]] = pto.vdup %arg0, %[[M1]] : f32, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_broadcast_deint4( +// CHECK-COUNT-4: pto.vdup %arg0 +// CHECK: return +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_broadcast_rank0( +// CHECK-COUNT-4: pto.vdup %arg0{{.*}}{position = "LOWEST"} +// CHECK: return +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_call_boundary.pto b/test/lit/vmi/vmi_to_vpto_call_boundary.pto new file mode 100644 index 0000000000..0a34ebe197 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_call_boundary.pto @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func private @callee(%x: !pto.vmi.vreg<128xf32>) + -> !pto.vmi.vreg<128xf32> { + %sum = pto.vmi.addf %x, %x + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %sum : !pto.vmi.vreg<128xf32> + } + + func.func @caller(%a: !pto.vmi.vreg<128xf16>) + -> !pto.vmi.vreg<128xf32> { + %ea = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %r = call @callee(%ea) + : (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.addf %r, %r + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %sum : !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func private @callee( +// CHECK-SAME: %[[C0:[^:]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[C1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK-DAG: %[[CM0:.*]] = pto.vadd %[[C0]], %[[C0]] +// CHECK-DAG: %[[CM1:.*]] = pto.vadd %[[C1]], %[[C1]] +// CHECK: return %[[CM0]], %[[CM1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// CHECK-LABEL: func.func @caller( +// CHECK-SAME: %[[A:[^)]+]]: !pto.vreg<128xf16> +// CHECK-DAG: %[[EA0:.*]] = pto.vcvt %[[A]] +// CHECK-DAG: %[[EA1:.*]] = pto.vcvt %[[A]] +// CHECK: %[[R:.*]]:2 = call @callee(%[[EA0]], %[[EA1]]) +// CHECK-SAME: (!pto.vreg<64xf32>, !pto.vreg<64xf32>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK-DAG: %[[S0:.*]] = pto.vadd %[[R]]#0, %[[R]]#0 +// CHECK-DAG: %[[S1:.*]] = pto.vadd %[[R]]#1, %[[R]]#1 +// CHECK: return %[[S0]], %[[S1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_cf_branch.pto b/test/lit/vmi/vmi_to_vpto_cf_branch.pto new file mode 100644 index 0000000000..0a4cf70e1d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_cf_branch.pto @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_cf_branch( + %cond: i1, + %a: !pto.vmi.vreg<128xf16>, + %b: !pto.vmi.vreg<128xf16>) -> !pto.vmi.vreg<128xf32> { + cf.cond_br %cond, ^then(%a : !pto.vmi.vreg<128xf16>), + ^else(%b : !pto.vmi.vreg<128xf16>) + + ^then(%then_arg: !pto.vmi.vreg<128xf16>): + %then_value = pto.vmi.extf %then_arg + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + cf.br ^join(%then_value : !pto.vmi.vreg<128xf32>) + + ^else(%else_arg: !pto.vmi.vreg<128xf16>): + %else_value = pto.vmi.extf %else_arg + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %else_sum = pto.vmi.addf %else_value, %else_value + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + cf.br ^join(%else_sum : !pto.vmi.vreg<128xf32>) + + ^join(%value: !pto.vmi.vreg<128xf32>): + %sum = pto.vmi.addf %value, %value + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %sum : !pto.vmi.vreg<128xf32> + } + + func.func @vmi_to_vpto_cf_cond_branch_operands( + %cond: i1, + %a: !pto.vmi.vreg<128xf32>, + %b: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> { + cf.cond_br %cond, ^join(%a : !pto.vmi.vreg<128xf32>), + ^join(%b : !pto.vmi.vreg<128xf32>) + + ^join(%value: !pto.vmi.vreg<128xf32>): + return %value : !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_cf_branch( +// CHECK-SAME: %[[COND:[^,]+]]: i1 +// CHECK-SAME: %[[A:[^,]+]]: !pto.vreg<128xf16> +// CHECK-SAME: %[[B:[^)]+]]: !pto.vreg<128xf16> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK: cf.cond_br %[[COND]], ^[[THEN:.*]], ^[[ELSE:.*]] +// CHECK: ^[[THEN]]: +// CHECK-DAG: %[[THEN_P0:.*]] = pto.vcvt %[[A]] +// CHECK-DAG: %[[THEN_P1:.*]] = pto.vcvt %[[A]] +// CHECK: cf.br ^[[JOIN:.*]](%[[THEN_P0]], %[[THEN_P1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK: ^[[ELSE]]: +// CHECK-DAG: %[[ELSE_P0:.*]] = pto.vcvt %[[B]] +// CHECK-DAG: %[[ELSE_P1:.*]] = pto.vcvt %[[B]] +// CHECK: cf.br ^[[JOIN]]({{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK: ^[[JOIN]](%{{.*}}: !pto.vreg<64xf32>, %{{.*}}: !pto.vreg<64xf32>): +// CHECK: pto.vadd +// CHECK-LABEL: func.func @vmi_to_vpto_cf_cond_branch_operands( +// CHECK-SAME: %[[COND2:[^,]+]]: i1 +// CHECK-SAME: %[[A0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[A1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[B0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[B1:[^)]+]]: !pto.vreg<64xf32> +// CHECK: cf.cond_br %[[COND2]], ^[[CB_JOIN:.*]](%[[A0]], %[[A1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>), ^[[CB_JOIN]](%[[B0]], %[[B1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK: ^[[CB_JOIN]](%{{.*}}: !pto.vreg<64xf32>, %{{.*}}: !pto.vreg<64xf32>): +// CHECK: return {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_channel_merge4_contiguous.pto b/test/lit/vmi/vmi_to_vpto_channel_merge4_contiguous.pto new file mode 100644 index 0000000000..4ffb8e384d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_channel_merge4_contiguous.pto @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_channel_merge4_contiguous( + %ch0: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %ch1: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %ch2: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %ch3: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> { + %merged = "pto.vmi.channel_merge"(%ch0, %ch1, %ch2, %ch3) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + return %merged : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_channel_merge4_contiguous( +// CHECK-SAME: %[[P0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[P1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[P2:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[P3:[^)]+]]: !pto.vreg<64xf32> +// CHECK: %[[E0:.*]], %[[E1:.*]] = pto.vintlv %[[P0]], %[[P2]] +// CHECK: %[[O0:.*]], %[[O1:.*]] = pto.vintlv %[[P1]], %[[P3]] +// CHECK: %[[L0:.*]], %[[L1:.*]] = pto.vintlv %[[E0]], %[[O0]] +// CHECK: %[[H0:.*]], %[[H1:.*]] = pto.vintlv %[[E1]], %[[O1]] +// CHECK: return %[[L0]], %[[L1]], %[[H0]], %[[H1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_channel_merge_count_unsupported_invalid.pto b/test/lit/vmi/vmi_to_vpto_channel_merge_count_unsupported_invalid.pto new file mode 100644 index 0000000000..8bdc2beb6a --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_channel_merge_count_unsupported_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_channel_merge_count_unsupported_invalid( + %ch0: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %ch1: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %ch2: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) { + %merged = "pto.vmi.channel_merge"(%ch0, %ch1, %ch2) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<192xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.channel_merge supports only 2 or 4 channels diff --git a/test/lit/vmi/vmi_to_vpto_channel_merge_layout_invalid.pto b/test/lit/vmi/vmi_to_vpto_channel_merge_layout_invalid.pto new file mode 100644 index 0000000000..867cbdce65 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_channel_merge_layout_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_channel_merge_layout_invalid( + %ch0: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %ch1: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) { + %merged = "pto.vmi.channel_merge"(%ch0, %ch1) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} + +// CHECK: 'pto.vmi.channel_merge' op requires layout-assigned channel_merge inputs to be contiguous diff --git a/test/lit/vmi/vmi_to_vpto_channel_merge_partial_group_invalid.pto b/test/lit/vmi/vmi_to_vpto_channel_merge_partial_group_invalid.pto new file mode 100644 index 0000000000..443a5fedae --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_channel_merge_partial_group_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_channel_merge_partial_group_invalid( + %ch0: !pto.vmi.vreg<2xf32, #pto.vmi.layout>, + %ch1: !pto.vmi.vreg<2xf32, #pto.vmi.layout>) { + %merged = "pto.vmi.channel_merge"(%ch0, %ch1) + : (!pto.vmi.vreg<2xf32, #pto.vmi.layout>, + !pto.vmi.vreg<2xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<4xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.channel_merge requires every input layout to be contiguous +// CHECK-SAME: complete physical channel groups +// CHECK-SAME: requires source and result to have the same physical arity diff --git a/test/lit/vmi/vmi_to_vpto_channel_split_count_unsupported_invalid.pto b/test/lit/vmi/vmi_to_vpto_channel_split_count_unsupported_invalid.pto new file mode 100644 index 0000000000..1bc963d400 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_channel_split_count_unsupported_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_channel_split_count_unsupported_invalid( + %src: !pto.vmi.vreg<192xf32, #pto.vmi.layout>) { + %ch0, %ch1, %ch2 = "pto.vmi.channel_split"(%src) + : (!pto.vmi.vreg<192xf32, #pto.vmi.layout>) + -> (!pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.channel_split supports only 2 or 4 channels diff --git a/test/lit/vmi/vmi_to_vpto_channel_split_layout_invalid.pto b/test/lit/vmi/vmi_to_vpto_channel_split_layout_invalid.pto new file mode 100644 index 0000000000..55c9ea862e --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_channel_split_layout_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_channel_split_layout_invalid( + %src: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %ch0, %ch1 = "pto.vmi.channel_split"(%src) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + return + } +} + +// CHECK: 'pto.vmi.channel_split' op requires layout-assigned channel_split source to be contiguous or deinterleaved by result count diff --git a/test/lit/vmi/vmi_to_vpto_channel_split_merge.pto b/test/lit/vmi/vmi_to_vpto_channel_split_merge.pto new file mode 100644 index 0000000000..10d90d2869 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_channel_split_merge.pto @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_channel_split_merge2( + %src: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> { + %ch0, %ch1 = "pto.vmi.channel_split"(%src) + : (!pto.vmi.vreg<128xf32>) + -> (!pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32>) + %merged = "pto.vmi.channel_merge"(%ch0, %ch1) + : (!pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32>) + -> !pto.vmi.vreg<128xf32> + return %merged : !pto.vmi.vreg<128xf32> + } + + func.func @vmi_channel_split4( + %src: !pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>) { + %ch0, %ch1, %ch2, %ch3 = "pto.vmi.channel_split"(%src) + : (!pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + return %ch0, %ch1, %ch2, %ch3 + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + } + + func.func @vmi_channel_split_deint2_identity( + %src: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>) { + %ch0, %ch1 = "pto.vmi.channel_split"(%src) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + return %ch0, %ch1 + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + } + + func.func @vmi_channel_merge_deint2_identity( + %ch0: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %ch1: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + %merged = "pto.vmi.channel_merge"(%ch0, %ch1) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %merged : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_channel_split_merge2( +// CHECK-SAME: %[[D0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[D1:[^)]+]]: !pto.vreg<64xf32> +// CHECK: %[[CH0:.*]], %[[CH1:.*]] = pto.vdintlv %[[D0]], %[[D1]] +// CHECK: return %[[CH0]], %[[CH1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// CHECK-LABEL: func.func @vmi_channel_split4( +// CHECK-SAME: %[[S0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[S1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[S2:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[S3:[^)]+]]: !pto.vreg<64xf32> +// CHECK: %[[A0:.*]], %[[A1:.*]] = pto.vdintlv %[[S0]], %[[S1]] +// CHECK: %[[B0:.*]], %[[B1:.*]] = pto.vdintlv %[[S2]], %[[S3]] +// CHECK: %[[C0:.*]], %[[C2:.*]] = pto.vdintlv %[[A0]], %[[B0]] +// CHECK: %[[C1:.*]], %[[C3:.*]] = pto.vdintlv %[[A1]], %[[B1]] +// CHECK: return %[[C0]], %[[C1]], %[[C2]], %[[C3]] +// CHECK-LABEL: func.func @vmi_channel_split_deint2_identity( +// CHECK-SAME: %[[P0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[P1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-NOT: pto.vdintlv +// CHECK: return %[[P0]], %[[P1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-LABEL: func.func @vmi_channel_merge_deint2_identity( +// CHECK-SAME: %[[M0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[M1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-NOT: pto.vintlv +// CHECK: return %[[M0]], %[[M1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_channel_split_merge_tail.pto b/test/lit/vmi/vmi_to_vpto_channel_split_merge_tail.pto new file mode 100644 index 0000000000..25afa0d016 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_channel_split_merge_tail.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_channel_split_merge2_tail( + %src: !pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> { + %ch0, %ch1 = "pto.vmi.channel_split"(%src) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vmi.vreg<50xf32, #pto.vmi.layout>, + !pto.vmi.vreg<50xf32, #pto.vmi.layout>) + %merged = "pto.vmi.channel_merge"(%ch0, %ch1) + : (!pto.vmi.vreg<50xf32, #pto.vmi.layout>, + !pto.vmi.vreg<50xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + return %merged : !pto.vmi.vreg<100xf32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_channel_split_merge2_tail( +// CHECK-SAME: %[[S0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[S1:[^)]+]]: !pto.vreg<64xf32> +// CHECK: %[[CH0:.*]], %[[CH1:.*]] = pto.vdintlv %[[S0]], %[[S1]] +// CHECK: %[[D0:.*]], %[[D1:.*]] = pto.vintlv %[[CH0]], %[[CH1]] +// CHECK: return %[[D0]], %[[D1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_channel_split_partial_group_invalid.pto b/test/lit/vmi/vmi_to_vpto_channel_split_partial_group_invalid.pto new file mode 100644 index 0000000000..f45b7cdfda --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_channel_split_partial_group_invalid.pto @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_channel_split_partial_group_invalid( + %src: !pto.vmi.vreg<4xf32, #pto.vmi.layout>) { + %ch0, %ch1 = "pto.vmi.channel_split"(%src) + : (!pto.vmi.vreg<4xf32, #pto.vmi.layout>) + -> (!pto.vmi.vreg<2xf32, #pto.vmi.layout>, + !pto.vmi.vreg<2xf32, #pto.vmi.layout>) + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.channel_split requires source layout to be contiguous or matching deinterleaved channel layout +// CHECK-SAME: complete physical channel groups +// CHECK-SAME: requires source and result to have the same physical arity diff --git a/test/lit/vmi/vmi_to_vpto_cmp_element_type_invalid.pto b/test/lit/vmi/vmi_to_vpto_cmp_element_type_invalid.pto new file mode 100644 index 0000000000..100f4b7378 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_cmp_element_type_invalid.pto @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_cmpf_f8_invalid( + %lhs: !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>) { + %mask = pto.vmi.cmpf "lt", %lhs, %rhs + : !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>, + !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + -> !pto.vmi.mask<256xb8, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.cmpf direct lowering requires f16/bf16/f32 element type +// CHECK-SAME: requires f16/bf16/f32 element type for direct VPTO lowering diff --git a/test/lit/vmi/vmi_to_vpto_cmp_predicate_unsupported_invalid.pto b/test/lit/vmi/vmi_to_vpto_cmp_predicate_unsupported_invalid.pto new file mode 100644 index 0000000000..8689bc8312 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_cmp_predicate_unsupported_invalid.pto @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_cmp_predicate_unsupported_invalid( + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %mask = pto.vmi.cmpf "uno", %a, %b + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %m0, %m1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %m0, %m1 : !pto.mask, !pto.mask + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} compare predicate uno cannot be lowered to pto.vcmp +// CHECK-SAME: supported predicates are eq/ne/lt/le/gt/ge, ordered FP forms diff --git a/test/lit/vmi/vmi_to_vpto_cmp_select.pto b/test/lit/vmi/vmi_to_vpto_cmp_select.pto new file mode 100644 index 0000000000..816913c8b2 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_cmp_select.pto @@ -0,0 +1,140 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_cmpf_select( + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %mask = pto.vmi.cmpf "lt", %a, %b + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %selected = pto.vmi.select %mask, %a, %b + : !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %m0, %m1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + %p0, %p1 = "pto.vmi.unpack"(%selected) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %m0, %m1, %p0, %p1 + : !pto.mask, !pto.mask, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_cmpi( + %a: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %mask = pto.vmi.cmpi "ge", %a, %b + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %m0, %m1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %m0, %m1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_cmpf_ordered_predicate( + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %mask = pto.vmi.cmpf "olt", %a, %b + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %m0, %m1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %m0, %m1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_cmpi_signed_predicate( + %a: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %mask = pto.vmi.cmpi "slt", %a, %b + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %m0, %m1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %m0, %m1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_cmpf_bf16( + %a: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>) + -> !pto.mask { + %mask = pto.vmi.cmpf "oge", %a, %b + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb16, #pto.vmi.layout>) + -> !pto.mask + return %part : !pto.mask + } + + func.func @vmi_to_vpto_cmpi_ui16( + %a: !pto.vmi.vreg<128xui16, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xui16, #pto.vmi.layout>) + -> !pto.mask { + %mask = pto.vmi.cmpi "eq", %a, %b + : !pto.vmi.vreg<128xui16, #pto.vmi.layout>, + !pto.vmi.vreg<128xui16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb16, #pto.vmi.layout>) + -> !pto.mask + return %part : !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_cmpf_select( +// CHECK: %[[FM0:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[CM0:.*]] = pto.vcmp {{.*}}, {{.*}}, %[[FM0]], "lt" +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +// CHECK: %[[FM1:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[CM1:.*]] = pto.vcmp {{.*}}, {{.*}}, %[[FM1]], "lt" +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +// CHECK: pto.vsel {{.*}}, {{.*}}, %[[CM0]] +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vsel {{.*}}, {{.*}}, %[[CM1]] +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-LABEL: func.func @vmi_to_vpto_cmpi( +// CHECK: pto.vcmp {{.*}}, {{.*}}, {{.*}}, "ge" +// CHECK-SAME: !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.mask +// CHECK: pto.vcmp {{.*}}, {{.*}}, {{.*}}, "ge" +// CHECK-SAME: !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.mask +// CHECK-LABEL: func.func @vmi_to_vpto_cmpf_ordered_predicate( +// CHECK: pto.vcmp {{.*}}, {{.*}}, {{.*}}, "lt" +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +// CHECK: pto.vcmp {{.*}}, {{.*}}, {{.*}}, "lt" +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +// CHECK-LABEL: func.func @vmi_to_vpto_cmpi_signed_predicate( +// CHECK: pto.vcmp {{.*}}, {{.*}}, {{.*}}, "lt" +// CHECK-SAME: !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.mask +// CHECK: pto.vcmp {{.*}}, {{.*}}, {{.*}}, "lt" +// CHECK-SAME: !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.mask +// CHECK-LABEL: func.func @vmi_to_vpto_cmpf_bf16( +// CHECK: pto.vcmp {{.*}}, {{.*}}, {{.*}}, "ge" +// CHECK-SAME: !pto.vreg<128xbf16>, !pto.vreg<128xbf16>, !pto.mask -> !pto.mask +// CHECK-LABEL: func.func @vmi_to_vpto_cmpi_ui16( +// CHECK: pto.vcmp {{.*}}, {{.*}}, {{.*}}, "eq" +// CHECK-SAME: !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_cmpi_unsigned_predicate_unsupported_invalid.pto b/test/lit/vmi/vmi_to_vpto_cmpi_unsigned_predicate_unsupported_invalid.pto new file mode 100644 index 0000000000..23b1e7f88f --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_cmpi_unsigned_predicate_unsupported_invalid.pto @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_cmpi_unsigned_predicate_unsupported_invalid( + %a: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %mask = pto.vmi.cmpi "ult", %a, %b + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %m0, %m1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %m0, %m1 : !pto.mask, !pto.mask + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} compare predicate ult cannot be lowered to pto.vcmp +// CHECK-SAME: signed integer forms slt/sle/sgt/sge diff --git a/test/lit/vmi/vmi_to_vpto_compaction_deint_invalid.pto b/test/lit/vmi/vmi_to_vpto_compaction_deint_invalid.pto new file mode 100644 index 0000000000..b4b2af9879 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_compaction_deint_invalid.pto @@ -0,0 +1,58 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_active_prefix_index_deint_invalid( + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + %idx = pto.vmi.active_prefix_index %mask + : !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xi32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.active_prefix_index lowers through pto.vusqz only for one contiguous physical chunk +// CHECK-SAME: requires contiguous mask and result layouts + +// ----- + +module { + func.func @vmi_to_vpto_compress_deint_invalid( + %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + %out = pto.vmi.compress %source, %mask + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.compress lowers through pto.vsqz only for one contiguous full physical chunk +// CHECK-SAME: requires contiguous source, mask, and result layouts + +// ----- + +module { + func.func @vmi_to_vpto_compress_store_deint_invalid( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + pto.vmi.compress_store %value, %dst[%offset], %mask + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.compress_store lowers through pto.vsqz + pto.vstur only for one contiguous full physical chunk +// CHECK-SAME: requires contiguous value and mask layouts diff --git a/test/lit/vmi/vmi_to_vpto_compress.pto b/test/lit/vmi/vmi_to_vpto_compress.pto new file mode 100644 index 0000000000..aba4da0228 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_compress.pto @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_compress( + %src: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.compress %src, %mask + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_compress( +// CHECK: %[[OUT:.*]] = pto.vsqz %arg0, %arg1 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_compress_multichunk_invalid.pto b/test/lit/vmi/vmi_to_vpto_compress_multichunk_invalid.pto new file mode 100644 index 0000000000..3122bbb0ee --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_compress_multichunk_invalid.pto @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_compress_multichunk_invalid( + %src: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %out = pto.vmi.compress %src, %mask + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.compress lowers through pto.vsqz only for one contiguous full physical chunk +// CHECK-SAME: multi-chunk compress needs cross-chunk compaction diff --git a/test/lit/vmi/vmi_to_vpto_compress_store.pto b/test/lit/vmi/vmi_to_vpto_compress_store.pto new file mode 100644 index 0000000000..edf8565c5f --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_compress_store.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_compress_store( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + pto.vmi.compress_store %value, %dst[%offset], %mask + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_compress_store( +// CHECK: %[[BASE:.*]] = pto.addptr %arg1, %arg2 +// CHECK: %[[SQZ:.*]] = pto.vsqz %arg0, %arg3 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[ALIGN0:.*]] = pto.init_align : !pto.align +// CHECK: %[[ALIGN1:.*]] = pto.vstur %[[ALIGN0]], %[[SQZ]], %[[BASE]], "POST_UPDATE" : !pto.align, !pto.vreg<64xf32>, !pto.ptr -> !pto.align +// CHECK: pto.vstar %[[ALIGN1]], %[[BASE]] : !pto.align, !pto.ptr +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_compress_store_multichunk_invalid.pto b/test/lit/vmi/vmi_to_vpto_compress_store_multichunk_invalid.pto new file mode 100644 index 0000000000..e4fc4738cc --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_compress_store_multichunk_invalid.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_compress_store_multichunk_invalid( + %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { + pto.vmi.compress_store %value, %dst[%offset], %mask + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.compress_store lowers through pto.vsqz + pto.vstur only for one contiguous full physical chunk +// CHECK-SAME: multi-chunk compress_store needs cross-chunk compaction diff --git a/test/lit/vmi/vmi_to_vpto_compress_tail_invalid.pto b/test/lit/vmi/vmi_to_vpto_compress_tail_invalid.pto new file mode 100644 index 0000000000..4d97cf831d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_compress_tail_invalid.pto @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_compress_tail_invalid( + %src: !pto.vmi.vreg<4xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<4xb32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.compress %src, %mask + : !pto.vmi.vreg<4xf32, #pto.vmi.layout>, + !pto.vmi.mask<4xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<4xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<4xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.compress lowers through pto.vsqz only for one contiguous full physical chunk +// CHECK-SAME: padding mask lanes cannot be squeezed into the result diff --git a/test/lit/vmi/vmi_to_vpto_constant.pto b/test/lit/vmi/vmi_to_vpto_constant.pto new file mode 100644 index 0000000000..c5c93bf2db --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_constant.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_constant_splat() + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %value = "pto.vmi.constant"() { + value = dense<1.000000e+00> : tensor<128xf32> + } : () -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_constant_splat +// CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[M0:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[P0:.*]] = pto.vdup %[[CST]], %[[M0]] : f32, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[M1:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[P1:.*]] = pto.vdup %[[CST]], %[[M1]] : f32, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_constant_mask.pto b/test/lit/vmi/vmi_to_vpto_constant_mask.pto new file mode 100644 index 0000000000..9c38a62148 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_constant_mask.pto @@ -0,0 +1,128 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_constant_mask_all_true() + -> (!pto.mask, !pto.mask) { + %mask = "pto.vmi.constant_mask"() { + value = dense : tensor<128xi1> + } : () -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_constant_mask_all_false() + -> (!pto.mask, !pto.mask) { + %mask = "pto.vmi.constant_mask"() { + value = dense : tensor<128xi1> + } : () -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_constant_mask_b8_all_true() + -> (!pto.mask, !pto.mask) { + %mask = "pto.vmi.constant_mask"() { + value = dense : tensor<512xi1> + } : () -> !pto.vmi.mask<512xb8, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<512xb8, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_constant_mask_b16_all_false() + -> (!pto.mask, !pto.mask) { + %mask = "pto.vmi.constant_mask"() { + value = dense : tensor<256xi1> + } : () -> !pto.vmi.mask<256xb16, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<256xb16, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_constant_mask_plt_fallback() + -> !pto.mask { + %mask = "pto.vmi.constant_mask"() { + value = dense<[true, true, true, true, true, false, false, false]> : tensor<8xi1> + } : () -> !pto.vmi.mask<8xb32, #pto.vmi.layout> + %p0 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<8xb32, #pto.vmi.layout>) -> !pto.mask + return %p0 : !pto.mask + } + + func.func @vmi_to_vpto_constant_mask_deinterleaved() + -> (!pto.mask, !pto.mask) { + %mask = "pto.vmi.constant_mask"() { + value = dense<[true, false, true, false, false, true, false, true]> : tensor<8xi1> + } : () -> !pto.vmi.mask<8xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<8xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_constant_mask_all_true +// CHECK: %[[M0:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[M1:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: return %[[M0]], %[[M1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_constant_mask_all_false +// CHECK: %[[F0:.*]] = pto.pset_b32 "PAT_ALLF" : !pto.mask +// CHECK: %[[F1:.*]] = pto.pset_b32 "PAT_ALLF" : !pto.mask +// CHECK: return %[[F0]], %[[F1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_constant_mask_b8_all_true +// CHECK: %[[B8_0:.*]] = pto.pset_b8 "PAT_ALL" : !pto.mask +// CHECK: %[[B8_1:.*]] = pto.pset_b8 "PAT_ALL" : !pto.mask +// CHECK: return %[[B8_0]], %[[B8_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_constant_mask_b16_all_false +// CHECK: %[[B16_0:.*]] = pto.pset_b16 "PAT_ALLF" : !pto.mask +// CHECK: %[[B16_1:.*]] = pto.pset_b16 "PAT_ALLF" : !pto.mask +// CHECK: return %[[B16_0]], %[[B16_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_constant_mask_plt_fallback +// CHECK: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK: %[[P0:.*]], %{{.*}} = pto.plt_b32 %[[C5]] : i32 -> !pto.mask, i32 +// CHECK: return %[[P0]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_constant_mask_deinterleaved +// CHECK: %[[PART0:.*]] = pto.pset_b32 "PAT_VL2" : !pto.mask +// CHECK: %[[ALL:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[P4:.*]] = pto.pset_b32 "PAT_VL4" : !pto.mask +// CHECK: %[[P2:.*]] = pto.pset_b32 "PAT_VL2" : !pto.mask +// CHECK: %[[NOT_P2:.*]] = pto.pnot %[[P2]], %[[ALL]] : !pto.mask, !pto.mask -> !pto.mask +// CHECK: %[[PART1:.*]] = pto.pand %[[P4]], %[[NOT_P2]], %[[ALL]] : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK: return %[[PART0]], %[[PART1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_constant_mask_nonprefix.pto b/test/lit/vmi/vmi_to_vpto_constant_mask_nonprefix.pto new file mode 100644 index 0000000000..cc3f439e62 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_constant_mask_nonprefix.pto @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_constant_mask_nonprefix() + -> !pto.mask { + %mask = "pto.vmi.constant_mask"() { + value = dense<[true, false, true, false]> : tensor<4xi1> + } : () -> !pto.vmi.mask<4xb32, #pto.vmi.layout> + %p0 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<4xb32, #pto.vmi.layout>) -> !pto.mask + return %p0 : !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_constant_mask_nonprefix +// CHECK: %[[ALL:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[RUN0:.*]] = pto.pset_b32 "PAT_VL1" : !pto.mask +// CHECK: %[[P3:.*]] = pto.pset_b32 "PAT_VL3" : !pto.mask +// CHECK: %[[P2:.*]] = pto.pset_b32 "PAT_VL2" : !pto.mask +// CHECK: %[[NOT_P2:.*]] = pto.pnot %[[P2]], %[[ALL]] : !pto.mask, !pto.mask -> !pto.mask +// CHECK: %[[RUN1:.*]] = pto.pand %[[P3]], %[[NOT_P2]], %[[ALL]] : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK: %[[OUT:.*]] = pto.por %[[RUN0]], %[[RUN1]], %[[ALL]] : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_constant_mask_rematerialize.pto b/test/lit/vmi/vmi_to_vpto_constant_mask_rematerialize.pto new file mode 100644 index 0000000000..3b2fc0d080 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_constant_mask_rematerialize.pto @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_constant_mask_rematerialize( + %a16: !pto.vmi.vreg<128xf16>, + %b16: !pto.vmi.vreg<128xf16>, + %a32: !pto.vmi.vreg<128xf32>, + %b32: !pto.vmi.vreg<128xf32>) + -> (!pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf32>) { + %mask = "pto.vmi.constant_mask"() { + value = dense : tensor<128xi1> + } : () -> !pto.vmi.mask<128xpred> + %sel16 = pto.vmi.select %mask, %a16, %b16 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf16> + -> !pto.vmi.vreg<128xf16> + %sel32 = pto.vmi.select %mask, %a32, %b32 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %sel16, %sel32 + : !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_constant_mask_rematerialize( +// CHECK: %[[M32_0:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[M32_1:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[M16:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: %[[S16:.*]] = pto.vsel %arg0, %arg1, %[[M16]] +// CHECK: %[[S32_0:.*]] = pto.vsel %arg2, %arg4, %[[M32_0]] +// CHECK: %[[S32_1:.*]] = pto.vsel %arg3, %arg5, %[[M32_1]] +// CHECK: return %[[S16]], %[[S32_0]], %[[S32_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_constant_nonsplat_invalid.pto b/test/lit/vmi/vmi_to_vpto_constant_nonsplat_invalid.pto new file mode 100644 index 0000000000..1d9fe8377e --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_constant_nonsplat_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_constant_nonsplat_invalid() + -> (!pto.vreg<64xf32>) { + %value = "pto.vmi.constant"() { + value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> + } : () -> !pto.vmi.vreg<4xf32, #pto.vmi.layout> + %p0 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<4xf32, #pto.vmi.layout>) -> !pto.vreg<64xf32> + return %p0 : !pto.vreg<64xf32> + } +} + +// CHECK: VMI{{.*}}non-splat pto.vmi.constant requires a vreg immediate or scratch materialization plan diff --git a/test/lit/vmi/vmi_to_vpto_construction_width_invalid.pto b/test/lit/vmi/vmi_to_vpto_construction_width_invalid.pto new file mode 100644 index 0000000000..1c6fdea4b0 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_construction_width_invalid.pto @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_broadcast_f64_unsupported(%scalar: f64) { + %value = pto.vmi.broadcast %scalar + : f64 -> !pto.vmi.vreg<32xf64, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.broadcast direct lowering requires physical vreg parts with b8/b16/b32 predicate masks +// CHECK-SAME: requires an 8/16/32-bit element type + +// ----- + +module { + func.func @vmi_constant_f64_unsupported() { + %value = "pto.vmi.constant"() { + value = dense<1.000000e+00> : tensor<32xf64> + } : () -> !pto.vmi.vreg<32xf64, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.constant direct lowering requires physical vreg parts with b8/b16/b32 predicate masks +// CHECK-SAME: requires an 8/16/32-bit element type diff --git a/test/lit/vmi/vmi_to_vpto_create_mask.pto b/test/lit/vmi/vmi_to_vpto_create_mask.pto new file mode 100644 index 0000000000..63417a8a99 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_create_mask.pto @@ -0,0 +1,87 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_create_mask_contiguous() + -> (!pto.mask, !pto.mask) { + %active = arith.constant 96 : index + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_create_mask_deint2() + -> (!pto.mask, !pto.mask) { + %active = arith.constant 64 : index + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_create_mask_b8_contiguous() + -> (!pto.mask, !pto.mask) { + %active = arith.constant 320 : index + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<512xb8, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<512xb8, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_create_mask_b16_deint2() + -> (!pto.mask, !pto.mask) { + %active = arith.constant 64 : index + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<256xb16, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<256xb16, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_contiguous +// CHECK: %[[M0:.*]] = pto.pge_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[M1:.*]] = pto.pge_b32 "PAT_VL32" : !pto.mask +// CHECK: return %[[M0]], %[[M1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_deint2 +// CHECK: %[[P0:.*]] = pto.pge_b32 "PAT_VL32" : !pto.mask +// CHECK: %[[P1:.*]] = pto.pge_b32 "PAT_VL32" : !pto.mask +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_b8_contiguous +// CHECK: %[[B8_0:.*]] = pto.pge_b8 "PAT_ALL" : !pto.mask +// CHECK: %[[B8_1:.*]] = pto.pge_b8 "PAT_VL64" : !pto.mask +// CHECK: return %[[B8_0]], %[[B8_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_b16_deint2 +// CHECK: %[[B16_0:.*]] = pto.pge_b16 "PAT_VL32" : !pto.mask +// CHECK: %[[B16_1:.*]] = pto.pge_b16 "PAT_VL32" : !pto.mask +// CHECK: return %[[B16_0]], %[[B16_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_create_mask_dynamic.pto b/test/lit/vmi/vmi_to_vpto_create_mask_dynamic.pto new file mode 100644 index 0000000000..c702d80529 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_create_mask_dynamic.pto @@ -0,0 +1,132 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_create_mask_dynamic_contiguous(%active: index) + -> (!pto.mask, !pto.mask) { + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_create_mask_dynamic_deint2(%active: index) + -> (!pto.mask, !pto.mask) { + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_create_mask_dynamic_deint4(%active: index) + -> (!pto.mask, !pto.mask, !pto.mask, !pto.mask) { + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<256xb32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<256xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask, !pto.mask, !pto.mask) + return %p0, %p1, %p2, %p3 + : !pto.mask, !pto.mask, !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_create_mask_dynamic_b8_contiguous(%active: index) + -> (!pto.mask, !pto.mask) { + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<512xb8, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<512xb8, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_create_mask_dynamic_b16_deint2(%active: index) + -> (!pto.mask, !pto.mask) { + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<256xb16, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<256xb16, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_dynamic_contiguous +// CHECK: %[[ACTIVE:.*]] = arith.index_cast %arg0 : index to i32 +// CHECK: %[[NONNEG:.*]] = arith.maxsi %[[ACTIVE]], {{.*}} : i32 +// CHECK: %[[CLAMPED:.*]] = arith.minui %[[NONNEG]], {{.*}} : i32 +// CHECK: %[[P0:.*]], %[[REM:.*]] = pto.plt_b32 %[[CLAMPED]] : i32 -> !pto.mask, i32 +// CHECK: %[[P1:.*]], %{{.*}} = pto.plt_b32 %[[REM]] : i32 -> !pto.mask, i32 +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_dynamic_deint2 +// CHECK: %[[ACTIVE2:.*]] = arith.index_cast %arg0 : index to i32 +// CHECK: %[[NONNEG2:.*]] = arith.maxsi %[[ACTIVE2]], {{.*}} : i32 +// CHECK: %[[CLAMPED2:.*]] = arith.minui %[[NONNEG2]], {{.*}} : i32 +// CHECK: %[[BIAS2:.*]] = arith.addi %[[CLAMPED2]], {{.*}} : i32 +// CHECK: %[[PART0:.*]] = arith.divui %[[BIAS2]], {{.*}} : i32 +// CHECK: %[[P2_0:.*]], %{{.*}} = pto.plt_b32 %[[PART0]] : i32 -> !pto.mask, i32 +// CHECK: %[[PART1:.*]] = arith.divui %[[CLAMPED2]], {{.*}} : i32 +// CHECK: %[[P2_1:.*]], %{{.*}} = pto.plt_b32 %[[PART1]] : i32 -> !pto.mask, i32 +// CHECK: return %[[P2_0]], %[[P2_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_dynamic_deint4 +// CHECK: %[[ACTIVE4:.*]] = arith.index_cast %arg0 : index to i32 +// CHECK: %[[NONNEG4:.*]] = arith.maxsi %[[ACTIVE4]], {{.*}} : i32 +// CHECK: %[[CLAMPED4:.*]] = arith.minui %[[NONNEG4]], {{.*}} : i32 +// CHECK: %[[BIAS4_0:.*]] = arith.addi %[[CLAMPED4]], {{.*}} : i32 +// CHECK: %[[PART4_0:.*]] = arith.divui %[[BIAS4_0]], {{.*}} : i32 +// CHECK: %[[P4_0:.*]], %{{.*}} = pto.plt_b32 %[[PART4_0]] : i32 -> !pto.mask, i32 +// CHECK: %[[BIAS4_1:.*]] = arith.addi %[[CLAMPED4]], {{.*}} : i32 +// CHECK: %[[PART4_1:.*]] = arith.divui %[[BIAS4_1]], {{.*}} : i32 +// CHECK: %[[P4_1:.*]], %{{.*}} = pto.plt_b32 %[[PART4_1]] : i32 -> !pto.mask, i32 +// CHECK: %[[BIAS4_2:.*]] = arith.addi %[[CLAMPED4]], {{.*}} : i32 +// CHECK: %[[PART4_2:.*]] = arith.divui %[[BIAS4_2]], {{.*}} : i32 +// CHECK: %[[P4_2:.*]], %{{.*}} = pto.plt_b32 %[[PART4_2]] : i32 -> !pto.mask, i32 +// CHECK: %[[PART4_3:.*]] = arith.divui %[[CLAMPED4]], {{.*}} : i32 +// CHECK: %[[P4_3:.*]], %{{.*}} = pto.plt_b32 %[[PART4_3]] : i32 -> !pto.mask, i32 +// CHECK: return %[[P4_0]], %[[P4_1]], %[[P4_2]], %[[P4_3]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_dynamic_b8_contiguous +// CHECK: %[[ACTIVE8:.*]] = arith.index_cast %arg0 : index to i32 +// CHECK: %[[NONNEG8:.*]] = arith.maxsi %[[ACTIVE8]], {{.*}} : i32 +// CHECK: %[[CLAMPED8:.*]] = arith.minui %[[NONNEG8]], {{.*}} : i32 +// CHECK: %[[P8_0:.*]], %[[REM8:.*]] = pto.plt_b8 %[[CLAMPED8]] : i32 -> !pto.mask, i32 +// CHECK: %[[P8_1:.*]], %{{.*}} = pto.plt_b8 %[[REM8]] : i32 -> !pto.mask, i32 +// CHECK: return %[[P8_0]], %[[P8_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_dynamic_b16_deint2 +// CHECK: %[[ACTIVE16:.*]] = arith.index_cast %arg0 : index to i32 +// CHECK: %[[NONNEG16:.*]] = arith.maxsi %[[ACTIVE16]], {{.*}} : i32 +// CHECK: %[[CLAMPED16:.*]] = arith.minui %[[NONNEG16]], {{.*}} : i32 +// CHECK: %[[BIAS16:.*]] = arith.addi %[[CLAMPED16]], {{.*}} : i32 +// CHECK: %[[PART16_0:.*]] = arith.divui %[[BIAS16]], {{.*}} : i32 +// CHECK: %[[P16_0:.*]], %{{.*}} = pto.plt_b16 %[[PART16_0]] : i32 -> !pto.mask, i32 +// CHECK: %[[PART16_1:.*]] = arith.divui %[[CLAMPED16]], {{.*}} : i32 +// CHECK: %[[P16_1:.*]], %{{.*}} = pto.plt_b16 %[[PART16_1]] : i32 -> !pto.mask, i32 +// CHECK: return %[[P16_0]], %[[P16_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_create_mask_plt_fallback.pto b/test/lit/vmi/vmi_to_vpto_create_mask_plt_fallback.pto new file mode 100644 index 0000000000..8cd9cd051c --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_create_mask_plt_fallback.pto @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_create_mask_plt_fallback() + -> !pto.mask { + %active = arith.constant 5 : index + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<64xb32, #pto.vmi.layout> + %p0 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<64xb32, #pto.vmi.layout>) + -> !pto.mask + return %p0 : !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_plt_fallback( +// CHECK: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK: %[[MASK:.*]], %{{.*}} = pto.plt_b32 %[[C5]] : i32 -> !pto.mask, i32 +// CHECK: return %[[MASK]] : !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_create_mask_rematerialize.pto b/test/lit/vmi/vmi_to_vpto_create_mask_rematerialize.pto new file mode 100644 index 0000000000..74ef8194d5 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_create_mask_rematerialize.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_create_mask_rematerialize( + %active: index, + %a16: !pto.vmi.vreg<128xf16>, + %b16: !pto.vmi.vreg<128xf16>, + %a32: !pto.vmi.vreg<128xf32>, + %b32: !pto.vmi.vreg<128xf32>) + -> (!pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf32>) { + %mask = pto.vmi.create_mask %active : index -> !pto.vmi.mask<128xpred> + %sel16 = pto.vmi.select %mask, %a16, %b16 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf16> + -> !pto.vmi.vreg<128xf16> + %sel32 = pto.vmi.select %mask, %a32, %b32 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %sel16, %sel32 + : !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_rematerialize( +// CHECK: %[[ACTIVE32:.*]] = arith.index_cast %arg0 : index to i32 +// CHECK: %[[NONNEG32:.*]] = arith.maxsi %[[ACTIVE32]], {{.*}} : i32 +// CHECK: %[[CLAMP32:.*]] = arith.minui %[[NONNEG32]], {{.*}} : i32 +// CHECK: %[[M32_0:.*]], %[[REM32:.*]] = pto.plt_b32 %[[CLAMP32]] : i32 -> !pto.mask, i32 +// CHECK: %[[M32_1:.*]], %{{.*}} = pto.plt_b32 %[[REM32]] : i32 -> !pto.mask, i32 +// CHECK: %[[ACTIVE16:.*]] = arith.index_cast %arg0 : index to i32 +// CHECK: %[[NONNEG16:.*]] = arith.maxsi %[[ACTIVE16]], {{.*}} : i32 +// CHECK: %[[CLAMP16:.*]] = arith.minui %[[NONNEG16]], {{.*}} : i32 +// CHECK: %[[M16:.*]], %{{.*}} = pto.plt_b16 %[[CLAMP16]] : i32 -> !pto.mask, i32 +// CHECK: %[[S16:.*]] = pto.vsel %arg1, %arg2, %[[M16]] +// CHECK: %[[S32_0:.*]] = pto.vsel %arg3, %arg5, %[[M32_0]] +// CHECK: %[[S32_1:.*]] = pto.vsel %arg4, %arg6, %[[M32_1]] +// CHECK: return %[[S16]], %[[S32_0]], %[[S32_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_divf.pto b/test/lit/vmi/vmi_to_vpto_divf.pto new file mode 100644 index 0000000000..be21ba5fdc --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_divf.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_divf( + %lhs: !pto.vmi.vreg<128xf32>, + %rhs: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> { + %quotient = pto.vmi.divf %lhs, %rhs + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %quotient : !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_divf( +// CHECK-SAME: %[[LHS0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[LHS1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[RHS0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[RHS1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK-DAG: %[[DIV0:.*]] = pto.vdiv %[[LHS0]], %[[RHS0]], {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[DIV1:.*]] = pto.vdiv %[[LHS1]], %[[RHS1]], {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[DIV0]], %[[DIV1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_e2e_widen_add_store.pto b/test/lit/vmi/vmi_to_vpto_e2e_widen_add_store.pto new file mode 100644 index 0000000000..f88f15a8eb --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_e2e_widen_add_store.pto @@ -0,0 +1,74 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_f16_widen_add_store( + %src: !pto.ptr, %dst: !pto.ptr, %offset: index) { + %narrow = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> + %wide = pto.vmi.extf %narrow + : !pto.vmi.vreg<128xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %sum = pto.vmi.addf %wide, %wide + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + pto.vmi.store %sum, %dst[%offset] + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr + return + } + + func.func @vmi_to_vpto_f8_widen_add_store( + %src: !pto.ptr, %dst: !pto.ptr, %offset: index) { + %narrow = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + %wide = pto.vmi.extf %narrow + : !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %sum = pto.vmi.addf %wide, %wide + : !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + pto.vmi.store %sum, %dst[%offset] + : !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_f16_widen_add_store( +// CHECK: %[[NARROW:.*]] = pto.vlds %arg0[%arg2] : !pto.ptr -> !pto.vreg<128xf16> +// CHECK: %[[CVT_MASK:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: %[[EVEN:.*]] = pto.vcvt %[[NARROW]], %[[CVT_MASK]] {part = "EVEN"} +// CHECK: %[[ODD:.*]] = pto.vcvt %[[NARROW]], %[[CVT_MASK]] {part = "ODD"} +// CHECK: %[[ADD_MASK0:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[SUM0:.*]] = pto.vadd %[[EVEN]], %[[EVEN]], %[[ADD_MASK0]] +// CHECK: %[[ADD_MASK1:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[SUM1:.*]] = pto.vadd %[[ODD]], %[[ODD]], %[[ADD_MASK1]] +// CHECK: %[[STORE_MASK:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: pto.vstsx2 %[[SUM0]], %[[SUM1]], %arg1[%arg2], "INTLV_B32", %[[STORE_MASK]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_f8_widen_add_store( +// CHECK: %[[NARROW8:.*]] = pto.vlds %arg0[%arg2] : !pto.ptr -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vcvt %[[NARROW8]], {{.*}} {part = "P0"} +// CHECK: pto.vcvt %[[NARROW8]], {{.*}} {part = "P1"} +// CHECK: pto.vcvt %[[NARROW8]], {{.*}} {part = "P2"} +// CHECK: pto.vcvt %[[NARROW8]], {{.*}} {part = "P3"} +// CHECK-COUNT-4: pto.vadd +// CHECK: pto.vintlv +// CHECK: pto.vintlv +// CHECK: pto.vintlv +// CHECK: pto.vintlv +// CHECK-COUNT-4: pto.vsts +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_elementwise_width_invalid.pto b/test/lit/vmi/vmi_to_vpto_elementwise_width_invalid.pto new file mode 100644 index 0000000000..958e6f1f5a --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_elementwise_width_invalid.pto @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_addf_f64_unsupported( + %a: !pto.vmi.vreg<32xf64, #pto.vmi.layout>, + %b: !pto.vmi.vreg<32xf64, #pto.vmi.layout>) { + %sum = pto.vmi.addf %a, %b + : !pto.vmi.vreg<32xf64, #pto.vmi.layout>, + !pto.vmi.vreg<32xf64, #pto.vmi.layout> + -> !pto.vmi.vreg<32xf64, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.addf direct lowering requires f16/bf16/f32 element type and physical vreg parts with b8/b16/b32 predicate masks +// CHECK-SAME: requires an 8/16/32-bit element type + +// ----- + +module { + func.func @vmi_addi_index_unsupported( + %a: !pto.vmi.vreg<64xindex, #pto.vmi.layout>, + %b: !pto.vmi.vreg<64xindex, #pto.vmi.layout>) { + %sum = pto.vmi.addi %a, %b + : !pto.vmi.vreg<64xindex, #pto.vmi.layout>, + !pto.vmi.vreg<64xindex, #pto.vmi.layout> + -> !pto.vmi.vreg<64xindex, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.addi direct lowering requires physical vreg parts with b8/b16/b32 predicate masks +// CHECK-SAME: requires an 8/16/32-bit element type diff --git a/test/lit/vmi/vmi_to_vpto_ensure_identity.pto b/test/lit/vmi/vmi_to_vpto_ensure_identity.pto new file mode 100644 index 0000000000..783bc3428d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_identity.pto @@ -0,0 +1,80 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_ensure_identity( + %v: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %m: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask, !pto.mask) { + %ev = "pto.vmi.ensure_layout"(%v) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %em0 = "pto.vmi.ensure_mask_layout"(%m) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %em1 = "pto.vmi.ensure_mask_granularity"(%em0) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%ev) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + %pm0, %pm1 = "pto.vmi.unpack"(%em1) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1, %pm0, %pm1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_ensure_identity_tail( + %v: !pto.vmi.vreg<100xf32, #pto.vmi.layout>, + %m: !pto.vmi.mask<100xb32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.mask, !pto.mask) { + %ev = "pto.vmi.ensure_layout"(%v) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %em = "pto.vmi.ensure_mask_layout"(%m) + : (!pto.vmi.mask<100xb32, #pto.vmi.layout>) + -> !pto.vmi.mask<100xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%ev) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + %pm0, %pm1 = "pto.vmi.unpack"(%em) + : (!pto.vmi.mask<100xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1, %pm0, %pm1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.mask, !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_ensure_identity( +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK: return +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask, !pto.mask +// CHECK-NOT: pto.vmi.ensure +// CHECK-NOT: pto.vmi.unpack +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_ensure_identity_tail( +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK: return +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask, !pto.mask +// CHECK-NOT: pto.vmi.ensure +// CHECK-NOT: pto.vmi.unpack +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_ensure_layout_deint4.pto b/test/lit/vmi/vmi_to_vpto_ensure_layout_deint4.pto new file mode 100644 index 0000000000..4c15af9f19 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_layout_deint4.pto @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_ensure_layout_deint4_to_contiguous( + %input: !pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %dense = pto.vmi.ensure_layout %input + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%dense) + : (!pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_ensure_layout_contiguous_to_deint4( + %input: !pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %split = pto.vmi.ensure_layout %input + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%split) + : (!pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_ensure_layout_deint4_to_contiguous( +// CHECK: %[[A0:.*]], %[[A1:.*]] = pto.vintlv %arg0, %arg2 +// CHECK: %[[B0:.*]], %[[B1:.*]] = pto.vintlv %arg1, %arg3 +// CHECK: %[[D0:.*]], %[[D1:.*]] = pto.vintlv %[[A0]], %[[B0]] +// CHECK: %[[D2:.*]], %[[D3:.*]] = pto.vintlv %[[A1]], %[[B1]] +// CHECK: return %[[D0]], %[[D1]], %[[D2]], %[[D3]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_ensure_layout_contiguous_to_deint4( +// CHECK: %[[A0:.*]], %[[B0:.*]] = pto.vdintlv %arg0, %arg1 +// CHECK: %[[A1:.*]], %[[B1:.*]] = pto.vdintlv %arg2, %arg3 +// CHECK: %[[P0:.*]], %[[P2:.*]] = pto.vdintlv %[[A0]], %[[A1]] +// CHECK: %[[P1:.*]], %[[P3:.*]] = pto.vdintlv %[[B0]], %[[B1]] +// CHECK: return %[[P0]], %[[P1]], %[[P2]], %[[P3]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_ensure_layout_partial_invalid.pto b/test/lit/vmi/vmi_to_vpto_ensure_layout_partial_invalid.pto new file mode 100644 index 0000000000..bbfd4dcfd8 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_layout_partial_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_ensure_layout_partial_invalid( + %input: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %dense = pto.vmi.ensure_layout %input + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.ensure_layout cannot materialize the requested data layout conversion +// CHECK-SAME: requires source and result to have the same physical arity +// CHECK-SAME: partial/tail layout materialization requires an explicit packing plan diff --git a/test/lit/vmi/vmi_to_vpto_ensure_layout_vdintlv.pto b/test/lit/vmi/vmi_to_vpto_ensure_layout_vdintlv.pto new file mode 100644 index 0000000000..03661ac669 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_layout_vdintlv.pto @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_ensure_layout_contiguous_to_deint2( + %input: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %split = pto.vmi.ensure_layout %input + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%split) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_ensure_layout_contiguous_to_deint2( +// CHECK: %[[P0:.*]], %[[P1:.*]] = pto.vdintlv %arg0, %arg1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_ensure_layout_vintlv.pto b/test/lit/vmi/vmi_to_vpto_ensure_layout_vintlv.pto new file mode 100644 index 0000000000..e4506c86c2 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_layout_vintlv.pto @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_ensure_layout_deint2_to_contiguous( + %input: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %dense = pto.vmi.ensure_layout %input + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%dense) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_ensure_layout_deint2_tail_to_contiguous( + %input: !pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %dense = pto.vmi.ensure_layout %input + : !pto.vmi.vreg<100xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%dense) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_ensure_layout_deint2_to_contiguous( +// CHECK: %[[D0:.*]], %[[D1:.*]] = pto.vintlv %arg0, %arg1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: return %[[D0]], %[[D1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_ensure_layout_deint2_tail_to_contiguous( +// CHECK: %[[D0:.*]], %[[D1:.*]] = pto.vintlv %arg0, %arg1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: return %[[D0]], %[[D1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_ensure_mask_granularity.pto b/test/lit/vmi/vmi_to_vpto_ensure_mask_granularity.pto new file mode 100644 index 0000000000..989fd0cb74 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_mask_granularity.pto @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_ensure_mask_granularity( + %m: !pto.vmi.mask<128xpred>, + %a16: !pto.vmi.vreg<128xf16>, + %b16: !pto.vmi.vreg<128xf16>, + %a32: !pto.vmi.vreg<128xf32>, + %b32: !pto.vmi.vreg<128xf32>) + -> (!pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf32>) { + %sel16 = pto.vmi.select %m, %a16, %b16 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf16> + -> !pto.vmi.vreg<128xf16> + %sel32 = pto.vmi.select %m, %a32, %b32 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %sel16, %sel32 + : !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_ensure_mask_granularity( +// CHECK: %[[LO:.*]] = pto.ppack %arg0, "LOWER" : !pto.mask -> !pto.mask +// CHECK: %[[HI:.*]] = pto.ppack %arg1, "HIGHER" : !pto.mask -> !pto.mask +// CHECK: %[[ALL:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: %[[M16:.*]] = pto.por %[[LO]], %[[HI]], %[[ALL]] : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK: pto.vsel %arg2, %arg3, %[[M16]] +// CHECK: pto.vsel %arg4, %arg6, %arg0 +// CHECK: pto.vsel %arg5, %arg7, %arg1 +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_ensure_mask_granularity_direct.pto b/test/lit/vmi/vmi_to_vpto_ensure_mask_granularity_direct.pto new file mode 100644 index 0000000000..2512367b64 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_mask_granularity_direct.pto @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_ensure_mask_granularity_direct( + %m: !pto.vmi.mask<128xb16, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %result = pto.vmi.ensure_mask_granularity %m + : !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%result) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_ensure_mask_granularity_direct( +// CHECK: %[[P0:.*]] = pto.punpack %arg0, "LOWER" : !pto.mask -> !pto.mask +// CHECK: %[[P1:.*]] = pto.punpack %arg0, "HIGHER" : !pto.mask -> !pto.mask +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_ensure_mask_granularity_multistep.pto b/test/lit/vmi/vmi_to_vpto_ensure_mask_granularity_multistep.pto new file mode 100644 index 0000000000..29bb147489 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_mask_granularity_multistep.pto @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_ensure_mask_granularity_multistep( + %m: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.mask { + %result = pto.vmi.ensure_mask_granularity %m + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb8, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%result) + : (!pto.vmi.mask<128xb8, #pto.vmi.layout>) + -> !pto.mask + return %part : !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_ensure_mask_granularity_multistep( +// CHECK: %[[LO16:.*]] = pto.ppack %arg0, "LOWER" : !pto.mask -> !pto.mask +// CHECK: %[[HI16:.*]] = pto.ppack %arg1, "HIGHER" : !pto.mask -> !pto.mask +// CHECK: %[[ALL16:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: %[[M16:.*]] = pto.por %[[LO16]], %[[HI16]], %[[ALL16]] : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK: %[[M8:.*]] = pto.ppack %[[M16]], "LOWER" : !pto.mask -> !pto.mask +// CHECK: return %[[M8]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_ensure_mask_layout.pto b/test/lit/vmi/vmi_to_vpto_ensure_mask_layout.pto new file mode 100644 index 0000000000..17a644834b --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_mask_layout.pto @@ -0,0 +1,114 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_mask_deint2_to_contiguous( + %m: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %dense = pto.vmi.ensure_mask_layout %m + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%dense) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_mask_contiguous_to_deint2( + %m: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %deint = pto.vmi.ensure_mask_layout %m + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%deint) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_mask_deint2_tail_to_contiguous( + %m: !pto.vmi.mask<100xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %dense = pto.vmi.ensure_mask_layout %m + : !pto.vmi.mask<100xb32, #pto.vmi.layout> + -> !pto.vmi.mask<100xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%dense) + : (!pto.vmi.mask<100xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_mask_deint4_to_contiguous( + %m: !pto.vmi.mask<256xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask, !pto.mask, !pto.mask) { + %dense = pto.vmi.ensure_mask_layout %m + : !pto.vmi.mask<256xb32, #pto.vmi.layout> + -> !pto.vmi.mask<256xb32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%dense) + : (!pto.vmi.mask<256xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask, !pto.mask, !pto.mask) + return %p0, %p1, %p2, %p3 + : !pto.mask, !pto.mask, !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_mask_contiguous_to_deint4( + %m: !pto.vmi.mask<256xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask, !pto.mask, !pto.mask) { + %deint = pto.vmi.ensure_mask_layout %m + : !pto.vmi.mask<256xb32, #pto.vmi.layout> + -> !pto.vmi.mask<256xb32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%deint) + : (!pto.vmi.mask<256xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask, !pto.mask, !pto.mask) + return %p0, %p1, %p2, %p3 + : !pto.mask, !pto.mask, !pto.mask, !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_mask_deint2_to_contiguous( +// CHECK: %[[D0:.*]], %[[D1:.*]] = pto.pintlv_b32 %arg0, %arg1 +// CHECK: return %[[D0]], %[[D1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_mask_contiguous_to_deint2( +// CHECK: %[[P0:.*]], %[[P1:.*]] = pto.pdintlv_b32 %arg0, %arg1 +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_mask_deint2_tail_to_contiguous( +// CHECK: %[[D0:.*]], %[[D1:.*]] = pto.pintlv_b32 %arg0, %arg1 +// CHECK: return %[[D0]], %[[D1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_mask_deint4_to_contiguous( +// CHECK: %[[A0:.*]], %[[A1:.*]] = pto.pintlv_b32 %arg0, %arg2 +// CHECK: %[[B0:.*]], %[[B1:.*]] = pto.pintlv_b32 %arg1, %arg3 +// CHECK: %[[D0:.*]], %[[D1:.*]] = pto.pintlv_b32 %[[A0]], %[[B0]] +// CHECK: %[[D2:.*]], %[[D3:.*]] = pto.pintlv_b32 %[[A1]], %[[B1]] +// CHECK: return %[[D0]], %[[D1]], %[[D2]], %[[D3]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_mask_contiguous_to_deint4( +// CHECK: %[[A0:.*]], %[[B0:.*]] = pto.pdintlv_b32 %arg0, %arg1 +// CHECK: %[[A1:.*]], %[[B1:.*]] = pto.pdintlv_b32 %arg2, %arg3 +// CHECK: %[[P0:.*]], %[[P2:.*]] = pto.pdintlv_b32 %[[A0]], %[[A1]] +// CHECK: %[[P1:.*]], %[[P3:.*]] = pto.pdintlv_b32 %[[B0]], %[[B1]] +// CHECK: return %[[P0]], %[[P1]], %[[P2]], %[[P3]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_ensure_mask_layout_partial_invalid.pto b/test/lit/vmi/vmi_to_vpto_ensure_mask_layout_partial_invalid.pto new file mode 100644 index 0000000000..87edcee933 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_mask_layout_partial_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_ensure_mask_layout_partial_invalid( + %input: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { + %dense = pto.vmi.ensure_mask_layout %input + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.ensure_mask_layout cannot materialize the requested mask layout conversion +// CHECK-SAME: requires source and result to have the same physical arity +// CHECK-SAME: partial/tail predicate layout materialization requires an explicit packing plan diff --git a/test/lit/vmi/vmi_to_vpto_ensure_mask_layout_widths.pto b/test/lit/vmi/vmi_to_vpto_ensure_mask_layout_widths.pto new file mode 100644 index 0000000000..0c8b9a4120 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_mask_layout_widths.pto @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_mask_b8_deint2_to_contiguous( + %m: !pto.vmi.mask<512xb8, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %dense = pto.vmi.ensure_mask_layout %m + : !pto.vmi.mask<512xb8, #pto.vmi.layout> + -> !pto.vmi.mask<512xb8, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%dense) + : (!pto.vmi.mask<512xb8, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_mask_b8_contiguous_to_deint2( + %m: !pto.vmi.mask<512xb8, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %deint = pto.vmi.ensure_mask_layout %m + : !pto.vmi.mask<512xb8, #pto.vmi.layout> + -> !pto.vmi.mask<512xb8, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%deint) + : (!pto.vmi.mask<512xb8, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_mask_b16_deint2_to_contiguous( + %m: !pto.vmi.mask<256xb16, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %dense = pto.vmi.ensure_mask_layout %m + : !pto.vmi.mask<256xb16, #pto.vmi.layout> + -> !pto.vmi.mask<256xb16, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%dense) + : (!pto.vmi.mask<256xb16, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_mask_b16_contiguous_to_deint2( + %m: !pto.vmi.mask<256xb16, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %deint = pto.vmi.ensure_mask_layout %m + : !pto.vmi.mask<256xb16, #pto.vmi.layout> + -> !pto.vmi.mask<256xb16, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%deint) + : (!pto.vmi.mask<256xb16, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_mask_b8_deint2_to_contiguous( +// CHECK: %[[D0:.*]], %[[D1:.*]] = pto.pintlv_b8 %arg0, %arg1 +// CHECK: return %[[D0]], %[[D1]] + +// CHECK-LABEL: func.func @vmi_to_vpto_mask_b8_contiguous_to_deint2( +// CHECK: %[[P0:.*]], %[[P1:.*]] = pto.pdintlv_b8 %arg0, %arg1 +// CHECK: return %[[P0]], %[[P1]] + +// CHECK-LABEL: func.func @vmi_to_vpto_mask_b16_deint2_to_contiguous( +// CHECK: %[[D2:.*]], %[[D3:.*]] = pto.pintlv_b16 %arg0, %arg1 +// CHECK: return %[[D2]], %[[D3]] + +// CHECK-LABEL: func.func @vmi_to_vpto_mask_b16_contiguous_to_deint2( +// CHECK: %[[P2:.*]], %[[P3:.*]] = pto.pdintlv_b16 %arg0, %arg1 +// CHECK: return %[[P2]], %[[P3]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_expand_load_all_active.pto b/test/lit/vmi/vmi_to_vpto_expand_load_all_active.pto new file mode 100644 index 0000000000..836e33dec0 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_expand_load_all_active.pto @@ -0,0 +1,66 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_expand_load_all_active( + %src: !pto.ptr, + %offset: index, + %passthru: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %active = arith.constant 64 : index + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<64xb32, #pto.vmi.layout> + %out = pto.vmi.expand_load %src[%offset], %mask, %passthru + : !pto.ptr, + !pto.vmi.mask<64xb32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_expand_load_all_active_safe_tail_memref_nonzero_offset( + %src: memref<132xf32>, + %passthru: !pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %active = arith.constant 100 : index + %offset = arith.constant 4 : index + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<100xb32, #pto.vmi.layout> + %out = pto.vmi.expand_load %src[%offset], %mask, %passthru + : memref<132xf32>, + !pto.vmi.mask<100xb32, #pto.vmi.layout>, + !pto.vmi.vreg<100xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %part0, %part1 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %part0, %part1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_expand_load_all_active( +// CHECK: %[[LOAD:.*]] = pto.vlds %arg0[%arg1] : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: return %[[LOAD]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_expand_load_all_active_safe_tail_memref_nonzero_offset( +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C68:.*]] = arith.constant 68 : index +// CHECK: %[[P0:.*]] = pto.vlds %arg0[%[[C4]]] : memref<132xf32> -> !pto.vreg<64xf32> +// CHECK: %[[P1:.*]] = pto.vlds %arg0[%[[C68]]] : memref<132xf32> -> !pto.vreg<64xf32> +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_expand_load_all_active_negative_offset_invalid.pto b/test/lit/vmi/vmi_to_vpto_expand_load_all_active_negative_offset_invalid.pto new file mode 100644 index 0000000000..d8733b4641 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_expand_load_all_active_negative_offset_invalid.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_expand_load_all_active_negative_offset_invalid( + %src: memref<132xf32>, + %passthru: !pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %active = arith.constant 100 : index + %offset = arith.constant -1 : index + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<100xb32, #pto.vmi.layout> + %out = pto.vmi.expand_load %src[%offset], %mask, %passthru + : memref<132xf32>, + !pto.vmi.mask<100xb32, #pto.vmi.layout>, + !pto.vmi.vreg<100xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %part0, %part1 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %part0, %part1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.expand_load direct lowering is currently supported +// CHECK-SAME: all-active path requires full physical chunks or statically safe full-read footprint +// CHECK-SAME: safe-read proof requires non-negative offset +// CHECK-SAME: fallback decision: partial/tail read needs a scratch, guarded, or true masked/non-faulting load fallback diff --git a/test/lit/vmi/vmi_to_vpto_expand_load_partial_mask_invalid.pto b/test/lit/vmi/vmi_to_vpto_expand_load_partial_mask_invalid.pto new file mode 100644 index 0000000000..cdab169262 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_expand_load_partial_mask_invalid.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_expand_load_partial_mask_invalid( + %src: !pto.ptr, + %offset: index, + %passthru: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %active = arith.constant 4 : index + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %out = pto.vmi.expand_load %src[%offset], %mask, %passthru + : !pto.ptr, + !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %part0, %part1 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %part0, %part1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.expand_load direct lowering is currently supported +// CHECK-SAME: one physical chunk diff --git a/test/lit/vmi/vmi_to_vpto_expand_load_runtime_mask.pto b/test/lit/vmi/vmi_to_vpto_expand_load_runtime_mask.pto new file mode 100644 index 0000000000..7c9d8a3a5b --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_expand_load_runtime_mask.pto @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_expand_load_runtime_mask( + %src: !pto.ptr, + %offset: index, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.expand_load %src[%offset], %mask, %passthru + : !pto.ptr, + !pto.vmi.mask<64xb32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_expand_load_runtime_mask( +// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[ALL:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK-DAG: %[[BASE:.*]] = pto.addptr %arg0, %arg1 +// CHECK: %[[CARRIER:.*]] = pto.vdup %[[ZERO]], %[[ALL]] : i32, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[IDX:.*]] = pto.vusqz %[[CARRIER]], %arg2 : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[LOAD:.*]] = pto.vgather2_bc %[[BASE]], %[[IDX]], %arg2 : !pto.ptr, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[OUT:.*]] = pto.vsel %[[LOAD]], %arg3, %arg2 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_extf.pto b/test/lit/vmi/vmi_to_vpto_extf.pto new file mode 100644 index 0000000000..af4fbca903 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_extf.pto @@ -0,0 +1,74 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_extf_f16_to_f32( + %input: !pto.vmi.vreg<128xf16, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %wide = pto.vmi.extf %input + : !pto.vmi.vreg<128xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%wide) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_extf_f16_tail_to_f32( + %input: !pto.vmi.vreg<100xf16, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %wide = pto.vmi.extf %input + : !pto.vmi.vreg<100xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%wide) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_extf_bf16_to_f32( + %input: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %wide = pto.vmi.extf %input + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%wide) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_extf_f16_to_f32( +// CHECK-SAME: %[[INPUT:.*]]: !pto.vreg<128xf16> +// CHECK: %[[MASK:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "ODD"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_extf_f16_tail_to_f32( +// CHECK-SAME: %[[TAIL_INPUT:.*]]: !pto.vreg<128xf16> +// CHECK: %[[TAIL_MASK:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: pto.vcvt %[[TAIL_INPUT]], %[[TAIL_MASK]] {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt %[[TAIL_INPUT]], %[[TAIL_MASK]] {part = "ODD"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_extf_bf16_to_f32( +// CHECK-SAME: %[[BF16_INPUT:.*]]: !pto.vreg<128xbf16> +// CHECK: %[[BF16_MASK:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: pto.vcvt %[[BF16_INPUT]], %[[BF16_MASK]] {part = "EVEN"} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt %[[BF16_INPUT]], %[[BF16_MASK]] {part = "ODD"} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_extf_f8.pto b/test/lit/vmi/vmi_to_vpto_extf_f8.pto new file mode 100644 index 0000000000..c9ab157d0d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_extf_f8.pto @@ -0,0 +1,59 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_extf_f8_to_f32( + %input: !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %wide = pto.vmi.extf %input + : !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%wide) + : (!pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_extf_f8_tail_to_f32( + %input: !pto.vmi.vreg<100xf8E4M3FN, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %wide = pto.vmi.extf %input + : !pto.vmi.vreg<100xf8E4M3FN, #pto.vmi.layout> + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%wide) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_extf_f8_to_f32( +// CHECK-SAME: %[[INPUT:.*]]: !pto.vreg<256xf8E4M3FN> +// CHECK: %[[MASK:.*]] = pto.pset_b8 "PAT_ALL" : !pto.mask +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "P0"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "P1"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "P2"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "P3"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_extf_f8_tail_to_f32( +// CHECK-SAME: %[[TAIL_INPUT:.*]]: !pto.vreg<256xf8E4M3FN> +// CHECK: %[[TAIL_MASK:.*]] = pto.pset_b8 "PAT_ALL" : !pto.mask +// CHECK: pto.vcvt %[[TAIL_INPUT]], %[[TAIL_MASK]] {part = "P0"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt %[[TAIL_INPUT]], %[[TAIL_MASK]] {part = "P1"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt %[[TAIL_INPUT]], %[[TAIL_MASK]] {part = "P2"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt %[[TAIL_INPUT]], %[[TAIL_MASK]] {part = "P3"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_extf_multichunk.pto b/test/lit/vmi/vmi_to_vpto_extf_multichunk.pto new file mode 100644 index 0000000000..0803ccde1c --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_extf_multichunk.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_extf_multichunk( + %input: !pto.vmi.vreg<256xf16, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %wide = pto.vmi.extf %input + : !pto.vmi.vreg<256xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%wide) + : (!pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_extf_multichunk( +// CHECK: %[[MASK:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: %[[EVEN0:.*]] = pto.vcvt %arg0, %[[MASK]] {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[EVEN1:.*]] = pto.vcvt %arg1, %[[MASK]] {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[ODD0:.*]] = pto.vcvt %arg0, %[[MASK]] {part = "ODD"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[ODD1:.*]] = pto.vcvt %arg1, %[[MASK]] {part = "ODD"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[EVEN0]], %[[EVEN1]], %[[ODD0]], %[[ODD1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_fma.pto b/test/lit/vmi/vmi_to_vpto_fma.pto new file mode 100644 index 0000000000..d222c1ecb9 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_fma.pto @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_fma( + %lhs: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %acc: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.fma %lhs, %rhs, %acc + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_fma_f16( + %lhs: !pto.vmi.vreg<128xf16, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xf16, #pto.vmi.layout>, + %acc: !pto.vmi.vreg<128xf16, #pto.vmi.layout>) + -> !pto.vreg<128xf16> { + %out = pto.vmi.fma %lhs, %rhs, %acc + : !pto.vmi.vreg<128xf16, #pto.vmi.layout>, + !pto.vmi.vreg<128xf16, #pto.vmi.layout>, + !pto.vmi.vreg<128xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<128xf16, #pto.vmi.layout>) + -> !pto.vreg<128xf16> + return %part : !pto.vreg<128xf16> + } + + func.func @vmi_to_vpto_fma_bf16( + %lhs: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + %acc: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>) + -> !pto.vreg<128xbf16> { + %out = pto.vmi.fma %lhs, %rhs, %acc + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<128xbf16, #pto.vmi.layout>) + -> !pto.vreg<128xbf16> + return %part : !pto.vreg<128xbf16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_fma( +// CHECK: %[[MASK:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[OUT:.*]] = pto.vmula %arg2, %arg0, %arg1, %[[MASK]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_fma_f16( +// CHECK: %[[MASK16:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: %[[OUT16:.*]] = pto.vmula %arg2, %arg0, %arg1, %[[MASK16]] : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: return %[[OUT16]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_fma_bf16( +// CHECK: %[[MASKBF16:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: %[[OUTBF16:.*]] = pto.vmula %arg2, %arg0, %arg1, %[[MASKBF16]] : !pto.vreg<128xbf16>, !pto.vreg<128xbf16>, !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<128xbf16> +// CHECK: return %[[OUTBF16]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_fma_element_type_invalid.pto b/test/lit/vmi/vmi_to_vpto_fma_element_type_invalid.pto new file mode 100644 index 0000000000..877568258b --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_fma_element_type_invalid.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_fma_f8_invalid( + %lhs: !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>, + %acc: !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>) { + %out = pto.vmi.fma %lhs, %rhs, %acc + : !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>, + !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>, + !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.fma lowers through pto.vmula only for f16/bf16/f32 element types +// CHECK-SAME: requires f16, bf16, or f32 element type for pto.vmula diff --git a/test/lit/vmi/vmi_to_vpto_function_type_layout_free_invalid.pto b/test/lit/vmi/vmi_to_vpto_function_type_layout_free_invalid.pto new file mode 100644 index 0000000000..0fedf0d694 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_function_type_layout_free_invalid.pto @@ -0,0 +1,16 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func private @external(!pto.vmi.vreg<128xf32>) + -> !pto.vmi.vreg<128xf32> +} + +// CHECK: VMI-PASS-INVARIANT: vmi-to-vpto requires layout-assigned VMI types diff --git a/test/lit/vmi/vmi_to_vpto_gather.pto b/test/lit/vmi/vmi_to_vpto_gather.pto new file mode 100644 index 0000000000..d68e72c1d2 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_gather.pto @@ -0,0 +1,37 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_gather( + %src: !pto.ptr, + %indices: !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.gather %src[%indices], %mask, %passthru + : !pto.ptr, + !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_gather( +// CHECK: %[[GATHER:.*]] = pto.vgather2_bc %arg0, %arg1, %arg2 : !pto.ptr, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[OUT:.*]] = pto.vsel %[[GATHER]], %arg3, %arg2 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_gather_f16_invalid.pto b/test/lit/vmi/vmi_to_vpto_gather_f16_invalid.pto new file mode 100644 index 0000000000..83bf5db675 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_gather_f16_invalid.pto @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_gather_f16_invalid( + %src: !pto.ptr, + %indices: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb16, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<128xf16, #pto.vmi.layout>) { + %out = pto.vmi.gather %src[%indices], %mask, %passthru + : !pto.ptr, + !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.vreg<128xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.gather lowers through pto.vgather2_bc + pto.vsel only +// CHECK-SAME: 32-bit result element type diff --git a/test/lit/vmi/vmi_to_vpto_gather_scatter_shape_invalid.pto b/test/lit/vmi/vmi_to_vpto_gather_scatter_shape_invalid.pto new file mode 100644 index 0000000000..c271e9f446 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_gather_scatter_shape_invalid.pto @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_gather_deint_invalid( + %src: !pto.ptr, + %indices: !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) { + %out = pto.vmi.gather %src[%indices], %mask, %passthru + : !pto.ptr, + !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.gather lowers through pto.vgather2_bc + pto.vsel only +// CHECK-SAME: contiguous result, indices, passthru, and mask layouts + +// ----- + +module { + func.func @vmi_to_vpto_gather_tail_invalid( + %src: !pto.ptr, + %indices: !pto.vmi.vreg<32xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<32xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<32xf32, #pto.vmi.layout>) { + %out = pto.vmi.gather %src[%indices], %mask, %passthru + : !pto.ptr, + !pto.vmi.vreg<32xi32, #pto.vmi.layout>, + !pto.vmi.mask<32xb32, #pto.vmi.layout>, + !pto.vmi.vreg<32xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<32xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.gather lowers through pto.vgather2_bc + pto.vsel only +// CHECK-SAME: result requires full physical chunks +// CHECK-SAME: found padding lane in physical chunk + +// ----- + +module { + func.func @vmi_to_vpto_scatter_deint_invalid( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %indices: !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + pto.vmi.scatter %value, %dst[%indices], %mask {indices_unique} + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.scatter lowers through pto.vscatter only +// CHECK-SAME: contiguous value, indices, and mask layouts + +// ----- + +module { + func.func @vmi_to_vpto_scatter_tail_invalid( + %value: !pto.vmi.vreg<32xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %indices: !pto.vmi.vreg<32xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<32xb32, #pto.vmi.layout>) { + pto.vmi.scatter %value, %dst[%indices], %mask {indices_unique} + : !pto.vmi.vreg<32xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.vreg<32xi32, #pto.vmi.layout>, + !pto.vmi.mask<32xb32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.scatter lowers through pto.vscatter only +// CHECK-SAME: value requires full physical chunks +// CHECK-SAME: found padding lane in physical chunk diff --git a/test/lit/vmi/vmi_to_vpto_iota.pto b/test/lit/vmi/vmi_to_vpto_iota.pto new file mode 100644 index 0000000000..a46f767b59 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_iota.pto @@ -0,0 +1,120 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_iota_i32_asc(%base: i32) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) { + %value = pto.vmi.iota %base + : i32 -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) + return %p0, %p1 : !pto.vreg<64xi32>, !pto.vreg<64xi32> + } + + func.func @vmi_to_vpto_iota_i32_desc(%base: i32) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) { + %value = pto.vmi.iota %base {order = "DESC"} + : i32 -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) + return %p0, %p1 : !pto.vreg<64xi32>, !pto.vreg<64xi32> + } + + func.func @vmi_to_vpto_iota_i32_deint2_asc(%base: i32) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) { + %value = pto.vmi.iota %base + : i32 -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) + return %p0, %p1 : !pto.vreg<64xi32>, !pto.vreg<64xi32> + } + + func.func @vmi_to_vpto_iota_i16_asc(%base: i16) + -> !pto.vreg<128xi16> { + %value = pto.vmi.iota %base + : i16 -> !pto.vmi.vreg<128xi16, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<128xi16, #pto.vmi.layout>) + -> !pto.vreg<128xi16> + return %part : !pto.vreg<128xi16> + } + + func.func @vmi_to_vpto_iota_f16_deint2_asc(%base: f16) + -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>) { + %value = pto.vmi.iota %base + : f16 -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<256xf16, #pto.vmi.layout>) + -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>) + return %p0, %p1 : !pto.vreg<128xf16>, !pto.vreg<128xf16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_iota_i32_asc( +// CHECK: %[[C64:.*]] = arith.constant 64 : i32 +// CHECK: %[[P0:.*]] = pto.vci %arg0 : i32 -> !pto.vreg<64xi32> +// CHECK: %[[B1:.*]] = arith.addi %arg0, %[[C64]] : i32 +// CHECK: %[[P1:.*]] = pto.vci %[[B1]] : i32 -> !pto.vreg<64xi32> +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_iota_i32_desc( +// CHECK: %[[C64:.*]] = arith.constant 64 : i32 +// CHECK: %[[P0:.*]] = pto.vci %arg0 {order = "DESC"} : i32 -> !pto.vreg<64xi32> +// CHECK: %[[B1:.*]] = arith.subi %arg0, %[[C64]] : i32 +// CHECK: %[[P1:.*]] = pto.vci %[[B1]] {order = "DESC"} : i32 -> !pto.vreg<64xi32> +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_iota_i32_deint2_asc( +// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[FACTOR:.*]] = arith.constant 2 : i32 +// CHECK-DAG: %[[PART1:.*]] = arith.constant 1 : i32 +// CHECK: %[[LOCAL0:.*]] = pto.vci %[[ZERO]] : i32 -> !pto.vreg<64xi32> +// CHECK: %[[SCALED0:.*]] = pto.vmuls %[[LOCAL0]], %[[FACTOR]], +// CHECK: %[[P0:.*]] = pto.vadds %[[SCALED0]], %arg0, +// CHECK: %[[LOCAL1:.*]] = pto.vci %[[ZERO]] : i32 -> !pto.vreg<64xi32> +// CHECK: %[[SCALED1:.*]] = pto.vmuls %[[LOCAL1]], %[[FACTOR]], +// CHECK: %[[BASE1:.*]] = arith.addi %arg0, %[[PART1]] : i32 +// CHECK: %[[P1:.*]] = pto.vadds %[[SCALED1]], %[[BASE1]], +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_iota_i16_asc( +// CHECK: %[[P16:.*]] = pto.vci %arg0 : i16 -> !pto.vreg<128xi16> +// CHECK: return %[[P16]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_iota_f16_deint2_asc( +// CHECK-DAG: %[[ZERO16:.*]] = arith.constant 0.000000e+00 : f16 +// CHECK-DAG: %[[FACTOR16:.*]] = arith.constant 2.000000e+00 : f16 +// CHECK-DAG: %[[PART16_1:.*]] = arith.constant 1.000000e+00 : f16 +// CHECK: %[[LOCAL16_0:.*]] = pto.vci %[[ZERO16]] : f16 -> !pto.vreg<128xf16> +// CHECK: %[[SCALED16_0:.*]] = pto.vmuls %[[LOCAL16_0]], %[[FACTOR16]], +// CHECK: %[[P16_0:.*]] = pto.vadds %[[SCALED16_0]], %arg0, +// CHECK: %[[LOCAL16_1:.*]] = pto.vci %[[ZERO16]] : f16 -> !pto.vreg<128xf16> +// CHECK: %[[SCALED16_1:.*]] = pto.vmuls %[[LOCAL16_1]], %[[FACTOR16]], +// CHECK: %[[BASE16_1:.*]] = arith.addf %arg0, %[[PART16_1]] : f16 +// CHECK: %[[P16_1:.*]] = pto.vadds %[[SCALED16_1]], %[[BASE16_1]], +// CHECK: return %[[P16_0]], %[[P16_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_iota_tail.pto b/test/lit/vmi/vmi_to_vpto_iota_tail.pto new file mode 100644 index 0000000000..7ba8a31f11 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_iota_tail.pto @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_iota_contiguous_tail(%base: i32) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) { + %value = pto.vmi.iota %base + : i32 -> !pto.vmi.vreg<100xi32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<100xi32, #pto.vmi.layout>) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) + return %p0, %p1 : !pto.vreg<64xi32>, !pto.vreg<64xi32> + } + + func.func @vmi_to_vpto_iota_deint2_tail(%base: i32) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>, + !pto.vreg<64xi32>, !pto.vreg<64xi32>) { + %value = pto.vmi.iota %base + : i32 -> !pto.vmi.vreg<130xi32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<130xi32, #pto.vmi.layout>) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>, + !pto.vreg<64xi32>, !pto.vreg<64xi32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xi32>, !pto.vreg<64xi32>, + !pto.vreg<64xi32>, !pto.vreg<64xi32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_iota_contiguous_tail( +// CHECK: %[[C64:.*]] = arith.constant 64 : i32 +// CHECK: %[[P0:.*]] = pto.vci %arg0 : i32 -> !pto.vreg<64xi32> +// CHECK: %[[B1:.*]] = arith.addi %arg0, %[[C64]] : i32 +// CHECK: %[[P1:.*]] = pto.vci %[[B1]] : i32 -> !pto.vreg<64xi32> +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_iota_deint2_tail( +// CHECK-DAG: %[[C128:.*]] = arith.constant 128 : i32 +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[C129:.*]] = arith.constant 129 : i32 +// CHECK: %[[BASE128:.*]] = arith.addi %arg0, %[[C128]] : i32 +// CHECK: %[[BASE1:.*]] = arith.addi %arg0, %[[C1]] : i32 +// CHECK: %[[BASE129:.*]] = arith.addi %arg0, %[[C129]] : i32 +// CHECK: return {{.*}}, {{.*}}, {{.*}}, {{.*}} +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_load_deint.pto b/test/lit/vmi/vmi_to_vpto_load_deint.pto new file mode 100644 index 0000000000..715dacdfa6 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_load_deint.pto @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_load_deint2(%src: !pto.ptr, %offset: index) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %value = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_load_deint4(%src: !pto.ptr, %offset: index) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %value = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_load_deint2( +// CHECK: %[[P0:.*]], %[[P1:.*]] = pto.vldsx2 %arg0[%arg1], "DINTLV_B32" +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_load_deint4( +// CHECK: %[[D0:.*]] = pto.vlds %arg0[%arg1] +// CHECK: %[[D1:.*]] = pto.vlds %arg0[{{.*}}] +// CHECK: %[[D2:.*]] = pto.vlds %arg0[{{.*}}] +// CHECK: %[[D3:.*]] = pto.vlds %arg0[{{.*}}] +// CHECK: %[[A0:.*]], %[[B0:.*]] = pto.vdintlv %[[D0]], %[[D1]] +// CHECK: %[[A1:.*]], %[[B1:.*]] = pto.vdintlv %[[D2]], %[[D3]] +// CHECK: %[[P0:.*]], %[[P2:.*]] = pto.vdintlv %[[A0]], %[[A1]] +// CHECK: %[[P1:.*]], %[[P3:.*]] = pto.vdintlv %[[B0]], %[[B1]] +// CHECK: return %[[P0]], %[[P1]], %[[P2]], %[[P3]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_load_deint_multichunk.pto b/test/lit/vmi/vmi_to_vpto_load_deint_multichunk.pto new file mode 100644 index 0000000000..433f222af3 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_load_deint_multichunk.pto @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_load_deint2_multichunk( + %src: !pto.ptr, %offset: index) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %value = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_load_deint2_multichunk( +// CHECK: %[[P0_0:.*]], %[[P1_0:.*]] = pto.vldsx2 %arg0[%arg1], "DINTLV_B32" +// CHECK: %[[P0_1:.*]], %[[P1_1:.*]] = pto.vldsx2 %arg0[{{.*}}], "DINTLV_B32" +// CHECK: return %[[P0_0]], %[[P0_1]], %[[P1_0]], %[[P1_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_load_nonfull_invalid.pto b/test/lit/vmi/vmi_to_vpto_load_nonfull_invalid.pto new file mode 100644 index 0000000000..f87e3753ca --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_load_nonfull_invalid.pto @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_load_nonfull_invalid( + %src: !pto.ptr, %offset: index) + -> (!pto.vreg<64xf32>) { + %value = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<4xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<4xf32, #pto.vmi.layout>) -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.load requires full physical chunks without padding lanes or a statically safe full-read footprint +// CHECK-SAME: safe-read proof failed: requires constant index offset +// CHECK-SAME: fallback decision: partial/tail read needs a scratch, guarded, or true masked/non-faulting load fallback +// CHECK-SAME: scratch memory fallback resource allocation is not implemented +// CHECK-SAME: guarded memory fallback control-flow lowering is not implemented diff --git a/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref.pto b/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref.pto new file mode 100644 index 0000000000..157f57f84d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref.pto @@ -0,0 +1,73 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_load_safe_tail_memref(%src: memref<128xf32>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %c0 = arith.constant 0 : index + %value = pto.vmi.load %src[%c0] + : memref<128xf32> -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_load_safe_tail_memref_nonzero_offset(%src: memref<132xf32>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %c4 = arith.constant 4 : index + %value = pto.vmi.load %src[%c4] + : memref<132xf32> -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_tile_read_safe_tail_memref(%src: memref<128xf32>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %value = pto.vmi.tile_read %src + : memref<128xf32> -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_load_safe_tail_memref( +// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[P0:.*]] = pto.vlds %arg0[%[[C0]]] : memref<128xf32> -> !pto.vreg<64xf32> +// CHECK: %[[P1:.*]] = pto.vlds %arg0[%[[C64]]] : memref<128xf32> -> !pto.vreg<64xf32> +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_load_safe_tail_memref_nonzero_offset( +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C68:.*]] = arith.constant 68 : index +// CHECK: %[[P0:.*]] = pto.vlds %arg0[%[[C4]]] : memref<132xf32> -> !pto.vreg<64xf32> +// CHECK: %[[P1:.*]] = pto.vlds %arg0[%[[C68]]] : memref<132xf32> -> !pto.vreg<64xf32> +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_tile_read_safe_tail_memref( +// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[P0:.*]] = pto.vlds %arg0[%[[C0]]] : memref<128xf32> -> !pto.vreg<64xf32> +// CHECK: %[[P1:.*]] = pto.vlds %arg0[%[[C64]]] : memref<128xf32> -> !pto.vreg<64xf32> +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref_invalid.pto b/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref_invalid.pto new file mode 100644 index 0000000000..07975ea70d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_load_safe_tail_memref_invalid(%src: memref<100xf32>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %c0 = arith.constant 0 : index + %value = pto.vmi.load %src[%c0] + : memref<100xf32> -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.load requires full physical chunks without padding lanes or a statically safe full-read footprint +// CHECK-SAME: safe-read proof failed: full physical read footprint [0, 128) exceeds static memref element count 100 diff --git a/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref_negative_offset_invalid.pto b/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref_negative_offset_invalid.pto new file mode 100644 index 0000000000..863c2b4fa5 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref_negative_offset_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_load_safe_tail_memref_negative_offset_invalid(%src: memref<132xf32>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %cm1 = arith.constant -1 : index + %value = pto.vmi.load %src[%cm1] + : memref<132xf32> -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.load requires full physical chunks without padding lanes or a statically safe full-read footprint +// CHECK-SAME: safe-read proof failed: requires non-negative offset diff --git a/test/lit/vmi/vmi_to_vpto_load_store_contiguous.pto b/test/lit/vmi/vmi_to_vpto_load_store_contiguous.pto new file mode 100644 index 0000000000..891cb20567 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_load_store_contiguous.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_load_store_contiguous( + %src: !pto.ptr, %dst: !pto.ptr, %offset: index) { + %value = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_load_store_contiguous( +// CHECK: %[[C64_LOAD:.*]] = arith.constant 64 : index +// CHECK: %[[L0:.*]] = pto.vlds %arg0[%arg2] : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: %[[OFF1_LOAD:.*]] = arith.addi %arg2, %[[C64_LOAD]] : index +// CHECK: %[[L1:.*]] = pto.vlds %arg0[%[[OFF1_LOAD]]] : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: %[[M0:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: pto.vsts %[[L0]], %arg1[%arg2], %[[M0]] : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// CHECK: %[[M1:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: pto.vsts %[[L1]], %arg1[{{.*}}], %[[M1]] : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_mask_logic.pto b/test/lit/vmi/vmi_to_vpto_mask_logic.pto new file mode 100644 index 0000000000..cf220cc6a4 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_mask_logic.pto @@ -0,0 +1,126 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_mask_logic( + %a: !pto.vmi.vreg<128xf32>, + %b: !pto.vmi.vreg<128xf32>, + %c: !pto.vmi.vreg<128xf32>) + -> (!pto.vmi.mask<128xpred>, !pto.vmi.mask<128xpred>, + !pto.vmi.mask<128xpred>, !pto.vmi.mask<128xpred>) { + %lt = pto.vmi.cmpf "olt", %a, %b + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.mask<128xpred> + %gt = pto.vmi.cmpf "ogt", %a, %c + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.mask<128xpred> + %and = pto.vmi.mask_and %lt, %gt + : !pto.vmi.mask<128xpred>, !pto.vmi.mask<128xpred> + -> !pto.vmi.mask<128xpred> + %or = pto.vmi.mask_or %lt, %gt + : !pto.vmi.mask<128xpred>, !pto.vmi.mask<128xpred> + -> !pto.vmi.mask<128xpred> + %xor = pto.vmi.mask_xor %lt, %gt + : !pto.vmi.mask<128xpred>, !pto.vmi.mask<128xpred> + -> !pto.vmi.mask<128xpred> + %not = pto.vmi.mask_not %lt + : !pto.vmi.mask<128xpred> -> !pto.vmi.mask<128xpred> + return %and, %or, %xor, %not + : !pto.vmi.mask<128xpred>, !pto.vmi.mask<128xpred>, + !pto.vmi.mask<128xpred>, !pto.vmi.mask<128xpred> + } + + func.func @vmi_to_vpto_mask_logic_b8( + %lhs: !pto.vmi.mask<256xb8, #pto.vmi.layout>, + %rhs: !pto.vmi.mask<256xb8, #pto.vmi.layout>) + -> (!pto.vmi.mask<256xb8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout>) { + %and = pto.vmi.mask_and %lhs, %rhs + : !pto.vmi.mask<256xb8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout> + -> !pto.vmi.mask<256xb8, #pto.vmi.layout> + %or = pto.vmi.mask_or %lhs, %rhs + : !pto.vmi.mask<256xb8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout> + -> !pto.vmi.mask<256xb8, #pto.vmi.layout> + %xor = pto.vmi.mask_xor %lhs, %rhs + : !pto.vmi.mask<256xb8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout> + -> !pto.vmi.mask<256xb8, #pto.vmi.layout> + %not = pto.vmi.mask_not %lhs + : !pto.vmi.mask<256xb8, #pto.vmi.layout> + -> !pto.vmi.mask<256xb8, #pto.vmi.layout> + return %and, %or, %xor, %not + : !pto.vmi.mask<256xb8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout> + } + + func.func @vmi_to_vpto_mask_logic_b16( + %lhs: !pto.vmi.mask<128xb16, #pto.vmi.layout>, + %rhs: !pto.vmi.mask<128xb16, #pto.vmi.layout>) + -> (!pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout>) { + %and = pto.vmi.mask_and %lhs, %rhs + : !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + %or = pto.vmi.mask_or %lhs, %rhs + : !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + %xor = pto.vmi.mask_xor %lhs, %rhs + : !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + %not = pto.vmi.mask_not %lhs + : !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + return %and, %or, %xor, %not + : !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_mask_logic( +// CHECK-SAME: -> (!pto.mask, !pto.mask, !pto.mask, !pto.mask, !pto.mask, !pto.mask, !pto.mask, !pto.mask) +// CHECK-DAG: %[[AND0:.*]] = pto.pand {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[AND1:.*]] = pto.pand {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[OR0:.*]] = pto.por {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[OR1:.*]] = pto.por {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[XOR0:.*]] = pto.pxor {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[XOR1:.*]] = pto.pxor {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[NOT0:.*]] = pto.pnot {{.*}} : !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[NOT1:.*]] = pto.pnot {{.*}} : !pto.mask, !pto.mask -> !pto.mask +// CHECK: return %[[AND0]], %[[AND1]], %[[OR0]], %[[OR1]], %[[XOR0]], %[[XOR1]], %[[NOT0]], %[[NOT1]] +// CHECK-LABEL: func.func @vmi_to_vpto_mask_logic_b8( +// CHECK-SAME: -> (!pto.mask, !pto.mask, !pto.mask, !pto.mask) +// CHECK-DAG: %[[AND_B8:.*]] = pto.pand {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[OR_B8:.*]] = pto.por {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[XOR_B8:.*]] = pto.pxor {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[NOT_B8:.*]] = pto.pnot {{.*}} : !pto.mask, !pto.mask -> !pto.mask +// CHECK: return %[[AND_B8]], %[[OR_B8]], %[[XOR_B8]], %[[NOT_B8]] +// CHECK-LABEL: func.func @vmi_to_vpto_mask_logic_b16( +// CHECK-SAME: -> (!pto.mask, !pto.mask, !pto.mask, !pto.mask) +// CHECK-DAG: %[[AND_B16:.*]] = pto.pand {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[OR_B16:.*]] = pto.por {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[XOR_B16:.*]] = pto.pxor {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[NOT_B16:.*]] = pto.pnot {{.*}} : !pto.mask, !pto.mask -> !pto.mask +// CHECK: return %[[AND_B16]], %[[OR_B16]], %[[XOR_B16]], %[[NOT_B16]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_masked_load.pto b/test/lit/vmi/vmi_to_vpto_masked_load.pto new file mode 100644 index 0000000000..bc46f591ac --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_masked_load.pto @@ -0,0 +1,36 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_masked_load( + %src: !pto.ptr, + %offset: index, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.masked_load %src[%offset], %mask, %passthru + : !pto.ptr, + !pto.vmi.mask<64xb32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_masked_load( +// CHECK: %[[LOAD:.*]] = pto.vlds %arg0[%arg1] : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: %[[OUT:.*]] = pto.vsel %[[LOAD]], %arg3, %arg2 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_masked_load_nonfull_invalid.pto b/test/lit/vmi/vmi_to_vpto_masked_load_nonfull_invalid.pto new file mode 100644 index 0000000000..9b79049c1a --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_masked_load_nonfull_invalid.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_masked_load_nonfull_invalid( + %src: !pto.ptr, + %offset: index, + %mask: !pto.vmi.mask<4xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<4xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.masked_load %src[%offset], %mask, %passthru + : !pto.ptr, + !pto.vmi.mask<4xb32, #pto.vmi.layout>, + !pto.vmi.vreg<4xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<4xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<4xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.masked_load direct lowering requires a supported memory source, contiguous result/passthru/mask layouts +// CHECK-SAME: safe-read proof requires constant index offset +// CHECK-SAME: fallback decision: partial/tail read needs a scratch, guarded, or true masked/non-faulting load fallback +// CHECK-SAME: target true masked/non-faulting load is unavailable diff --git a/test/lit/vmi/vmi_to_vpto_masked_load_safe_tail_memref.pto b/test/lit/vmi/vmi_to_vpto_masked_load_safe_tail_memref.pto new file mode 100644 index 0000000000..d4b9f23c23 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_masked_load_safe_tail_memref.pto @@ -0,0 +1,69 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_masked_load_safe_tail_memref( + %src: memref<128xf32>, + %mask: !pto.vmi.mask<100xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %c0 = arith.constant 0 : index + %out = pto.vmi.masked_load %src[%c0], %mask, %passthru + : memref<128xf32>, + !pto.vmi.mask<100xb32, #pto.vmi.layout>, + !pto.vmi.vreg<100xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_masked_load_safe_tail_memref_nonzero_offset( + %src: memref<132xf32>, + %mask: !pto.vmi.mask<100xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %c4 = arith.constant 4 : index + %out = pto.vmi.masked_load %src[%c4], %mask, %passthru + : memref<132xf32>, + !pto.vmi.mask<100xb32, #pto.vmi.layout>, + !pto.vmi.vreg<100xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_masked_load_safe_tail_memref( +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index +// CHECK: %[[L0:.*]] = pto.vlds %arg0[%[[C0]]] : memref<128xf32> -> !pto.vreg<64xf32> +// CHECK: %[[O0:.*]] = pto.vsel %[[L0]], %arg3, %arg1 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[L1:.*]] = pto.vlds %arg0[%[[C64]]] : memref<128xf32> -> !pto.vreg<64xf32> +// CHECK: %[[O1:.*]] = pto.vsel %[[L1]], %arg4, %arg2 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[O0]], %[[O1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_masked_load_safe_tail_memref_nonzero_offset( +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C68:.*]] = arith.constant 68 : index +// CHECK: %[[L0:.*]] = pto.vlds %arg0[%[[C4]]] : memref<132xf32> -> !pto.vreg<64xf32> +// CHECK: %[[O0:.*]] = pto.vsel %[[L0]], %arg3, %arg1 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[L1:.*]] = pto.vlds %arg0[%[[C68]]] : memref<132xf32> -> !pto.vreg<64xf32> +// CHECK: %[[O1:.*]] = pto.vsel %[[L1]], %arg4, %arg2 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[O0]], %[[O1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_masked_load_safe_tail_memref_negative_offset_invalid.pto b/test/lit/vmi/vmi_to_vpto_masked_load_safe_tail_memref_negative_offset_invalid.pto new file mode 100644 index 0000000000..ab22618d3e --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_masked_load_safe_tail_memref_negative_offset_invalid.pto @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_masked_load_safe_tail_memref_negative_offset_invalid( + %src: memref<132xf32>, + %mask: !pto.vmi.mask<100xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %cm1 = arith.constant -1 : index + %out = pto.vmi.masked_load %src[%cm1], %mask, %passthru + : memref<132xf32>, + !pto.vmi.mask<100xb32, #pto.vmi.layout>, + !pto.vmi.vreg<100xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.masked_load direct lowering requires a supported memory source, contiguous result/passthru/mask layouts +// CHECK-SAME: safe-read proof requires non-negative offset diff --git a/test/lit/vmi/vmi_to_vpto_masked_store.pto b/test/lit/vmi/vmi_to_vpto_masked_store.pto new file mode 100644 index 0000000000..01e8d53d89 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_masked_store.pto @@ -0,0 +1,38 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_masked_store_contiguous( + %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + pto.vmi.masked_store %value, %dst[%offset], %mask + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_masked_store_contiguous( +// CHECK-SAME: %[[V0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[V1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[M0:[^,]+]]: !pto.mask +// CHECK-SAME: %[[M1:[^,]+]]: !pto.mask +// CHECK-SAME: %[[DST:[^,]+]]: !pto.ptr +// CHECK-SAME: %[[OFF:[^)]+]]: index +// CHECK: %[[C64:.*]] = arith.constant 64 : index +// CHECK: pto.vsts %[[V0]], %[[DST]][%[[OFF]]], %[[M0]] : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// CHECK: %[[OFF1:.*]] = arith.addi %[[OFF]], %[[C64]] : index +// CHECK: pto.vsts %[[V1]], %[[DST]][%[[OFF1]]], %[[M1]] : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_masked_store_deint_tail.pto b/test/lit/vmi/vmi_to_vpto_masked_store_deint_tail.pto new file mode 100644 index 0000000000..e874e8d90d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_masked_store_deint_tail.pto @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_masked_store_deint_tail( + %value: !pto.vmi.vreg<4xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<4xb32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + pto.vmi.masked_store %value, %dst[%offset], %mask + : !pto.vmi.vreg<4xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<4xb32, #pto.vmi.layout> + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_masked_store_deint_tail( +// CHECK-SAME: %[[V0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[V1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[M0:[^,]+]]: !pto.mask +// CHECK-SAME: %[[M1:[^,]+]]: !pto.mask +// CHECK-SAME: %[[DST:[^,]+]]: !pto.ptr +// CHECK-SAME: %[[OFF:[^)]+]]: index +// CHECK: %[[C4:.*]] = arith.constant 4 : i32 +// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = pto.vintlv %[[V0]], %[[V1]] +// CHECK: %[[USER:.*]], %{{.*}} = pto.pintlv_b32 %[[M0]], %[[M1]] +// CHECK: %[[TAIL:.*]], %{{.*}} = pto.plt_b32 %[[C4]] : i32 -> !pto.mask, i32 +// CHECK: %[[ALL:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[MASK:.*]] = pto.pand %[[USER]], %[[TAIL]], %[[ALL]] +// CHECK: pto.vsts %[[LOW]], %[[DST]][%[[OFF]]], %[[MASK]] +// CHECK-NOT: pto.vsts %[[HIGH]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_masked_store_nonfull_invalid.pto b/test/lit/vmi/vmi_to_vpto_masked_store_nonfull_invalid.pto new file mode 100644 index 0000000000..375f44c894 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_masked_store_nonfull_invalid.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_masked_store_nonfull_invalid( + %value: !pto.vmi.vreg<129xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<129xb32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + pto.vmi.masked_store %value, %dst[%offset], %mask + : !pto.vmi.vreg<129xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<129xb32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.masked_store requires either full physical chunks or contiguous tail-store value/mask layout +// CHECK-SAME: requires every deinterleaved part to have the same physical chunk count diff --git a/test/lit/vmi/vmi_to_vpto_masked_store_tail.pto b/test/lit/vmi/vmi_to_vpto_masked_store_tail.pto new file mode 100644 index 0000000000..361277c4fd --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_masked_store_tail.pto @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_masked_store_tail( + %value: !pto.vmi.vreg<100xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<100xb32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + pto.vmi.masked_store %value, %dst[%offset], %mask + : !pto.vmi.vreg<100xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<100xb32, #pto.vmi.layout> + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_masked_store_tail( +// CHECK-SAME: %[[V0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[V1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[M0:[^,]+]]: !pto.mask +// CHECK-SAME: %[[M1:[^,]+]]: !pto.mask +// CHECK-SAME: %[[DST:[^,]+]]: !pto.ptr +// CHECK-SAME: %[[OFF:[^)]+]]: index +// CHECK: %[[C36:.*]] = arith.constant 36 : i32 +// CHECK: pto.vsts %[[V0]], %[[DST]][%[[OFF]]], %[[M0]] +// CHECK: %[[TAIL:.*]], %{{.*}} = pto.plt_b32 %[[C36]] : i32 -> !pto.mask, i32 +// CHECK: %[[ALL:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[COMBINED:.*]] = pto.pand %[[M1]], %[[TAIL]], %[[ALL]] : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK: pto.vsts %[[V1]], %[[DST]][{{.*}}], %[[COMBINED]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_math_element_type_invalid.pto b/test/lit/vmi/vmi_to_vpto_math_element_type_invalid.pto new file mode 100644 index 0000000000..1102d992ef --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_math_element_type_invalid.pto @@ -0,0 +1,131 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_addf_f8_invalid( + %lhs: !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>) { + %out = pto.vmi.addf %lhs, %rhs + : !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>, + !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.addf direct lowering requires f16/bf16/f32 element type +// CHECK-SAME: requires f16/bf16/f32 element type for direct VPTO lowering + +// ----- + +module { + func.func @vmi_to_vpto_divf_bf16_invalid( + %lhs: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>) { + %out = pto.vmi.divf %lhs, %rhs + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.divf direct lowering requires f16/f32 element type +// CHECK-SAME: requires f16/f32 element type for direct VPTO lowering + +// ----- + +module { + func.func @vmi_to_vpto_sqrt_bf16_invalid( + %source: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>) { + %out = pto.vmi.sqrt %source + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.sqrt direct lowering requires f16/f32 element type +// CHECK-SAME: requires f16/f32 element type for direct VPTO lowering + +// ----- + +module { + func.func @vmi_to_vpto_exp_f8_invalid( + %source: !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>) { + %out = pto.vmi.exp %source + : !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.exp direct lowering requires f16/f32 element type +// CHECK-SAME: requires f16/f32 element type for direct VPTO lowering + +// ----- + +module { + func.func @vmi_to_vpto_negf_bf16_invalid( + %source: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>) { + %out = pto.vmi.negf %source + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.negf direct lowering requires f16/f32 element type +// CHECK-SAME: requires f16/f32 element type for direct VPTO lowering + +// ----- + +module { + func.func @vmi_to_vpto_ln_bf16_invalid( + %source: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>) { + %out = pto.vmi.ln %source + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.ln direct lowering requires f16/f32 element type +// CHECK-SAME: requires f16/f32 element type for direct VPTO lowering + +// ----- + +module { + func.func @vmi_to_vpto_absf_bf16_invalid( + %source: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>) { + %out = pto.vmi.absf %source + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.absf direct lowering requires f16/f32 element type +// CHECK-SAME: requires f16/f32 element type for direct VPTO lowering + +// ----- + +module { + func.func @vmi_to_vpto_absi_unsigned_invalid( + %source: !pto.vmi.vreg<128xui16, #pto.vmi.layout>) { + %out = pto.vmi.absi %source + : !pto.vmi.vreg<128xui16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xui16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.absi direct lowering requires signless/signed i8/i16/i32 element type +// CHECK-SAME: requires signless/signed i8/i16/i32 element type for direct VPTO lowering diff --git a/test/lit/vmi/vmi_to_vpto_memory_space_invalid.pto b/test/lit/vmi/vmi_to_vpto_memory_space_invalid.pto new file mode 100644 index 0000000000..7a222a35ad --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_memory_space_invalid.pto @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_load_gm_unsupported(%src: !pto.ptr, %offset: index) { + %value = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.load requires full physical chunks without padding lanes or a statically safe full-read footprint +// CHECK-SAME: source is GM-backed +// CHECK-SAME: requires UB-backed memory + +// ----- + +module { + func.func @vmi_masked_load_gm_unsupported( + %src: !pto.ptr, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %offset: index) { + %value = pto.vmi.masked_load %src[%offset], %mask, %passthru + : !pto.ptr, !pto.vmi.mask<64xb32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.masked_load direct lowering requires a supported memory source, contiguous result/passthru/mask layouts +// CHECK-SAME: source is GM-backed +// CHECK-SAME: requires UB-backed memory + +// ----- + +module { + func.func @vmi_expand_load_gm_unsupported( + %src: !pto.ptr, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %offset: index) { + %value = pto.vmi.expand_load %src[%offset], %mask, %passthru + : !pto.ptr, !pto.vmi.mask<64xb32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.expand_load direct lowering is currently supported +// CHECK-SAME: source is GM-backed +// CHECK-SAME: requires UB-backed memory + +// ----- + +module { + func.func @vmi_store_gm_unsupported( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.ptr + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.store requires an 8/16/32-bit predicate-maskable element type +// CHECK-SAME: with UB-backed destination +// CHECK-SAME: destination is GM-backed + +// ----- + +module { + func.func @vmi_masked_store_gm_unsupported( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + pto.vmi.masked_store %value, %dst[%offset], %mask + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.ptr, !pto.vmi.mask<64xb32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.masked_store requires either full physical chunks or contiguous tail-store value/mask layout +// CHECK-SAME: with UB-backed destination +// CHECK-SAME: destination is GM-backed + +// ----- + +module { + func.func @vmi_tile_read_gm_unsupported( + %src: memref<64xf32, #pto.address_space>) { + %value = pto.vmi.tile_read %src + : memref<64xf32, #pto.address_space> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.tile_read requires full physical chunks without padding lanes or a statically safe full-read footprint +// CHECK-SAME: source is GM-backed +// CHECK-SAME: requires UB-backed memory + +// ----- + +module { + func.func @vmi_tile_write_gm_unsupported( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %dst: memref<64xf32, #pto.address_space>) { + pto.vmi.tile_write %value, %dst + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + memref<64xf32, #pto.address_space> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.tile_write requires an 8/16/32-bit predicate-maskable element type +// CHECK-SAME: with UB-backed destination +// CHECK-SAME: destination is GM-backed diff --git a/test/lit/vmi/vmi_to_vpto_memory_x2_widths.pto b/test/lit/vmi/vmi_to_vpto_memory_x2_widths.pto new file mode 100644 index 0000000000..98d92a3262 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_memory_x2_widths.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_load_deint2_f16( + %src: !pto.ptr, %offset: index) + -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>) { + %value = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<256xf16, #pto.vmi.layout>) + -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>) + return %p0, %p1 : !pto.vreg<128xf16>, !pto.vreg<128xf16> + } + + func.func @vmi_to_vpto_store_deint2_i8( + %value: !pto.vmi.vreg<512xi8, #pto.vmi.layout>, + %dst: !pto.ptr, %offset: index) { + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<512xi8, #pto.vmi.layout>, !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_load_deint2_f16( +// CHECK: %[[P0:.*]], %[[P1:.*]] = pto.vldsx2 %arg0[%arg1], "DINTLV_B16" +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_store_deint2_i8( +// CHECK: %[[MASK:.*]] = pto.pset_b8 "PAT_ALL" +// CHECK: pto.vstsx2 %arg0, %arg1, %arg2[%arg3], "INTLV_B8", %[[MASK]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_memref_layout_invalid.pto b/test/lit/vmi/vmi_to_vpto_memref_layout_invalid.pto new file mode 100644 index 0000000000..489891c72a --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_memref_layout_invalid.pto @@ -0,0 +1,177 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_load_strided_memref_unsupported( + %src: memref<128xf32, strided<[2], offset: 0>>) { + %c0 = arith.constant 0 : index + %value = pto.vmi.load %src[%c0] + : memref<128xf32, strided<[2], offset: 0>> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.load requires full physical chunks without padding lanes or a statically safe full-read footprint +// CHECK-SAME: source memref layout is non-identity +// CHECK-SAME: contiguous identity lane-to-address maps + +// ----- + +module { + func.func @vmi_load_memref_subview_unsupported(%src: memref<128xf32>) { + %c0 = arith.constant 0 : index + %view = memref.subview %src[%c0] [64] [1] + : memref<128xf32> to memref<64xf32, strided<[1], offset: ?>> + %value = pto.vmi.load %view[%c0] + : memref<64xf32, strided<[1], offset: ?>> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.load requires full physical chunks without padding lanes or a statically safe full-read footprint +// CHECK-SAME: source memref layout is non-identity +// CHECK-SAME: contiguous identity lane-to-address maps +// CHECK-SAME: memref.subview requires normalized base/offset/stride lane-to-address planning + +// ----- + +module { + func.func @vmi_masked_load_strided_memref_unsupported( + %src: memref<128xf32, strided<[2], offset: 0>>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) { + %c0 = arith.constant 0 : index + %value = pto.vmi.masked_load %src[%c0], %mask, %passthru + : memref<128xf32, strided<[2], offset: 0>>, + !pto.vmi.mask<64xb32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.masked_load direct lowering requires a supported memory source, contiguous result/passthru/mask layouts +// CHECK-SAME: source memref layout is non-identity +// CHECK-SAME: contiguous identity lane-to-address maps + +// ----- + +module { + func.func @vmi_expand_load_strided_memref_unsupported( + %src: memref<128xf32, strided<[2], offset: 0>>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) { + %c0 = arith.constant 0 : index + %value = pto.vmi.expand_load %src[%c0], %mask, %passthru + : memref<128xf32, strided<[2], offset: 0>>, + !pto.vmi.mask<64xb32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.expand_load direct lowering is currently supported +// CHECK-SAME: source memref layout is non-identity +// CHECK-SAME: contiguous identity lane-to-address maps + +// ----- + +module { + func.func @vmi_store_strided_memref_unsupported( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %dst: memref<128xf32, strided<[2], offset: 0>>) { + %c0 = arith.constant 0 : index + pto.vmi.store %value, %dst[%c0] + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + memref<128xf32, strided<[2], offset: 0>> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.store requires an 8/16/32-bit predicate-maskable element type +// CHECK-SAME: destination memref layout is non-identity +// CHECK-SAME: contiguous identity lane-to-address maps + +// ----- + +module { + func.func @vmi_store_memref_subview_unsupported( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %dst: memref<128xf32>) { + %c0 = arith.constant 0 : index + %view = memref.subview %dst[%c0] [64] [1] + : memref<128xf32> to memref<64xf32, strided<[1], offset: ?>> + pto.vmi.store %value, %view[%c0] + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + memref<64xf32, strided<[1], offset: ?>> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.store requires an 8/16/32-bit predicate-maskable element type +// CHECK-SAME: destination memref layout is non-identity +// CHECK-SAME: contiguous identity lane-to-address maps +// CHECK-SAME: memref.subview requires normalized base/offset/stride lane-to-address planning + +// ----- + +module { + func.func @vmi_masked_store_strided_memref_unsupported( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %dst: memref<128xf32, strided<[2], offset: 0>>) { + %c0 = arith.constant 0 : index + pto.vmi.masked_store %value, %dst[%c0], %mask + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + memref<128xf32, strided<[2], offset: 0>>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.masked_store requires either full physical chunks or contiguous tail-store value/mask layout +// CHECK-SAME: destination memref layout is non-identity +// CHECK-SAME: contiguous identity lane-to-address maps + +// ----- + +module { + func.func @vmi_tile_read_strided_memref_unsupported( + %src: memref<128xf32, strided<[2], offset: 0>>) { + %value = pto.vmi.tile_read %src + : memref<128xf32, strided<[2], offset: 0>> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.tile_read requires full physical chunks without padding lanes or a statically safe full-read footprint +// CHECK-SAME: source memref layout is non-identity +// CHECK-SAME: contiguous identity lane-to-address maps + +// ----- + +module { + func.func @vmi_tile_write_strided_memref_unsupported( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %dst: memref<128xf32, strided<[2], offset: 0>>) { + pto.vmi.tile_write %value, %dst + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + memref<128xf32, strided<[2], offset: 0>> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.tile_write requires an 8/16/32-bit predicate-maskable element type +// CHECK-SAME: destination memref layout is non-identity +// CHECK-SAME: contiguous identity lane-to-address maps diff --git a/test/lit/vmi/vmi_to_vpto_min_max.pto b/test/lit/vmi/vmi_to_vpto_min_max.pto new file mode 100644 index 0000000000..eeefc6ee94 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_min_max.pto @@ -0,0 +1,39 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_min_max( + %lhs: !pto.vmi.vreg<128xf32>, + %rhs: !pto.vmi.vreg<128xf32>) + -> (!pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32>) { + %min = pto.vmi.minf %lhs, %rhs + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %max = pto.vmi.maxf %lhs, %rhs + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %min, %max : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_min_max( +// CHECK-SAME: %[[LHS0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[LHS1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[RHS0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[RHS1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK-DAG: %[[MIN0:.*]] = pto.vmin %[[LHS0]], %[[RHS0]], {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[MIN1:.*]] = pto.vmin %[[LHS1]], %[[RHS1]], {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[MAX0:.*]] = pto.vmax %[[LHS0]], %[[RHS0]], {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[MAX1:.*]] = pto.vmax %[[LHS1]], %[[RHS1]], {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[MIN0]], %[[MIN1]], %[[MAX0]], %[[MAX1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_negf.pto b/test/lit/vmi/vmi_to_vpto_negf.pto new file mode 100644 index 0000000000..1aafa02c9a --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_negf.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_negf(%a: !pto.vmi.vreg<128xf32>) + -> !pto.vmi.vreg<128xf32> { + %neg = pto.vmi.negf %a + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + return %neg : !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_negf( +// CHECK-SAME: %[[A0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[A1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK-DAG: %[[NEG0:.*]] = pto.vneg %[[A0]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[NEG1:.*]] = pto.vneg %[[A1]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[NEG0]], %[[NEG1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_pack_unpack.pto b/test/lit/vmi/vmi_to_vpto_pack_unpack.pto new file mode 100644 index 0000000000..e4caa3cc0a --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_pack_unpack.pto @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_unpack( + %v: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %p0, %p1 = "pto.vmi.unpack"(%v) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_pack_unpack( + %p0: !pto.vreg<64xf32>, + %p1: !pto.vreg<64xf32>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %v = "pto.vmi.pack"(%p0, %p1) + : (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %q0, %q1 = "pto.vmi.unpack"(%v) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %q0, %q1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_unpack( +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK: return +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-LABEL: func.func @vmi_to_vpto_pack_unpack( +// CHECK: return +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi.pack +// CHECK-NOT: pto.vmi.unpack +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_quant_dequant.pto b/test/lit/vmi/vmi_to_vpto_quant_dequant.pto new file mode 100644 index 0000000000..7d302805d6 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_quant_dequant.pto @@ -0,0 +1,310 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_dequant_matrix_f16_to_f32( + %src: !pto.ptr, + %scale: f32, + %dst: !pto.ptr, + %rows: index, + %full_blocks: index, + %tail: index, + %src_stride: index, + %dst_stride: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %has_tail = arith.cmpi ne, %tail, %c0 : index + scf.for %row = %c0 to %rows step %c1 { + %src_row = arith.muli %row, %src_stride : index + %dst_row = arith.muli %row, %dst_stride : index + scf.for %block = %c0 to %full_blocks step %c1 { + %block_offset = arith.muli %block, %c128 : index + %src_offset = arith.addi %src_row, %block_offset : index + %dst_offset = arith.addi %dst_row, %block_offset : index + %packed = pto.vmi.load %src[%src_offset] + : !pto.ptr -> !pto.vmi.vreg<128xf16> + %wide = pto.vmi.extf %packed + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %scale_vec = pto.vmi.broadcast %scale + : f32 -> !pto.vmi.vreg<128xf32> + %dequant = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.store %dequant, %dst[%dst_offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + scf.if %has_tail { + %tail_offset = arith.muli %full_blocks, %c128 : index + %src_offset = arith.addi %src_row, %tail_offset : index + %dst_offset = arith.addi %dst_row, %tail_offset : index + %tail_mask = pto.vmi.create_mask %tail + : index -> !pto.vmi.mask<128xpred> + %packed = pto.vmi.load %src[%src_offset] + : !pto.ptr -> !pto.vmi.vreg<128xf16> + %wide = pto.vmi.extf %packed + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %scale_vec = pto.vmi.broadcast %scale + : f32 -> !pto.vmi.vreg<128xf32> + %dequant = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.masked_store %dequant, %dst[%dst_offset], %tail_mask + : !pto.vmi.vreg<128xf32>, !pto.ptr, + !pto.vmi.mask<128xpred> + } + } + return + } + + func.func @vmi_to_vpto_quant_matrix_f32_to_f16( + %src: !pto.ptr, + %inv_scale: f32, + %dst: !pto.ptr, + %rows: index, + %full_blocks: index, + %tail: index, + %src_stride: index, + %dst_stride: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %has_tail = arith.cmpi ne, %tail, %c0 : index + scf.for %row = %c0 to %rows step %c1 { + %src_row = arith.muli %row, %src_stride : index + %dst_row = arith.muli %row, %dst_stride : index + scf.for %block = %c0 to %full_blocks step %c1 { + %block_offset = arith.muli %block, %c128 : index + %src_offset = arith.addi %src_row, %block_offset : index + %dst_offset = arith.addi %dst_row, %block_offset : index + %wide = pto.vmi.load %src[%src_offset] + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %scale_vec = pto.vmi.broadcast %inv_scale + : f32 -> !pto.vmi.vreg<128xf32> + %scaled = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %packed = pto.vmi.truncf %scaled + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + pto.vmi.store %packed, %dst[%dst_offset] + : !pto.vmi.vreg<128xf16>, !pto.ptr + } + scf.if %has_tail { + %tail_offset = arith.muli %full_blocks, %c128 : index + %src_offset = arith.addi %src_row, %tail_offset : index + %dst_offset = arith.addi %dst_row, %tail_offset : index + %tail_mask = pto.vmi.create_mask %tail + : index -> !pto.vmi.mask<128xpred> + %wide = pto.vmi.load %src[%src_offset] + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %scale_vec = pto.vmi.broadcast %inv_scale + : f32 -> !pto.vmi.vreg<128xf32> + %scaled = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %packed = pto.vmi.truncf %scaled + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + pto.vmi.masked_store %packed, %dst[%dst_offset], %tail_mask + : !pto.vmi.vreg<128xf16>, !pto.ptr, + !pto.vmi.mask<128xpred> + } + } + return + } + + func.func @vmi_to_vpto_dequant_matrix_fp8_to_f32( + %src: !pto.ptr, + %scale: f32, + %dst: !pto.ptr, + %rows: index, + %full_blocks: index, + %tail: index, + %src_stride: index, + %dst_stride: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %has_tail = arith.cmpi ne, %tail, %c0 : index + scf.for %row = %c0 to %rows step %c1 { + %src_row = arith.muli %row, %src_stride : index + %dst_row = arith.muli %row, %dst_stride : index + scf.for %block = %c0 to %full_blocks step %c1 { + %block_offset = arith.muli %block, %c256 : index + %src_offset = arith.addi %src_row, %block_offset : index + %dst_offset = arith.addi %dst_row, %block_offset : index + %packed = pto.vmi.load %src[%src_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf8E4M3FN> + %wide = pto.vmi.extf %packed + : !pto.vmi.vreg<256xf8E4M3FN> -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.broadcast %scale + : f32 -> !pto.vmi.vreg<256xf32> + %dequant = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + pto.vmi.store %dequant, %dst[%dst_offset] + : !pto.vmi.vreg<256xf32>, !pto.ptr + } + scf.if %has_tail { + %tail_offset = arith.muli %full_blocks, %c256 : index + %src_offset = arith.addi %src_row, %tail_offset : index + %dst_offset = arith.addi %dst_row, %tail_offset : index + %tail_mask = pto.vmi.create_mask %tail + : index -> !pto.vmi.mask<256xpred> + %packed = pto.vmi.load %src[%src_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf8E4M3FN> + %wide = pto.vmi.extf %packed + : !pto.vmi.vreg<256xf8E4M3FN> -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.broadcast %scale + : f32 -> !pto.vmi.vreg<256xf32> + %dequant = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + pto.vmi.masked_store %dequant, %dst[%dst_offset], %tail_mask + : !pto.vmi.vreg<256xf32>, !pto.ptr, + !pto.vmi.mask<256xpred> + } + } + return + } + + func.func @vmi_to_vpto_quant_matrix_f32_to_fp8( + %src: !pto.ptr, + %inv_scale: f32, + %dst: !pto.ptr, + %rows: index, + %full_blocks: index, + %tail: index, + %src_stride: index, + %dst_stride: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %has_tail = arith.cmpi ne, %tail, %c0 : index + scf.for %row = %c0 to %rows step %c1 { + %src_row = arith.muli %row, %src_stride : index + %dst_row = arith.muli %row, %dst_stride : index + scf.for %block = %c0 to %full_blocks step %c1 { + %block_offset = arith.muli %block, %c256 : index + %src_offset = arith.addi %src_row, %block_offset : index + %dst_offset = arith.addi %dst_row, %block_offset : index + %wide = pto.vmi.load %src[%src_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.broadcast %inv_scale + : f32 -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %packed = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %packed, %dst[%dst_offset] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + scf.if %has_tail { + %tail_offset = arith.muli %full_blocks, %c256 : index + %src_offset = arith.addi %src_row, %tail_offset : index + %dst_offset = arith.addi %dst_row, %tail_offset : index + %tail_mask = pto.vmi.create_mask %tail + : index -> !pto.vmi.mask<256xpred> + %wide = pto.vmi.load %src[%src_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.broadcast %inv_scale + : f32 -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %packed = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.masked_store %packed, %dst[%dst_offset], %tail_mask + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr, + !pto.vmi.mask<256xpred> + } + } + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_dequant_matrix_f16_to_f32( +// CHECK-SAME: %[[DSRC:[^,]+]]: !pto.ptr +// CHECK-SAME: %[[SCALE:[^,]+]]: f32 +// CHECK-SAME: %[[DDST:[^,]+]]: !pto.ptr +// CHECK: scf.for +// CHECK: scf.for +// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<128xf16> +// CHECK: pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "ODD"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vmul {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vsts {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// CHECK: scf.if +// CHECK: pto.plt_b32 {{.*}} : i32 -> !pto.mask, i32 +// CHECK: pto.vsts {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + +// CHECK-LABEL: func.func @vmi_to_vpto_quant_matrix_f32_to_f16( +// CHECK-SAME: %[[QSRC:[^,]+]]: !pto.ptr +// CHECK-SAME: %[[INV_SCALE:[^,]+]]: f32 +// CHECK-SAME: %[[QDST:[^,]+]]: !pto.ptr +// CHECK: scf.for +// CHECK: scf.for +// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: pto.vmul {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: pto.vcvt {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: pto.vor {{.*}} : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: scf.if +// CHECK: pto.plt_b16 {{.*}} : i32 -> !pto.mask, i32 +// CHECK: pto.vsts {{.*}} : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + +// CHECK-LABEL: func.func @vmi_to_vpto_dequant_matrix_fp8_to_f32( +// CHECK-SAME: %[[FSRC:[^,]+]]: !pto.ptr +// CHECK-SAME: %[[FSCALE:[^,]+]]: f32 +// CHECK-SAME: %[[FDST:[^,]+]]: !pto.ptr +// CHECK: scf.for +// CHECK: scf.for +// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vcvt {{.*}} {part = "P0"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "P1"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "P2"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "P3"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vmul {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vintlv +// CHECK: pto.vsts {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// CHECK: scf.if +// CHECK: pto.plt_b32 {{.*}} : i32 -> !pto.mask, i32 +// CHECK: pto.vsts {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + +// CHECK-LABEL: func.func @vmi_to_vpto_quant_matrix_f32_to_fp8( +// CHECK-SAME: %[[FQSRC:[^,]+]]: !pto.ptr +// CHECK-SAME: %[[FINV_SCALE:[^,]+]]: f32 +// CHECK-SAME: %[[FQDST:[^,]+]]: !pto.ptr +// CHECK: scf.for +// CHECK: scf.for +// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: pto.vmul {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vcvt {{.*}} {part = "P1", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vcvt {{.*}} {part = "P2", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vcvt {{.*}} {part = "P3", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vor {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vor {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vor {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: scf.if +// CHECK: pto.plt_b8 {{.*}} : i32 -> !pto.mask, i32 +// CHECK: pto.vsts {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.ptr, !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_quant_fp8.pto b/test/lit/vmi/vmi_to_vpto_quant_fp8.pto new file mode 100644 index 0000000000..c44de2ec84 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_quant_fp8.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_quant_matrix_f32_to_fp8( + %src: !pto.ptr, + %inv_scale: f32, + %dst: !pto.ptr, + %offset: index) { + %wide = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.broadcast %inv_scale + : f32 -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %packed = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %packed, %dst[%offset] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_quant_matrix_f32_to_fp8( +// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vcvt {{.*}} {part = "P1", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vcvt {{.*}} {part = "P2", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vcvt {{.*}} {part = "P3", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vor {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vor {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vor {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vsts {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.ptr, !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_reduce_addf.pto b/test/lit/vmi/vmi_to_vpto_reduce_addf.pto new file mode 100644 index 0000000000..6f2fadfdba --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_reduce_addf.pto @@ -0,0 +1,36 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_reduce_addf( + %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.reduce_addf %source, %init, %mask {reassoc} + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_reduce_addf( +// CHECK: %[[FIRST:.*]] = pto.pge_b32 "PAT_VL1" : !pto.mask +// CHECK: %[[REDUCED:.*]] = pto.vcadd %arg0, %arg2 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[OUT:.*]] = pto.vadd %[[REDUCED]], %arg1, %[[FIRST]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_reduce_addf_f16_invalid.pto b/test/lit/vmi/vmi_to_vpto_reduce_addf_f16_invalid.pto new file mode 100644 index 0000000000..4e24ee12a8 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_reduce_addf_f16_invalid.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_reduce_addf_f16_invalid( + %source: !pto.vmi.vreg<128xf16, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xf16, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb16, #pto.vmi.layout>) { + %out = pto.vmi.reduce_addf %source, %init, %mask {reassoc} + : !pto.vmi.vreg<128xf16, #pto.vmi.layout>, + !pto.vmi.vreg<1xf16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.vreg<1xf16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.reduce_addf lowers through pto.vcadd only with reassoc +// CHECK-SAME: currently supports only f32 elements diff --git a/test/lit/vmi/vmi_to_vpto_reduce_addf_multichunk.pto b/test/lit/vmi/vmi_to_vpto_reduce_addf_multichunk.pto new file mode 100644 index 0000000000..0389c17e25 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_reduce_addf_multichunk.pto @@ -0,0 +1,38 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_reduce_addf_multichunk( + %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.reduce_addf %source, %init, %mask {reassoc} + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_reduce_addf_multichunk( +// CHECK: %[[FIRST:.*]] = pto.pge_b32 "PAT_VL1" : !pto.mask +// CHECK: %[[RED0:.*]] = pto.vcadd %arg0, %arg3 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[ACC0:.*]] = pto.vadd %[[RED0]], %arg2, %[[FIRST]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[RED1:.*]] = pto.vcadd %arg1, %arg4 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[ACC1:.*]] = pto.vadd %[[RED1]], %[[ACC0]], %[[FIRST]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[ACC1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_reduce_addi.pto b/test/lit/vmi/vmi_to_vpto_reduce_addi.pto new file mode 100644 index 0000000000..fd6c461b2c --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_reduce_addi.pto @@ -0,0 +1,36 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_reduce_addi( + %source: !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) + -> !pto.vreg<64xi32> { + %out = pto.vmi.reduce_addi %source, %init, %mask + : !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + !pto.vmi.vreg<1xi32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<1xi32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1xi32, #pto.vmi.layout>) + -> !pto.vreg<64xi32> + return %part : !pto.vreg<64xi32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_reduce_addi( +// CHECK: %[[FIRST:.*]] = pto.pge_b32 "PAT_VL1" : !pto.mask +// CHECK: %[[REDUCED:.*]] = pto.vcadd %arg0, %arg2 : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[OUT:.*]] = pto.vadd %[[REDUCED]], %arg1, %[[FIRST]] : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_reduce_addi_i16_invalid.pto b/test/lit/vmi/vmi_to_vpto_reduce_addi_i16_invalid.pto new file mode 100644 index 0000000000..466374c65c --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_reduce_addi_i16_invalid.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_reduce_addi_i16_invalid( + %source: !pto.vmi.vreg<128xi16, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xi16, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb16, #pto.vmi.layout>) { + %out = pto.vmi.reduce_addi %source, %init, %mask + : !pto.vmi.vreg<128xi16, #pto.vmi.layout>, + !pto.vmi.vreg<1xi16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.vreg<1xi16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.reduce_addi lowers through pto.vcadd only +// CHECK-SAME: currently supports only 32-bit integer elements diff --git a/test/lit/vmi/vmi_to_vpto_reduce_addi_multichunk.pto b/test/lit/vmi/vmi_to_vpto_reduce_addi_multichunk.pto new file mode 100644 index 0000000000..8275a80790 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_reduce_addi_multichunk.pto @@ -0,0 +1,38 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_reduce_addi_multichunk( + %source: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vreg<64xi32> { + %out = pto.vmi.reduce_addi %source, %init, %mask + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.vreg<1xi32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<1xi32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1xi32, #pto.vmi.layout>) + -> !pto.vreg<64xi32> + return %part : !pto.vreg<64xi32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_reduce_addi_multichunk( +// CHECK: %[[FIRST:.*]] = pto.pge_b32 "PAT_VL1" : !pto.mask +// CHECK: %[[RED0:.*]] = pto.vcadd %arg0, %arg3 : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[ACC0:.*]] = pto.vadd %[[RED0]], %arg2, %[[FIRST]] : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[RED1:.*]] = pto.vcadd %arg1, %arg4 : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[ACC1:.*]] = pto.vadd %[[RED1]], %[[ACC0]], %[[FIRST]] : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: return %[[ACC1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_reduce_maxf_multichunk.pto b/test/lit/vmi/vmi_to_vpto_reduce_maxf_multichunk.pto new file mode 100644 index 0000000000..51782e8462 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_reduce_maxf_multichunk.pto @@ -0,0 +1,65 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_reduce_maxf_multichunk( + %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.reduce_maxf %source, %init, %mask + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_reduce_minf_multichunk( + %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.reduce_minf %source, %init, %mask + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_reduce_maxf_multichunk( +// CHECK: %[[FIRST:.*]] = pto.pge_b32 "PAT_VL1" : !pto.mask +// CHECK: %[[RED0:.*]] = pto.vcmax %arg0, %arg3 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[ACC0:.*]] = pto.vmax %[[RED0]], %arg2, %[[FIRST]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[RED1:.*]] = pto.vcmax %arg1, %arg4 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[ACC1:.*]] = pto.vmax %[[RED1]], %[[ACC0]], %[[FIRST]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[ACC1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_reduce_minf_multichunk( +// CHECK: %[[FIRST:.*]] = pto.pge_b32 "PAT_VL1" : !pto.mask +// CHECK: %[[RED0:.*]] = pto.vcmin %arg0, %arg3 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[ACC0:.*]] = pto.vmin %[[RED0]], %arg2, %[[FIRST]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[RED1:.*]] = pto.vcmin %arg1, %arg4 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[ACC1:.*]] = pto.vmin %[[RED1]], %[[ACC0]], %[[FIRST]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[ACC1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_reduce_maxf_tail_invalid.pto b/test/lit/vmi/vmi_to_vpto_reduce_maxf_tail_invalid.pto new file mode 100644 index 0000000000..a926b48d70 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_reduce_maxf_tail_invalid.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_reduce_maxf_tail_invalid( + %source: !pto.vmi.vreg<65xf32, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<65xb32, #pto.vmi.layout>) { + %out = pto.vmi.reduce_maxf %source, %init, %mask + : !pto.vmi.vreg<65xf32, #pto.vmi.layout>, + !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + !pto.vmi.mask<65xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return + } +} + +// CHECK: VMI{{.*}} pto.vmi.reduce_maxf lowers through pto.vcmax only +// CHECK-SAME: requires full source physical chunks diff --git a/test/lit/vmi/vmi_to_vpto_reduce_minf.pto b/test/lit/vmi/vmi_to_vpto_reduce_minf.pto new file mode 100644 index 0000000000..96a70a03f3 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_reduce_minf.pto @@ -0,0 +1,36 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_reduce_minf( + %source: !pto.vmi.vreg<128xf16, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xf16, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb16, #pto.vmi.layout>) + -> !pto.vreg<128xf16> { + %out = pto.vmi.reduce_minf %source, %init, %mask + : !pto.vmi.vreg<128xf16, #pto.vmi.layout>, + !pto.vmi.vreg<1xf16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.vreg<1xf16, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1xf16, #pto.vmi.layout>) + -> !pto.vreg<128xf16> + return %part : !pto.vreg<128xf16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_reduce_minf( +// CHECK: %[[FIRST:.*]] = pto.pge_b16 "PAT_VL1" : !pto.mask +// CHECK: %[[REDUCED:.*]] = pto.vcmin %arg0, %arg2 : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: %[[OUT:.*]] = pto.vmin %[[REDUCED]], %arg1, %[[FIRST]] : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_reduce_shape_invalid.pto b/test/lit/vmi/vmi_to_vpto_reduce_shape_invalid.pto new file mode 100644 index 0000000000..1b2cf33ffa --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_reduce_shape_invalid.pto @@ -0,0 +1,85 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_reduce_addi_tail_invalid( + %source: !pto.vmi.vreg<32xi32, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<32xb32, #pto.vmi.layout>) { + %out = pto.vmi.reduce_addi %source, %init, %mask + : !pto.vmi.vreg<32xi32, #pto.vmi.layout>, + !pto.vmi.vreg<1xi32, #pto.vmi.layout>, + !pto.vmi.mask<32xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<1xi32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.reduce_addi lowers through pto.vcadd only +// CHECK-SAME: requires full source physical chunks +// CHECK-SAME: found padding lane in physical chunk + +// ----- + +module { + func.func @vmi_to_vpto_reduce_addf_deint_invalid( + %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + %out = pto.vmi.reduce_addf %source, %init, %mask {reassoc} + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.reduce_addf lowers through pto.vcadd only +// CHECK-SAME: requires contiguous source, init, mask, and result layouts + +// ----- + +module { + func.func @vmi_to_vpto_reduce_minf_tail_invalid( + %source: !pto.vmi.vreg<64xf16, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xf16, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb16, #pto.vmi.layout>) { + %out = pto.vmi.reduce_minf %source, %init, %mask + : !pto.vmi.vreg<64xf16, #pto.vmi.layout>, + !pto.vmi.vreg<1xf16, #pto.vmi.layout>, + !pto.vmi.mask<64xb16, #pto.vmi.layout> + -> !pto.vmi.vreg<1xf16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.reduce_minf lowers through pto.vcmin only +// CHECK-SAME: requires full source physical chunks +// CHECK-SAME: found padding lane in physical chunk + +// ----- + +module { + func.func @vmi_to_vpto_reduce_maxf_deint_invalid( + %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + %out = pto.vmi.reduce_maxf %source, %init, %mask + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.reduce_maxf lowers through pto.vcmax only +// CHECK-SAME: requires contiguous source, init, mask, and result layouts diff --git a/test/lit/vmi/vmi_to_vpto_relu_element_type_invalid.pto b/test/lit/vmi/vmi_to_vpto_relu_element_type_invalid.pto new file mode 100644 index 0000000000..ab4f204979 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_relu_element_type_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_relu_bf16_invalid( + %source: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>) { + %relu = pto.vmi.relu %source + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.relu direct lowering requires physical vreg parts with b8/b16/b32 predicate masks and f16/f32 element type +// CHECK-SAME: pto.vrelu direct lowering supports only f16/f32 VMI floating-point element types diff --git a/test/lit/vmi/vmi_to_vpto_scatter.pto b/test/lit/vmi/vmi_to_vpto_scatter.pto new file mode 100644 index 0000000000..12799c01fc --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_scatter.pto @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_scatter( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %indices: !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + pto.vmi.scatter %value, %dst[%indices], %mask {indices_unique} + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_scatter( +// CHECK: pto.vscatter %arg0, %arg1, %arg2, %arg3 : !pto.vreg<64xf32>, !pto.ptr, !pto.vreg<64xi32>, !pto.mask +// CHECK: return +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_scatter_missing_unique_invalid.pto b/test/lit/vmi/vmi_to_vpto_scatter_missing_unique_invalid.pto new file mode 100644 index 0000000000..027162ac68 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_scatter_missing_unique_invalid.pto @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_scatter_missing_unique_invalid( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %indices: !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + pto.vmi.scatter %value, %dst[%indices], %mask + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.scatter lowers through pto.vscatter only with an indices_unique proof +// CHECK-SAME: requires indices_unique proof diff --git a/test/lit/vmi/vmi_to_vpto_scf_for.pto b/test/lit/vmi/vmi_to_vpto_scf_for.pto new file mode 100644 index 0000000000..253432b6dc --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_scf_for.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_scf_for(%a: !pto.vmi.vreg<128xf16>) + -> !pto.vmi.vreg<128xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %init = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %result = scf.for %i = %c0 to %c2 step %c1 + iter_args(%acc = %init) -> (!pto.vmi.vreg<128xf32>) { + %next = pto.vmi.addf %acc, %acc + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + scf.yield %next : !pto.vmi.vreg<128xf32> + } + return %result : !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_scf_for( +// CHECK-SAME: %[[A:[^)]+]]: !pto.vreg<128xf16> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK-DAG: %[[P0:.*]] = pto.vcvt %[[A]] +// CHECK-DAG: %[[P1:.*]] = pto.vcvt %[[A]] +// CHECK: %[[RESULT:.*]]:2 = scf.for +// CHECK-SAME: iter_args(%[[ACC0:.*]] = %[[P0]], %[[ACC1:.*]] = %[[P1]]) +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK: %[[N0:.*]] = pto.vadd %[[ACC0]], %[[ACC0]] +// CHECK: %[[N1:.*]] = pto.vadd %[[ACC1]], %[[ACC1]] +// CHECK: scf.yield %[[N0]], %[[N1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_scf_if.pto b/test/lit/vmi/vmi_to_vpto_scf_if.pto new file mode 100644 index 0000000000..dcc7497ee4 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_scf_if.pto @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_scf_if( + %cond: i1, + %a: !pto.vmi.vreg<128xf16>, + %b: !pto.vmi.vreg<128xf16>) -> !pto.vmi.vreg<128xf32> { + %value, %mask = scf.if %cond + -> (!pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred>) { + %ea = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %cmpa = pto.vmi.cmpf "olt", %ea, %ea + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.mask<128xpred> + scf.yield %ea, %cmpa : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + } else { + %eb = pto.vmi.extf %b + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %cmpb = pto.vmi.cmpf "olt", %eb, %eb + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.mask<128xpred> + scf.yield %eb, %cmpb : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + } + %selected = pto.vmi.select %mask, %value, %value + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, + !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + return %selected : !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_scf_if( +// CHECK-SAME: %[[COND:[^,]+]]: i1 +// CHECK-SAME: %[[A:[^,]+]]: !pto.vreg<128xf16> +// CHECK-SAME: %[[B:[^)]+]]: !pto.vreg<128xf16> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK: %[[IF:.*]]:4 = scf.if %[[COND]] -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask, !pto.mask) +// CHECK: pto.vcvt %[[A]] +// CHECK: pto.vcmp {{.*}}, {{.*}}, {{.*}}, "lt" +// CHECK: scf.yield {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask, !pto.mask +// CHECK: else +// CHECK: pto.vcvt %[[B]] +// CHECK: pto.vcmp {{.*}}, {{.*}}, {{.*}}, "lt" +// CHECK: scf.yield {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask, !pto.mask +// CHECK: pto.vsel %[[IF]]#0, %[[IF]]#0, %[[IF]]#2 +// CHECK: pto.vsel %[[IF]]#1, %[[IF]]#1, %[[IF]]#3 +// CHECK: return {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_shli.pto b/test/lit/vmi/vmi_to_vpto_shli.pto new file mode 100644 index 0000000000..eb5fa7d64d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_shli.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_shli( + %value: !pto.vmi.vreg<256xi16>, + %amount: !pto.vmi.vreg<256xi16>) -> !pto.vmi.vreg<256xi16> { + %shifted = pto.vmi.shli %value, %amount + : !pto.vmi.vreg<256xi16>, !pto.vmi.vreg<256xi16> + -> !pto.vmi.vreg<256xi16> + return %shifted : !pto.vmi.vreg<256xi16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_shli( +// CHECK-SAME: %[[VALUE0:[^,]+]]: !pto.vreg<128xi16> +// CHECK-SAME: %[[VALUE1:[^,]+]]: !pto.vreg<128xi16> +// CHECK-SAME: %[[AMOUNT0:[^,]+]]: !pto.vreg<128xi16> +// CHECK-SAME: %[[AMOUNT1:[^)]+]]: !pto.vreg<128xi16> +// CHECK-SAME: -> (!pto.vreg<128xi16>, !pto.vreg<128xi16>) +// CHECK-DAG: %[[SHL0:.*]] = pto.vshl %[[VALUE0]], %[[AMOUNT0]], {{.*}} : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK-DAG: %[[SHL1:.*]] = pto.vshl %[[VALUE1]], %[[AMOUNT1]], {{.*}} : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK: return %[[SHL0]], %[[SHL1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_shrui.pto b/test/lit/vmi/vmi_to_vpto_shrui.pto new file mode 100644 index 0000000000..46ccbf8d86 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_shrui.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_shrui( + %value: !pto.vmi.vreg<256xui16>, + %amount: !pto.vmi.vreg<256xui16>) -> !pto.vmi.vreg<256xui16> { + %shifted = pto.vmi.shrui %value, %amount + : !pto.vmi.vreg<256xui16>, !pto.vmi.vreg<256xui16> + -> !pto.vmi.vreg<256xui16> + return %shifted : !pto.vmi.vreg<256xui16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_shrui( +// CHECK-SAME: %[[VALUE0:[^,]+]]: !pto.vreg<128xui16> +// CHECK-SAME: %[[VALUE1:[^,]+]]: !pto.vreg<128xui16> +// CHECK-SAME: %[[AMOUNT0:[^,]+]]: !pto.vreg<128xui16> +// CHECK-SAME: %[[AMOUNT1:[^)]+]]: !pto.vreg<128xui16> +// CHECK-SAME: -> (!pto.vreg<128xui16>, !pto.vreg<128xui16>) +// CHECK-DAG: %[[SHR0:.*]] = pto.vshr %[[VALUE0]], %[[AMOUNT0]], {{.*}} : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> +// CHECK-DAG: %[[SHR1:.*]] = pto.vshr %[[VALUE1]], %[[AMOUNT1]], {{.*}} : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> +// CHECK: return %[[SHR0]], %[[SHR1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_shuffle_forwarding.pto b/test/lit/vmi/vmi_to_vpto_shuffle_forwarding.pto new file mode 100644 index 0000000000..dc237c02dc --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_shuffle_forwarding.pto @@ -0,0 +1,159 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_shuffle_identity( + %src: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> { + %out = "pto.vmi.shuffle"(%src) { + indices = array + } : (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> + return %out : !pto.vmi.vreg<128xf32> + } + + func.func @vmi_shuffle_second_chunk( + %src: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<64xf32> { + %out = "pto.vmi.shuffle"(%src) { + indices = array + } : (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<64xf32> + return %out : !pto.vmi.vreg<64xf32> + } + + func.func @vmi_shuffle_tail_prefix( + %src: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<4xf32, #pto.vmi.layout> { + %out = "pto.vmi.shuffle"(%src) { + indices = array + } : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<4xf32, #pto.vmi.layout> + return %out : !pto.vmi.vreg<4xf32, #pto.vmi.layout> + } + + func.func @vmi_shuffle_chunk_swap( + %src: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> { + %out = "pto.vmi.shuffle"(%src) { + indices = array + } : (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> + return %out : !pto.vmi.vreg<128xf32> + } + + func.func @vmi_shuffle_reverse_one_chunk( + %src: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> { + %out = "pto.vmi.shuffle"(%src) { + indices = array + } : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return %out : !pto.vmi.vreg<64xf32, #pto.vmi.layout> + } + + func.func @vmi_shuffle_deint2_identity( + %src: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + %out = "pto.vmi.shuffle"(%src) { + indices = array + } : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %out : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_shuffle_identity( +// CHECK-SAME: %[[D0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[D1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-NEXT: return %[[D0]], %[[D1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// CHECK-LABEL: func.func @vmi_shuffle_second_chunk( +// CHECK-SAME: %{{[^,]+}}: !pto.vreg<64xf32> +// CHECK-SAME: %[[D1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-NEXT: return %[[D1]] : !pto.vreg<64xf32> + +// CHECK-LABEL: func.func @vmi_shuffle_tail_prefix( +// CHECK-SAME: %[[S0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %{{[^)]+}}: !pto.vreg<64xf32> +// CHECK-NEXT: return %[[S0]] : !pto.vreg<64xf32> + +// CHECK-LABEL: func.func @vmi_shuffle_chunk_swap( +// CHECK-SAME: %[[S0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[S1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-NEXT: return %[[S1]], %[[S0]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// CHECK-LABEL: func.func @vmi_shuffle_reverse_one_chunk( +// CHECK-SAME: %[[SRC:[^)]+]]: !pto.vreg<64xf32> +// CHECK-DAG: %[[C63:.*]] = arith.constant 63 : i32 +// CHECK: %[[IDX:.*]] = pto.vci %[[C63]] {order = "DESC"} : i32 -> !pto.vreg<64xi32> +// CHECK: %[[OUT:.*]] = pto.vselr %[[SRC]], %[[IDX]] : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +// CHECK-NEXT: return %[[OUT]] : !pto.vreg<64xf32> + +// CHECK-LABEL: func.func @vmi_shuffle_deint2_identity( +// CHECK-SAME: %[[P0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[P1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-NEXT: return %[[P0]], %[[P1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_shuffle_lane0_splat.pto b/test/lit/vmi/vmi_to_vpto_shuffle_lane0_splat.pto new file mode 100644 index 0000000000..264b7b6a6a --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_shuffle_lane0_splat.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_shuffle_lane0_splat( + %src: !pto.vmi.vreg<1xf32>) -> !pto.vmi.vreg<128xf32> { + %out = "pto.vmi.shuffle"(%src) { + indices = array + } : (!pto.vmi.vreg<1xf32>) -> !pto.vmi.vreg<128xf32> + return %out : !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_shuffle_lane0_splat( +// CHECK: %[[MASK0:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[DUP0:.*]] = pto.vdup %arg0, %[[MASK0]] {position = "LOWEST"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[MASK1:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[DUP1:.*]] = pto.vdup %arg0, %[[MASK1]] {position = "LOWEST"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[DUP0]], %[[DUP1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_stable_gather_masked_load_todo_invalid.pto b/test/lit/vmi/vmi_to_vpto_stable_gather_masked_load_todo_invalid.pto new file mode 100644 index 0000000000..6e89595596 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_stable_gather_masked_load_todo_invalid.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto='enable-stable-gather-masked-load=true' 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_stable_gather_masked_load_todo( + %src: !pto.ptr, + %offset: index, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) { + %out = pto.vmi.masked_load %src[%offset], %mask, %passthru + : !pto.ptr, + !pto.vmi.mask<64xb32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + pto.vmi.store %out, %src[%offset] + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.ptr + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.masked_load stable VGATHER-based lowering is reserved for strict masked/tail loads but is not implemented yet diff --git a/test/lit/vmi/vmi_to_vpto_store_deint.pto b/test/lit/vmi/vmi_to_vpto_store_deint.pto new file mode 100644 index 0000000000..cafebbf14d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_store_deint.pto @@ -0,0 +1,64 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_store_deint2( + %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %dst: !pto.ptr, %offset: index) { + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr + return + } + + func.func @vmi_to_vpto_store_deint4( + %value: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + %dst: !pto.ptr, %offset: index) { + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr + return + } + + func.func @vmi_to_vpto_store_deint2_multichunk( + %value: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + %dst: !pto.ptr, %offset: index) { + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_store_deint2( +// CHECK: %[[MASK:.*]] = pto.pset_b32 "PAT_ALL" +// CHECK: pto.vstsx2 %arg0, %arg1, %arg2[%arg3], "INTLV_B32", %[[MASK]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_store_deint4( +// CHECK: %[[A0:.*]], %[[A1:.*]] = pto.vintlv %arg0, %arg2 +// CHECK: %[[B0:.*]], %[[B1:.*]] = pto.vintlv %arg1, %arg3 +// CHECK: %[[D0:.*]], %[[D1:.*]] = pto.vintlv %[[A0]], %[[B0]] +// CHECK: %[[D2:.*]], %[[D3:.*]] = pto.vintlv %[[A1]], %[[B1]] +// CHECK: pto.vsts %[[D0]], %arg4[%arg5] +// CHECK: pto.vsts %[[D1]], %arg4[{{.*}}] +// CHECK: pto.vsts %[[D2]], %arg4[{{.*}}] +// CHECK: pto.vsts %[[D3]], %arg4[{{.*}}] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_store_deint2_multichunk( +// CHECK: %[[MASK0:.*]] = pto.pset_b32 "PAT_ALL" +// CHECK: pto.vstsx2 %arg0, %arg2, %arg4[%arg5], "INTLV_B32", %[[MASK0]] +// CHECK: %[[MASK1:.*]] = pto.pset_b32 "PAT_ALL" +// CHECK: pto.vstsx2 %arg1, %arg3, %arg4[{{.*}}], "INTLV_B32", %[[MASK1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_store_deint_invalid.pto b/test/lit/vmi/vmi_to_vpto_store_deint_invalid.pto new file mode 100644 index 0000000000..e1068f813c --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_store_deint_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_store_deint_invalid( + %value: !pto.vmi.vreg<129xf32, #pto.vmi.layout>, + %dst: !pto.ptr, %offset: index) { + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<129xf32, #pto.vmi.layout>, !pto.ptr + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.store requires an 8/16/32-bit predicate-maskable element type +// CHECK-SAME: requires every deinterleaved part to have the same physical chunk count diff --git a/test/lit/vmi/vmi_to_vpto_store_deint_tail.pto b/test/lit/vmi/vmi_to_vpto_store_deint_tail.pto new file mode 100644 index 0000000000..653d9b6f33 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_store_deint_tail.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_store_deint_tail( + %value: !pto.vmi.vreg<4xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<4xf32, #pto.vmi.layout>, + !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_store_deint_tail( +// CHECK-SAME: %[[P0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[P1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[DST:[^,]+]]: !pto.ptr +// CHECK-SAME: %[[OFF:[^)]+]]: index +// CHECK: %[[C4:.*]] = arith.constant 4 : i32 +// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = pto.vintlv %[[P0]], %[[P1]] +// CHECK: %[[MASK:.*]], %{{.*}} = pto.plt_b32 %[[C4]] : i32 -> !pto.mask, i32 +// CHECK: pto.vsts %[[LOW]], %[[DST]][%[[OFF]]], %[[MASK]] +// CHECK-NOT: pto.vsts %[[HIGH]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_store_tail.pto b/test/lit/vmi/vmi_to_vpto_store_tail.pto new file mode 100644 index 0000000000..34058b925c --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_store_tail.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_store_tail( + %value: !pto.vmi.vreg<100xf32, #pto.vmi.layout>, + %dst: !pto.ptr, %offset: index) { + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<100xf32, #pto.vmi.layout>, !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_store_tail( +// CHECK: %[[C36:.*]] = arith.constant 36 : i32 +// CHECK: %[[FULL_MASK:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: pto.vsts %arg0, %arg2[%arg3], %[[FULL_MASK]] +// CHECK: %[[TAIL_MASK:.*]], %{{.*}} = pto.plt_b32 %[[C36]] : i32 -> !pto.mask, i32 +// CHECK: pto.vsts %arg1, %arg2[{{.*}}], %[[TAIL_MASK]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_store_width_invalid.pto b/test/lit/vmi/vmi_to_vpto_store_width_invalid.pto new file mode 100644 index 0000000000..b412afdcca --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_store_width_invalid.pto @@ -0,0 +1,38 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_store_f64_unsupported( + %value: !pto.vmi.vreg<32xf64, #pto.vmi.layout>, + %dst: memref<32xf64>, + %offset: index) { + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<32xf64, #pto.vmi.layout>, memref<32xf64> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.store requires an 8/16/32-bit predicate-maskable element type +// CHECK-SAME: requires an 8/16/32-bit element type + +// ----- + +module { + func.func @vmi_tile_write_f64_unsupported( + %value: !pto.vmi.vreg<32xf64, #pto.vmi.layout>, + %dst: memref<32xf64>) { + pto.vmi.tile_write %value, %dst + : !pto.vmi.vreg<32xf64, #pto.vmi.layout>, memref<32xf64> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.tile_write requires an 8/16/32-bit predicate-maskable element type +// CHECK-SAME: requires an 8/16/32-bit element type diff --git a/test/lit/vmi/vmi_to_vpto_sub_mul.pto b/test/lit/vmi/vmi_to_vpto_sub_mul.pto new file mode 100644 index 0000000000..d76a6bfd3c --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_sub_mul.pto @@ -0,0 +1,60 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_subf_mulf( + %a: !pto.vmi.vreg<128xf16>, + %b: !pto.vmi.vreg<128xf16>) + -> (!pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32>) { + %wa = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %wb = pto.vmi.extf %b + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %diff = pto.vmi.subf %wa, %wb + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %prod = pto.vmi.mulf %wa, %wb + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %diff, %prod : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + } + + func.func @vmi_to_vpto_subi_muli( + %a: !pto.vmi.vreg<128xi32>, + %b: !pto.vmi.vreg<128xi32>) + -> (!pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32>) { + %diff = pto.vmi.subi %a, %b + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + %prod = pto.vmi.muli %a, %b + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + return %diff, %prod : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_subf_mulf( +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK-DAG: %[[SUB0:.*]] = pto.vsub {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[SUB1:.*]] = pto.vsub {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[MUL0:.*]] = pto.vmul {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[MUL1:.*]] = pto.vmul {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[SUB0]], %[[SUB1]], %[[MUL0]], %[[MUL1]] + +// CHECK-LABEL: func.func @vmi_to_vpto_subi_muli( +// CHECK-SAME: -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.vreg<64xi32>) +// CHECK-DAG: %[[ISUB0:.*]] = pto.vsub {{.*}} : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK-DAG: %[[ISUB1:.*]] = pto.vsub {{.*}} : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK-DAG: %[[IMUL0:.*]] = pto.vmul {{.*}} : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK-DAG: %[[IMUL1:.*]] = pto.vmul {{.*}} : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: return %[[ISUB0]], %[[ISUB1]], %[[IMUL0]], %[[IMUL1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_tile_read_write.pto b/test/lit/vmi/vmi_to_vpto_tile_read_write.pto new file mode 100644 index 0000000000..5b7e6dbe00 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_tile_read_write.pto @@ -0,0 +1,64 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_tile_read_write_contiguous(%src: memref<128xf32>, %dst: memref<128xf32>) { + %value = pto.vmi.tile_read %src + : memref<128xf32> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + pto.vmi.tile_write %value, %dst + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, memref<128xf32> + return + } + + func.func @vmi_to_vpto_tile_read_deint2(%src: memref<128xf32>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %value = pto.vmi.tile_read %src + : memref<128xf32> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_tile_write_deint2( + %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %dst: memref<128xf32>) { + pto.vmi.tile_write %value, %dst + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, memref<128xf32> + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_tile_read_write_contiguous( +// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index +// CHECK: %[[L0:.*]] = pto.vlds %arg0[%[[ZERO]]] : memref<128xf32> -> !pto.vreg<64xf32> +// CHECK: %[[L1:.*]] = pto.vlds %arg0[%[[C64]]] : memref<128xf32> -> !pto.vreg<64xf32> +// CHECK: pto.vsts %[[L0]], %arg1[%[[ZERO]]] +// CHECK: pto.vsts %[[L1]], %arg1[{{.*}}] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_tile_read_deint2( +// CHECK: %[[ZERO:.*]] = arith.constant 0 : index +// CHECK: %[[P0:.*]], %[[P1:.*]] = pto.vldsx2 %arg0[%[[ZERO]]], "DINTLV_B32" +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_tile_write_deint2( +// CHECK: %[[ZERO:.*]] = arith.constant 0 : index +// CHECK: %[[MASK:.*]] = pto.pset_b32 "PAT_ALL" +// CHECK: pto.vstsx2 %arg0, %arg1, %arg2[%[[ZERO]]], "INTLV_B32", %[[MASK]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_tile_write_deint_tail.pto b/test/lit/vmi/vmi_to_vpto_tile_write_deint_tail.pto new file mode 100644 index 0000000000..701d921186 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_tile_write_deint_tail.pto @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_tile_write_deint_tail( + %value: !pto.vmi.vreg<4xf32, #pto.vmi.layout>, + %dst: memref<4xf32>) { + pto.vmi.tile_write %value, %dst + : !pto.vmi.vreg<4xf32, #pto.vmi.layout>, + memref<4xf32> + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_tile_write_deint_tail( +// CHECK-SAME: %[[P0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[P1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[DST:[^)]+]]: memref<4xf32> +// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : i32 +// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = pto.vintlv %[[P0]], %[[P1]] +// CHECK: %[[MASK:.*]], %{{.*}} = pto.plt_b32 %[[C4]] : i32 -> !pto.mask, i32 +// CHECK: pto.vsts %[[LOW]], %[[DST]][%[[ZERO]]], %[[MASK]] +// CHECK-NOT: pto.vsts %[[HIGH]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_tile_write_tail.pto b/test/lit/vmi/vmi_to_vpto_tile_write_tail.pto new file mode 100644 index 0000000000..d4f37d48fc --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_tile_write_tail.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_tile_write_tail( + %value: !pto.vmi.vreg<100xf32, #pto.vmi.layout>, + %dst: memref<100xf32>) { + pto.vmi.tile_write %value, %dst + : !pto.vmi.vreg<100xf32, #pto.vmi.layout>, memref<100xf32> + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_tile_write_tail( +// CHECK-SAME: %[[V0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[V1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[DST:[^)]+]]: memref<100xf32> +// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C36:.*]] = arith.constant 36 : i32 +// CHECK: %[[FULL_MASK:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: pto.vsts %[[V0]], %[[DST]][%[[ZERO]]], %[[FULL_MASK]] +// CHECK: %[[TAIL_MASK:.*]], %{{.*}} = pto.plt_b32 %[[C36]] : i32 -> !pto.mask, i32 +// CHECK: pto.vsts %[[V1]], %[[DST]][{{.*}}], %[[TAIL_MASK]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_tile_write_tail_deint_invalid.pto b/test/lit/vmi/vmi_to_vpto_tile_write_tail_deint_invalid.pto new file mode 100644 index 0000000000..4d4dac9d6d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_tile_write_tail_deint_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_tile_write_tail_deint_invalid( + %value: !pto.vmi.vreg<129xf32, #pto.vmi.layout>, + %dst: memref<129xf32>) { + pto.vmi.tile_write %value, %dst + : !pto.vmi.vreg<129xf32, #pto.vmi.layout>, memref<129xf32> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.tile_write requires an 8/16/32-bit predicate-maskable element type +// CHECK-SAME: requires every deinterleaved part to have the same physical chunk count diff --git a/test/lit/vmi/vmi_to_vpto_truncf.pto b/test/lit/vmi/vmi_to_vpto_truncf.pto new file mode 100644 index 0000000000..e8d8340c83 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_truncf.pto @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_truncf_f32_to_f16( + %even: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %odd: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vreg<128xf16> { + %wide = pto.vmi.addf %even, %odd + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %narrow = pto.vmi.truncf %wide + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> + %p = "pto.vmi.unpack"(%narrow) + : (!pto.vmi.vreg<128xf16, #pto.vmi.layout>) + -> !pto.vreg<128xf16> + return %p : !pto.vreg<128xf16> + } + + func.func @vmi_to_vpto_truncf_f32_tail_to_f16( + %wide: !pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> !pto.vreg<128xf16> { + %narrow = pto.vmi.truncf %wide + : !pto.vmi.vreg<100xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<100xf16, #pto.vmi.layout> + %p = "pto.vmi.unpack"(%narrow) + : (!pto.vmi.vreg<100xf16, #pto.vmi.layout>) + -> !pto.vreg<128xf16> + return %p : !pto.vreg<128xf16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_truncf_f32_to_f16( +// CHECK: %[[EVEN:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: %[[ODD:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: pto.vor %[[EVEN]], %[[ODD]], {{.*}} : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_truncf_f32_tail_to_f16( +// CHECK: %[[EVEN:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: %[[ODD:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: pto.vor %[[EVEN]], %[[ODD]], {{.*}} : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto b/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto new file mode 100644 index 0000000000..5297123e5a --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_truncf_fp8_128_contiguous_invalid( + %input: !pto.vmi.vreg<128xf32>) { + %packed = pto.vmi.truncf %input + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf8E4M3FN> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.truncf operand #0 has type {{'?}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: but requires {{'?}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: pto.vmi.ensure_layout cannot materialize this conversion +// CHECK: failed helper conversion {{'?}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: {{'?}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: requires source and result to have the same physical arity diff --git a/test/lit/vmi/vmi_to_vpto_truncf_unsupported_shape_invalid.pto b/test/lit/vmi/vmi_to_vpto_truncf_unsupported_shape_invalid.pto new file mode 100644 index 0000000000..9d8cb972aa --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_truncf_unsupported_shape_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_truncf_unsupported_shape_invalid( + %input: !pto.vmi.vreg<256xf32, #pto.vmi.layout>) { + %narrow = pto.vmi.truncf %input + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.truncf supports only f32 deinterleaved=2 source parts +// CHECK-SAME: one contiguous f16 result chunk +// CHECK-SAME: f32 deinterleaved=4 source parts to one contiguous fp8-like result chunk diff --git a/test/lit/vmi/vmi_to_vpto_type_arity.pto b/test/lit/vmi/vmi_to_vpto_type_arity.pto new file mode 100644 index 0000000000..e99e8e9ea0 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_type_arity.pto @@ -0,0 +1,63 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_type_arity_contiguous_partial( + %value: !pto.vmi.vreg<130xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<130xb32, #pto.vmi.layout>) { + return + } + + func.func @vmi_to_vpto_type_arity_deint4( + %value: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<256xb32, #pto.vmi.layout>) { + return + } + + func.func @vmi_to_vpto_type_arity_deint2_partial( + %value: !pto.vmi.vreg<130xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<130xb32, #pto.vmi.layout>) { + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_type_arity_contiguous_partial( +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK: return + +// CHECK-LABEL: func.func @vmi_to_vpto_type_arity_deint4( +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK: return + +// CHECK-LABEL: func.func @vmi_to_vpto_type_arity_deint2_partial( +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK: return +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast +// CHECK-NOT: pto.vmi. diff --git a/test/lit/vmi/vmi_to_vpto_type_attr_nested_residual_invalid.pto b/test/lit/vmi/vmi_to_vpto_type_attr_nested_residual_invalid.pto new file mode 100644 index 0000000000..afc8502caf --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_type_attr_nested_residual_invalid.pto @@ -0,0 +1,16 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module attributes { + pto.hidden_vmi_type = [{nested = !pto.vmi.vreg<128xf32, #pto.vmi.layout>}] +} { +} + +// CHECK: VMI-RESIDUAL-OP: failed to convert all VMI ops/types to VPTO diff --git a/test/lit/vmi/vmi_to_vpto_type_attr_residual_invalid.pto b/test/lit/vmi/vmi_to_vpto_type_attr_residual_invalid.pto new file mode 100644 index 0000000000..c115c1c3d8 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_type_attr_residual_invalid.pto @@ -0,0 +1,16 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module attributes { + pto.hidden_vmi_type = !pto.vmi.vreg<128xf32, #pto.vmi.layout> +} { +} + +// CHECK: VMI-RESIDUAL-OP: failed to convert all VMI ops/types to VPTO diff --git a/test/lit/vmi/vmi_to_vpto_type_only.pto b/test/lit/vmi/vmi_to_vpto_type_only.pto new file mode 100644 index 0000000000..777afaf124 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_type_only.pto @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_type_only( + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %m: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_type_only( +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK: return +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast +// CHECK-NOT: pto.vmi. diff --git a/test/lit/vmi/vmi_to_vpto_unary_math.pto b/test/lit/vmi/vmi_to_vpto_unary_math.pto new file mode 100644 index 0000000000..5a4419bad2 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_unary_math.pto @@ -0,0 +1,89 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_unary_math( + %value: !pto.vmi.vreg<128xf32>) + -> (!pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32>, + !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32>, + !pto.vmi.vreg<128xf32>) { + %neg = pto.vmi.negf %value + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %sqrt = pto.vmi.sqrt %value + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %exp = pto.vmi.exp %value + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %ln = pto.vmi.ln %value + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %relu = pto.vmi.relu %value + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + return %neg, %sqrt, %exp, %ln, %relu + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32>, + !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32>, + !pto.vmi.vreg<128xf32> + } + + func.func @vmi_to_vpto_absf( + %value: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> { + %abs = pto.vmi.absf %value + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + return %abs : !pto.vmi.vreg<128xf32> + } + + func.func @vmi_to_vpto_absi( + %value: !pto.vmi.vreg<64xi32>) -> !pto.vmi.vreg<64xi32> { + %abs = pto.vmi.absi %value + : !pto.vmi.vreg<64xi32> -> !pto.vmi.vreg<64xi32> + return %abs : !pto.vmi.vreg<64xi32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_unary_math( +// CHECK-SAME: %[[V0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[V1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK-DAG: %[[NEG0:.*]] = pto.vneg %[[V0]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[NEG1:.*]] = pto.vneg %[[V1]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[SQRT0:.*]] = pto.vsqrt %[[V0]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[SQRT1:.*]] = pto.vsqrt %[[V1]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[EXP0:.*]] = pto.vexp %[[V0]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[EXP1:.*]] = pto.vexp %[[V1]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[LN0:.*]] = pto.vln %[[V0]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[LN1:.*]] = pto.vln %[[V1]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[RELU0:.*]] = pto.vrelu %[[V0]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[RELU1:.*]] = pto.vrelu %[[V1]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[NEG0]], %[[NEG1]], %[[SQRT0]], %[[SQRT1]], %[[EXP0]], %[[EXP1]], %[[LN0]], %[[LN1]], %[[RELU0]], %[[RELU1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_absf( +// CHECK-SAME: %[[F0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[F1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK-DAG: %[[ABSF0:.*]] = pto.vabs %[[F0]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[ABSF1:.*]] = pto.vabs %[[F1]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[ABSF0]], %[[ABSF1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_absi( +// CHECK-SAME: %[[I0:[^)]+]]: !pto.vreg<64xi32> +// CHECK-SAME: -> !pto.vreg<64xi32> +// CHECK: %[[ABSI:.*]] = pto.vabs %[[I0]], {{.*}} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: return %[[ABSI]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_unrealized_cast_residual_invalid.pto b/test/lit/vmi/vmi_to_vpto_unrealized_cast_residual_invalid.pto new file mode 100644 index 0000000000..9bf8f25949 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_unrealized_cast_residual_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_unrealized_cast_residual_invalid( + %arg0: i32) -> f32 { + %0 = builtin.unrealized_conversion_cast %arg0 + : i32 to f32 + return %0 : f32 + } +} + +// CHECK: VMI-RESIDUAL-OP: unrealized conversion cast remains after vmi-to-vpto diff --git a/test/lit/vmi/vmi_to_vpto_unsupported_op_invalid.pto b/test/lit/vmi/vmi_to_vpto_unsupported_op_invalid.pto new file mode 100644 index 0000000000..df51608b08 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_unsupported_op_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_unsupported_op_invalid( + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %shuffled = "pto.vmi.shuffle"(%a) { + indices = array + } : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<4xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.shuffle requires physical chunk forwarding or lane0 splat or vci-materializable vselr indices +// CHECK-SAME: forwarding: +// CHECK-SAME: lane0 splat: +// CHECK-SAME: vselr: diff --git a/test/lit/vmi/vmi_truncf_direction_invalid.pto b/test/lit/vmi/vmi_truncf_direction_invalid.pto new file mode 100644 index 0000000000..934f1e4ba3 --- /dev/null +++ b/test/lit/vmi/vmi_truncf_direction_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_truncf_direction_invalid(%source: !pto.vmi.vreg<128xf16>) { + %result = pto.vmi.truncf %source + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: requires result element type to be narrower than source element type diff --git a/test/lit/vmi/vmi_truncf_lane_mismatch_invalid.pto b/test/lit/vmi/vmi_truncf_lane_mismatch_invalid.pto new file mode 100644 index 0000000000..56e07a9892 --- /dev/null +++ b/test/lit/vmi/vmi_truncf_lane_mismatch_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_truncf_lane_mismatch_invalid(%source: !pto.vmi.vreg<64xf32>) { + %result = pto.vmi.truncf %source + : !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<128xf16> + return + } +} + +// CHECK: requires source and result logical lane counts to match diff --git a/test/lit/vmi/vmi_type_attr_parse.pto b/test/lit/vmi/vmi_type_attr_parse.pto new file mode 100644 index 0000000000..04613c441d --- /dev/null +++ b/test/lit/vmi/vmi_type_attr_parse.pto @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module attributes { + pto.vmi_contiguous = #pto.vmi.layout, + pto.vmi_deinterleaved2 = #pto.vmi.layout, + pto.vmi_deinterleaved4 = #pto.vmi.layout +} { + func.func @vmi_type_attr_parse( + %surface: !pto.vmi.vreg<128xf32>, + %contiguous: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %wide2: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %wide4: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + %surface_mask: !pto.vmi.mask<128xpred>, + %mask_b8: !pto.vmi.mask<128xb8, #pto.vmi.layout>, + %mask_b16: !pto.vmi.mask<128xb16, #pto.vmi.layout>, + %mask_b32: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { + return + } +} + +// CHECK: pto.vmi_contiguous = #pto.vmi.layout +// CHECK: pto.vmi_deinterleaved2 = #pto.vmi.layout +// CHECK: pto.vmi_deinterleaved4 = #pto.vmi.layout +// CHECK-LABEL: func.func @vmi_type_attr_parse( +// CHECK-SAME: %{{.*}}: !pto.vmi.vreg<128xf32> +// CHECK-SAME: %{{.*}}: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: %{{.*}}: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: %{{.*}}: !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// CHECK-SAME: %{{.*}}: !pto.vmi.mask<128xpred> +// CHECK-SAME: %{{.*}}: !pto.vmi.mask<128xb8, #pto.vmi.layout> +// CHECK-SAME: %{{.*}}: !pto.vmi.mask<128xb16, #pto.vmi.layout> +// CHECK-SAME: %{{.*}}: !pto.vmi.mask<128xb32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_type_element_count_invalid.pto b/test/lit/vmi/vmi_type_element_count_invalid.pto new file mode 100644 index 0000000000..a7548528c9 --- /dev/null +++ b/test/lit/vmi/vmi_type_element_count_invalid.pto @@ -0,0 +1,18 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_type_element_count_invalid( + %arg0: !pto.vmi.vreg<0xf32>) { + return + } +} + +// CHECK: expected a positive element count diff --git a/test/lit/vmi/vmi_unary_math_integer_invalid.pto b/test/lit/vmi/vmi_unary_math_integer_invalid.pto new file mode 100644 index 0000000000..8f3af3092e --- /dev/null +++ b/test/lit/vmi/vmi_unary_math_integer_invalid.pto @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file 2>&1 | FileCheck %s + +module { + func.func @vmi_sqrt_integer_invalid(%value: !pto.vmi.vreg<128xi32>) { + %sqrt = pto.vmi.sqrt %value + : !pto.vmi.vreg<128xi32> -> !pto.vmi.vreg<128xi32> + return + } +} + +// CHECK: 'pto.vmi.sqrt' op requires floating-point-like VMI element type + +// ----- + +module { + func.func @vmi_exp_integer_invalid(%value: !pto.vmi.vreg<128xi32>) { + %exp = pto.vmi.exp %value + : !pto.vmi.vreg<128xi32> -> !pto.vmi.vreg<128xi32> + return + } +} + +// CHECK: 'pto.vmi.exp' op requires floating-point-like VMI element type + +// ----- + +module { + func.func @vmi_ln_integer_invalid(%value: !pto.vmi.vreg<128xi32>) { + %ln = pto.vmi.ln %value + : !pto.vmi.vreg<128xi32> -> !pto.vmi.vreg<128xi32> + return + } +} + +// CHECK: 'pto.vmi.ln' op requires floating-point-like VMI element type + +// ----- + +module { + func.func @vmi_relu_integer_invalid(%value: !pto.vmi.vreg<128xi32>) { + %relu = pto.vmi.relu %value + : !pto.vmi.vreg<128xi32> -> !pto.vmi.vreg<128xi32> + return + } +} + +// CHECK: 'pto.vmi.relu' op requires floating-point-like VMI element type diff --git a/test/lit/vmi/vmi_unpack_arity_invalid.pto b/test/lit/vmi/vmi_unpack_arity_invalid.pto new file mode 100644 index 0000000000..5cd224a6e6 --- /dev/null +++ b/test/lit/vmi/vmi_unpack_arity_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_unpack_arity_invalid( + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %p0 = "pto.vmi.unpack"(%a) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) -> !pto.vreg<64xf32> + return + } +} + +// CHECK: requires 2 physical parts, got 1 diff --git a/test/vpto/cases/vmi/dequant-f16-to-f32-tail/compare.py b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/compare.py new file mode 100644 index 0000000000..8de470b64d --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-5, rtol=1e-5): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-5, rtol=1e-5))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/dequant-f16-to-f32-tail/golden.py b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/golden.py new file mode 100644 index 0000000000..8c3eb7acea --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/golden.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 1024 +LOGICAL_ELEMS = 1000 +SEED = 23 +SCALE = np.float32(2.0) +SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + src = rng.uniform(-4.0, 4.0, size=ELEMS).astype(np.float16) + dst = np.full(ELEMS, SENTINEL, dtype=np.float32) + golden = np.full(ELEMS, SENTINEL, dtype=np.float32) + golden[:LOGICAL_ELEMS] = src[:LOGICAL_ELEMS].astype(np.float32) * SCALE + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/dequant-f16-to-f32-tail/kernel.pto b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/kernel.pto new file mode 100644 index 0000000000..81b8640c0a --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/kernel.pto @@ -0,0 +1,60 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dequant_f16_to_f32_tail_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c1000 = arith.constant 1000 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %scale = arith.constant 2.000000e+00 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1000) -> (index) { + %mask = pto.vmi.create_mask %remaining : index -> !pto.vmi.mask<128xpred> + %packed = pto.vmi.load %ub_src[%offset] : !pto.ptr -> !pto.vmi.vreg<128xf16> + %wide = pto.vmi.extf %packed : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %scale_vec = pto.vmi.broadcast %scale : f32 -> !pto.vmi.vreg<128xf32> + %out = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + pto.vmi.masked_store %out, %ub_dst[%offset], %mask + : !pto.vmi.vreg<128xf32>, !pto.ptr, !pto.vmi.mask<128xpred> + %next = arith.subi %remaining, %c128 : index + scf.yield %next : index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/dequant-f16-to-f32-tail/launch.cpp b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/launch.cpp new file mode 100644 index 0000000000..3c329a34bb --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/launch.cpp @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dequant_f16_to_f32_tail_kernel(__gm__ half *src, __gm__ float *dst); + +void LaunchVmi_dequant_f16_to_f32_tail_kernel(uint16_t *src, float *dst, + void *stream) { + vmi_dequant_f16_to_f32_tail_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/dequant-f16-to-f32-tail/main.cpp b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/main.cpp new file mode 100644 index 0000000000..7797fe7fb0 --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/main.cpp @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dequant_f16_to_f32_tail_kernel(uint16_t *src, float *dst, + void *stream); + +int main() { + constexpr size_t kElems = 1024; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t dstBytes = kElems * sizeof(float); + uint16_t *srcHost = nullptr; + uint16_t *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_dequant_f16_to_f32_tail_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/dequant-f16-to-f32-tail/ptoas.flags b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/dequant-f8-to-f32-tail/compare.py b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/compare.py new file mode 100644 index 0000000000..8de470b64d --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-5, rtol=1e-5): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-5, rtol=1e-5))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/dequant-f8-to-f32-tail/golden.py b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/golden.py new file mode 100644 index 0000000000..b53b4b2ba9 --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 1024 +LOGICAL_ELEMS = 1000 +SCALE = np.float32(2.0) +VALUES = np.array([0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0], dtype=np.float32) +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8], dtype=np.uint8) +SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path) -> None: + repeats = (ELEMS + len(VALUES) - 1) // len(VALUES) + src = np.tile(F8E4M3FN_BYTES, repeats)[:ELEMS].astype(np.uint8) + decoded = np.tile(VALUES, repeats)[:ELEMS].astype(np.float32) + dst = np.full(ELEMS, SENTINEL, dtype=np.float32) + golden = np.full(ELEMS, SENTINEL, dtype=np.float32) + golden[:LOGICAL_ELEMS] = decoded[:LOGICAL_ELEMS] * SCALE + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/dequant-f8-to-f32-tail/kernel.pto b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/kernel.pto new file mode 100644 index 0000000000..bddf6b0f06 --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/kernel.pto @@ -0,0 +1,59 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dequant_f8_to_f32_tail_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + %c1000 = arith.constant 1000 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %scale = arith.constant 2.000000e+00 : f32 + + %ub_src_u8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_src_f8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src_u8, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c4096_i64 + nburst(%c1_i64, %c4096_i64, %c4096_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c256 iter_args(%remaining = %c1000) -> (index) { + %mask = pto.vmi.create_mask %remaining : index -> !pto.vmi.mask<256xpred> + %packed = pto.vmi.load %ub_src_f8[%offset] : !pto.ptr -> !pto.vmi.vreg<256xf8E4M3FN> + %wide = pto.vmi.extf %packed : !pto.vmi.vreg<256xf8E4M3FN> -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.broadcast %scale : f32 -> !pto.vmi.vreg<256xf32> + %out = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + pto.vmi.masked_store %out, %ub_dst[%offset], %mask + : !pto.vmi.vreg<256xf32>, !pto.ptr, !pto.vmi.mask<256xpred> + %next = arith.subi %remaining, %c256 : index + scf.yield %next : index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c4096_i64 + nburst(%c1_i64, %c4096_i64, %c4096_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/dequant-f8-to-f32-tail/launch.cpp b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/launch.cpp new file mode 100644 index 0000000000..02688457e3 --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/launch.cpp @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dequant_f8_to_f32_tail_kernel(__gm__ uint8_t *src, __gm__ float *dst); + +void LaunchVmi_dequant_f8_to_f32_tail_kernel(uint8_t *src, float *dst, + void *stream) { + vmi_dequant_f8_to_f32_tail_kernel<<<1, nullptr, stream>>>( + (__gm__ uint8_t *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/dequant-f8-to-f32-tail/main.cpp b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/main.cpp new file mode 100644 index 0000000000..ee62749258 --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/main.cpp @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dequant_f8_to_f32_tail_kernel(uint8_t *src, float *dst, + void *stream); + +int main() { + constexpr size_t kElems = 1024; + size_t srcBytes = kElems * sizeof(uint8_t); + size_t dstBytes = kElems * sizeof(float); + uint8_t *srcHost = nullptr; + uint8_t *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_dequant_f8_to_f32_tail_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/dequant-f8-to-f32-tail/ptoas.flags b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/quant-f32-to-f16-tail/compare.py b/test/vpto/cases/vmi/quant-f32-to-f16-tail/compare.py new file mode 100644 index 0000000000..39f37ccd7c --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f16-tail/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float16) + output = np.fromfile("v2.bin", dtype=np.float16) + if golden.shape != output.shape or not np.array_equal(golden, output): + diff = np.nonzero(golden.view(np.uint16) != output.view(np.uint16))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/quant-f32-to-f16-tail/golden.py b/test/vpto/cases/vmi/quant-f32-to-f16-tail/golden.py new file mode 100644 index 0000000000..7938574cd5 --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f16-tail/golden.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 1024 +LOGICAL_ELEMS = 1000 +SEED = 29 +SCALE = np.float32(0.5) +SENTINEL = np.float16(-17.5) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + src = rng.uniform(-8.0, 8.0, size=ELEMS).astype(np.float32) + dst = np.full(ELEMS, SENTINEL, dtype=np.float16) + golden = np.full(ELEMS, SENTINEL, dtype=np.float16) + golden[:LOGICAL_ELEMS] = (src[:LOGICAL_ELEMS] * SCALE).astype(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/quant-f32-to-f16-tail/kernel.pto b/test/vpto/cases/vmi/quant-f32-to-f16-tail/kernel.pto new file mode 100644 index 0000000000..2920617624 --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f16-tail/kernel.pto @@ -0,0 +1,60 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_quant_f32_to_f16_tail_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c1000 = arith.constant 1000 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %scale = arith.constant 5.000000e-01 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1000) -> (index) { + %mask = pto.vmi.create_mask %remaining : index -> !pto.vmi.mask<128xpred> + %wide = pto.vmi.load %ub_src[%offset] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %scale_vec = pto.vmi.broadcast %scale : f32 -> !pto.vmi.vreg<128xf32> + %scaled = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %packed = pto.vmi.truncf %scaled : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + pto.vmi.masked_store %packed, %ub_dst[%offset], %mask + : !pto.vmi.vreg<128xf16>, !pto.ptr, !pto.vmi.mask<128xpred> + %next = arith.subi %remaining, %c128 : index + scf.yield %next : index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/quant-f32-to-f16-tail/launch.cpp b/test/vpto/cases/vmi/quant-f32-to-f16-tail/launch.cpp new file mode 100644 index 0000000000..bf3aa91f10 --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f16-tail/launch.cpp @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_quant_f32_to_f16_tail_kernel(__gm__ float *src, __gm__ half *dst); + +void LaunchVmi_quant_f32_to_f16_tail_kernel(float *src, uint16_t *dst, + void *stream) { + vmi_quant_f32_to_f16_tail_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ half *)dst); +} diff --git a/test/vpto/cases/vmi/quant-f32-to-f16-tail/main.cpp b/test/vpto/cases/vmi/quant-f32-to-f16-tail/main.cpp new file mode 100644 index 0000000000..b03ccbce5c --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f16-tail/main.cpp @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_quant_f32_to_f16_tail_kernel(float *src, uint16_t *dst, + void *stream); + +int main() { + constexpr size_t kElems = 1024; + size_t srcBytes = kElems * sizeof(float); + size_t dstBytes = kElems * sizeof(uint16_t); + float *srcHost = nullptr; + float *srcDevice = nullptr; + uint16_t *dstHost = nullptr; + uint16_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_quant_f32_to_f16_tail_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/quant-f32-to-f16-tail/ptoas.flags b/test/vpto/cases/vmi/quant-f32-to-f16-tail/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f16-tail/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-full/compare.py b/test/vpto/cases/vmi/quant-f32-to-f8-full/compare.py new file mode 100644 index 0000000000..68c53a335e --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-full/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.uint8) + output = np.fromfile("v2.bin", dtype=np.uint8) + if golden.shape != output.shape or not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-full/golden.py b/test/vpto/cases/vmi/quant-f32-to-f8-full/golden.py new file mode 100644 index 0000000000..9c36f02c73 --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-full/golden.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 256 +VALUES = np.array([0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0], dtype=np.float32) +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8], dtype=np.uint8) + + +def generate(output_dir: Path) -> None: + repeats = (ELEMS + len(VALUES) - 1) // len(VALUES) + src = np.tile(VALUES, repeats)[:ELEMS].astype(np.float32) + golden = np.tile(F8E4M3FN_BYTES, repeats)[:ELEMS].astype(np.uint8) + dst = np.full(ELEMS, 0xA5, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-full/kernel.pto b/test/vpto/cases/vmi/quant-f32-to-f8-full/kernel.pto new file mode 100644 index 0000000000..4c7193f970 --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-full/kernel.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_quant_f32_to_f8_full_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst_u8 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_dst_f8 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst_u8, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %wide = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<256xf32> + %packed = pto.vmi.truncf %wide : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %packed, %ub_dst_f8[%c0] : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst_u8, %dst_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-full/launch.cpp b/test/vpto/cases/vmi/quant-f32-to-f8-full/launch.cpp new file mode 100644 index 0000000000..18bc01e2d1 --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-full/launch.cpp @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_quant_f32_to_f8_full_kernel(__gm__ float *src, __gm__ uint8_t *dst); + +void LaunchVmi_quant_f32_to_f8_full_kernel(float *src, uint8_t *dst, + void *stream) { + vmi_quant_f32_to_f8_full_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ uint8_t *)dst); +} diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-full/main.cpp b/test/vpto/cases/vmi/quant-f32-to-f8-full/main.cpp new file mode 100644 index 0000000000..6e3aae53f2 --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-full/main.cpp @@ -0,0 +1,79 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_quant_f32_to_f8_full_kernel(float *src, uint8_t *dst, + void *stream); + +int main() { + constexpr size_t kSrcElems = 256; + constexpr size_t kDstElems = 256; + size_t srcBytes = kSrcElems * sizeof(float); + size_t dstBytes = kDstElems * sizeof(uint8_t); + float *srcHost = nullptr; + float *srcDevice = nullptr; + uint8_t *dstHost = nullptr; + uint8_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_quant_f32_to_f8_full_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-full/ptoas.flags b/test/vpto/cases/vmi/quant-f32-to-f8-full/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-full/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-tail/compare.py b/test/vpto/cases/vmi/quant-f32-to-f8-tail/compare.py new file mode 100644 index 0000000000..68c53a335e --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-tail/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.uint8) + output = np.fromfile("v2.bin", dtype=np.uint8) + if golden.shape != output.shape or not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-tail/golden.py b/test/vpto/cases/vmi/quant-f32-to-f8-tail/golden.py new file mode 100644 index 0000000000..b662cd604f --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-tail/golden.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 1024 +LOGICAL_ELEMS = 1000 +VALUES = np.array([0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0], dtype=np.float32) +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8], dtype=np.uint8) +SENTINEL = np.uint8(0xA5) + + +def generate(output_dir: Path) -> None: + repeats = (ELEMS + len(VALUES) - 1) // len(VALUES) + src = np.tile(VALUES, repeats)[:ELEMS].astype(np.float32) + packed = np.tile(F8E4M3FN_BYTES, repeats)[:ELEMS].astype(np.uint8) + dst = np.full(ELEMS, SENTINEL, dtype=np.uint8) + golden = np.full(ELEMS, SENTINEL, dtype=np.uint8) + golden[:LOGICAL_ELEMS] = packed[:LOGICAL_ELEMS] + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-tail/kernel.pto b/test/vpto/cases/vmi/quant-f32-to-f8-tail/kernel.pto new file mode 100644 index 0000000000..bb3db56ff2 --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-tail/kernel.pto @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_quant_f32_to_f8_tail_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + %c1000 = arith.constant 1000 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst_u8 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_dst_f8 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c4096_i64 + nburst(%c1_i64, %c4096_i64, %c4096_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst_u8, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c256 iter_args(%remaining = %c1000) -> (index) { + %mask = pto.vmi.create_mask %remaining : index -> !pto.vmi.mask<256xpred> + %wide = pto.vmi.load %ub_src[%offset] : !pto.ptr -> !pto.vmi.vreg<256xf32> + %packed = pto.vmi.truncf %wide : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.masked_store %packed, %ub_dst_f8[%offset], %mask + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr, !pto.vmi.mask<256xpred> + %next = arith.subi %remaining, %c256 : index + scf.yield %next : index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst_u8, %dst_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-tail/launch.cpp b/test/vpto/cases/vmi/quant-f32-to-f8-tail/launch.cpp new file mode 100644 index 0000000000..cf40a3fc57 --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-tail/launch.cpp @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_quant_f32_to_f8_tail_kernel(__gm__ float *src, __gm__ uint8_t *dst); + +void LaunchVmi_quant_f32_to_f8_tail_kernel(float *src, uint8_t *dst, + void *stream) { + vmi_quant_f32_to_f8_tail_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ uint8_t *)dst); +} diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-tail/main.cpp b/test/vpto/cases/vmi/quant-f32-to-f8-tail/main.cpp new file mode 100644 index 0000000000..5f5bda8502 --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-tail/main.cpp @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_quant_f32_to_f8_tail_kernel(float *src, uint8_t *dst, + void *stream); + +int main() { + constexpr size_t kElems = 1024; + size_t srcBytes = kElems * sizeof(float); + size_t dstBytes = kElems * sizeof(uint8_t); + float *srcHost = nullptr; + float *srcDevice = nullptr; + uint8_t *dstHost = nullptr; + uint8_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_quant_f32_to_f8_tail_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-tail/ptoas.flags b/test/vpto/cases/vmi/quant-f32-to-f8-tail/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-tail/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/reduce-f16-f8-mul-store/compare.py b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/compare.py new file mode 100644 index 0000000000..5030420250 --- /dev/null +++ b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.float32) + output = np.fromfile("v3.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-5, rtol=1e-5): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-5, rtol=1e-5))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/reduce-f16-f8-mul-store/golden.py b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/golden.py new file mode 100644 index 0000000000..ee2be3c731 --- /dev/null +++ b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 256 +F16_VALUE = np.float16(0.125) +VALUES = np.array([0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0], dtype=np.float32) +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8], dtype=np.uint8) +SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path) -> None: + repeats = (ELEMS + len(VALUES) - 1) // len(VALUES) + src_f16 = np.full(ELEMS, F16_VALUE, dtype=np.float16) + src_f8 = np.tile(F8E4M3FN_BYTES, repeats)[:ELEMS].astype(np.uint8) + decoded_f8 = np.tile(VALUES, repeats)[:ELEMS].astype(np.float32) + reduction = np.sum(src_f16.astype(np.float32), dtype=np.float32) + dst = np.full(ELEMS, SENTINEL, dtype=np.float32) + golden = decoded_f8 * reduction + + output_dir.mkdir(parents=True, exist_ok=True) + src_f16.tofile(output_dir / "v1.bin") + src_f8.tofile(output_dir / "v2.bin") + dst.tofile(output_dir / "v3.bin") + golden.astype(np.float32, copy=False).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/reduce-f16-f8-mul-store/kernel.pto b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/kernel.pto new file mode 100644 index 0000000000..ae307ef525 --- /dev/null +++ b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/kernel.pto @@ -0,0 +1,66 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_reduce_f16_f8_mul_store_kernel(%src_f16_gm: !pto.ptr, + %src_f8_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %zero = arith.constant 0.000000e+00 : f32 + %c256 = arith.constant 256 : index + + %ub_f16 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_f8_u8 = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_f8 = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_f16_gm, %ub_f16, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %src_f8_gm, %ub_f8_u8, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %src_f16 = pto.vmi.load %ub_f16[%c0] : !pto.ptr -> !pto.vmi.vreg<256xf16> + %src_f16_f32 = pto.vmi.extf %src_f16 : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %init = pto.vmi.broadcast %zero : f32 -> !pto.vmi.vreg<1xf32> + %sum = pto.vmi.reduce_addf %src_f16_f32, %init, %mask {reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<1xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<1xf32> + %sum_vec = pto.vmi.broadcast %sum + : !pto.vmi.vreg<1xf32> -> !pto.vmi.vreg<256xf32> + %src_f8 = pto.vmi.load %ub_f8[%c0] : !pto.ptr -> !pto.vmi.vreg<256xf8E4M3FN> + %src_f8_f32 = pto.vmi.extf %src_f8 : !pto.vmi.vreg<256xf8E4M3FN> -> !pto.vmi.vreg<256xf32> + %out = pto.vmi.mulf %sum_vec, %src_f8_f32 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + pto.vmi.store %out, %ub_dst[%c0] : !pto.vmi.vreg<256xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/reduce-f16-f8-mul-store/launch.cpp b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/launch.cpp new file mode 100644 index 0000000000..b882f9e0e2 --- /dev/null +++ b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/launch.cpp @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_reduce_f16_f8_mul_store_kernel(__gm__ half *src_f16, + __gm__ uint8_t *src_f8, + __gm__ float *dst); + +void LaunchVmi_reduce_f16_f8_mul_store_kernel(uint16_t *src_f16, + uint8_t *src_f8, float *dst, + void *stream) { + vmi_reduce_f16_f8_mul_store_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src_f16, (__gm__ uint8_t *)src_f8, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/reduce-f16-f8-mul-store/main.cpp b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/main.cpp new file mode 100644 index 0000000000..e48cd97661 --- /dev/null +++ b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/main.cpp @@ -0,0 +1,88 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_reduce_f16_f8_mul_store_kernel(uint16_t *src_f16, + uint8_t *src_f8, float *dst, + void *stream); + +int main() { + constexpr size_t kElems = 256; + size_t srcF16Bytes = kElems * sizeof(uint16_t); + size_t srcF8Bytes = kElems * sizeof(uint8_t); + size_t dstBytes = kElems * sizeof(float); + uint16_t *srcF16Host = nullptr; + uint16_t *srcF16Device = nullptr; + uint8_t *srcF8Host = nullptr; + uint8_t *srcF8Device = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcF16Host), srcF16Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&srcF8Host), srcF8Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcF16Device, srcF16Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&srcF8Device, srcF8Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcF16Bytes, srcF16Host, srcF16Bytes); + ReadFile("./v2.bin", srcF8Bytes, srcF8Host, srcF8Bytes); + ReadFile("./v3.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcF16Device, srcF16Bytes, srcF16Host, srcF16Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(srcF8Device, srcF8Bytes, srcF8Host, srcF8Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_reduce_f16_f8_mul_store_kernel(srcF16Device, srcF8Device, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcF16Device); + aclrtFree(srcF8Device); + aclrtFree(dstDevice); + aclrtFreeHost(srcF16Host); + aclrtFreeHost(srcF8Host); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/reduce-f16-f8-mul-store/ptoas.flags b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 5bde9442e7..3732f72313 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -13,3 +13,4 @@ add_subdirectory(ptobc) add_subdirectory(ptoas) +add_subdirectory(pto-test-opt) diff --git a/tools/pto-test-opt/CMakeLists.txt b/tools/pto-test-opt/CMakeLists.txt new file mode 100644 index 0000000000..8f72f0383d --- /dev/null +++ b/tools/pto-test-opt/CMakeLists.txt @@ -0,0 +1,35 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +set(LLVM_LINK_COMPONENTS + Support +) + +add_llvm_executable(pto-test-opt + pto-test-opt.cpp +) + +target_link_libraries(pto-test-opt PRIVATE + PTOIR + PTOTransforms + MLIRMlirOptMain + MLIRIR + MLIRParser + MLIRPass + MLIRSupport + MLIRFuncDialect + MLIRArithDialect + MLIRMemRefDialect + MLIRSCFDialect + MLIRControlFlowDialect +) + +add_dependencies(pto-test-opt + PTOOpsIncGen + PTOPassesIncGen +) diff --git a/tools/pto-test-opt/pto-test-opt.cpp b/tools/pto-test-opt/pto-test-opt.cpp new file mode 100644 index 0000000000..6ec1dc70ef --- /dev/null +++ b/tools/pto-test-opt/pto-test-opt.cpp @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- pto-test-opt.cpp - PTO lit pass runner -----------------------------===// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/InitAllPasses.h" +#include "mlir/Tools/mlir-opt/MlirOptMain.h" + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + registry.insert(); + + mlir::registerAllPasses(); + mlir::pto::registerPTOPasses(); + + return failed(mlir::MlirOptMain(argc, argv, "PTO lit pass runner\n", + registry)); +} diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 5bb6821677..60e88b4276 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -8,6 +8,7 @@ #include "ptoas.h" #include "PTO/IR/PTO.h" +#include "PTO/IR/VMIUtils.h" #include "PTO/Transforms/VPTOLLVMEmitter.h" #include "PTO/Transforms/Passes.h" #include "PTO/Transforms/BufferizableOpInterfaceImpl.h" @@ -433,6 +434,12 @@ static llvm::cl::opt disableInferLayout( llvm::cl::desc("Disable PTO layout inference pass (static-only)"), llvm::cl::init(false)); +static llvm::cl::opt enableVMI( + "enable-vmi", + llvm::cl::desc("Run the experimental VMI-to-VPTO semantic pipeline " + "(requires --pto-backend=vpto or pto.backend = \"vpto\")"), + llvm::cl::init(false)); + static llvm::cl::opt emitAddPtrTrace( "emit-addptr-trace", llvm::cl::desc("Emit addptr trace comments in generated C++ output"), @@ -1585,6 +1592,51 @@ static LogicalResult runVPTOBackendPipeline(OwningOpRef &module, return success(); } +static bool containsVMIType(Type type) { + if (isa(type)) + return true; + if (auto functionType = dyn_cast(type)) { + return llvm::any_of(functionType.getInputs(), containsVMIType) || + llvm::any_of(functionType.getResults(), containsVMIType); + } + if (auto shapedType = dyn_cast(type)) + return containsVMIType(shapedType.getElementType()); + return false; +} + +static LogicalResult verifyNoPublicVMISignature(ModuleOp module) { + WalkResult result = module.walk([&](func::FuncOp func) { + if (!func.isPublic() || !containsVMIType(func.getFunctionType())) + return WalkResult::advance(); + func.emitError() + << pto::kVMIDiagLayoutContractPrefix + << "public VMI typed function requires an explicit external ABI " + "materialization plan"; + return WalkResult::interrupt(); + }); + return failure(result.wasInterrupted()); +} + +static LogicalResult runVMISemanticPipeline(OwningOpRef &module) { + if (failed(verifyNoPublicVMISignature(module.get()))) + return failure(); + + PassManager pm(module->getContext()); + pm.enableVerifier(); + pm.addPass(pto::createPTOValidateVMIIRPass()); + pm.addPass(pto::createVMILayoutAssignmentPass()); + pm.addPass(pto::createPTOValidateVMILayoutIRPass()); + pm.addPass(pto::createVMIToVPTOPass()); + if (failed(applyConfiguredPassManagerCLOptions(pm, + "VMI-to-VPTO pipeline"))) + return failure(); + if (failed(pm.run(module.get()))) { + llvm::errs() << "Error: VMI-to-VPTO pipeline failed.\n"; + return failure(); + } + return success(); +} + int mlir::pto::compilePTOASModule( OwningOpRef &module, PTOASContext &context, PTOBackend effectiveBackend, PTOASCompileResult &result, @@ -1600,6 +1652,11 @@ int mlir::pto::compilePTOASModule( "--pto-backend=vpto or pto.backend = \"vpto\".\n"; return 1; } + if (enableVMI && effectiveBackend != PTOBackend::VPTO) { + llvm::errs() << "Error: --enable-vmi requires --pto-backend=vpto or " + "pto.backend = \"vpto\".\n"; + return 1; + } PTOBuildLevel effectiveLevel = defaultBuildLevel(); if (!parseBuildLevel(ptoBuildLevel, effectiveLevel)) { @@ -1718,6 +1775,11 @@ int mlir::pto::compilePTOASModule( const bool hasTileOpsToExpand = hasUnexpandedTileOps(*module); const bool hasTilelangHelpers = hasTilelangInlineHelpers(*module); + if (enableVMI) { + if (failed(runVMISemanticPipeline(module))) + return 1; + } + if (effectiveBackend == PTOBackend::VPTO && !hasTileOpsToExpand) { if (ptoPrintSeamIR || !ptoSeamIRFile.empty()) { llvm::errs() << "Error: shared pre-backend seam IR is unavailable when " From ab6bc048999bd307378a177badc4f5bba098df1e Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Thu, 18 Jun 2026 13:03:02 +0800 Subject: [PATCH 02/31] feat: support num_groups layout --- docs/designs/vmi-implementation-manual.md | 121 +++ include/PTO/IR/VMIAttrs.td | 4 + include/PTO/IR/VMIOps.td | 38 + lib/PTO/IR/VMI.cpp | 139 ++- lib/PTO/Transforms/VMILayoutAssignment.cpp | 39 + lib/PTO/Transforms/VMIToVPTO.cpp | 910 +++++++++++++++++- .../vmi/vmi_to_vpto_group_broadcast_deint.pto | 33 + .../vmi/vmi_to_vpto_group_broadcast_vselr.pto | 42 + test/lit/vmi/vmi_to_vpto_group_ops.pto | 41 + .../vmi/vmi_to_vpto_group_reduce_vcgadd.pto | 33 + ...to_vpto_group_reduce_vcgadd_multichunk.pto | 45 + .../group-reduce-f16-f8-mul-store/compare.py | 27 + .../group-reduce-f16-f8-mul-store/golden.py | 59 ++ .../group-reduce-f16-f8-mul-store/kernel.pto | 71 ++ .../group-reduce-f16-f8-mul-store/launch.cpp | 43 + .../group-reduce-f16-f8-mul-store/main.cpp | 91 ++ .../group-reduce-f16-f8-mul-store/ptoas.flags | 1 + 17 files changed, 1725 insertions(+), 12 deletions(-) create mode 100644 test/lit/vmi/vmi_to_vpto_group_broadcast_deint.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_broadcast_vselr.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_ops.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd_multichunk.pto create mode 100644 test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/ptoas.flags diff --git a/docs/designs/vmi-implementation-manual.md b/docs/designs/vmi-implementation-manual.md index 772194f64d..cd674db32a 100644 --- a/docs/designs/vmi-implementation-manual.md +++ b/docs/designs/vmi-implementation-manual.md @@ -748,6 +748,11 @@ deinterleaved=4: part1 chunks for lanes 1,5,9,... part2 chunks for lanes 2,6,10,... part3 chunks for lanes 3,7,11,... + +num_groups=G: + sparse group-slot reduce result layout + physical storage is contiguous chunk order + only canonical group_slot(g) lanes contain semantic values ``` 每个 semantic pattern 必须从 adaptor 拿 physical parts,不允许从 defining op 反推: @@ -1596,6 +1601,11 @@ The type converter must define one canonical physical ordering and every pattern part2 lanes [2,6,10,...] part3 lanes [3,7,11,...] +!pto.vmi.vreg + -> chunks in contiguous physical storage order + only derived group_slot(g) lanes contain semantic values + this layout is valid only for group reduce/broadcast exchange values + !pto.vmi.mask -> same part/chunk ordering as its data layout, one !pto.mask per physical part/chunk ``` @@ -2932,6 +2942,117 @@ pto.vmi.reduce_addf: f16 until accumulator precision and rounding contract are designed partial/tail source chunks because padding lanes must not participate +pto.vmi.group_load / pto.vmi.group_store: + semantic: + num_groups is the only static grouping attribute. + N = logical lane count; G = num_groups; S = N / G. + group_load reads each logical group as one contiguous row: + result[g * S + i] = source[offset + g * row_stride + i] + for 0 <= g < G and 0 <= i < S + group_store writes the inverse row mapping: + destination[offset + g * row_stride + i] = value[g * S + i] + row_stride is an index operand, measured in elements, and may be dynamic. + Tail/valid-lane information is not an attr; it must be represented by a + mask in the producing/consuming computation. The current direct + group_load/group_store path is for full physical chunks. + layout assignment: + group_load result natural layout is contiguous + group_store value use is requested as contiguous + current direct lowering: + source/value element width must be maskable by b8/b16/b32 + layout must be contiguous with full physical chunks + num_groups must evenly divide N, and the derived group size S must be a + multiple of the physical lanes + per part, so every physical chunk belongs to exactly one group + lower each physical chunk with pto.vlds/pto.vsts at: + offset + group * row_stride + chunk_in_group * lanes_per_part + unsupported cases: + derived group size splitting a physical chunk, because this needs partial-vreg + lane insertion/extraction or a gather/scatter plan + partial/tail physical chunks + GM-backed direct vector load/store paths not already accepted by the normal + VMI memory access plan + +pto.vmi.group_reduce_addf: + semantic: + requires {reassoc} + N = logical lane count; G = num_groups; S = N / G + L = physical lanes per 256B chunk for the element type. + The result carries #pto.vmi.layout, a sparse group-slot + layout. It is not a dense vector layout: only group_slot(g) lanes have + semantic values. + group_slot(g) is canonical and derived from N, G, and L: + if S < L: + low_elems = L / S + chunk_stride = 1 + if S >= L: + low_elems = 1 + chunk_stride = S / L + group_slot(g) = (g / low_elems) * chunk_stride * L + (g % low_elems) + for each group g: + result[group_slot(g)] = + reduce_add(source[g * S .. (g + 1) * S), mask in same range) + Non-slot lanes are not consumed by pto.vmi.group_broadcast. The current + direct lowering materializes them as zero where the hardware path does not + already define them. + The result remains a VMI vector with the same element type and logical lane + count as the source, but its layout is #pto.vmi.layout. + layout assignment: + source use is requested as contiguous + result natural layout is #pto.vmi.layout + mask use is requested as contiguous with granularity derived from source + element width + current direct lowering: + source/result element type must be f32 + source, result, and mask must have matching physical arity and full chunks + if S=8 for f32, lower each physical chunk with pto.vcgadd. This is the + hardware 32B VLane group reduction path for f32: each source chunk produces + eight 8-lane group sums in the low lanes of that physical chunk. The + lowering preserves this natural no-pack result. + Otherwise: + derived group size S must be a multiple of physical lanes per part + lower each source chunk with pto.vcadd, combine chunks in the same group + with pto.vadd under PAT_VL1, then place group g at group_slot(g) in the + #pto.vmi.layout result. All other result chunks/lane values + are zero. + unsupported cases: + missing reassoc attr + f16 or integer group reductions until accumulator and result contracts are + designed + derived group size S that neither divides nor is a multiple of L + +pto.vmi.group_broadcast: + semantic: + N = logical lane count; G = num_groups; S = N / G + source must carry #pto.vmi.layout. For each group g, the + source value is read from group_slot(g), using the same canonical group_slot + definition as pto.vmi.group_reduce_addf. The result broadcasts it back to + each logical group: + result[g * S + i] = source[group_slot(g)] + layout assignment: + source use is requested as #pto.vmi.layout + result is consumer-driven. If no consumer requests another layout, it + defaults to contiguous. + current direct lowering: + source must carry #pto.vmi.layout with full physical chunks + result may be contiguous with full physical chunks + result may also be deinterleaved when S is large enough that every physical + result chunk stays inside one logical group, for example N=512, G=2, S=256, + L=64, deinterleaved=4 + derived group size S must divide or be a multiple of L for canonical + group-slot addressing + if result is contiguous and S < L, each physical chunk contains multiple group + slots. Lower by + creating an index vector [0...0, 1...1, ...] and applying pto.vselr to the + corresponding source chunk. + if S >= L and each result physical chunk belongs to one group, lower by + duplicating the first lane of that group's source chunk with pto.vdup LOWEST. + unsupported cases: + partial/tail physical chunks + derived group size S that neither divides nor is a multiple of L + deinterleaved small-group broadcast where one physical result chunk needs + values from multiple source chunks + pto.vmi.reduce_maxf / pto.vmi.reduce_minf: semantic: acc = init[0] diff --git a/include/PTO/IR/VMIAttrs.td b/include/PTO/IR/VMIAttrs.td index fc2a7f2f5b..da8428dd23 100644 --- a/include/PTO/IR/VMIAttrs.td +++ b/include/PTO/IR/VMIAttrs.td @@ -25,9 +25,13 @@ def VMILayoutAttr : PTO_Attr<"VMILayout", "vmi.layout"> { static VMILayoutAttr getContiguous(::mlir::MLIRContext *context); static VMILayoutAttr getDeinterleaved(::mlir::MLIRContext *context, int64_t factor); + static VMILayoutAttr getGroupSlots(::mlir::MLIRContext *context, + int64_t numGroups); bool isContiguous() const { return getKind() == "contiguous"; } bool isDeinterleaved() const { return getKind() == "deinterleaved"; } + bool isGroupSlots() const { return getKind() == "num_groups"; } + int64_t getNumGroups() const { return getFactor(); } }]; } diff --git a/include/PTO/IR/VMIOps.td b/include/PTO/IR/VMIOps.td index 6f567bb8a5..7bd7524118 100644 --- a/include/PTO/IR/VMIOps.td +++ b/include/PTO/IR/VMIOps.td @@ -391,6 +391,26 @@ def VMIReduceMinFOp : VMI_Op<"reduce_minf"> { let assemblyFormat = "$source `,` $init `,` $mask attr-dict `:` type($source) `,` type($init) `,` type($mask) `->` type($result)"; } +def VMIGroupReduceAddFOp : VMI_Op<"group_reduce_addf"> { + let summary = "VMI masked floating-point add reduction within fixed logical groups"; + let arguments = (ins VMI_VRegTypeConstraint:$source, + VMI_MaskTypeConstraint:$mask, + I64Attr:$num_groups, + OptionalAttr:$reassoc); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `,` $mask attr-dict `:` type($source) `,` type($mask) `->` type($result)"; +} + +def VMIGroupBroadcastOp : VMI_Op<"group_broadcast"> { + let summary = "VMI broadcast group-slot values back to each logical group"; + let arguments = (ins VMI_VRegTypeConstraint:$source, + I64Attr:$num_groups); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + def VMIExtFOp : VMI_Op<"extf"> { let summary = "VMI floating-point elementwise extension"; let arguments = (ins VMI_VRegTypeConstraint:$source); @@ -423,6 +443,15 @@ def VMILoadOp : VMI_Op<"load", [DeclareOpInterfaceMethods]> { + let summary = "VMI logical grouped vector load with a row stride between groups"; + let arguments = (ins PtrOrMemRef:$source, Index:$offset, Index:$row_stride, + I64Attr:$num_groups); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `[` $offset `]` `,` $row_stride attr-dict `:` type($source) `->` type($result)"; +} + def VMIMaskedLoadOp : VMI_Op<"masked_load", [DeclareOpInterfaceMethods]> { let summary = "VMI logical masked vector load with passthrough lanes"; let arguments = (ins PtrOrMemRef:$source, Index:$offset, @@ -462,6 +491,15 @@ def VMIStoreOp : VMI_Op<"store", [DeclareOpInterfaceMethods]> { + let summary = "VMI logical grouped vector store with a row stride between groups"; + let arguments = (ins VMI_VRegTypeConstraint:$value, PtrOrMemRef:$destination, + Index:$offset, Index:$row_stride, I64Attr:$num_groups); + let results = (outs); + let hasVerifier = 1; + let assemblyFormat = "$value `,` $destination `[` $offset `]` `,` $row_stride attr-dict `:` type($value) `,` type($destination)"; +} + def VMIMaskedStoreOp : VMI_Op<"masked_store", [DeclareOpInterfaceMethods]> { let summary = "VMI logical masked vector store"; let arguments = (ins VMI_VRegTypeConstraint:$value, PtrOrMemRef:$destination, diff --git a/lib/PTO/IR/VMI.cpp b/lib/PTO/IR/VMI.cpp index 1f9a43f51a..e26982e347 100644 --- a/lib/PTO/IR/VMI.cpp +++ b/lib/PTO/IR/VMI.cpp @@ -143,7 +143,7 @@ static FailureOr getLayoutFactor(Type type) { FailureOr layout = getAssignedVMILayout(type); if (failed(layout)) return failure(); - return (*layout).isContiguous() ? 1 : (*layout).getFactor(); + return (*layout).isDeinterleaved() ? (*layout).getFactor() : 1; } static FailureOr getPhysicalLanesPerPart(Type type) { @@ -294,6 +294,17 @@ static LogicalResult verifyMemoryElementMatches(Operation *op, Type memoryType, return success(); } +static LogicalResult verifyNumGroups(Operation *op, VMIVRegType type, + int64_t numGroups) { + if (numGroups <= 0) + return op->emitOpError("requires num_groups to be positive"); + if (type.getElementCount() % numGroups != 0) + return op->emitOpError() + << "requires num_groups to evenly divide VMI logical lane count " + << type.getElementCount(); + return success(); +} + static LogicalResult verifyPhysicalParts(Operation *op, Type vmiType, TypeRange physicalTypes) { FailureOr expectedArity = getVMIPhysicalArity(vmiType); @@ -354,6 +365,11 @@ VMILayoutAttr VMILayoutAttr::getDeinterleaved(MLIRContext *context, return VMILayoutAttr::get(context, "deinterleaved", factor); } +VMILayoutAttr VMILayoutAttr::getGroupSlots(MLIRContext *context, + int64_t numGroups) { + return VMILayoutAttr::get(context, "num_groups", numGroups); +} + Attribute VMILayoutAttr::parse(AsmParser &parser, Type) { SMLoc loc = parser.getCurrentLocation(); StringRef kind; @@ -367,10 +383,13 @@ Attribute VMILayoutAttr::parse(AsmParser &parser, Type) { } else if (kind == "deinterleaved") { if (failed(parser.parseEqual()) || failed(parser.parseInteger(factor))) return {}; + } else if (kind == "num_groups") { + if (failed(parser.parseEqual()) || failed(parser.parseInteger(factor))) + return {}; } else { parser.emitError(parser.getCurrentLocation(), "expected VMI layout kind 'contiguous' or " - "'deinterleaved'"); + "'deinterleaved' or 'num_groups'"); return {}; } @@ -383,7 +402,7 @@ Attribute VMILayoutAttr::parse(AsmParser &parser, Type) { void VMILayoutAttr::print(AsmPrinter &printer) const { printer << "<" << getKind(); - if (isDeinterleaved()) + if (isDeinterleaved() || isGroupSlots()) printer << " = " << getFactor(); printer << ">"; } @@ -406,8 +425,16 @@ VMILayoutAttr::verify(function_ref emitError, return success(); } + if (kind == "num_groups") { + if (factor <= 0) + return emitError() + << "#pto.vmi.layout requires num_groups to be positive"; + return success(); + } + return emitError() << "expected VMI layout kind to be 'contiguous' or " - "'deinterleaved'"; + "'deinterleaved' or 'num_groups'"; } Type VMIVRegType::parse(AsmParser &parser) { @@ -454,6 +481,14 @@ LogicalResult VMIVRegType::verify(function_ref emitError, return emitError() << "'" << formatVMIVRegType(elementCount, elementType, layout) << "' expected layout to be #pto.vmi.layout"; + if (auto layoutAttr = llvm::dyn_cast_or_null(layout)) { + if (layoutAttr.isGroupSlots() && + elementCount % layoutAttr.getNumGroups() != 0) + return emitError() << "'" << formatVMIVRegType(elementCount, elementType, + layout) + << "' expected num_groups layout to evenly divide " + "the VMI logical lane count"; + } return success(); } @@ -509,6 +544,12 @@ LogicalResult VMIMaskType::verify(function_ref emitError, return emitError() << "'" << formatVMIMaskType(elementCount, granularity, layout) << "' expected layout to be #pto.vmi.layout"; + if (auto layoutAttr = llvm::dyn_cast_or_null(layout)) { + if (layoutAttr.isGroupSlots()) + return emitError() << "'" << formatVMIMaskType(elementCount, granularity, + layout) + << "' mask type must not carry num_groups layout"; + } if (granularity == "pred" && layout) return emitError() << "'" << formatVMIMaskType(elementCount, granularity, @@ -958,6 +999,65 @@ LogicalResult VMIReduceMinFOp::verify() { return verifyReduceMinMaxFOp(*this); } +LogicalResult VMIGroupReduceAddFOp::verify() { + auto sourceType = cast(getSource().getType()); + auto maskType = cast(getMask().getType()); + auto resultType = cast(getResult().getType()); + if (!getOperation()->hasAttr("reassoc")) + return emitOpError( + "requires reassoc attr because grouped lowering uses pair-wise " + "floating-point reductions"); + if (!isVMIFloatLikeType(sourceType.getElementType())) + return emitOpError("requires floating-point-like VMI source element type"); + if (sourceType.getElementCount() != resultType.getElementCount()) + return emitOpError( + "requires source and result logical lane counts to match"); + if (sourceType.getElementType() != resultType.getElementType()) + return emitOpError("requires source and result element types to match"); + if (auto sourceLayout = sourceType.getLayoutAttr()) { + if (!sourceLayout.isContiguous()) + return emitOpError( + "requires layout-assigned source to use contiguous layout"); + } + if (auto resultLayout = resultType.getLayoutAttr()) { + if (!resultLayout.isGroupSlots() || + resultLayout.getNumGroups() != getNumGroupsAttr().getInt()) + return emitOpError() + << "requires layout-assigned result to use " + "#pto.vmi.layout"; + } + if (failed(verifyMaskMatchesData(getOperation(), maskType, sourceType))) + return failure(); + return verifyNumGroups(getOperation(), sourceType, + getNumGroupsAttr().getInt()); +} + +LogicalResult VMIGroupBroadcastOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementCount() != resultType.getElementCount()) + return emitOpError( + "requires source and result logical lane counts to match"); + if (sourceType.getElementType() != resultType.getElementType()) + return emitOpError("requires source and result element types to match"); + if (auto sourceLayout = sourceType.getLayoutAttr()) { + if (!sourceLayout.isGroupSlots() || + sourceLayout.getNumGroups() != getNumGroupsAttr().getInt()) + return emitOpError() + << "requires layout-assigned source to use " + "#pto.vmi.layout"; + } + if (auto resultLayout = resultType.getLayoutAttr()) { + if (resultLayout.isGroupSlots()) + return emitOpError( + "requires layout-assigned result to use a dense VMI layout"); + } + return verifyNumGroups(getOperation(), sourceType, + getNumGroupsAttr().getInt()); +} + LogicalResult VMIExtFOp::verify() { auto sourceType = cast(getSource().getType()); auto resultType = cast(getResult().getType()); @@ -1025,6 +1125,21 @@ void VMILoadOp::getEffects( effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); } +LogicalResult VMIGroupLoadOp::verify() { + auto resultType = cast(getResult().getType()); + if (failed(verifyMemoryElementMatches(getOperation(), getSource().getType(), + resultType, "source"))) + return failure(); + return verifyNumGroups(getOperation(), resultType, + getNumGroupsAttr().getInt()); +} + +void VMIGroupLoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + LogicalResult VMIMaskedLoadOp::verify() { auto maskType = cast(getMask().getType()); auto passthruType = cast(getPassthru().getType()); @@ -1109,6 +1224,22 @@ void VMIStoreOp::getEffects( effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); } +LogicalResult VMIGroupStoreOp::verify() { + auto valueType = cast(getValue().getType()); + if (failed(verifyMemoryElementMatches(getOperation(), + getDestination().getType(), valueType, + "destination"))) + return failure(); + return verifyNumGroups(getOperation(), valueType, + getNumGroupsAttr().getInt()); +} + +void VMIGroupStoreOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + LogicalResult VMIMaskedStoreOp::verify() { auto valueType = cast(getValue().getType()); auto maskType = cast(getMask().getType()); diff --git a/lib/PTO/Transforms/VMILayoutAssignment.cpp b/lib/PTO/Transforms/VMILayoutAssignment.cpp index e4d201d45c..27d6b806fe 100644 --- a/lib/PTO/Transforms/VMILayoutAssignment.cpp +++ b/lib/PTO/Transforms/VMILayoutAssignment.cpp @@ -224,6 +224,10 @@ struct LayoutSolver { return VMILayoutAttr::getContiguous(ctx); } + VMILayoutAttr getGroupSlotsLayout(int64_t numGroups) { + return VMILayoutAttr::getGroupSlots(ctx, numGroups); + } + VMILayoutAttr getDataLayout(Value value) { unsigned id = addDataValue(value); if (id == ~0u) @@ -537,6 +541,21 @@ struct LayoutSolver { return WalkResult::interrupt(); return WalkResult::advance(); } + if (auto reduce = dyn_cast(op)) { + requestDataUse(reduce.getSourceMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(reduce.getResult(), + getGroupSlotsLayout( + reduce.getNumGroupsAttr().getInt()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto broadcast = dyn_cast(op)) { + requestDataUse(broadcast.getSourceMutable(), + getGroupSlotsLayout( + broadcast.getNumGroupsAttr().getInt())); + return WalkResult::advance(); + } if (auto extf = dyn_cast(op)) { auto sourceType = cast(extf.getSource().getType()); auto resultType = cast(extf.getResult().getType()); @@ -607,10 +626,20 @@ struct LayoutSolver { return WalkResult::interrupt(); return WalkResult::advance(); } + if (auto load = dyn_cast(op)) { + if (failed(setNaturalLayout(load.getResult(), getContiguousLayout(), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } if (auto store = dyn_cast(op)) { requestDataUse(store.getValueMutable(), getContiguousLayout()); return WalkResult::advance(); } + if (auto store = dyn_cast(op)) { + requestDataUse(store.getValueMutable(), getContiguousLayout()); + return WalkResult::advance(); + } if (auto store = dyn_cast(op)) { auto valueType = cast(store.getValue().getType()); requestDataUse(store.getValueMutable(), getContiguousLayout()); @@ -1136,6 +1165,16 @@ struct LayoutSolver { return WalkResult::interrupt(); return WalkResult::advance(); } + if (auto reduce = dyn_cast(op)) { + auto sourceType = cast(reduce.getSource().getType()); + if (failed(requestMaskUse(reduce.getMaskMutable(), + sourceType.getLayoutAttr(), + getMaskGranularityForElement( + sourceType.getElementType()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } if (auto load = dyn_cast(op)) { auto resultType = cast(load.getResult().getType()); if (failed(requestMaskUse(load.getMaskMutable(), diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index db19c2846b..cf91af1142 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -488,7 +488,7 @@ FailureOr getDataLayoutFactor(VMIVRegType type) { VMILayoutAttr layout = type.getLayoutAttr(); if (!layout) return failure(); - return layout.isContiguous() ? 1 : layout.getFactor(); + return layout.isDeinterleaved() ? layout.getFactor() : 1; } FailureOr getDataChunksInPart(VMIVRegType type, int64_t part) { @@ -568,7 +568,7 @@ FailureOr getVMITypeLayoutFactor(Type type) { auto layoutAttr = dyn_cast_or_null(layout); if (!layoutAttr) return failure(); - return layoutAttr.isContiguous() ? 1 : layoutAttr.getFactor(); + return layoutAttr.isDeinterleaved() ? layoutAttr.getFactor() : 1; } FailureOr getVMITypeElementCount(Type type) { @@ -1087,6 +1087,80 @@ LogicalResult checkSupportedStoreShape( materializationReason); } +FailureOr getGroupSizeFromNumGroups(VMIVRegType type, + int64_t numGroups, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + if (numGroups <= 0) + return fail("requires num_groups to be positive"); + if (type.getElementCount() % numGroups != 0) + return fail("requires num_groups to evenly divide logical lane count"); + return type.getElementCount() / numGroups; +} + +LogicalResult checkSupportedGroupChunkShape(VMIVRegType type, + int64_t groupSize, + std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout || !layout.isContiguous()) + return fail("requires assigned contiguous layout"); + std::string fullChunkReason; + if (failed(checkFullDataPhysicalChunks(type, &fullChunkReason))) + return fail(Twine("requires full physical chunks; ") + fullChunkReason); + FailureOr lanesPerPart = getDataLanesPerPart(type.getElementType()); + if (failed(lanesPerPart)) + return fail("requires known physical lanes per part"); + if (groupSize <= 0 || type.getElementCount() % groupSize != 0) + return fail("requires derived group size to evenly divide logical lane " + "count"); + if (groupSize % *lanesPerPart != 0) + return fail("currently requires group size to be a multiple of physical " + "lanes per part"); + return success(); +} + +LogicalResult checkSupportedGroupLoadShape( + const VMITargetCapabilityRegistry &capabilities, VMIGroupLoadOp op, + std::string *reason) { + auto resultType = cast(op.getResult().getType()); + FailureOr groupSize = + getGroupSizeFromNumGroups(resultType, op.getNumGroupsAttr().getInt(), + reason); + if (failed(groupSize)) + return failure(); + if (failed(checkSupportedLoadShape(capabilities, resultType, op.getSource(), + op.getSource().getType(), + std::nullopt, reason))) + return failure(); + return checkSupportedGroupChunkShape(resultType, *groupSize, reason); +} + +LogicalResult checkSupportedGroupStoreShape( + const VMITargetCapabilityRegistry &capabilities, VMIGroupStoreOp op, + std::string *reason) { + auto valueType = cast(op.getValue().getType()); + FailureOr groupSize = + getGroupSizeFromNumGroups(valueType, op.getNumGroupsAttr().getInt(), + reason); + if (failed(groupSize)) + return failure(); + if (failed(checkSupportedStoreShape(capabilities, valueType, + op.getDestination(), + op.getDestination().getType(), reason))) + return failure(); + return checkSupportedGroupChunkShape(valueType, *groupSize, reason); +} + LogicalResult checkSupportedMaskedLoadShape(const VMITargetCapabilityRegistry &capabilities, VMIMaskedLoadOp op, std::string *reason) { @@ -1766,7 +1840,7 @@ computeConstantMaskMaterialization(VMIConstantMaskOp op, std::string *reason) { return fail("requires known physical mask lanes per part"); auto boolValues = denseAttr.getValues(); - int64_t factor = layout.isContiguous() ? 1 : layout.getFactor(); + int64_t factor = layout.isDeinterleaved() ? layout.getFactor() : 1; SmallVector materializations; for (int64_t part = 0; part < factor; ++part) { for (int64_t chunk = 0;; ++chunk) { @@ -1897,6 +1971,10 @@ materializeConstantMaskChunk(Location loc, MaskType maskType, return materializePrefixMask(loc, maskType, 0, *lanesPerPart, rewriter); } +FailureOr createScalarOffsetConstant(Location loc, Type type, + int64_t value, + PatternRewriter &rewriter); + Value createChunkOffset(Location loc, Value baseOffset, int64_t laneOffset, PatternRewriter &rewriter) { if (laneOffset == 0) @@ -1905,6 +1983,234 @@ Value createChunkOffset(Location loc, Value baseOffset, int64_t laneOffset, return rewriter.create(loc, baseOffset, delta).getResult(); } +Value createGroupChunkOffset(Location loc, Value baseOffset, Value rowStride, + int64_t group, int64_t inGroupLaneOffset, + PatternRewriter &rewriter) { + Value offset = baseOffset; + if (group != 0) { + Value groupIndex = rewriter.create(loc, group); + Value rowOffset = + rewriter.create(loc, rowStride, groupIndex).getResult(); + offset = rewriter.create(loc, offset, rowOffset).getResult(); + } + return createChunkOffset(loc, offset, inGroupLaneOffset, rewriter); +} + +LogicalResult checkContiguousFullGroupChunks(Operation *op, VMIVRegType type, + int64_t groupSize, + int64_t *lanesPerPart, + int64_t *groupCount, + int64_t *chunksPerGroup, + PatternRewriter &rewriter) { + auto fail = [&](const Twine &message) { + return rewriter.notifyMatchFailure(op, message); + }; + + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout || !layout.isContiguous()) + return fail("group op requires contiguous VMI layout"); + if (failed(checkFullDataPhysicalChunks(type, nullptr))) + return fail("group op requires full physical chunks"); + FailureOr lanes = getDataLanesPerPart(type.getElementType()); + if (failed(lanes)) + return fail("group op requires known physical lanes per part"); + if (groupSize <= 0 || type.getElementCount() % groupSize != 0) + return fail("group op requires derived group size to evenly divide lane " + "count"); + if (groupSize % *lanes != 0) + return fail("group op currently requires group size to be a multiple of " + "physical lanes per part"); + + *lanesPerPart = *lanes; + *groupCount = type.getElementCount() / groupSize; + *chunksPerGroup = groupSize / *lanes; + return success(); +} + +LogicalResult checkFullGroupSlotSourceShape(Operation *op, VMIVRegType type, + int64_t groupSize, + int64_t numGroups, + int64_t *lanesPerPart, + int64_t *groupCount, + PatternRewriter &rewriter) { + auto fail = [&](const Twine &message) { + return rewriter.notifyMatchFailure(op, message); + }; + + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout || !layout.isGroupSlots() || + layout.getNumGroups() != numGroups) + return fail("group slot op requires matching num_groups VMI layout"); + if (failed(checkFullDataPhysicalChunks(type, nullptr))) + return fail("group slot op requires full physical chunks"); + FailureOr lanes = getDataLanesPerPart(type.getElementType()); + if (failed(lanes)) + return fail("group slot op requires known physical lanes per part"); + if (groupSize <= 0 || type.getElementCount() % groupSize != 0) + return fail( + "group slot op requires derived group size to evenly divide lane count"); + if (*lanes % groupSize != 0 && groupSize % *lanes != 0) + return fail("group slot op requires group size to divide or be a " + "multiple of physical lanes per part"); + + *lanesPerPart = *lanes; + *groupCount = type.getElementCount() / groupSize; + return success(); +} + +LogicalResult checkFullGroupBroadcastResultShape(Operation *op, + VMIVRegType type, + int64_t groupSize, + int64_t lanesPerPart, + int64_t *layoutFactor, + int64_t *groupCount, + PatternRewriter &rewriter) { + auto fail = [&](const Twine &message) { + return rewriter.notifyMatchFailure(op, message); + }; + + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout) + return fail("group_broadcast result requires assigned VMI layout"); + if (layout.isGroupSlots()) + return fail("group_broadcast result requires a dense VMI layout"); + if (failed(checkFullDataPhysicalChunks(type, nullptr))) + return fail("group_broadcast result requires full physical chunks"); + FailureOr resultLanes = + getDataLanesPerPart(type.getElementType()); + if (failed(resultLanes) || *resultLanes != lanesPerPart) + return fail("group_broadcast result requires matching physical lanes"); + if (groupSize <= 0 || type.getElementCount() % groupSize != 0) + return fail("group_broadcast result requires derived group size to evenly " + "divide lane count"); + FailureOr factor = getDataLayoutFactor(type); + if (failed(factor)) + return fail("group_broadcast result requires known layout factor"); + + if (*factor == 1) { + if (lanesPerPart % groupSize != 0 && groupSize % lanesPerPart != 0) + return fail("group_broadcast contiguous result requires group size to " + "divide or be a multiple of physical lanes per part"); + } else { + int64_t logicalSpanPerResultChunk = lanesPerPart * *factor; + if (groupSize < lanesPerPart || + groupSize % logicalSpanPerResultChunk != 0) + return fail("group_broadcast deinterleaved result requires every " + "physical result chunk to stay within one logical group"); + } + + *layoutFactor = *factor; + *groupCount = type.getElementCount() / groupSize; + return success(); +} + +FailureOr createZeroVector(Location loc, VRegType type, + PatternRewriter &rewriter) { + FailureOr zero = + createScalarOffsetConstant(loc, type.getElementType(), 0, rewriter); + FailureOr mask = createAllTrueMaskForVReg(loc, type, rewriter); + if (failed(zero) || failed(mask)) + return failure(); + return rewriter.create(loc, type, *zero, *mask, + /*position=*/nullptr) + .getResult(); +} + +FailureOr createLaneRangeMask(Location loc, MaskType maskType, + int64_t begin, int64_t end, + PatternRewriter &rewriter) { + FailureOr lanesPerPart = + getMaskLanesPerPart(maskType.getGranularity()); + if (failed(lanesPerPart) || begin < 0 || begin > end || + end > *lanesPerPart) + return failure(); + SmallVector active(*lanesPerPart, 0); + for (int64_t lane = begin; lane < end; ++lane) + active[lane] = 1; + return materializeConstantMaskChunk(loc, maskType, active, rewriter); +} + +FailureOr createGroupSlotIndexVector(Location loc, VRegType indexType, + int64_t groupSize, + PatternRewriter &rewriter) { + int64_t lanesPerPart = indexType.getElementCount(); + FailureOr zero = + createZeroVector(loc, indexType, rewriter); + FailureOr maskType = getMaskTypeForVReg(indexType, rewriter.getContext()); + FailureOr allMask = createAllTrueMaskForVReg(loc, indexType, rewriter); + if (failed(zero) || failed(maskType) || failed(allMask)) + return failure(); + if (groupSize >= lanesPerPart) + return *zero; + if (lanesPerPart % groupSize != 0) + return failure(); + + Value result = *zero; + int64_t groupsPerChunk = lanesPerPart / groupSize; + for (int64_t localGroup = 1; localGroup < groupsPerChunk; ++localGroup) { + FailureOr groupScalar = createScalarOffsetConstant( + loc, indexType.getElementType(), localGroup, rewriter); + FailureOr laneMask = + createLaneRangeMask(loc, *maskType, localGroup * groupSize, + (localGroup + 1) * groupSize, rewriter); + if (failed(groupScalar) || failed(laneMask)) + return failure(); + Value splat = + rewriter + .create(loc, indexType, *groupScalar, *allMask, + /*position=*/nullptr) + .getResult(); + result = rewriter.create(loc, indexType, splat, result, *laneMask) + .getResult(); + } + return result; +} + +LogicalResult checkVcgaddGroupReduceShape(VMIVRegType sourceType, + VMIMaskType maskType, + VMIVRegType resultType, + int64_t groupSize, + std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (!sourceType.getElementType().isF32() || + sourceType.getElementType() != resultType.getElementType()) + return fail("vcgadd group_reduce_addf path requires f32 source/result"); + if (groupSize != 8) + return fail("vcgadd group_reduce_addf path requires group size = 8 for " + "f32 32-byte VLane groups"); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + int64_t numGroups = sourceType.getElementCount() / groupSize; + if (!sourceLayout || !resultLayout || !maskLayout || + !sourceLayout.isContiguous() || !resultLayout.isGroupSlots() || + resultLayout.getNumGroups() != numGroups || + !maskLayout.isContiguous()) + return fail("vcgadd group_reduce_addf path requires contiguous source/mask " + "layouts and matching num_groups result layout"); + std::string sourceFullReason; + if (failed(checkFullDataPhysicalChunks(sourceType, &sourceFullReason))) + return fail(Twine("vcgadd group_reduce_addf path requires full source " + "chunks; ") + + sourceFullReason); + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(maskArity) || failed(resultArity)) + return fail("vcgadd group_reduce_addf path requires computable physical " + "arity"); + if (*sourceArity < 1 || *sourceArity != *maskArity || + *sourceArity != *resultArity) + return fail("vcgadd group_reduce_addf path requires matching non-empty " + "source/mask/result physical arity"); + return success(); +} + std::optional getX2MemoryDistToken(Type elementType, StringRef prefix) { unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); @@ -2713,9 +3019,10 @@ FailureOr createScalarOffsetConstant(Location loc, Type type, } if (auto floatType = dyn_cast(type)) { return rewriter - .create( - loc, FloatAttr::get(floatType, - llvm::APFloat(static_cast(value)))) + .create(loc, + rewriter.getFloatAttr(floatType, + static_cast( + value))) .getResult(); } return failure(); @@ -2987,7 +3294,7 @@ struct OneToNVMICreateMaskOpPattern return failure(); TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); - int64_t factor = layout.isContiguous() ? 1 : layout.getFactor(); + int64_t factor = layout.isDeinterleaved() ? layout.getFactor() : 1; if (resultTypes.size() % factor != 0) return rewriter.notifyMatchFailure( op, "dynamic create_mask physical result count does not match " @@ -3034,7 +3341,7 @@ struct OneToNVMICreateMaskOpPattern activeLanes = resultVMIType.getElementCount(); TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); - int64_t factor = layout.isContiguous() ? 1 : layout.getFactor(); + int64_t factor = layout.isDeinterleaved() ? layout.getFactor() : 1; SmallVector results; results.reserve(resultTypes.size()); @@ -3184,6 +3491,73 @@ struct OneToNVMILoadOpPattern : OneToNOpConversionPattern { } }; +struct OneToNVMIGroupLoadOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIGroupLoadOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIGroupLoadOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto resultVMIType = cast(op.getResult().getType()); + FailureOr source = + getSingleValue(op, adaptor.getSource(), + "group_load source must convert to one value", + rewriter); + FailureOr offset = + getSingleValue(op, adaptor.getOffset(), + "group_load offset must convert to one value", + rewriter); + FailureOr rowStride = + getSingleValue(op, adaptor.getRowStride(), + "group_load row_stride must convert to one value", + rewriter); + if (failed(source) || failed(offset) || failed(rowStride)) + return failure(); + + int64_t lanesPerPart = 0; + int64_t groupCount = 0; + int64_t chunksPerGroup = 0; + FailureOr groupSize = getGroupSizeFromNumGroups( + resultVMIType, op.getNumGroupsAttr().getInt()); + if (failed(groupSize)) + return rewriter.notifyMatchFailure( + op, "group_load requires num_groups to evenly divide lane count"); + if (failed(checkContiguousFullGroupChunks( + op, resultVMIType, *groupSize, &lanesPerPart, &groupCount, + &chunksPerGroup, rewriter))) + return failure(); + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (static_cast(resultTypes.size()) != + groupCount * chunksPerGroup) + return rewriter.notifyMatchFailure(op, "group_load arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [index, resultType] : llvm::enumerate(resultTypes)) { + auto vregType = dyn_cast(resultType); + if (!vregType) + return rewriter.notifyMatchFailure(op, + "group_load result must be vreg"); + int64_t group = index / chunksPerGroup; + int64_t chunkInGroup = index % chunksPerGroup; + Value chunkOffset = createGroupChunkOffset( + op.getLoc(), *offset, *rowStride, group, + chunkInGroup * lanesPerPart, rewriter); + results.push_back( + rewriter + .create(op.getLoc(), resultType, + /*updated_base=*/Type{}, *source, chunkOffset, + /*dist=*/nullptr) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + struct OneToNVMIMaskedLoadOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern< @@ -3502,6 +3876,73 @@ struct OneToNVMIStoreOpPattern : OneToNOpConversionPattern { } }; +struct OneToNVMIGroupStoreOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIGroupStoreOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIGroupStoreOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto valueVMIType = cast(op.getValue().getType()); + int64_t lanesPerPart = 0; + int64_t groupCount = 0; + int64_t chunksPerGroup = 0; + FailureOr groupSize = getGroupSizeFromNumGroups( + valueVMIType, op.getNumGroupsAttr().getInt()); + if (failed(groupSize)) + return rewriter.notifyMatchFailure( + op, "group_store requires num_groups to evenly divide lane count"); + if (failed(checkContiguousFullGroupChunks( + op, valueVMIType, *groupSize, &lanesPerPart, &groupCount, + &chunksPerGroup, rewriter))) + return failure(); + + FailureOr destination = + getSingleValue(op, adaptor.getDestination(), + "group_store destination must convert to one value", + rewriter); + FailureOr offset = + getSingleValue(op, adaptor.getOffset(), + "group_store offset must convert to one value", + rewriter); + FailureOr rowStride = + getSingleValue(op, adaptor.getRowStride(), + "group_store row_stride must convert to one value", + rewriter); + if (failed(destination) || failed(offset) || failed(rowStride)) + return failure(); + + ValueRange valueParts = adaptor.getValue(); + if (static_cast(valueParts.size()) != + groupCount * chunksPerGroup) + return rewriter.notifyMatchFailure(op, "group_store arity mismatch"); + + for (auto [index, value] : llvm::enumerate(valueParts)) { + auto vregType = dyn_cast(value.getType()); + if (!vregType) + return rewriter.notifyMatchFailure(op, + "group_store value must be vreg"); + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for group_store mask"); + int64_t group = index / chunksPerGroup; + int64_t chunkInGroup = index % chunksPerGroup; + Value chunkOffset = createGroupChunkOffset( + op.getLoc(), *offset, *rowStride, group, + chunkInGroup * lanesPerPart, rewriter); + rewriter.create(op.getLoc(), + /*updated_base=*/Type{}, value, *destination, + chunkOffset, /*dist=*/nullptr, *mask); + } + + rewriter.eraseOp(op); + return success(); + } +}; + struct OneToNVMIMaskedStoreOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern< @@ -4346,6 +4787,284 @@ struct OneToNVMIReduceAddFOpPattern } }; +struct OneToNVMIGroupReduceAddFOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIGroupReduceAddFOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIGroupReduceAddFOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto sourceVMIType = cast(op.getSource().getType()); + auto maskVMIType = cast(op.getMask().getType()); + auto resultVMIType = cast(op.getResult().getType()); + ValueRange sourceParts = adaptor.getSource(); + ValueRange maskParts = adaptor.getMask(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + FailureOr groupSize = getGroupSizeFromNumGroups( + sourceVMIType, op.getNumGroupsAttr().getInt()); + if (failed(groupSize)) + return rewriter.notifyMatchFailure( + op, + "group_reduce_addf requires num_groups to evenly divide lane count"); + if (succeeded(checkVcgaddGroupReduceShape( + sourceVMIType, maskVMIType, resultVMIType, + *groupSize, nullptr))) { + if (sourceParts.size() != maskParts.size() || + sourceParts.size() != resultTypes.size() || sourceParts.empty()) + return rewriter.notifyMatchFailure( + op, "vcgadd group_reduce_addf path requires matching physical " + "arity"); + auto resultType = dyn_cast(resultTypes.front()); + auto maskType = dyn_cast(maskParts.front().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure( + op, "vcgadd group_reduce_addf path requires physical vreg/mask"); + for (auto [sourcePart, maskPart, physicalResultType] : + llvm::zip_equal(sourceParts, maskParts, resultTypes)) { + if (sourcePart.getType() != resultType || + maskPart.getType() != maskType || physicalResultType != resultType) + return rewriter.notifyMatchFailure( + op, "vcgadd group_reduce_addf path requires uniform physical " + "chunk types"); + } + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [sourceIndex, sourcePart] : llvm::enumerate(sourceParts)) { + results.push_back( + rewriter + .create(op.getLoc(), resultType, sourcePart, + maskParts[sourceIndex]) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + int64_t lanesPerPart = 0; + int64_t groupCount = 0; + int64_t chunksPerGroup = 0; + if (failed(checkContiguousFullGroupChunks( + op, sourceVMIType, *groupSize, &lanesPerPart, &groupCount, + &chunksPerGroup, rewriter))) + return failure(); + if (sourceParts.size() != maskParts.size() || + static_cast(sourceParts.size()) != + groupCount * chunksPerGroup || + resultTypes.size() != sourceParts.size()) + return rewriter.notifyMatchFailure( + op, "group_reduce_addf requires matching source/mask/result arity"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (Type resultType : resultTypes) { + auto vregType = dyn_cast(resultType); + if (!vregType) + return rewriter.notifyMatchFailure( + op, "group_reduce_addf result must be vreg"); + FailureOr zero = createZeroVector(op.getLoc(), vregType, rewriter); + if (failed(zero)) + return rewriter.notifyMatchFailure( + op, "failed to materialize group_reduce_addf zero result"); + results.push_back(*zero); + } + + auto resultType = dyn_cast(resultTypes.front()); + auto maskType = dyn_cast(maskParts.front().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure( + op, "group_reduce_addf requires physical vreg result and mask"); + + FailureOr firstLaneMask = + createPrefixMask(op.getLoc(), maskType, "PAT_VL1", rewriter); + if (failed(firstLaneMask)) + return rewriter.notifyMatchFailure( + op, "failed to create group_reduce_addf masks"); + + for (int64_t group = 0; group < groupCount; ++group) { + FailureOr accumulator = + createZeroVector(op.getLoc(), resultType, rewriter); + if (failed(accumulator)) + return rewriter.notifyMatchFailure( + op, "failed to create group_reduce_addf accumulator"); + + for (int64_t chunk = 0; chunk < chunksPerGroup; ++chunk) { + int64_t index = group * chunksPerGroup + chunk; + if (sourceParts[index].getType() != resultType || + maskParts[index].getType() != maskType) + return rewriter.notifyMatchFailure( + op, "group_reduce_addf requires uniform physical chunk types"); + Value reduced = + rewriter + .create(op.getLoc(), resultType, sourceParts[index], + maskParts[index]) + .getResult(); + *accumulator = + rewriter + .create(op.getLoc(), resultType, reduced, + *accumulator, *firstLaneMask) + .getResult(); + } + + int64_t destChunk = group * chunksPerGroup; + results[destChunk] = + rewriter + .create(op.getLoc(), resultType, *accumulator, + results[destChunk], *firstLaneMask) + .getResult(); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIGroupBroadcastOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIGroupBroadcastOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIGroupBroadcastOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto sourceVMIType = cast(op.getSource().getType()); + auto resultVMIType = cast(op.getResult().getType()); + FailureOr groupSize = getGroupSizeFromNumGroups( + sourceVMIType, op.getNumGroupsAttr().getInt()); + if (failed(groupSize)) + return rewriter.notifyMatchFailure( + op, + "group_broadcast requires num_groups to evenly divide lane count"); + int64_t lanesPerPart = 0; + int64_t groupCount = 0; + if (failed(checkFullGroupSlotSourceShape( + op, sourceVMIType, *groupSize, op.getNumGroupsAttr().getInt(), + &lanesPerPart, &groupCount, rewriter))) + return failure(); + int64_t resultLayoutFactor = 0; + int64_t resultGroupCount = 0; + if (failed(checkFullGroupBroadcastResultShape( + op, resultVMIType, *groupSize, lanesPerPart, &resultLayoutFactor, + &resultGroupCount, rewriter))) + return failure(); + if (resultGroupCount != groupCount) + return rewriter.notifyMatchFailure( + op, "group_broadcast requires matching source/result group slots"); + + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.empty() || resultTypes.empty()) + return rewriter.notifyMatchFailure(op, "group_broadcast arity mismatch"); + + auto firstSourceType = dyn_cast(sourceParts.front().getType()); + if (!firstSourceType) + return rewriter.notifyMatchFailure( + op, "group_broadcast source must be vreg"); + unsigned indexBits = + pto::getPTOStorageElemBitWidth(firstSourceType.getElementType()); + if (indexBits != 8 && indexBits != 16 && indexBits != 32) + return rewriter.notifyMatchFailure( + op, "group_broadcast requires 8/16/32-bit index elements"); + auto indexElementType = IntegerType::get(rewriter.getContext(), indexBits); + auto indexType = + VRegType::get(rewriter.getContext(), firstSourceType.getElementCount(), + indexElementType); + std::optional groupSlotIndex; + FailureOr allMask = + createAllTrueMaskForVReg(op.getLoc(), firstSourceType, rewriter); + if (failed(allMask)) + return rewriter.notifyMatchFailure( + op, "failed to create group_broadcast all mask"); + if (*groupSize < lanesPerPart) { + FailureOr index = createGroupSlotIndexVector( + op.getLoc(), indexType, *groupSize, rewriter); + if (failed(index)) + return rewriter.notifyMatchFailure( + op, "failed to create group_broadcast group-slot index vector"); + groupSlotIndex = *index; + } + + SmallVector results; + results.resize(resultTypes.size()); + for (auto [flatIndex, resultType] : llvm::enumerate(resultTypes)) { + auto resultVRegType = dyn_cast(resultType); + if (!resultVRegType || resultVRegType != firstSourceType) + return rewriter.notifyMatchFailure( + op, "group_broadcast requires uniform physical vreg types"); + int64_t sourceChunk = flatIndex; + if (resultLayoutFactor == 1) { + if (*groupSize >= lanesPerPart) { + int64_t chunksPerGroup = *groupSize / lanesPerPart; + int64_t group = flatIndex / chunksPerGroup; + sourceChunk = group * chunksPerGroup; + } + } else { + int64_t runningFlatIndex = 0; + bool found = false; + for (int64_t part = 0; part < resultLayoutFactor && !found; ++part) { + FailureOr chunks = getDataChunksInPart(resultVMIType, part); + if (failed(chunks)) + return rewriter.notifyMatchFailure( + op, "group_broadcast failed to enumerate result chunks"); + for (int64_t chunk = 0; chunk < *chunks; ++chunk, ++runningFlatIndex) { + if (runningFlatIndex != static_cast(flatIndex)) + continue; + FailureOr firstLogical = + mapPhysicalLaneToLogical(resultVMIType, part, chunk, 0); + FailureOr lastLogical = mapPhysicalLaneToLogical( + resultVMIType, part, chunk, lanesPerPart - 1); + if (failed(firstLogical) || failed(lastLogical)) + return rewriter.notifyMatchFailure( + op, "group_broadcast failed to map result chunk lanes"); + int64_t firstGroup = *firstLogical / *groupSize; + int64_t lastGroup = *lastLogical / *groupSize; + if (firstGroup != lastGroup) + return rewriter.notifyMatchFailure( + op, "group_broadcast result chunk crosses logical groups"); + int64_t chunksPerGroup = *groupSize / lanesPerPart; + sourceChunk = firstGroup * chunksPerGroup; + found = true; + break; + } + } + if (!found) + return rewriter.notifyMatchFailure( + op, "group_broadcast result chunk index is out of range"); + } + if (*groupSize >= lanesPerPart) { + if (sourceChunk < 0 || + sourceChunk >= static_cast(sourceParts.size())) + return rewriter.notifyMatchFailure( + op, "group_broadcast source chunk is out of range"); + results[flatIndex] = + rewriter + .create(op.getLoc(), resultType, sourceParts[sourceChunk], + *allMask, rewriter.getStringAttr("LOWEST")) + .getResult(); + } else { + if (resultLayoutFactor != 1) + return rewriter.notifyMatchFailure( + op, "group_broadcast small-group deinterleaved result is not " + "supported"); + if (sourceChunk < 0 || + sourceChunk >= static_cast(sourceParts.size())) + return rewriter.notifyMatchFailure( + op, "group_broadcast source chunk is out of range"); + results[flatIndex] = + rewriter + .create(op.getLoc(), resultType, + sourceParts[sourceChunk], *groupSlotIndex) + .getResult(); + } + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + template struct OneToNVMIReduceMinMaxFOpPattern : OneToNOpConversionPattern { @@ -5032,10 +5751,12 @@ void populateVMIOneToNConversionPatterns( OneToNVMIMaskBinaryOpPattern, OneToNVMIMaskUnaryOpPattern, OneToNVMILoadOpPattern, + OneToNVMIGroupLoadOpPattern, OneToNVMIMaskedLoadOpPattern, OneToNVMIGatherOpPattern, OneToNVMIExpandLoadOpPattern, OneToNVMIStoreOpPattern, + OneToNVMIGroupStoreOpPattern, OneToNVMIMaskedStoreOpPattern, OneToNVMIScatterOpPattern, OneToNVMITileReadOpPattern, @@ -5071,6 +5792,8 @@ void populateVMIOneToNConversionPatterns( OneToNVMICompressStoreOpPattern, OneToNVMIReduceAddIOpPattern, OneToNVMIReduceAddFOpPattern, + OneToNVMIGroupReduceAddFOpPattern, + OneToNVMIGroupBroadcastOpPattern, OneToNVMIReduceMinMaxFOpPattern, OneToNVMIReduceMinMaxFOpPattern LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (!op->hasAttr("reassoc")) + return fail("requires reassoc attr for pair-wise floating-point reduction"); + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + auto maskType = cast(op.getMask().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + if (!sourceLayout || !resultLayout || !maskLayout) + return fail("requires assigned source, mask, and result layouts"); + if (!sourceLayout.isContiguous() || !resultLayout.isGroupSlots() || + resultLayout.getNumGroups() != op.getNumGroupsAttr().getInt() || + !maskLayout.isContiguous()) + return fail("requires contiguous source/mask layouts and matching " + "num_groups result layout"); + VMICapabilityResult elementCapability = + capabilities.supportsReductionElementType(VMIReductionKind::AddF, + sourceType.getElementType()); + if (!elementCapability.isSupported()) + return fail(elementCapability.reason); + if (sourceType.getElementType() != resultType.getElementType()) + return fail("requires source/result element type to match"); + if (sourceType.getElementCount() != resultType.getElementCount()) + return fail("requires source/result lane count to match"); + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + if (failed(sourceArity) || failed(resultArity) || failed(maskArity)) + return fail("requires computable source/result/mask physical arity"); + if (*sourceArity != *resultArity || *sourceArity != *maskArity) + return fail("requires source/result/mask physical arity to match"); + FailureOr groupSize = + getGroupSizeFromNumGroups(sourceType, op.getNumGroupsAttr().getInt(), + reason); + if (failed(groupSize)) + return failure(); + if (succeeded(checkVcgaddGroupReduceShape( + sourceType, maskType, resultType, *groupSize, nullptr))) + return success(); + return checkSupportedGroupChunkShape(sourceType, *groupSize, reason); +} + +LogicalResult checkSupportedGroupBroadcastShape( + const VMITargetCapabilityRegistry &capabilities, VMIGroupBroadcastOp op, + std::string *reason = nullptr) { + (void)capabilities; + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + if (sourceType.getElementType() != resultType.getElementType() || + sourceType.getElementCount() != resultType.getElementCount()) { + if (reason) + *reason = "requires source/result shape and element type to match"; + return failure(); + } + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!sourceLayout || !resultLayout) + return fail("requires assigned source/result layouts"); + if (!sourceLayout.isGroupSlots() || + sourceLayout.getNumGroups() != op.getNumGroupsAttr().getInt()) + return fail("requires matching num_groups source layout"); + if (resultLayout.isGroupSlots()) + return fail("requires dense result layout"); + + std::string fullChunkReason; + if (failed(checkFullDataPhysicalChunks(sourceType, &fullChunkReason))) + return fail(Twine("requires full source physical chunks; ") + + fullChunkReason); + if (failed(checkFullDataPhysicalChunks(resultType, &fullChunkReason))) + return fail(Twine("requires full result physical chunks; ") + + fullChunkReason); + + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + FailureOr resultLanesPerPart = + getDataLanesPerPart(resultType.getElementType()); + if (failed(lanesPerPart) || failed(resultLanesPerPart) || + *lanesPerPart != *resultLanesPerPart) + return fail("requires matching physical lanes per part"); + FailureOr groupSize = + getGroupSizeFromNumGroups(sourceType, op.getNumGroupsAttr().getInt(), + reason); + if (failed(groupSize)) + return failure(); + if (*lanesPerPart % *groupSize != 0 && *groupSize % *lanesPerPart != 0) + return fail("requires derived group size to divide or be a multiple of " + "physical lanes per part"); + + FailureOr resultFactor = getDataLayoutFactor(resultType); + if (failed(resultFactor)) + return fail("requires known result layout factor"); + if (*resultFactor == 1) + return success(); + int64_t logicalSpanPerResultChunk = *lanesPerPart * *resultFactor; + if (*groupSize < *lanesPerPart || + *groupSize % logicalSpanPerResultChunk != 0) + return fail("deinterleaved result requires every physical result chunk to " + "stay within one logical group"); + return success(); +} + LogicalResult checkSupportedFmaShape(const VMITargetCapabilityRegistry &capabilities, VMIFmaOp op, std::string *reason = nullptr) { @@ -5678,11 +6517,37 @@ LogicalResult verifySupportedVMIToVPTOOps( return emitMaskableUnsupported( op, "pto.vmi.broadcast", cast(broadcast.getResult().getType())); + if (auto broadcast = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedGroupBroadcastShape(capabilities, broadcast, + &reason))) + return WalkResult::advance(); + broadcast.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.group_broadcast requires full source chunks with " + "#pto.vmi.layout, a dense full result layout, " + "and num_groups deriving a group size that divides or is a " + "multiple of physical chunk lanes (" + << reason << ")"; + return WalkResult::interrupt(); + } if (auto load = dyn_cast(op)) return emitMemoryUnsupported( op, "pto.vmi.load", cast(load.getResult().getType()), load.getSource(), getConstantIndexValue(load.getOffset())); + if (auto load = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedGroupLoadShape(capabilities, load, &reason))) + return WalkResult::advance(); + load.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.group_load requires contiguous full result chunks, a " + "supported UB source, and num_groups deriving a group size " + "aligned to physical chunks (" + << reason << ")"; + return WalkResult::interrupt(); + } if (auto load = dyn_cast(op)) { if (enableStableGatherMaskedLoad) { load.emitError() @@ -5744,6 +6609,19 @@ LogicalResult verifySupportedVMIToVPTOOps( << reason << ")"; return WalkResult::interrupt(); } + if (auto store = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedGroupStoreShape(capabilities, store, + &reason))) + return WalkResult::advance(); + store.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.group_store requires contiguous full value chunks, a " + "supported UB destination, and num_groups deriving a group size " + "aligned to physical chunks (" + << reason << ")"; + return WalkResult::interrupt(); + } if (auto store = dyn_cast(op)) { std::string reason; if (succeeded(checkSupportedMaskedStoreShape( @@ -6063,6 +6941,22 @@ LogicalResult verifySupportedVMIToVPTOOps( return WalkResult::interrupt(); } + if (auto reduce = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedGroupReduceAddFShape(capabilities, reduce, + &reason))) + return WalkResult::advance(); + reduce.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.group_reduce_addf lowers through pto.vcgadd for f32 " + "32B groups or through pto.vcadd with reassoc, contiguous full " + "source/mask chunks, #pto.vmi.layout result " + "chunks, and num_groups deriving a group size aligned to " + "physical chunks (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto reduce = dyn_cast(op)) { std::string reason; if (succeeded(checkSupportedReduceShape( diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_deint.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_deint.pto new file mode 100644 index 0000000000..078b61b5bf --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_deint.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_broadcast_deint( + %sum: !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + %src_f8: !pto.vmi.vreg<512xf8E4M3FN>) + -> !pto.vmi.vreg<512xf32> { + %src_f32 = pto.vmi.extf %src_f8 + : !pto.vmi.vreg<512xf8E4M3FN> -> !pto.vmi.vreg<512xf32> + %sum_vec = pto.vmi.group_broadcast %sum {num_groups = 2} + : !pto.vmi.vreg<512xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<512xf32> + %out = pto.vmi.mulf %sum_vec, %src_f32 + : !pto.vmi.vreg<512xf32>, !pto.vmi.vreg<512xf32> -> !pto.vmi.vreg<512xf32> + return %out : !pto.vmi.vreg<512xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_broadcast_deint( +// CHECK-COUNT-8: {position = "LOWEST"} +// CHECK-COUNT-8: pto.vmul +// CHECK-NOT: pto.vselr +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_vselr.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_vselr.pto new file mode 100644 index 0000000000..01d9711ef0 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_vselr.pto @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_broadcast_vselr( + %source: !pto.vmi.vreg<1024xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %out = pto.vmi.group_broadcast %source {num_groups = 128} + : !pto.vmi.vreg<1024xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, + %p8, %p9, %p10, %p11, %p12, %p13, %p14, %p15 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1024xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, + %p8, %p9, %p10, %p11, %p12, %p13, %p14, %p15 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_broadcast_vselr( +// CHECK-COUNT-16: pto.vselr +// CHECK-NOT: pto.vcadd +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_ops.pto b/test/lit/vmi/vmi_to_vpto_group_ops.pto new file mode 100644 index 0000000000..6a10e168dd --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_ops.pto @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_ops( + %src: !pto.ptr, + %dst: !pto.ptr, + %row_stride: index, + %mask: !pto.vmi.mask<512xb32, #pto.vmi.layout>) { + %c0 = arith.constant 0 : index + %v = pto.vmi.group_load %src[%c0], %row_stride {num_groups = 2} + : !pto.ptr -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + %r = pto.vmi.group_reduce_addf %v, %mask {num_groups = 2, reassoc} + : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + !pto.vmi.mask<512xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + %b = pto.vmi.group_broadcast %r {num_groups = 2} + : !pto.vmi.vreg<512xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + pto.vmi.group_store %b, %dst[%c0], %row_stride {num_groups = 2} + : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_ops( +// CHECK-COUNT-8: pto.vlds +// CHECK-COUNT-8: pto.vcadd +// CHECK-COUNT-8: {position = "LOWEST"} +// CHECK-NOT: pto.vselr +// CHECK-COUNT-8: pto.vsts +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd.pto new file mode 100644 index 0000000000..27d246e6d2 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_reduce_vcgadd( + %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_vcgadd( +// CHECK: %[[OUT:.*]] = pto.vcgadd %arg0, %arg1 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vcadd +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd_multichunk.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd_multichunk.pto new file mode 100644 index 0000000000..d3da9416b6 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd_multichunk.pto @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_reduce_vcgadd_multichunk( + %source: !pto.vmi.vreg<1024xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<1024xb32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 128, reassoc} + : !pto.vmi.vreg<1024xf32, #pto.vmi.layout>, + !pto.vmi.mask<1024xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, + %p8, %p9, %p10, %p11, %p12, %p13, %p14, %p15 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1024xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, + %p8, %p9, %p10, %p11, %p12, %p13, %p14, %p15 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_vcgadd_multichunk( +// CHECK-COUNT-16: pto.vcgadd +// CHECK-NOT: pto.vcadd +// CHECK-NOT: pto.vselr +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/compare.py b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/compare.py new file mode 100644 index 0000000000..5030420250 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.float32) + output = np.fromfile("v3.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-5, rtol=1e-5): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-5, rtol=1e-5))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/golden.py b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/golden.py new file mode 100644 index 0000000000..69fbe13344 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/golden.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 2 +ROW_ELEMS = 256 +ROW_STRIDE = 320 +TOTAL_ELEMS = ROWS * ROW_STRIDE +F16_VALUES = np.array([0.125, 0.25], dtype=np.float16) +VALUES = np.array([0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0], dtype=np.float32) +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8], dtype=np.uint8) +SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path) -> None: + repeats = (ROW_ELEMS + len(VALUES) - 1) // len(VALUES) + row_f8 = np.tile(F8E4M3FN_BYTES, repeats)[:ROW_ELEMS].astype(np.uint8) + row_decoded_f8 = np.tile(VALUES, repeats)[:ROW_ELEMS].astype(np.float32) + + src_f16 = np.zeros(TOTAL_ELEMS, dtype=np.float16) + src_f8 = np.zeros(TOTAL_ELEMS, dtype=np.uint8) + dst = np.full(TOTAL_ELEMS, SENTINEL, dtype=np.float32) + golden = np.full(TOTAL_ELEMS, SENTINEL, dtype=np.float32) + + for row in range(ROWS): + begin = row * ROW_STRIDE + end = begin + ROW_ELEMS + src_f16[begin:end] = F16_VALUES[row] + src_f8[begin:end] = np.roll(row_f8, row) + decoded_f8 = np.roll(row_decoded_f8, row) + reduction = np.sum(src_f16[begin:end].astype(np.float32), dtype=np.float32) + golden[begin:end] = decoded_f8 * reduction + + output_dir.mkdir(parents=True, exist_ok=True) + src_f16.tofile(output_dir / "v1.bin") + src_f8.tofile(output_dir / "v2.bin") + dst.tofile(output_dir / "v3.bin") + golden.astype(np.float32, copy=False).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/kernel.pto new file mode 100644 index 0000000000..9cedd97e60 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/kernel.pto @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_f16_f8_mul_store_kernel(%src_f16_gm: !pto.ptr, + %src_f8_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c320 = arith.constant 320 : index + %c512 = arith.constant 512 : index + %c0_i64 = arith.constant 0 : i64 + %c2_i64 = arith.constant 2 : i64 + %c256_i64 = arith.constant 256 : i64 + %c320_i64 = arith.constant 320 : i64 + %c512_i64 = arith.constant 512 : i64 + %c640_i64 = arith.constant 640 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c1280_i64 = arith.constant 1280 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_f16 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_f8_u8 = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_f8 = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_f16_gm, %ub_f16, %c0_i64, %c512_i64 + nburst(%c2_i64, %c640_i64, %c640_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %src_f8_gm, %ub_f8_u8, %c0_i64, %c256_i64 + nburst(%c2_i64, %c320_i64, %c320_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c512 : index -> !pto.vmi.mask<512xpred> + %src_f16 = pto.vmi.group_load %ub_f16[%c0], %c320 {num_groups = 2} + : !pto.ptr -> !pto.vmi.vreg<512xf16> + %src_f16_f32 = pto.vmi.extf %src_f16 + : !pto.vmi.vreg<512xf16> -> !pto.vmi.vreg<512xf32> + %sum = pto.vmi.group_reduce_addf %src_f16_f32, %mask {num_groups = 2, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<512xf32> + %src_f8 = pto.vmi.group_load %ub_f8[%c0], %c320 {num_groups = 2} + : !pto.ptr -> !pto.vmi.vreg<512xf8E4M3FN> + %src_f8_f32 = pto.vmi.extf %src_f8 + : !pto.vmi.vreg<512xf8E4M3FN> -> !pto.vmi.vreg<512xf32> + %sum_vec = pto.vmi.group_broadcast %sum {num_groups = 2} + : !pto.vmi.vreg<512xf32> -> !pto.vmi.vreg<512xf32> + %out = pto.vmi.mulf %sum_vec, %src_f8_f32 + : !pto.vmi.vreg<512xf32>, !pto.vmi.vreg<512xf32> -> !pto.vmi.vreg<512xf32> + pto.vmi.group_store %out, %ub_dst[%c0], %c320 {num_groups = 2} + : !pto.vmi.vreg<512xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c1024_i64 + nburst(%c2_i64, %c1280_i64, %c1280_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/launch.cpp new file mode 100644 index 0000000000..03bf4d7e8f --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/launch.cpp @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_f16_f8_mul_store_kernel(__gm__ half *src_f16, + __gm__ uint8_t *src_f8, + __gm__ float *dst); + +void LaunchVmi_group_reduce_f16_f8_mul_store_kernel(uint16_t *src_f16, + uint8_t *src_f8, + float *dst, void *stream) { + vmi_group_reduce_f16_f8_mul_store_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src_f16, (__gm__ uint8_t *)src_f8, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/main.cpp b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/main.cpp new file mode 100644 index 0000000000..e5769e3978 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/main.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_f16_f8_mul_store_kernel(uint16_t *src_f16, + uint8_t *src_f8, + float *dst, void *stream); + +int main() { + constexpr size_t kRows = 2; + constexpr size_t kRowStride = 320; + constexpr size_t kElems = kRows * kRowStride; + size_t srcF16Bytes = kElems * sizeof(uint16_t); + size_t srcF8Bytes = kElems * sizeof(uint8_t); + size_t dstBytes = kElems * sizeof(float); + uint16_t *srcF16Host = nullptr; + uint16_t *srcF16Device = nullptr; + uint8_t *srcF8Host = nullptr; + uint8_t *srcF8Device = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcF16Host), srcF16Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&srcF8Host), srcF8Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcF16Device, srcF16Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&srcF8Device, srcF8Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcF16Bytes, srcF16Host, srcF16Bytes); + ReadFile("./v2.bin", srcF8Bytes, srcF8Host, srcF8Bytes); + ReadFile("./v3.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcF16Device, srcF16Bytes, srcF16Host, srcF16Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(srcF8Device, srcF8Bytes, srcF8Host, srcF8Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_f16_f8_mul_store_kernel(srcF16Device, srcF8Device, + dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcF16Device); + aclrtFree(srcF8Device); + aclrtFree(dstDevice); + aclrtFreeHost(srcF16Host); + aclrtFreeHost(srcF8Host); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi From 9d63f309306653f6d3b4db293dd31f271690b140 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 00:33:13 +0800 Subject: [PATCH 03/31] feat: new layout-lowering design --- docs/designs/vmi-layout-lowering-cases.md | 3103 +++++++++++++++++++++ docs/isa/micro-isa/10-reduction-ops.md | 28 +- 2 files changed, 3118 insertions(+), 13 deletions(-) create mode 100644 docs/designs/vmi-layout-lowering-cases.md diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md new file mode 100644 index 0000000000..807baf841e --- /dev/null +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -0,0 +1,3103 @@ +# VMI Layout Lowering Cases + +本文是 VMI layout/lowering 的典型 case catalog,不是完整设计总文档。它只回答一个问题: +一个 VMI logical vector 在某个场景下选择某种 layout 后,`vmi-to-vpto` 必须生成什么 +VPTO 结果。这里不写动机式描述;每个场景都给出 layout assignment 和 lowering result。 + +## 1. Layout Families + +### 1.1 Dense Layout + +Dense layout 的每个 logical lane 都有语义值。 + +```text +#pto.vmi.layout +``` + +Physical ordering: + +```text +chunk c, lane l -> logical lane c * L + l +``` + +`L` is the physical lanes per 256B VPTO vector register for the element type. + +```text +#pto.vmi.layout +``` + +`block_elems` defaults to `1`. Existing spellings are shorthands: + +```text +#pto.vmi.layout + == #pto.vmi.layout + +#pto.vmi.layout + == #pto.vmi.layout +``` + +Logical-to-physical mapping: + +```text +logical lane i +block q = i / B +in_block lane r = i % B +part p = q % F +part_block t = q / F + +physical part p, physical lane t * B + r +``` + +Required invariants: + +```text +F > 0 +B > 0 +N % (F * B) == 0 for the direct full-chunk paths in this document +``` + +### 1.2 Sparse Group-Slot Layout + +Sparse group-slot layout is not dense. Only `G` lanes have semantic values. + +```text +#pto.vmi.layout +``` + +Physical slot mapping: + +```text +N = logical lane count +S = N / G // logical lanes per source group + +slot_block(g) = g / K +slot_lane(g) = g % K +``` + +Required invariants: + +```text +G > 0 +K > 0 +G % K == 0 +K must fit in the physical vreg element count +``` + +`K` is selected by the producer/consumer plan. It is not always 8. For +`VCGADD`-packed results, `K = 8` matches the eight 32B block results written to +the low lanes of one destination vreg. For row-local reductions where each +logical group already occupies one full 256B vreg, `K = 1` keeps each group's +scalar result in lane 0 of its own physical vreg and avoids an unsupported +cross-vreg scalar pack. + +Only these lanes are semantic: + +```text +physical slot block slot_block(g), lane slot_lane(g) +``` + +All other lanes are undefined for ordinary VMI consumers. They may only be read +by group-aware ops that define how to interpret group slots. + +## 2. Plan Selection Rules + +VMI cast ops must not hard-code one physical `vcvt` plan as their semantic +layout rule. + +```text +dense cast: + source/result are dense layouts. + lowering may require deinterleaved(F, block_elems=1) around VCVT. + +group-slot cast: + source/result are both group_slots(G,K). + lowering preserves slot_block(g) and slot_lane(g). Width-changing casts are + legal only when a slot-preserving VPTO plan is registered, or when the cast + can be commuted through a later group-aware consumer such as group_broadcast. +``` + +Illegal consumer mix: + +```text +group_slots value -> ordinary dense store/add/mul +``` + +This must fail unless an explicit semantic op converts the sparse value: + +```text +group_broadcast +group_store +future explicit group-pack op +``` + +## 3. Lowering Results + +The following examples use symbolic VPTO names. `PAT_ALL_B*` means an all-true +predicate with the element granularity required by the instruction. `PAT_VLk` +means a prefix predicate for the first `k` lanes. + +Completeness rule for this section: every numbered endpoint below must contain +VMI input, assigned layouts, VPTO lowering result, and either a memory result or +an explicit diagnostic. Non-endpoint layout notes may appear only as setup for +the immediately following complete endpoints. + +```text +3.1 f16 -> f32 -> store complete +3.2 f32 -> f16 -> store complete +3.3 f8 -> f32 -> compute -> f8 complete +3.4 group_reduce S=8 -> group_store complete +3.5.1 group_reduce S=16 -> group_store complete +3.5.2 group_reduce S=16 -> broadcast -> compute -> reduce -> store + complete +3.5.3 group_reduce S=16 -> elemwise(rhs) -> group_store complete +3.6.1 group_reduce S=32 -> group_store complete +3.6.2 group_reduce S=32 -> elemwise(rhs) -> group_store complete +3.6.3 group_reduce S=32 -> broadcast -> compute -> reduce -> store + complete +3.7.1 group_reduce S=64 -> group_store complete +3.7.2 group_reduce S=64 -> elemwise(rhs) -> group_store complete +3.7.3 group_reduce S=64 -> broadcast -> compute -> reduce -> store + complete +3.8 group_reduce -> truncf -> broadcast -> dense store complete +3.9 dense store of group slots illegal diagnostic +3.10 non-load producer feeding S=32 group_reduce complete +3.11 partial tail groups complete/diagnostic +3.12 control-flow join before group_reduce complete +3.13 direct group-slot f32 -> f16 cast illegal diagnostic +3.14 unsupported group size illegal diagnostic +3.15 compact S=12 written as logical S=16 complete/design +3.16 group_slot_load layout contract complete +3.17 group_broadcast physical arity alias complete +3.18 one value with dense and group-reduce consumers complete/materialization +3.19 S=16 reduce block_elems plan selection complete/diagnostic +3.20 group_slots control-flow join complete +3.21 S=32 tail with full-tile-readable source complete/design +3.22 scf.for loop-carried layout complete +3.23 group_broadcast with multiple dense consumers complete +3.24 mask with elementwise/select/store complete +3.25 function boundary layout specialization complete/design +``` + +### 3.1 `f16 -> f32 -> store` + +VMI input: + +```text +%x16 = pto.vmi.load %base[%off] + : memref<128xf16> -> !pto.vmi.vreg<128xf16> +%x32 = pto.vmi.extf %x16 + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> +pto.vmi.store %x32, %out[%off] +``` + +Assigned layouts: + +```text +%x16 : !pto.vmi.vreg<128xf16, #pto.vmi.layout> +%x32 : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%x16_0 = pto.vlds %base[%off] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<128xf16> + +%x32_p0 = pto.vcvt %x16_0, PAT_ALL_B16 {part = "EVEN"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +%x32_p1 = pto.vcvt %x16_0, PAT_ALL_B16 {part = "ODD"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + +pto.vstsx2 %x32_p0, %x32_p1, %out[%off], "INTLV_B32", PAT_ALL_B32 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, + !pto.mask +``` + +Alternative complete VPTO lowering result if `vstsx2 INTLV_B32` is unavailable: + +```text +%x16_0 = pto.vlds %base[%off] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<128xf16> + +%x32_p0 = pto.vcvt %x16_0, PAT_ALL_B16 {part = "EVEN"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +%x32_p1 = pto.vcvt %x16_0, PAT_ALL_B16 {part = "ODD"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + +%x32_d0, %x32_d1 = pto.vintlv %x32_p0, %x32_p1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +pto.vsts %x32_d0, %out[%off], PAT_ALL_B32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %x32_d1, %out[%off_plus_64], PAT_ALL_B32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for i = 0..127: + out[off + i] = extf(base[off + i]) +``` + +### 3.2 Dense `f32 -> f16 -> store` + +VMI input: + +```text +%x32 = pto.vmi.load %base[%off] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> +%x16 = pto.vmi.truncf %x32 + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> +pto.vmi.store %x16, %out[%off] +``` + +Assigned layouts: + +```text +%x32 : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +%x16 : !pto.vmi.vreg<128xf16, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%x32_p0, %x32_p1 = pto.vldsx2 %base[%off], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%part0 = pto.vcvt %x32_p0, PAT_ALL_B32 + {part = "EVEN", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + +%part1 = pto.vcvt %x32_p1, PAT_ALL_B32 + {part = "ODD", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + +%x16_0 = pto.vor %part0, %part1, PAT_ALL_B16 + : !pto.vreg<128xf16> + +pto.vsts %x16_0, %out[%off], PAT_ALL_B16 {dist = "NORM_B16"} + : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +``` + +Alternative complete VPTO lowering result if the source has already been loaded +as two contiguous f32 chunks and must be materialized to `deinterleaved=2` before +the conversion: + +```text +%x32_d0 = pto.vlds %base[%off] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%x32_d1 = pto.vlds %base[%off_plus_64] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%x32_p0, %x32_p1 = pto.vdintlv %x32_d0, %x32_d1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%part0 = pto.vcvt %x32_p0, PAT_ALL_B32 + {part = "EVEN", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + +%part1 = pto.vcvt %x32_p1, PAT_ALL_B32 + {part = "ODD", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + +%x16_0 = pto.vor %part0, %part1, PAT_ALL_B16 + : !pto.vreg<128xf16> + +pto.vsts %x16_0, %out[%off], PAT_ALL_B16 {dist = "NORM_B16"} + : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for i = 0..127: + out[off + i] = truncf(base[off + i]) +``` + +### 3.3 Dense `f8 -> f32 -> compute -> f8` + +VMI input: + +```text +%x8 = pto.vmi.load %base[%off] +%x32 = pto.vmi.extf %x8 +%scale = pto.vmi.broadcast %scale_s : f32 -> !pto.vmi.vreg<256xf32> +%y32 = pto.vmi.mulf %x32, %scale +%y8 = pto.vmi.truncf %y32 +pto.vmi.store %y8, %out[%off] +``` + +Assigned layouts: + +```text +%x8 : !pto.vmi.vreg<256xf8, #pto.vmi.layout> +%x32 : !pto.vmi.vreg<256xf32, #pto.vmi.layout> +%scale : !pto.vmi.vreg<256xf32, #pto.vmi.layout> +%y32 : !pto.vmi.vreg<256xf32, #pto.vmi.layout> +%y8 : !pto.vmi.vreg<256xf8, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%x8_0 = pto.vlds %base[%off] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<256xf8> + +%x32_p0 = pto.vcvt %x8_0, PAT_ALL_B8 {part = "P0"} + : !pto.vreg<256xf8>, !pto.mask -> !pto.vreg<64xf32> +%x32_p1 = pto.vcvt %x8_0, PAT_ALL_B8 {part = "P1"} + : !pto.vreg<256xf8>, !pto.mask -> !pto.vreg<64xf32> +%x32_p2 = pto.vcvt %x8_0, PAT_ALL_B8 {part = "P2"} + : !pto.vreg<256xf8>, !pto.mask -> !pto.vreg<64xf32> +%x32_p3 = pto.vcvt %x8_0, PAT_ALL_B8 {part = "P3"} + : !pto.vreg<256xf8>, !pto.mask -> !pto.vreg<64xf32> + +%scale_p0 = pto.vdup %scale_s, PAT_ALL_B32 + : f32, !pto.mask -> !pto.vreg<64xf32> +%scale_p1 = pto.vdup %scale_s, PAT_ALL_B32 + : f32, !pto.mask -> !pto.vreg<64xf32> +%scale_p2 = pto.vdup %scale_s, PAT_ALL_B32 + : f32, !pto.mask -> !pto.vreg<64xf32> +%scale_p3 = pto.vdup %scale_s, PAT_ALL_B32 + : f32, !pto.mask -> !pto.vreg<64xf32> + +%y32_p0 = pto.vmul %x32_p0, %scale_p0, PAT_ALL_B32 +%y32_p1 = pto.vmul %x32_p1, %scale_p1, PAT_ALL_B32 +%y32_p2 = pto.vmul %x32_p2, %scale_p2, PAT_ALL_B32 +%y32_p3 = pto.vmul %x32_p3, %scale_p3, PAT_ALL_B32 + +%y8_p0 = pto.vcvt %y32_p0, PAT_ALL_B32 + {part = "P0", rnd = "R", sat = "SAT"} -> !pto.vreg<256xf8> +%y8_p1 = pto.vcvt %y32_p1, PAT_ALL_B32 + {part = "P1", rnd = "R", sat = "SAT"} -> !pto.vreg<256xf8> +%y8_p2 = pto.vcvt %y32_p2, PAT_ALL_B32 + {part = "P2", rnd = "R", sat = "SAT"} -> !pto.vreg<256xf8> +%y8_p3 = pto.vcvt %y32_p3, PAT_ALL_B32 + {part = "P3", rnd = "R", sat = "SAT"} -> !pto.vreg<256xf8> + +%y8_01 = pto.vor %y8_p0, %y8_p1, PAT_ALL_B8 +%y8_23 = pto.vor %y8_p2, %y8_p3, PAT_ALL_B8 +%y8_0 = pto.vor %y8_01, %y8_23, PAT_ALL_B8 + +pto.vsts %y8_0, %out[%off], PAT_ALL_B8 {dist = "NORM_B8"} + : !pto.vreg<256xf8>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for i = 0..255: + out[off + i] = truncf(extf(base[off + i]) * scale_s) +``` + +### 3.4 `group_reduce` S=8 f32 + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<64xf32> -> !pto.vmi.vreg<64xf32> +%mask = pto.vmi.create_mask %c64 : index -> !pto.vmi.mask<64xpred> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} + : !pto.vmi.vreg<64xf32>, !pto.vmi.mask<64xpred> + -> !pto.vmi.vreg<64xf32> +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x : !pto.vmi.vreg<64xf32, #pto.vmi.layout> +%mask : !pto.vmi.mask<64xpred, #pto.vmi.layout> +%sum : !pto.vmi.vreg<64xf32, + #pto.vmi.layout> +``` + +VPTO lowering result for one full 8-row tile: + +```text +%mask_chunk = pto.pge_b32 "PAT_ALL" + +%x_chunk = pto.vlds %base[%tile_off] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> + +%sum_block = pto.vcgadd %x_chunk, %mask_chunk + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%store8 = pto.pge_b32 "PAT_VL8" +pto.vsts %sum_block, %sum_out[%group_tile_off], %store8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Lowering result for one chunk, per the `visa.txt` VCGADD contract: + +```text +%sum_block lane 0 = reduce %x lanes 0..7 +%sum_block lane 1 = reduce %x lanes 8..15 +... +%sum_block lane 7 = reduce %x lanes 56..63 +all non-slot lanes are non-semantic +``` + +Layout result: + +```text +G = N / 8 +K = 8 + +slot_block(g) = g / 8 +slot_lane(g) = g % 8 +``` + +Memory result: + +```text +for r = 0..7: + sum_out[group_tile_off + r] = reduce(row_r[0..7]) +``` + +### 3.5 `group_reduce` S=16 f32, load-fused split + +The facts used by this lowering are checked against the current repo: + +```text +pto.vldsx2 supports "BDINTLV". +pto.vstsx2 supports only "INTLV_B8" / "INTLV_B16" / "INTLV_B32". +visa.txt says VCGADD writes one 32B-block result continuously to destination +LSBs; the current repository golden tests follow lanes 0..7 for f32. +``` + +There are three complete consumers for this layout today: + +```text +load -> group_reduce -> group_store(sum) +load -> group_reduce -> elementwise compute on group-slot values + -> group_store +load -> group_reduce -> group_broadcast -> elementwise compute + -> group_reduce -> group_store +``` + +#### 3.5.1 Reduce And Store Group Sums + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref -> !pto.vmi.vreg +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = N / 16} +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = N / 16} +``` + +Assigned layouts: + +```text +%x : !pto.vmi.vreg> + +%sum : !pto.vmi.vreg> +``` + +For each 8-row tile: + +```text +row r = 16xf32 = row_r.lo8, row_r.hi8 +``` + +VPTO lowering result for one full 8-row tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%lo, %hi = pto.vldsx2 %base[%tile_off], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%lo lanes 0..7 = row0.lo8 +%lo lanes 8..15 = row1.lo8 +... +%lo lanes 56..63 = row7.lo8 + +%hi lanes 0..7 = row0.hi8 +%hi lanes 8..15 = row1.hi8 +... +%hi lanes 56..63 = row7.hi8 + +%lo_sum = pto.vcgadd %lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%hi_sum = pto.vcgadd %hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%sum_block = pto.vadd %lo_sum, %hi_sum, %sum_mask + : !pto.vreg<64xf32> + +%store8 = pto.pge_b32 "PAT_VL8" +pto.vsts %sum_block, %sum_out[%group_tile_off], %store8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +`BDINTLV` here denotes the ISA `#bdintlv` block-based interleaving load mode: +it loads `2 * VL` bytes and sends even 32B blocks to the first destination +register and odd 32B blocks to the second destination register. For f32, +one 32B block is `8xf32`, matching `block_elems = 8`. + +Tail tiles use the same dataflow with `%all_b32` replaced by masks derived from +the VMI mask for the low and high 8-lane halves of each row. + +Layout result: + +```text +G = N / 16 +K = 8 + +slot_block(g) = g / 8 +slot_lane(g) = g % 8 + +%sum_block lane 0 = reduce row0 lanes 0..15 +%sum_block lane 1 = reduce row1 lanes 0..15 +... +%sum_block lane 7 = reduce row7 lanes 0..15 +``` + +No VMI value exposes `%lo_sum` or `%hi_sum`. They are internal VPTO values. + +Memory result: + +```text +sum_out[group_tile_off + 0] = reduce row0 lanes 0..15 +sum_out[group_tile_off + 1] = reduce row1 lanes 0..15 +... +sum_out[group_tile_off + 7] = reduce row7 lanes 0..15 +``` + +This endpoint is fully specified: the only sparse value is `%sum`; `group_store` +stores the low 8 slot lanes with an ordinary prefix store. + +#### 3.5.2 Reduce, Broadcast, Elementwise, Reduce, Store + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref -> !pto.vmi.vreg +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = N / 16} +%b = pto.vmi.group_broadcast %sum {num_groups = N / 16} +%y = pto.vmi.mulf %x, %b +%ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = N / 16} +pto.vmi.group_store %ysum, %out[%group_off], %c1 {num_groups = N / 16} +``` + +Assigned layouts: + +```text +%x : !pto.vmi.vreg> +%sum : !pto.vmi.vreg> +%b : !pto.vmi.vreg> +%y : !pto.vmi.vreg> +%ysum : !pto.vmi.vreg> +``` + +VPTO lowering result for one full 8-row tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%x_lo, %x_hi = pto.vldsx2 %base[%tile_off], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_lo_sum = pto.vcgadd %x_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%x_hi_sum = pto.vcgadd %x_hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%sum_block = pto.vadd %x_lo_sum, %x_hi_sum, %sum_mask + : !pto.vreg<64xf32> + +%lane_id = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> +%broadcast_idx = pto.vshrs %lane_id, %c3_i16, %all_b32 + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> + +// This is the materialization of pto.vmi.group_broadcast. The group sums are +// in %sum_block lanes 0..7; vselr expands each sum to the 8 lanes of the +// corresponding row half. The following vmul/vcgadd consume an ordinary dense +// physical vector. +%b_rows = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> + +%y_lo = pto.vmul %x_lo, %b_rows, %all_b32 + : !pto.vreg<64xf32> +%y_hi = pto.vmul %x_hi, %b_rows, %all_b32 + : !pto.vreg<64xf32> + +%y_lo_sum = pto.vcgadd %y_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%y_hi_sum = pto.vcgadd %y_hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Final per-row reduction and store. +%ysum_block = pto.vadd %y_lo_sum, %y_hi_sum, %sum_mask + : !pto.vreg<64xf32> + +%store8 = pto.pge_b32 "PAT_VL8" +pto.vsts %ysum_block, %out[%group_tile_off], %store8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +This trace processes 8 logical rows at once. `num_groups = N / 16` means each +logical group is one `16xf32` row, and one full f32 VPTO tile covers 8 such +groups: + +```text +64 f32 lanes per physical part = 8 rows * 8 f32 lanes per half-row +``` + +Tail tiles use the same dataflow with `%all_b32` replaced by masks derived from +the VMI mask for the low and high 8-lane halves of each row. + +Physical lane result for the tile: + +```text +%x_lo lanes 0..7 = row0[0..7] +%x_lo lanes 8..15 = row1[0..7] +... +%x_lo lanes 56..63 = row7[0..7] + +%x_hi lanes 0..7 = row0[8..15] +%x_hi lanes 8..15 = row1[8..15] +... +%x_hi lanes 56..63 = row7[8..15] + +%sum_block lanes 0..7 = + reduce(row0[0..15]), reduce(row1[0..15]), ..., reduce(row7[0..15]) + +%b_rows lanes 0..7 = reduce(row0[0..15]) +%b_rows lanes 8..15 = reduce(row1[0..15]) +... +%b_rows lanes 56..63 = reduce(row7[0..15]) + +For each row `r` in this 8-row tile: + +%y_lo lanes r*8 .. r*8+7 = + row_r[0..7] * reduce(row_r[0..15]) + +%y_hi lanes r*8 .. r*8+7 = + row_r[8..15] * reduce(row_r[0..15]) + +Concretely: +%y_lo lanes 0..7 = row0[0..7] * reduce(row0[0..15]) +%y_lo lanes 8..15 = row1[0..7] * reduce(row1[0..15]) +... +%y_lo lanes 56..63 = row7[0..7] * reduce(row7[0..15]) + +%y_hi lanes 0..7 = row0[8..15] * reduce(row0[0..15]) +%y_hi lanes 8..15 = row1[8..15] * reduce(row1[0..15]) +... +%y_hi lanes 56..63 = row7[8..15] * reduce(row7[0..15]) + +%ysum_block lanes 0..7 = + reduce(%y row0), reduce(%y row1), ..., reduce(%y row7) +``` + +Memory result: + +```text +out[group_tile_off + r] = + reduce_i((row_r[i] * reduce_j(row_r[j])) for i in 0..15) + = reduce(row_r[0..15]) * reduce(row_r[0..15]) +for r = 0..7 +``` + +If a later consumer requires row-major contiguous order, `vmi-to-vpto` must +materialize: + +```text +deinterleaved=2, block_elems=8 -> contiguous +``` + +This materialization cannot be implemented with `vstsx2 INTLV_B32`, because +that instruction interleaves individual b32 elements, not 32B row halves. Until +a concrete block-interleave register materialization or store op is selected, +row-major store of this layout must be rejected with: + +```text +VMI-LAYOUT-CONTRACT: + pto.vmi.store requires materializing + #pto.vmi.layout to contiguous, but no + VPTO block-interleave materialization/store plan is registered. +``` + +#### 3.5.3 Reduce Result, Elementwise, Store + +This case computes a per-row reduction, applies an elementwise operation to the +reduced values themselves, and stores one result per group. There is no +`group_broadcast` in this flow because the elementwise op is not applied to the +original `8x16xf32` matrix elements. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> +%rhs = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%outv = pto.vmi.addf %sum, %rhs +pto.vmi.group_store %outv, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x for reduce: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%rhs: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%outv: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +For this endpoint, the RHS is a packed per-group vector: + +```text +rhs_base[rhs_off + r] = rhs(row r), for r = 0..7 +``` + +Layout assignment must treat `group_slot_load` as a group-slot producer: one +f32 value per group is placed in the live slot lanes. It must not use +`group_load`, which loads `group_size` data elements per group instead of one +per-group scalar. + +The elementwise op runs only on the live group-slot lanes: + +```text +%sum lanes 0..7 = + reduce(row0[0..15]), reduce(row1[0..15]), ..., reduce(row7[0..15]) + +%rhs lanes 0..7 = + rhs(row0), rhs(row1), ..., rhs(row7) + +%outv lanes 0..7 = + %sum lanes 0..7 + %rhs lanes 0..7 + +lanes 8..63 remain dead/zero and are masked off by PAT_VL8. +``` + +VPTO lowering result for one full 8-row tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" +%one_block = pto.pge_b32 "PAT_VL1" + +// Reduction path: use BDINTLV to feed two VCG reductions. +%x_lo, %x_hi = pto.vldsx2 %base[%tile_off], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_lo_sum = pto.vcgadd %x_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%x_hi_sum = pto.vcgadd %x_hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%sum_block = pto.vadd %x_lo_sum, %x_hi_sum, %sum_mask + : !pto.vreg<64xf32> + +// Packed RHS group-slot load. %rhs_tile_base points to rhs_base[rhs_off]. +// One 32B block contains 8 f32 RHS values and materializes lanes 0..7; all +// other lanes are dead/zero. +%rhs_block = pto.vsldb %rhs_tile_base, %c0_i16, %c0_i16, %one_block + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +// Elementwise compute on group-slot values. Only lanes 0..7 are live. +%outv_block = pto.vadd %sum_block, %rhs_block, %sum_mask + : !pto.vreg<64xf32> + +pto.vsts %outv_block, %out[%group_tile_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + s = reduce(row_r[0..15]) + out[group_tile_off + r] = s + rhs[r] +``` + +### 3.6 `group_reduce` S=32 f32, 4-way split + +This case covers one `8x32xf32` tile. Each logical row is 128B, so it must be +split into four 32B partial rows before `vcgadd` can reduce it efficiently. + +The canonical layout for the input is: + +```text +%x : !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +With `deinterleaved = 4`, physical part `p` contains columns whose logical +column index is `p mod 4`: + +```text +%x_p0 lanes r*8 .. r*8+7 = + row_r[0], row_r[4], row_r[8], ..., row_r[28] + +%x_p1 lanes r*8 .. r*8+7 = + row_r[1], row_r[5], row_r[9], ..., row_r[29] + +%x_p2 lanes r*8 .. r*8+7 = + row_r[2], row_r[6], row_r[10], ..., row_r[30] + +%x_p3 lanes r*8 .. r*8+7 = + row_r[3], row_r[7], row_r[11], ..., row_r[31] +``` + +Each physical part now has exactly 8 f32 values per row, so one `vcgadd` per +part computes one partial sum per row. The four partial sums are then added +under `PAT_VL8`. + +The full contiguous-to-4-way materialization for one tile should fuse the first +deinterleave level into the load. `vldsx2 DINTLV_B32` loads `2 * VL` bytes and +splits even/odd f32 elements into two physical vectors. Two such loads cover +the `8x32xf32` tile, and a second register `vdintlv` level splits even columns +into `mod4 = 0/2` and odd columns into `mod4 = 1/3`. + +This setup documentation is repeated inside every complete 32-wide endpoint +below. + +```text +%x_even_0, %x_odd_0 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_even_1, %x_odd_1 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_p0, %x_p2 = pto.vdintlv %x_even_0, %x_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1, %x_p3 = pto.vdintlv %x_odd_0, %x_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +``` + +Each endpoint below inlines this materialization before the first consumer of +`%x_p0..%x_p3`. + +#### 3.6.1 Reduce And Store Group Sums + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result for one full 8-row tile: + +```text +%x_even_0, %x_odd_0 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_even_1, %x_odd_1 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_p0, %x_p2 = pto.vdintlv %x_even_0, %x_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1, %x_p3 = pto.vdintlv %x_odd_0, %x_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%s0 = pto.vcgadd %x_p0, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> + +pto.vsts %sum_block, %sum_out[%group_tile_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + sum_out[group_tile_off + r] = reduce(row_r[0..31]) +``` + +#### 3.6.2 Reduce Result, Elementwise, Store + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> +%rhs = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<256xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%outv = pto.vmi.addf %sum, %rhs +pto.vmi.group_store %outv, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%sum, %rhs, %outv: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result for one full 8-row tile: + +```text +%x_even_0, %x_odd_0 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_even_1, %x_odd_1 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_p0, %x_p2 = pto.vdintlv %x_even_0, %x_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1, %x_p3 = pto.vdintlv %x_odd_0, %x_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" +%one_block = pto.pge_b32 "PAT_VL1" + +%s0 = pto.vcgadd %x_p0, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> + +// Packed RHS group-slot load. %rhs_tile_base points to rhs_base[rhs_off]. +%rhs_block = pto.vsldb %rhs_tile_base, %c0_i16, %c0_i16, %one_block + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +%outv_block = pto.vadd %sum_block, %rhs_block, %sum_mask + : !pto.vreg<64xf32> + +pto.vsts %outv_block, %out[%group_tile_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_tile_off + r] = reduce(row_r[0..31]) + rhs[r] +``` + +#### 3.6.3 Reduce, Broadcast, Elementwise, Reduce, Store + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%b = pto.vmi.group_broadcast %sum {num_groups = 8} +%y = pto.vmi.mulf %x, %b +%ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8} +pto.vmi.group_store %ysum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x, %b, %y: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%sum, %ysum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result for one full 8-row tile: + +```text +%x_even_0, %x_odd_0 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_even_1, %x_odd_1 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_p0, %x_p2 = pto.vdintlv %x_even_0, %x_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1, %x_p3 = pto.vdintlv %x_odd_0, %x_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%s0 = pto.vcgadd %x_p0, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> + +%lane_id = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> +%broadcast_idx = pto.vshrs %lane_id, %c3_i16, %all_b32 + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> + +// group_broadcast materialized for each deinterleaved=4 physical part. +%b_p0 = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%b_p1 = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%b_p2 = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%b_p3 = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> + +%y_p0 = pto.vmul %x_p0, %b_p0, %all_b32 : !pto.vreg<64xf32> +%y_p1 = pto.vmul %x_p1, %b_p1, %all_b32 : !pto.vreg<64xf32> +%y_p2 = pto.vmul %x_p2, %b_p2, %all_b32 : !pto.vreg<64xf32> +%y_p3 = pto.vmul %x_p3, %b_p3, %all_b32 : !pto.vreg<64xf32> + +%ys0 = pto.vcgadd %y_p0, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%ys1 = pto.vcgadd %y_p1, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%ys2 = pto.vcgadd %y_p2, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%ys3 = pto.vcgadd %y_p3, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%ys01 = pto.vadd %ys0, %ys1, %sum_mask : !pto.vreg<64xf32> +%ys23 = pto.vadd %ys2, %ys3, %sum_mask : !pto.vreg<64xf32> +%ysum_block = pto.vadd %ys01, %ys23, %sum_mask : !pto.vreg<64xf32> + +pto.vsts %ysum_block, %out[%group_tile_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + s = reduce(row_r[0..31]) + out[group_tile_off + r] = + reduce_i(row_r[i] * s for i = 0..31) + = s * s +``` + +### 3.7 `group_reduce` S=64 f32, row-local reduction + +This case covers one `8x64xf32` tile. Each logical row is exactly 256B, so the +input does not need a deinterleaved layout: + +```text +row r = 64xf32 = one !pto.vreg<64xf32> +``` + +The reduction is two-stage but row-local: + +```text +vcgadd(row_r) -> 8 partial sums in lanes 0..7 +vcadd(PAT_VL8) -> one row sum in lane 0 +``` + +The result layout is therefore not `slots = 8`. It is: + +```text +#pto.vmi.layout +``` + +Physical slot mapping for this tile: + +```text +slot_block(r) = r +slot_lane(r) = 0 + +%sum0 lane 0 = reduce row0 lanes 0..63 +%sum1 lane 0 = reduce row1 lanes 0..63 +... +%sum7 lane 0 = reduce row7 lanes 0..63 +``` + +Trying to canonicalize this result to `slots = 8` would require packing lane 0 +from eight different physical vregs into lanes 0..7 of one vreg. This document +does not use that plan. `slots = 1` is the canonical layout for S=64 row-local +group reductions. + +#### 3.7.1 Reduce And Store Group Sums + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<512xf32> -> !pto.vmi.vreg<512xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> +``` + +VPTO lowering result for one full 8-row tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%block8 = pto.pge_b32 "PAT_VL8" +%one_b32 = pto.pge_b32 "PAT_VL1" + +%x0 = pto.vlds %base[%row_off_0] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%x1 = pto.vlds %base[%row_off_1] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%x2 = pto.vlds %base[%row_off_2] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%x3 = pto.vlds %base[%row_off_3] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%x4 = pto.vlds %base[%row_off_4] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%x5 = pto.vlds %base[%row_off_5] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%x6 = pto.vlds %base[%row_off_6] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%x7 = pto.vlds %base[%row_off_7] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> + +%p0 = pto.vcgadd %x0, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p1 = pto.vcgadd %x1, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p2 = pto.vcgadd %x2, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p3 = pto.vcgadd %x3, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p4 = pto.vcgadd %x4, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p5 = pto.vcgadd %x5, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p6 = pto.vcgadd %x6, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p7 = pto.vcgadd %x7, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%sum0 = pto.vcadd %p0, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum1 = pto.vcadd %p1, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum2 = pto.vcadd %p2, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum3 = pto.vcadd %p3, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum4 = pto.vcadd %p4, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum5 = pto.vcadd %p5, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum6 = pto.vcadd %p6, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum7 = pto.vcadd %p7, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +pto.vsts %sum0, %sum_out[%group_tile_off_0], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %sum1, %sum_out[%group_tile_off_1], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %sum2, %sum_out[%group_tile_off_2], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %sum3, %sum_out[%group_tile_off_3], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %sum4, %sum_out[%group_tile_off_4], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %sum5, %sum_out[%group_tile_off_5], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %sum6, %sum_out[%group_tile_off_6], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %sum7, %sum_out[%group_tile_off_7], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + sum_out[group_tile_off + r] = reduce(row_r[0..63]) +``` + +#### 3.7.2 Reduce Result, Elementwise, Store + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<512xf32> -> !pto.vmi.vreg<512xf32> +%rhs = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<512xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%outv = pto.vmi.addf %sum, %rhs +pto.vmi.group_store %outv, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> + +%sum, %rhs, %outv: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> +``` + +VPTO lowering result for one full 8-row tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%block8 = pto.pge_b32 "PAT_VL8" +%one_b32 = pto.pge_b32 "PAT_VL1" + +%x0 = pto.vlds %base[%row_off_0] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +%x1 = pto.vlds %base[%row_off_1] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +%x2 = pto.vlds %base[%row_off_2] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +%x3 = pto.vlds %base[%row_off_3] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +%x4 = pto.vlds %base[%row_off_4] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +%x5 = pto.vlds %base[%row_off_5] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +%x6 = pto.vlds %base[%row_off_6] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +%x7 = pto.vlds %base[%row_off_7] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> + +%p0 = pto.vcgadd %x0, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p1 = pto.vcgadd %x1, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p2 = pto.vcgadd %x2, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p3 = pto.vcgadd %x3, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p4 = pto.vcgadd %x4, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p5 = pto.vcgadd %x5, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p6 = pto.vcgadd %x6, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p7 = pto.vcgadd %x7, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%sum0 = pto.vcadd %p0, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum1 = pto.vcadd %p1, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum2 = pto.vcadd %p2, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum3 = pto.vcadd %p3, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum4 = pto.vcadd %p4, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum5 = pto.vcadd %p5, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum6 = pto.vcadd %p6, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum7 = pto.vcadd %p7, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%rhs0 = pto.vsldb %rhs_ptr_0, %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%rhs1 = pto.vsldb %rhs_ptr_1, %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%rhs2 = pto.vsldb %rhs_ptr_2, %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%rhs3 = pto.vsldb %rhs_ptr_3, %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%rhs4 = pto.vsldb %rhs_ptr_4, %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%rhs5 = pto.vsldb %rhs_ptr_5, %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%rhs6 = pto.vsldb %rhs_ptr_6, %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%rhs7 = pto.vsldb %rhs_ptr_7, %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +%out0 = pto.vadd %sum0, %rhs0, %one_b32 : !pto.vreg<64xf32> +%out1 = pto.vadd %sum1, %rhs1, %one_b32 : !pto.vreg<64xf32> +%out2 = pto.vadd %sum2, %rhs2, %one_b32 : !pto.vreg<64xf32> +%out3 = pto.vadd %sum3, %rhs3, %one_b32 : !pto.vreg<64xf32> +%out4 = pto.vadd %sum4, %rhs4, %one_b32 : !pto.vreg<64xf32> +%out5 = pto.vadd %sum5, %rhs5, %one_b32 : !pto.vreg<64xf32> +%out6 = pto.vadd %sum6, %rhs6, %one_b32 : !pto.vreg<64xf32> +%out7 = pto.vadd %sum7, %rhs7, %one_b32 : !pto.vreg<64xf32> + +pto.vsts %out0, %out[%group_tile_off_0], %one_b32 {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %out1, %out[%group_tile_off_1], %one_b32 {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %out2, %out[%group_tile_off_2], %one_b32 {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %out3, %out[%group_tile_off_3], %one_b32 {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %out4, %out[%group_tile_off_4], %one_b32 {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %out5, %out[%group_tile_off_5], %one_b32 {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %out6, %out[%group_tile_off_6], %one_b32 {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %out7, %out[%group_tile_off_7], %one_b32 {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_tile_off + r] = reduce(row_r[0..63]) + rhs[r] +``` + +#### 3.7.3 Reduce, Broadcast, Elementwise, Reduce, Store + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<512xf32> -> !pto.vmi.vreg<512xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%b = pto.vmi.group_broadcast %sum {num_groups = 8} +%y = pto.vmi.mulf %x, %b +%ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8} +pto.vmi.group_store %ysum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x, %b, %y: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> + +%sum, %ysum: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> +``` + +VPTO lowering result for one full 8-row tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%block8 = pto.pge_b32 "PAT_VL8" +%one_b32 = pto.pge_b32 "PAT_VL1" + +// The compiler emits this row-local block once for each r in 0..7. +%x_r = pto.vlds %base[%row_off_r] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> + +%p_r = pto.vcgadd %x_r, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum_r = pto.vcadd %p_r, %block8 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// This vdup is the lowering of pto.vmi.group_broadcast for slots=1. +%b_r = pto.vdup %sum_r, %all_b32 {position = "LOWEST"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%y_r = pto.vmul %x_r, %b_r, %all_b32 : !pto.vreg<64xf32> + +%yp_r = pto.vcgadd %y_r, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%ysum_r = pto.vcadd %yp_r, %block8 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +pto.vsts %ysum_r, %out[%group_tile_off_r], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +The row-local block above is not a runtime loop requirement. It is the repeated +VPTO shape for row offsets `%row_off_0` through `%row_off_7` and store offsets +`%group_tile_off_0` through `%group_tile_off_7`. + +Memory result: + +```text +for r = 0..7: + s = reduce(row_r[0..63]) + out[group_tile_off + r] = + reduce_i(row_r[i] * s for i = 0..63) + = s * s +``` + +### 3.8 `group_reduce -> truncf -> group_broadcast -> store` + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> +%sum32 = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%sum16 = pto.vmi.truncf %sum32 +%b16 = pto.vmi.group_broadcast %sum16 {num_groups = 8} +pto.vmi.store %b16, %out[%off] +``` + +Assigned layouts: + +```text +%x : !pto.vmi.vreg<128xf32, + #pto.vmi.layout> +%sum32 : !pto.vmi.vreg<128xf32, + #pto.vmi.layout> +%sum16 : semantic value only; not materialized as a group-slot VPTO value +%b32 : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +%b16 : !pto.vmi.vreg<128xf16, #pto.vmi.layout> +``` + +This case is supported by commuting `truncf` after `group_broadcast`: + +```text +group_broadcast(truncf(group_reduce(x))) + == truncf(group_broadcast(group_reduce(x))) +``` + +This avoids materializing a group-slot f16 value. The only cast emitted is the +existing dense `f32 deinterleaved=2 -> contiguous f16` truncation. + +VPTO lowering result for one full 8-row tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%x_lo, %x_hi = pto.vldsx2 %base[%tile_off], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_lo_sum = pto.vcgadd %x_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%x_hi_sum = pto.vcgadd %x_hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%sum32_block = pto.vadd %x_lo_sum, %x_hi_sum, %sum_mask + : !pto.vreg<64xf32> + +%lane_id = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> +%broadcast_idx = pto.vshrs %lane_id, %c3_i16, %all_b32 + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> + +// This vselr is the VPTO lowering of pto.vmi.group_broadcast. The later store +// only writes lanes as-is; it does not duplicate group-slot values. +%b32_rows = pto.vselr %sum32_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> + +// The broadcasted f32 value is dense deinterleaved=2. +// Both parity parts carry the same per-row broadcast values. +%b16_even = pto.vcvt %b32_rows, %all_b32 {part = "EVEN", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%b16_odd = pto.vcvt %b32_rows, %all_b32 {part = "ODD", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + +%all_b16 = pto.pge_b16 "PAT_ALL" +%b16 = pto.vor %b16_even, %b16_odd, %all_b16 + : !pto.vreg<128xf16> + +pto.vsts %b16, %out[%off], %all_b16 {dist = "NORM_B16"} + : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + s32 = reduce(row_r[0..15]) + s16 = truncf(s32) + out[r * 16 + 0 .. r * 16 + 15] = splat(s16) +``` + +### 3.9 Illegal Dense Consumer Of Group Slots + +VMI input: + +```text +%sum32 = pto.vmi.group_reduce_addf %x, %mask {num_groups = G} +pto.vmi.store %sum32, %out[%off] +``` + +Assigned layouts before the illegal consumer: + +```text +%sum32 : group_slots(G,K) +``` + +Required diagnostic: + +```text +VMI-LAYOUT-CONTRACT: + pto.vmi.store cannot consume #pto.vmi.layout + as a dense vector. Use pto.vmi.group_store, pto.vmi.group_broadcast, or an + explicit group-pack op. +``` + +It must not be diagnosed as: + +```text +dense store materializes group slots implicitly +``` + +That behavior would silently reinterpret a sparse group-slot value as a dense +vector. + +### 3.10 Non-Load Producer Feeding S=32 `group_reduce` + +This case proves that layout assignment is consumer-driven. The producer of the +S=32 input is an elementwise op, not a load. The S=32 `group_reduce` still +requires the elementwise result to be `deinterleaved = 4`, and that requirement +must propagate backward through the elementwise op to both operands. + +VMI input: + +```text +%a = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> +%bias = pto.vmi.broadcast %bias_s + : f32 -> !pto.vmi.vreg<256xf32> +%x = pto.vmi.addf %a, %bias +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%a, %bias, %x: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result for one full `8x32xf32` tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%a_even_0, %a_odd_0 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%a_even_1, %a_odd_1 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%a_p0, %a_p2 = pto.vdintlv %a_even_0, %a_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%a_p1, %a_p3 = pto.vdintlv %a_odd_0, %a_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%bias_p0 = pto.vdup %bias_s, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> +%bias_p1 = pto.vdup %bias_s, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> +%bias_p2 = pto.vdup %bias_s, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> +%bias_p3 = pto.vdup %bias_s, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> + +%x_p0 = pto.vadd %a_p0, %bias_p0, %all_b32 : !pto.vreg<64xf32> +%x_p1 = pto.vadd %a_p1, %bias_p1, %all_b32 : !pto.vreg<64xf32> +%x_p2 = pto.vadd %a_p2, %bias_p2, %all_b32 : !pto.vreg<64xf32> +%x_p3 = pto.vadd %a_p3, %bias_p3, %all_b32 : !pto.vreg<64xf32> + +%s0 = pto.vcgadd %x_p0, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> + +pto.vsts %sum_block, %out[%group_tile_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_tile_off + r] = + reduce_i(base[row_r, i] + bias_s for i = 0..31) +``` + +### 3.11 Partial Tail Groups + +Tail handling must be separated by the physical input layout. Row-local S=64 +can avoid inactive rows entirely. Load-fused S=16/S=32 cannot safely do that +with the current `vldsx2` materialization unless the source is known to be +full-tile readable. + +#### 3.11.1 S=64 Active Row Tail + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<384xf32> -> !pto.vmi.vreg<384xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 6} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<384xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<384xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%block8 = pto.pge_b32 "PAT_VL8" +%one_b32 = pto.pge_b32 "PAT_VL1" + +// Emit this row-local block for r = 0..5 only. No load or store is emitted for +// rows 6 and 7. +%x_r = pto.vlds %base[%row_off_r] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%p_r = pto.vcgadd %x_r, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum_r = pto.vcadd %p_r, %block8 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %sum_r, %out[%group_tile_off_r], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..5: + out[group_tile_off + r] = reduce(row_r[0..63]) +``` + +#### 3.11.2 S=32 Tail Without Full-Tile Read Contract + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<192xf32> -> !pto.vmi.vreg<192xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 6} +``` + +Assigned layout requested by the consumer: + +```text +%x: + !pto.vmi.vreg<192xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<192xf32, #pto.vmi.layout> +``` + +Required diagnostic when the source does not carry a full-tile-readable +contract: + +```text +VMI-LAYOUT-CONTRACT: + pto.vmi.group_reduce_addf with group size 32 and num_groups tail 6 requires + materializing #pto.vmi.layout. The registered fast plan + uses vldsx2 DINTLV_B32 over a full 8-row tile. This source is not marked + full-tile-readable, and the stable gather tail plan is not implemented. +``` + +If a future option enables the stable gather tail plan, the same VMI input may +lower by gathering only the active lanes. Until that plan is registered, the +converter must not silently issue the full-tile `vldsx2` loads. + +### 3.12 Control-Flow Join Before `group_reduce` + +The layout carried by a value must survive block arguments. In MLIR converter +terms, the logical VMI value lowered through control flow becomes a tuple of +physical VPTO values with one tuple type per assigned layout. + +VMI input: + +```text +%x = scf.if %cond -> !pto.vmi.vreg<256xf32> { + %a = pto.vmi.load %a_base[%a_off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> + scf.yield %a : !pto.vmi.vreg<256xf32> +} else { + %b = pto.vmi.load %b_base[%b_off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> + scf.yield %b : !pto.vmi.vreg<256xf32> +} +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%a, %b, %x: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result for the join: + +```text +%x_p0, %x_p1, %x_p2, %x_p3 = + scf.if %cond + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %a_even_0, %a_odd_0 = pto.vldsx2 %a_base[%a_tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %a_even_1, %a_odd_1 = pto.vldsx2 %a_base[%a_tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %a_p0, %a_p2 = pto.vdintlv %a_even_0, %a_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %a_p1, %a_p3 = pto.vdintlv %a_odd_0, %a_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + scf.yield %a_p0, %a_p1, %a_p2, %a_p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } else { + %b_even_0, %b_odd_0 = pto.vldsx2 %b_base[%b_tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %b_even_1, %b_odd_1 = pto.vldsx2 %b_base[%b_tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %b_p0, %b_p2 = pto.vdintlv %b_even_0, %b_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %b_p1, %b_p3 = pto.vdintlv %b_odd_0, %b_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + scf.yield %b_p0, %b_p1, %b_p2, %b_p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +``` + +The consumer after the join is the same S=32 reduction plan as section 3.6: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%s0 = pto.vcgadd %x_p0, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> + +pto.vsts %sum_block, %out[%group_tile_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + selected_row = cond ? a_row_r : b_row_r + out[group_tile_off + r] = reduce(selected_row[0..31]) +``` + +If the two branches cannot be assigned the same layout and no materialization +plan exists before `scf.yield`, the required diagnostic is: + +```text +VMI-LAYOUT-CONTRACT: + scf.yield joins incompatible VMI layouts for !pto.vmi.vreg<256xf32>. + Expected #pto.vmi.layout on every incoming value. +``` + +### 3.13 Direct Group-Slot `f32 -> f16` Cast + +This case is intentionally illegal for the current S=16/S=32 packed +group-slot layout. It prevents the compiler from treating a width-changing +`vcvt` as if it preserved low-lane group slots. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> +%sum32 = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%sum16 = pto.vmi.truncf %sum32 +pto.vmi.group_store %sum16, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts before the illegal cast: + +```text +%x: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + +%sum32: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +Required diagnostic: + +```text +VMI-LAYOUT-CONTRACT: + pto.vmi.truncf cannot lower from + #pto.vmi.layout f32 to f16 because no + slot-preserving width-changing VPTO plan is registered. f32->f16 vcvt writes + even/odd sub-lanes, not lanes 0..7. Use group_broadcast before truncf, or + keep the group_store element type as f32. +``` + +This does not contradict section 3.8. Section 3.8 is legal because the cast is +commuted after `group_broadcast`, where the value is dense again. + +### 3.14 Unsupported Group Size + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<96xf32> -> !pto.vmi.vreg<96xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Here `S = 96 / 8 = 12` f32 elements per group. The current VCG-based plans use +32B groups, i.e. 8 f32 elements per row fragment: + +```text +S = 8 -> one VCGADD block per group +S = 16 -> two 8-lane row fragments, add partial sums +S = 32 -> four 8-lane row fragments, add partial sums +S = 64 -> one full 256B row, VCGADD then VCADD +``` + +Required diagnostic: + +```text +VMI-LAYOUT-CONTRACT: + pto.vmi.group_reduce_addf with f32 group size 12 has no registered VPTO + layout plan. Supported VCG-based f32 group sizes are 8, 16, 32, and 64. + A scalar/gather fallback or a rewrite to logical group size 16 with an + explicit per-group mask is required. +``` + +### 3.15 Compact S=12 Written As Logical S=16 + +If the program wants to use the S=16 lowering for data with 12 semantic f32 +elements per group, the IR must distinguish two sizes: + +```text +logical group size used by VMI ops: 16 +active elements per group: 12 +``` + +The mask is not a prefix mask over the whole vector. It is a per-group mask: + +```text +mask lane i is active iff (i % 16) < 12 +``` + +The group load surface carries the physical source stride as an SSA operand: + +```text +%x = pto.vmi.group_load %base[%off], %source_group_stride + {num_groups = G, group_size = S} + : !pto.ptr, index -> !pto.vmi.vreg +``` + +`source_group_stride` is in elements, not bytes. It is an operand because it may +come from a dynamic leading dimension, a subview, or a runtime tile descriptor. +Static strides use a constant index operand and can be canonicalized later. +`group_size` remains an attribute in this design because it selects the logical +load layout. `active_elems_per_group` belongs to the mask producer, not to the +load. + +Grouped masks use a paired `pto.vmi.create_group_mask` op. It is intentionally +separate from ordinary prefix `pto.vmi.create_mask` so the IR makes group +semantics explicit next to `pto.vmi.group_load` / `pto.vmi.group_reduce_*`: + +```text +%mask = pto.vmi.create_group_mask %active_elems_per_group + {num_groups = G, group_size = S} + : index -> !pto.vmi.mask<(G*S)xpred> +``` + +Semantics: + +```text +lane i is active iff (i % S) < active_elems_per_group +``` + +Ordinary `pto.vmi.create_mask %active_lanes` keeps the prefix-mask meaning: + +```text +lane i is active iff i < active_lanes +``` + +#### 3.15.1 Existing Design Works If Source Row Stride Is 16 + +If memory already has a 16-f32 row stride, the user can write a logical S=16 +tile and mask off the last four lanes of every group. + +VMI input: + +```text +%stride16 = arith.constant 16 : index +%x = pto.vmi.group_load %base[%off], %stride16 + {num_groups = 8, group_size = 16} + : !pto.ptr, index -> !pto.vmi.vreg<128xf32> +%c12 = arith.constant 12 : index +%mask = pto.vmi.create_group_mask %c12 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + +%mask: + !pto.vmi.mask<128xpred, + #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +VPTO lowering result for one `8x16xf32` tile: + +```text +%lo_mask = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%lane = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> +%row = pto.vshrs %lane, %c3_i16, %lo_mask + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +%row8 = pto.vshls %row, %c3_i16, %lo_mask + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +%col = pto.vsub %lane, %row8, %lo_mask + : !pto.vreg<64xi32> +%hi4_mask = pto.vcmps %col, %c4_i32, %lo_mask, "lt" + : !pto.vreg<64xi32>, i32, !pto.mask -> !pto.mask + +%lo, %hi = pto.vldsx2 %base[%tile_off], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%lo lanes r*8 .. r*8+7 = row_r[0..7] +%hi lanes r*8 .. r*8+3 = row_r[8..11] +%hi lanes r*8+4 .. r*8+7 = row_r[12..15] // inactive by mask + +%lo_sum = pto.vcgadd %lo, %lo_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%hi_sum = pto.vcgadd %hi, %hi4_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%sum_block = pto.vadd %lo_sum, %hi_sum, %sum_mask + : !pto.vreg<64xf32> + +pto.vsts %sum_block, %out[%group_tile_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_tile_off + r] = reduce(row_r[0..11]) +``` + +Design requirement added by this case: VMI mask lowering must support +group-periodic masks by generating the predicate from lane indices. It must not +rewrite this mask to `PAT_M4`: VISA defines `M4` as multiples of 4, not the +first four lanes of each 8-lane block. + +```text +lane = vci(0) +row = lane >> 3 +col = lane - (row << 3) +mask = col < 4 +``` + +#### 3.15.2 Source Row Stride Greater Than 16 + +For now, support the non-compact case where each physical row has at least 16 +f32 slots and the row stride is greater than 16. The fast strided-block path +requires the row stride to be a multiple of one 32B block: + +```text +source_group_stride % 8 == 0 +``` + +The example below uses `source_group_stride = 24`. Each row has 12 semantic +values, 4 masked-but-readable slots, and 8 extra skipped slots: + +```text +row_r[0..11] semantic +row_r[12..15] readable but inactive for the S=16 logical group +row_r[16..23] outside the logical group +``` + +VMI input: + +```text +%stride24 = arith.constant 24 : index +%x = pto.vmi.group_load %base[%off], %stride24 + {num_groups = 8, group_size = 16} + : !pto.ptr, index -> !pto.vmi.vreg<128xf32> +%c12 = arith.constant 12 : index +%mask = pto.vmi.create_group_mask %c12 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts are the same as section 3.15.1: + +```text +%x, %mask: + #pto.vmi.layout +%sum: + #pto.vmi.layout +``` + +VPTO lowering result: + +```text +%lo_mask = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%lane = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> +%row = pto.vshrs %lane, %c3_i16, %lo_mask + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +%row8 = pto.vshls %row, %c3_i16, %lo_mask + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +%col = pto.vsub %lane, %row8, %lo_mask + : !pto.vreg<64xi32> +%hi4_mask = pto.vcmps %col, %c4_i32, %lo_mask, "lt" + : !pto.vreg<64xi32>, i32, !pto.mask -> !pto.mask + +// source_group_stride = 24 f32 = 3 * 32B blocks. +%stride_blocks = %c3_i16 + +%base_lo = %base + tile_off +%base_hi = %base + tile_off + 8 + +%lo = pto.vsldb %base_lo, %stride_blocks, %c0_i16, %lo_mask + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%hi = pto.vsldb %base_hi, %stride_blocks, %c0_i16, %lo_mask + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +%lo lanes r*8 .. r*8+7 = row_r[0..7] +%hi lanes r*8 .. r*8+7 = row_r[8..15] + +%lo_sum = pto.vcgadd %lo, %lo_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%hi_sum = pto.vcgadd %hi, %hi4_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%sum_block = pto.vadd %lo_sum, %hi_sum, %sum_mask + : !pto.vreg<64xf32> + +pto.vsts %sum_block, %out[%group_tile_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_tile_off + r] = + reduce(base[tile_off + r * 24 + 0 .. tile_off + r * 24 + 11]) +``` + +If `source_group_stride > 16` but is not a multiple of 8 f32 elements, this +strided-block path is not legal because `vsldb` block addresses are 32B based. +That case remains unsupported until a gather materialization is selected. + +#### 3.15.3 Compact Source Row Stride 12 + +Compact storage is explicitly out of scope for the first implementation: + +```text +row0[0..11], row1[0..11], row2[0..11], ... +``` + +Required diagnostic: + +```text +VMI-LAYOUT-CONTRACT: + logical group size 16 with active_elems_per_group 12 and + source_group_stride 12 requires compact-row gather materialization. This + plan is not part of the initial VMI layout lowering. +``` + +### 3.16 `group_slot_load` Layout Contract + +`group_slot_load` is separate from `group_load`. + +```text +group_load: + loads group_size data elements per group and produces dense grouped data. + +group_slot_load: + loads one scalar value per group and produces sparse group slots. +``` + +Surface form: + +```text +%v = pto.vmi.group_slot_load %base[%off], %source_group_stride + {num_groups = G} + : !pto.ptr, index -> !pto.vmi.vreg +``` + +Semantics: + +```text +semantic group slot g = base[off + g * source_group_stride] +``` + +The result logical lane count `N` remains the surrounding VMI value shape. Only +the `G` group slots are semantic. Layout assignment chooses the sparse physical +placement requested by the consumer: + +```text +#pto.vmi.layout +#pto.vmi.layout +``` + +#### 3.16.1 Packed `group_slot_load`, `slots = 8` + +VMI input: + +```text +%rhs = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c1 {num_groups = 8} + : !pto.ptr, index -> !pto.vmi.vreg<128xf32> +pto.vmi.group_store %rhs, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layout: + +```text +%rhs: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%slot_mask = pto.pge_b32 "PAT_VL8" +%one_block = pto.pge_b32 "PAT_VL1" + +// source_group_stride = 1, so one 32B block contains all 8 scalar group slots. +%rhs_block = pto.vsldb %rhs_base[%rhs_off], %c0_i16, %c0_i16, %one_block + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +pto.vsts %rhs_block, %out[%group_off], %slot_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for g = 0..7: + out[group_off + g] = rhs_base[rhs_off + g] +``` + +If `source_group_stride != 1`, this packed `slots = 8` plan requires a +strided/gather group-slot load materializer. Until that plan is registered, +`group_slot_load` with `slots = 8` and non-unit stride must diagnose instead of +silently using full-group `group_load`. + +#### 3.16.2 Row-Local `group_slot_load`, `slots = 1` + +VMI input: + +```text +%rhs = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c1 {num_groups = 8} + : !pto.ptr, index -> !pto.vmi.vreg<512xf32> +pto.vmi.group_store %rhs, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layout: + +```text +%rhs: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%one_b32 = pto.pge_b32 "PAT_VL1" + +// Emit this shape for r = 0..7. Each result value carries one semantic slot +// in lane 0, matching the S=64 row-local group_reduce result layout. +%rhs_r = pto.vsldb %rhs_base[%rhs_off_plus_r], %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +pto.vsts %rhs_r, %out[%group_off_plus_r], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r] = rhs_base[rhs_off + r] +``` + +### 3.17 `group_broadcast` Physical Arity Alias + +This case fixes a lowering invariant: a layout determines physical arity. A +`deinterleaved = 2` result has two physical bundle entries even when both +entries can reuse the same VPTO SSA value. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%b = pto.vmi.group_broadcast %sum {num_groups = 8} +%h = pto.vmi.truncf %b +pto.vmi.store %h, %out[%off] +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%b: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%h: + !pto.vmi.vreg<128xf16, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%x_lo, %x_hi = pto.vldsx2 %base[%tile_off], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%lo_sum = pto.vcgadd %x_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%hi_sum = pto.vcgadd %x_hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum_block = pto.vadd %lo_sum, %hi_sum, %sum_mask + : !pto.vreg<64xf32> + +%lane_id = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> +%broadcast_idx = pto.vshrs %lane_id, %c3_i16, %all_b32 + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> + +%b_rows = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> + +// Physical bundle binding for %b, not emitted VPTO ops: +// physical entry 0 = %b_rows +// physical entry 1 = %b_rows +// The layout still has two physical entries; they alias the same SSA value +// because every even/odd logical lane pair contains the same broadcast value. + +%h_even = pto.vcvt %b_rows, %all_b32 {part = "EVEN", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%h_odd = pto.vcvt %b_rows, %all_b32 {part = "ODD", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + +%all_b16 = pto.pge_b16 "PAT_ALL" +%h0 = pto.vor %h_even, %h_odd, %all_b16 + : !pto.vreg<128xf16> + +pto.vsts %h0, %out[%off], %all_b16 {dist = "NORM_B16"} + : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + s = reduce(row_r[0..15]) + out[r * 16 + 0 .. r * 16 + 15] = truncf(s) +``` + +### 3.18 One Value With Dense And Group-Reduce Consumers + +This case forces layout assignment to handle a solvable use-site conflict. One +consumer requires an S=32 group-reduce layout; another consumer requires dense +row-major store. This is not semantically illegal. It must be solved by +use-site materialization or producer rematerialization when a registered plan +exists. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} +pto.vmi.store %x, %copy_out[%off] +``` + +Assigned layouts: + +```text +%x for group_reduce: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%x for dense store: + requires #pto.vmi.layout +``` + +If `%x` is cheap to rematerialize, layout assignment may clone the producer for +the dense store. Otherwise, if the registry has a `deinterleaved = 4 -> +contiguous` materialization plan, layout assignment may keep `%x` in +`deinterleaved = 4` and insert `ensure_layout` before the dense store. + +VPTO lowering result: + +```text +%x_even_0, %x_odd_0 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_even_1, %x_odd_1 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_p0, %x_p2 = pto.vdintlv %x_even_0, %x_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1, %x_p3 = pto.vdintlv %x_odd_0, %x_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%s0 = pto.vcgadd %x_p0, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> + +pto.vsts %sum_block, %sum_out[%group_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + +// Dense store materialization for the second consumer. +%even0, %even1 = pto.vintlv %x_p0, %x_p2 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%odd0, %odd1 = pto.vintlv %x_p1, %x_p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%d0, %d1 = pto.vintlv %even0, %odd0 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%d2, %d3 = pto.vintlv %even1, %odd1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +pto.vsts %d0, %copy_out[%off_0], %all_b32 {dist = "NORM_B32"} +pto.vsts %d1, %copy_out[%off_64], %all_b32 {dist = "NORM_B32"} +pto.vsts %d2, %copy_out[%off_128], %all_b32 {dist = "NORM_B32"} +pto.vsts %d3, %copy_out[%off_192], %all_b32 {dist = "NORM_B32"} +``` + +Memory result: + +```text +for r = 0..7: + sum_out[group_off + r] = reduce(row_r[0..31]) + +for i = 0..255: + copy_out[off + i] = base[off + i] +``` + +If the `deinterleaved = 4 -> contiguous` plan is not registered, the required +diagnostic is: + +```text +VMI-LAYOUT-CONTRACT: + value %x is required as #pto.vmi.layout by + pto.vmi.group_reduce_addf and as #pto.vmi.layout by + pto.vmi.store, but no registered materialization plan exists at the store + use site. +``` + +### 3.19 S=16 Reduce `block_elems` Plan Selection + +S=16 f32 group reduction has two legal dense input layouts: + +```text +#pto.vmi.layout +#pto.vmi.layout +``` + +`block_elems = 1` is the element-parity layout required by f32->f16 `truncf`. +It is also a valid S=16 reduction layout: each physical part contains eight +values per row, so `VCGADD` can reduce each part and `VADD` can combine the two +partial sums. + +`block_elems = 8` is still useful when the producer is a block load plan such +as `BDINTLV` or `vsldb` over 32B row fragments. Layout assignment must select +between these plans by producer/consumer cost. It must not hard-code S=16 +reduce to `block_elems = 8`. + +#### 3.19.1 Continuous S=16 Reduce And Truncf, `block_elems = 1` + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} +%h = pto.vmi.truncf %x + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> +pto.vmi.store %h, %out[%off] +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%h: + !pto.vmi.vreg<128xf16, #pto.vmi.layout> +``` + +Physical lane map: + +```text +%x_p0 lanes r*8 .. r*8+7 = + row_r[0], row_r[2], row_r[4], ..., row_r[14] + +%x_p1 lanes r*8 .. r*8+7 = + row_r[1], row_r[3], row_r[5], ..., row_r[15] +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%x_p0, %x_p1 = pto.vldsx2 %base[%tile_off], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%s0 = pto.vcgadd %x_p0, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum_block = pto.vadd %s0, %s1, %sum_mask + : !pto.vreg<64xf32> + +pto.vsts %sum_block, %sum_out[%group_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + +%h_even = pto.vcvt %x_p0, %all_b32 {part = "EVEN", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%h_odd = pto.vcvt %x_p1, %all_b32 {part = "ODD", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + +%all_b16 = pto.pge_b16 "PAT_ALL" +%h0 = pto.vor %h_even, %h_odd, %all_b16 + : !pto.vreg<128xf16> +pto.vsts %h0, %out[%off], %all_b16 {dist = "NORM_B16"} + : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + sum_out[group_off + r] = reduce(row_r[0..15]) + +for i = 0..127: + out[off + i] = truncf(base[off + i]) +``` + +#### 3.19.2 Block-Load Producer Fixed To `block_elems = 8` + +This is the real conflict case. The value is fixed to `block_elems = 8` +because the producer is a registered block-load plan. A later `truncf` +requires element-parity `block_elems = 1`. + +VMI input: + +```text +%stride24 = arith.constant 24 : index +%x = pto.vmi.group_load %base[%off], %stride24 + {num_groups = 8, group_size = 16} + : !pto.ptr, index -> !pto.vmi.vreg<128xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} +%h = pto.vmi.truncf %x + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> +pto.vmi.store %h, %out[%off] +``` + +Assigned layouts before the conflicting `truncf` use: + +```text +%x from strided block group_load: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +The reduction path is legal and uses the same `vsldb` block-load shape as +section 3.15.2. The `truncf` path is legal only if one of these plans exists: + +```text +1. rematerialize the original memory producer as block_elems=1 +2. materialize block_elems=8 -> block_elems=1 in registers +3. use an explicitly enabled scratch/reload fallback +``` + +If no such plan is registered, the required diagnostic is: + +```text +VMI-LAYOUT-CONTRACT: + pto.vmi.truncf requires + #pto.vmi.layout, but the source value is + fixed to #pto.vmi.layout by the selected + strided group_load plan. Register a rematerialization or preserving + materialization plan, or avoid consuming this block-loaded value with truncf. +``` + +### 3.20 `group_slots` Control-Flow Join + +`group_slots` values must be allowed to cross control flow. The join type is a +sparse physical tuple, not a dense vector. + +VMI input: + +```text +%sum = scf.if %cond -> !pto.vmi.vreg<128xf32> { + %x = pto.vmi.load %base[%off] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> + %a = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} + scf.yield %a : !pto.vmi.vreg<128xf32> +} else { + %b = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c1 {num_groups = 8} + : !pto.ptr, index -> !pto.vmi.vreg<128xf32> + scf.yield %b : !pto.vmi.vreg<128xf32> +} +%bias = pto.vmi.group_slot_load %bias_base[%bias_off], %c1 {num_groups = 8} + : !pto.ptr, index -> !pto.vmi.vreg<128xf32> +%outv = pto.vmi.addf %sum, %bias +pto.vmi.group_store %outv, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + +%a, %b, %sum, %bias, %outv: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +VPTO lowering result for the join: + +```text +%sum_block = scf.if %cond -> !pto.vreg<64xf32> { + %all_b32 = pto.pge_b32 "PAT_ALL" + %sum_mask = pto.pge_b32 "PAT_VL8" + + %x_lo, %x_hi = pto.vldsx2 %base[%tile_off], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %lo_sum = pto.vcgadd %x_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %hi_sum = pto.vcgadd %x_hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %a_block = pto.vadd %lo_sum, %hi_sum, %sum_mask + : !pto.vreg<64xf32> + scf.yield %a_block : !pto.vreg<64xf32> +} else { + %one_block = pto.pge_b32 "PAT_VL1" + %b_block = pto.vsldb %rhs_base[%rhs_off], %c0_i16, %c0_i16, %one_block + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + scf.yield %b_block : !pto.vreg<64xf32> +} + +%one_block = pto.pge_b32 "PAT_VL1" +%slot_mask = pto.pge_b32 "PAT_VL8" +%bias_block = pto.vsldb %bias_base[%bias_off], %c0_i16, %c0_i16, %one_block + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%out_block = pto.vadd %sum_block, %bias_block, %slot_mask + : !pto.vreg<64xf32> + +pto.vsts %out_block, %out[%group_off], %slot_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + lhs = cond ? reduce(row_r[0..15]) : rhs_base[rhs_off + r] + out[group_off + r] = lhs + bias_base[bias_off + r] +``` + +### 3.21 S=32 Tail With Full-Tile-Readable Source + +This is the positive counterpart to section 3.11.2. Tail participation is +still expressed by masks, but the source additionally promises that reading the +rounded-up 8-row physical tile is memory-safe. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] {full_tile_readable} + : memref<192xf32> -> !pto.vmi.vreg<192xf32> +%mask = pto.vmi.create_mask %c192 : index -> !pto.vmi.mask<192xpred> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 6} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<192xf32, #pto.vmi.layout> + +%mask: + !pto.vmi.mask<192xpred, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<192xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +// Full-tile-readable allows the load plan to read the rounded-up 8-row tile. +// Only rows 0..5 are semantically active. +%data_mask = pto.pge_b32 "PAT_VL48" // 6 rows * 8 lanes per physical part +%sum_mask = pto.pge_b32 "PAT_VL6" + +%x_even_0, %x_odd_0 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_even_1, %x_odd_1 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_p0, %x_p2 = pto.vdintlv %x_even_0, %x_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1, %x_p3 = pto.vdintlv %x_odd_0, %x_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%s0 = pto.vcgadd %x_p0, %data_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %data_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %data_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %data_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> + +pto.vsts %sum_block, %out[%group_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..5: + out[group_off + r] = reduce(row_r[0..31]) +``` + +Rows 6 and 7 may be physically loaded because of `full_tile_readable`, but +their lanes are not active in `%data_mask`, and their group slots are not stored +because `%sum_mask` is `PAT_VL6`. + +### 3.22 `scf.for` Loop-Carried Layout + +Loop-carried VMI values require a layout fixed point. The iter_arg, body block +argument, yield operand, loop result, and later consumer must all agree on one +layout, or `vmi-layout-assignment` must insert a materialization at a legal +dominating use site. + +VMI input: + +```text +%init = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> +%acc = scf.for %i = %c0 to %steps step %c1 + iter_args(%arg = %init) -> !pto.vmi.vreg<256xf32> { + %bias = pto.vmi.broadcast %bias_s + : f32 -> !pto.vmi.vreg<256xf32> + %next = pto.vmi.addf %arg, %bias + scf.yield %next : !pto.vmi.vreg<256xf32> +} +%sum = pto.vmi.group_reduce_addf %acc, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%init, %arg, %bias, %next, %acc: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%init_even_0, %init_odd_0 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%init_even_1, %init_odd_1 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%init_p0, %init_p2 = pto.vdintlv %init_even_0, %init_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%init_p1, %init_p3 = pto.vdintlv %init_odd_0, %init_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%acc_p0, %acc_p1, %acc_p2, %acc_p3 = + scf.for %i = %c0 to %steps step %c1 + iter_args(%arg_p0 = %init_p0, %arg_p1 = %init_p1, + %arg_p2 = %init_p2, %arg_p3 = %init_p3) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %all_b32 = pto.pge_b32 "PAT_ALL" + %bias_p0 = pto.vdup %bias_s, %all_b32 + : f32, !pto.mask -> !pto.vreg<64xf32> + %bias_p1 = pto.vdup %bias_s, %all_b32 + : f32, !pto.mask -> !pto.vreg<64xf32> + %bias_p2 = pto.vdup %bias_s, %all_b32 + : f32, !pto.mask -> !pto.vreg<64xf32> + %bias_p3 = pto.vdup %bias_s, %all_b32 + : f32, !pto.mask -> !pto.vreg<64xf32> + + %next_p0 = pto.vadd %arg_p0, %bias_p0, %all_b32 : !pto.vreg<64xf32> + %next_p1 = pto.vadd %arg_p1, %bias_p1, %all_b32 : !pto.vreg<64xf32> + %next_p2 = pto.vadd %arg_p2, %bias_p2, %all_b32 : !pto.vreg<64xf32> + %next_p3 = pto.vadd %arg_p3, %bias_p3, %all_b32 : !pto.vreg<64xf32> + scf.yield %next_p0, %next_p1, %next_p2, %next_p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" +%s0 = pto.vcgadd %acc_p0, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %acc_p1, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s2 = pto.vcgadd %acc_p2, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s3 = pto.vcgadd %acc_p3, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> +pto.vsts %sum_block, %out[%group_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + for c = 0..31: + acc[row_r, c] = base[row_r, c] + steps * bias_s + out[group_off + r] = reduce(acc[row_r, 0..31]) +``` + +### 3.23 `group_broadcast` With Multiple Dense Consumers + +One `group_slots` value may feed multiple `group_broadcast` uses with different +dense result layout requirements. Layout assignment should rematerialize the +broadcast per use instead of forcing one result layout onto all consumers. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} + +%b_for_mul = pto.vmi.group_broadcast %sum {num_groups = 8} +%y = pto.vmi.mulf %x, %b_for_mul +%ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8} +pto.vmi.group_store %ysum, %sum_out[%group_off], %c1 {num_groups = 8} + +%b_for_cast = pto.vmi.group_broadcast %sum {num_groups = 8} +%h = pto.vmi.truncf %b_for_cast +pto.vmi.store %h, %dense_out[%off] +``` + +Assigned layouts: + +```text +%x, %b_for_mul, %y: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + +%sum, %ysum: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%b_for_cast: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%h: + !pto.vmi.vreg<128xf16, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%x_lo, %x_hi = pto.vldsx2 %base[%tile_off], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_lo_sum = pto.vcgadd %x_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%x_hi_sum = pto.vcgadd %x_hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum_block = pto.vadd %x_lo_sum, %x_hi_sum, %sum_mask + : !pto.vreg<64xf32> + +%lane_id = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> +%broadcast_idx = pto.vshrs %lane_id, %c3_i16, %all_b32 + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> + +// Use 1: broadcast for the S=16 block_elems=8 multiply path. Both row halves +// use the same per-row broadcast vector. +%b_rows_for_mul = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%y_lo = pto.vmul %x_lo, %b_rows_for_mul, %all_b32 : !pto.vreg<64xf32> +%y_hi = pto.vmul %x_hi, %b_rows_for_mul, %all_b32 : !pto.vreg<64xf32> +%y_lo_sum = pto.vcgadd %y_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%y_hi_sum = pto.vcgadd %y_hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%ysum_block = pto.vadd %y_lo_sum, %y_hi_sum, %sum_mask + : !pto.vreg<64xf32> +pto.vsts %ysum_block, %sum_out[%group_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + +// Use 2: rematerialize broadcast for the f32->f16 parity cast path. The +// deinterleaved=2 physical bundle has two entries that alias this SSA value. +%b_rows_for_cast = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%h_even = pto.vcvt %b_rows_for_cast, %all_b32 + {part = "EVEN", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%h_odd = pto.vcvt %b_rows_for_cast, %all_b32 + {part = "ODD", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%all_b16 = pto.pge_b16 "PAT_ALL" +%h0 = pto.vor %h_even, %h_odd, %all_b16 : !pto.vreg<128xf16> +pto.vsts %h0, %dense_out[%off], %all_b16 {dist = "NORM_B16"} + : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + s = reduce(row_r[0..15]) + sum_out[group_off + r] = reduce_i(row_r[i] * s for i = 0..15) + dense_out[r * 16 + 0 .. r * 16 + 15] = truncf(s) +``` + +### 3.24 Mask With Elementwise, Select, And Store + +This case separates compute masking from memory effects. A masked elementwise +operation with passthrough semantics can be represented as ordinary compute +plus `select`; a masked store uses the mask only on the store effect. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<64xf32> -> !pto.vmi.vreg<64xf32> +%rhs = pto.vmi.load %rhs_base[%off] + : memref<64xf32> -> !pto.vmi.vreg<64xf32> +%mask = pto.vmi.create_mask %c48 + : index -> !pto.vmi.mask<64xpred> +%sum = pto.vmi.addf %x, %rhs +%passthrough = pto.vmi.select %mask, %sum, %x +pto.vmi.store %passthrough, %dense_out[%off] +pto.vmi.masked_store %sum, %masked_out[%off], %mask +``` + +Assigned layouts: + +```text +%x, %rhs, %sum, %passthrough: + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + +%mask: + !pto.vmi.mask<64xpred, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%m = pto.pge_b32 "PAT_VL48" + +%x0 = pto.vlds %base[%off] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%rhs0 = pto.vlds %rhs_base[%off] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%sum0 = pto.vadd %x0, %rhs0, %all_b32 : !pto.vreg<64xf32> + +%pass0 = pto.vsel %sum0, %x0, %m + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %pass0, %dense_out[%off], %all_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + +pto.vsts %sum0, %masked_out[%off], %m {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for i = 0..63: + if i < 48: + dense_out[off + i] = base[off + i] + rhs_base[off + i] + masked_out[off + i] = base[off + i] + rhs_base[off + i] + else: + dense_out[off + i] = base[off + i] + masked_out[off + i] is unchanged +``` + +### 3.25 Function Boundary Layout Specialization + +Function boundaries cannot rely on hidden layout side tables. Either the +function is internal and layout-specialized by `vmi-layout-assignment`, or a +public/external VMI boundary must diagnose until a stable VMI ABI is defined. + +#### 3.25.1 Internal Function Specialized To Consumer Layout + +VMI input: + +```text +func.func private @producer(%base: !pto.ptr, %off: index) + -> !pto.vmi.vreg<256xf32> { + %x = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> + return %x : !pto.vmi.vreg<256xf32> +} + +func.func @caller(%base: !pto.ptr, %off: index, %out: !pto.ptr) { + %x = call @producer(%base, %off) + : (!pto.ptr, index) -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} + pto.vmi.group_store %sum, %out[%off], %c1 {num_groups = 8} + return +} +``` + +Assigned layouts: + +```text +@producer result: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%x in @caller: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result for the function boundary: + +```text +func.func private @producer(...) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %x_even_0, %x_odd_0 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %x_even_1, %x_odd_1 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %x_p0, %x_p2 = pto.vdintlv %x_even_0, %x_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %x_p1, %x_p3 = pto.vdintlv %x_odd_0, %x_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + return %x_p0, %x_p1, %x_p2, %x_p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32> +} + +func.func @caller(...) { + %x_p0, %x_p1, %x_p2, %x_p3 = call @producer(...) + : (...) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) + + %all_b32 = pto.pge_b32 "PAT_ALL" + %sum_mask = pto.pge_b32 "PAT_VL8" + %s0 = pto.vcgadd %x_p0, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %s1 = pto.vcgadd %x_p1, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %s2 = pto.vcgadd %x_p2, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %s3 = pto.vcgadd %x_p3, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> + %s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> + %sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> + pto.vsts %sum_block, %out[%group_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +Memory result: + +```text +for r = 0..7: + out[off + r] = reduce(row_r[0..31]) +``` + +#### 3.25.2 Public Or External VMI Boundary + +VMI input: + +```text +func.func @public_producer(%base: !pto.ptr, %off: index) + -> !pto.vmi.vreg<256xf32> attributes {public} { + %x = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> + return %x : !pto.vmi.vreg<256xf32> +} +``` + +Required diagnostic for the initial design: + +```text +VMI-LAYOUT-CONTRACT: + public or external function boundary returns !pto.vmi.vreg<256xf32> without a + stable VMI layout ABI. Mark the function internal for layout specialization, + inline it before vmi-layout-assignment, or define an explicit ABI layout. +``` diff --git a/docs/isa/micro-isa/10-reduction-ops.md b/docs/isa/micro-isa/10-reduction-ops.md index ecae818f2c..2129f91ce0 100644 --- a/docs/isa/micro-isa/10-reduction-ops.md +++ b/docs/isa/micro-isa/10-reduction-ops.md @@ -206,7 +206,9 @@ VLane 4: [32..39] VLane 5: [40..47] VLane 6: [48..55] VLane 7: [56..63] - **syntax:** `%result = pto.vcgadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` - **A5 types:** i16-i32, f16, f32 -- **semantics:** Sum within each VLane. 8 results at indices 0, 8, 16, 24, 32, 40, 48, 56 (for f32). +- **semantics:** Sum within each 32-byte VLane. The 8 VLane results are written + continuously to the low lanes of the destination vector. For f32, results are + at indices 0, 1, 2, 3, 4, 5, 6, 7. ```c int K = N / 8; // elements per VLane @@ -214,17 +216,17 @@ for (int g = 0; g < 8; g++) { T sum = 0; for (int i = 0; i < K; i++) sum += src[g*K + i]; - dst[g*K] = sum; - for (int i = 1; i < K; i++) - dst[g*K + i] = 0; + dst[g] = sum; } -// For f32: results at dst[0], dst[8], dst[16], dst[24], dst[32], dst[40], dst[48], dst[56] +for (int i = 8; i < N; i++) + dst[i] = 0; +// For f32: results at dst[0], dst[1], ..., dst[7]. ``` - **inputs:** `%input` is the source vector and `%mask` selects participating lanes. - **outputs:** `%result` contains one sum per 32-byte VLane group, written - contiguously into the low slot of each group. + continuously to the low lanes of the destination vector. - **constraints and limitations:** This is a per-32-byte VLane-group reduction. Inactive lanes are treated as zero. @@ -242,10 +244,10 @@ for (int g = 0; g < 8; g++) { T mx = -INF; for (int i = 0; i < K; i++) if (src[g*K + i] > mx) mx = src[g*K + i]; - dst[g*K] = mx; - for (int i = 1; i < K; i++) - dst[g*K + i] = 0; + dst[g] = mx; } +for (int i = 8; i < N; i++) + dst[i] = 0; ``` - **inputs:** `%input` is the source vector and `%mask` selects participating @@ -268,10 +270,10 @@ for (int g = 0; g < 8; g++) { T mn = INF; for (int i = 0; i < K; i++) if (src[g*K + i] < mn) mn = src[g*K + i]; - dst[g*K] = mn; - for (int i = 1; i < K; i++) - dst[g*K + i] = 0; + dst[g] = mn; } +for (int i = 8; i < N; i++) + dst[i] = 0; ``` - **inputs:** `%input` is the source vector and `%mask` selects participating @@ -320,7 +322,7 @@ for (int i = 1; i < N; i++) // Row-wise sum using vcgadd (for 8-row tile) %row_sums = pto.vcgadd %tile, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> -// Results at indices 0, 8, 16, 24, 32, 40, 48, 56 +// Results at indices 0, 1, 2, 3, 4, 5, 6, 7 // Full vector sum for normalization %total = pto.vcadd %values, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> From 50bffab32c2ed957bd3f072e1eeedb58cd99ff1a Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 11:19:45 +0800 Subject: [PATCH 04/31] Add VMI layout assignment lowering coverage --- README.md | 5 + .../vmi-layout-assignment-implementation.md | 1469 ++++++++ .../vmi-layout-assignment-lowering-design.md | 625 ++++ docs/designs/vmi-layout-lowering-cases.md | 2318 ++++++++++++- include/PTO/IR/VMIAttrs.td | 10 +- include/PTO/IR/VMIOps.td | 28 +- lib/PTO/IR/PTO.cpp | 24 + lib/PTO/IR/VMI.cpp | 548 ++- lib/PTO/Transforms/VMILayoutAssignment.cpp | 806 ++++- lib/PTO/Transforms/VMIToVPTO.cpp | 2940 +++++++++++------ .../lit/vmi/vmi_create_group_mask_invalid.pto | 20 + ...assignment_broadcast_dense_group_users.pto | 75 + ...yout_assignment_call_argument_boundary.pto | 74 + ...ayout_assignment_create_group_mask_s16.pto | 54 + ..._layout_assignment_dense_f16_f32_store.pto | 77 + ...ment_dense_group_reduce_multi_consumer.pto | 58 + ...gnment_dense_store_group_slots_invalid.pto | 32 + ..._layout_assignment_f32_f8_store_reduce.pto | 65 + .../vmi_layout_assignment_f8_compute_f8.pto | 61 + ...ignment_group_broadcast_multi_consumer.pto | 83 + ...yout_assignment_group_broadcast_slots8.pto | 27 + .../vmi/vmi_layout_assignment_group_load.pto | 27 + ...nment_group_load_block8_truncf_invalid.pto | 42 + ...roup_load_s16_compact_stride12_invalid.pto | 31 + ...assignment_group_load_s16_stride_store.pto | 50 + ...roup_load_s16_unaligned_stride_invalid.pto | 31 + ...group_load_s32_stride_broadcast_reduce.pto | 71 + ...assignment_group_load_s32_stride_store.pto | 51 + ...roup_load_s32_unaligned_stride_invalid.pto | 31 + ...ut_assignment_group_reduce_s12_invalid.pto | 29 + ...yout_assignment_group_reduce_s16_store.pto | 53 + ...roup_reduce_s16_truncf_broadcast_store.pto | 59 + ...ment_group_reduce_s32_broadcast_reduce.pto | 66 + ...nment_group_reduce_s32_multitile_store.pto | 53 + ...yout_assignment_group_reduce_s32_store.pto | 52 + ...gnment_group_reduce_s32_tail_full_tile.pto | 85 + ...p_reduce_s32_tail_no_full_tile_invalid.pto | 33 + ...vmi_layout_assignment_group_reduce_s64.pto | 29 + ...ment_group_reduce_s64_broadcast_reduce.pto | 57 + ...assignment_group_reduce_s64_tail_store.pto | 42 + ...out_assignment_group_reduce_s64_truncf.pto | 51 + ..._layout_assignment_group_reduce_slots8.pto | 29 + ...t_assignment_group_reduce_slots8_store.pto | 44 + .../vmi_layout_assignment_group_slot_load.pto | 58 + ...assignment_group_slot_load_dual_layout.pto | 76 + ...lot_load_slots1_dynamic_stride_invalid.pto | 24 + ...t_load_slots1_unaligned_stride_invalid.pto | 25 + ..._layout_assignment_group_slots_cf_join.pto | 59 + ...i_layout_assignment_group_slots_fanout.pto | 68 + ..._layout_assignment_group_slots_scf_for.pto | 79 + ...group_store_slots1_unit_stride_invalid.pto | 32 + ...ignment_mask_granularity_f32_f16_store.pto | 61 + ...mi_layout_assignment_mask_select_store.pto | 64 + ...signment_masked_load_dense_group_users.pto | 66 + ..._assignment_masked_load_group_tail_s32.pto | 39 + ..._layout_assignment_non_load_s32_reduce.pto | 62 + ...ment_packed_group_slots_truncf_invalid.pto | 35 + ...yout_assignment_widen_f16_store_reduce.pto | 64 + .../vmi/vmi_layout_group_slots_invalid.pto | 18 + .../vmi/vmi_load_full_read_elems_invalid.pto | 20 + test/lit/vmi/vmi_op_verifier_basic.pto | 7 + ...i_ptoas_call_boundary_vecscope_invalid.pto | 35 + .../vmi_to_vpto_group_broadcast_slots8.pto | 43 + ..._broadcast_slots8_missing_plan_invalid.pto | 29 + ...o_vpto_group_load_missing_plan_invalid.pto | 29 + test/lit/vmi/vmi_to_vpto_group_ops.pto | 3 +- test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto | 45 + ..._group_reduce_s64_missing_plan_invalid.pto | 30 + .../vmi/vmi_to_vpto_group_reduce_slots8.pto | 34 + ...oup_reduce_slots8_missing_plan_invalid.pto | 28 + test/lit/vmi/vmi_to_vpto_group_slot_load.pto | 74 + ...o_group_slot_load_missing_plan_invalid.pto | 27 + ...group_slot_load_nonunit_slots8_invalid.pto | 25 + .../vmi_to_vpto_group_slot_truncf_slots1.pto | 39 + ...lot_truncf_slots1_missing_plan_invalid.pto | 28 + ...pto_group_store_slots8_nonunit_invalid.pto | 26 + test/lit/vmi/vmi_to_vpto_quant_dequant.pto | 4 +- test/lit/vmi/vmi_type_attr_parse.pto | 15 +- .../broadcast-dense-group-users/compare.py | 40 + .../vmi/broadcast-dense-group-users/golden.py | 47 + .../broadcast-dense-group-users/kernel.pto | 68 + .../broadcast-dense-group-users/launch.cpp | 33 + .../vmi/broadcast-dense-group-users/main.cpp | 97 + .../broadcast-dense-group-users/ptoas.flags | 1 + .../compare.py | 32 + .../golden.py | 50 + .../kernel.pto | 57 + .../launch.cpp | 35 + .../main.cpp | 94 + .../ptoas.flags | 1 + .../vmi/f32-to-f8-store-reduce/compare.py | 49 + .../vmi/f32-to-f8-store-reduce/golden.py | 55 + .../vmi/f32-to-f8-store-reduce/kernel.pto | 62 + .../vmi/f32-to-f8-store-reduce/launch.cpp | 41 + .../cases/vmi/f32-to-f8-store-reduce/main.cpp | 94 + .../vmi/f32-to-f8-store-reduce/ptoas.flags | 1 + test/vpto/cases/vmi/f8-compute-f8/compare.py | 27 + test/vpto/cases/vmi/f8-compute-f8/golden.py | 40 + test/vpto/cases/vmi/f8-compute-f8/kernel.pto | 55 + test/vpto/cases/vmi/f8-compute-f8/launch.cpp | 40 + test/vpto/cases/vmi/f8-compute-f8/main.cpp | 76 + test/vpto/cases/vmi/f8-compute-f8/ptoas.flags | 1 + .../group-broadcast-multi-consumer/compare.py | 44 + .../group-broadcast-multi-consumer/golden.py | 54 + .../group-broadcast-multi-consumer/kernel.pto | 69 + .../group-broadcast-multi-consumer/launch.cpp | 42 + .../group-broadcast-multi-consumer/main.cpp | 92 + .../ptoas.flags | 1 + .../group-load-s16-stride-store/compare.py | 27 + .../vmi/group-load-s16-stride-store/golden.py | 48 + .../group-load-s16-stride-store/kernel.pto | 51 + .../group-load-s16-stride-store/launch.cpp | 32 + .../vmi/group-load-s16-stride-store/main.cpp | 80 + .../group-load-s16-stride-store/ptoas.flags | 1 + .../compare.py | 27 + .../golden.py | 49 + .../kernel.pto | 59 + .../launch.cpp | 34 + .../main.cpp | 82 + .../ptoas.flags | 1 + .../group-load-s32-stride-store/compare.py | 27 + .../vmi/group-load-s32-stride-store/golden.py | 48 + .../group-load-s32-stride-store/kernel.pto | 51 + .../group-load-s32-stride-store/launch.cpp | 32 + .../vmi/group-load-s32-stride-store/main.cpp | 80 + .../group-load-s32-stride-store/ptoas.flags | 1 + .../vmi/group-reduce-basic-store/compare.py | 42 + .../vmi/group-reduce-basic-store/golden.py | 50 + .../vmi/group-reduce-basic-store/kernel.pto | 92 + .../vmi/group-reduce-basic-store/launch.cpp | 40 + .../vmi/group-reduce-basic-store/main.cpp | 123 + .../vmi/group-reduce-basic-store/ptoas.flags | 1 + .../compare.py | 27 + .../golden.py | 48 + .../kernel.pto | 57 + .../launch.cpp | 34 + .../main.cpp | 82 + .../ptoas.flags | 1 + .../compare.py | 30 + .../golden.py | 48 + .../kernel.pto | 63 + .../launch.cpp | 34 + .../main.cpp | 81 + .../ptoas.flags | 1 + .../compare.py | 30 + .../golden.py | 46 + .../kernel.pto | 55 + .../launch.cpp | 34 + .../main.cpp | 82 + .../ptoas.flags | 1 + .../compare.py | 30 + .../golden.py | 49 + .../kernel.pto | 55 + .../launch.cpp | 34 + .../main.cpp | 81 + .../ptoas.flags | 1 + .../compare.py | 27 + .../golden.py | 48 + .../kernel.pto | 54 + .../launch.cpp | 42 + .../main.cpp | 80 + .../ptoas.flags | 1 + .../compare.py | 27 + .../group-reduce-s32-add-bias-store/golden.py | 48 + .../kernel.pto | 54 + .../launch.cpp | 33 + .../group-reduce-s32-add-bias-store/main.cpp | 81 + .../ptoas.flags | 1 + .../compare.py | 27 + .../golden.py | 48 + .../kernel.pto | 57 + .../launch.cpp | 34 + .../main.cpp | 82 + .../ptoas.flags | 1 + .../group-reduce-s32-cf-join-store/compare.py | 27 + .../group-reduce-s32-cf-join-store/golden.py | 47 + .../group-reduce-s32-cf-join-store/kernel.pto | 63 + .../group-reduce-s32-cf-join-store/launch.cpp | 33 + .../group-reduce-s32-cf-join-store/main.cpp | 81 + .../ptoas.flags | 1 + .../compare.py | 27 + .../golden.py | 47 + .../kernel.pto | 49 + .../launch.cpp | 33 + .../group-reduce-s32-multitile-store/main.cpp | 81 + .../ptoas.flags | 1 + .../compare.py | 28 + .../golden.py | 49 + .../kernel.pto | 53 + .../launch.cpp | 34 + .../main.cpp | 82 + .../ptoas.flags | 1 + .../compare.py | 27 + .../golden.py | 50 + .../kernel.pto | 61 + .../launch.cpp | 34 + .../main.cpp | 83 + .../ptoas.flags | 1 + .../compare.py | 36 + .../group-reduce-s64-slot-add-store/golden.py | 51 + .../kernel.pto | 64 + .../launch.cpp | 35 + .../group-reduce-s64-slot-add-store/main.cpp | 94 + .../ptoas.flags | 1 + .../group-reduce-s64-tail-store/compare.py | 30 + .../vmi/group-reduce-s64-tail-store/golden.py | 46 + .../group-reduce-s64-tail-store/kernel.pto | 52 + .../group-reduce-s64-tail-store/launch.cpp | 32 + .../vmi/group-reduce-s64-tail-store/main.cpp | 81 + .../group-reduce-s64-tail-store/ptoas.flags | 1 + .../group-reduce-s64-truncf-store/compare.py | 32 + .../group-reduce-s64-truncf-store/golden.py | 47 + .../group-reduce-s64-truncf-store/kernel.pto | 54 + .../group-reduce-s64-truncf-store/launch.cpp | 40 + .../group-reduce-s64-truncf-store/main.cpp | 79 + .../group-reduce-s64-truncf-store/ptoas.flags | 1 + .../group-reduce-slot-add-store/compare.py | 41 + .../vmi/group-reduce-slot-add-store/golden.py | 57 + .../group-reduce-slot-add-store/kernel.pto | 86 + .../group-reduce-slot-add-store/launch.cpp | 38 + .../vmi/group-reduce-slot-add-store/main.cpp | 113 + .../group-reduce-slot-add-store/ptoas.flags | 1 + .../vmi/group-slots-cf-join-store/compare.py | 38 + .../vmi/group-slots-cf-join-store/golden.py | 53 + .../vmi/group-slots-cf-join-store/kernel.pto | 97 + .../vmi/group-slots-cf-join-store/launch.cpp | 44 + .../vmi/group-slots-cf-join-store/main.cpp | 102 + .../vmi/group-slots-cf-join-store/ptoas.flags | 1 + .../compare.py | 38 + .../golden.py | 53 + .../kernel.pto | 71 + .../launch.cpp | 44 + .../main.cpp | 93 + .../ptoas.flags | 1 + .../vmi/group-slots-scf-for-store/compare.py | 36 + .../vmi/group-slots-scf-for-store/golden.py | 44 + .../vmi/group-slots-scf-for-store/kernel.pto | 68 + .../vmi/group-slots-scf-for-store/launch.cpp | 33 + .../vmi/group-slots-scf-for-store/main.cpp | 95 + .../vmi/group-slots-scf-for-store/ptoas.flags | 1 + .../mask-granularity-f32-f16-store/compare.py | 52 + .../mask-granularity-f32-f16-store/golden.py | 49 + .../mask-granularity-f32-f16-store/kernel.pto | 60 + .../mask-granularity-f32-f16-store/launch.cpp | 43 + .../mask-granularity-f32-f16-store/main.cpp | 91 + .../ptoas.flags | 1 + .../cases/vmi/mask-select-store/compare.py | 32 + .../cases/vmi/mask-select-store/golden.py | 51 + .../cases/vmi/mask-select-store/kernel.pto | 71 + .../cases/vmi/mask-select-store/launch.cpp | 42 + .../vpto/cases/vmi/mask-select-store/main.cpp | 99 + .../cases/vmi/mask-select-store/ptoas.flags | 1 + .../masked-load-dense-group-users/compare.py | 40 + .../masked-load-dense-group-users/golden.py | 46 + .../masked-load-dense-group-users/kernel.pto | 61 + .../masked-load-dense-group-users/launch.cpp | 33 + .../masked-load-dense-group-users/main.cpp | 97 + .../masked-load-dense-group-users/ptoas.flags | 1 + .../vmi/scf-for-loop-carried-store/compare.py | 27 + .../vmi/scf-for-loop-carried-store/golden.py | 41 + .../vmi/scf-for-loop-carried-store/kernel.pto | 53 + .../vmi/scf-for-loop-carried-store/launch.cpp | 32 + .../vmi/scf-for-loop-carried-store/main.cpp | 78 + .../scf-for-loop-carried-store/ptoas.flags | 1 + .../widen-f16-to-f32-store-reduce/compare.py | 38 + .../widen-f16-to-f32-store-reduce/golden.py | 50 + .../widen-f16-to-f32-store-reduce/kernel.pto | 67 + .../widen-f16-to-f32-store-reduce/launch.cpp | 42 + .../widen-f16-to-f32-store-reduce/main.cpp | 92 + .../widen-f16-to-f32-store-reduce/ptoas.flags | 1 + 270 files changed, 18976 insertions(+), 1444 deletions(-) create mode 100644 docs/designs/vmi-layout-assignment-implementation.md create mode 100644 docs/designs/vmi-layout-assignment-lowering-design.md create mode 100644 test/lit/vmi/vmi_create_group_mask_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_call_argument_boundary.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_dense_f16_f32_store.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_f8_compute_f8.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_broadcast_slots8.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_load.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_load_s16_compact_stride12_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_load_s16_stride_store.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_load_s16_unaligned_stride_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_broadcast_reduce.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_store.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_load_s32_unaligned_stride_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_s16_store.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_s32_broadcast_reduce.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_s32_multitile_store.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_s32_store.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_s64.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_s64_broadcast_reduce.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_s64_tail_store.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_s64_truncf.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_slots8.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_slots8_store.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_slot_load.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_slots_cf_join.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_slots_fanout.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_mask_granularity_f32_f16_store.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_mask_select_store.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto create mode 100644 test/lit/vmi/vmi_layout_group_slots_invalid.pto create mode 100644 test/lit/vmi/vmi_load_full_read_elems_invalid.pto create mode 100644 test/lit/vmi/vmi_ptoas_call_boundary_vecscope_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_broadcast_slots8.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_missing_plan_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_load_missing_plan_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_reduce_s64_missing_plan_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_reduce_slots8.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_reduce_slots8_missing_plan_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_slot_load.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_slot_load_missing_plan_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_missing_plan_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_store_slots8_nonunit_invalid.pto create mode 100644 test/vpto/cases/vmi/broadcast-dense-group-users/compare.py create mode 100644 test/vpto/cases/vmi/broadcast-dense-group-users/golden.py create mode 100644 test/vpto/cases/vmi/broadcast-dense-group-users/kernel.pto create mode 100644 test/vpto/cases/vmi/broadcast-dense-group-users/launch.cpp create mode 100644 test/vpto/cases/vmi/broadcast-dense-group-users/main.cpp create mode 100644 test/vpto/cases/vmi/broadcast-dense-group-users/ptoas.flags create mode 100644 test/vpto/cases/vmi/dense-group-reduce-multi-consumer/compare.py create mode 100644 test/vpto/cases/vmi/dense-group-reduce-multi-consumer/golden.py create mode 100644 test/vpto/cases/vmi/dense-group-reduce-multi-consumer/kernel.pto create mode 100644 test/vpto/cases/vmi/dense-group-reduce-multi-consumer/launch.cpp create mode 100644 test/vpto/cases/vmi/dense-group-reduce-multi-consumer/main.cpp create mode 100644 test/vpto/cases/vmi/dense-group-reduce-multi-consumer/ptoas.flags create mode 100644 test/vpto/cases/vmi/f32-to-f8-store-reduce/compare.py create mode 100644 test/vpto/cases/vmi/f32-to-f8-store-reduce/golden.py create mode 100644 test/vpto/cases/vmi/f32-to-f8-store-reduce/kernel.pto create mode 100644 test/vpto/cases/vmi/f32-to-f8-store-reduce/launch.cpp create mode 100644 test/vpto/cases/vmi/f32-to-f8-store-reduce/main.cpp create mode 100644 test/vpto/cases/vmi/f32-to-f8-store-reduce/ptoas.flags create mode 100644 test/vpto/cases/vmi/f8-compute-f8/compare.py create mode 100644 test/vpto/cases/vmi/f8-compute-f8/golden.py create mode 100644 test/vpto/cases/vmi/f8-compute-f8/kernel.pto create mode 100644 test/vpto/cases/vmi/f8-compute-f8/launch.cpp create mode 100644 test/vpto/cases/vmi/f8-compute-f8/main.cpp create mode 100644 test/vpto/cases/vmi/f8-compute-f8/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-broadcast-multi-consumer/compare.py create mode 100644 test/vpto/cases/vmi/group-broadcast-multi-consumer/golden.py create mode 100644 test/vpto/cases/vmi/group-broadcast-multi-consumer/kernel.pto create mode 100644 test/vpto/cases/vmi/group-broadcast-multi-consumer/launch.cpp create mode 100644 test/vpto/cases/vmi/group-broadcast-multi-consumer/main.cpp create mode 100644 test/vpto/cases/vmi/group-broadcast-multi-consumer/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-load-s16-stride-store/compare.py create mode 100644 test/vpto/cases/vmi/group-load-s16-stride-store/golden.py create mode 100644 test/vpto/cases/vmi/group-load-s16-stride-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-load-s16-stride-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-load-s16-stride-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-load-s16-stride-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/compare.py create mode 100644 test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/golden.py create mode 100644 test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/kernel.pto create mode 100644 test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/launch.cpp create mode 100644 test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/main.cpp create mode 100644 test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-load-s32-stride-store/compare.py create mode 100644 test/vpto/cases/vmi/group-load-s32-stride-store/golden.py create mode 100644 test/vpto/cases/vmi/group-load-s32-stride-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-load-s32-stride-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-load-s32-stride-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-load-s32-stride-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-basic-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-basic-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-basic-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-basic-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-basic-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-basic-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-s32-add-bias-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-s32-add-bias-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-s32-add-bias-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-s32-add-bias-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s32-add-bias-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s32-add-bias-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-s32-cf-join-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-s32-cf-join-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-s32-cf-join-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-s32-cf-join-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s32-cf-join-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s32-cf-join-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-s32-multitile-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-s32-multitile-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-s32-multitile-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-s32-multitile-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s32-multitile-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s32-multitile-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-s64-slot-add-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-s64-slot-add-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-s64-slot-add-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-s64-slot-add-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s64-slot-add-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s64-slot-add-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-s64-tail-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-s64-tail-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-s64-tail-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-s64-tail-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s64-tail-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s64-tail-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-s64-truncf-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-s64-truncf-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-s64-truncf-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-s64-truncf-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s64-truncf-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s64-truncf-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-slot-add-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-slot-add-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-slot-add-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-slot-add-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-slot-add-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-slot-add-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-slots-cf-join-store/compare.py create mode 100644 test/vpto/cases/vmi/group-slots-cf-join-store/golden.py create mode 100644 test/vpto/cases/vmi/group-slots-cf-join-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-slots-cf-join-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-slots-cf-join-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-slots-cf-join-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-slots-fanout-store-broadcast/compare.py create mode 100644 test/vpto/cases/vmi/group-slots-fanout-store-broadcast/golden.py create mode 100644 test/vpto/cases/vmi/group-slots-fanout-store-broadcast/kernel.pto create mode 100644 test/vpto/cases/vmi/group-slots-fanout-store-broadcast/launch.cpp create mode 100644 test/vpto/cases/vmi/group-slots-fanout-store-broadcast/main.cpp create mode 100644 test/vpto/cases/vmi/group-slots-fanout-store-broadcast/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-slots-scf-for-store/compare.py create mode 100644 test/vpto/cases/vmi/group-slots-scf-for-store/golden.py create mode 100644 test/vpto/cases/vmi/group-slots-scf-for-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-slots-scf-for-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-slots-scf-for-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-slots-scf-for-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/mask-granularity-f32-f16-store/compare.py create mode 100644 test/vpto/cases/vmi/mask-granularity-f32-f16-store/golden.py create mode 100644 test/vpto/cases/vmi/mask-granularity-f32-f16-store/kernel.pto create mode 100644 test/vpto/cases/vmi/mask-granularity-f32-f16-store/launch.cpp create mode 100644 test/vpto/cases/vmi/mask-granularity-f32-f16-store/main.cpp create mode 100644 test/vpto/cases/vmi/mask-granularity-f32-f16-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/mask-select-store/compare.py create mode 100644 test/vpto/cases/vmi/mask-select-store/golden.py create mode 100644 test/vpto/cases/vmi/mask-select-store/kernel.pto create mode 100644 test/vpto/cases/vmi/mask-select-store/launch.cpp create mode 100644 test/vpto/cases/vmi/mask-select-store/main.cpp create mode 100644 test/vpto/cases/vmi/mask-select-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/masked-load-dense-group-users/compare.py create mode 100644 test/vpto/cases/vmi/masked-load-dense-group-users/golden.py create mode 100644 test/vpto/cases/vmi/masked-load-dense-group-users/kernel.pto create mode 100644 test/vpto/cases/vmi/masked-load-dense-group-users/launch.cpp create mode 100644 test/vpto/cases/vmi/masked-load-dense-group-users/main.cpp create mode 100644 test/vpto/cases/vmi/masked-load-dense-group-users/ptoas.flags create mode 100644 test/vpto/cases/vmi/scf-for-loop-carried-store/compare.py create mode 100644 test/vpto/cases/vmi/scf-for-loop-carried-store/golden.py create mode 100644 test/vpto/cases/vmi/scf-for-loop-carried-store/kernel.pto create mode 100644 test/vpto/cases/vmi/scf-for-loop-carried-store/launch.cpp create mode 100644 test/vpto/cases/vmi/scf-for-loop-carried-store/main.cpp create mode 100644 test/vpto/cases/vmi/scf-for-loop-carried-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/compare.py create mode 100644 test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/golden.py create mode 100644 test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/kernel.pto create mode 100644 test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/launch.cpp create mode 100644 test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/main.cpp create mode 100644 test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/ptoas.flags diff --git a/README.md b/README.md index 0d8399783f..b3a547ab04 100644 --- a/README.md +++ b/README.md @@ -206,6 +206,11 @@ ptoas test/lit/pto/empty_func.pto --pto-arch=a5 -o outputfile.cpp # 指定构建 Level(level3 会禁用 PlanMemory/InsertSync) ptoas test/lit/pto/empty_func.pto --pto-level=level3 -o outputfile.cpp +# 启用实验性 VMI -> VPTO 语义 pipeline +# 该模式要求 --pto-backend=vpto,或输入 IR 中带 pto.backend = "vpto" +# public function signature 不能直接暴露 !pto.vmi.* 类型 +ptoas test/lit/vmi/vmi_ptoas_cli_pipeline.pto --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto -o - + # 查看当前 ptoas release 版本号 ptoas --version diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md new file mode 100644 index 0000000000..e6b39dd984 --- /dev/null +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -0,0 +1,1469 @@ +# VMI Layout Assignment Implementation Plan + +本文是 `vmi-layout-assignment` 和 `vmi-to-vpto` 的实现计划。它配套 +`vmi-layout-assignment-lowering-design.md`,并以 +`vmi-layout-lowering-cases.md` 为测试和验收来源。 + +不使用旧 `vmi-dialect-design.md` 作为设计输入。 + +## 1. Pipeline + +Recommended pass pipeline: + +```text +pto-validate-vmi-surface + -> vmi-layout-assignment + -> pto-validate-vmi-layout + -> vmi-to-vpto + -> canonicalize/cse + -> existing VPTO lowering/codegen +``` + +Pass responsibilities: + +```text +pto-validate-vmi-surface: + verify surface VMI has no physical VPTO layout dependency + reject public/external VMI ABI unless explicitly enabled + +vmi-layout-assignment: + solve value layouts + choose selected lowering plans + insert ensure/rematerialization helpers + make internal function boundary layouts explicit + rewrite VMI types with layout attrs + +pto-validate-vmi-layout: + verify every VMI data/mask value has layout + verify every context-sensitive op has selected_plan + verify helper ops have registered materialization plans + +vmi-to-vpto: + use OneToN type conversion + lower only from explicit layout/plan information + emit VPTO or precise unsupported diagnostic +``` + +## 2. Files To Add Or Update + +Expected implementation files: + +```text +include/PTO/IR/VMITypes.td +include/PTO/IR/VMIOps.td +include/PTO/IR/VMIAttrs.td +lib/PTO/IR/VMI.cpp + +include/PTO/Transforms/Passes.td +lib/PTO/Transforms/ValidateVMI.cpp +lib/PTO/Transforms/VMILayoutAssignment.cpp +lib/PTO/Transforms/VMIToVPTO.cpp +lib/PTO/Transforms/VMILayoutPlanRegistry.cpp + +test/lit/vmi/vmi_layout_assignment_*.pto +test/lit/vmi/vmi_to_vpto_*.pto +test/vpto/cases/vmi/*/ +``` + +Exact names may follow project conventions, but the layering should remain: + +```text +IR definitions + -> validation + -> assignment + -> OneToN lowering + -> lit and sim tests +``` + +## 3. IR Types And Attributes + +### 3.1 Layout Attribute + +Represent layout as a closed attribute family: + +```text +#pto.vmi.layout +#pto.vmi.layout +#pto.vmi.layout +``` + +C++ form: + +```c++ +enum class VMILayoutKind { + Contiguous, + Deinterleaved, + GroupSlots, +}; + +struct VMILayoutKey { + VMILayoutKind kind; + int64_t deinterleaveFactor = 1; + int64_t blockElems = 1; + int64_t numGroups = 0; + int64_t slots = 0; +}; +``` + +Verifier rules: + +```text +contiguous: + no extra parameters + +deinterleaved: + F > 1 + B > 0 + direct full-chunk plans require N % (F * B) == 0 + +group_slots: + G > 0 + K > 0 + G % K == 0 + K fits in one physical vreg for element type +``` + +Parser compatibility during migration: + +```text +#pto.vmi.layout +``` + +is accepted as a legacy spelling for the pre-design implicit group layout. New +`vmi-layout-assignment` output must not rely on that implicit form. It must +print one of: + +```text +#pto.vmi.layout +#pto.vmi.layout +``` + +so `vmi-to-vpto` can lower from the assigned type without reconstructing group +slot placement from producer or consumer context. + +### 3.2 VMI Types + +Surface: + +```text +!pto.vmi.vreg +!pto.vmi.mask +``` + +Layout-assigned: + +```text +!pto.vmi.vreg> +!pto.vmi.mask> +``` + +Surface VMI types are legal before assignment. Layout-assigned VMI types are +required after assignment. + +### 3.3 Selected Plan Attribute + +Every context-sensitive op gets a selected plan attr after assignment. The +initial implementation may use a stable string attr: + +```text +vmi.selected_plan = "s16_reduce_parity" +``` + +Once the plan registry syntax is stable, this can become a dedicated plan attr: + +```text +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +``` + +Ops that are uniquely determined by layout may omit this attr, but the rule +should be conservative. If future maintainers could reasonably ask "why this +lowering?", assignment should write a plan. + +## 4. VMI Surface Ops Required By Cases + +Initial op set from the case catalog: + +```text +load +group_load +group_slot_load +store +masked_store + +create_mask +create_group_mask + +extf +truncf +addf +mulf +select +broadcast + +group_reduce_addf +group_broadcast +group_store + +ensure_layout // internal +ensure_mask_layout // internal +ensure_mask_granularity // internal +``` + +Important semantic split: + +```text +load: + optional full_read_elems=N is a memory-safety contract for pointer sources. + It states that source[offset : offset + N) may be physically read even if the + VMI logical result has fewer active lanes. + +group_load: + loads group_size data elements per group + +group_slot_load: + loads one scalar per group and produces group_slots +``` + +## 5. Plan Registry + +Create one registry object shared by assignment and lowering. + +```c++ +class VMILayoutPlanRegistry { +public: + SmallVector getProducerPlans(Operation *op); + SmallVector getConsumerPlans(OpOperand &use); + SmallVector getTransferPlans(Operation *op); + FailureOr getMaterializationPlan(Type valueType, + VMILayoutKey from, + VMILayoutKey to); + bool isCheaplyRematerializable(Operation *op); + bool hasTargetCapability(PlanID plan) const; +}; +``` + +Plan record: + +```c++ +struct VMILayoutPlan { + PlanID id; + SmallVector operandLayouts; + SmallVector resultLayouts; + int64_t cost; + bool requiresSelectedPlanAttr; + bool requiresFullTileReadable; + bool mayReadInactivePhysicalLanes; + DiagnosticBuilder (*explainFailure)(...); +}; +``` + +The registry must be target-aware but deterministic. It should not read global +mutable state. Pass options configure fallback availability: + +```text +enableScratchFallback +enableGatherFallback +enablePublicVMIABI +diagnosticVerbosity +``` + +## 6. Layout Assignment Data Model + +### 6.1 Solver State + +```c++ +struct ValueLayoutState { + Value value; + Type logicalType; + SmallVector candidates; + std::optional chosen; + SmallVector useRequests; +}; + +struct UseRequest { + OpOperand *operand; + VMILayoutKey requestedLayout; + PlanID requestingPlan; + bool hard; +}; + +struct OpPlanState { + Operation *op; + SmallVector candidates; + std::optional chosen; +}; +``` + +### 6.2 Collection Phase + +Walk the module and collect: + +```text +1. every VMI value +2. every VMI block argument +3. every VMI function argument/result +4. every VMI op with candidate plans +5. every branch/yield/call/return edge carrying VMI +``` + +Build SCCs over: + +```text +dataflow uses +region yields +loop iter_args +function call graph for private/internal functions +``` + +Public/external VMI function boundaries are rejected unless +`enablePublicVMIABI` is explicitly supported. + +Block arguments are first-class layout variables. Assignment must write the +chosen layout into the block argument type or specialized function signature. +`vmi-to-vpto` must never recover a block argument layout by walking to an +incoming branch, yield, or call operand. + +### 6.3 Constraint Generation + +Examples: + +```text +truncf f32->f16: + source request deinterleaved=2, block_elems=1 + result contiguous + +group_reduce S=16: + source candidate deinterleaved=2, block_elems=1 + source candidate deinterleaved=2, block_elems=8 + result group_slots(G, slots=8) + +group_reduce S=32: + source candidate deinterleaved=4, block_elems=1 + source candidate deinterleaved=4, block_elems=8 + result group_slots(G, slots=8) + +group_reduce S=64: + source request contiguous + result group_slots(G, slots=1) + +group_broadcast: + source request group_slots(G,K) + result candidate comes from each dense consumer request + op is rematerializable per use + +ordinary dense add/mul/select: + operands/results same dense layout + +group-slot add/mul: + operands/results same group_slots(G,K) + +ordinary store: + dense source required + group_slots source is illegal + +group_store: + source request group_slots(G,K) +``` + +Consumer-driven adoption is limited to producers that are layout-transparent or +can produce the requested memory layout directly: + +```text +direct layout producer: + load, tile_read + +layout-transparent producer: + broadcast, constant, iota + add/sub/mul/fma/div/min/max/neg/abs/sqrt/exp/ln/relu + integer bitwise/shift/not + select, bitcast +``` + +For a non-load layout-transparent producer, only non-contiguous consumer +requests may be adopted by the producer equivalence class. Contiguous requests +from ordinary stores are handled by use-site `ensure_layout` or +rematerialization instead. This prevents a dense store from overwriting a +natural `deinterleaved` cast layout while still allowing: + +```text +load -> broadcast -> addf -> S=32 group_reduce +``` + +to assign the whole producer chain as +`deinterleaved = 4, block_elems = 8` before `vmi-to-vpto`. + +Memory legality constraints: + +```text +S=32 tail fast load: + requires full_tile_readable + otherwise require gather fallback or diagnose + +compact S=12 logical S=16: + requires compact-row gather materialization + diagnose if gather fallback is disabled/missing +``` + +### 6.4 Solving And Rewriting + +Algorithm: + +```text +1. Pick candidate plan sets for every op. +2. Propagate hard constraints through SCCs. +3. Resolve transfer-equivalent dense values. +4. Choose multi-plan ops by cost: + - S=16 parity vs block8 + - load memory-fused vs load+materialize + - group_slot_load slots=8 vs slots=1 +5. For conflicting uses: + - rematerialize cheap producer where legal + - otherwise insert ensure_layout at use + - otherwise diagnose +6. Rewrite VMI result/block/function types with chosen layouts. +7. Attach selected_plan attrs where required. +8. Insert helper ops with source/result layout attrs. +``` + +Rewrite invariants: + +```text +No VMI data/mask value after assignment has a null layout. +No context-sensitive VMI op after assignment lacks selected_plan. +Every ensure_* helper has a registered materialization plan. +Every function/call signature carrying VMI is specialized or diagnosed. +``` + +## 7. OneToN Type Conversion + +`vmi-to-vpto` should use OneToN conversion for VMI values. + +Conversion rules: + +```text +contiguous: + ceil(N / lanesPerVReg(T)) physical vregs + +deinterleaved=F: + F * ceil((N / F) / lanesPerVReg(T)) physical vregs + ordering: part-major, then chunk + +group_slots(G,K): + ceil(G / K) physical vregs + each vreg has logical slot lanes 0..K-1 live +``` + +Mask conversion: + +```text +mask layout follows data layout +mask granularity is selected from consumer element width: + f32/i32 -> b32 + f16/i16 -> b16 + f8/i8 -> b8 +``` + +If one logical mask is used by multiple widths, assignment inserts +`ensure_mask_granularity` or rematerializes the mask producer. + +## 8. VMI-to-VPTO Pattern Rules + +Each pattern uses: + +```text +op +operand/result layouts +selected_plan +adaptor physical values +``` + +Each pattern rejects: + +```text +missing selected_plan for context-sensitive op +layout not matching selected_plan +missing target capability +unexpected group_slots dense consumer +``` + +Target selected-plan matrix: + +```text +load, selected_plan=dense_load_norm: + result layout contiguous + emits pto.vlds / pto.vsts NORM paths + covers dense store users and S=64 row-local reduce input + +load, selected_plan=load_dintlv2: + result layout deinterleaved=2, block_elems=1 + emits vldsx2 DINTLV_B32 or normal load + vdintlv materialization + covers f32->f16, S=16 parity reduce, f16->f32 widened values + +load, selected_plan=load_dintlv4: + result layout deinterleaved=4, block_elems=1 + emits two vldsx2 DINTLV_B32 plus vdintlv + covers f32->f8, S=32 dintlv4 reduce + +group_load, selected_plan=s16_group_load_block8_unit_stride: + result layout deinterleaved=2, block_elems=8 + emits vldsx2/BDINTLV for 8 rows of 16xf32 + covers compact logical S=16 when source_group_stride == 16 + +group_load, selected_plan=s16_group_load_block8_stride: + result layout deinterleaved=2, block_elems=8 + emits two vsldb strided 32B block loads + requires source_group_stride % 8 == 0 + +group_load, selected_plan=s32_group_load_block8_stride: + result layout deinterleaved=4, block_elems=8 + emits four vsldb strided 32B block loads + requires source_group_stride % 8 == 0 + +group_load, selected_plan=group_load_contiguous_chunks: + result layout contiguous + emits one vlds per physical group chunk using row_stride address arithmetic + covers the currently implemented full-chunk row-local group_load path + +group_reduce_addf, selected_plan=s8_reduce_contiguous: + consumes contiguous f32 with group size 8 + produces group_slots(G, slots=8) + emits one vcgadd + +group_reduce_addf, selected_plan=s16_reduce_parity: + consumes deinterleaved=2, block_elems=1 + produces group_slots(G, slots=8) + emits two vcgadd operations and one vadd + +group_reduce_addf, selected_plan=s16_reduce_block8: + consumes deinterleaved=2, block_elems=8 + produces group_slots(G, slots=8) + emits two vcgadd operations and one vadd + +group_reduce_addf, selected_plan=s32_reduce_dintlv4: + consumes deinterleaved=4, block_elems=1 + produces group_slots(G, slots=8) + emits four vcgadd operations and a vadd tree + +group_reduce_addf, selected_plan=s32_reduce_block8_stride: + consumes deinterleaved=4, block_elems=8 + produces group_slots(G, slots=8) + emits four vcgadd operations and a vadd tree + +group_reduce_addf, selected_plan=s64_reduce_row_local: + consumes contiguous f32 with group size 64 + produces group_slots(G, slots=1) + target lowering emits per-row vcgadd plus vcadd; the current prototype uses + the existing row-local VCADD/VADD/VSEL sequence while preserving the same + group_slots(G, slots=1) value contract + +group_slot_load, selected_plan=group_slot_load_slots8_unit_stride: + result group_slots(G, slots=8) + requires source_group_stride == 1 + emits one packed vsldb load + +group_slot_load, selected_plan=group_slot_load_slots1_row_local: + result group_slots(G, slots=1) + supports aligned non-unit source_group_stride + requires constant positive source_group_stride divisible by 256 / elementBits + emits one lane-0 vsldb per group + +group_broadcast, selected_plan=group_broadcast_slots8_vselr: + source group_slots(G, slots=8) + result dense layout selected per use + emits vselr using assigned result layout + +group_broadcast, selected_plan=group_broadcast_slots1_vselr: + source group_slots(G, slots=1) + result dense layout selected per use + emits vdup/vselr row-local materialization + +truncf, selected_plan=group_slot_cast_slots1_f32_to_f16: + source/result group_slots(G, slots=1) + emits one lane-0 vcvt per group slot block + rejects packed slots=8 unless another plan is registered +``` + +The target matrix is the implementation contract. The staged status below +records how much of that contract the current prototype has already enforced. + +Current staged implementation status: + +```text +group_slot_load: + vmi-to-vpto requires vmi.selected_plan and checks it against + #pto.vmi.layout. + +group_reduce_addf: + explicit slots=8 VCGADD lowering requires + vmi.selected_plan = "s8_reduce_contiguous". Legacy bare num_groups and + generic VCADD lowering still need the plan-registry migration. + S=16 block8 assignment emits source/mask + #pto.vmi.layout, result + #pto.vmi.layout, and + vmi.selected_plan = "s16_reduce_block8"; vmi-to-vpto checks that plan and + lowers through two VCGADDs plus a PAT_VL8 VADD per packed result block. + S=32 block8 assignment emits source/mask + #pto.vmi.layout, result + #pto.vmi.layout, and + vmi.selected_plan = "s32_reduce_block8_stride"; vmi-to-vpto checks that + plan and lowers through four VCGADDs plus a PAT_VL8 VADD tree per packed + result block. + S=64 row-local assignment now emits + vmi.selected_plan = "s64_reduce_row_local" and has focused + layout-assignment/vmi-to-vpto lit coverage; the explicit slots=1 generic + VCADD row-local path also requires and checks that selected_plan. Other + legacy bare num_groups generic VCADD paths still need the plan-registry + migration. + +group_broadcast: + explicit slots=8/1 source layouts require + vmi.selected_plan = "group_broadcast_slots8_vselr" or + "group_broadcast_slots1_vselr". Deinterleaved block-fragment results use + the result layout block_elems as the local vselr selection group, so + `deinterleaved = 4, block_elems = 8` broadcasts one group slot across each + 32B row fragment. VSELR index vectors are materialized per physical result + chunk. For small-group results, layout assignment has already fixed the + result layout, and vmi-to-vpto computes: + `firstGroup = first logical group covered by this result chunk`, + `sourceChunk = firstGroup / slots`, and + `baseGroupSlot = firstGroup % slots`. The generated index vector selects + `baseGroupSlot .. baseGroupSlot + groupsPerResultChunk - 1`; it must not be + reused across result chunks. Legacy bare num_groups still needs the + plan-registry migration. + +group_load: + contiguous full-chunk path emits and checks + vmi.selected_plan = "group_load_contiguous_chunks". S=16/S=32 + block-aligned strided loads emit and check + vmi.selected_plan = "s16_group_load_block8_stride" or + "s32_group_load_block8_stride", assign + #pto.vmi.layout, and lower to one + vsldb per 32B row fragment and physical chunk. The dedicated S=16 unit-stride + vldsx2/BDINTLV plan remains a design target. S=16/S=32 group_load with a + non-constant, non-positive, or non-8-f32-aligned row_stride is rejected by + vmi-layout-assignment because the stable gather fallback is not implemented. + +truncf group-slot cast: + layout assignment and vmi-to-vpto support and check + vmi.selected_plan = "group_slot_cast_slots1_f32_to_f16" for + group_slots(G, slots=1) f32 -> f16. The reduce->truncf->group_store + slots=1 flow has focused lit coverage and no longer relies on vmi-to-vpto + inspecting the truncf producer. + +group_store: + row-local group_slots(G, slots=1) lowering is implemented as one lane-0 + vsts per group and is covered by the reduce->truncf->group_store lit case. + The current plan is accepted only when row_stride is a constant positive + multiple of the 32B store alignment in destination elements: 8 for f32, + 16 for f16, and 32 for f8. Unit-stride f32 output is rejected because only + the first row-local store is 32B-aligned; later `group_off + r` stores are + 4B apart. A future pack-to-slots=8 or unaligned-store plan is required before + contiguous `%c1` slots=1 group_store can be accepted. + Packed group_slots(G, slots=8) group_store is implemented only when + num_groups is a multiple of 8 and row_stride is constant 1; it emits one + PAT_VL8 store per packed slot block. Non-unit packed group stores remain a + design target unless a strided packed-lane store plan is selected explicitly. +``` + +Examples: + +```text +group_reduce_addf, selected_plan=s16_reduce_parity: + consume deinterleaved=2, block_elems=1 + emit two VCGADDs and one VADD + +group_reduce_addf, selected_plan=s16_reduce_block8: + consume deinterleaved=2, block_elems=8 + emit two VCGADDs and one VADD + +group_reduce_addf, selected_plan=s32_reduce_dintlv4: + consume deinterleaved=4 + emit four VCGADDs and reduction tree + +group_broadcast: + consume group_slots + emit VSELR or VDUP depending slots and target dense layout + +group_slot_load slots=8: + emit one packed block load for unit stride + +group_slot_load slots=1: + emit row-local lane-0 loads for constant positive 32B-aligned strides +``` + +## 9. Validation Passes + +### 9.1 Surface Validation + +Before assignment: + +```text +VMI types may omit layout. +VPTO physical op must not consume VMI values. +Public/external VMI function ABI rejected unless enabled. +Unsupported vector-to-scalar extract rejected. +``` + +### 9.2 Layout Validation + +After assignment: + +```text +Every VMI value has layout. +Every VMI mask has layout and granularity plan. +Every context-sensitive op has selected_plan. +Every selected_plan matches operand/result layouts. +Every ensure_* helper has a materialization plan. +Every control-flow edge has matching VMI layouts. +``` + +### 9.3 `vmi-to-vpto` Context Read Audit + +`vmi-to-vpto` may still read defining ops in narrowly scoped cases that do not +select a layout or plan: + +```text +allowed: + arith.constant for the current op's scalar operands + create_mask/create_group_mask internals when lowering that mask op itself + ensure_mask_layout / ensure_mask_granularity stripping for static mask facts + memref.subview only to improve an already-failed non-identity memref + diagnostic + +not allowed: + walking from a consumer to a producer to decide a selected_plan + walking from a consumer to a mask producer to decide whether a plan is legal + inspecting users to choose a result layout or materialization + recovering full_tile_readable from surrounding MTE/caller context +``` + +Current audit result: + +```text +3.44 partial S=32 create_group_mask: + decision moved to vmi-layout-assignment. vmi-to-vpto no longer walks from + group_reduce_addf to the mask defining op to reject the plan. + +masked_load: + direct lowering is load + vsel. It does not inspect the mask producer to + choose a different load form; memory safety is provided by full physical + chunks, shaped memref proof, or load full_read_elems. + +memref.subview: + mentioned only after identity lane-to-address planning fails. It is not used + to recover a hidden base/stride lowering. +``` + +## 10. Diagnostics + +Implement diagnostics with stable prefixes: + +```text +VMI-LAYOUT-CONTRACT +VMI-UNSUPPORTED-PLAN +VMI-MISSING-CAPABILITY +VMI-PUBLIC-ABI +VMI-MASK-GRANULARITY +VMI-CONTROL-FLOW-LAYOUT +``` + +Minimum diagnostic payload: + +```text +op name +logical type +actual layout +requested layout +selected/missing plan +recommended rewrite or option +``` + +Example: + +```text +VMI-LAYOUT-CONTRACT: + pto.vmi.truncf requires + #pto.vmi.layout, but the source value is + fixed to #pto.vmi.layout by the selected + strided group_load plan. Register a rematerialization or preserving + materialization plan, or avoid consuming this block-loaded value with truncf. +``` + +## 11. Test And Simulator Acceptance + +Each numbered endpoint in `vmi-layout-lowering-cases.md` should become: + +```text +1. a layout-assignment lit test +2. a vmi-to-vpto lit test +3. a simulator case when the VPTO sequence is supported by the current backend +4. a diagnostic lit test when the case is explicitly unsupported +``` + +Repository locations: + +```text +test/lit/vmi/ +test/vpto/cases/vmi/ +``` + +The current repository uses descriptive flat lit names rather than +case-numbered subdirectories. New tests should follow the existing prefixes: + +```text +vmi_layout_assignment_.pto +vmi_to_vpto_.pto +/kernel.pto +``` + +The case number should still be recoverable from the coverage table in this +document and from the corresponding section in `vmi-layout-lowering-cases.md`. + +### 11.1 Layout Assignment Checks + +Each positive layout-assignment test must check: + +```text +assigned data layouts +assigned mask layouts +selected_plan attrs +inserted ensure_layout/rematerialized producers +control-flow/function signature specialization +``` + +Negative tests check diagnostic text. + +### 11.2 VMI-to-VPTO Checks + +Each positive vmi-to-vpto test must check: + +```text +no pto.vmi ops remain +VPTO op sequence matches the case lowering +physical value arity and ordering are correct +mask granularity is correct +stores preserve observable logical memory order +``` + +### 11.3 Simulator Checks + +Simulator cases should compare final memory against the memory result written in +the case catalog. + +Current broad runtime sweep: + +```text +WORK_SPACE=$PWD/.tmp/vmi-runtime-batch-39 CASE_PREFIX='vmi/' JOBS=4 \ + test/vpto/scripts/run_host_vpto_validation_parallel.sh + +PASS=39 FAIL=0 +summary: .tmp/vmi-runtime-batch-39/parallel-summary.tsv +log scan: rg -n "RV_|alignment|\[ERROR\]|\[error\]|ERROR" \ + .tmp/vmi-runtime-batch-39.log +result: no matches +``` + +The `find: Permission denied` messages printed while discovering CANN simulator +paths are environment noise and are not treated as simulator failures. + +Required groups: + +```text +dense conversion: + 3.1, 3.2, 3.3, 3.31, 3.32 + +group reduce: + 3.4, 3.5.1, 3.5.2, 3.5.3 + 3.6.1, 3.6.2, 3.6.3 + 3.7.1, 3.7.2, 3.7.3 + 3.7.4 diagnostic + +layout/rematerialization: + 3.8, 3.10, 3.17, 3.18, 3.19.1, 3.22, 3.23, 3.31, + 3.32, 3.33, 3.34, 3.35, 3.36, 3.38, 3.40, 3.41 + +mask/tail: + 3.11.1, 3.15.1, 3.15.2, 3.21, 3.24, 3.26, 3.29, + 3.30, 3.44 + +strided/group-slot memory: + 3.27, 3.28, 3.37, 3.39 + +function/control-flow: + 3.12, 3.20, 3.22, 3.25.1, 3.42, 3.43 +``` + +Aggregate catalog headings are covered through their endpoint subcases: + +```text +3.11 partial tail groups: + 3.11.1 positive S=64 active-row tail + 3.11.2 diagnostic S=32 tail without full_tile_readable + +3.15 compact S=12 written as logical S=16: + 3.15.1 positive source row stride 16 + 3.15.2 positive source row stride greater than 16 + 3.15.3 diagnostic compact source row stride 12 + +3.16 group_slot_load layout contract: + 3.16.1 packed slots=8 positive and non-unit-stride diagnostic + 3.16.2 row-local slots=1 positive plus dynamic/unaligned diagnostics + +3.25 function boundary layout specialization: + 3.25.1 private/internal boundary lit coverage, runtime backend gap + 3.25.2 public/external boundary diagnostics +``` + +Current checked-in coverage for 3.3 dense f8->f32->compute->f8: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_f8_compute_f8.pto + +runtime SIM: + test/vpto/cases/vmi/f8-compute-f8 +``` + +Current checked-in coverage for 3.1/3.2 dense f16/f32 conversion stores: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_dense_f16_f32_store.pto + +runtime SIM: + test/vpto/cases/vmi/widen-f16-to-f32-store-reduce + test/vpto/cases/vmi/quant-f32-to-f16-tail +``` + +Current checked-in coverage for basic packed group_reduce -> group_store paths +for 3.4, 3.5.1, and 3.6.1: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_reduce_slots8_store.pto + test/lit/vmi/vmi_layout_assignment_group_reduce_s16_store.pto + test/lit/vmi/vmi_layout_assignment_group_reduce_s32_store.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-basic-store +``` + +Current checked-in coverage for S=16 group broadcast continuation: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_slots_fanout.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store +``` + +Current checked-in coverage for 3.35 group_slots fanout to direct group_store +and group_broadcast: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_slots_fanout.pto + +runtime SIM: + test/vpto/cases/vmi/group-slots-fanout-store-broadcast +``` + +Current checked-in coverage for 3.8 `group_reduce -> truncf -> +group_broadcast -> dense store` and 3.17 `group_broadcast` feeding a +deinterleaved consumer: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store +``` + +Current checked-in coverage for 3.18 one dense value with dense and +group-reduce consumers: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto + +runtime SIM: + test/vpto/cases/vmi/dense-group-reduce-multi-consumer +``` + +Current checked-in coverage for 3.10 non-load producer feeding S=32 +`group_reduce`: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s32-add-bias-store +``` + +Current checked-in coverage for 3.23 group_broadcast with multiple dense +consumers: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto + +runtime SIM: + test/vpto/cases/vmi/group-broadcast-multi-consumer +``` + +Current checked-in coverage for S=32 contiguous group broadcast continuation: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_reduce_s32_broadcast_reduce.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store +``` + +Current checked-in coverage for 3.21 S=32 tail with a statically safe +full-read source: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store + This case has `ptoas.flags` with `--enable-vmi`, because the partial pointer + load must run through layout assignment before VPTO/LLVM emission. +``` + +Current checked-in runtime coverage for 3.12 control-flow join before S=32 +`group_reduce`: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_cf_branch.pto + test/lit/vmi/vmi_to_vpto_cf_branch.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s32-cf-join-store +``` + +Current checked-in runtime coverage for 3.20 `group_slots` control-flow join: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_slots_cf_join.pto + +runtime SIM: + test/vpto/cases/vmi/group-slots-cf-join-store +``` + +Current checked-in runtime coverage for 3.22 `scf.for` loop-carried VMI layout: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_scf_for.pto + test/lit/vmi/vmi_to_vpto_scf_for.pto + +runtime SIM: + test/vpto/cases/vmi/scf-for-loop-carried-store +``` + +Current checked-in runtime coverage for 3.42 `group_slots` `scf.for` +loop-carried accumulator: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto + +runtime SIM: + test/vpto/cases/vmi/group-slots-scf-for-store +``` + +Current checked-in lit coverage for 3.43 internal function argument boundary +materialization: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_call_argument_boundary.pto + +runtime SIM: + blocked by the current private vector callee backend path; see known + implementation gaps below +``` + +Current checked-in coverage for packed group-slot RHS elementwise continuations +for 3.5.3 and 3.6.2: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-slot-add-store +``` + +Current checked-in coverage for S=64 row-local group broadcast continuation +with aligned row_stride: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_reduce_s64_broadcast_reduce.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store +``` + +Current checked-in coverage for S=64 active-row tail with aligned row_stride: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_reduce_s64_tail_store.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s64-tail-store +``` + +The companion negative lit case for contiguous `%c1` slots=1 group_store is: + +```text +test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto +``` + +Current checked-in coverage for S=64 row-local group-slot RHS elementwise +continuation with aligned source_group_stride and aligned output row_stride: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s64-slot-add-store +``` + +Current checked-in coverage for 3.34 S=64 `slots = 1` group-slot f32->f16 cast: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_reduce_s64_truncf.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s64-truncf-store +``` + +The companion negative lit cases for dynamic or unaligned `%c2` slots=1 +group_slot_load, and non-unit `slots = 8` group_slot_load, are: + +```text +test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto +test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto +test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto +``` + +Current checked-in coverage for the strided block-load cases: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_load_s16_stride_store.pto + test/lit/vmi/vmi_layout_assignment_group_load_s16_unaligned_stride_invalid.pto + test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_store.pto + test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_broadcast_reduce.pto + test/lit/vmi/vmi_layout_assignment_group_load_s32_unaligned_stride_invalid.pto + +runtime SIM: + test/vpto/cases/vmi/group-load-s16-stride-store + test/vpto/cases/vmi/group-load-s32-stride-store + test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce +``` + +Current checked-in coverage for grouped mask S=16 tail: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto + test/lit/vmi/vmi_create_group_mask_invalid.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store + test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store + test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store +``` + +Current checked-in coverage for 3.24 mask/select/masked-store semantics: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_mask_select_store.pto + +runtime SIM: + test/vpto/cases/vmi/mask-select-store +``` + +Current checked-in coverage for 3.29 one semantic mask with f32 and f16 +consumers: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_mask_granularity_f32_f16_store.pto + +runtime SIM: + test/vpto/cases/vmi/mask-granularity-f32-f16-store +``` + +Current checked-in coverage for 3.31 f16->f32 feeding dense store and S=16 +reduce: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto + +runtime SIM: + test/vpto/cases/vmi/widen-f16-to-f32-store-reduce +``` + +Current checked-in coverage for 3.32 f32 feeding f8 store and S=32 reduce: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto + +runtime SIM: + test/vpto/cases/vmi/f32-to-f8-store-reduce +``` + +Current checked-in coverage for multi-tile group-slot arity: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_reduce_s32_multitile_store.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s32-multitile-store +``` + +Current checked-in coverage for 3.40 scalar broadcast feeding dense and grouped +users: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto + +runtime SIM: + test/vpto/cases/vmi/broadcast-dense-group-users +``` + +Current checked-in coverage for 3.41 non-rematerializable `masked_load` feeding +dense and grouped users: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto + +runtime SIM: + test/vpto/cases/vmi/masked-load-dense-group-users +``` + +Diagnostic-only cases: + +```text +3.9 dense store of group slots +3.11.2 S=32 tail without full_tile_readable +3.7.4 S=64 slots=1 group_store with unit output stride +3.13 packed group-slot f32 -> f16 cast +3.14 unsupported group size +3.15.3 compact source row stride 12 +3.16.1 group_slot_load slots=8 non-unit stride +3.16.2 group_slot_load slots=1 dynamic or unaligned stride +3.27 S=32 source_group_stride not divisible by 8 f32 elements +3.19.2 block_elems=8 value consumed by truncf without materialization plan +3.25.1 full ptoas emission for private VMI callees that return VPTO vector values +3.25.2 public/external VMI boundary +3.30 unsafe masked_load tail without stable masked/gather fallback +3.44 masked_load grouped tail with S=32 partial create_group_mask +``` + +Current checked-in diagnostic coverage for 3.9/3.13/3.14: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto + test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto + test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto +``` + +Current checked-in diagnostic coverage for the remaining non-SIM diagnostic +entries: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto + test/lit/vmi/vmi_layout_assignment_group_load_s16_compact_stride12_invalid.pto + test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto + test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto + test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto + test/lit/vmi/vmi_ptoas_public_abi_invalid.pto + test/lit/vmi/vmi_ptoas_public_result_abi_invalid.pto + test/lit/vmi/vmi_layout_assignment_external_call_invalid.pto + test/lit/vmi/vmi_layout_assignment_external_decl_invalid.pto + test/lit/vmi/vmi_ptoas_call_boundary_vecscope_invalid.pto + test/lit/vmi/vmi_to_vpto_masked_load_nonfull_invalid.pto + test/lit/vmi/vmi_to_vpto_stable_gather_masked_load_todo_invalid.pto + test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto +``` + +Known implementation gaps before all catalog cases can become runtime SIM +coverage: + +```text +dynamic grouped masks: + pto.vmi.create_group_mask exists and supports constant + active_elems_per_group. Dynamic active_elems_per_group is not implemented + yet. Do not replace grouped masks with prefix create_mask; that would change + the semantics. + +S=32 partial grouped masks: + 3.44 `masked_load` grouped tail with `active_elems_per_group < 32` is + diagnostic-only for the current S=32 block8 reduce path, and the diagnostic + is emitted by `vmi-layout-assignment` before a selected plan is written. A + runtime probe of the previously allowed lowering did not preserve the logical + 25-lane row sum. A second probe with `active_elems_per_group = 25` produced + row 0 `golden=-3.6290324` but `output=-3.6592741`, and the row-wise error + grew monotonically. This combination must stay unsupported until the + deinterleaved grouped-mask materialization is fixed and validated by SIM. + +remaining function runtime coverage: + 3.25.1 internal function boundary specialization has layout-assignment and + vmi-to-vpto lit coverage, but full ptoas emission still fails after + physicalization because today's inferred pto.vecscope is resultless and VPTO + vector-scope values cannot escape through a function return. Runtime coverage + requires either a resultful vecscope/VPTO vector ABI or an explicit inlining + policy before vecscope inference. + + 3.43 internal function argument boundary materialization has + layout-assignment and vmi-to-vpto lit coverage. Full ptoas emission for a + private void vector callee currently reaches the Bisheng device backend and + fails on the physicalized callee with: + + fatal error: error in backend: Do not know how to split the result of this operator! + + Runtime coverage requires either inlining private vector callees before the + device backend path or adding backend support for the physical VPTO vector + function ABI. This is a runtime/backend gap, not a license for `vmi-to-vpto` + to infer layouts from caller/callee context. + +memory-proof runtime coverage: + 3.21 S=32 full-tile-readable tail is covered by a runtime case that uses + `pto.vmi.load {full_read_elems = 256}` on a UB pointer source. The attr is + the explicit safe-read proof consumed by `vmi-to-vpto`; no surrounding MTE, + caller/body context, or producer/user scan is inspected to justify the + rounded-up physical reads. +``` + +## 12. Implementation Slices + +### Slice 1: IR Skeleton And Verifiers + +```text +layout attrs +vmi.vreg/vmi.mask types +surface op definitions +selected_plan attr +surface/layout validators +``` + +### Slice 2: Straight-Line Dense Assignment/Lowering + +```text +3.1 f16->f32->store +3.2 f32->f16->store +3.3 f8->f32->compute->f8 +``` + +### Slice 3: Group Slots And Reductions + +```text +3.4 S=8 +3.5 S=16 parity/block8 +3.6 S=32 +3.7 S=64 +group_slot_load +group_broadcast +group_store +``` + +### Slice 4: Layout Conflicts And Materialization + +```text +3.8 cast commute through group_broadcast +3.18 dense/group-reduce multi-consumer +3.19 block_elems plan selection +3.23 group_broadcast multi-consumer +3.32 f32 feeding f8 store and S=32 reduce +3.33 S=16/S=32 reduce multi-consumer rematerialization +3.34 slots=1 group-slot f32->f16 cast +3.35 group_slots fanout to group_store and group_broadcast +3.36 group_slot_load rematerialized for slots=8/slots=1 +3.38 multi-tile group_slots arity +3.40 scalar broadcast rematerialized for dense/grouped users +3.41 non-rematerializable value with ensure_layout +``` + +### Slice 5: Masks, Tail, And Memory Legality + +```text +create_mask +create_group_mask +masked_store +safe full-read proof +compact/gather diagnostics +mask granularity per use +group_load stride greater than group size +group_slot_load slots=1 aligned non-unit stride plus dynamic/unaligned diagnostic +group_store slots=1 non-unit output stride +strided group_load feeding broadcast and a second reduce +masked_load grouped tail feeding S=32 reduce +``` + +### Slice 6: Control Flow And Functions + +```text +scf.if +scf.for +group_slots across control flow +group_slots loop-carried accumulator +internal function specialization +internal function argument boundary materialization +public ABI diagnostic +``` + +## 13. Completion Checklist + +The implementation is not complete until: + +```text +1. every case has a layout-assignment test +2. every positive case has a vmi-to-vpto test +3. every simulator-supported case has a sim validation +4. every unsupported case has a diagnostic test +5. vmi-to-vpto contains no producer/user context inference +6. missing selected_plan on context-sensitive ops is a hard failure +7. release docs are updated only after the design stabilizes +``` diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md new file mode 100644 index 0000000000..710ab267a7 --- /dev/null +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -0,0 +1,625 @@ +# VMI Layout Assignment And Lowering Design + +本文是新的 VMI layout assignment / lowering 设计文档。它只以 +`docs/designs/vmi-layout-lowering-cases.md` 为 source of truth,不继承旧 +`vmi-dialect-design.md` 的 layout 设计,以避免旧上下文污染。 + +目标: + +```text +VMI surface IR + -> vmi-layout-assignment + -> layout-assigned VMI IR + -> vmi-to-vpto + -> VPTO IR +``` + +核心验收约束: + +```text +vmi-to-vpto 不允许通过上下文猜 lowering。 + +任何需要 producer/consumer/control-flow/memory/mask 上下文才能决定的事, +必须在 vmi-layout-assignment 阶段变成显式 IR 信息: + +1. vmi.vreg/vmi.mask 的 layout +2. op 的 selected lowering plan +3. use-site ensure_layout / ensure_mask_layout +4. rematerialized producer +5. target capability diagnostic +``` + +## 1. Source Case Coverage + +设计必须覆盖 case catalog 中的端到端场景: + +```text +dense cast: + f16 -> f32 -> store + f32 -> f16 -> store + f8 -> f32 -> compute -> f8 + f16 -> f32 shared by dense store and S=16 reduce + f32 shared by f8 store and S=32 reduce + +group reduce: + S=8, S=16, S=32, S=64 + reduce -> group_store + reduce -> group_slot_load/elemwise -> group_store + reduce -> group_broadcast -> elemwise -> reduce -> store + one group_slots result fanning out to group_store and group_broadcast + grouped tail -> broadcast -> reduce -> store + +layout conflict: + one value with dense and group-reduce consumers + one value with S=16 and S=32 group-reduce consumers + one scalar broadcast rematerialized for dense and grouped users + one non-rematerializable value materialized with use-site ensure_layout + one scalar group-slot source rematerialized as slots=8 and slots=1 + S=16 block_elems=1/8 plan selection + dense consumer of group_slots diagnostic + packed group-slot width-changing cast diagnostic + S=64 slots=1 group-slot width-changing cast + +control flow: + scf.if before group_reduce + group_slots across scf.if + scf.for loop-carried layout fixed point + group_slots as scf.for loop-carried accumulator + internal function boundary specialization + internal function argument boundary materialization + public/external VMI ABI diagnostic + +mask and tail: + prefix mask + group-periodic mask + masked_load tail with explicit passthrough instead of padding + masked_load grouped tail feeding group_reduce + masked select/store + one semantic mask used by multiple predicate granularities + S=32 tail with and without full_tile_readable + compact S=12 diagnostic + +strided memory: + group_load source stride greater than logical group size + strided group_load feeding broadcast and a second group_reduce + group_slot_load slots=1 with non-unit source stride + group_store slots=1 with non-unit output stride +``` + +## 2. Layout Domain + +Layout is a property of a layout-assigned VMI value, not a property inferred by +the final lowering pattern. + +### 2.1 Dense Layouts + +```text +#pto.vmi.layout +#pto.vmi.layout +``` + +`block_elems` defaults to `1`: + +```text +#pto.vmi.layout + == #pto.vmi.layout +``` + +Dense layouts preserve one semantic value for every logical lane. + +Lane map for `deinterleaved = F, block_elems = B`: + +```text +logical lane i +block q = i / B +in-block lane r = i % B +part p = q % F +part block t = q / F + +physical part p, physical lane t * B + r +``` + +Important consequence: + +```text +deinterleaved=2, block_elems=1 +deinterleaved=2, block_elems=8 +``` + +are different layouts. They cannot be treated as compatible because `F` is the +same. + +### 2.2 Sparse Group-Slot Layouts + +```text +#pto.vmi.layout +``` + +Only `G` lanes have semantic values: + +```text +slot_block(g) = g / K +slot_lane(g) = g % K +``` + +All non-slot lanes are undefined and may only be read by group-aware operations. +Ordinary dense `add/mul/store/truncf` cannot consume `group_slots`. + +`K` is selected by the lowering plan: + +```text +S=8/16/32 packed VCG result -> slots=8 +S=64 row-local result -> slots=1 +``` + +## 3. Lowering Context Must Become Assignment Output + +`vmi-to-vpto` may inspect only: + +```text +1. op name and explicit op attrs +2. converted operand/result types with layout +3. selected plan attrs written by layout assignment +4. inserted helper ops +5. target capability registry +``` + +It must not: + +```text +1. walk to defining op to infer layout +2. inspect all users to choose a lowering path +3. infer memory legality from a later mask +4. decide S=16 block_elems=1 vs block_elems=8 locally +5. decide whether group_broadcast should be materialized for one or many users +6. specialize function signatures during vmi-to-vpto +``` + +Any of those decisions belongs to `vmi-layout-assignment`. + +## 4. Explicit Assignment Products + +After `vmi-layout-assignment`, every VMI data and mask value must be in one of +these states: + +```text +layout-assigned type: + !pto.vmi.vreg> + !pto.vmi.mask> + +or explicit helper: + pto.vmi.ensure_layout + pto.vmi.ensure_mask_layout + pto.vmi.ensure_mask_granularity +``` + +Every context-sensitive op must also have a selected plan if layout alone does +not uniquely identify the lowering: + +```text +vmi.selected_plan = "dense_load_norm" +vmi.selected_plan = "load_dintlv2" +vmi.selected_plan = "load_dintlv4" +vmi.selected_plan = "group_load_contiguous_chunks" +vmi.selected_plan = "s16_group_load_block8_unit_stride" +vmi.selected_plan = "s16_group_load_block8_stride" +vmi.selected_plan = "s32_group_load_block8_stride" +vmi.selected_plan = "s8_reduce_contiguous" +vmi.selected_plan = "s16_reduce_parity" +vmi.selected_plan = "s16_reduce_block8" +vmi.selected_plan = "s32_reduce_dintlv4" +vmi.selected_plan = "s32_reduce_block8_stride" +vmi.selected_plan = "s64_reduce_row_local" +vmi.selected_plan = "group_slot_load_slots8_unit_stride" +vmi.selected_plan = "group_slot_load_slots1_row_local" +vmi.selected_plan = "group_broadcast_slots8_vselr" +vmi.selected_plan = "group_broadcast_slots1_vselr" +vmi.selected_plan = "group_slot_cast_slots1_f32_to_f16" +``` + +The spelling above is illustrative; implementation may use an enum attr. The +invariant is not illustrative: if a lowering decision is not uniquely implied +by op + assigned operand/result layouts + explicit attrs, assignment must write +a selected plan. + +## 5. Plan Registry + +The compiler owns a target-aware plan registry. Layout assignment queries this +registry; vmi-to-vpto verifies and consumes the chosen plan. + +### 5.1 Plan Kinds + +```text +ProducerPlan: + op can produce result layout L + example: load -> deinterleaved=4 using DINTLV_B32 + vdintlv + +ConsumerPlan: + op can consume operand layout L + example: group_reduce S=32 consumes deinterleaved=4 + +TransferPlan: + op ties operand/result layouts + example: addf requires same dense layout for operands/result + +MaterializationPlan: + layout A -> layout B without changing logical value + example: deinterleaved=4 -> contiguous by vintlv tree + +RematerializationPlan: + cheap producer can be cloned for a use-site layout + example: broadcast/create_mask/group_broadcast + +DiagnosticPlan: + known unsupported semantic/capability boundary + example: compact S=12 requires gather materialization +``` + +### 5.2 Dense Plans From Cases + +```text +f16 -> f32: + source contiguous f16 + result deinterleaved=2, block_elems=1 + +f8 -> f32: + source contiguous f8 + result deinterleaved=4, block_elems=1 + +f32 -> f16: + source deinterleaved=2, block_elems=1 + result contiguous f16 + +f32 -> f8: + source deinterleaved=4, block_elems=1 + result contiguous f8 + +elementwise dense: + all dense operands/results share the same layout + +broadcast scalar: + rematerializable to any dense layout requested by the consumer + +load: + may be rematerialized per use when two consumers request incompatible dense + layouts, such as S=16 deinterleaved=2 and S=32 deinterleaved=4 +``` + +### 5.3 Group Plans From Cases + +```text +group_reduce f32 S=8: + input contiguous + result group_slots(G, slots=8) + +group_reduce f32 S=16: + legal input layout A: deinterleaved=2, block_elems=1 + legal input layout B: deinterleaved=2, block_elems=8 + result group_slots(G, slots=8) + +group_reduce f32 S=32: + legal input layout A: deinterleaved=4, block_elems=1 + legal input layout B: deinterleaved=4, block_elems=8 + result group_slots(G, slots=8) + +group_reduce f32 S=64: + input contiguous + result group_slots(G, slots=1) + +group_slot_load: + result group_slots(G, slots=8) for packed slots + result group_slots(G, slots=1) for row-local slots + +group_broadcast: + source group_slots(G,K) + result is dense layout requested by each consumer + rematerialize per use instead of forcing one result layout + +group_store: + source group_slots(G,K) + +group_slot_cast f32 -> f16: + slots=1 row-local source/result is legal with + group_slot_cast_slots1_f32_to_f16 + slots=8 packed source is illegal unless a packed slot-preserving plan is + registered +``` + +### 5.4 Tail And Memory Safety Plans + +Mask semantics and memory legality are separate: + +```text +mask: + decides which logical lanes participate in compute/store semantics + +full_tile_readable: + decides whether a rounded-up physical load is allowed to read inactive lanes +``` + +The full-tile-readable proof must be explicit. It may be carried by a +statically shaped memref source, or by `pto.vmi.load {full_read_elems = N}` for +pointer sources. `vmi-to-vpto` consumes only this proof carrier; it does not +inspect surrounding MTE copies, producer bodies, callers, or later consumers to +decide whether inactive physical lanes are safe to read. + +Example: + +```text +S=32 tail num_groups=6: + without full_tile_readable: + fast DINTLV_B32 full-tile load is illegal + + with full_tile_readable: + full 8-row physical tile may be loaded + compute mask is PAT_VL48 per physical part + group store mask is PAT_VL6 + +S=16 grouped tail active_elems_per_group=12: + low 8-lane row half uses PAT_ALL + high 8-lane row half uses lane_mod_8 < 4 + the same split applies before and after group_broadcast + +one mask used by f32 and f16 consumers: + f32 use materializes a b32 predicate + f16 use materializes a b16 predicate + vmi-to-vpto consumes the assigned per-use mask materialization +``` + +## 6. Layout Assignment Algorithm + +`vmi-layout-assignment` is module-level. It must see function/call/control-flow +connections before choosing layouts. + +### 6.1 Variables + +Create a layout variable for: + +```text +1. every VMI OpResult +2. every VMI BlockArgument +3. every function argument/result that is allowed to carry VMI +4. every VMI mask value +``` + +Create a use-site request for: + +```text +1. every operand use that requires a specific layout +2. every control-flow yield/branch/call/return edge +3. every memory operation that requires a memory legality plan +``` + +### 6.2 Constraints + +Hard constraints: + +```text +group_slots cannot feed ordinary dense consumers +direct group-slot width-changing cast requires a slot-preserving plan +public/external VMI function boundary requires a stable ABI or diagnostic +S=32 fast tail load requires full_tile_readable or gather fallback +``` + +`slots = 1` row-local cast may satisfy the slot-preserving plan requirement. +Packed `slots = 8` f32->f16 remains a diagnostic unless a separate packed cast +or unpack/materialization plan is registered. + +Equivalence constraints: + +```text +dense add/mul/select: + operands/results use same dense layout unless an explicit materialization is + inserted at a use site + +scf.if/scf.for: + region yield operands and block arguments must have the same assigned layout + as the region result/iter_arg +``` + +Candidate constraints: + +```text +S=16 group_reduce: + choose block_elems=1 or block_elems=8 by cost and explicit assignment constraints + +one dense value feeding S=16 and S=32 group_reduce: + rematerialize a cheap producer per consumer layout, or insert an explicit + materialization plan; the final lowering pass must not pick one layout after + seeing both users + +load/group_load: + choose memory plan and result layout together + +group_broadcast: + rematerialize per dense consumer layout +``` + +### 6.3 Solving + +Recommended solving order: + +```text +1. Build function/control-flow SCCs. +2. Collect candidate plans for every op. +3. Propagate hard required layouts from consumers. +4. Propagate producer natural layouts where they are unique. +5. Resolve multi-plan ops by cost. +6. Insert use-site materialization where a value has multiple incompatible uses. +7. Rematerialize cheap producers instead of materializing when cheaper. +8. Specialize internal function signatures. +9. Emit diagnostics for unsatisfied hard constraints. +10. Rewrite VMI types and selected plan attrs. +``` + +Tie-breaking must be deterministic. Suggested priority: + +```text +1. Avoid unsupported plans. +2. Prefer rematerializing cheap producers over register materialization. +3. Prefer layouts accepted by all consumers without conversion. +4. Prefer memory-fused layout plans over load + register rearrange. +5. Prefer fewer VPTO instructions. +6. Prefer contiguous only when cost ties and no consumer requests a special layout. +``` + +## 7. Control Flow And Functions + +### 7.1 `scf.if` + +All branch yields for one result must agree on one assigned layout. If they do +not, assignment inserts materialization before `scf.yield` where possible. +The `scf.if` result type after assignment carries that layout, so +`vmi-to-vpto` does not need to inspect either branch body. + +### 7.2 `scf.for` + +Loop-carried VMI values are fixed-point variables: + +```text +initial iter_arg layout +body block argument layout +yield operand layout +loop result layout +``` + +must converge to one layout. If a body consumer needs another layout, it is a +use-site request inside the loop body. +The loop body block argument has no defining op. Its layout is therefore part +of the block argument type after assignment, not information reconstructed from +the initial value or previous iteration during lowering. + +### 7.3 Calls + +Internal/private VMI function boundaries must make layout choices explicit in +the assigned IR. The baseline implementation keeps function arguments in a +contiguous VMI ABI and inserts callee-entry `ensure_layout` helpers when the +callee body needs another layout. A later private-function optimization may +specialize signatures directly: + +```text +func @producer() -> !vmi.vreg<256xf32, deinterleaved=4> +``` + +then physicalized by `vmi-to-vpto` into multiple VPTO function results. + +Public/external VMI function boundaries are rejected until a stable VMI ABI is +defined. + +## 8. vmi-to-vpto Contract + +`vmi-to-vpto` receives layout-assigned VMI. It performs no global reasoning. + +For each op, the pattern: + +```text +1. reads operand/result layouts +2. reads selected_plan if required +3. asks TypeConverter for ordered physical values +4. emits the registered VPTO recipe +5. fails if the selected plan is missing or target capability is absent +``` + +The pattern must not: + +```text +1. inspect all users to decide result layout +2. inspect defining ops to decide source layout +3. choose between S=16 block_elems=1 and block_elems=8 +4. decide whether a load is full_tile_readable +5. decide function signature specialization +``` + +Allowed local reads are deliberately narrower: + +```text +arith.constant defining op: + allowed only to materialize an operand of the current op, such as + create_mask active_lanes or a constant memory offset + +current VMI op body/attrs: + allowed for op-local semantics, such as create_group_mask + active_elems_per_group when lowering the create_group_mask op itself + +helper materialization chain: + allowed only to strip ensure_mask_layout / ensure_mask_granularity for + static predicate analysis that does not choose a different layout or plan + +diagnostic embellishment: + allowed only to improve an already-failed capability message, such as naming + memref.subview after identity lane-to-address planning has failed +``` + +Anything else is a layout-assignment responsibility. In particular, an +unsupported producer/consumer combination must be rejected before assignment +writes a selected plan. Section 3.44 is the model: partial S=32 grouped masks +are diagnosed in `vmi-layout-assignment`, not by `vmi-to-vpto` walking from +`group_reduce_addf` to the mask producer. + +## 9. Physical Value Ordering + +The OneToN lowering order is fixed. + +```text +contiguous: + chunk0, chunk1, ... + +deinterleaved=F: + part0_chunk0, part0_chunk1, ..., + part1_chunk0, part1_chunk1, ..., + ... + part(F-1)_chunk0, ... + +group_slots(G,K): + slot_block0, slot_block1, ... +``` + +Two physical bundle entries may alias the same VPTO SSA value when the selected +plan proves they have the same contents, such as group_broadcast feeding both +parts of a `deinterleaved=2` broadcast result. Arity still follows the layout; +aliasing is not a different layout. + +## 10. Diagnostics + +Diagnostics are part of the design. They must name: + +```text +1. the VMI op +2. source logical type +3. assigned source layout +4. requested layout +5. missing plan or disabled fallback +6. suggested rewrite when available +``` + +Examples: + +```text +dense store of group_slots: + use group_store, group_broadcast, or explicit group-pack + +packed group-slot f32->f16: + group_broadcast before truncf, or keep group_store as f32 + +S=32 tail without full_tile_readable: + mark source full_tile_readable or enable stable gather fallback + +S=32 group_load with unaligned source_group_stride: + choose a stride divisible by 8 f32 elements or enable stable gather fallback + +public VMI function boundary: + make function internal, inline before assignment, or define ABI layout +``` + +## 11. Design Completion Criteria + +The design is complete only when: + +```text +1. every case in vmi-layout-lowering-cases.md maps to registered plans +2. every selected plan can be emitted without looking at producer/user context +3. every unsupported case has a precise capability diagnostic +4. every control-flow/function boundary either specializes layout or diagnoses +5. every mask has explicit data layout and predicate granularity +6. every case has an end-to-end test and simulator validation +``` diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md index 807baf841e..b111397fc9 100644 --- a/docs/designs/vmi-layout-lowering-cases.md +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -154,20 +154,22 @@ the immediately following complete endpoints. 3.6.2 group_reduce S=32 -> elemwise(rhs) -> group_store complete 3.6.3 group_reduce S=32 -> broadcast -> compute -> reduce -> store complete -3.7.1 group_reduce S=64 -> group_store complete -3.7.2 group_reduce S=64 -> elemwise(rhs) -> group_store complete +3.7.1 group_reduce S=64 -> aligned group_store complete +3.7.2 group_reduce S=64 -> elemwise(rhs) -> aligned group_store + complete 3.7.3 group_reduce S=64 -> broadcast -> compute -> reduce -> store complete +3.7.4 group_reduce S=64 -> unit-stride group_store illegal diagnostic 3.8 group_reduce -> truncf -> broadcast -> dense store complete 3.9 dense store of group slots illegal diagnostic 3.10 non-load producer feeding S=32 group_reduce complete 3.11 partial tail groups complete/diagnostic 3.12 control-flow join before group_reduce complete -3.13 direct group-slot f32 -> f16 cast illegal diagnostic +3.13 packed group-slot f32 -> f16 cast illegal diagnostic 3.14 unsupported group size illegal diagnostic 3.15 compact S=12 written as logical S=16 complete/design 3.16 group_slot_load layout contract complete -3.17 group_broadcast physical arity alias complete +3.17 group_broadcast feeding deinterleaved consumer complete 3.18 one value with dense and group-reduce consumers complete/materialization 3.19 S=16 reduce block_elems plan selection complete/diagnostic 3.20 group_slots control-flow join complete @@ -176,6 +178,25 @@ the immediately following complete endpoints. 3.23 group_broadcast with multiple dense consumers complete 3.24 mask with elementwise/select/store complete 3.25 function boundary layout specialization complete/design +3.26 S=16 grouped tail through broadcast/reduce/store complete +3.27 S=32 group_load with stride greater than group size complete +3.28 group_slot_load slots=1 aligned non-unit stride complete +3.29 one semantic mask with f32 and f16 consumers complete +3.30 masked_load tail without padding complete/diagnostic +3.31 f16->f32 feeding dense store and S=16 reduce complete +3.32 f32 feeding f8 store and S=32 reduce complete +3.33 one dense value feeding S=16 and S=32 reduces complete/materialization +3.34 S=64 group-slot result f32->f16 cast complete +3.35 group_slots fanout to group_store and broadcast complete/design +3.36 same scalar source materialized as slots=8/slots=1 complete/design +3.37 S=64 group_store with non-unit output stride complete/design +3.38 multi-tile S=32 group_reduce complete +3.39 strided S=32 group_load through broadcast/reduce complete +3.40 scalar broadcast feeding dense and grouped users complete/materialization +3.41 non-rematerializable value with incompatible users complete/materialization +3.42 group_slots scf.for loop-carried accumulator complete +3.43 internal function argument boundary materialization complete/design +3.44 masked_load grouped tail feeding S=32 reduce complete/design ``` ### 3.1 `f16 -> f32 -> store` @@ -1170,7 +1191,8 @@ VMI input: %x = pto.vmi.load %base[%off] : memref<512xf32> -> !pto.vmi.vreg<512xf32> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} -pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} +%c8 = arith.constant 8 : index +pto.vmi.group_store %sum, %sum_out[%group_off], %c8 {num_groups = 8} ``` Assigned layouts: @@ -1247,7 +1269,7 @@ Memory result: ```text for r = 0..7: - sum_out[group_tile_off + r] = reduce(row_r[0..63]) + sum_out[group_tile_off + r * 8] = reduce(row_r[0..63]) ``` #### 3.7.2 Reduce Result, Elementwise, Store @@ -1261,7 +1283,8 @@ VMI input: : !pto.ptr -> !pto.vmi.vreg<512xf32> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} %outv = pto.vmi.addf %sum, %rhs -pto.vmi.group_store %outv, %out[%group_off], %c1 {num_groups = 8} +%c8 = arith.constant 8 : index +pto.vmi.group_store %outv, %out[%group_off], %c8 {num_groups = 8} ``` Assigned layouts: @@ -1348,7 +1371,7 @@ Memory result: ```text for r = 0..7: - out[group_tile_off + r] = reduce(row_r[0..63]) + rhs[r] + out[group_tile_off + r * 8] = reduce(row_r[0..63]) + rhs[r] ``` #### 3.7.3 Reduce, Broadcast, Elementwise, Reduce, Store @@ -1362,7 +1385,8 @@ VMI input: %b = pto.vmi.group_broadcast %sum {num_groups = 8} %y = pto.vmi.mulf %x, %b %ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8} -pto.vmi.group_store %ysum, %out[%group_off], %c1 {num_groups = 8} +%c8 = arith.constant 8 : index +pto.vmi.group_store %ysum, %out[%group_off], %c8 {num_groups = 8} ``` Assigned layouts: @@ -1415,11 +1439,50 @@ Memory result: ```text for r = 0..7: s = reduce(row_r[0..63]) - out[group_tile_off + r] = + out[group_tile_off + r * 8] = reduce_i(row_r[i] * s for i = 0..63) = s * s ``` +#### 3.7.4 Unit-Stride Store Is Not A Valid Lowering Yet + +The row-local S=64 result uses one physical vreg per group with the semantic +value in lane 0: + +```text +%sum_r lane 0 = reduce(row_r[0..63]) +``` + +The current VPTO lowering for `slots = 1` group_store emits one lane-0 `vsts` +per group. Therefore unit-stride f32 output would issue stores at: + +```text +group_off + 0, group_off + 1, group_off + 2, ... +``` + +Only the first address is necessarily 32B-aligned. The remaining f32 addresses +are 4B apart and are not valid for this `vsts` lowering. The compiler must not +accept this as a clean lowering until either a pack-to-slots=8 plan or an +unaligned-store plan is selected. + +VMI input: + +```text +%c1 = arith.constant 1 : index +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Required diagnostic: + +```text +VMI-LAYOUT-CONTRACT: + pto.vmi.group_store with #pto.vmi.layout lowers + as one lane-0 vsts per group and requires constant positive row_stride + divisible by 8 f32 elements for 32B store alignment. Packed or unaligned + contiguous store lowering is not implemented. +``` + ### 3.8 `group_reduce -> truncf -> group_broadcast -> store` VMI input: @@ -1441,7 +1504,8 @@ Assigned layouts: %sum32 : !pto.vmi.vreg<128xf32, #pto.vmi.layout> %sum16 : semantic value only; not materialized as a group-slot VPTO value -%b32 : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +%b32_dense : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +%b32_split : !pto.vmi.vreg<128xf32, #pto.vmi.layout> %b16 : !pto.vmi.vreg<128xf16, #pto.vmi.layout> ``` @@ -1452,8 +1516,13 @@ group_broadcast(truncf(group_reduce(x))) == truncf(group_broadcast(group_reduce(x))) ``` -This avoids materializing a group-slot f16 value. The only cast emitted is the -existing dense `f32 deinterleaved=2 -> contiguous f16` truncation. +This avoids materializing a group-slot f16 value. Current lowering makes the +layout transition explicit: `group_broadcast` first produces a dense contiguous +f32 value, then `pto.vmi.ensure_layout` materializes the deinterleaved=2 f32 +view required by dense `f32 -> f16` truncation. A future direct +`group_broadcast -> deinterleaved=2` lowering may remove that materialization, +but it must be implemented as a `group_broadcast` selected plan rather than +hidden inside `truncf` lowering. VPTO lowering result for one full 8-row tile: @@ -1473,19 +1542,28 @@ VPTO lowering result for one full 8-row tile: : !pto.vreg<64xf32> %lane_id = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> -%broadcast_idx = pto.vshrs %lane_id, %c3_i16, %all_b32 - : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +%broadcast_idx_lo = compute index vector [0 repeated 16, 1 repeated 16, + 2 repeated 16, 3 repeated 16] + : !pto.vreg<64xi32> +%broadcast_idx_hi = compute index vector [4 repeated 16, 5 repeated 16, + 6 repeated 16, 7 repeated 16] + : !pto.vreg<64xi32> -// This vselr is the VPTO lowering of pto.vmi.group_broadcast. The later store -// only writes lanes as-is; it does not duplicate group-slot values. -%b32_rows = pto.vselr %sum32_block, %broadcast_idx +// These vselr ops are the VPTO lowering of pto.vmi.group_broadcast for the two +// dense contiguous f32 physical chunks. +%b32_rows_lo = pto.vselr %sum32_block, %broadcast_idx_lo : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%b32_rows_hi = pto.vselr %sum32_block, %broadcast_idx_hi + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> + +// ensure_layout contiguous -> deinterleaved=2 materializes the two f32 parity +// inputs expected by f32 -> f16 truncation. +%b32_even_input, %b32_odd_input = pto.vdintlv %b32_rows_lo, %b32_rows_hi + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> -// The broadcasted f32 value is dense deinterleaved=2. -// Both parity parts carry the same per-row broadcast values. -%b16_even = pto.vcvt %b32_rows, %all_b32 {part = "EVEN", rnd = "R", sat = "SAT"} +%b16_even = pto.vcvt %b32_even_input, %all_b32 {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> -%b16_odd = pto.vcvt %b32_rows, %all_b32 {part = "ODD", rnd = "R", sat = "SAT"} +%b16_odd = pto.vcvt %b32_odd_input, %all_b32 {part = "ODD", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> %all_b16 = pto.pge_b16 "PAT_ALL" @@ -1561,7 +1639,7 @@ Assigned layouts: ```text %a, %bias, %x: - !pto.vmi.vreg<256xf32, #pto.vmi.layout> + !pto.vmi.vreg<256xf32, #pto.vmi.layout> %sum: !pto.vmi.vreg<256xf32, #pto.vmi.layout> @@ -1629,7 +1707,8 @@ VMI input: %x = pto.vmi.load %base[%off] : memref<384xf32> -> !pto.vmi.vreg<384xf32> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6} -pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 6} +%c8 = arith.constant 8 : index +pto.vmi.group_store %sum, %out[%group_off], %c8 {num_groups = 6} ``` Assigned layouts: @@ -1665,7 +1744,7 @@ Memory result: ```text for r = 0..5: - out[group_tile_off + r] = reduce(row_r[0..63]) + out[group_tile_off + r * 8] = reduce(row_r[0..63]) ``` #### 3.11.2 S=32 Tail Without Full-Tile Read Contract @@ -1803,7 +1882,7 @@ VMI-LAYOUT-CONTRACT: Expected #pto.vmi.layout on every incoming value. ``` -### 3.13 Direct Group-Slot `f32 -> f16` Cast +### 3.13 Packed Group-Slot `f32 -> f16` Cast This case is intentionally illegal for the current S=16/S=32 packed group-slot layout. It prevents the compiler from treating a width-changing @@ -1922,6 +2001,10 @@ Semantics: lane i is active iff (i % S) < active_elems_per_group ``` +Current lowering support covers constant `active_elems_per_group`. Dynamic +grouped masks require a runtime lane-index predicate materializer and remain a +separate implementation item. + Ordinary `pto.vmi.create_mask %active_lanes` keeps the prefix-mask meaning: ```text @@ -1960,6 +2043,10 @@ Assigned layouts: %sum: !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%x32_for_store: + pto.vmi.ensure_layout %x32 + : #pto.vmi.layout -> #pto.vmi.layout ``` VPTO lowering result for one `8x16xf32` tile: @@ -2212,9 +2299,10 @@ silently using full-group `group_load`. VMI input: ```text -%rhs = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c1 {num_groups = 8} +%c8 = arith.constant 8 : index +%rhs = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c8 {num_groups = 8} : !pto.ptr, index -> !pto.vmi.vreg<512xf32> -pto.vmi.group_store %rhs, %out[%group_off], %c1 {num_groups = 8} +pto.vmi.group_store %rhs, %out[%group_off], %c8 {num_groups = 8} ``` Assigned layout: @@ -2231,6 +2319,8 @@ VPTO lowering result: // Emit this shape for r = 0..7. Each result value carries one semantic slot // in lane 0, matching the S=64 row-local group_reduce result layout. +// For f32, source_group_stride = 8 elements = 32B, so every lane-0 vsldb is +// aligned. %rhs_r = pto.vsldb %rhs_base[%rhs_off_plus_r], %c0_i16, %c0_i16, %one_b32 : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> @@ -2242,14 +2332,32 @@ Memory result: ```text for r = 0..7: - out[group_off + r] = rhs_base[rhs_off + r] + out[group_off + r * 8] = rhs_base[rhs_off + r * 8] +``` + +Current lowering rule: + +```text +slots = 1 group_slot_load uses one lane-0 vsldb per semantic group slot. +For f32, source_group_stride must be a positive constant divisible by 8 +elements. For f16 it must be divisible by 16 elements, and for f8 it must be +divisible by 32 elements. ``` -### 3.17 `group_broadcast` Physical Arity Alias +### 3.17 `group_broadcast` Feeding A Deinterleaved Consumer + +This case fixes a lowering invariant: `group_broadcast` itself does not infer a +consumer-specific deinterleaved result. It produces the layout selected by +layout assignment. If a later consumer requires another layout, assignment must +insert an explicit `ensure_layout`. -This case fixes a lowering invariant: a layout determines physical arity. A -`deinterleaved = 2` result has two physical bundle entries even when both -entries can reuse the same VPTO SSA value. +The current endpoint is: + +```text +group_reduce -> group_broadcast(contiguous f32) + -> ensure_layout(deinterleaved = 2) + -> truncf(contiguous f16) +``` VMI input: @@ -2272,8 +2380,12 @@ Assigned layouts: %sum: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -%b: - !pto.vmi.vreg<128xf32, #pto.vmi.layout> +%b_dense: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%b_split = pto.vmi.ensure_layout %b_dense: + #pto.vmi.layout + -> #pto.vmi.layout %h: !pto.vmi.vreg<128xf16, #pto.vmi.layout> @@ -2295,22 +2407,26 @@ VPTO lowering result: %sum_block = pto.vadd %lo_sum, %hi_sum, %sum_mask : !pto.vreg<64xf32> -%lane_id = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> -%broadcast_idx = pto.vshrs %lane_id, %c3_i16, %all_b32 - : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +// group_broadcast lowers to two contiguous f32 chunks. +%idx_lo = materialize indices [0 repeated 16, 1 repeated 16, + 2 repeated 16, 3 repeated 16] + : !pto.vreg<64xi32> +%idx_hi = materialize indices [4 repeated 16, 5 repeated 16, + 6 repeated 16, 7 repeated 16] + : !pto.vreg<64xi32> -%b_rows = pto.vselr %sum_block, %broadcast_idx +%b_lo = pto.vselr %sum_block, %idx_lo + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%b_hi = pto.vselr %sum_block, %idx_hi : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> -// Physical bundle binding for %b, not emitted VPTO ops: -// physical entry 0 = %b_rows -// physical entry 1 = %b_rows -// The layout still has two physical entries; they alias the same SSA value -// because every even/odd logical lane pair contains the same broadcast value. +// ensure_layout contiguous -> deinterleaved=2 is explicit in assigned VMI. +%b_even_input, %b_odd_input = pto.vdintlv %b_lo, %b_hi + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> -%h_even = pto.vcvt %b_rows, %all_b32 {part = "EVEN", rnd = "R", sat = "SAT"} +%h_even = pto.vcvt %b_even_input, %all_b32 {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> -%h_odd = pto.vcvt %b_rows, %all_b32 {part = "ODD", rnd = "R", sat = "SAT"} +%h_odd = pto.vcvt %b_odd_input, %all_b32 {part = "ODD", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> %all_b16 = pto.pge_b16 "PAT_ALL" @@ -2329,6 +2445,15 @@ for r = 0..7: out[r * 16 + 0 .. r * 16 + 15] = truncf(s) ``` +Required assignment rule: + +```text +`group_broadcast` layout is chosen before `vmi-to-vpto`. A width-changing +consumer such as `truncf` may require a deinterleaved f32 source, but that +requirement must be represented by `ensure_layout`; `truncf` lowering must not +look through the defining `group_broadcast` and choose a hidden broadcast shape. +``` + ### 3.18 One Value With Dense And Group-Reduce Consumers This case forces layout assignment to handle a solvable use-site conflict. One @@ -2663,27 +2788,40 @@ for r = 0..7: ### 3.21 S=32 Tail With Full-Tile-Readable Source This is the positive counterpart to section 3.11.2. Tail participation is -still expressed by masks, but the source additionally promises that reading the -rounded-up 8-row physical tile is memory-safe. +still expressed by masks, but the source must provide a static proof that +reading the rounded-up 8-row physical tile is memory-safe. That proof is +explicit: it can come from a statically shaped memref source, or from +`pto.vmi.load {full_read_elems = N}` on a pointer source. The pointer attr +means the memory interval starting at the load offset is safe to read for `N` +logical elements; it is not inferred from surrounding MTE copies or caller +context. VMI input: ```text -%x = pto.vmi.load %base[%off] {full_tile_readable} - : memref<192xf32> -> !pto.vmi.vreg<192xf32> +%x = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<192xf32> %mask = pto.vmi.create_mask %c192 : index -> !pto.vmi.mask<192xpred> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6} pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 6} ``` +Equivalent pointer-source VMI input for runtime kernels: + +```text +%x = pto.vmi.load %base[%off] {full_read_elems = 256} + : !pto.ptr -> !pto.vmi.vreg<192xf32> +``` + Assigned layouts: ```text %x: - !pto.vmi.vreg<192xf32, #pto.vmi.layout> + !pto.vmi.vreg<192xf32, #pto.vmi.layout> %mask: - !pto.vmi.mask<192xpred, #pto.vmi.layout> + !pto.vmi.mask<192xpred, + #pto.vmi.layout> %sum: !pto.vmi.vreg<192xf32, #pto.vmi.layout> @@ -2692,28 +2830,44 @@ Assigned layouts: VPTO lowering result: ```text -// Full-tile-readable allows the load plan to read the rounded-up 8-row tile. -// Only rows 0..5 are semantically active. -%data_mask = pto.pge_b32 "PAT_VL48" // 6 rows * 8 lanes per physical part -%sum_mask = pto.pge_b32 "PAT_VL6" - -%x_even_0, %x_odd_0 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" - : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> -%x_even_1, %x_odd_1 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" - : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// A statically safe full-read proof allows the load plan to read the +// rounded-up 8-row tile. Only rows 0..5 are semantically active. +%x_c0 = pto.vlds %base[%tile_off_0] + : memref<256xf32> -> !pto.vreg<64xf32> +%x_c1 = pto.vlds %base[%tile_off_1] + : memref<256xf32> -> !pto.vreg<64xf32> +%x_c2 = pto.vlds %base[%tile_off_2] + : memref<256xf32> -> !pto.vreg<64xf32> +%x_c3 = pto.vlds %base[%tile_off_3] + : memref<256xf32> -> !pto.vreg<64xf32> -%x_p0, %x_p2 = pto.vdintlv %x_even_0, %x_even_1 +%x_lo01, %x_hi01 = pto.vdintlv %x_c0, %x_c1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> -%x_p1, %x_p3 = pto.vdintlv %x_odd_0, %x_odd_1 +%x_lo23, %x_hi23 = pto.vdintlv %x_c2, %x_c3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p0, %x_p2 = pto.vdintlv %x_lo01, %x_lo23 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1, %x_p3 = pto.vdintlv %x_hi01, %x_hi23 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> -%s0 = pto.vcgadd %x_p0, %data_mask +%data_mask0, %_ = pto.plt_b32 %c48_i32 + : i32 -> !pto.mask, i32 +%data_mask1, %_ = pto.plt_b32 %c48_i32 + : i32 -> !pto.mask, i32 +%data_mask2, %_ = pto.plt_b32 %c48_i32 + : i32 -> !pto.mask, i32 +%data_mask3, %_ = pto.plt_b32 %c48_i32 + : i32 -> !pto.mask, i32 +%sum_mask, %_ = pto.plt_b32 %c6_i32 + : i32 -> !pto.mask, i32 + +%s0 = pto.vcgadd %x_p0, %data_mask0 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> -%s1 = pto.vcgadd %x_p1, %data_mask +%s1 = pto.vcgadd %x_p1, %data_mask1 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> -%s2 = pto.vcgadd %x_p2, %data_mask +%s2 = pto.vcgadd %x_p2, %data_mask2 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> -%s3 = pto.vcgadd %x_p3, %data_mask +%s3 = pto.vcgadd %x_p3, %data_mask3 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> @@ -2731,9 +2885,9 @@ for r = 0..5: out[group_off + r] = reduce(row_r[0..31]) ``` -Rows 6 and 7 may be physically loaded because of `full_tile_readable`, but -their lanes are not active in `%data_mask`, and their group slots are not stored -because `%sum_mask` is `PAT_VL6`. +Rows 6 and 7 may be physically loaded because of the safe full-read proof, but +their lanes are not active in `%data_mask*`, and their group slots are not +stored because `%sum_mask` is produced by `plt_b32 %c6_i32`. ### 3.22 `scf.for` Loop-Carried Layout @@ -2851,23 +3005,49 @@ pto.vmi.group_store %ysum, %sum_out[%group_off], %c1 {num_groups = 8} pto.vmi.store %h, %dense_out[%off] ``` -Assigned layouts: +Assigned layouts in the current implementation: ```text -%x, %b_for_mul, %y: +%x: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%x_for_reduce: !pto.vmi.vreg<128xf32, #pto.vmi.layout> %sum, %ysum: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +%b_for_mul, %y: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%y_for_reduce: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + %b_for_cast: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%b_for_cast_split: !pto.vmi.vreg<128xf32, #pto.vmi.layout> %h: !pto.vmi.vreg<128xf16, #pto.vmi.layout> ``` +The important invariant is not that both dense consumers choose the same dense +layout. It is that each use has an explicit layout boundary: + +```text +%x_for_reduce = pto.vmi.ensure_layout %x +%y_for_reduce = pto.vmi.ensure_layout %y +%b_for_cast_split = pto.vmi.ensure_layout %b_for_cast +``` + +If a future `group_broadcast -> deinterleaved` selected plan is added, layout +assignment may assign `%b_for_mul` or `%b_for_cast` directly to that layout, but +the choice must still be visible in the assigned IR and selected plan. + VPTO lowering result: ```text @@ -2887,12 +3067,17 @@ VPTO lowering result: %broadcast_idx = pto.vshrs %lane_id, %c3_i16, %all_b32 : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> -// Use 1: broadcast for the S=16 block_elems=8 multiply path. Both row halves -// use the same per-row broadcast vector. -%b_rows_for_mul = pto.vselr %sum_block, %broadcast_idx +// Use 1: broadcast for the multiply path. Current lowering materializes two +// contiguous f32 chunks, multiplies them with the original contiguous chunks, +// then deinterleaves the product for the second group_reduce. +%b_rows_for_mul_0 = pto.vselr %sum_block, %broadcast_idx_0 : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> -%y_lo = pto.vmul %x_lo, %b_rows_for_mul, %all_b32 : !pto.vreg<64xf32> -%y_hi = pto.vmul %x_hi, %b_rows_for_mul, %all_b32 : !pto.vreg<64xf32> +%b_rows_for_mul_1 = pto.vselr %sum_block, %broadcast_idx_1 + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%y0 = pto.vmul %x0, %b_rows_for_mul_0, %all_b32 : !pto.vreg<64xf32> +%y1 = pto.vmul %x1, %b_rows_for_mul_1, %all_b32 : !pto.vreg<64xf32> +%y_lo, %y_hi = pto.vdintlv %y0, %y1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> %y_lo_sum = pto.vcgadd %y_lo, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %y_hi_sum = pto.vcgadd %y_hi, %all_b32 @@ -2902,14 +3087,17 @@ VPTO lowering result: pto.vsts %ysum_block, %sum_out[%group_off], %sum_mask {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask -// Use 2: rematerialize broadcast for the f32->f16 parity cast path. The -// deinterleaved=2 physical bundle has two entries that alias this SSA value. -%b_rows_for_cast = pto.vselr %sum_block, %broadcast_idx +// Use 2: rematerialize broadcast for the f32->f16 parity cast path. +%b_rows_for_cast_0 = pto.vselr %sum_block, %broadcast_idx_0 : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> -%h_even = pto.vcvt %b_rows_for_cast, %all_b32 +%b_rows_for_cast_1 = pto.vselr %sum_block, %broadcast_idx_1 + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%cast_lo, %cast_hi = pto.vdintlv %b_rows_for_cast_0, %b_rows_for_cast_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%h_even = pto.vcvt %cast_lo, %all_b32 {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> -%h_odd = pto.vcvt %b_rows_for_cast, %all_b32 +%h_odd = pto.vcvt %cast_hi, %all_b32 {part = "ODD", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> %all_b16 = pto.pge_b16 "PAT_ALL" @@ -2962,7 +3150,7 @@ VPTO lowering result: ```text %all_b32 = pto.pge_b32 "PAT_ALL" -%m = pto.pge_b32 "PAT_VL48" +%m, %_ = pto.plt_b32 %c48_i32 : i32 -> !pto.mask, i32 %x0 = pto.vlds %base[%off] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> @@ -3101,3 +3289,1967 @@ VMI-LAYOUT-CONTRACT: stable VMI layout ABI. Mark the function internal for layout specialization, inline it before vmi-layout-assignment, or define an explicit ABI layout. ``` + +### 3.26 S=16 Grouped Tail Through Broadcast, Reduce, Store + +This case extends section 3.15.1 from `reduce -> group_store` to the full +grouped compute path. It is needed because `create_group_mask` must remain a +group-periodic mask after a `group_broadcast`; it cannot collapse to a prefix +mask or an all-true mask. + +VMI input: + +```text +%stride16 = arith.constant 16 : index +%x = pto.vmi.group_load %base[%off], %stride16 + {num_groups = 8, group_size = 16} + : !pto.ptr, index -> !pto.vmi.vreg<128xf32> +%c12 = arith.constant 12 : index +%mask = pto.vmi.create_group_mask %c12 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%b = pto.vmi.group_broadcast %sum {num_groups = 8} +%y = pto.vmi.mulf %x, %b +%ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8} +pto.vmi.group_store %ysum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x, %b, %y: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + +%mask: + !pto.vmi.mask<128xpred, + #pto.vmi.layout> + +%sum, %ysum: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +VPTO lowering result for one `8x16xf32` tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%lane = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> +%row = pto.vshrs %lane, %c3_i16, %all_b32 + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +%row8 = pto.vshls %row, %c3_i16, %all_b32 + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +%col = pto.vsub %lane, %row8, %all_b32 + : !pto.vreg<64xi32> +%hi4_mask = pto.vcmps %col, %c4_i32, %all_b32, "lt" + : !pto.vreg<64xi32>, i32, !pto.mask -> !pto.mask + +%x_lo, %x_hi = pto.vldsx2 %base[%tile_off], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_lo_sum = pto.vcgadd %x_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%x_hi_sum = pto.vcgadd %x_hi, %hi4_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum_block = pto.vadd %x_lo_sum, %x_hi_sum, %sum_mask + : !pto.vreg<64xf32> + +%broadcast_idx = pto.vshrs %lane, %c3_i16, %all_b32 + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +%b_rows = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> + +%y_lo = pto.vmul %x_lo, %b_rows, %all_b32 : !pto.vreg<64xf32> +%y_hi = pto.vmul %x_hi, %b_rows, %hi4_mask : !pto.vreg<64xf32> + +%y_lo_sum = pto.vcgadd %y_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%y_hi_sum = pto.vcgadd %y_hi, %hi4_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%ysum_block = pto.vadd %y_lo_sum, %y_hi_sum, %sum_mask + : !pto.vreg<64xf32> + +pto.vsts %ysum_block, %out[%group_tile_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + s = reduce(row_r[0..11]) + out[group_tile_off + r] = + reduce_i(row_r[i] * s for i = 0..11) + = s * s +``` + +Required assignment rule: + +```text +%mask is a grouped mask with S=16 and active_elems_per_group=12. +For the low half, the physical predicate is PAT_ALL. +For the high half, the physical predicate is lane_mod_8 < 4. +The same split must be reused for both group_reduce operations. +``` + +### 3.27 S=32 `group_load` With Stride Greater Than Group Size + +This case is the S=32 counterpart to section 3.15.2. The logical group is +`32xf32`, but rows in memory have a larger stride. The fast plan is legal only +when the stride is a multiple of one 32B f32 block. + +VMI input: + +```text +%stride40 = arith.constant 40 : index +%x = pto.vmi.group_load %base[%off], %stride40 + {num_groups = 8, group_size = 32} + : !pto.ptr, index -> !pto.vmi.vreg<256xf32> +%mask = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<256xf32, + #pto.vmi.layout> + +%mask: + !pto.vmi.mask<256xpred, + #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +// source_group_stride = 40 f32 = 5 * 32B blocks. +%stride_blocks = %c5_i16 + +%frag0 = pto.vsldb %base_frag0, %stride_blocks, %c0_i16, %all_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%frag1 = pto.vsldb %base_frag1, %stride_blocks, %c0_i16, %all_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%frag2 = pto.vsldb %base_frag2, %stride_blocks, %c0_i16, %all_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%frag3 = pto.vsldb %base_frag3, %stride_blocks, %c0_i16, %all_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +%frag0 lanes r*8 .. r*8+7 = row_r[0..7] +%frag1 lanes r*8 .. r*8+7 = row_r[8..15] +%frag2 lanes r*8 .. r*8+7 = row_r[16..23] +%frag3 lanes r*8 .. r*8+7 = row_r[24..31] + +%s0 = pto.vcgadd %frag0, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %frag1, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s2 = pto.vcgadd %frag2, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s3 = pto.vcgadd %frag3, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> + +pto.vsts %sum_block, %out[%group_tile_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_tile_off + r] = + reduce(base[tile_off + r * 40 + 0 .. tile_off + r * 40 + 31]) +``` + +Required diagnostic when the stride is not block-aligned: + +```text +VMI-LAYOUT-CONTRACT: + pto.vmi.group_load group_size 32 with source_group_stride not divisible by + 8 f32 elements cannot use the registered vsldb strided-block plan. Enable a + stable gather plan or choose a block-aligned source_group_stride. +``` + +Required assignment rule: + +```text +This producer selects the S=32 block-fragment plan: + #pto.vmi.layout + +It must not be unified with the contiguous-load S=32 plan from section 3.6: + #pto.vmi.layout + +Both layouts are legal inputs to group_reduce_addf S=32, but they require +different producer materialization plans. +``` + +### 3.28 `group_slot_load` `slots = 1` With Aligned Non-Unit Stride + +Section 3.16.1 diagnoses non-unit stride for the packed `slots = 8` plan. The +row-local `slots = 1` plan supports non-unit stride only when each one-lane +load can be issued as an aligned `vsldb`. In the current lowering this means +the stride is a positive compile-time constant and is divisible by the 32B +alignment expressed in source elements. + +VMI input: + +```text +%c8 = arith.constant 8 : index +%rhs = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c8 {num_groups = 8} + : !pto.ptr, index -> !pto.vmi.vreg<512xf32> +pto.vmi.group_store %rhs, %out[%group_off], %c8 {num_groups = 8} +``` + +Assigned layout: + +```text +%rhs: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%one_b32 = pto.pge_b32 "PAT_VL1" + +// Emit this shape for r = 0..7. The address expression is scalar/index +// arithmetic outside the vector register layout. For f32, %c8 is 32B. +%addr_r = %rhs_base + %rhs_off + r * 8 +%rhs_r = pto.vsldb %addr_r, %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +pto.vsts %rhs_r, %out[%group_tile_off_r], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r * 8] = rhs_base[rhs_off + r * 8] +``` + +Required assignment rule: + +```text +If a non-unit-stride group_slot_load has only slots=1 consumers and its stride +is a positive constant divisible by the element count of 32B, select +group_slot_load_slots1_row_local. Do not diagnose it using the slots=8 +unit-stride restriction. +``` + +Required diagnostic: + +```text +%c2 = arith.constant 2 : index +%bad = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c2 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<512xf32> + +VMI-UNSUPPORTED: pto.vmi.group_slot_load + slots=1 group_slot_load currently lowers as one lane-0 vsldb per group and + requires constant positive source_group_stride divisible by 8 elements for + 32B load alignment; packed or unaligned scalar load lowering is not + implemented. +``` + +Dynamic stride has the same status until a stable gather or scalarized packed +load plan is designed: + +```text +%bad = pto.vmi.group_slot_load %rhs_base[%rhs_off], %runtime_stride + {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<512xf32> + +VMI-UNSUPPORTED: pto.vmi.group_slot_load + requires constant positive source_group_stride divisible by 8 elements. +``` + +### 3.29 One Semantic Mask With f32 And f16 Consumers + +One VMI mask may feed consumers with different physical predicate +granularities. Layout assignment must keep the semantic mask value single, but +materialize per-use physical masks after element type is known. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> +%mask = pto.vmi.create_mask %c96 + : index -> !pto.vmi.mask<128xpred> +pto.vmi.masked_store %x, %out32[%off], %mask +%h = pto.vmi.truncf %x + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> +pto.vmi.masked_store %h, %out16[%off], %mask +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%mask: + !pto.vmi.mask<128xb32, #pto.vmi.layout> + +%x_for_cast: + pto.vmi.ensure_layout %x + : #pto.vmi.layout -> #pto.vmi.layout + +%mask_for_h_store: + pto.vmi.create_mask %c96 + : index -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + +%h: + !pto.vmi.vreg<128xf16, #pto.vmi.layout> +``` + +Physical mask materialization: + +```text +use at masked_store %x: + predicate granularity b32, PAT_VL96, layout contiguous + +use at vcvt %x -> %h: + predicate granularity b32, PAT_ALL. The cast may compute inactive lanes + because the following masked_store controls the external memory effect. + +use at masked_store %h: + predicate granularity b16, PAT_VL96, layout contiguous +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%mask32_0 = pto.pge_b32 "PAT_ALL" +%mask32_1 = pto.pge_b32 "PAT_VL32" + +%x0 = pto.vlds %base[%off] + : !pto.ptr, index -> !pto.vreg<64xf32> +%x1 = pto.vlds %base[%off_plus_64] + : !pto.ptr, index -> !pto.vreg<64xf32> + +pto.vsts %x0, %out32[%off], %mask32_0 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %x1, %out32[%off_plus_64], %mask32_1 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + +%x_p0, %x_p1 = pto.vdintlv %x0, %x1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%h_even = pto.vcvt %x_p0, %all_b32 {part = "EVEN", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%h_odd = pto.vcvt %x_p1, %all_b32 {part = "ODD", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + +%all_b16 = pto.pset_b16 "PAT_ALL" +%h0 = pto.vor %h_even, %h_odd, %all_b16 + : !pto.vreg<128xf16> +%mask_b16, %scalar_out = pto.plt_b16 %c96_i32 + : i32 -> !pto.mask, i32 +pto.vsts %h0, %out16[%off], %mask_b16 {dist = "NORM_B16"} + : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for i = 0..95: + out32[off + i] = base[off + i] + out16[off + i] = truncf(base[off + i]) + +for i = 96..127: + out32[off + i] is unchanged + out16[off + i] is unchanged +``` + +Required assignment rule: + +```text +`vmi-to-vpto` must not decide mask granularity by inspecting users. It consumes +the per-use typed mask materialization inserted by vmi-layout-assignment. For +a rematerializable `create_mask`, assignment may clone it as b32/b16 masks. For +a non-rematerializable mask producer, assignment must insert +`ensure_mask_granularity` or diagnose if no materialization plan is registered. +``` + +### 3.30 `masked_load` Tail Without Padding + +This case is the replacement for `vector.transfer_read` padding semantics in the +initial VMI surface. Tail lanes are expressed by a mask and a passthrough value; +there is no implicit padding constant in the load. The direct lowering is legal +only when every physical chunk read by `vlds` is memory-safe. + +VMI input: + +```text +%c100 = arith.constant 100 : index +%mask = pto.vmi.create_mask %c100 : index -> !pto.vmi.mask<100xpred> +%zero = pto.vmi.broadcast %c0_f32 : f32 -> !pto.vmi.vreg<100xf32> +%x = pto.vmi.masked_load %base[%c0], %mask, %zero + : memref<128xf32>, !pto.vmi.mask<100xpred>, !pto.vmi.vreg<100xf32> + -> !pto.vmi.vreg<100xf32> +pto.vmi.store %x, %out[%c0] +``` + +Assigned layouts: + +```text +%mask: + !pto.vmi.mask<100xb32, #pto.vmi.layout> + +%zero, %x: + !pto.vmi.vreg<100xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%m0 = pto.pge_b32 "PAT_ALL" +%m1 = pto.pge_b32 "PAT_VL36" + +%zero0 = pto.vdup %c0_f32, %m0 + : f32, !pto.mask -> !pto.vreg<64xf32> +%zero1 = pto.vdup %c0_f32, %m0 + : f32, !pto.mask -> !pto.vreg<64xf32> + +%l0 = pto.vlds %base[%c0] + : memref<128xf32> -> !pto.vreg<64xf32> +%x0 = pto.vsel %l0, %zero0, %m0 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%l1 = pto.vlds %base[%c64] + : memref<128xf32> -> !pto.vreg<64xf32> +%x1 = pto.vsel %l1, %zero1, %m1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +pto.vsts %x0, %out[%c0], %m0 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, memref<128xf32>, !pto.mask +pto.vsts %x1, %out[%c64], %m1 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, memref<128xf32>, !pto.mask +``` + +Memory result: + +```text +for i = 0..99: + out[i] = base[i] + +for i = 100..127: + out[i] is unchanged +``` + +Required diagnostic when the source cannot prove a safe full-read footprint: + +```text +VMI-UNSUPPORTED: + pto.vmi.masked_load direct lowering requires a supported memory source, + contiguous result/passthru/mask layouts, and either full physical chunks or a + statically safe full-read footprint. Use a memref with enough static extent, + enable the future stable masked/gather load plan, or make the logical vector a + full physical chunk. +``` + +Required assignment rule: + +```text +`masked_load` requests contiguous result, passthru, and mask layouts. Padding +is not a layout decision; it is the explicit passthrough operand selected by the +user. +``` + +### 3.31 `f16 -> f32` Feeding Dense Store And S=16 Reduce + +This case proves that the `deinterleaved = 2` layout produced by widening +`f16 -> f32` is not just a store layout. It must also be a legal S=16 grouped +reduction input. Layout assignment must not force the reduce consumer to +`block_elems = 8` and then rematerialize the widened value. + +VMI input: + +```text +%x16 = pto.vmi.load %base[%off] + : memref<128xf16> -> !pto.vmi.vreg<128xf16> +%x32 = pto.vmi.extf %x16 + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> +%mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> +%sum = pto.vmi.group_reduce_addf %x32, %mask {num_groups = 8} +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} +pto.vmi.store %x32, %dense_out[%off] +``` + +Assigned layouts: + +```text +%x16: + !pto.vmi.vreg<128xf16, #pto.vmi.layout> + +%x32: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%mask: + !pto.vmi.mask<128xb32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b16 = pto.pge_b16 "PAT_ALL" +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%x16_0 = pto.vlds %base[%off] + : memref<128xf16> -> !pto.vreg<128xf16> +%x32_p0 = pto.vcvt %x16_0, %all_b16 {part = "EVEN"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +%x32_p1 = pto.vcvt %x16_0, %all_b16 {part = "ODD"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + +%s0 = pto.vcgadd %x32_p0, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %x32_p1, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum_block = pto.vadd %s0, %s1, %sum_mask + : !pto.vreg<64xf32> + +pto.vsts %sum_block, %sum_out[%group_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, memref<8xf32>, !pto.mask + +%dense0, %dense1 = pto.vintlv %x32_p0, %x32_p1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +pto.vsts %dense0, %dense_out[%off], %all_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, memref<128xf32>, !pto.mask +pto.vsts %dense1, %dense_out[%off_plus_64], %all_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, memref<128xf32>, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + sum_out[group_off + r] = + reduce(extf(base[off + r * 16 + 0 .. off + r * 16 + 15])) + +for i = 0..127: + dense_out[off + i] = extf(base[off + i]) +``` + +Required assignment rule: + +```text +When S=16 group_reduce consumes an existing `deinterleaved = 2` dense value, +the reduce plan must accept `block_elems = 1`. `block_elems = 8` is only a +producer-driven fast plan for block-fragment loads, not the semantic +requirement of S=16 reduction. +``` + +### 3.32 `f32` Feeding f8 Store And S=32 Reduce + +This is the `f32 -> f8` counterpart to section 3.31. A 256-lane f32 value can +serve both `truncf -> f8` and S=32 group reduction with the same +`deinterleaved = 4, block_elems = 1` layout. The value must not be forced to a +block-fragment `block_elems = 8` layout unless its producer requires that plan. + +VMI input: + +```text +%x32 = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> +%mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> +%sum = pto.vmi.group_reduce_addf %x32, %mask {num_groups = 8} +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} +%x8 = pto.vmi.truncf %x32 + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8> +pto.vmi.store %x8, %out8[%off] +``` + +Assigned layouts: + +```text +%x32: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%mask: + !pto.vmi.mask<256xb32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%x8: + !pto.vmi.vreg<256xf8, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%x0 = pto.vlds %base[%off] : memref<256xf32> -> !pto.vreg<64xf32> +%x1 = pto.vlds %base[%off_plus_64] : memref<256xf32> -> !pto.vreg<64xf32> +%x2 = pto.vlds %base[%off_plus_128] : memref<256xf32> -> !pto.vreg<64xf32> +%x3 = pto.vlds %base[%off_plus_192] : memref<256xf32> -> !pto.vreg<64xf32> + +%x01_lo, %x01_hi = pto.vdintlv %x0, %x1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x23_lo, %x23_hi = pto.vdintlv %x2, %x3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p0, %x_p2 = pto.vdintlv %x01_lo, %x23_lo + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1, %x_p3 = pto.vdintlv %x01_hi, %x23_hi + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%s0 = pto.vcgadd %x_p0, %all_b32 : !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %all_b32 : !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %all_b32 : !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %all_b32 : !pto.vreg<64xf32> +%s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> + +pto.vsts %sum_block, %sum_out[%group_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, memref<8xf32>, !pto.mask + +%x8_p0 = pto.vcvt %x_p0, %all_b32 {part = "P0", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8> +%x8_p1 = pto.vcvt %x_p1, %all_b32 {part = "P1", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8> +%x8_p2 = pto.vcvt %x_p2, %all_b32 {part = "P2", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8> +%x8_p3 = pto.vcvt %x_p3, %all_b32 {part = "P3", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8> + +%x8_01 = pto.vor %x8_p0, %x8_p1, PAT_ALL_B8 + : !pto.vreg<256xf8> +%x8_23 = pto.vor %x8_p2, %x8_p3, PAT_ALL_B8 + : !pto.vreg<256xf8> +%x8_0 = pto.vor %x8_01, %x8_23, PAT_ALL_B8 + : !pto.vreg<256xf8> + +pto.vsts %x8_0, %out8[%off], PAT_ALL_B8 {dist = "NORM_B8"} + : !pto.vreg<256xf8>, memref<256xf8>, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + sum_out[group_off + r] = + reduce(base[off + r * 32 + 0 .. off + r * 32 + 31]) + +for i = 0..255: + out8[off + i] = truncf(base[off + i]) +``` + +Required assignment rule: + +```text +The common layout selected for `%x32` is +`#pto.vmi.layout`. This satisfies both +`truncf f32 -> f8` and S=32 `group_reduce_addf`. A later strided block-load +producer may introduce `block_elems = 8`, but that is a different case and +requires an explicit materialization/rematerialization decision. +``` + +### 3.33 One Dense Value Feeding S=16 And S=32 Reduces + +This case is a pure layout-assignment conflict. The same logical +`256xf32` value is consumed by two legal reductions, but their efficient input +layouts are different: + +```text +S=16 reduce over 16 groups: + #pto.vmi.layout + +S=32 reduce over 8 groups: + #pto.vmi.layout +``` + +The program is semantically legal. Layout assignment must solve it by cloning +or rematerializing the cheap load for one use, or by inserting an explicit +registered materialization plan. `vmi-to-vpto` must not inspect both users and +choose one locally. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> + +%mask16 = pto.vmi.create_group_mask %c16 {num_groups = 16, group_size = 16} + : index -> !pto.vmi.mask<256xpred> +%sum16 = pto.vmi.group_reduce_addf %x, %mask16 {num_groups = 16} +pto.vmi.group_store %sum16, %out16[%group_off16], %c1 {num_groups = 16} + +%mask32 = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> +%sum32 = pto.vmi.group_reduce_addf %x, %mask32 {num_groups = 8} +pto.vmi.group_store %sum32, %out32[%group_off32], %c1 {num_groups = 8} +``` + +Assigned layouts after rematerializing the load: + +```text +%x_s16: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%mask16: + !pto.vmi.mask<256xpred, #pto.vmi.layout> + +%sum16: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%x_s32: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%mask32: + !pto.vmi.mask<256xpred, #pto.vmi.layout> + +%sum32: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum8_mask = pto.pge_b32 "PAT_VL8" + +// Rematerialized S=16 use. The first vldsx2 covers rows 0..7, the second +// covers rows 8..15. Each pair is deinterleaved by element parity. +%s16_p0, %s16_p1 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%s16_p2, %s16_p3 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%s16_0 = pto.vcgadd %s16_p0, %all_b32 : !pto.vreg<64xf32> +%s16_1 = pto.vcgadd %s16_p1, %all_b32 : !pto.vreg<64xf32> +%s16_2 = pto.vcgadd %s16_p2, %all_b32 : !pto.vreg<64xf32> +%s16_3 = pto.vcgadd %s16_p3, %all_b32 : !pto.vreg<64xf32> + +%sum16_lo = pto.vadd %s16_0, %s16_1, %sum8_mask + : !pto.vreg<64xf32> +%sum16_hi = pto.vadd %s16_2, %s16_3, %sum8_mask + : !pto.vreg<64xf32> + +pto.vsts %sum16_lo, %out16[%group_off16], %sum8_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %sum16_hi, %out16[%group_off16_plus_8], %sum8_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + +// Rematerialized S=32 use. Two DINTLV loads plus one register deinterleave +// level produce mod-4 columns for rows 0..7. +%x_even_0, %x_odd_0 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_even_1, %x_odd_1 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_p0, %x_p2 = pto.vdintlv %x_even_0, %x_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1, %x_p3 = pto.vdintlv %x_odd_0, %x_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%s32_0 = pto.vcgadd %x_p0, %all_b32 : !pto.vreg<64xf32> +%s32_1 = pto.vcgadd %x_p1, %all_b32 : !pto.vreg<64xf32> +%s32_2 = pto.vcgadd %x_p2, %all_b32 : !pto.vreg<64xf32> +%s32_3 = pto.vcgadd %x_p3, %all_b32 : !pto.vreg<64xf32> + +%s32_01 = pto.vadd %s32_0, %s32_1, %sum8_mask : !pto.vreg<64xf32> +%s32_23 = pto.vadd %s32_2, %s32_3, %sum8_mask : !pto.vreg<64xf32> +%sum32_block = pto.vadd %s32_01, %s32_23, %sum8_mask : !pto.vreg<64xf32> + +pto.vsts %sum32_block, %out32[%group_off32], %sum8_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..15: + out16[group_off16 + r] = + reduce(base[off + r * 16 + 0 .. off + r * 16 + 15]) + +for r = 0..7: + out32[group_off32 + r] = + reduce(base[off + r * 32 + 0 .. off + r * 32 + 31]) +``` + +Required assignment rule: + +```text +If a cheap producer such as load can produce both requested layouts, clone or +rematerialize it at the use sites and assign each clone independently. If the +producer is not rematerializable and no deinterleaved=2 <-> deinterleaved=4 +materialization plan is registered, emit a layout-contract diagnostic naming +both consumers and both required layouts. +``` + +### 3.34 S=64 Group-Slot Result `f32 -> f16` Cast + +Section 3.13 rejects direct width-changing cast for packed `slots = 8` +group-slot values. This case is the positive counterpart for row-local +`slots = 1`: each group result is already lane 0 of its own physical vreg, so a +slot-preserving cast can lower one row-local result at a time. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<512xf32> -> !pto.vmi.vreg<512xf32> +%mask = pto.vmi.create_group_mask %c64 {num_groups = 8, group_size = 64} + : index -> !pto.vmi.mask<512xpred> +%sum32 = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%sum16 = pto.vmi.truncf %sum32 + : !pto.vmi.vreg<512xf32> -> !pto.vmi.vreg<512xf16> +pto.vmi.group_store %sum16, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> + +%sum32: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> + +%sum16: + !pto.vmi.vreg<512xf16, #pto.vmi.layout> +``` + +VPTO lowering result for one full 8-row tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%block8 = pto.pge_b32 "PAT_VL8" +%one_b32 = pto.pge_b32 "PAT_VL1" +%one_b16 = pto.pge_b16 "PAT_VL1" + +// The compiler emits this row-local sequence for r = 0..7. +%x_r = pto.vlds %base[%row_off_r] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%p_r = pto.vcgadd %x_r, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum32_r = pto.vcadd %p_r, %block8 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Only lane 0 is semantic. EVEN keeps f32 lane 0 in f16 lane 0; all other +// lanes are non-semantic for group_slots(num_groups=8, slots=1). +%sum16_r = pto.vcvt %sum32_r, %one_b32 {part = "EVEN", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + +pto.vsts %sum16_r, %out[%group_tile_off_r], %one_b16 {dist = "NORM_B16"} + : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r] = + truncf(reduce(base[off + r * 64 + 0 .. off + r * 64 + 63])) +``` + +Required assignment rule: + +```text +Group-slot casts are layout-specific. `slots = 1` may use a slot-preserving +row-local cast because each semantic scalar is lane 0 of its own physical vreg. +This does not legalize packed `slots = 8` casts from section 3.13. +``` + +### 3.35 `group_slots` Fanout To `group_store` And `group_broadcast` + +This case fixes the fanout rule for sparse values. A `group_slots` value may +feed multiple group-aware consumers directly. Layout assignment must not +materialize it as dense just because one later use broadcasts it. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> +%mask = pto.vmi.create_group_mask %c16 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} + +%b = pto.vmi.group_broadcast %sum {num_groups = 8} +%y = pto.vmi.mulf %x, %b +%ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8} +pto.vmi.group_store %ysum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%x_for_reduce: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + +%mask_for_reduce: + !pto.vmi.mask<128xb32, + #pto.vmi.layout> + +%sum, %ysum: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%b, %y: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%y_for_reduce: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> +``` + +VPTO lowering result for one full 8-row tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%slot8 = pto.pge_b32 "PAT_VL8" + +%x0 = pto.vlds %base[%tile_off] + : !pto.ptr, index -> !pto.vreg<64xf32> +%x1 = pto.vlds %base[%tile_off_plus_64] + : !pto.ptr, index -> !pto.vreg<64xf32> + +// ensure_layout for the first group_reduce. +%x_lo, %x_hi = pto.vdintlv %x0, %x1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%lo_sum = pto.vcgadd %x_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%hi_sum = pto.vcgadd %x_hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum_block = pto.vadd %lo_sum, %hi_sum, %slot8 : !pto.vreg<64xf32> + +// First sparse consumer: store the group slots without changing layout. +pto.vsts %sum_block, %sum_out[%group_off], %slot8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + +// Second sparse consumer: materialize only this use as dense grouped data. +%broadcast_idx0 = compute index vector [0 repeated 16, 1 repeated 16, + 2 repeated 16, 3 repeated 16] + : !pto.vreg<64xi32> +%broadcast_idx1 = compute index vector [4 repeated 16, 5 repeated 16, + 6 repeated 16, 7 repeated 16] + : !pto.vreg<64xi32> +%b0 = pto.vselr %sum_block, %broadcast_idx0 + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%b1 = pto.vselr %sum_block, %broadcast_idx1 + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> + +%y0 = pto.vmul %x0, %b0, %all_b32 : !pto.vreg<64xf32> +%y1 = pto.vmul %x1, %b1, %all_b32 : !pto.vreg<64xf32> + +// ensure_layout for the second group_reduce. +%y_lo, %y_hi = pto.vdintlv %y0, %y1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%y_lo_sum = pto.vcgadd %y_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%y_hi_sum = pto.vcgadd %y_hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%ysum_block = pto.vadd %y_lo_sum, %y_hi_sum, %slot8 : !pto.vreg<64xf32> + +pto.vsts %ysum_block, %out[%group_off], %slot8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + s = reduce(row_r[0..15]) + sum_out[group_off + r] = s + out[group_off + r] = reduce_i(row_r[i] * s for i = 0..15) +``` + +Required assignment rule: + +```text +`%sum` keeps one assigned layout: + #pto.vmi.layout + +`group_store` consumes that sparse layout directly. +`group_broadcast` is a use-site materialization to a dense layout. It must not +rewrite the defining `group_reduce` result or the sibling `group_store` use. +``` + +### 3.36 Same Scalar Source Materialized As `slots = 8` And `slots = 1` + +The same memory scalar stream may be used by both packed S=16 group-slot +compute and row-local S=64 group-slot compute. The two uses require different +logical vector shapes and different sparse layouts, so the source must be +rematerialized as two VMI values. There is no single `group_slots` layout that +serves both uses. + +VMI input: + +```text +%rhs16 = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c1 {num_groups = 8} + : !pto.ptr, index -> !pto.vmi.vreg<128xf32> +%x16 = pto.vmi.load %base16[%off16] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> +%sum16 = pto.vmi.group_reduce_addf %x16, %mask16 {num_groups = 8} +%out16v = pto.vmi.addf %sum16, %rhs16 +pto.vmi.group_store %out16v, %out16[%group_off16], %c1 {num_groups = 8} + +%rhs64 = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c1 {num_groups = 8} + : !pto.ptr, index -> !pto.vmi.vreg<512xf32> +%x64 = pto.vmi.load %base64[%off64] + : memref<512xf32> -> !pto.vmi.vreg<512xf32> +%sum64 = pto.vmi.group_reduce_addf %x64, %mask64 {num_groups = 8} +%out64v = pto.vmi.addf %sum64, %rhs64 +pto.vmi.group_store %out64v, %out64[%group_off64], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%rhs16, %sum16, %out16v: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%x16, %mask16: + #pto.vmi.layout + +%rhs64, %sum64, %out64v: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> + +%x64, %mask64: + #pto.vmi.layout +``` + +VPTO lowering result: + +```text +// Packed S=16 RHS: one 32B scalar block in lanes 0..7. +%slot8 = pto.pge_b32 "PAT_VL8" +%one_b32 = pto.pge_b32 "PAT_VL1" +%rhs16_block = pto.vsldb %rhs_base[%rhs_off], %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +// S=16 reduction is the section 3.5.1 shape. +%x16_lo, %x16_hi = pto.vldsx2 %base16[%tile_off16], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%s16_lo = pto.vcgadd %x16_lo, PAT_ALL_B32 : !pto.vreg<64xf32> +%s16_hi = pto.vcgadd %x16_hi, PAT_ALL_B32 : !pto.vreg<64xf32> +%sum16_block = pto.vadd %s16_lo, %s16_hi, %slot8 : !pto.vreg<64xf32> +%out16_block = pto.vadd %sum16_block, %rhs16_block, %slot8 + : !pto.vreg<64xf32> +pto.vsts %out16_block, %out16[%group_off16], %slot8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + +// Row-local S=64 RHS: rematerialize the same scalar stream into one lane-0 +// value per physical row-local result. +%rhs64_r = pto.vsldb %rhs_base[%rhs_off_plus_r], %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +// Emit this row-local reduction/add/store shape for r = 0..7. +%x64_r = pto.vlds %base64[%row_off64_r] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%p64_r = pto.vcgadd %x64_r, PAT_ALL_B32 : !pto.vreg<64xf32> +%sum64_r = pto.vcadd %p64_r, PAT_VL8_B32 : !pto.vreg<64xf32> +%out64_r = pto.vadd %sum64_r, %rhs64_r, %one_b32 : !pto.vreg<64xf32> +pto.vsts %out64_r, %out64[%group_off64_plus_r], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out16[group_off16 + r] = reduce(base16[row_r, 0..15]) + rhs_base[rhs_off + r] + out64[group_off64 + r] = reduce(base64[row_r, 0..63]) + rhs_base[rhs_off + r] +``` + +Required assignment rule: + +```text +`group_slot_load` is cheaply rematerializable. If two use sites request +different `group_slots` layouts, clone/rematerialize the load per use. Do not +invent a common layout or make `vmi-to-vpto` inspect both users. +``` + +### 3.37 S=64 `group_store` With Non-Unit Output Stride + +Packed `slots = 8` stores currently require unit output stride. Row-local +`slots = 1` does not have that restriction because each group scalar is stored +by a separate lane-0 store. + +VMI input: + +```text +%row_stride = arith.index_cast %ld : i64 to index +%x = pto.vmi.load %base[%off] + : memref<512xf32> -> !pto.vmi.vreg<512xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %row_stride {num_groups = 8} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%block8 = pto.pge_b32 "PAT_VL8" +%one_b32 = pto.pge_b32 "PAT_VL1" + +// Emit this row-local sequence for r = 0..7. +%x_r = pto.vlds %base[%row_off_r] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%p_r = pto.vcgadd %x_r, %all_b32 : !pto.vreg<64xf32> +%sum_r = pto.vcadd %p_r, %block8 : !pto.vreg<64xf32> + +%dst_r = %out + %group_off + r * %row_stride +pto.vsts %sum_r, %dst_r, %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r * row_stride] = reduce(row_r[0..63]) +``` + +Required assignment rule: + +```text +If `group_store` has non-unit row_stride and the source can legally use +`slots = 1`, assignment may select `slots = 1` to keep the store legal. If the +source is fixed to `slots = 8`, the current target plan must diagnose unless a +strided packed store materializer is registered. +``` + +### 3.38 Multi-Tile S=32 `group_reduce` + +The S=32 plan is not only a one-tile special case. For more than eight groups, +layout assignment keeps the same layout and `vmi-to-vpto` emits the same +8-row tile recipe for each physical tile. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<512xf32> -> !pto.vmi.vreg<512xf32> +%mask = pto.vmi.create_group_mask %c32 {num_groups = 16, group_size = 32} + : index -> !pto.vmi.mask<512xpred> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 16} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 16} +``` + +Assigned layouts: + +```text +%x, %mask: + !pto.vmi.vreg<512xf32, + #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +// Emit this shape for tile t = 0 and tile t = 1. +// Each tile covers eight 32-f32 rows. +%tile_base_t = %base + %off + t * 256 +%tile_out_t = %out + %group_off + t * 8 + +%x_even_0_t, %x_odd_0_t = pto.vldsx2 %tile_base_t[%c0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_even_1_t, %x_odd_1_t = pto.vldsx2 %tile_base_t[%c128], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_p0_t, %x_p2_t = pto.vdintlv %x_even_0_t, %x_even_1_t + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1_t, %x_p3_t = pto.vdintlv %x_odd_0_t, %x_odd_1_t + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%s0_t = pto.vcgadd %x_p0_t, PAT_ALL_B32 : !pto.vreg<64xf32> +%s1_t = pto.vcgadd %x_p1_t, PAT_ALL_B32 : !pto.vreg<64xf32> +%s2_t = pto.vcgadd %x_p2_t, PAT_ALL_B32 : !pto.vreg<64xf32> +%s3_t = pto.vcgadd %x_p3_t, PAT_ALL_B32 : !pto.vreg<64xf32> +%s01_t = pto.vadd %s0_t, %s1_t, PAT_VL8_B32 : !pto.vreg<64xf32> +%s23_t = pto.vadd %s2_t, %s3_t, PAT_VL8_B32 : !pto.vreg<64xf32> +%sum_block_t = pto.vadd %s01_t, %s23_t, PAT_VL8_B32 + : !pto.vreg<64xf32> + +pto.vsts %sum_block_t, %tile_out_t, PAT_VL8_B32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..15: + out[group_off + r] = + reduce(base[off + r * 32 + 0 .. off + r * 32 + 31]) +``` + +Required assignment rule: + +```text +For `group_slots(num_groups = 16, slots = 8)`, the physical arity is +`num_groups / slots = 2`. The type conversion must expose two packed result +blocks in group order. `group_store` stores both blocks with offsets +`group_off + 0` and `group_off + 8`. +``` + +### 3.39 Strided S=32 `group_load` Through Broadcast And Second Reduce + +Section 3.27 covers strided S=32 `group_load -> group_reduce -> group_store`. +This case adds the missing dense continuation. The important layout fact is +that a strided block load naturally produces +`deinterleaved = 4, block_elems = 8`; `group_broadcast` must materialize the +broadcast into that same block-fragment layout when the broadcast feeds +elementwise compute and another S=32 group reduction. + +VMI input: + +```text +%stride40 = arith.constant 40 : index +%x = pto.vmi.group_load %base[%off], %stride40 + {num_groups = 8, group_size = 32} + : !pto.ptr, index -> !pto.vmi.vreg<256xf32> +%mask = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%b = pto.vmi.group_broadcast %sum {num_groups = 8} +%y = pto.vmi.mulf %x, %b +%ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8} +pto.vmi.group_store %ysum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x, %mask, %b, %y: + !pto.vmi.vreg<256xf32, + #pto.vmi.layout> + +%sum, %ysum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%slot8 = pto.pge_b32 "PAT_VL8" +%stride_blocks = %c5_i16 // 40 f32 = 5 * 32B blocks. + +%x_p0 = pto.vsldb %base_frag0, %stride_blocks, %c0_i16, %all_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%x_p1 = pto.vsldb %base_frag1, %stride_blocks, %c0_i16, %all_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%x_p2 = pto.vsldb %base_frag2, %stride_blocks, %c0_i16, %all_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%x_p3 = pto.vsldb %base_frag3, %stride_blocks, %c0_i16, %all_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +%s0 = pto.vcgadd %x_p0, %all_b32 : !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %all_b32 : !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %all_b32 : !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %all_b32 : !pto.vreg<64xf32> +%s01 = pto.vadd %s0, %s1, %slot8 : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %slot8 : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %slot8 : !pto.vreg<64xf32> + +%lane_id = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> +%broadcast_idx = pto.vshrs %lane_id, %c3_i16, %all_b32 + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> + +// Materialize the same per-row scalar into every 32B row fragment. The four +// bundle entries have the same lane contents, but the result layout remains +// deinterleaved=4, block_elems=8 because the consumer `%y = mulf %x, %b` +// operates on the block-fragment layout. +%b_p0 = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%b_p1 = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%b_p2 = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%b_p3 = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> + +%y_p0 = pto.vmul %x_p0, %b_p0, %all_b32 : !pto.vreg<64xf32> +%y_p1 = pto.vmul %x_p1, %b_p1, %all_b32 : !pto.vreg<64xf32> +%y_p2 = pto.vmul %x_p2, %b_p2, %all_b32 : !pto.vreg<64xf32> +%y_p3 = pto.vmul %x_p3, %b_p3, %all_b32 : !pto.vreg<64xf32> + +%ys0 = pto.vcgadd %y_p0, %all_b32 : !pto.vreg<64xf32> +%ys1 = pto.vcgadd %y_p1, %all_b32 : !pto.vreg<64xf32> +%ys2 = pto.vcgadd %y_p2, %all_b32 : !pto.vreg<64xf32> +%ys3 = pto.vcgadd %y_p3, %all_b32 : !pto.vreg<64xf32> +%ys01 = pto.vadd %ys0, %ys1, %slot8 : !pto.vreg<64xf32> +%ys23 = pto.vadd %ys2, %ys3, %slot8 : !pto.vreg<64xf32> +%ysum_block = pto.vadd %ys01, %ys23, %slot8 : !pto.vreg<64xf32> + +pto.vsts %ysum_block, %out[%group_off], %slot8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + s = reduce(base[off + r * 40 + 0 .. off + r * 40 + 31]) + out[group_off + r] = + reduce_i(base[off + r * 40 + i] * s for i = 0..31) +``` + +Required assignment rule: + +```text +`block_elems` is part of dense layout compatibility. A broadcast result feeding +an elementwise op with `%x : deinterleaved=4, block_elems=8` must also be +assigned `deinterleaved=4, block_elems=8`. Reusing a +`deinterleaved=4, block_elems=1` broadcast would be a layout mismatch even +though both have four physical parts. +``` + +### 3.40 Scalar Broadcast Feeding Dense And Grouped Users + +This case fixes the rule for ordinary scalar broadcasts. A scalar broadcast is +not born with a physical layout. Layout assignment may either rematerialize it +per use, or assign the transfer-equivalent producer chain to the non-contiguous +layout requested by the grouped consumer and insert an explicit materialization +at the dense store use. The latter is the concrete plan below. + +VMI input: + +```text +%scale = pto.vmi.broadcast %scale_s + : f32 -> !pto.vmi.vreg<256xf32> +%x = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> + +%copy = pto.vmi.addf %x, %scale +pto.vmi.store %copy, %copy_out[%off] + +%mask = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> +%prod = pto.vmi.mulf %x, %scale +%sum = pto.vmi.group_reduce_addf %prod, %mask {num_groups = 8} +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x, %scale, %copy, %prod: + !pto.vmi.vreg<256xf32, + #pto.vmi.layout> + +%copy_dense = pto.vmi.ensure_layout %copy: + #pto.vmi.layout + -> #pto.vmi.layout + +%mask: + !pto.vmi.mask<256xpred, + #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%slot8 = pto.pge_b32 "PAT_VL8" + +// The shared load is assigned deinterleaved=4, block_elems=8 because the +// grouped consumer dominates the useful compute layout. +%x0 = pto.vlds %base[%off] : !pto.ptr, index -> !pto.vreg<64xf32> +%x1 = pto.vlds %base[%off_plus_64] : !pto.ptr, index -> !pto.vreg<64xf32> +%x2 = pto.vlds %base[%off_plus_128] : !pto.ptr, index -> !pto.vreg<64xf32> +%x3 = pto.vlds %base[%off_plus_192] : !pto.ptr, index -> !pto.vreg<64xf32> + +%x01_lo, %x01_hi = pto.vdintlv %x0, %x1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x23_lo, %x23_hi = pto.vdintlv %x2, %x3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p0, %x_p2 = pto.vdintlv %x01_lo, %x23_lo + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1, %x_p3 = pto.vdintlv %x01_hi, %x23_hi + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%scale_p0 = pto.vdup %scale_s, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> +%scale_p1 = pto.vdup %scale_s, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> +%scale_p2 = pto.vdup %scale_s, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> +%scale_p3 = pto.vdup %scale_s, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> + +// Dense store use: compute in deinterleaved=4, then ensure_layout materializes +// the contiguous memory order for the external effect. +%copy_p0 = pto.vadd %x_p0, %scale_p0, %all_b32 : !pto.vreg<64xf32> +%copy_p1 = pto.vadd %x_p1, %scale_p1, %all_b32 : !pto.vreg<64xf32> +%copy_p2 = pto.vadd %x_p2, %scale_p2, %all_b32 : !pto.vreg<64xf32> +%copy_p3 = pto.vadd %x_p3, %scale_p3, %all_b32 : !pto.vreg<64xf32> + +%c01_lo, %c01_hi = pto.vintlv %copy_p0, %copy_p2 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%c23_lo, %c23_hi = pto.vintlv %copy_p1, %copy_p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%copy0, %copy1 = pto.vintlv %c01_lo, %c23_lo + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%copy2, %copy3 = pto.vintlv %c01_hi, %c23_hi + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +pto.vsts %copy0, %copy_out[%off], %all_b32 {dist = "NORM_B32"} +pto.vsts %copy1, %copy_out[%off_plus_64], %all_b32 {dist = "NORM_B32"} +pto.vsts %copy2, %copy_out[%off_plus_128], %all_b32 {dist = "NORM_B32"} +pto.vsts %copy3, %copy_out[%off_plus_192], %all_b32 {dist = "NORM_B32"} + +// Grouped use: reuse the same deinterleaved operands directly. +%prod_p0 = pto.vmul %x_p0, %scale_p0, %all_b32 : !pto.vreg<64xf32> +%prod_p1 = pto.vmul %x_p1, %scale_p1, %all_b32 : !pto.vreg<64xf32> +%prod_p2 = pto.vmul %x_p2, %scale_p2, %all_b32 : !pto.vreg<64xf32> +%prod_p3 = pto.vmul %x_p3, %scale_p3, %all_b32 : !pto.vreg<64xf32> + +%s0 = pto.vcgadd %prod_p0, %all_b32 : !pto.vreg<64xf32> +%s1 = pto.vcgadd %prod_p1, %all_b32 : !pto.vreg<64xf32> +%s2 = pto.vcgadd %prod_p2, %all_b32 : !pto.vreg<64xf32> +%s3 = pto.vcgadd %prod_p3, %all_b32 : !pto.vreg<64xf32> +%s01 = pto.vadd %s0, %s1, %slot8 : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %slot8 : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %slot8 : !pto.vreg<64xf32> + +pto.vsts %sum_block, %sum_out[%group_off], %slot8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for i = 0..255: + copy_out[off + i] = base[off + i] + scale_s + +for r = 0..7: + sum_out[group_off + r] = + reduce_i(base[off + r * 32 + i] * scale_s for i = 0..31) +``` + +Required assignment rule: + +```text +`broadcast` is layout-transparent and cheaply rematerializable, but assignment +does not have to force a separate contiguous broadcast just because a dense +store exists. It may choose a common deinterleaved compute layout for +transfer-equivalent elementwise ops and insert `ensure_layout` at the dense +store. The required invariant is that this choice is explicit in the assigned +IR; `vmi-to-vpto` must not infer it by inspecting both users. +``` + +### 3.41 Non-Rematerializable Value With Incompatible Users + +This is the non-cheap counterpart to section 3.18. A `masked_load` has explicit +mask and passthrough semantics, so layout assignment should not clone it as a +normal cheap load unless the registry explicitly marks that clone legal. The +conflict is solved by inserting `ensure_layout` at one use site. + +VMI input: + +```text +%mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> +%zero = pto.vmi.broadcast %c0_f32 : f32 -> !pto.vmi.vreg<256xf32> +%x = pto.vmi.masked_load %base[%off], %mask, %zero + : memref<256xf32>, !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + +pto.vmi.store %x, %copy_out[%off] + +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x, %zero for masked_load/store: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%mask for masked_load/store: + !pto.vmi.mask<256xpred, #pto.vmi.layout> + +%x_for_reduce = pto.vmi.ensure_layout %x + : #pto.vmi.layout + -> #pto.vmi.layout + +%mask_for_reduce = pto.vmi.ensure_mask_layout %mask + : #pto.vmi.layout + -> #pto.vmi.layout + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%slot8 = pto.pge_b32 "PAT_VL8" + +%zero0 = pto.vdup %c0_f32, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> +%zero1 = pto.vdup %c0_f32, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> +%zero2 = pto.vdup %c0_f32, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> +%zero3 = pto.vdup %c0_f32, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> + +%l0 = pto.vlds %base[%off] : !pto.ptr, index -> !pto.vreg<64xf32> +%l1 = pto.vlds %base[%off_plus_64] : !pto.ptr, index -> !pto.vreg<64xf32> +%l2 = pto.vlds %base[%off_plus_128] : !pto.ptr, index -> !pto.vreg<64xf32> +%l3 = pto.vlds %base[%off_plus_192] : !pto.ptr, index -> !pto.vreg<64xf32> + +%x0 = pto.vsel %l0, %zero0, %all_b32 : !pto.vreg<64xf32> +%x1 = pto.vsel %l1, %zero1, %all_b32 : !pto.vreg<64xf32> +%x2 = pto.vsel %l2, %zero2, %all_b32 : !pto.vreg<64xf32> +%x3 = pto.vsel %l3, %zero3, %all_b32 : !pto.vreg<64xf32> + +pto.vsts %x0, %copy_out[%off], %all_b32 {dist = "NORM_B32"} +pto.vsts %x1, %copy_out[%off_plus_64], %all_b32 {dist = "NORM_B32"} +pto.vsts %x2, %copy_out[%off_plus_128], %all_b32 {dist = "NORM_B32"} +pto.vsts %x3, %copy_out[%off_plus_192], %all_b32 {dist = "NORM_B32"} + +// ensure_layout contiguous -> deinterleaved=4 at the reduce use. +%x01_lo, %x01_hi = pto.vdintlv %x0, %x1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x23_lo, %x23_hi = pto.vdintlv %x2, %x3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p0, %x_p2 = pto.vdintlv %x01_lo, %x23_lo + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1, %x_p3 = pto.vdintlv %x01_hi, %x23_hi + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%s0 = pto.vcgadd %x_p0, %all_b32 : !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %all_b32 : !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %all_b32 : !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %all_b32 : !pto.vreg<64xf32> +%s01 = pto.vadd %s0, %s1, %slot8 : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %slot8 : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %slot8 : !pto.vreg<64xf32> + +pto.vsts %sum_block, %sum_out[%group_off], %slot8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for i = 0..255: + copy_out[off + i] = base[off + i] + +for r = 0..7: + sum_out[group_off + r] = + reduce(base[off + r * 32 + 0 .. off + r * 32 + 31]) +``` + +Required assignment rule: + +```text +For non-rematerializable producers, assignment must insert a registered +use-site materialization plan, such as contiguous -> deinterleaved=4. If no +plan exists, it must diagnose at assignment time. `vmi-to-vpto` must not clone +the masked_load or choose a materialization after seeing both users. +``` + +### 3.42 `group_slots` `scf.for` Loop-Carried Accumulator + +Section 3.22 covers dense loop-carried values. Sparse group-slot values need a +separate case because the loop-carried block argument has no dense lane +semantics outside the live group slots. + +VMI input: + +```text +%acc0 = pto.vmi.group_slot_load %init[%group_off], %c1 {num_groups = 8} + : !pto.ptr, index -> !pto.vmi.vreg<128xf32> + +%acc = scf.for %k = %c0 to %steps step %c1 + iter_args(%arg = %acc0) -> !pto.vmi.vreg<128xf32> { + %x = pto.vmi.group_load %base[%tile_off_k], %c16 + {num_groups = 8, group_size = 16} + : !pto.ptr, index -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_group_mask %c16 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} + %next = pto.vmi.addf %arg, %sum + scf.yield %next : !pto.vmi.vreg<128xf32> +} + +pto.vmi.group_store %acc, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%acc0, %arg, %sum, %next, %acc: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%x: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + +%mask: + !pto.vmi.mask<128xpred, + #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%slot8 = pto.pge_b32 "PAT_VL8" +%one_b32 = pto.pge_b32 "PAT_VL1" + +%acc0_block = pto.vsldb %init[%group_off], %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +%acc_block = scf.for %k = %c0 to %steps step %c1 + iter_args(%arg_block = %acc0_block) -> !pto.vreg<64xf32> { + %lo, %hi = pto.vldsx2 %base[%tile_off_k], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %lo_sum = pto.vcgadd %lo, %all_b32 : !pto.vreg<64xf32> + %hi_sum = pto.vcgadd %hi, %all_b32 : !pto.vreg<64xf32> + %sum_block = pto.vadd %lo_sum, %hi_sum, %slot8 : !pto.vreg<64xf32> + %next_block = pto.vadd %arg_block, %sum_block, %slot8 : !pto.vreg<64xf32> + scf.yield %next_block : !pto.vreg<64xf32> +} + +pto.vsts %acc_block, %out[%group_off], %slot8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r] = + init[group_off + r] + + sum_k reduce(base[tile_k, row_r, 0..15]) +``` + +Required assignment rule: + +```text +Loop-carried `group_slots` values are valid. The iter_arg, body block +argument, yield operand, loop result, and final group_store operand all carry +the same `group_slots(num_groups=8, slots=8)` layout. Ordinary dense consumers +inside the loop still require an explicit `group_broadcast` or diagnostic. +``` + +### 3.43 Internal Function Argument Boundary Materialization + +Section 3.25 covers a private function returning a VMI value. A callee argument +is the other direction of the same ABI problem: the callee body may require a +layout that is different from the layout naturally produced at a call site. + +The current implementation keeps the internal function VMI signature +contiguous and makes the callee-entry materialization explicit with +`ensure_layout` / `ensure_mask_layout`. This is less aggressive than +specializing the VMI function signature to `deinterleaved = 4`, but it preserves +the same invariant: after layout assignment, `vmi-to-vpto` lowers only from +explicit type and helper information and does not inspect the callee body while +lowering a call. + +VMI input: + +```text +func.func private @consume(%x: !pto.vmi.vreg<256xf32>, + %mask: !pto.vmi.mask<256xpred>, + %out: !pto.ptr, %group_off: index) { + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} + pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} + return +} + +func.func @caller(%base: !pto.ptr, %off: index, + %out: !pto.ptr, %group_off: index) { + %x = pto.vmi.load %base[%off] + : !pto.ptr, index -> !pto.vmi.vreg<256xf32> + %mask = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + call @consume(%x, %mask, %out, %group_off) + : (!pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred>, + !pto.ptr, index) -> () + return +} +``` + +Assigned layouts: + +```text +@consume argument %x: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +@consume argument %mask: + !pto.vmi.mask<256xpred, #pto.vmi.layout> + +inside @consume: + %x_split = pto.vmi.ensure_layout %x + : #pto.vmi.layout + -> #pto.vmi.layout + + %mask_split = pto.vmi.ensure_mask_layout %mask + : #pto.vmi.layout + -> #pto.vmi.layout + +@caller %x and %mask: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + !pto.vmi.mask<256xpred, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result for the function boundary: + +```text +func.func private @consume(%x_p0: !pto.vreg<64xf32>, + %x_p1: !pto.vreg<64xf32>, + %x_p2: !pto.vreg<64xf32>, + %x_p3: !pto.vreg<64xf32>, + %m0: !pto.mask, + %m1: !pto.mask, + %m2: !pto.mask, + %m3: !pto.mask, + %out: !pto.ptr, + %group_off: index) { + // Callee-entry lowering of ensure_layout contiguous -> deinterleaved=4, + // block_elems=8. + %x01_lo, %x01_hi = pto.vdintlv %x_p0, %x_p1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %x23_lo, %x23_hi = pto.vdintlv %x_p2, %x_p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %x_d0, %x_d2 = pto.vdintlv %x01_lo, %x23_lo + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %x_d1, %x_d3 = pto.vdintlv %x01_hi, %x23_hi + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + + %m01_lo, %m01_hi = pto.pdintlv_b32 %m0, %m1 + : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + %m23_lo, %m23_hi = pto.pdintlv_b32 %m2, %m3 + : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + %m_d0, %m_d2 = pto.pdintlv_b32 %m01_lo, %m23_lo + : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + %m_d1, %m_d3 = pto.pdintlv_b32 %m01_hi, %m23_hi + : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + + %slot8 = pto.pge_b32 "PAT_VL8" + %s0 = pto.vcgadd %x_d0, %m_d0 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %s1 = pto.vcgadd %x_d1, %m_d1 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %s2 = pto.vcgadd %x_d2, %m_d2 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %s3 = pto.vcgadd %x_d3, %m_d3 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %s01 = pto.vadd %s0, %s1, %slot8 : !pto.vreg<64xf32> + %s23 = pto.vadd %s2, %s3, %slot8 : !pto.vreg<64xf32> + %sum_block = pto.vadd %s01, %s23, %slot8 : !pto.vreg<64xf32> + pto.vsts %sum_block, %out[%group_off], %slot8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + return +} + +func.func @caller(...) { + // Caller keeps the load and group mask in the contiguous function ABI layout. + %x0 = pto.vlds %base[%off] : !pto.ptr -> !pto.vreg<64xf32> + %x1 = pto.vlds %base[%off_plus_64] : !pto.ptr -> !pto.vreg<64xf32> + %x2 = pto.vlds %base[%off_plus_128] : !pto.ptr -> !pto.vreg<64xf32> + %x3 = pto.vlds %base[%off_plus_192] : !pto.ptr -> !pto.vreg<64xf32> + + %m0 = pto.pset_b32 "PAT_ALL" : !pto.mask + %m1 = pto.pset_b32 "PAT_ALL" : !pto.mask + %m2 = pto.pset_b32 "PAT_ALL" : !pto.mask + %m3 = pto.pset_b32 "PAT_ALL" : !pto.mask + + call @consume(%x0, %x1, %x2, %x3, %m0, %m1, %m2, %m3, %out, %group_off) + : (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.mask, !pto.mask, + !pto.mask, !pto.mask, !pto.ptr, index) -> () + return +} +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r] = + reduce(base[off + r * 32 + 0 .. off + r * 32 + 31]) +``` + +Required assignment rule: + +```text +Private function boundary layout is explicit in the assigned function type and +callee-entry helpers. The current endpoint chooses a contiguous VMI function +ABI and inserts callee-entry materialization for the grouped body requirement. +`vmi-to-vpto` does not inspect the callee body while lowering the call and does +not inspect callers while lowering the callee block argument. + +Future optimization may specialize private VMI function signatures directly to +`deinterleaved = 4, block_elems = 8` when all call sites agree. That +optimization must still be expressed in the assigned VMI function type before +`vmi-to-vpto` runs. +``` + +### 3.44 `masked_load` Grouped Tail Feeding S=32 Reduce + +This case connects the explicit `masked_load` tail model from section 3.30 with +grouped reduction. The load has no padding constant hidden in the op; inactive +lanes are provided by the passthrough value and excluded from the reduction by +the same grouped mask. + +VMI input: + +```text +%c25 = arith.constant 25 : index +%mask = pto.vmi.create_group_mask %c25 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> +%zero = pto.vmi.broadcast %c0_f32 : f32 -> !pto.vmi.vreg<256xf32> +%x = pto.vmi.masked_load %base[%off], %mask, %zero + : memref<256xf32>, !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%mask for masked_load: + !pto.vmi.mask<256xpred, #pto.vmi.layout> + +%zero, %x for masked_load: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%x_for_reduce = pto.vmi.ensure_layout %x: + #pto.vmi.layout + -> #pto.vmi.layout + +%mask_for_reduce: + pto.vmi.create_group_mask %c25 {num_groups = 8, group_size = 32} + -> !pto.vmi.mask<256xpred, + #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +Current implementation result: + +```text +VMI-UNSUPPORTED: pto.vmi.group_reduce_addf s32 block8 lowering does not yet +support partial create_group_mask active_elems_per_group during layout +assignment +``` + +This must remain a layout-assignment diagnostic until the S=32 block8 +grouped-mask lowering is proven against runtime SIM. Assignment must not write +`vmi.selected_plan = "s32_reduce_block8_stride"` for this case and leave +`vmi-to-vpto` to discover the partial mask by walking the mask defining op. A +`masked_load` can be lowered contiguously and then materialized to +`deinterleaved = 4, block_elems = 8`, but the grouped reduce still needs a +physically correct `create_group_mask` for `active_elems_per_group = 25`. +Allowing the current S=32 block8 path to proceed would not preserve the logical +memory result below. + +Intended VPTO lowering shape after the grouped-mask issue is fixed: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%slot8 = pto.pge_b32 "PAT_VL8" + +// masked_load direct lowering stays contiguous. +%m0, %m1, %m2, %m3 = materialize contiguous create_group_mask(c25, S=32) +%z0, %z1, %z2, %z3 = vdup zero +%l0 = pto.vlds %base[%off] +%l1 = pto.vlds %base[%off_plus_64] +%l2 = pto.vlds %base[%off_plus_128] +%l3 = pto.vlds %base[%off_plus_192] +%x0 = pto.vsel %l0, %z0, %m0 : !pto.vreg<64xf32> +%x1 = pto.vsel %l1, %z1, %m1 : !pto.vreg<64xf32> +%x2 = pto.vsel %l2, %z2, %m2 : !pto.vreg<64xf32> +%x3 = pto.vsel %l3, %z3, %m3 : !pto.vreg<64xf32> + +// ensure_layout contiguous -> deinterleaved=4, block_elems=8. +%x01_lo, %x01_hi = pto.vdintlv %x0, %x1 +%x23_lo, %x23_hi = pto.vdintlv %x2, %x3 +%x_p0, %x_p2 = pto.vdintlv %x01_lo, %x23_lo +%x_p1, %x_p3 = pto.vdintlv %x01_hi, %x23_hi + +// Correct deinterleaved grouped mask for active columns 0..24: +// part 0 covers columns 0..7 for every row: all active +// part 1 covers columns 8..15 for every row: all active +// part 2 covers columns 16..23 for every row: all active +// part 3 covers columns 24..31 for every row: one active lane per row +%mask_p0 = pto.pset_b32 "PAT_ALL" +%mask_p1 = pto.pset_b32 "PAT_ALL" +%mask_p2 = pto.pset_b32 "PAT_ALL" +%mask_p3 = materialize one lane per 8-lane row block + +%s0 = pto.vcgadd %x_p0, %mask_p0 : !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %mask_p1 : !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %mask_p2 : !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %mask_p3 : !pto.vreg<64xf32> +%s01 = pto.vadd %s0, %s1, %slot8 : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %slot8 : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %slot8 : !pto.vreg<64xf32> + +pto.vsts %sum_block, %out[%group_off], %slot8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r] = + reduce(base[off + r * 32 + 0 .. off + r * 32 + 24]) +``` + +Required assignment rule: + +```text +`masked_load` and `group_reduce` must share the same grouped mask layout. The +passthrough value defines inactive loaded lanes, while the reduce mask defines +participation. Assignment may select a deinterleaved S=32 load plan only when +the rounded physical reads are memory-safe; otherwise it must diagnose or use a +future stable gather fallback. + +Current implementation additionally diagnoses the S=32 block8 partial grouped +mask itself. This is deliberate: the case is not implemented until the +deinterleaved grouped-mask materialization and `vcgadd` interpretation are +validated end to end by SIM. +``` diff --git a/include/PTO/IR/VMIAttrs.td b/include/PTO/IR/VMIAttrs.td index da8428dd23..e8c44a1454 100644 --- a/include/PTO/IR/VMIAttrs.td +++ b/include/PTO/IR/VMIAttrs.td @@ -16,7 +16,9 @@ def VMILayoutAttr : PTO_Attr<"VMILayout", "vmi.layout"> { let summary = "VMI logical vector register layout"; let parameters = (ins StringRefParameter<"layout kind">:$kind, - "int64_t":$factor + "int64_t":$factor, + "int64_t":$blockElems, + "int64_t":$slots ); let hasCustomAssemblyFormat = 1; let genVerifyDecl = 1; @@ -24,9 +26,11 @@ def VMILayoutAttr : PTO_Attr<"VMILayout", "vmi.layout"> { let extraClassDeclaration = [{ static VMILayoutAttr getContiguous(::mlir::MLIRContext *context); static VMILayoutAttr getDeinterleaved(::mlir::MLIRContext *context, - int64_t factor); + int64_t factor, + int64_t blockElems = 1); static VMILayoutAttr getGroupSlots(::mlir::MLIRContext *context, - int64_t numGroups); + int64_t numGroups, + int64_t slots = 0); bool isContiguous() const { return getKind() == "contiguous"; } bool isDeinterleaved() const { return getKind() == "deinterleaved"; } diff --git a/include/PTO/IR/VMIOps.td b/include/PTO/IR/VMIOps.td index 7bd7524118..80036f9946 100644 --- a/include/PTO/IR/VMIOps.td +++ b/include/PTO/IR/VMIOps.td @@ -76,6 +76,22 @@ def VMICreateMaskOp : VMI_Op<"create_mask"> { let assemblyFormat = "$active_lanes attr-dict `:` type($active_lanes) `->` type($result)"; } +def VMICreateGroupMaskOp : VMI_Op<"create_group_mask"> { + let summary = "Create a VMI logical grouped predicate mask"; + let description = [{ + Creates a mask where lane i is active iff + `(i % group_size) < active_elems_per_group`. + }]; + let arguments = (ins + Index:$active_elems_per_group, + I64Attr:$num_groups, + I64Attr:$group_size + ); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$active_elems_per_group attr-dict `:` type($active_elems_per_group) `->` type($result)"; +} + def VMIConstantMaskOp : VMI_Op<"constant_mask"> { let summary = "VMI logical predicate mask constant"; let arguments = (ins AnyAttr:$value); @@ -437,7 +453,8 @@ def VMIBitcastOp : VMI_Op<"bitcast"> { def VMILoadOp : VMI_Op<"load", [DeclareOpInterfaceMethods]> { let summary = "VMI logical vector load"; - let arguments = (ins PtrOrMemRef:$source, Index:$offset); + let arguments = (ins PtrOrMemRef:$source, Index:$offset, + OptionalAttr:$full_read_elems); let results = (outs VMI_VRegTypeConstraint:$result); let hasVerifier = 1; let assemblyFormat = "$source `[` $offset `]` attr-dict `:` type($source) `->` type($result)"; @@ -452,6 +469,15 @@ def VMIGroupLoadOp : VMI_Op<"group_load", [DeclareOpInterfaceMethods]> { + let summary = "VMI load one scalar value per logical group into group slots"; + let arguments = (ins PtrOrMemRef:$source, Index:$offset, Index:$source_group_stride, + I64Attr:$num_groups); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `[` $offset `]` `,` $source_group_stride attr-dict `:` type($source) `->` type($result)"; +} + def VMIMaskedLoadOp : VMI_Op<"masked_load", [DeclareOpInterfaceMethods]> { let summary = "VMI logical masked vector load with passthrough lanes"; let arguments = (ins PtrOrMemRef:$source, Index:$offset, diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index d8db178278..8fe898a76f 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -29,6 +29,7 @@ #include "mlir/IR/Types.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Support/LLVM.h" +#include "mlir/Transforms/InliningUtils.h" #include "mlir/Parser/Parser.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" @@ -118,6 +119,27 @@ static bool isKnownZeroOrUnitExtent(int64_t value); static bool isByteIntegerType(Type ty); static LogicalResult verifyTileBufCommon(Operation *op, Type ty, StringRef name, bool allowLowPrecision = false); + +namespace { +struct PTOInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + bool isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const final { + return true; + } + + bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, + IRMapping &valueMapping) const final { + return true; + } + + bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, + IRMapping &valueMapping) const final { + return true; + } +}; +} // namespace static LogicalResult verifyTileBufSameElemType(Operation *op, Type lhs, Type rhs, StringRef lhsName, StringRef rhsName); @@ -2597,6 +2619,8 @@ void PTODialect::initialize() { #define GET_ATTRDEF_LIST #include "PTO/IR/PTOAttrs.cpp.inc" >(); + + addInterfaces(); } diff --git a/lib/PTO/IR/VMI.cpp b/lib/PTO/IR/VMI.cpp index e26982e347..ff7170044e 100644 --- a/lib/PTO/IR/VMI.cpp +++ b/lib/PTO/IR/VMI.cpp @@ -1,10 +1,12 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. //===- VMI.cpp - PTO VMI type and attribute support -----------------------===// //===----------------------------------------------------------------------===// @@ -36,8 +38,7 @@ static std::string formatVMIVRegType(int64_t elementCount, Type elementType, } static std::string formatVMIMaskType(int64_t elementCount, - StringRef granularity, - Attribute layout) { + StringRef granularity, Attribute layout) { std::string result; llvm::raw_string_ostream os(result); os << "!pto.vmi.mask<" << elementCount << "x" << granularity; @@ -146,6 +147,13 @@ static FailureOr getLayoutFactor(Type type) { return (*layout).isDeinterleaved() ? (*layout).getFactor() : 1; } +static FailureOr getLayoutBlockElems(Type type) { + FailureOr layout = getAssignedVMILayout(type); + if (failed(layout)) + return failure(); + return (*layout).isDeinterleaved() ? (*layout).getBlockElems() : 1; +} + static FailureOr getPhysicalLanesPerPart(Type type) { if (auto vregType = dyn_cast(type)) return getDataLanesPerPart(vregType.getElementType()); @@ -172,26 +180,29 @@ static bool isLayoutAssigned(VMIMaskType type) { return static_cast(type.getLayoutAttr()); } -static LogicalResult verifyAllSameVRegShapeAndLayout(Operation *op, - ArrayRef types, - bool requireSameElement) { +static LogicalResult +verifyAllSameVRegShapeAndLayout(Operation *op, ArrayRef types, + bool requireSameElement) { if (types.empty()) return success(); VMIVRegType first = types.front(); - bool anyLayout = llvm::any_of(types, [](VMIVRegType type) { - return isLayoutAssigned(type); - }); + bool anyLayout = llvm::any_of( + types, [](VMIVRegType type) { return isLayoutAssigned(type); }); for (VMIVRegType type : types) { if (type.getElementCount() != first.getElementCount()) - return op->emitOpError("requires all VMI data values to have the same logical lane count"); + return op->emitOpError( + "requires all VMI data values to have the same logical lane count"); if (requireSameElement && type.getElementType() != first.getElementType()) - return op->emitOpError("requires all VMI data values to have the same element type"); + return op->emitOpError( + "requires all VMI data values to have the same element type"); if (anyLayout && !isLayoutAssigned(type)) - return op->emitOpError("requires either all or no VMI data values to carry layout"); + return op->emitOpError( + "requires either all or no VMI data values to carry layout"); if (anyLayout && type.getLayout() != first.getLayout()) - return op->emitOpError("requires all layout-assigned VMI data values to have the same layout"); + return op->emitOpError("requires all layout-assigned VMI data values to " + "have the same layout"); } return success(); } @@ -203,8 +214,7 @@ static LogicalResult verifyElementwiseVRegOp(Operation *op, VMIVRegType lhs, /*requireSameElement=*/true); } -static LogicalResult verifyFloatUnaryVRegOp(Operation *op, - VMIVRegType source, +static LogicalResult verifyFloatUnaryVRegOp(Operation *op, VMIVRegType source, VMIVRegType result) { if (!isVMIFloatLikeType(source.getElementType())) return op->emitOpError("requires floating-point-like VMI element type"); @@ -221,15 +231,15 @@ static LogicalResult verifyFloatTernaryVRegOp(Operation *op, VMIVRegType lhs, /*requireSameElement=*/true); } -static LogicalResult verifyAllSameMaskShapeLayoutAndGranularity( - Operation *op, ArrayRef types) { +static LogicalResult +verifyAllSameMaskShapeLayoutAndGranularity(Operation *op, + ArrayRef types) { if (types.empty()) return success(); VMIMaskType first = types.front(); - bool anyLayout = llvm::any_of(types, [](VMIMaskType type) { - return isLayoutAssigned(type); - }); + bool anyLayout = llvm::any_of( + types, [](VMIMaskType type) { return isLayoutAssigned(type); }); for (VMIMaskType type : types) { if (type.getElementCount() != first.getElementCount()) @@ -252,11 +262,13 @@ static LogicalResult verifyAllSameMaskShapeLayoutAndGranularity( static LogicalResult verifyMaskMatchesData(Operation *op, VMIMaskType maskType, VMIVRegType dataType) { if (maskType.getElementCount() != dataType.getElementCount()) - return op->emitOpError("requires mask logical lane count to match data lane count"); + return op->emitOpError( + "requires mask logical lane count to match data lane count"); if (isLayoutAssigned(maskType) || isLayoutAssigned(dataType)) { if (!isLayoutAssigned(maskType) || !isLayoutAssigned(dataType)) - return op->emitOpError("requires either both mask and data to carry layout or neither to carry layout"); + return op->emitOpError("requires either both mask and data to carry " + "layout or neither to carry layout"); if (maskType.getLayout() != dataType.getLayout()) return op->emitOpError("requires mask layout to match data layout"); } @@ -268,7 +280,8 @@ static LogicalResult verifyMaskMatchesData(Operation *op, VMIMaskType maskType, int64_t maskBitWidth = getMaskGranularityBitWidth(maskType.getGranularity()); if (elementBitWidth != 0 && maskBitWidth != 0 && elementBitWidth != static_cast(maskBitWidth)) - return op->emitOpError("requires mask granularity to match data element width"); + return op->emitOpError( + "requires mask granularity to match data element width"); return success(); } @@ -288,9 +301,8 @@ static LogicalResult verifyMemoryElementMatches(Operation *op, Type memoryType, if (!memoryElementType) return success(); if (memoryElementType != dataType.getElementType()) - return op->emitOpError() - << "requires memory " << role - << " element type to match VMI data element type"; + return op->emitOpError() << "requires memory " << role + << " element type to match VMI data element type"; return success(); } @@ -309,24 +321,26 @@ static LogicalResult verifyPhysicalParts(Operation *op, Type vmiType, TypeRange physicalTypes) { FailureOr expectedArity = getVMIPhysicalArity(vmiType); if (failed(expectedArity)) - return op->emitOpError("requires a layout-assigned VMI type with computable physical arity"); + return op->emitOpError( + "requires a layout-assigned VMI type with computable physical arity"); if (static_cast(physicalTypes.size()) != *expectedArity) - return op->emitOpError() - << "requires " << *expectedArity << " physical parts, got " - << physicalTypes.size(); + return op->emitOpError() << "requires " << *expectedArity + << " physical parts, got " << physicalTypes.size(); if (auto vregType = dyn_cast(vmiType)) { FailureOr lanesPerPart = getDataLanesPerPart(vregType.getElementType()); if (failed(lanesPerPart)) - return op->emitOpError("requires data element type with known physical lane count"); + return op->emitOpError( + "requires data element type with known physical lane count"); for (Type physicalType : physicalTypes) { auto partType = dyn_cast(physicalType); if (!partType) return op->emitOpError("requires physical data parts to be !pto.vreg"); if (partType.getElementCount() != *lanesPerPart || partType.getElementType() != vregType.getElementType()) - return op->emitOpError("requires physical data part type to match VMI lane-map helper"); + return op->emitOpError( + "requires physical data part type to match VMI lane-map helper"); } return success(); } @@ -335,45 +349,86 @@ static LogicalResult verifyPhysicalParts(Operation *op, Type vmiType, if (!maskType) return op->emitOpError("requires VMI data or mask type"); if (maskType.isPred()) - return op->emitOpError("requires layout-assigned mask with concrete granularity"); + return op->emitOpError( + "requires layout-assigned mask with concrete granularity"); for (Type physicalType : physicalTypes) { auto partType = dyn_cast(physicalType); if (!partType) return op->emitOpError("requires physical mask parts to be !pto.mask"); if (partType.getGranularity() != maskType.getGranularity()) - return op->emitOpError("requires physical mask part granularity to match VMI mask"); + return op->emitOpError( + "requires physical mask part granularity to match VMI mask"); } return success(); } -static int64_t getLogicalLanesInPart(int64_t elementCount, int64_t factor, - int64_t part) { - if (part < 0 || part >= factor || part >= elementCount) - return 0; - return ((elementCount - 1 - part) / factor) + 1; +static std::optional +mapDenseLogicalLaneToPartIndex(int64_t elementCount, int64_t factor, + int64_t blockElems, int64_t logicalLane, + int64_t &part) { + if (logicalLane < 0 || logicalLane >= elementCount || factor <= 0 || + blockElems <= 0) + return std::nullopt; + int64_t block = logicalLane / blockElems; + int64_t inBlockLane = logicalLane % blockElems; + part = block % factor; + int64_t partBlock = block / factor; + return partBlock * blockElems + inBlockLane; +} + +static std::optional +mapDensePartIndexToLogicalLane(int64_t elementCount, int64_t factor, + int64_t blockElems, int64_t part, + int64_t indexInPart) { + if (part < 0 || part >= factor || indexInPart < 0 || factor <= 0 || + blockElems <= 0) + return std::nullopt; + int64_t partBlock = indexInPart / blockElems; + int64_t inBlockLane = indexInPart % blockElems; + int64_t logicalBlock = partBlock * factor + part; + int64_t logicalLane = logicalBlock * blockElems + inBlockLane; + if (logicalLane >= elementCount) + return std::nullopt; + return logicalLane; +} + +static int64_t getDenseLogicalLanesInPart(int64_t elementCount, int64_t factor, + int64_t blockElems, int64_t part) { + int64_t maxIndex = -1; + for (int64_t lane = 0; lane < elementCount; ++lane) { + int64_t lanePart = 0; + std::optional index = mapDenseLogicalLaneToPartIndex( + elementCount, factor, blockElems, lane, lanePart); + if (index && lanePart == part) + maxIndex = std::max(maxIndex, *index); + } + return maxIndex + 1; } } // namespace VMILayoutAttr VMILayoutAttr::getContiguous(MLIRContext *context) { - return VMILayoutAttr::get(context, "contiguous", 1); + return VMILayoutAttr::get(context, "contiguous", 1, 1, 0); } VMILayoutAttr VMILayoutAttr::getDeinterleaved(MLIRContext *context, - int64_t factor) { - return VMILayoutAttr::get(context, "deinterleaved", factor); + int64_t factor, + int64_t blockElems) { + return VMILayoutAttr::get(context, "deinterleaved", factor, blockElems, 0); } VMILayoutAttr VMILayoutAttr::getGroupSlots(MLIRContext *context, - int64_t numGroups) { - return VMILayoutAttr::get(context, "num_groups", numGroups); + int64_t numGroups, int64_t slots) { + return VMILayoutAttr::get(context, "num_groups", numGroups, 1, slots); } Attribute VMILayoutAttr::parse(AsmParser &parser, Type) { SMLoc loc = parser.getCurrentLocation(); StringRef kind; int64_t factor = 1; + int64_t blockElems = 1; + int64_t slots = 0; if (failed(parser.parseLess()) || failed(parser.parseKeyword(&kind))) return {}; @@ -383,9 +438,28 @@ Attribute VMILayoutAttr::parse(AsmParser &parser, Type) { } else if (kind == "deinterleaved") { if (failed(parser.parseEqual()) || failed(parser.parseInteger(factor))) return {}; + if (succeeded(parser.parseOptionalComma())) { + StringRef field; + if (failed(parser.parseKeyword(&field)) || field != "block_elems" || + failed(parser.parseEqual()) || + failed(parser.parseInteger(blockElems))) { + parser.emitError(parser.getCurrentLocation(), + "expected 'block_elems = '"); + return {}; + } + } } else if (kind == "num_groups") { if (failed(parser.parseEqual()) || failed(parser.parseInteger(factor))) return {}; + if (succeeded(parser.parseOptionalComma())) { + StringRef field; + if (failed(parser.parseKeyword(&field)) || field != "slots" || + failed(parser.parseEqual()) || failed(parser.parseInteger(slots))) { + parser.emitError(parser.getCurrentLocation(), + "expected 'slots = '"); + return {}; + } + } } else { parser.emitError(parser.getCurrentLocation(), "expected VMI layout kind 'contiguous' or " @@ -397,39 +471,62 @@ Attribute VMILayoutAttr::parse(AsmParser &parser, Type) { return {}; return parser.getChecked(loc, parser.getContext(), kind, - factor); + factor, blockElems, slots); } void VMILayoutAttr::print(AsmPrinter &printer) const { printer << "<" << getKind(); - if (isDeinterleaved() || isGroupSlots()) + if (isDeinterleaved()) { + printer << " = " << getFactor(); + if (getBlockElems() != 1) + printer << ", block_elems = " << getBlockElems(); + } else if (isGroupSlots()) { printer << " = " << getFactor(); + if (getSlots() != 0) + printer << ", slots = " << getSlots(); + } printer << ">"; } LogicalResult VMILayoutAttr::verify(function_ref emitError, - StringRef kind, int64_t factor) { + StringRef kind, int64_t factor, int64_t blockElems, + int64_t slots) { if (kind == "contiguous") { - if (factor != 1) + if (factor != 1 || blockElems != 1 || slots != 0) return emitError() - << "#pto.vmi.layout requires factor to be 1"; + << "#pto.vmi.layout requires factor, block_elems, " + "and slots to be their defaults"; return success(); } if (kind == "deinterleaved") { if (factor != 2 && factor != 4) - return emitError() - << "#pto.vmi.layout expected factor to be 2 or 4"; + return emitError() << "#pto.vmi.layout expected factor to be 2 or 4"; + if (blockElems <= 0) + return emitError() << "#pto.vmi.layout requires block_elems to be positive"; + if (slots != 0) + return emitError() << "#pto.vmi.layout requires slots to be omitted"; return success(); } if (kind == "num_groups") { if (factor <= 0) + return emitError() << "#pto.vmi.layout requires num_groups to be positive"; + if (blockElems != 1) + return emitError() << "#pto.vmi.layout requires block_elems to be 1"; + if (slots < 0 || (slots != 0 && factor % slots != 0)) return emitError() << "#pto.vmi.layout requires num_groups to be positive"; + << ", slots = " << slots + << "> requires slots to be positive and divide num_groups when " + "specified"; return success(); } @@ -451,8 +548,8 @@ Type VMIVRegType::parse(AsmParser &parser) { failed(parser.parseGreater())) return {}; - return parser.getChecked(loc, parser.getContext(), - shape.front(), elementType, layout); + return parser.getChecked(loc, parser.getContext(), shape.front(), + elementType, layout); } void VMIVRegType::print(AsmPrinter &printer) const { @@ -467,25 +564,25 @@ LogicalResult VMIVRegType::verify(function_ref emitError, int64_t elementCount, Type elementType, Attribute layout) { if (elementCount <= 0) - return emitError() << "'" << formatVMIVRegType(elementCount, elementType, - layout) + return emitError() << "'" + << formatVMIVRegType(elementCount, elementType, layout) << "' expected a positive element count"; if (!isSupportedVMIElementType(elementType)) - return emitError() << "'" << formatVMIVRegType(elementCount, elementType, - layout) + return emitError() << "'" + << formatVMIVRegType(elementCount, elementType, layout) << "' expected an integer, index, floating-point, or " "PTO low-precision element type"; if (layout && !mlir::isa(layout)) - return emitError() << "'" << formatVMIVRegType(elementCount, elementType, - layout) + return emitError() << "'" + << formatVMIVRegType(elementCount, elementType, layout) << "' expected layout to be #pto.vmi.layout"; if (auto layoutAttr = llvm::dyn_cast_or_null(layout)) { if (layoutAttr.isGroupSlots() && elementCount % layoutAttr.getNumGroups() != 0) - return emitError() << "'" << formatVMIVRegType(elementCount, elementType, - layout) + return emitError() << "'" + << formatVMIVRegType(elementCount, elementType, layout) << "' expected num_groups layout to evenly divide " "the VMI logical lane count"; } @@ -515,8 +612,8 @@ Type VMIMaskType::parse(AsmParser &parser) { failed(parser.parseGreater())) return {}; - return parser.getChecked(loc, parser.getContext(), - shape.front(), granularity, layout); + return parser.getChecked(loc, parser.getContext(), shape.front(), + granularity, layout); } void VMIMaskType::print(AsmPrinter &printer) const { @@ -530,35 +627,35 @@ LogicalResult VMIMaskType::verify(function_ref emitError, int64_t elementCount, StringRef granularity, Attribute layout) { if (elementCount <= 0) - return emitError() << "'" << formatVMIMaskType(elementCount, granularity, - layout) + return emitError() << "'" + << formatVMIMaskType(elementCount, granularity, layout) << "' expected a positive element count"; if (!isSupportedGranularity(granularity)) - return emitError() << "'" << formatVMIMaskType(elementCount, granularity, - layout) + return emitError() << "'" + << formatVMIMaskType(elementCount, granularity, layout) << "' expected granularity to be one of pred, b8, b16, " "b32"; if (layout && !mlir::isa(layout)) - return emitError() << "'" << formatVMIMaskType(elementCount, granularity, - layout) + return emitError() << "'" + << formatVMIMaskType(elementCount, granularity, layout) << "' expected layout to be #pto.vmi.layout"; if (auto layoutAttr = llvm::dyn_cast_or_null(layout)) { if (layoutAttr.isGroupSlots()) - return emitError() << "'" << formatVMIMaskType(elementCount, granularity, - layout) + return emitError() << "'" + << formatVMIMaskType(elementCount, granularity, layout) << "' mask type must not carry num_groups layout"; } if (granularity == "pred" && layout) - return emitError() << "'" << formatVMIMaskType(elementCount, granularity, - layout) + return emitError() << "'" + << formatVMIMaskType(elementCount, granularity, layout) << "' pred mask must not carry layout"; if (granularity != "pred" && !layout) - return emitError() << "'" << formatVMIMaskType(elementCount, granularity, - layout) + return emitError() << "'" + << formatVMIMaskType(elementCount, granularity, layout) << "' concrete mask granularity requires layout"; return success(); @@ -570,9 +667,11 @@ LogicalResult VMIConstantOp::verify() { if (!denseAttr) return emitOpError("requires dense elements constant attribute"); if (denseAttr.getElementType() != resultType.getElementType()) - return emitOpError("requires dense constant element type to match result element type"); + return emitOpError( + "requires dense constant element type to match result element type"); if (denseAttr.getNumElements() != resultType.getElementCount()) - return emitOpError("requires dense constant element count to match result logical lane count"); + return emitOpError("requires dense constant element count to match result " + "logical lane count"); return success(); } @@ -616,6 +715,22 @@ LogicalResult VMICreateMaskOp::verify() { return success(); } +LogicalResult VMICreateGroupMaskOp::verify() { + auto resultType = cast(getResult().getType()); + int64_t numGroups = getNumGroupsAttr().getInt(); + int64_t groupSize = getGroupSizeAttr().getInt(); + if (numGroups <= 0) + return emitOpError("requires positive num_groups"); + if (groupSize <= 0) + return emitOpError("requires positive group_size"); + if (resultType.getElementCount() != numGroups * groupSize) + return emitOpError("requires result lane count to equal num_groups * " + "group_size"); + if (!resultType.isPred() && !isLayoutAssigned(resultType)) + return emitOpError("requires concrete mask result to carry layout"); + return success(); +} + LogicalResult VMIConstantMaskOp::verify() { auto resultType = cast(getResult().getType()); auto denseAttr = dyn_cast(getValue()); @@ -624,7 +739,8 @@ LogicalResult VMIConstantMaskOp::verify() { if (!denseAttr.getElementType().isInteger(1)) return emitOpError("requires dense mask constant element type to be i1"); if (denseAttr.getNumElements() != resultType.getElementCount()) - return emitOpError("requires dense mask constant element count to match result logical lane count"); + return emitOpError("requires dense mask constant element count to match " + "result logical lane count"); return success(); } @@ -655,8 +771,8 @@ LogicalResult VMIMaskXOrOp::verify() { LogicalResult VMIMaskNotOp::verify() { auto sourceType = cast(getSource().getType()); auto resultType = cast(getResult().getType()); - return verifyAllSameMaskShapeLayoutAndGranularity( - getOperation(), {sourceType, resultType}); + return verifyAllSameMaskShapeLayoutAndGranularity(getOperation(), + {sourceType, resultType}); } LogicalResult VMIAddFOp::verify() { @@ -847,7 +963,8 @@ LogicalResult VMINotOp::verify() { auto resultType = cast(getResult().getType()); if (!isVMIIntegerLikeType(sourceType.getElementType())) return emitOpError("requires integer-like VMI element type"); - return verifyAllSameVRegShapeAndLayout(getOperation(), {sourceType, resultType}, + return verifyAllSameVRegShapeAndLayout(getOperation(), + {sourceType, resultType}, /*requireSameElement=*/true); } @@ -880,9 +997,9 @@ LogicalResult VMISelectOp::verify() { auto trueType = cast(getTrueValue().getType()); auto falseType = cast(getFalseValue().getType()); auto resultType = cast(getResult().getType()); - if (failed(verifyAllSameVRegShapeAndLayout( - getOperation(), {trueType, falseType, resultType}, - /*requireSameElement=*/true))) + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), + {trueType, falseType, resultType}, + /*requireSameElement=*/true))) return failure(); return verifyMaskMatchesData(getOperation(), maskType, resultType); } @@ -903,9 +1020,9 @@ LogicalResult VMICompressOp::verify() { auto sourceType = cast(getSource().getType()); auto maskType = cast(getMask().getType()); auto resultType = cast(getResult().getType()); - if (failed(verifyAllSameVRegShapeAndLayout( - getOperation(), {sourceType, resultType}, - /*requireSameElement=*/true))) + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), + {sourceType, resultType}, + /*requireSameElement=*/true))) return failure(); return verifyMaskMatchesData(getOperation(), maskType, sourceType); } @@ -970,14 +1087,14 @@ LogicalResult VMIReduceAddFOp::verify() { return verifyMaskMatchesData(getOperation(), maskType, sourceType); } -template -LogicalResult verifyReduceMinMaxFOp(OpTy op) { +template LogicalResult verifyReduceMinMaxFOp(OpTy op) { auto sourceType = cast(op.getSource().getType()); auto initType = cast(op.getInit().getType()); auto maskType = cast(op.getMask().getType()); auto resultType = cast(op.getResult().getType()); if (!isVMIFloatLikeType(sourceType.getElementType())) - return op.emitOpError("requires floating-point-like VMI source element type"); + return op.emitOpError( + "requires floating-point-like VMI source element type"); if (sourceType.getElementType() != initType.getElementType() || sourceType.getElementType() != resultType.getElementType()) return op.emitOpError( @@ -991,13 +1108,9 @@ LogicalResult verifyReduceMinMaxFOp(OpTy op) { return verifyMaskMatchesData(op.getOperation(), maskType, sourceType); } -LogicalResult VMIReduceMaxFOp::verify() { - return verifyReduceMinMaxFOp(*this); -} +LogicalResult VMIReduceMaxFOp::verify() { return verifyReduceMinMaxFOp(*this); } -LogicalResult VMIReduceMinFOp::verify() { - return verifyReduceMinMaxFOp(*this); -} +LogicalResult VMIReduceMinFOp::verify() { return verifyReduceMinMaxFOp(*this); } LogicalResult VMIGroupReduceAddFOp::verify() { auto sourceType = cast(getSource().getType()); @@ -1015,17 +1128,25 @@ LogicalResult VMIGroupReduceAddFOp::verify() { if (sourceType.getElementType() != resultType.getElementType()) return emitOpError("requires source and result element types to match"); if (auto sourceLayout = sourceType.getLayoutAttr()) { - if (!sourceLayout.isContiguous()) + bool supportedSourceLayout = + sourceLayout.isContiguous() || + (sourceLayout.isDeinterleaved() && sourceLayout.getFactor() == 2 && + (sourceLayout.getBlockElems() == 1 || + sourceLayout.getBlockElems() == 8)) || + (sourceLayout.isDeinterleaved() && sourceLayout.getFactor() == 4 && + (sourceLayout.getBlockElems() == 1 || + sourceLayout.getBlockElems() == 8)); + if (!supportedSourceLayout) return emitOpError( - "requires layout-assigned source to use contiguous layout"); + "requires layout-assigned source to use contiguous layout or " + "deinterleaved=2/4 layout with block_elems=1 or block_elems=8"); } if (auto resultLayout = resultType.getLayoutAttr()) { if (!resultLayout.isGroupSlots() || resultLayout.getNumGroups() != getNumGroupsAttr().getInt()) - return emitOpError() - << "requires layout-assigned result to use " - "#pto.vmi.layout"; + return emitOpError() << "requires layout-assigned result to use " + "#pto.vmi.layout"; } if (failed(verifyMaskMatchesData(getOperation(), maskType, sourceType))) return failure(); @@ -1044,10 +1165,9 @@ LogicalResult VMIGroupBroadcastOp::verify() { if (auto sourceLayout = sourceType.getLayoutAttr()) { if (!sourceLayout.isGroupSlots() || sourceLayout.getNumGroups() != getNumGroupsAttr().getInt()) - return emitOpError() - << "requires layout-assigned source to use " - "#pto.vmi.layout"; + return emitOpError() << "requires layout-assigned source to use " + "#pto.vmi.layout"; } if (auto resultLayout = resultType.getLayoutAttr()) { if (resultLayout.isGroupSlots()) @@ -1062,13 +1182,16 @@ LogicalResult VMIExtFOp::verify() { auto sourceType = cast(getSource().getType()); auto resultType = cast(getResult().getType()); if (sourceType.getElementCount() != resultType.getElementCount()) - return emitOpError("requires source and result logical lane counts to match"); + return emitOpError( + "requires source and result logical lane counts to match"); if (!isVMIFloatLikeType(sourceType.getElementType()) || !isVMIFloatLikeType(resultType.getElementType())) - return emitOpError("requires floating-point-like source and result element types"); + return emitOpError( + "requires floating-point-like source and result element types"); if (getVMIElementBitWidth(sourceType.getElementType()) >= getVMIElementBitWidth(resultType.getElementType())) - return emitOpError("requires result element type to be wider than source element type"); + return emitOpError( + "requires result element type to be wider than source element type"); return success(); } @@ -1076,13 +1199,16 @@ LogicalResult VMITruncFOp::verify() { auto sourceType = cast(getSource().getType()); auto resultType = cast(getResult().getType()); if (sourceType.getElementCount() != resultType.getElementCount()) - return emitOpError("requires source and result logical lane counts to match"); + return emitOpError( + "requires source and result logical lane counts to match"); if (!isVMIFloatLikeType(sourceType.getElementType()) || !isVMIFloatLikeType(resultType.getElementType())) - return emitOpError("requires floating-point-like source and result element types"); + return emitOpError( + "requires floating-point-like source and result element types"); if (getVMIElementBitWidth(sourceType.getElementType()) <= getVMIElementBitWidth(resultType.getElementType())) - return emitOpError("requires result element type to be narrower than source element type"); + return emitOpError( + "requires result element type to be narrower than source element type"); return success(); } @@ -1114,6 +1240,10 @@ LogicalResult VMIBitcastOp::verify() { } LogicalResult VMILoadOp::verify() { + if (auto fullReadElems = getFullReadElemsAttr()) { + if (fullReadElems.getInt() <= 0) + return emitOpError("requires full_read_elems to be positive"); + } return verifyMemoryElementMatches(getOperation(), getSource().getType(), cast(getResult().getType()), "source"); @@ -1140,6 +1270,28 @@ void VMIGroupLoadOp::getEffects( effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); } +LogicalResult VMIGroupSlotLoadOp::verify() { + auto resultType = cast(getResult().getType()); + if (failed(verifyMemoryElementMatches(getOperation(), getSource().getType(), + resultType, "source"))) + return failure(); + if (auto resultLayout = resultType.getLayoutAttr()) { + if (!resultLayout.isGroupSlots() || + resultLayout.getNumGroups() != getNumGroupsAttr().getInt()) + return emitOpError() << "requires layout-assigned result to use " + "#pto.vmi.layout"; + } + return verifyNumGroups(getOperation(), resultType, + getNumGroupsAttr().getInt()); +} + +void VMIGroupSlotLoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + LogicalResult VMIMaskedLoadOp::verify() { auto maskType = cast(getMask().getType()); auto passthruType = cast(getPassthru().getType()); @@ -1147,9 +1299,9 @@ LogicalResult VMIMaskedLoadOp::verify() { if (failed(verifyMemoryElementMatches(getOperation(), getSource().getType(), resultType, "source"))) return failure(); - if (failed(verifyAllSameVRegShapeAndLayout( - getOperation(), {passthruType, resultType}, - /*requireSameElement=*/true))) + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), + {passthruType, resultType}, + /*requireSameElement=*/true))) return failure(); return verifyMaskMatchesData(getOperation(), maskType, resultType); } @@ -1178,9 +1330,9 @@ LogicalResult VMIGatherOp::verify() { getOperation(), {indicesType, passthruType, resultType}, /*requireSameElement=*/false))) return failure(); - if (failed(verifyAllSameVRegShapeAndLayout( - getOperation(), {passthruType, resultType}, - /*requireSameElement=*/true))) + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), + {passthruType, resultType}, + /*requireSameElement=*/true))) return failure(); return verifyMaskMatchesData(getOperation(), maskType, resultType); } @@ -1198,9 +1350,9 @@ LogicalResult VMIExpandLoadOp::verify() { if (failed(verifyMemoryElementMatches(getOperation(), getSource().getType(), resultType, "source"))) return failure(); - if (failed(verifyAllSameVRegShapeAndLayout( - getOperation(), {passthruType, resultType}, - /*requireSameElement=*/true))) + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), + {passthruType, resultType}, + /*requireSameElement=*/true))) return failure(); return verifyMaskMatchesData(getOperation(), maskType, resultType); } @@ -1212,8 +1364,7 @@ void VMIExpandLoadOp::getEffects( } LogicalResult VMIStoreOp::verify() { - return verifyMemoryElementMatches(getOperation(), - getDestination().getType(), + return verifyMemoryElementMatches(getOperation(), getDestination().getType(), cast(getValue().getType()), "destination"); } @@ -1244,8 +1395,8 @@ LogicalResult VMIMaskedStoreOp::verify() { auto valueType = cast(getValue().getType()); auto maskType = cast(getMask().getType()); if (failed(verifyMemoryElementMatches(getOperation(), - getDestination().getType(), - valueType, "destination"))) + getDestination().getType(), valueType, + "destination"))) return failure(); return verifyMaskMatchesData(getOperation(), maskType, valueType); } @@ -1261,8 +1412,8 @@ LogicalResult VMIScatterOp::verify() { auto indicesType = cast(getIndices().getType()); auto maskType = cast(getMask().getType()); if (failed(verifyMemoryElementMatches(getOperation(), - getDestination().getType(), - valueType, "destination"))) + getDestination().getType(), valueType, + "destination"))) return failure(); auto indexElementType = dyn_cast(indicesType.getElementType()); @@ -1270,9 +1421,9 @@ LogicalResult VMIScatterOp::verify() { indexElementType.isSigned()) return emitOpError("requires signless or unsigned 32-bit integer indices"); - if (failed(verifyAllSameVRegShapeAndLayout( - getOperation(), {valueType, indicesType}, - /*requireSameElement=*/false))) + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), + {valueType, indicesType}, + /*requireSameElement=*/false))) return failure(); return verifyMaskMatchesData(getOperation(), maskType, valueType); } @@ -1296,8 +1447,7 @@ void VMITileReadOp::getEffects( } LogicalResult VMITileWriteOp::verify() { - return verifyMemoryElementMatches(getOperation(), - getDestination().getType(), + return verifyMemoryElementMatches(getOperation(), getDestination().getType(), cast(getValue().getType()), "destination"); } @@ -1312,16 +1462,20 @@ LogicalResult VMIShuffleOp::verify() { auto sourceType = cast(getSource().getType()); auto resultType = cast(getResult().getType()); if (sourceType.getElementType() != resultType.getElementType()) - return emitOpError("requires result element type to match source element type"); + return emitOpError( + "requires result element type to match source element type"); if (static_cast(getIndices().size()) != resultType.getElementCount()) - return emitOpError("requires shuffle index count to match result logical lane count"); + return emitOpError( + "requires shuffle index count to match result logical lane count"); for (int64_t index : getIndices()) { if (index < 0 || index >= sourceType.getElementCount()) - return emitOpError("requires every shuffle index to select an existing source logical lane"); + return emitOpError("requires every shuffle index to select an existing " + "source logical lane"); } if (isLayoutAssigned(sourceType) || isLayoutAssigned(resultType)) { if (!isLayoutAssigned(sourceType) || !isLayoutAssigned(resultType)) - return emitOpError("requires either both source and result to carry layout or neither to carry layout"); + return emitOpError("requires either both source and result to carry " + "layout or neither to carry layout"); } return success(); } @@ -1332,26 +1486,32 @@ LogicalResult VMIChannelSplitOp::verify() { return emitOpError("requires at least two channel results"); auto firstResultType = cast(getResults().front().getType()); if (sourceType.getElementCount() != - static_cast(getResults().size()) * firstResultType.getElementCount()) - return emitOpError("requires source lane count to equal result count times per-channel lane count"); + static_cast(getResults().size()) * + firstResultType.getElementCount()) + return emitOpError("requires source lane count to equal result count times " + "per-channel lane count"); for (Value result : getResults()) { auto resultType = cast(result.getType()); if (resultType.getElementCount() != firstResultType.getElementCount() || resultType.getElementType() != sourceType.getElementType()) - return emitOpError("requires every channel result to have equal lane count and source element type"); + return emitOpError("requires every channel result to have equal lane " + "count and source element type"); } bool anyLayout = isLayoutAssigned(sourceType); for (Value result : getResults()) anyLayout |= isLayoutAssigned(cast(result.getType())); if (anyLayout) { if (!isLayoutAssigned(sourceType)) - return emitOpError("requires layout-assigned channel_split source when any channel result has layout"); + return emitOpError("requires layout-assigned channel_split source when " + "any channel result has layout"); for (Value result : getResults()) { auto resultType = cast(result.getType()); if (!isLayoutAssigned(resultType)) - return emitOpError("requires every channel_split result to carry layout when source has layout"); + return emitOpError("requires every channel_split result to carry " + "layout when source has layout"); if (!cast(resultType.getLayout()).isContiguous()) - return emitOpError("requires layout-assigned channel_split results to be contiguous"); + return emitOpError( + "requires layout-assigned channel_split results to be contiguous"); } int64_t channels = getResults().size(); if (channels == 2 || channels == 4) { @@ -1359,7 +1519,8 @@ LogicalResult VMIChannelSplitOp::verify() { auto expectedLayout = VMILayoutAttr::getDeinterleaved(getContext(), channels); if (!sourceLayout.isContiguous() && sourceLayout != expectedLayout) - return emitOpError("requires layout-assigned channel_split source to be contiguous or deinterleaved by result count"); + return emitOpError("requires layout-assigned channel_split source to " + "be contiguous or deinterleaved by result count"); } } return success(); @@ -1374,24 +1535,29 @@ LogicalResult VMIChannelMergeOp::verify() { auto inputType = cast(input.getType()); if (inputType.getElementCount() != firstInputType.getElementCount() || inputType.getElementType() != firstInputType.getElementType()) - return emitOpError("requires all channel inputs to have the same lane count and element type"); + return emitOpError("requires all channel inputs to have the same lane " + "count and element type"); } - if (resultType.getElementCount() != - static_cast(getInputs().size()) * firstInputType.getElementCount() || + if (resultType.getElementCount() != static_cast(getInputs().size()) * + firstInputType.getElementCount() || resultType.getElementType() != firstInputType.getElementType()) - return emitOpError("requires result lane count and element type to match merged channels"); + return emitOpError( + "requires result lane count and element type to match merged channels"); bool anyLayout = isLayoutAssigned(resultType); for (Value input : getInputs()) anyLayout |= isLayoutAssigned(cast(input.getType())); if (anyLayout) { if (!isLayoutAssigned(resultType)) - return emitOpError("requires layout-assigned channel_merge result when any channel input has layout"); + return emitOpError("requires layout-assigned channel_merge result when " + "any channel input has layout"); for (Value input : getInputs()) { auto inputType = cast(input.getType()); if (!isLayoutAssigned(inputType)) - return emitOpError("requires every channel_merge input to carry layout when result has layout"); + return emitOpError("requires every channel_merge input to carry layout " + "when result has layout"); if (!cast(inputType.getLayout()).isContiguous()) - return emitOpError("requires layout-assigned channel_merge inputs to be contiguous"); + return emitOpError( + "requires layout-assigned channel_merge inputs to be contiguous"); } int64_t channels = getInputs().size(); if (channels == 2 || channels == 4) { @@ -1399,7 +1565,8 @@ LogicalResult VMIChannelMergeOp::verify() { auto expectedLayout = VMILayoutAttr::getDeinterleaved(getContext(), channels); if (!resultLayout.isContiguous() && resultLayout != expectedLayout) - return emitOpError("requires layout-assigned channel_merge result to be contiguous or deinterleaved by input count"); + return emitOpError("requires layout-assigned channel_merge result to " + "be contiguous or deinterleaved by input count"); } } return success(); @@ -1410,7 +1577,8 @@ LogicalResult VMIEnsureLayoutOp::verify() { auto resultType = cast(getResult().getType()); if (sourceType.getElementCount() != resultType.getElementCount() || sourceType.getElementType() != resultType.getElementType()) - return emitOpError("requires source and result to preserve VMI data shape and element type"); + return emitOpError("requires source and result to preserve VMI data shape " + "and element type"); if (!isLayoutAssigned(sourceType) || !isLayoutAssigned(resultType)) return emitOpError("requires source and result to be layout-assigned"); return success(); @@ -1421,7 +1589,8 @@ LogicalResult VMIEnsureMaskLayoutOp::verify() { auto resultType = cast(getResult().getType()); if (sourceType.getElementCount() != resultType.getElementCount() || sourceType.getGranularity() != resultType.getGranularity()) - return emitOpError("requires source and result to preserve VMI mask shape and granularity"); + return emitOpError("requires source and result to preserve VMI mask shape " + "and granularity"); if (!isLayoutAssigned(sourceType) || !isLayoutAssigned(resultType)) return emitOpError("requires source and result to be layout-assigned"); return success(); @@ -1431,13 +1600,15 @@ LogicalResult VMIEnsureMaskGranularityOp::verify() { auto sourceType = cast(getSource().getType()); auto resultType = cast(getResult().getType()); if (sourceType.getElementCount() != resultType.getElementCount()) - return emitOpError("requires source and result to preserve VMI mask lane count"); + return emitOpError( + "requires source and result to preserve VMI mask lane count"); if (!isLayoutAssigned(sourceType) || !isLayoutAssigned(resultType)) return emitOpError("requires source and result to be layout-assigned"); if (sourceType.getLayout() != resultType.getLayout()) return emitOpError("requires source and result mask layouts to match"); if (sourceType.isPred() || resultType.isPred()) - return emitOpError("requires concrete source and result mask granularities"); + return emitOpError( + "requires concrete source and result mask granularities"); return success(); } @@ -1473,14 +1644,22 @@ FailureOr mlir::pto::getMaskLanesPerPart(StringRef granularity) { FailureOr mlir::pto::getVMIPhysicalArity(Type type) { FailureOr elementCount = getVMIElementCount(type); - FailureOr factor = getLayoutFactor(type); FailureOr lanesPerPart = getPhysicalLanesPerPart(type); - if (failed(elementCount) || failed(factor) || failed(lanesPerPart)) + FailureOr layout = getAssignedVMILayout(type); + if (failed(elementCount) || failed(lanesPerPart) || failed(layout)) return failure(); + if ((*layout).isGroupSlots() && (*layout).getSlots() > 0) + return divideCeilNonNegative((*layout).getNumGroups(), + (*layout).getSlots()); + + int64_t factor = (*layout).isDeinterleaved() ? (*layout).getFactor() : 1; + int64_t blockElems = + (*layout).isDeinterleaved() ? (*layout).getBlockElems() : 1; int64_t arity = 0; - for (int64_t part = 0; part < *factor; ++part) { - int64_t lanesInPart = getLogicalLanesInPart(*elementCount, *factor, part); + for (int64_t part = 0; part < factor; ++part) { + int64_t lanesInPart = + getDenseLogicalLanesInPart(*elementCount, factor, blockElems, part); arity += divideCeilNonNegative(lanesInPart, *lanesPerPart); } return arity; @@ -1490,16 +1669,21 @@ FailureOr mlir::pto::mapLogicalLaneToPhysical(Type type, int64_t logicalLane) { FailureOr elementCount = getVMIElementCount(type); FailureOr factor = getLayoutFactor(type); + FailureOr blockElems = getLayoutBlockElems(type); FailureOr lanesPerPart = getPhysicalLanesPerPart(type); - if (failed(elementCount) || failed(factor) || failed(lanesPerPart)) + if (failed(elementCount) || failed(factor) || failed(blockElems) || + failed(lanesPerPart)) return failure(); if (logicalLane < 0 || logicalLane >= *elementCount) return failure(); - int64_t part = logicalLane % *factor; - int64_t indexInPart = logicalLane / *factor; - return VMIPhysicalLane{part, indexInPart / *lanesPerPart, - indexInPart % *lanesPerPart}; + int64_t part = 0; + std::optional indexInPart = mapDenseLogicalLaneToPartIndex( + *elementCount, *factor, *blockElems, logicalLane, part); + if (!indexInPart) + return failure(); + return VMIPhysicalLane{part, *indexInPart / *lanesPerPart, + *indexInPart % *lanesPerPart}; } FailureOr mlir::pto::mapPhysicalLaneToLogical(Type type, int64_t part, @@ -1507,32 +1691,38 @@ FailureOr mlir::pto::mapPhysicalLaneToLogical(Type type, int64_t part, int64_t lane) { FailureOr elementCount = getVMIElementCount(type); FailureOr factor = getLayoutFactor(type); + FailureOr blockElems = getLayoutBlockElems(type); FailureOr lanesPerPart = getPhysicalLanesPerPart(type); - if (failed(elementCount) || failed(factor) || failed(lanesPerPart)) + if (failed(elementCount) || failed(factor) || failed(blockElems) || + failed(lanesPerPart)) return failure(); if (part < 0 || part >= *factor || chunk < 0 || lane < 0 || lane >= *lanesPerPart) return failure(); int64_t indexInPart = chunk * *lanesPerPart + lane; - int64_t logicalLane = indexInPart * *factor + part; - if (logicalLane >= *elementCount) + std::optional logicalLane = mapDensePartIndexToLogicalLane( + *elementCount, *factor, *blockElems, part, indexInPart); + if (!logicalLane) return failure(); - return logicalLane; + return *logicalLane; } -FailureOr mlir::pto::isPaddingLane(Type type, int64_t part, - int64_t chunk, int64_t lane) { +FailureOr mlir::pto::isPaddingLane(Type type, int64_t part, int64_t chunk, + int64_t lane) { FailureOr elementCount = getVMIElementCount(type); FailureOr factor = getLayoutFactor(type); + FailureOr blockElems = getLayoutBlockElems(type); FailureOr lanesPerPart = getPhysicalLanesPerPart(type); - if (failed(elementCount) || failed(factor) || failed(lanesPerPart)) + if (failed(elementCount) || failed(factor) || failed(blockElems) || + failed(lanesPerPart)) return failure(); if (part < 0 || part >= *factor || chunk < 0 || lane < 0 || lane >= *lanesPerPart) return failure(); - int64_t lanesInPart = getLogicalLanesInPart(*elementCount, *factor, part); + int64_t lanesInPart = + getDenseLogicalLanesInPart(*elementCount, *factor, *blockElems, part); int64_t indexInPart = chunk * *lanesPerPart + lane; return indexInPart >= lanesInPart; } diff --git a/lib/PTO/Transforms/VMILayoutAssignment.cpp b/lib/PTO/Transforms/VMILayoutAssignment.cpp index 27d6b806fe..c95e8772ec 100644 --- a/lib/PTO/Transforms/VMILayoutAssignment.cpp +++ b/lib/PTO/Transforms/VMILayoutAssignment.cpp @@ -1,10 +1,12 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. //===- VMILayoutAssignment.cpp - Assign VMI layouts ----------------------===// //===----------------------------------------------------------------------===// @@ -63,6 +65,8 @@ struct MaskUseRequest { std::string granularity; }; +static constexpr const char *kVMISelectedPlanAttrName = "vmi.selected_plan"; + static unsigned getElementBitWidth(Type type) { if (isa(type)) return 64; @@ -82,6 +86,15 @@ static StringRef getMaskGranularityForElement(Type elementType) { } } +static std::optional getConstantIndexValue(Value value) { + if (auto constant = value.getDefiningOp()) + return constant.value(); + if (auto constant = value.getDefiningOp()) + if (auto integerAttr = dyn_cast(constant.getValue())) + return integerAttr.getInt(); + return std::nullopt; +} + static bool isLane0SplatShuffle(VMIShuffleOp op) { auto sourceType = cast(op.getSource().getType()); ArrayRef indices = op.getIndices(); @@ -93,12 +106,10 @@ bool containsVMIType(Type type) { if (isa(type)) return true; if (auto functionType = dyn_cast(type)) { - return llvm::any_of(functionType.getInputs(), [](Type input) { - return containsVMIType(input); - }) || - llvm::any_of(functionType.getResults(), [](Type result) { - return containsVMIType(result); - }); + return llvm::any_of(functionType.getInputs(), + [](Type input) { return containsVMIType(input); }) || + llvm::any_of(functionType.getResults(), + [](Type result) { return containsVMIType(result); }); } if (auto shapedType = dyn_cast(type)) return containsVMIType(shapedType.getElementType()); @@ -191,11 +202,10 @@ struct LayoutSolver { if (!lhsNode.requestedGranularity.empty() && !rhsNode.requestedGranularity.empty() && lhsNode.requestedGranularity != rhsNode.requestedGranularity) - return op->emitError() - << kVMIDiagLayoutContractPrefix - << "conflicting mask granularities " - << lhsNode.requestedGranularity << " and " - << rhsNode.requestedGranularity; + return op->emitError() << kVMIDiagLayoutContractPrefix + << "conflicting mask granularities " + << lhsNode.requestedGranularity << " and " + << rhsNode.requestedGranularity; rhsNode.parent = lhsRoot; if (!lhsNode.requestedLayout) @@ -228,6 +238,123 @@ struct LayoutSolver { return VMILayoutAttr::getGroupSlots(ctx, numGroups); } + VMILayoutAttr getPreferredGroupSlotsLayout(VMIVRegType type, + int64_t numGroups) { + if (VMILayoutAttr existing = type.getLayoutAttr()) + if (existing.isGroupSlots() && existing.getSlots() > 0) + return existing; + if (numGroups > 0 && type.getElementCount() % numGroups == 0) { + int64_t groupSize = type.getElementCount() / numGroups; + if (groupSize == 8) + return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); + if (groupSize == 16) + return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); + if (groupSize == 32) + return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); + if (groupSize == 64) + return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/1); + } + return getGroupSlotsLayout(numGroups); + } + + VMILayoutAttr getPreferredGroupReduceSourceLayout(VMIVRegType type, + int64_t numGroups) { + if (VMILayoutAttr existing = type.getLayoutAttr()) + return existing; + if (numGroups > 0 && type.getElementCount() % numGroups == 0) { + int64_t groupSize = type.getElementCount() / numGroups; + if (groupSize == 16) + return VMILayoutAttr::getDeinterleaved(ctx, 2, /*blockElems=*/8); + if (groupSize == 32) + return VMILayoutAttr::getDeinterleaved(ctx, 4, /*blockElems=*/8); + } + return getContiguousLayout(); + } + + VMILayoutAttr getPreferredGroupSlotLoadLayout(VMIVRegType type, + int64_t numGroups) { + if (VMILayoutAttr existing = type.getLayoutAttr()) + if (existing.isGroupSlots() && existing.getSlots() > 0) + return existing; + if (numGroups > 0 && type.getElementCount() % numGroups == 0) { + int64_t groupSize = type.getElementCount() / numGroups; + if (groupSize == 64) + return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/1); + } + return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); + } + + VMILayoutAttr getPreferredGroupLoadResultLayout(VMIGroupLoadOp op) { + auto type = cast(op.getResult().getType()); + if (VMILayoutAttr existing = type.getLayoutAttr()) + return existing; + + int64_t numGroups = op.getNumGroupsAttr().getInt(); + if (numGroups <= 0 || type.getElementCount() % numGroups != 0) + return getContiguousLayout(); + + if (!type.getElementType().isF32()) + return getContiguousLayout(); + + int64_t groupSize = type.getElementCount() / numGroups; + std::optional rowStride = getConstantIndexValue(op.getRowStride()); + if (!rowStride || *rowStride <= 0 || *rowStride % 8 != 0) + return getContiguousLayout(); + + if (groupSize == 16) + return VMILayoutAttr::getDeinterleaved(ctx, 2, /*blockElems=*/8); + if (groupSize == 32) + return VMILayoutAttr::getDeinterleaved(ctx, 4, /*blockElems=*/8); + + return getContiguousLayout(); + } + + LogicalResult validateGroupLoadLayoutPlan(VMIGroupLoadOp op) { + auto type = cast(op.getResult().getType()); + if (type.getLayoutAttr()) + return success(); + + int64_t numGroups = op.getNumGroupsAttr().getInt(); + if (numGroups <= 0 || type.getElementCount() % numGroups != 0) + return success(); + if (!type.getElementType().isF32()) + return success(); + + int64_t groupSize = type.getElementCount() / numGroups; + if (groupSize != 16 && groupSize != 32) + return success(); + + std::optional rowStride = getConstantIndexValue(op.getRowStride()); + if (rowStride && *rowStride > 0 && *rowStride % 8 == 0) + return success(); + + return op.emitError() + << kVMIDiagLayoutContractPrefix << "pto.vmi.group_load group_size " + << groupSize + << " requires constant positive row_stride divisible by 8 f32 " + "elements for the block8 stride plan; stable gather fallback is " + "not implemented"; + } + + VMILayoutAttr getPreferredGroupStoreUseLayout(Value value, + int64_t numGroups) { + auto type = dyn_cast(value.getType()); + if (!type) + return getContiguousLayout(); + if (VMILayoutAttr existing = type.getLayoutAttr()) + if (existing.isGroupSlots() && existing.getSlots() > 0) + return existing; + VMILayoutAttr solved = getDataLayout(value); + if (solved && solved.isGroupSlots() && solved.getNumGroups() == numGroups && + solved.getSlots() > 0) + return solved; + if (value.getDefiningOp()) + return getPreferredGroupSlotsLayout(type, numGroups); + if (value.getDefiningOp()) + return getPreferredGroupSlotLoadLayout(type, numGroups); + return getContiguousLayout(); + } + VMILayoutAttr getDataLayout(Value value) { unsigned id = addDataValue(value); if (id == ~0u) @@ -238,6 +365,35 @@ struct LayoutSolver { return getContiguousLayout(); } + VMILayoutAttr getExplicitDataLayout(Value value) { + unsigned id = addDataValue(value); + if (id == ~0u) + return {}; + return dataNodes[find(id)].naturalLayout; + } + + bool hasCompatibleTruncFUseForGroupReduce(Value value, int64_t groupSize) { + auto sourceType = dyn_cast(value.getType()); + if (!sourceType || !sourceType.getElementType().isF32()) + return false; + + for (OpOperand &use : value.getUses()) { + auto truncf = dyn_cast(use.getOwner()); + if (!truncf || use.getOperandNumber() != 0) + continue; + + auto resultType = dyn_cast(truncf.getResult().getType()); + if (!resultType) + continue; + unsigned resultBits = getElementBitWidth(resultType.getElementType()); + if (groupSize == 16 && resultBits == 16) + return true; + if (groupSize == 32 && resultBits == 8) + return true; + } + return false; + } + LogicalResult requestMask(Value mask, VMILayoutAttr layout, StringRef granularity, Operation *op) { unsigned id = addMaskValue(mask); @@ -256,8 +412,8 @@ struct LayoutSolver { node.requestedGranularity != granularity) return op->emitError() << kVMIDiagLayoutContractPrefix - << "conflicting mask granularities " - << node.requestedGranularity << " and " << granularity; + << "conflicting mask granularities " << node.requestedGranularity + << " and " << granularity; node.requestedLayout = layout; node.requestedGranularity = granularity.str(); return success(); @@ -268,17 +424,54 @@ struct LayoutSolver { dataUseRequests.push_back(DataUseRequest{&operand, layout}); } - bool canAdoptConsumerRequestedLayout(Value value) { - if (!value.hasOneUse()) + bool canProducerAdoptConsumerLayout(Operation *op) { + if (!op) return false; + return isa(op); + } + + bool canAdoptConsumerRequestedLayout(Value value, + VMILayoutAttr requestedLayout) { Operation *definingOp = value.getDefiningOp(); - return definingOp && isa(definingOp); + if (!definingOp) + return false; + if (!isa(definingOp)) { + if (!requestedLayout || requestedLayout.isContiguous()) + return false; + if (!canProducerAdoptConsumerLayout(definingOp)) + return false; + } + if (value.hasOneUse()) + return true; + + unsigned matchingRequests = 0; + unsigned totalUses = 0; + for (OpOperand &use : value.getUses()) { + ++totalUses; + bool foundRequest = false; + for (DataUseRequest request : dataUseRequests) { + if (request.operand != &use) + continue; + if (request.layout != requestedLayout) + return false; + foundRequest = true; + } + if (!foundRequest) + return false; + ++matchingRequests; + } + return matchingRequests == totalUses; } LogicalResult applyConsumerDrivenDataLayouts() { for (DataUseRequest request : dataUseRequests) { Value value = request.operand->get(); - if (!canAdoptConsumerRequestedLayout(value)) + if (!canAdoptConsumerRequestedLayout(value, request.layout)) continue; unsigned id = addDataValue(value); if (id == ~0u) @@ -287,8 +480,7 @@ struct LayoutSolver { VMILayoutAttr existing = dataNodes[root].naturalLayout; if (existing && existing != request.layout) return request.operand->getOwner()->emitError() - << kVMIDiagLayoutContractPrefix - << "conflicting natural layouts " + << kVMIDiagLayoutContractPrefix << "conflicting natural layouts " << existing << " and " << request.layout; dataNodes[root].naturalLayout = request.layout; } @@ -324,6 +516,71 @@ struct LayoutSolver { return success(); } + bool shouldCommuteTruncFAfterGroupBroadcast(VMIGroupBroadcastOp broadcast) { + auto truncf = broadcast.getSource().getDefiningOp(); + if (!truncf) + return false; + + auto truncSourceType = dyn_cast(truncf.getSource().getType()); + auto truncResultType = dyn_cast(truncf.getResult().getType()); + auto broadcastResultType = + dyn_cast(broadcast.getResult().getType()); + if (!truncSourceType || !truncResultType || !broadcastResultType) + return false; + if (truncSourceType.getElementCount() != + truncResultType.getElementCount() || + truncResultType.getElementCount() != + broadcastResultType.getElementCount()) + return false; + + VMILayoutAttr sourceLayout = truncSourceType.getLayoutAttr(); + bool sourceIsGroupSlotValue = + (sourceLayout && sourceLayout.isGroupSlots()) || + truncf.getSource().getDefiningOp() || + truncf.getSource().getDefiningOp(); + if (!sourceIsGroupSlotValue) + return false; + + unsigned sourceBits = getElementBitWidth(truncSourceType.getElementType()); + unsigned resultBits = getElementBitWidth(truncResultType.getElementType()); + return truncSourceType.getElementType().isF32() && sourceBits > resultBits; + } + + LogicalResult commuteTruncFAfterGroupBroadcast() { + SmallVector broadcasts; + module.walk([&](VMIGroupBroadcastOp broadcast) { + if (shouldCommuteTruncFAfterGroupBroadcast(broadcast)) + broadcasts.push_back(broadcast); + }); + + OpBuilder builder(ctx); + for (VMIGroupBroadcastOp broadcast : broadcasts) { + auto truncf = broadcast.getSource().getDefiningOp(); + if (!truncf) + continue; + + auto truncSourceType = cast(truncf.getSource().getType()); + auto broadcastResultType = + cast(broadcast.getResult().getType()); + auto wideBroadcastType = + VMIVRegType::get(ctx, broadcastResultType.getElementCount(), + truncSourceType.getElementType(), + broadcastResultType.getLayoutAttr()); + + builder.setInsertionPoint(broadcast); + auto wideBroadcast = builder.create( + broadcast.getLoc(), wideBroadcastType, truncf.getSource(), + broadcast.getNumGroupsAttr()); + auto narrow = builder.create( + broadcast.getLoc(), broadcastResultType, wideBroadcast.getResult()); + broadcast.getResult().replaceAllUsesWith(narrow.getResult()); + broadcast.erase(); + if (truncf->use_empty()) + truncf.erase(); + } + return success(); + } + LogicalResult addConstraints() { WalkResult result = module.walk([&](Operation *op) -> WalkResult { if (auto maskAnd = dyn_cast(op)) { @@ -504,56 +761,117 @@ struct LayoutSolver { } if (auto compress = dyn_cast(op)) { requestDataUse(compress.getSourceMutable(), getContiguousLayout()); - if (failed(setNaturalLayout(compress.getResult(), - getContiguousLayout(), op))) + if (failed(setNaturalLayout(compress.getResult(), getContiguousLayout(), + op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto reduce = dyn_cast(op)) { requestDataUse(reduce.getSourceMutable(), getContiguousLayout()); requestDataUse(reduce.getInitMutable(), getContiguousLayout()); - if (failed(setNaturalLayout(reduce.getResult(), - getContiguousLayout(), op))) + if (failed(setNaturalLayout(reduce.getResult(), getContiguousLayout(), + op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto reduce = dyn_cast(op)) { requestDataUse(reduce.getSourceMutable(), getContiguousLayout()); requestDataUse(reduce.getInitMutable(), getContiguousLayout()); - if (failed(setNaturalLayout(reduce.getResult(), - getContiguousLayout(), op))) + if (failed(setNaturalLayout(reduce.getResult(), getContiguousLayout(), + op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto reduce = dyn_cast(op)) { requestDataUse(reduce.getSourceMutable(), getContiguousLayout()); requestDataUse(reduce.getInitMutable(), getContiguousLayout()); - if (failed(setNaturalLayout(reduce.getResult(), - getContiguousLayout(), op))) + if (failed(setNaturalLayout(reduce.getResult(), getContiguousLayout(), + op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto reduce = dyn_cast(op)) { requestDataUse(reduce.getSourceMutable(), getContiguousLayout()); requestDataUse(reduce.getInitMutable(), getContiguousLayout()); - if (failed(setNaturalLayout(reduce.getResult(), - getContiguousLayout(), op))) + if (failed(setNaturalLayout(reduce.getResult(), getContiguousLayout(), + op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto reduce = dyn_cast(op)) { - requestDataUse(reduce.getSourceMutable(), getContiguousLayout()); - if (failed(setNaturalLayout(reduce.getResult(), - getGroupSlotsLayout( - reduce.getNumGroupsAttr().getInt()), - op))) + auto sourceType = cast(reduce.getSource().getType()); + auto resultType = cast(reduce.getResult().getType()); + VMILayoutAttr sourceLayout = getPreferredGroupReduceSourceLayout( + sourceType, reduce.getNumGroupsAttr().getInt()); + VMILayoutAttr solvedSourceLayout = + getExplicitDataLayout(reduce.getSource()); + int64_t numGroups = reduce.getNumGroupsAttr().getInt(); + if (solvedSourceLayout && numGroups > 0 && + sourceType.getElementCount() % numGroups == 0) { + int64_t groupSize = sourceType.getElementCount() / numGroups; + if (groupSize == 16 && solvedSourceLayout.isDeinterleaved() && + solvedSourceLayout.getFactor() == 2 && + (solvedSourceLayout.getBlockElems() == 1 || + solvedSourceLayout.getBlockElems() == 8)) + sourceLayout = solvedSourceLayout; + if (groupSize == 32 && solvedSourceLayout.isDeinterleaved() && + solvedSourceLayout.getFactor() == 4 && + (solvedSourceLayout.getBlockElems() == 1 || + solvedSourceLayout.getBlockElems() == 8)) + sourceLayout = solvedSourceLayout; + } else if (!sourceType.getLayoutAttr() && numGroups > 0 && + sourceType.getElementCount() % numGroups == 0) { + int64_t groupSize = sourceType.getElementCount() / numGroups; + if (hasCompatibleTruncFUseForGroupReduce(reduce.getSource(), + groupSize)) { + if (groupSize == 16) + sourceLayout = + VMILayoutAttr::getDeinterleaved(ctx, 2, /*blockElems=*/1); + if (groupSize == 32) + sourceLayout = + VMILayoutAttr::getDeinterleaved(ctx, 4, /*blockElems=*/1); + } + } + if (sourceLayout && sourceLayout.isDeinterleaved() && + sourceLayout.getFactor() == 4 && + sourceLayout.getBlockElems() == 8 && numGroups > 0 && + sourceType.getElementCount() % numGroups == 0) { + int64_t groupSize = sourceType.getElementCount() / numGroups; + if (groupSize == 32) { + if (auto groupMask = + reduce.getMask().getDefiningOp()) { + std::optional activeElems = + getConstantIndexValue(groupMask.getActiveElemsPerGroup()); + if (activeElems && *activeElems >= 0 && + *activeElems < groupSize) { + reduce.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.group_reduce_addf s32 block8 lowering does " + "not yet support partial create_group_mask " + "active_elems_per_group during layout assignment"; + return WalkResult::interrupt(); + } + } + } + } + requestDataUse(reduce.getSourceMutable(), sourceLayout); + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceLayout, + getMaskGranularityForElement(sourceType.getElementType()), op))) + return WalkResult::interrupt(); + if (failed(setNaturalLayout( + reduce.getResult(), + getPreferredGroupSlotsLayout( + resultType, reduce.getNumGroupsAttr().getInt()), + op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto broadcast = dyn_cast(op)) { + auto sourceType = cast(broadcast.getSource().getType()); requestDataUse(broadcast.getSourceMutable(), - getGroupSlotsLayout( - broadcast.getNumGroupsAttr().getInt())); + getPreferredGroupSlotsLayout( + sourceType, broadcast.getNumGroupsAttr().getInt())); return WalkResult::advance(); } if (auto extf = dyn_cast(op)) { @@ -581,6 +899,14 @@ struct LayoutSolver { auto resultType = cast(truncf.getResult().getType()); unsigned sourceBits = getElementBitWidth(sourceType.getElementType()); unsigned resultBits = getElementBitWidth(resultType.getElementType()); + VMILayoutAttr sourceLayout = getDataLayout(truncf.getSource()); + if (sourceBits == 32 && resultBits == 16 && sourceLayout && + sourceLayout.isGroupSlots() && sourceLayout.getSlots() == 1) { + requestDataUse(truncf.getSourceMutable(), sourceLayout); + if (failed(setNaturalLayout(truncf.getResult(), sourceLayout, op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } if (sourceBits == 32 && resultBits == 16) requestDataUse(truncf.getSourceMutable(), VMILayoutAttr::getDeinterleaved(ctx, 2)); @@ -599,8 +925,8 @@ struct LayoutSolver { } if (auto load = dyn_cast(op)) { requestDataUse(load.getPassthruMutable(), getContiguousLayout()); - if (failed(setNaturalLayout(load.getResult(), getContiguousLayout(), - op))) + if (failed( + setNaturalLayout(load.getResult(), getContiguousLayout(), op))) return WalkResult::interrupt(); return WalkResult::advance(); } @@ -608,27 +934,37 @@ struct LayoutSolver { auto resultType = cast(gather.getResult().getType()); requestDataUse(gather.getIndicesMutable(), getContiguousLayout()); requestDataUse(gather.getPassthruMutable(), getContiguousLayout()); - if (failed(requestMaskUse(gather.getMaskMutable(), - getContiguousLayout(), - getMaskGranularityForElement( - resultType.getElementType()), - op))) + if (failed(requestMaskUse( + gather.getMaskMutable(), getContiguousLayout(), + getMaskGranularityForElement(resultType.getElementType()), op))) return WalkResult::interrupt(); - if (failed(setNaturalLayout(gather.getResult(), - getContiguousLayout(), op))) + if (failed(setNaturalLayout(gather.getResult(), getContiguousLayout(), + op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto load = dyn_cast(op)) { requestDataUse(load.getPassthruMutable(), getContiguousLayout()); - if (failed(setNaturalLayout(load.getResult(), getContiguousLayout(), - op))) + if (failed( + setNaturalLayout(load.getResult(), getContiguousLayout(), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto load = dyn_cast(op)) { - if (failed(setNaturalLayout(load.getResult(), getContiguousLayout(), - op))) + if (failed(validateGroupLoadLayoutPlan(load))) + return WalkResult::interrupt(); + if (failed(setNaturalLayout( + load.getResult(), getPreferredGroupLoadResultLayout(load), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto load = dyn_cast(op)) { + auto resultType = cast(load.getResult().getType()); + if (failed(setNaturalLayout( + load.getResult(), + getPreferredGroupSlotLoadLayout( + resultType, load.getNumGroupsAttr().getInt()), + op))) return WalkResult::interrupt(); return WalkResult::advance(); } @@ -637,17 +973,18 @@ struct LayoutSolver { return WalkResult::advance(); } if (auto store = dyn_cast(op)) { - requestDataUse(store.getValueMutable(), getContiguousLayout()); + requestDataUse( + store.getValueMutable(), + getPreferredGroupStoreUseLayout(store.getValue(), + store.getNumGroupsAttr().getInt())); return WalkResult::advance(); } if (auto store = dyn_cast(op)) { auto valueType = cast(store.getValue().getType()); requestDataUse(store.getValueMutable(), getContiguousLayout()); - if (failed(requestMaskUse(store.getMaskMutable(), - getContiguousLayout(), - getMaskGranularityForElement( - valueType.getElementType()), - op))) + if (failed(requestMaskUse( + store.getMaskMutable(), getContiguousLayout(), + getMaskGranularityForElement(valueType.getElementType()), op))) return WalkResult::interrupt(); return WalkResult::advance(); } @@ -655,22 +992,18 @@ struct LayoutSolver { auto valueType = cast(scatter.getValue().getType()); requestDataUse(scatter.getValueMutable(), getContiguousLayout()); requestDataUse(scatter.getIndicesMutable(), getContiguousLayout()); - if (failed(requestMaskUse(scatter.getMaskMutable(), - getContiguousLayout(), - getMaskGranularityForElement( - valueType.getElementType()), - op))) + if (failed(requestMaskUse( + scatter.getMaskMutable(), getContiguousLayout(), + getMaskGranularityForElement(valueType.getElementType()), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto store = dyn_cast(op)) { auto valueType = cast(store.getValue().getType()); requestDataUse(store.getValueMutable(), getContiguousLayout()); - if (failed(requestMaskUse(store.getMaskMutable(), - getContiguousLayout(), - getMaskGranularityForElement( - valueType.getElementType()), - op))) + if (failed(requestMaskUse( + store.getMaskMutable(), getContiguousLayout(), + getMaskGranularityForElement(valueType.getElementType()), op))) return WalkResult::interrupt(); return WalkResult::advance(); } @@ -680,16 +1013,14 @@ struct LayoutSolver { } if (auto split = dyn_cast(op)) { int64_t channels = split.getNumResults(); - VMICapabilityResult capability = - capabilities.supportsChannelCount("pto.vmi.channel_split", - channels); + VMICapabilityResult capability = capabilities.supportsChannelCount( + "pto.vmi.channel_split", channels); if (!capability.isSupported()) { split.emitError() << kVMIDiagUnsupportedPrefix << capability.reason; return WalkResult::interrupt(); } - requestDataUse( - split.getSourceMutable(), - VMILayoutAttr::getDeinterleaved(ctx, channels)); + requestDataUse(split.getSourceMutable(), + VMILayoutAttr::getDeinterleaved(ctx, channels)); for (Value result : split.getResults()) if (failed(setNaturalLayout(result, getContiguousLayout(), op))) return WalkResult::interrupt(); @@ -697,9 +1028,8 @@ struct LayoutSolver { } if (auto merge = dyn_cast(op)) { int64_t channels = merge.getInputs().size(); - VMICapabilityResult capability = - capabilities.supportsChannelCount("pto.vmi.channel_merge", - channels); + VMICapabilityResult capability = capabilities.supportsChannelCount( + "pto.vmi.channel_merge", channels); if (!capability.isSupported()) { merge.emitError() << kVMIDiagUnsupportedPrefix << capability.reason; return WalkResult::interrupt(); @@ -771,9 +1101,8 @@ struct LayoutSolver { if (failed(addBranchConstraints(switchOp.getDefaultDestination(), switchOp.getDefaultOperands(), op))) return WalkResult::interrupt(); - for (auto [dest, operands] : - llvm::zip(switchOp.getCaseDestinations(), - switchOp.getCaseOperands())) { + for (auto [dest, operands] : llvm::zip(switchOp.getCaseDestinations(), + switchOp.getCaseOperands())) { if (failed(addBranchConstraints(dest, operands, op))) return WalkResult::interrupt(); } @@ -825,8 +1154,7 @@ struct LayoutSolver { for (Region *region : {&ifOp.getThenRegion(), &ifOp.getElseRegion()}) { if (region->empty()) continue; - auto yieldOp = - dyn_cast(region->front().getTerminator()); + auto yieldOp = dyn_cast(region->front().getTerminator()); if (!yieldOp || resultNo >= yieldOp.getNumOperands()) continue; if (failed(uniteEquivalentValues(result, yieldOp.getOperand(resultNo), @@ -852,8 +1180,8 @@ struct LayoutSolver { WalkResult result = executeOp.getRegion().walk([&](scf::YieldOp yieldOp) { if (yieldOp->getParentOp() != executeOp.getOperation()) return WalkResult::advance(); - if (failed(addYieldConstraints(executeOp->getResults(), yieldOp, - executeOp))) + if (failed( + addYieldConstraints(executeOp->getResults(), yieldOp, executeOp))) return WalkResult::interrupt(); return WalkResult::advance(); }); @@ -903,8 +1231,8 @@ struct LayoutSolver { whileOp))) return failure(); if (index < whileOp.getNumResults() && - failed(uniteEquivalentValues(anchor, whileOp.getResult(index), - whileOp))) + failed( + uniteEquivalentValues(anchor, whileOp.getResult(index), whileOp))) return failure(); } return success(); @@ -927,8 +1255,8 @@ struct LayoutSolver { failed(uniteEquivalentValues(anchor, results[index], forOp))) return failure(); if (yieldOp && index < yieldOp.getNumOperands() && - failed(uniteEquivalentValues(anchor, yieldOp.getOperand(index), - forOp))) + failed( + uniteEquivalentValues(anchor, yieldOp.getOperand(index), forOp))) return failure(); } return success(); @@ -963,7 +1291,8 @@ struct LayoutSolver { for (auto [index, operand] : llvm::enumerate(returnOp.getOperands())) { if (index >= firstOperands.size()) break; - if (failed(uniteEquivalentValues(firstOperands[index], operand, returnOp))) + if (failed( + uniteEquivalentValues(firstOperands[index], operand, returnOp))) return failure(); } return success(); @@ -1020,13 +1349,12 @@ struct LayoutSolver { } std::optional rematerializeDataUse(Value value, VMIVRegType resultType, - Location loc, - OpBuilder &builder) { + Location loc, OpBuilder &builder) { if (auto constant = value.getDefiningOp()) { auto denseAttr = dyn_cast(constant.getValue()); if (denseAttr && denseAttr.isSplat()) - return builder.create(loc, resultType, - constant.getValue()) + return builder + .create(loc, resultType, constant.getValue()) .getResult(); } if (auto broadcast = value.getDefiningOp()) @@ -1034,8 +1362,9 @@ struct LayoutSolver { .create(loc, resultType, broadcast.getValue()) .getResult(); if (auto iota = value.getDefiningOp()) - return builder.create(loc, resultType, iota.getBase(), - iota.getOrderAttr()) + return builder + .create(loc, resultType, iota.getBase(), + iota.getOrderAttr()) .getResult(); return std::nullopt; } @@ -1060,9 +1389,8 @@ struct LayoutSolver { VMIVRegType::get(ctx, sourceType.getElementCount(), sourceType.getElementType(), request.layout); builder.setInsertionPoint(request.operand->getOwner()); - std::optional rematerialized = - rematerializeDataUse(value, resultType, - request.operand->getOwner()->getLoc(), builder); + std::optional rematerialized = rematerializeDataUse( + value, resultType, request.operand->getOwner()->getLoc(), builder); if (rematerialized) { request.operand->set(*rematerialized); continue; @@ -1078,120 +1406,97 @@ struct LayoutSolver { WalkResult result = module.walk([&](Operation *op) -> WalkResult { if (auto cmpf = dyn_cast(op)) { auto lhsType = cast(cmpf.getLhs().getType()); - if (failed(requestMask(cmpf.getResult(), lhsType.getLayoutAttr(), - getMaskGranularityForElement( - lhsType.getElementType()), - op))) + if (failed(requestMask( + cmpf.getResult(), lhsType.getLayoutAttr(), + getMaskGranularityForElement(lhsType.getElementType()), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto cmpi = dyn_cast(op)) { auto lhsType = cast(cmpi.getLhs().getType()); - if (failed(requestMask(cmpi.getResult(), lhsType.getLayoutAttr(), - getMaskGranularityForElement( - lhsType.getElementType()), - op))) + if (failed(requestMask( + cmpi.getResult(), lhsType.getLayoutAttr(), + getMaskGranularityForElement(lhsType.getElementType()), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto select = dyn_cast(op)) { auto resultType = cast(select.getResult().getType()); - if (failed(requestMaskUse(select.getMaskMutable(), - resultType.getLayoutAttr(), - getMaskGranularityForElement( - resultType.getElementType()), - op))) + if (failed(requestMaskUse( + select.getMaskMutable(), resultType.getLayoutAttr(), + getMaskGranularityForElement(resultType.getElementType()), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto activePrefix = dyn_cast(op)) { - auto resultType = - cast(activePrefix.getResult().getType()); - if (failed(requestMaskUse(activePrefix.getMaskMutable(), - resultType.getLayoutAttr(), - getMaskGranularityForElement( - resultType.getElementType()), - op))) + auto resultType = cast(activePrefix.getResult().getType()); + if (failed(requestMaskUse( + activePrefix.getMaskMutable(), resultType.getLayoutAttr(), + getMaskGranularityForElement(resultType.getElementType()), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto compress = dyn_cast(op)) { auto resultType = cast(compress.getResult().getType()); - if (failed(requestMaskUse(compress.getMaskMutable(), - resultType.getLayoutAttr(), - getMaskGranularityForElement( - resultType.getElementType()), - op))) + if (failed(requestMaskUse( + compress.getMaskMutable(), resultType.getLayoutAttr(), + getMaskGranularityForElement(resultType.getElementType()), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto reduce = dyn_cast(op)) { auto sourceType = cast(reduce.getSource().getType()); - if (failed(requestMaskUse(reduce.getMaskMutable(), - sourceType.getLayoutAttr(), - getMaskGranularityForElement( - sourceType.getElementType()), - op))) + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceType.getLayoutAttr(), + getMaskGranularityForElement(sourceType.getElementType()), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto reduce = dyn_cast(op)) { auto sourceType = cast(reduce.getSource().getType()); - if (failed(requestMaskUse(reduce.getMaskMutable(), - sourceType.getLayoutAttr(), - getMaskGranularityForElement( - sourceType.getElementType()), - op))) + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceType.getLayoutAttr(), + getMaskGranularityForElement(sourceType.getElementType()), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto reduce = dyn_cast(op)) { auto sourceType = cast(reduce.getSource().getType()); - if (failed(requestMaskUse(reduce.getMaskMutable(), - sourceType.getLayoutAttr(), - getMaskGranularityForElement( - sourceType.getElementType()), - op))) + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceType.getLayoutAttr(), + getMaskGranularityForElement(sourceType.getElementType()), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto reduce = dyn_cast(op)) { auto sourceType = cast(reduce.getSource().getType()); - if (failed(requestMaskUse(reduce.getMaskMutable(), - sourceType.getLayoutAttr(), - getMaskGranularityForElement( - sourceType.getElementType()), - op))) + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceType.getLayoutAttr(), + getMaskGranularityForElement(sourceType.getElementType()), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto reduce = dyn_cast(op)) { auto sourceType = cast(reduce.getSource().getType()); - if (failed(requestMaskUse(reduce.getMaskMutable(), - sourceType.getLayoutAttr(), - getMaskGranularityForElement( - sourceType.getElementType()), - op))) + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceType.getLayoutAttr(), + getMaskGranularityForElement(sourceType.getElementType()), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto load = dyn_cast(op)) { auto resultType = cast(load.getResult().getType()); - if (failed(requestMaskUse(load.getMaskMutable(), - resultType.getLayoutAttr(), - getMaskGranularityForElement( - resultType.getElementType()), - op))) + if (failed(requestMaskUse( + load.getMaskMutable(), resultType.getLayoutAttr(), + getMaskGranularityForElement(resultType.getElementType()), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto load = dyn_cast(op)) { auto resultType = cast(load.getResult().getType()); - if (failed(requestMaskUse(load.getMaskMutable(), - resultType.getLayoutAttr(), - getMaskGranularityForElement( - resultType.getElementType()), - op))) + if (failed(requestMaskUse( + load.getMaskMutable(), resultType.getLayoutAttr(), + getMaskGranularityForElement(resultType.getElementType()), op))) return WalkResult::interrupt(); return WalkResult::advance(); } @@ -1203,8 +1508,8 @@ struct LayoutSolver { void rewriteMaskTypes() { for (MaskNode &node : maskNodes) { MaskNode &root = maskNodes[findMask(maskIds.lookup(node.value))]; - VMILayoutAttr layout = root.requestedLayout ? root.requestedLayout - : getContiguousLayout(); + VMILayoutAttr layout = + root.requestedLayout ? root.requestedLayout : getContiguousLayout(); StringRef granularity = root.requestedGranularity.empty() ? StringRef("b32") : StringRef(root.requestedGranularity); @@ -1214,11 +1519,17 @@ struct LayoutSolver { } std::optional rematerializeMaskUse(Value value, VMIMaskType resultType, - Location loc, - OpBuilder &builder) { + Location loc, OpBuilder &builder) { if (auto createMask = value.getDefiningOp()) - return builder.create(loc, resultType, - createMask.getActiveLanes()) + return builder + .create(loc, resultType, createMask.getActiveLanes()) + .getResult(); + if (auto createGroupMask = value.getDefiningOp()) + return builder + .create( + loc, resultType, createGroupMask.getActiveElemsPerGroup(), + createGroupMask.getNumGroupsAttr(), + createGroupMask.getGroupSizeAttr()) .getResult(); if (auto constantMask = value.getDefiningOp()) return builder @@ -1245,9 +1556,9 @@ struct LayoutSolver { builder.setInsertionPoint(request.operand->getOwner()); Value current = value; VMIMaskType currentType = sourceType; - auto requestedType = VMIMaskType::get(ctx, sourceType.getElementCount(), - request.granularity, - request.layout); + auto requestedType = + VMIMaskType::get(ctx, sourceType.getElementCount(), + request.granularity, request.layout); if (sourceType != requestedType) { std::optional rematerialized = rematerializeMaskUse( value, requestedType, request.operand->getOwner()->getLoc(), @@ -1259,9 +1570,9 @@ struct LayoutSolver { } if (sourceLayout != request.layout) { - auto layoutType = VMIMaskType::get(ctx, currentType.getElementCount(), - currentType.getGranularity(), - request.layout); + auto layoutType = + VMIMaskType::get(ctx, currentType.getElementCount(), + currentType.getGranularity(), request.layout); auto ensureLayout = builder.create( request.operand->getOwner()->getLoc(), layoutType, current); current = ensureLayout.getResult(); @@ -1272,10 +1583,8 @@ struct LayoutSolver { auto granularityType = VMIMaskType::get(ctx, currentType.getElementCount(), request.granularity, request.layout); - auto ensureGranularity = - builder.create( - request.operand->getOwner()->getLoc(), granularityType, - current); + auto ensureGranularity = builder.create( + request.operand->getOwner()->getLoc(), granularityType, current); current = ensureGranularity.getResult(); } @@ -1285,6 +1594,143 @@ struct LayoutSolver { return success(); } + std::optional getGroupReduceSelectedPlan(VMIGroupReduceAddFOp op) { + auto sourceType = dyn_cast(op.getSource().getType()); + if (!sourceType) + return std::nullopt; + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + if (!sourceLayout) + return std::nullopt; + + int64_t numGroups = op.getNumGroupsAttr().getInt(); + if (numGroups <= 0 || sourceType.getElementCount() % numGroups != 0) + return std::nullopt; + int64_t groupSize = sourceType.getElementCount() / numGroups; + + if (sourceLayout.isContiguous()) { + if (groupSize == 8) + return StringRef("s8_reduce_contiguous"); + if (groupSize == 64) + return StringRef("s64_reduce_row_local"); + return std::nullopt; + } + + if (!sourceLayout.isDeinterleaved()) + return std::nullopt; + + if (groupSize == 16 && sourceLayout.getFactor() == 2) { + if (sourceLayout.getBlockElems() == 1) + return StringRef("s16_reduce_parity"); + if (sourceLayout.getBlockElems() == 8) + return StringRef("s16_reduce_block8"); + } + + if (groupSize == 32 && sourceLayout.getFactor() == 4) { + if (sourceLayout.getBlockElems() == 1) + return StringRef("s32_reduce_dintlv4"); + if (sourceLayout.getBlockElems() == 8) + return StringRef("s32_reduce_block8_stride"); + } + + return std::nullopt; + } + + std::optional getGroupSlotLoadSelectedPlan(VMIGroupSlotLoadOp op) { + auto resultType = dyn_cast(op.getResult().getType()); + if (!resultType) + return std::nullopt; + VMILayoutAttr layout = resultType.getLayoutAttr(); + if (!layout || !layout.isGroupSlots() || + layout.getNumGroups() != op.getNumGroupsAttr().getInt()) + return std::nullopt; + if (layout.getSlots() == 8) + return StringRef("group_slot_load_slots8_unit_stride"); + if (layout.getSlots() == 1) + return StringRef("group_slot_load_slots1_row_local"); + return std::nullopt; + } + + std::optional getGroupLoadSelectedPlan(VMIGroupLoadOp op) { + auto resultType = dyn_cast(op.getResult().getType()); + if (!resultType) + return std::nullopt; + VMILayoutAttr layout = resultType.getLayoutAttr(); + if (!layout) + return std::nullopt; + if (layout.isContiguous()) + return StringRef("group_load_contiguous_chunks"); + if (!layout.isDeinterleaved() || layout.getBlockElems() != 8) + return std::nullopt; + + int64_t numGroups = op.getNumGroupsAttr().getInt(); + if (numGroups <= 0 || resultType.getElementCount() % numGroups != 0) + return std::nullopt; + int64_t groupSize = resultType.getElementCount() / numGroups; + if (groupSize == 16 && layout.getFactor() == 2) + return StringRef("s16_group_load_block8_stride"); + if (groupSize == 32 && layout.getFactor() == 4) + return StringRef("s32_group_load_block8_stride"); + return std::nullopt; + } + + std::optional + getGroupBroadcastSelectedPlan(VMIGroupBroadcastOp op) { + auto sourceType = dyn_cast(op.getSource().getType()); + auto resultType = dyn_cast(op.getResult().getType()); + if (!sourceType || !resultType) + return std::nullopt; + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!sourceLayout || !resultLayout || !sourceLayout.isGroupSlots() || + sourceLayout.getNumGroups() != op.getNumGroupsAttr().getInt() || + resultLayout.isGroupSlots()) + return std::nullopt; + if (sourceLayout.getSlots() == 8) + return StringRef("group_broadcast_slots8_vselr"); + if (sourceLayout.getSlots() == 1) + return StringRef("group_broadcast_slots1_vselr"); + return std::nullopt; + } + + std::optional getTruncFSelectedPlan(VMITruncFOp op) { + auto sourceType = dyn_cast(op.getSource().getType()); + auto resultType = dyn_cast(op.getResult().getType()); + if (!sourceType || !resultType) + return std::nullopt; + + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!sourceLayout || !resultLayout || sourceLayout != resultLayout || + !sourceLayout.isGroupSlots() || sourceLayout.getSlots() != 1) + return std::nullopt; + + unsigned sourceBits = getElementBitWidth(sourceType.getElementType()); + unsigned resultBits = getElementBitWidth(resultType.getElementType()); + if (sourceBits == 32 && resultBits == 16) + return StringRef("group_slot_cast_slots1_f32_to_f16"); + return std::nullopt; + } + + void attachSelectedPlanAttrs() { + Builder builder(ctx); + module.walk([&](Operation *op) { + std::optional plan; + if (auto reduce = dyn_cast(op)) + plan = getGroupReduceSelectedPlan(reduce); + else if (auto load = dyn_cast(op)) + plan = getGroupLoadSelectedPlan(load); + else if (auto load = dyn_cast(op)) + plan = getGroupSlotLoadSelectedPlan(load); + else if (auto broadcast = dyn_cast(op)) + plan = getGroupBroadcastSelectedPlan(broadcast); + else if (auto truncf = dyn_cast(op)) + plan = getTruncFSelectedPlan(truncf); + + if (plan) + op->setAttr(kVMISelectedPlanAttrName, builder.getStringAttr(*plan)); + }); + } + void rewriteFunctionType() { module.walk([&](func::FuncOp func) { if (func.empty()) @@ -1320,6 +1766,8 @@ struct LayoutSolver { } LogicalResult run() { + if (failed(commuteTruncFAfterGroupBroadcast())) + return failure(); if (failed(collect())) return failure(); if (failed(addConstraints())) @@ -1329,6 +1777,7 @@ struct LayoutSolver { rewriteDataTypes(); if (failed(insertDataUseMaterializations())) return failure(); + attachSelectedPlanAttrs(); if (failed(inferMaskRequests())) return failure(); rewriteMaskTypes(); @@ -1351,8 +1800,7 @@ struct LayoutSolver { }; struct VMILayoutAssignmentPass - : public mlir::pto::impl::VMILayoutAssignmentBase< - VMILayoutAssignmentPass> { + : public mlir::pto::impl::VMILayoutAssignmentBase { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VMILayoutAssignmentPass) void runOnOperation() override { diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index cf91af1142..95141bada7 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -1,10 +1,12 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. //===- VMIToVPTO.cpp - Convert VMI to physical VPTO IR -------------------===// //===----------------------------------------------------------------------===// @@ -48,6 +50,8 @@ using namespace mlir::pto; namespace { +static constexpr const char *kVMISelectedPlanAttrName = "vmi.selected_plan"; + bool isVMIType(Type type) { return isa(type); } bool containsVMIType(Type type) { @@ -87,9 +91,8 @@ bool hasVMIType(Attribute attr) { return true; if (auto arrayAttr = dyn_cast(attr)) - return llvm::any_of(arrayAttr, [](Attribute element) { - return hasVMIType(element); - }); + return llvm::any_of(arrayAttr, + [](Attribute element) { return hasVMIType(element); }); if (auto dictAttr = dyn_cast(attr)) return llvm::any_of(dictAttr, [](NamedAttribute namedAttr) { @@ -130,9 +133,8 @@ bool isLayoutAssignedVMIType(Type type) { LogicalResult verifyLayoutAssignedVMITypeTree(Operation *op, Type type) { if (!isLayoutAssignedVMIType(type)) - return op->emitError() - << kVMIDiagPassInvariantPrefix - << "vmi-to-vpto requires layout-assigned VMI types"; + return op->emitError() << kVMIDiagPassInvariantPrefix + << "vmi-to-vpto requires layout-assigned VMI types"; if (auto functionType = dyn_cast(type)) { for (Type input : functionType.getInputs()) @@ -233,28 +235,28 @@ class VMIToVPTOTypeConverter final : public OneToNTypeConverter { public: VMIToVPTOTypeConverter() { addConversion([](Type type) { return type; }); - addConversion([](VMIVRegType type, SmallVectorImpl &results) - -> LogicalResult { - FailureOr arity = getVMIPhysicalArity(type); - FailureOr lanesPerPart = - getDataLanesPerPart(type.getElementType()); - if (failed(arity) || failed(lanesPerPart)) - return failure(); - for (int64_t i = 0; i < *arity; ++i) - results.push_back(VRegType::get(type.getContext(), *lanesPerPart, - type.getElementType())); - return success(); - }); - addConversion([](VMIMaskType type, SmallVectorImpl &results) - -> LogicalResult { - FailureOr arity = getVMIPhysicalArity(type); - if (failed(arity)) - return failure(); - for (int64_t i = 0; i < *arity; ++i) - results.push_back(MaskType::get(type.getContext(), - type.getGranularity())); - return success(); - }); + addConversion( + [](VMIVRegType type, SmallVectorImpl &results) -> LogicalResult { + FailureOr arity = getVMIPhysicalArity(type); + FailureOr lanesPerPart = + getDataLanesPerPart(type.getElementType()); + if (failed(arity) || failed(lanesPerPart)) + return failure(); + for (int64_t i = 0; i < *arity; ++i) + results.push_back(VRegType::get(type.getContext(), *lanesPerPart, + type.getElementType())); + return success(); + }); + addConversion( + [](VMIMaskType type, SmallVectorImpl &results) -> LogicalResult { + FailureOr arity = getVMIPhysicalArity(type); + if (failed(arity)) + return failure(); + for (int64_t i = 0; i < *arity; ++i) + results.push_back( + MaskType::get(type.getContext(), type.getGranularity())); + return success(); + }); TypeConverter::addSourceMaterialization(materializeVPTOToVMI); TypeConverter::addArgumentMaterialization(materializeVPTOToVMI); OneToNTypeConverter::addTargetMaterialization(materializeVMIToVPTO); @@ -284,8 +286,7 @@ FailureOr createAllTrueMaskForVReg(Location loc, VRegType vregType, return failure(); } -FailureOr getMaskTypeForVReg(VRegType vregType, - MLIRContext *ctx) { +FailureOr getMaskTypeForVReg(VRegType vregType, MLIRContext *ctx) { unsigned elementBits = pto::getPTOStorageElemBitWidth(vregType.getElementType()); if (elementBits == 8) @@ -341,10 +342,12 @@ FailureOr createPrefixMask(Location loc, MaskType maskType, return rewriter.create(loc, MaskType::get(ctx, "b8"), patternAttr) .getResult(); if (maskType.isB16()) - return rewriter.create(loc, MaskType::get(ctx, "b16"), patternAttr) + return rewriter + .create(loc, MaskType::get(ctx, "b16"), patternAttr) .getResult(); if (maskType.isB32()) - return rewriter.create(loc, MaskType::get(ctx, "b32"), patternAttr) + return rewriter + .create(loc, MaskType::get(ctx, "b32"), patternAttr) .getResult(); return failure(); } @@ -372,18 +375,17 @@ createRuntimePrefixMask(Location loc, MaskType maskType, Value activeLanes, return failure(); } -LogicalResult checkSupportedMaskableVReg( - const VMITargetCapabilityRegistry &capabilities, VMIVRegType type, - std::string *reason = nullptr) { +LogicalResult +checkSupportedMaskableVReg(const VMITargetCapabilityRegistry &capabilities, + VMIVRegType type, std::string *reason = nullptr) { auto fail = [&](const Twine &message) -> LogicalResult { if (reason) *reason = message.str(); return failure(); }; - VMICapabilityResult elementCapability = - capabilities.supportsElementType(type.getElementType(), - VMIElementPurpose::PredicateMask); + VMICapabilityResult elementCapability = capabilities.supportsElementType( + type.getElementType(), VMIElementPurpose::PredicateMask); if (!elementCapability.isSupported()) return fail(elementCapability.reason); @@ -395,10 +397,11 @@ LogicalResult checkSupportedMaskableVReg( return success(); } -LogicalResult checkSupportedTargetElementVReg( - const VMITargetCapabilityRegistry &capabilities, VMIVRegType type, - VMIElementPurpose purpose, StringRef elementContract, - std::string *reason = nullptr) { +LogicalResult +checkSupportedTargetElementVReg(const VMITargetCapabilityRegistry &capabilities, + VMIVRegType type, VMIElementPurpose purpose, + StringRef elementContract, + std::string *reason = nullptr) { auto fail = [&](const Twine &message) -> LogicalResult { if (reason) *reason = message.str(); @@ -416,19 +419,46 @@ LogicalResult checkSupportedTargetElementVReg( return success(); } -Value createI32Constant(Location loc, int64_t value, PatternRewriter &rewriter) { +Value createI32Constant(Location loc, int64_t value, + PatternRewriter &rewriter) { return rewriter.create(loc, value, 32); } +FailureOr createPrefixMaskForActiveLanes(Location loc, MaskType maskType, + int64_t activeLanes, + PatternRewriter &rewriter) { + if (activeLanes <= 0) + return createPrefixMask(loc, maskType, "PAT_ALLF", rewriter); + + switch (activeLanes) { + case 1: + case 2: + case 3: + case 4: + case 8: + case 16: + case 32: + case 64: + case 128: + return createPrefixMask( + loc, maskType, (Twine("PAT_VL") + Twine(activeLanes)).str(), rewriter); + default: { + FailureOr> dynamicMask = createRuntimePrefixMask( + loc, maskType, createI32Constant(loc, activeLanes, rewriter), rewriter); + if (failed(dynamicMask)) + return failure(); + return dynamicMask->first; + } + } +} + Value clampDynamicActiveLanes(Location loc, Value activeLanes, int64_t maxActiveLanes, PatternRewriter &rewriter) { - Value activeI32 = - rewriter.create(loc, rewriter.getI32Type(), - activeLanes); + Value activeI32 = rewriter.create( + loc, rewriter.getI32Type(), activeLanes); Value zeroI32 = createI32Constant(loc, 0, rewriter); - Value nonNegative = - rewriter.create(loc, activeI32, zeroI32); + Value nonNegative = rewriter.create(loc, activeI32, zeroI32); Value maxI32 = createI32Constant(loc, maxActiveLanes, rewriter); return rewriter.create(loc, nonNegative, maxI32); } @@ -441,9 +471,8 @@ Value createPartitionActiveLanes(Location loc, Value activeLanesI32, int64_t bias = factor - 1 - part; Value biased = activeLanesI32; if (bias != 0) - biased = - rewriter.create(loc, biased, - createI32Constant(loc, bias, rewriter)); + biased = rewriter.create( + loc, biased, createI32Constant(loc, bias, rewriter)); return rewriter.create( loc, biased, createI32Constant(loc, factor, rewriter)); } @@ -643,8 +672,7 @@ LogicalResult checkSupportedLayoutMaterialization( }; VMICapabilityResult layoutCapability = - capabilities.supportsLayoutConversion(sourceLayout, resultLayout, - Type{}); + capabilities.supportsLayoutConversion(sourceLayout, resultLayout, Type{}); if (!layoutCapability.isSupported()) return fail(layoutCapability.reason); @@ -682,10 +710,10 @@ LogicalResult checkSupportedLayoutMaterialization( return success(); if (failed(sourceFull)) - return fail(Twine("source ") + sourceReason + - "; source materialization " + sourceMaterializationReason); - return fail(Twine("result ") + resultReason + - "; result materialization " + resultMaterializationReason); + return fail(Twine("source ") + sourceReason + "; source materialization " + + sourceMaterializationReason); + return fail(Twine("result ") + resultReason + "; result materialization " + + resultMaterializationReason); } FailureOr getContiguousMaterializationPartCount(Type type, @@ -861,8 +889,7 @@ buildContiguousIdentityLaneAddressMap(int64_t constantOffset, return map; } -VMICapabilityResult requireIdentityMemRefLayout(Type memoryType, - StringRef role, +VMICapabilityResult requireIdentityMemRefLayout(Type memoryType, StringRef role, Value memoryValue = {}) { auto memrefType = dyn_cast(memoryType); if (!memrefType || memrefType.getLayout().isIdentity()) @@ -878,9 +905,10 @@ VMICapabilityResult requireIdentityMemRefLayout(Type memoryType, return VMICapabilityResult::missingCapability(reason); } -VMIMemorySafeReadProof -computeSafeFullReadProof(Type sourceType, std::optional constantOffset, - VMIVRegType resultType) { +VMIMemorySafeReadProof computeSafeFullReadProof( + Type sourceType, std::optional constantOffset, + VMIVRegType resultType, + std::optional explicitFullReadElems = std::nullopt) { VMIMemorySafeReadProof proof; proof.constantOffset = constantOffset; @@ -893,9 +921,14 @@ computeSafeFullReadProof(Type sourceType, std::optional constantOffset, if (!constantOffset) return fail("requires constant index offset"); - FailureOr elements = getStaticMemRefElementCount(sourceType); - if (failed(elements)) - return fail("requires statically shaped memref source"); + std::optional elements = explicitFullReadElems; + if (!elements) { + FailureOr staticElements = getStaticMemRefElementCount(sourceType); + if (failed(staticElements)) + return fail("requires statically shaped memref source or explicit " + "full_read_elems"); + elements = *staticElements; + } proof.staticElementCount = *elements; if (*constantOffset < 0) @@ -920,11 +953,11 @@ computeSafeFullReadProof(Type sourceType, std::optional constantOffset, return proof; } -VMIMemoryAccessPlan -buildReadAccessPlan(const VMITargetCapabilityRegistry &capabilities, - Value source, Type sourceType, VMIVRegType resultType, - std::optional constantOffset, - VMIMemoryValidMaskKind validMask) { +VMIMemoryAccessPlan buildReadAccessPlan( + const VMITargetCapabilityRegistry &capabilities, Value source, + Type sourceType, VMIVRegType resultType, + std::optional constantOffset, VMIMemoryValidMaskKind validMask, + std::optional explicitFullReadElems = std::nullopt) { VMIMemoryAccessPlan plan; plan.baseType = sourceType; plan.valueType = resultType; @@ -933,19 +966,19 @@ buildReadAccessPlan(const VMITargetCapabilityRegistry &capabilities, plan.validMask = validMask; plan.permutation = VMIMemoryPermutationKind::Identity; plan.writeMask = VMIMemoryWriteMaskKind::AllTrue; - plan.safeReadProof = - computeSafeFullReadProof(sourceType, constantOffset, resultType); + plan.safeReadProof = computeSafeFullReadProof( + sourceType, constantOffset, resultType, explicitFullReadElems); plan.laneAddressMap = plan.safeReadProof.laneAddressMap; - plan.targetCapability = capabilities.supportsDirectMemory(sourceType, - "source"); + plan.targetCapability = + capabilities.supportsDirectMemory(sourceType, "source"); if (plan.targetCapability.isSupported()) plan.targetCapability = requireIdentityMemRefLayout(sourceType, "source", source); if (validMask == VMIMemoryValidMaskKind::ExplicitMask) plan.trueMaskedLoadCapability = capabilities.supportsTrueMaskedLoad(sourceType, resultType, Type{}); - plan.scratchFallbackCapability = - capabilities.supportsFallbackResource(VMIFallbackResourceKind::ScratchMemory); + plan.scratchFallbackCapability = capabilities.supportsFallbackResource( + VMIFallbackResourceKind::ScratchMemory); plan.guardedFallbackCapability = capabilities.supportsFallbackResource( VMIFallbackResourceKind::GuardedControlFlow); return plan; @@ -954,8 +987,7 @@ buildReadAccessPlan(const VMITargetCapabilityRegistry &capabilities, VMIMemoryAccessPlan buildWriteAccessPlan(const VMITargetCapabilityRegistry &capabilities, Value destination, Type destinationType, - VMIVRegType valueType, - VMIMemoryWriteMaskKind writeMask) { + VMIVRegType valueType, VMIMemoryWriteMaskKind writeMask) { VMIMemoryAccessPlan plan; plan.baseType = destinationType; plan.valueType = valueType; @@ -966,9 +998,8 @@ buildWriteAccessPlan(const VMITargetCapabilityRegistry &capabilities, plan.targetCapability = capabilities.supportsDirectMemory(destinationType, "destination"); if (plan.targetCapability.isSupported()) - plan.targetCapability = - requireIdentityMemRefLayout(destinationType, "destination", - destination); + plan.targetCapability = requireIdentityMemRefLayout( + destinationType, "destination", destination); return plan; } @@ -991,18 +1022,18 @@ void requireUnavailableReadFallback(VMIMemoryAccessPlan &plan) { maskedLoadReason + scratchReason + guardedReason); } -FailureOr -verifyFullOrSafeReadVRegChunks(Operation *op, VMIVRegType type, - Type sourceType, Value offset, - PatternRewriter &rewriter) { +FailureOr verifyFullOrSafeReadVRegChunks( + Operation *op, VMIVRegType type, Type sourceType, Value offset, + PatternRewriter &rewriter, + std::optional explicitFullReadElems = std::nullopt) { std::string fullChunkReason; FailureOr lanesPerPart = checkFullDataPhysicalChunks(type, &fullChunkReason); if (succeeded(lanesPerPart)) return *lanesPerPart; - VMIMemorySafeReadProof safeReadProof = - computeSafeFullReadProof(sourceType, getConstantIndexValue(offset), type); + VMIMemorySafeReadProof safeReadProof = computeSafeFullReadProof( + sourceType, getConstantIndexValue(offset), type, explicitFullReadElems); if (safeReadProof.proven) { lanesPerPart = getDataLanesPerPart(type.getElementType()); if (succeeded(lanesPerPart)) @@ -1018,16 +1049,16 @@ verifyFullOrSafeReadVRegChunks(Operation *op, VMIVRegType type, LogicalResult checkSupportedLoadShape( const VMITargetCapabilityRegistry &capabilities, VMIVRegType type, Value source, Type sourceType, std::optional constantOffset, - std::string *reason) { + std::optional explicitFullReadElems, std::string *reason) { auto fail = [&](const Twine &message) -> LogicalResult { if (reason) *reason = message.str(); return failure(); }; - VMIMemoryAccessPlan accessPlan = - buildReadAccessPlan(capabilities, source, sourceType, type, - constantOffset, VMIMemoryValidMaskKind::AllTrue); + VMIMemoryAccessPlan accessPlan = buildReadAccessPlan( + capabilities, source, sourceType, type, constantOffset, + VMIMemoryValidMaskKind::AllTrue, explicitFullReadElems); if (!accessPlan.targetCapability.isSupported()) return fail(accessPlan.targetCapability.reason); @@ -1039,14 +1070,14 @@ LogicalResult checkSupportedLoadShape( return success(); requireUnavailableReadFallback(accessPlan); return fail(Twine(fullChunkReason) + - "; safe-read proof failed: " + - accessPlan.safeReadProof.reason + + "; safe-read proof failed: " + accessPlan.safeReadProof.reason + "; fallback decision: " + accessPlan.fallbackDecision.reason); } -LogicalResult checkSupportedStoreShape( - const VMITargetCapabilityRegistry &capabilities, VMIVRegType type, - Value destination, Type destinationType, std::string *reason) { +LogicalResult +checkSupportedStoreShape(const VMITargetCapabilityRegistry &capabilities, + VMIVRegType type, Value destination, + Type destinationType, std::string *reason) { VMIMemoryAccessPlan accessPlan = buildWriteAccessPlan(capabilities, destination, destinationType, type, VMIMemoryWriteMaskKind::AllTrue); @@ -1083,8 +1114,7 @@ LogicalResult checkSupportedStoreShape( return fail(Twine("partial/tail store requires contiguous layout or " "deinterleaved layout that can materialize to contiguous; " "value ") + - fullChunkReason + ", materialization " + - materializationReason); + fullChunkReason + ", materialization " + materializationReason); } FailureOr getGroupSizeFromNumGroups(VMIVRegType type, @@ -1102,8 +1132,7 @@ FailureOr getGroupSizeFromNumGroups(VMIVRegType type, return type.getElementCount() / numGroups; } -LogicalResult checkSupportedGroupChunkShape(VMIVRegType type, - int64_t groupSize, +LogicalResult checkSupportedGroupChunkShape(VMIVRegType type, int64_t groupSize, std::string *reason) { auto fail = [&](const Twine &message) -> LogicalResult { if (reason) @@ -1129,29 +1158,211 @@ LogicalResult checkSupportedGroupChunkShape(VMIVRegType type, return success(); } -LogicalResult checkSupportedGroupLoadShape( - const VMITargetCapabilityRegistry &capabilities, VMIGroupLoadOp op, - std::string *reason) { +LogicalResult +checkSupportedGroupLoadShape(const VMITargetCapabilityRegistry &capabilities, + VMIGroupLoadOp op, std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + auto resultType = cast(op.getResult().getType()); - FailureOr groupSize = - getGroupSizeFromNumGroups(resultType, op.getNumGroupsAttr().getInt(), - reason); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!resultLayout) + return fail("requires assigned result layout"); + auto selectedPlan = op->getAttrOfType(kVMISelectedPlanAttrName); + if (!selectedPlan) + return fail("requires vmi.selected_plan selected by " + "vmi-layout-assignment"); + FailureOr groupSize = getGroupSizeFromNumGroups( + resultType, op.getNumGroupsAttr().getInt(), reason); if (failed(groupSize)) return failure(); - if (failed(checkSupportedLoadShape(capabilities, resultType, op.getSource(), - op.getSource().getType(), - std::nullopt, reason))) - return failure(); - return checkSupportedGroupChunkShape(resultType, *groupSize, reason); + + if (resultLayout.isContiguous()) { + StringRef expectedPlan = "group_load_contiguous_chunks"; + if (selectedPlan.getValue() != expectedPlan) + return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + + "' does not match result layout; expected '" + expectedPlan + + "'"); + if (failed(checkSupportedLoadShape(capabilities, resultType, op.getSource(), + op.getSource().getType(), std::nullopt, + std::nullopt, reason))) + return failure(); + return checkSupportedGroupChunkShape(resultType, *groupSize, reason); + } + + if (resultLayout.isDeinterleaved() && resultLayout.getBlockElems() == 8 && + resultType.getElementType().isF32()) { + StringRef expectedPlan; + if (*groupSize == 16 && resultLayout.getFactor() == 2) + expectedPlan = "s16_group_load_block8_stride"; + else if (*groupSize == 32 && resultLayout.getFactor() == 4) + expectedPlan = "s32_group_load_block8_stride"; + else + return fail("block8 strided group_load requires S=16/factor=2 or " + "S=32/factor=4"); + if (selectedPlan.getValue() != expectedPlan) + return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + + "' does not match result layout; expected '" + expectedPlan + + "'"); + if (!isa(op.getSource().getType())) + return fail("block8 strided group_load requires !pto.ptr source"); + if (op.getNumGroupsAttr().getInt() % 8 != 0) + return fail("block8 strided group_load requires num_groups multiple " + "of 8"); + std::optional rowStride = getConstantIndexValue(op.getRowStride()); + if (!rowStride || *rowStride <= 0 || *rowStride % 8 != 0) + return fail("block8 strided group_load requires constant positive " + "row_stride divisible by 8 f32 elements"); + std::string fullChunkReason; + if (failed(checkFullDataPhysicalChunks(resultType, &fullChunkReason))) + return fail(Twine("block8 strided group_load requires full physical " + "result chunks; ") + + fullChunkReason); + return success(); + } + + return fail("requires contiguous layout or deinterleaved block8 f32 layout"); } -LogicalResult checkSupportedGroupStoreShape( - const VMITargetCapabilityRegistry &capabilities, VMIGroupStoreOp op, +LogicalResult checkSupportedGroupSlotLoadShape( + const VMITargetCapabilityRegistry &capabilities, VMIGroupSlotLoadOp op, std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr layout = resultType.getLayoutAttr(); + if (!layout || !layout.isGroupSlots() || + layout.getNumGroups() != op.getNumGroupsAttr().getInt() || + layout.getSlots() <= 0) + return fail("requires explicit group_slots result layout matching " + "num_groups"); + + auto selectedPlan = op->getAttrOfType(kVMISelectedPlanAttrName); + if (!selectedPlan) + return fail("requires vmi.selected_plan selected by " + "vmi-layout-assignment"); + + StringRef expectedPlan; + if (layout.getSlots() == 8) + expectedPlan = "group_slot_load_slots8_unit_stride"; + else if (layout.getSlots() == 1) + expectedPlan = "group_slot_load_slots1_row_local"; + else + return fail("supports only slots=8 or slots=1 group_slot_load layouts"); + + if (selectedPlan.getValue() != expectedPlan) + return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + + "' does not match result layout; expected '" + expectedPlan + + "'"); + + if (!capabilities.supportsDirectMemory(op.getSource().getType(), "source") + .isSupported()) + return fail("requires supported direct memory source"); + if (!isa(op.getSource().getType())) + return fail("requires !pto.ptr source for vsldb lowering"); + if (layout.getSlots() == 8) { + std::optional stride = + getConstantIndexValue(op.getSourceGroupStride()); + if (!stride || *stride != 1) + return fail("slots=8 group_slot_load requires constant unit " + "source_group_stride"); + return success(); + } + if (layout.getSlots() == 1) { + unsigned elementBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); + if (elementBits == 0 || 256 % elementBits != 0) + return fail("slots=1 group_slot_load requires an 8/16/32-bit element " + "type"); + int64_t alignedStrideElems = 256 / elementBits; + std::optional stride = + getConstantIndexValue(op.getSourceGroupStride()); + if (!stride || *stride <= 0 || *stride % alignedStrideElems != 0) + return fail(Twine("slots=1 group_slot_load currently lowers as one " + "lane-0 vsldb per group and requires constant " + "positive source_group_stride divisible by ") + + Twine(alignedStrideElems) + + " elements for 32B load alignment; packed or unaligned " + "scalar load lowering is not implemented"); + return success(); + } + llvm_unreachable("unsupported group_slot_load slots should be rejected"); +} + +LogicalResult +checkSupportedGroupStoreShape(const VMITargetCapabilityRegistry &capabilities, + VMIGroupStoreOp op, std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + auto valueType = cast(op.getValue().getType()); - FailureOr groupSize = - getGroupSizeFromNumGroups(valueType, op.getNumGroupsAttr().getInt(), - reason); + VMILayoutAttr layout = valueType.getLayoutAttr(); + if (layout && layout.isGroupSlots()) { + int64_t numGroups = op.getNumGroupsAttr().getInt(); + if (layout.getNumGroups() != numGroups) + return fail("group_slots group_store requires layout num_groups to " + "match op num_groups"); + + VMIMemoryAccessPlan accessPlan = buildWriteAccessPlan( + capabilities, op.getDestination(), op.getDestination().getType(), + valueType, VMIMemoryWriteMaskKind::AllTrue); + if (!accessPlan.targetCapability.isSupported()) + return fail(accessPlan.targetCapability.reason); + + if (failed(checkSupportedMaskableVReg(capabilities, valueType, reason))) + return failure(); + + FailureOr arity = getVMIPhysicalArity(valueType); + if (failed(arity)) + return fail("requires computable physical arity"); + if (layout.getSlots() == 1) { + if (*arity != numGroups) + return fail("slots=1 group_store requires one physical part per " + "group"); + unsigned elementBits = + pto::getPTOStorageElemBitWidth(valueType.getElementType()); + if (elementBits == 0 || 256 % elementBits != 0) + return fail("slots=1 group_store requires an 8/16/32-bit element " + "type"); + int64_t alignedStrideElems = 256 / elementBits; + std::optional rowStride = + getConstantIndexValue(op.getRowStride()); + if (!rowStride || *rowStride <= 0 || *rowStride % alignedStrideElems != 0) + return fail(Twine("slots=1 group_store currently lowers as one " + "lane-0 vsts per group and requires constant " + "positive row_stride divisible by ") + + Twine(alignedStrideElems) + + " elements for 32B store alignment; packed or unaligned " + "contiguous store lowering is not implemented"); + return success(); + } + if (layout.getSlots() == 8) { + std::optional rowStride = + getConstantIndexValue(op.getRowStride()); + if (!rowStride || *rowStride != 1) + return fail("slots=8 group_store currently requires constant unit " + "row_stride"); + if (*arity != ceilDivNonNegative(numGroups, 8)) + return fail("slots=8 group_store arity must equal ceil(num_groups / " + "8)"); + return success(); + } + return fail("group_slots group_store currently supports only slots=1 or " + "unit-stride slots=8"); + } + + FailureOr groupSize = getGroupSizeFromNumGroups( + valueType, op.getNumGroupsAttr().getInt(), reason); if (failed(groupSize)) return failure(); if (failed(checkSupportedStoreShape(capabilities, valueType, @@ -1202,9 +1413,9 @@ checkSupportedMaskedLoadShape(const VMITargetCapabilityRegistry &capabilities, "; fallback decision: " + accessPlan.fallbackDecision.reason); } -LogicalResult checkSupportedGatherShape( - const VMITargetCapabilityRegistry &capabilities, VMIGatherOp op, - std::string *reason) { +LogicalResult +checkSupportedGatherShape(const VMITargetCapabilityRegistry &capabilities, + VMIGatherOp op, std::string *reason) { auto fail = [&](const Twine &message) -> LogicalResult { if (reason) *reason = message.str(); @@ -1260,8 +1471,7 @@ LogicalResult checkSupportedGatherShape( std::string passthruReason; std::string maskReason; if (failed(checkFullDataPhysicalChunks(resultType, &resultReason))) - return fail(Twine("result requires full physical chunks; ") + - resultReason); + return fail(Twine("result requires full physical chunks; ") + resultReason); if (failed(checkFullDataPhysicalChunks(indicesType, &indicesReason))) return fail(Twine("indices require full physical chunks; ") + indicesReason); @@ -1274,9 +1484,9 @@ LogicalResult checkSupportedGatherShape( return success(); } -LogicalResult checkSupportedScatterShape( - const VMITargetCapabilityRegistry &capabilities, VMIScatterOp op, - std::string *reason) { +LogicalResult +checkSupportedScatterShape(const VMITargetCapabilityRegistry &capabilities, + VMIScatterOp op, std::string *reason) { auto fail = [&](const Twine &message) -> LogicalResult { if (reason) *reason = message.str(); @@ -1300,9 +1510,9 @@ LogicalResult checkSupportedScatterShape( return fail("requires contiguous value, indices, and mask layouts"); VMICapabilityResult destinationCapability = - capabilities.supportsUBPointerMemory( - op.getDestination().getType(), "destination", "pto.vscatter", - "pto.vscatter writes only UB"); + capabilities.supportsUBPointerMemory(op.getDestination().getType(), + "destination", "pto.vscatter", + "pto.vscatter writes only UB"); if (!destinationCapability.isSupported()) return fail(destinationCapability.reason); @@ -1421,9 +1631,8 @@ checkSupportedExpandLoadShape(const VMITargetCapabilityRegistry &capabilities, return fail("requires contiguous result, passthru, and mask layouts"); std::string maskReason; - bool staticAllActive = - isStaticAllActiveMask(op.getMask(), resultType.getElementCount(), - &maskReason); + bool staticAllActive = isStaticAllActiveMask( + op.getMask(), resultType.getElementCount(), &maskReason); std::string fullChunkReason; if (staticAllActive && @@ -1435,16 +1644,16 @@ checkSupportedExpandLoadShape(const VMITargetCapabilityRegistry &capabilities, std::string allActivePathReason; if (!staticAllActive) { - allActivePathReason = maskReason.empty() ? "requires static all-active mask" - : maskReason; + allActivePathReason = + maskReason.empty() ? "requires static all-active mask" : maskReason; } else { requireUnavailableReadFallback(accessPlan); allActivePathReason = (Twine("requires full physical chunks or statically safe full-read " "footprint; value ") + fullChunkReason + ", safe-read proof " + - accessPlan.safeReadProof.reason + "; fallback decision: " + - accessPlan.fallbackDecision.reason) + accessPlan.safeReadProof.reason + + "; fallback decision: " + accessPlan.fallbackDecision.reason) .str(); } @@ -1487,13 +1696,14 @@ checkSupportedExpandLoadShape(const VMITargetCapabilityRegistry &capabilities, return success(); } -LogicalResult checkSupportedMaskedStoreShape( - const VMITargetCapabilityRegistry &capabilities, VMIVRegType valueType, - VMIMaskType maskType, Value destination, Type destinationType, - std::string *reason) { +LogicalResult +checkSupportedMaskedStoreShape(const VMITargetCapabilityRegistry &capabilities, + VMIVRegType valueType, VMIMaskType maskType, + Value destination, Type destinationType, + std::string *reason) { VMIMemoryAccessPlan accessPlan = - buildWriteAccessPlan(capabilities, destination, destinationType, valueType, - VMIMemoryWriteMaskKind::ExplicitMask); + buildWriteAccessPlan(capabilities, destination, destinationType, + valueType, VMIMemoryWriteMaskKind::ExplicitMask); if (!accessPlan.targetCapability.isSupported()) { if (reason) *reason = accessPlan.targetCapability.reason; @@ -1535,10 +1745,10 @@ LogicalResult checkSupportedMaskedStoreShape( maskType, &maskMaterializationReason); if (failed(maskParts)) return fail(Twine("mask cannot materialize to contiguous; mask ") + - maskReason + ", materialization " + - maskMaterializationReason); + maskReason + ", materialization " + maskMaterializationReason); if (*valueParts != *maskParts) - return fail("requires value/mask contiguous materialization arity to match"); + return fail( + "requires value/mask contiguous materialization arity to match"); return success(); } @@ -1561,8 +1771,7 @@ FailureOr createContiguousStoreMask(Location loc, VMIVRegType vmiType, if (failed(lanesPerPart)) return failure(); - FailureOr activeLanes = - getContiguousActiveDataLanes(vmiType, chunk); + FailureOr activeLanes = getContiguousActiveDataLanes(vmiType, chunk); if (failed(activeLanes)) return failure(); if (*activeLanes == *lanesPerPart) @@ -1572,10 +1781,8 @@ FailureOr createContiguousStoreMask(Location loc, VMIVRegType vmiType, getMaskTypeForVReg(vregType, rewriter.getContext()); if (failed(maskType)) return failure(); - FailureOr> maskAndRemaining = - createRuntimePrefixMask(loc, *maskType, - createI32Constant(loc, *activeLanes, rewriter), - rewriter); + FailureOr> maskAndRemaining = createRuntimePrefixMask( + loc, *maskType, createI32Constant(loc, *activeLanes, rewriter), rewriter); if (failed(maskAndRemaining)) return failure(); return maskAndRemaining->first; @@ -1590,8 +1797,7 @@ FailureOr createMaskedStorePredicate(Location loc, VMIVRegType vmiType, if (failed(lanesPerPart)) return failure(); - FailureOr activeLanes = - getContiguousActiveDataLanes(vmiType, chunk); + FailureOr activeLanes = getContiguousActiveDataLanes(vmiType, chunk); if (failed(activeLanes)) return failure(); if (*activeLanes == *lanesPerPart) @@ -1651,8 +1857,7 @@ computeShuffleForwardingSourceParts(VMIShuffleOp op, std::string *reason) { continue; FailureOr resultLogicalLane = - mapPhysicalLaneToLogical(resultType, resultPart, resultChunk, - lane); + mapPhysicalLaneToLogical(resultType, resultPart, resultChunk, lane); if (failed(resultLogicalLane) || *resultLogicalLane >= static_cast(indices.size())) return fail("failed to map result lane"); @@ -1694,7 +1899,7 @@ struct ShuffleVselrPlan { }; FailureOr computeShuffleLane0SplatSourcePart(VMIShuffleOp op, - std::string *reason) { + std::string *reason) { auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); @@ -1721,7 +1926,8 @@ FailureOr computeShuffleLane0SplatSourcePart(VMIShuffleOp op, FailureOr> computeShuffleVselrPlans(VMIShuffleOp op, std::string *reason) { - auto fail = [&](const Twine &message) -> FailureOr> { + auto fail = + [&](const Twine &message) -> FailureOr> { if (reason) *reason = message.str(); return failure(); @@ -1761,8 +1967,7 @@ computeShuffleVselrPlans(VMIShuffleOp op, std::string *reason) { return fail("requires full physical result chunks"); FailureOr resultLogicalLane = - mapPhysicalLaneToLogical(resultType, resultPart, resultChunk, - lane); + mapPhysicalLaneToLogical(resultType, resultPart, resultChunk, lane); if (failed(resultLogicalLane) || *resultLogicalLane >= static_cast(indices.size())) return fail("failed to map result lane"); @@ -1873,6 +2078,81 @@ computeConstantMaskMaterialization(VMIConstantMaskOp op, std::string *reason) { return materializations; } +FailureOr> +computeGroupMaskMaterialization(VMICreateGroupMaskOp op, std::string *reason) { + auto fail = [&](const Twine &message) + -> FailureOr> { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto activeConstant = + op.getActiveElemsPerGroup().getDefiningOp(); + if (!activeConstant) + return fail("requires constant active_elems_per_group"); + auto activeAttr = dyn_cast(activeConstant.getValue()); + if (!activeAttr) + return fail("active_elems_per_group must be an integer constant"); + + auto resultVMIType = cast(op.getResult().getType()); + VMILayoutAttr layout = resultVMIType.getLayoutAttr(); + if (!layout || + !VMIMaskType::isConcreteGranularity(resultVMIType.getGranularity())) + return fail("requires concrete layout and granularity"); + + FailureOr lanesPerPart = + getMaskLanesPerPart(resultVMIType.getGranularity()); + if (failed(lanesPerPart)) + return fail("requires known physical mask lanes per part"); + + int64_t numGroups = op.getNumGroupsAttr().getInt(); + int64_t groupSize = op.getGroupSizeAttr().getInt(); + if (numGroups <= 0 || groupSize <= 0 || + resultVMIType.getElementCount() != numGroups * groupSize) + return fail("requires result lane count to match num_groups * group_size"); + + int64_t activeElems = activeAttr.getInt(); + if (activeElems < 0) + activeElems = 0; + if (activeElems > groupSize) + activeElems = groupSize; + + int64_t factor = layout.isDeinterleaved() ? layout.getFactor() : 1; + SmallVector materializations; + for (int64_t part = 0; part < factor; ++part) { + for (int64_t chunk = 0;; ++chunk) { + bool anyLane = false; + ConstantMaskChunkMaterialization materialization; + materialization.activeLanes.reserve(*lanesPerPart); + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = + isPaddingLane(resultVMIType, part, chunk, lane); + if (failed(padding)) + return fail("failed to map physical padding lane"); + if (*padding) { + materialization.activeLanes.push_back(0); + continue; + } + anyLane = true; + + FailureOr logicalLane = + mapPhysicalLaneToLogical(resultVMIType, part, chunk, lane); + if (failed(logicalLane)) + return fail("failed to map physical lane"); + int64_t laneInGroup = *logicalLane % groupSize; + materialization.activeLanes.push_back(laneInGroup < activeElems ? 1 + : 0); + } + if (!anyLane) + break; + materializations.push_back(std::move(materialization)); + } + } + + return materializations; +} + std::optional getPrefixActiveLaneCount(ArrayRef activeLanes) { bool seenInactive = false; int64_t activeCount = 0; @@ -1897,19 +2177,16 @@ FailureOr materializePrefixMask(Location loc, MaskType maskType, if (pattern) return createPatternMask(loc, maskType, *pattern, rewriter); - FailureOr> maskAndRemaining = - createRuntimePrefixMask(loc, maskType, - createI32Constant(loc, activeLanes, rewriter), - rewriter); + FailureOr> maskAndRemaining = createRuntimePrefixMask( + loc, maskType, createI32Constant(loc, activeLanes, rewriter), rewriter); if (failed(maskAndRemaining)) return failure(); return maskAndRemaining->first; } -FailureOr -materializeConstantMaskChunk(Location loc, MaskType maskType, - ArrayRef activeLanes, - PatternRewriter &rewriter) { +FailureOr materializeConstantMaskChunk(Location loc, MaskType maskType, + ArrayRef activeLanes, + PatternRewriter &rewriter) { FailureOr lanesPerPart = getMaskLanesPerPart(maskType.getGranularity()); if (failed(lanesPerPart) || @@ -1952,10 +2229,10 @@ materializeConstantMaskChunk(Location loc, MaskType maskType, Value notPrefixBegin = rewriter.create(loc, maskType, *prefixBegin, *allTrue) .getResult(); - runMask = - rewriter.create(loc, maskType, *prefixEnd, notPrefixBegin, - *allTrue) - .getResult(); + runMask = rewriter + .create(loc, maskType, *prefixEnd, notPrefixBegin, + *allTrue) + .getResult(); } if (!result) { @@ -1996,12 +2273,9 @@ Value createGroupChunkOffset(Location loc, Value baseOffset, Value rowStride, return createChunkOffset(loc, offset, inGroupLaneOffset, rewriter); } -LogicalResult checkContiguousFullGroupChunks(Operation *op, VMIVRegType type, - int64_t groupSize, - int64_t *lanesPerPart, - int64_t *groupCount, - int64_t *chunksPerGroup, - PatternRewriter &rewriter) { +LogicalResult checkContiguousFullGroupChunks( + Operation *op, VMIVRegType type, int64_t groupSize, int64_t *lanesPerPart, + int64_t *groupCount, int64_t *chunksPerGroup, PatternRewriter &rewriter) { auto fail = [&](const Twine &message) { return rewriter.notifyMatchFailure(op, message); }; @@ -2027,19 +2301,15 @@ LogicalResult checkContiguousFullGroupChunks(Operation *op, VMIVRegType type, return success(); } -LogicalResult checkFullGroupSlotSourceShape(Operation *op, VMIVRegType type, - int64_t groupSize, - int64_t numGroups, - int64_t *lanesPerPart, - int64_t *groupCount, - PatternRewriter &rewriter) { +LogicalResult checkFullGroupSlotSourceShape( + Operation *op, VMIVRegType type, int64_t groupSize, int64_t numGroups, + int64_t *lanesPerPart, int64_t *groupCount, PatternRewriter &rewriter) { auto fail = [&](const Twine &message) { return rewriter.notifyMatchFailure(op, message); }; VMILayoutAttr layout = type.getLayoutAttr(); - if (!layout || !layout.isGroupSlots() || - layout.getNumGroups() != numGroups) + if (!layout || !layout.isGroupSlots() || layout.getNumGroups() != numGroups) return fail("group slot op requires matching num_groups VMI layout"); if (failed(checkFullDataPhysicalChunks(type, nullptr))) return fail("group slot op requires full physical chunks"); @@ -2047,8 +2317,8 @@ LogicalResult checkFullGroupSlotSourceShape(Operation *op, VMIVRegType type, if (failed(lanes)) return fail("group slot op requires known physical lanes per part"); if (groupSize <= 0 || type.getElementCount() % groupSize != 0) - return fail( - "group slot op requires derived group size to evenly divide lane count"); + return fail("group slot op requires derived group size to evenly divide " + "lane count"); if (*lanes % groupSize != 0 && groupSize % *lanes != 0) return fail("group slot op requires group size to divide or be a " "multiple of physical lanes per part"); @@ -2058,13 +2328,9 @@ LogicalResult checkFullGroupSlotSourceShape(Operation *op, VMIVRegType type, return success(); } -LogicalResult checkFullGroupBroadcastResultShape(Operation *op, - VMIVRegType type, - int64_t groupSize, - int64_t lanesPerPart, - int64_t *layoutFactor, - int64_t *groupCount, - PatternRewriter &rewriter) { +LogicalResult checkFullGroupBroadcastResultShape( + Operation *op, VMIVRegType type, int64_t groupSize, int64_t lanesPerPart, + int64_t *layoutFactor, int64_t *groupCount, PatternRewriter &rewriter) { auto fail = [&](const Twine &message) { return rewriter.notifyMatchFailure(op, message); }; @@ -2076,8 +2342,7 @@ LogicalResult checkFullGroupBroadcastResultShape(Operation *op, return fail("group_broadcast result requires a dense VMI layout"); if (failed(checkFullDataPhysicalChunks(type, nullptr))) return fail("group_broadcast result requires full physical chunks"); - FailureOr resultLanes = - getDataLanesPerPart(type.getElementType()); + FailureOr resultLanes = getDataLanesPerPart(type.getElementType()); if (failed(resultLanes) || *resultLanes != lanesPerPart) return fail("group_broadcast result requires matching physical lanes"); if (groupSize <= 0 || type.getElementCount() % groupSize != 0) @@ -2092,9 +2357,13 @@ LogicalResult checkFullGroupBroadcastResultShape(Operation *op, return fail("group_broadcast contiguous result requires group size to " "divide or be a multiple of physical lanes per part"); } else { + bool blockFragmentSmallGroup = + layout.isDeinterleaved() && layout.getBlockElems() > 1 && + groupSize < lanesPerPart && lanesPerPart % layout.getBlockElems() == 0; int64_t logicalSpanPerResultChunk = lanesPerPart * *factor; - if (groupSize < lanesPerPart || - groupSize % logicalSpanPerResultChunk != 0) + if (!blockFragmentSmallGroup && + (groupSize < lanesPerPart || + groupSize % logicalSpanPerResultChunk != 0)) return fail("group_broadcast deinterleaved result requires every " "physical result chunk to stay within one logical group"); } @@ -2111,8 +2380,9 @@ FailureOr createZeroVector(Location loc, VRegType type, FailureOr mask = createAllTrueMaskForVReg(loc, type, rewriter); if (failed(zero) || failed(mask)) return failure(); - return rewriter.create(loc, type, *zero, *mask, - /*position=*/nullptr) + return rewriter + .create(loc, type, *zero, *mask, + /*position=*/nullptr) .getResult(); } @@ -2121,8 +2391,7 @@ FailureOr createLaneRangeMask(Location loc, MaskType maskType, PatternRewriter &rewriter) { FailureOr lanesPerPart = getMaskLanesPerPart(maskType.getGranularity()); - if (failed(lanesPerPart) || begin < 0 || begin > end || - end > *lanesPerPart) + if (failed(lanesPerPart) || begin < 0 || begin > end || end > *lanesPerPart) return failure(); SmallVector active(*lanesPerPart, 0); for (int64_t lane = begin; lane < end; ++lane) @@ -2132,34 +2401,38 @@ FailureOr createLaneRangeMask(Location loc, MaskType maskType, FailureOr createGroupSlotIndexVector(Location loc, VRegType indexType, int64_t groupSize, + int64_t baseGroupSlot, PatternRewriter &rewriter) { int64_t lanesPerPart = indexType.getElementCount(); - FailureOr zero = - createZeroVector(loc, indexType, rewriter); - FailureOr maskType = getMaskTypeForVReg(indexType, rewriter.getContext()); + FailureOr baseScalar = createScalarOffsetConstant( + loc, indexType.getElementType(), baseGroupSlot, rewriter); + FailureOr maskType = + getMaskTypeForVReg(indexType, rewriter.getContext()); FailureOr allMask = createAllTrueMaskForVReg(loc, indexType, rewriter); - if (failed(zero) || failed(maskType) || failed(allMask)) + if (failed(baseScalar) || failed(maskType) || failed(allMask)) return failure(); + Value result = rewriter + .create(loc, indexType, *baseScalar, *allMask, + /*position=*/nullptr) + .getResult(); if (groupSize >= lanesPerPart) - return *zero; + return result; if (lanesPerPart % groupSize != 0) return failure(); - Value result = *zero; int64_t groupsPerChunk = lanesPerPart / groupSize; for (int64_t localGroup = 1; localGroup < groupsPerChunk; ++localGroup) { FailureOr groupScalar = createScalarOffsetConstant( - loc, indexType.getElementType(), localGroup, rewriter); + loc, indexType.getElementType(), baseGroupSlot + localGroup, rewriter); FailureOr laneMask = createLaneRangeMask(loc, *maskType, localGroup * groupSize, (localGroup + 1) * groupSize, rewriter); if (failed(groupScalar) || failed(laneMask)) return failure(); - Value splat = - rewriter - .create(loc, indexType, *groupScalar, *allMask, - /*position=*/nullptr) - .getResult(); + Value splat = rewriter + .create(loc, indexType, *groupScalar, *allMask, + /*position=*/nullptr) + .getResult(); result = rewriter.create(loc, indexType, splat, result, *laneMask) .getResult(); } @@ -2189,8 +2462,7 @@ LogicalResult checkVcgaddGroupReduceShape(VMIVRegType sourceType, int64_t numGroups = sourceType.getElementCount() / groupSize; if (!sourceLayout || !resultLayout || !maskLayout || !sourceLayout.isContiguous() || !resultLayout.isGroupSlots() || - resultLayout.getNumGroups() != numGroups || - !maskLayout.isContiguous()) + resultLayout.getNumGroups() != numGroups || !maskLayout.isContiguous()) return fail("vcgadd group_reduce_addf path requires contiguous source/mask " "layouts and matching num_groups result layout"); std::string sourceFullReason; @@ -2211,6 +2483,135 @@ LogicalResult checkVcgaddGroupReduceShape(VMIVRegType sourceType, return success(); } +LogicalResult checkS16Block8GroupReduceShape(VMIGroupReduceAddFOp op, + std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto maskType = cast(op.getMask().getType()); + auto resultType = cast(op.getResult().getType()); + if (!sourceType.getElementType().isF32() || + sourceType.getElementType() != resultType.getElementType()) + return fail("s16 block8 group_reduce_addf requires f32 source/result"); + + FailureOr groupSize = + getGroupSizeFromNumGroups(sourceType, op.getNumGroupsAttr().getInt()); + if (failed(groupSize) || *groupSize != 16) + return fail("s16 block8 group_reduce_addf requires group size 16"); + + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + int64_t numGroups = op.getNumGroupsAttr().getInt(); + if (!sourceLayout || !sourceLayout.isDeinterleaved() || + sourceLayout.getFactor() != 2 || + (sourceLayout.getBlockElems() != 1 && sourceLayout.getBlockElems() != 8)) + return fail("s16 group_reduce_addf requires source layout " + "deinterleaved=2 with block_elems=1 or block_elems=8"); + if (!maskLayout || !maskLayout.isDeinterleaved() || + maskLayout.getFactor() != 2 || + maskLayout.getBlockElems() != sourceLayout.getBlockElems()) + return fail("s16 group_reduce_addf requires matching mask layout " + "deinterleaved=2 with the same block_elems"); + if (!resultLayout || !resultLayout.isGroupSlots() || + resultLayout.getNumGroups() != numGroups || resultLayout.getSlots() != 8) + return fail("s16 block8 group_reduce_addf requires " + "group_slots(num_groups, slots=8) result layout"); + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(maskArity) || failed(resultArity)) + return fail("s16 block8 group_reduce_addf requires computable physical " + "arity"); + int64_t expectedResultArity = ceilDivNonNegative(numGroups, 8); + if (*resultArity != expectedResultArity || *sourceArity != *resultArity * 2 || + *maskArity != *sourceArity) + return fail("s16 block8 group_reduce_addf requires two source/mask " + "parts per result part"); + + auto selectedPlan = op->getAttrOfType(kVMISelectedPlanAttrName); + if (!selectedPlan) + return fail("requires vmi.selected_plan selected by " + "vmi-layout-assignment"); + StringRef expectedPlan = sourceLayout.getBlockElems() == 1 + ? "s16_reduce_parity" + : "s16_reduce_block8"; + if (selectedPlan.getValue() != expectedPlan) + return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + + "' does not match source/result layouts; expected '" + + expectedPlan + "'"); + + return success(); +} + +LogicalResult checkS32Block8GroupReduceShape(VMIGroupReduceAddFOp op, + std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto maskType = cast(op.getMask().getType()); + auto resultType = cast(op.getResult().getType()); + if (!sourceType.getElementType().isF32() || + sourceType.getElementType() != resultType.getElementType()) + return fail("s32 block8 group_reduce_addf requires f32 source/result"); + + FailureOr groupSize = + getGroupSizeFromNumGroups(sourceType, op.getNumGroupsAttr().getInt()); + if (failed(groupSize) || *groupSize != 32) + return fail("s32 block8 group_reduce_addf requires group size 32"); + + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + int64_t numGroups = op.getNumGroupsAttr().getInt(); + if (!sourceLayout || !sourceLayout.isDeinterleaved() || + sourceLayout.getFactor() != 4 || + (sourceLayout.getBlockElems() != 1 && sourceLayout.getBlockElems() != 8)) + return fail("s32 group_reduce_addf requires source layout " + "deinterleaved=4 with block_elems=1 or block_elems=8"); + if (!maskLayout || !maskLayout.isDeinterleaved() || + maskLayout.getFactor() != 4 || + maskLayout.getBlockElems() != sourceLayout.getBlockElems()) + return fail("s32 group_reduce_addf requires matching mask layout " + "deinterleaved=4 with the same block_elems"); + if (!resultLayout || !resultLayout.isGroupSlots() || + resultLayout.getNumGroups() != numGroups || resultLayout.getSlots() != 8) + return fail("s32 block8 group_reduce_addf requires " + "group_slots(num_groups, slots=8) result layout"); + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(maskArity) || failed(resultArity)) + return fail("s32 block8 group_reduce_addf requires computable physical " + "arity"); + int64_t expectedResultArity = ceilDivNonNegative(numGroups, 8); + if (*resultArity != expectedResultArity || *sourceArity != *resultArity * 4 || + *maskArity != *sourceArity) + return fail("s32 block8 group_reduce_addf requires four source/mask " + "parts per result part"); + + auto selectedPlan = op->getAttrOfType(kVMISelectedPlanAttrName); + if (!selectedPlan) + return fail("requires vmi.selected_plan selected by " + "vmi-layout-assignment"); + StringRef expectedPlan = sourceLayout.getBlockElems() == 1 + ? "s32_reduce_dintlv4" + : "s32_reduce_block8_stride"; + if (selectedPlan.getValue() != expectedPlan) + return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + + "' does not match source/result layouts; expected '" + + expectedPlan + "'"); + return success(); +} + std::optional getX2MemoryDistToken(Type elementType, StringRef prefix) { unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); @@ -2289,7 +2690,8 @@ struct OneToNVMIPackOpPattern : OneToNOpConversionPattern { } }; -LogicalResult verifyIdentityPartForwarding(Operation *op, ValueRange sourceParts, +LogicalResult verifyIdentityPartForwarding(Operation *op, + ValueRange sourceParts, TypeRange resultTypes, PatternRewriter &rewriter) { if (sourceParts.size() != resultTypes.size()) @@ -2303,12 +2705,10 @@ LogicalResult verifyIdentityPartForwarding(Operation *op, ValueRange sourceParts return success(); } -FailureOr> -materializeDataLayoutConversion(Operation *op, ValueRange sourceParts, - TypeRange resultTypes, - VMILayoutAttr sourceLayout, - VMILayoutAttr resultLayout, - PatternRewriter &rewriter) { +FailureOr> materializeDataLayoutConversion( + Operation *op, ValueRange sourceParts, TypeRange resultTypes, + VMILayoutAttr sourceLayout, VMILayoutAttr resultLayout, + PatternRewriter &rewriter) { if (!sourceLayout || !resultLayout) { (void)rewriter.notifyMatchFailure( op, "layout materialization requires assigned source/result layouts"); @@ -2332,8 +2732,7 @@ materializeDataLayoutConversion(Operation *op, ValueRange sourceParts, if (sourceParts.size() != resultTypes.size() || sourceParts.empty() || sourceParts.size() % 2 != 0) { (void)rewriter.notifyMatchFailure( - op, - "deinterleaved=2 layout materialization requires 2*N parts"); + op, "deinterleaved=2 layout materialization requires 2*N parts"); return failure(); } if (failed(verifyIdentityPartForwarding(op, sourceParts, resultTypes, @@ -2378,8 +2777,7 @@ materializeDataLayoutConversion(Operation *op, ValueRange sourceParts, if (sourceParts.size() != resultTypes.size() || sourceParts.empty() || sourceParts.size() % 4 != 0) { (void)rewriter.notifyMatchFailure( - op, - "deinterleaved=4 layout materialization requires 4*N parts"); + op, "deinterleaved=4 layout materialization requires 4*N parts"); return failure(); } if (failed(verifyIdentityPartForwarding(op, sourceParts, resultTypes, @@ -2395,20 +2793,16 @@ materializeDataLayoutConversion(Operation *op, ValueRange sourceParts, Value p1 = sourceParts[groups + i]; Value p2 = sourceParts[2 * groups + i]; Value p3 = sourceParts[3 * groups + i]; - auto even = - rewriter.create(op->getLoc(), resultTypes[4 * i], - resultTypes[4 * i + 1], p0, p2); - auto odd = - rewriter.create(op->getLoc(), resultTypes[4 * i], - resultTypes[4 * i + 1], p1, p3); - auto low = - rewriter.create(op->getLoc(), resultTypes[4 * i], - resultTypes[4 * i + 1], even.getLow(), - odd.getLow()); - auto high = - rewriter.create(op->getLoc(), resultTypes[4 * i + 2], - resultTypes[4 * i + 3], even.getHigh(), - odd.getHigh()); + auto even = rewriter.create(op->getLoc(), resultTypes[4 * i], + resultTypes[4 * i + 1], p0, p2); + auto odd = rewriter.create(op->getLoc(), resultTypes[4 * i], + resultTypes[4 * i + 1], p1, p3); + auto low = rewriter.create(op->getLoc(), resultTypes[4 * i], + resultTypes[4 * i + 1], + even.getLow(), odd.getLow()); + auto high = rewriter.create( + op->getLoc(), resultTypes[4 * i + 2], resultTypes[4 * i + 3], + even.getHigh(), odd.getHigh()); results.append( {low.getLow(), low.getHigh(), high.getLow(), high.getHigh()}); } @@ -2422,21 +2816,19 @@ materializeDataLayoutConversion(Operation *op, ValueRange sourceParts, part2.reserve(groups); part3.reserve(groups); for (int64_t i = 0; i < groups; ++i) { - auto low = - rewriter.create(op->getLoc(), resultTypes[i], - resultTypes[groups + i], - sourceParts[4 * i], - sourceParts[4 * i + 1]); + auto low = rewriter.create( + op->getLoc(), resultTypes[i], resultTypes[groups + i], + sourceParts[4 * i], sourceParts[4 * i + 1]); auto high = rewriter.create( op->getLoc(), resultTypes[2 * groups + i], resultTypes[3 * groups + i], sourceParts[4 * i + 2], sourceParts[4 * i + 3]); - auto even = rewriter.create( - op->getLoc(), resultTypes[i], resultTypes[2 * groups + i], - low.getLow(), high.getLow()); + auto even = rewriter.create(op->getLoc(), resultTypes[i], + resultTypes[2 * groups + i], + low.getLow(), high.getLow()); auto odd = rewriter.create( - op->getLoc(), resultTypes[groups + i], - resultTypes[3 * groups + i], low.getHigh(), high.getHigh()); + op->getLoc(), resultTypes[groups + i], resultTypes[3 * groups + i], + low.getHigh(), high.getHigh()); part0.push_back(even.getLow()); part1.push_back(odd.getLow()); part2.push_back(even.getHigh()); @@ -2497,12 +2889,10 @@ createPredicateIntlv(Location loc, Type lowType, Type highType, Value lhs, return failure(); } -FailureOr> -materializeMaskLayoutConversion(Operation *op, ValueRange sourceParts, - TypeRange resultTypes, - VMILayoutAttr sourceLayout, - VMILayoutAttr resultLayout, - PatternRewriter &rewriter) { +FailureOr> materializeMaskLayoutConversion( + Operation *op, ValueRange sourceParts, TypeRange resultTypes, + VMILayoutAttr sourceLayout, VMILayoutAttr resultLayout, + PatternRewriter &rewriter) { if (!sourceLayout || !resultLayout) { (void)rewriter.notifyMatchFailure( op, "mask layout materialization requires assigned source/result " @@ -2540,10 +2930,9 @@ materializeMaskLayoutConversion(Operation *op, ValueRange sourceParts, results.reserve(sourceParts.size()); if (deint2ToContiguous) { for (int64_t i = 0; i < groups; ++i) { - FailureOr> materialize = - createPredicateIntlv(op->getLoc(), resultTypes[2 * i], - resultTypes[2 * i + 1], sourceParts[i], - sourceParts[groups + i], rewriter); + FailureOr> materialize = createPredicateIntlv( + op->getLoc(), resultTypes[2 * i], resultTypes[2 * i + 1], + sourceParts[i], sourceParts[groups + i], rewriter); if (failed(materialize)) return rewriter.notifyMatchFailure( op, "unsupported predicate intlv mask type"); @@ -2555,10 +2944,9 @@ materializeMaskLayoutConversion(Operation *op, ValueRange sourceParts, part0.reserve(groups); part1.reserve(groups); for (int64_t i = 0; i < groups; ++i) { - FailureOr> materialize = - createPredicateDintlv(op->getLoc(), resultTypes[i], - resultTypes[groups + i], sourceParts[2 * i], - sourceParts[2 * i + 1], rewriter); + FailureOr> materialize = createPredicateDintlv( + op->getLoc(), resultTypes[i], resultTypes[groups + i], + sourceParts[2 * i], sourceParts[2 * i + 1], rewriter); if (failed(materialize)) return rewriter.notifyMatchFailure( op, "unsupported predicate dintlv mask type"); @@ -2607,14 +2995,12 @@ materializeMaskLayoutConversion(Operation *op, ValueRange sourceParts, if (failed(even) || failed(odd)) return rewriter.notifyMatchFailure( op, "unsupported predicate intlv mask type"); - FailureOr> low = - createPredicateIntlv(op->getLoc(), resultTypes[4 * i], - resultTypes[4 * i + 1], even->first, - odd->first, rewriter); - FailureOr> high = - createPredicateIntlv(op->getLoc(), resultTypes[4 * i + 2], - resultTypes[4 * i + 3], even->second, - odd->second, rewriter); + FailureOr> low = createPredicateIntlv( + op->getLoc(), resultTypes[4 * i], resultTypes[4 * i + 1], + even->first, odd->first, rewriter); + FailureOr> high = createPredicateIntlv( + op->getLoc(), resultTypes[4 * i + 2], resultTypes[4 * i + 3], + even->second, odd->second, rewriter); if (failed(low) || failed(high)) return rewriter.notifyMatchFailure( op, "unsupported predicate intlv mask type"); @@ -2630,11 +3016,9 @@ materializeMaskLayoutConversion(Operation *op, ValueRange sourceParts, part2.reserve(groups); part3.reserve(groups); for (int64_t i = 0; i < groups; ++i) { - FailureOr> low = - createPredicateDintlv(op->getLoc(), resultTypes[i], - resultTypes[groups + i], - sourceParts[4 * i], sourceParts[4 * i + 1], - rewriter); + FailureOr> low = createPredicateDintlv( + op->getLoc(), resultTypes[i], resultTypes[groups + i], + sourceParts[4 * i], sourceParts[4 * i + 1], rewriter); FailureOr> high = createPredicateDintlv( op->getLoc(), resultTypes[2 * groups + i], resultTypes[3 * groups + i], sourceParts[4 * i + 2], @@ -2642,14 +3026,12 @@ materializeMaskLayoutConversion(Operation *op, ValueRange sourceParts, if (failed(low) || failed(high)) return rewriter.notifyMatchFailure( op, "unsupported predicate dintlv mask type"); - FailureOr> even = - createPredicateDintlv(op->getLoc(), resultTypes[i], - resultTypes[2 * groups + i], low->first, - high->first, rewriter); - FailureOr> odd = - createPredicateDintlv(op->getLoc(), resultTypes[groups + i], - resultTypes[3 * groups + i], low->second, - high->second, rewriter); + FailureOr> even = createPredicateDintlv( + op->getLoc(), resultTypes[i], resultTypes[2 * groups + i], + low->first, high->first, rewriter); + FailureOr> odd = createPredicateDintlv( + op->getLoc(), resultTypes[groups + i], resultTypes[3 * groups + i], + low->second, high->second, rewriter); if (failed(even) || failed(odd)) return rewriter.notifyMatchFailure( op, "unsupported predicate dintlv mask type"); @@ -2760,19 +3142,17 @@ FailureOr> materializeAdjacentMaskGranularityConversion( for (int64_t chunk = 0; chunk < *sourceChunks && produced < *resultChunks; ++chunk) { Value source = sourceParts[sourceOffset + chunk]; - results.push_back( - rewriter - .create(op->getLoc(), resultMaskType, source, - partAttr("LOWER")) - .getResult()); + results.push_back(rewriter + .create(op->getLoc(), resultMaskType, + source, partAttr("LOWER")) + .getResult()); ++produced; if (produced >= *resultChunks) break; - results.push_back( - rewriter - .create(op->getLoc(), resultMaskType, source, - partAttr("HIGHER")) - .getResult()); + results.push_back(rewriter + .create(op->getLoc(), resultMaskType, + source, partAttr("HIGHER")) + .getResult()); ++produced; } if (produced != *resultChunks) @@ -2786,18 +3166,16 @@ FailureOr> materializeAdjacentMaskGranularityConversion( return fail("narrowing mask granularity conversion ran out of " "source chunks"); Value lowerSource = sourceParts[sourceOffset + consumed++]; - Value packed = - rewriter - .create(op->getLoc(), resultMaskType, lowerSource, - partAttr("LOWER")) - .getResult(); + Value packed = rewriter + .create(op->getLoc(), resultMaskType, + lowerSource, partAttr("LOWER")) + .getResult(); if (consumed < *sourceChunks) { Value higherSource = sourceParts[sourceOffset + consumed++]; - Value higher = - rewriter - .create(op->getLoc(), resultMaskType, higherSource, - partAttr("HIGHER")) - .getResult(); + Value higher = rewriter + .create(op->getLoc(), resultMaskType, + higherSource, partAttr("HIGHER")) + .getResult(); if (!allTrue) { FailureOr mask = createAllTrueMask(op->getLoc(), resultMaskType, rewriter); @@ -2832,9 +3210,8 @@ FailureOr> materializeMaskGranularityConversion( VMIMaskType sourceType, VMIMaskType resultType, ValueRange sourceParts, PatternRewriter &rewriter) { std::string reason; - if (failed(checkSupportedMaskGranularityMaterialization(capabilities, - sourceType, - resultType, &reason))) { + if (failed(checkSupportedMaskGranularityMaterialization( + capabilities, sourceType, resultType, &reason))) { (void)rewriter.notifyMatchFailure(op, reason); return failure(); } @@ -2856,8 +3233,8 @@ FailureOr> materializeMaskGranularityConversion( VMIMaskType::get(op->getContext(), currentType.getElementCount(), nextGranularity, currentType.getLayoutAttr()); FailureOr> nextParts = - materializeAdjacentMaskGranularityConversion( - op, currentType, nextType, currentParts, rewriter); + materializeAdjacentMaskGranularityConversion(op, currentType, nextType, + currentParts, rewriter); if (failed(nextParts)) return failure(); currentType = nextType; @@ -2942,9 +3319,8 @@ struct OneToNVMIEnsureMaskGranularityOpPattern TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); if (sourceType.getGranularity() != resultType.getGranularity()) { FailureOr> results = - materializeMaskGranularityConversion(op, capabilities, sourceType, - resultType, sourceParts, - rewriter); + materializeMaskGranularityConversion( + op, capabilities, sourceType, resultType, sourceParts, rewriter); if (failed(results)) return failure(); if (results->size() != resultTypes.size()) @@ -2969,8 +3345,7 @@ struct OneToNVMIEnsureMaskGranularityOpPattern const VMITargetCapabilityRegistry &capabilities; }; -struct OneToNVMIBroadcastOpPattern - : OneToNOpConversionPattern { +struct OneToNVMIBroadcastOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult @@ -2988,8 +3363,7 @@ struct OneToNVMIBroadcastOpPattern for (Type resultType : resultTypes) { auto vregType = dyn_cast(resultType); if (!vregType) - return rewriter.notifyMatchFailure(op, - "broadcast result must be vreg"); + return rewriter.notifyMatchFailure(op, "broadcast result must be vreg"); FailureOr mask = createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); if (failed(mask)) @@ -2997,11 +3371,10 @@ struct OneToNVMIBroadcastOpPattern op, "unsupported element type for broadcast mask"); StringAttr position = inputIsVReg ? rewriter.getStringAttr("LOWEST") : StringAttr{}; - results.push_back( - rewriter - .create(op.getLoc(), resultType, inputParts.front(), - *mask, position) - .getResult()); + results.push_back(rewriter + .create(op.getLoc(), resultType, + inputParts.front(), *mask, position) + .getResult()); } rewriter.replaceOp(op, results, adaptor.getResultMapping()); @@ -3019,18 +3392,15 @@ FailureOr createScalarOffsetConstant(Location loc, Type type, } if (auto floatType = dyn_cast(type)) { return rewriter - .create(loc, - rewriter.getFloatAttr(floatType, - static_cast( - value))) + .create( + loc, rewriter.getFloatAttr(floatType, static_cast(value))) .getResult(); } return failure(); } FailureOr createIotaChunkBase(Location loc, Value base, - int64_t laneOffset, - StringRef order, + int64_t laneOffset, StringRef order, PatternRewriter &rewriter) { if (laneOffset == 0) return base; @@ -3078,8 +3448,8 @@ FailureOr createIotaDeinterleavedChunk(Location loc, Type resultType, return failure(); FailureOr mask = createAllTrueMaskForVReg(loc, vregType, rewriter); - FailureOr zero = createScalarOffsetConstant(loc, base.getType(), 0, - rewriter); + FailureOr zero = + createScalarOffsetConstant(loc, base.getType(), 0, rewriter); FailureOr factorScalar = createScalarOffsetConstant(loc, base.getType(), factor, rewriter); if (failed(mask) || failed(zero) || failed(factorScalar)) @@ -3099,11 +3469,10 @@ FailureOr createIotaDeinterleavedChunk(Location loc, Type resultType, return failure(); if (order == "DESC") { - Value baseVector = - rewriter - .create(loc, resultType, *biasedBase, *mask, - /*position=*/nullptr) - .getResult(); + Value baseVector = rewriter + .create(loc, resultType, *biasedBase, *mask, + /*position=*/nullptr) + .getResult(); return rewriter.create(loc, resultType, baseVector, scaled, *mask) .getResult(); } @@ -3121,8 +3490,7 @@ struct OneToNVMIIotaOpPattern : OneToNOpConversionPattern { auto resultVMIType = cast(op.getResult().getType()); VMILayoutAttr layout = resultVMIType.getLayoutAttr(); if (!layout) - return rewriter.notifyMatchFailure(op, - "iota requires assigned layout"); + return rewriter.notifyMatchFailure(op, "iota requires assigned layout"); FailureOr lanesPerPart = getDataLanesPerPart(resultVMIType.getElementType()); @@ -3130,9 +3498,8 @@ struct OneToNVMIIotaOpPattern : OneToNOpConversionPattern { return rewriter.notifyMatchFailure( op, "iota requires known physical lanes per part"); - FailureOr base = - getSingleValue(op, adaptor.getBase(), - "iota base must convert to one value", rewriter); + FailureOr base = getSingleValue( + op, adaptor.getBase(), "iota base must convert to one value", rewriter); if (failed(base)) return failure(); @@ -3167,8 +3534,8 @@ struct OneToNVMIIotaOpPattern : OneToNOpConversionPattern { for (int64_t chunk = 0; chunk < chunksPerPart; ++chunk) { Type resultType = resultTypes[part * chunksPerPart + chunk]; FailureOr result = createIotaDeinterleavedChunk( - op.getLoc(), resultType, *base, factor, part, chunk, - *lanesPerPart, op.getOrderAttr(), rewriter); + op.getLoc(), resultType, *base, factor, part, chunk, *lanesPerPart, + op.getOrderAttr(), rewriter); if (failed(result)) return rewriter.notifyMatchFailure( op, "failed to materialize deinterleaved iota chunk"); @@ -3193,8 +3560,7 @@ struct OneToNVMIConstantOpPattern : OneToNOpConversionPattern { op, "only splat dense data constants are supported"); auto splatAttr = dyn_cast(denseAttr.getSplatValue()); if (!splatAttr) - return rewriter.notifyMatchFailure(op, - "splat constant must be typed"); + return rewriter.notifyMatchFailure(op, "splat constant must be typed"); Value scalar = rewriter.create(op.getLoc(), splatAttr).getResult(); @@ -3210,11 +3576,11 @@ struct OneToNVMIConstantOpPattern : OneToNOpConversionPattern { if (failed(mask)) return rewriter.notifyMatchFailure( op, "unsupported element type for constant mask"); - results.push_back( - rewriter - .create(op.getLoc(), resultType, scalar, *mask, - /*position=*/nullptr) - .getResult()); + results.push_back(rewriter + .create(op.getLoc(), resultType, scalar, + *mask, + /*position=*/nullptr) + .getResult()); } rewriter.replaceOp(op, results, adaptor.getResultMapping()); @@ -3224,8 +3590,7 @@ struct OneToNVMIConstantOpPattern : OneToNOpConversionPattern { struct OneToNVMIConstantMaskOpPattern : OneToNOpConversionPattern { - using OneToNOpConversionPattern< - VMIConstantMaskOp>::OneToNOpConversionPattern; + using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult matchAndRewrite(VMIConstantMaskOp op, OpAdaptor adaptor, @@ -3235,8 +3600,7 @@ struct OneToNVMIConstantMaskOpPattern FailureOr> materializations = computeConstantMaskMaterialization(op, &reason); if (failed(materializations)) - return rewriter.notifyMatchFailure( - op, Twine("constant_mask ") + reason); + return rewriter.notifyMatchFailure(op, Twine("constant_mask ") + reason); SmallVector results; results.reserve(resultTypes.size()); @@ -3276,8 +3640,8 @@ struct OneToNVMICreateMaskOpPattern op.getActiveLanes().getDefiningOp(); auto resultVMIType = cast(op.getResult().getType()); VMILayoutAttr layout = resultVMIType.getLayoutAttr(); - if (!layout || !VMIMaskType::isConcreteGranularity( - resultVMIType.getGranularity())) + if (!layout || + !VMIMaskType::isConcreteGranularity(resultVMIType.getGranularity())) return rewriter.notifyMatchFailure( op, "create_mask requires concrete layout and granularity"); FailureOr lanesPerPart = @@ -3306,8 +3670,8 @@ struct OneToNVMICreateMaskOpPattern SmallVector results; results.reserve(resultTypes.size()); for (int64_t part = 0; part < factor; ++part) { - Value remaining = createPartitionActiveLanes( - op.getLoc(), activeI32, factor, part, rewriter); + Value remaining = createPartitionActiveLanes(op.getLoc(), activeI32, + factor, part, rewriter); for (int64_t chunk = 0; chunk < chunksPerPart; ++chunk) { Type resultType = resultTypes[part * chunksPerPart + chunk]; auto maskType = dyn_cast(resultType); @@ -3408,6 +3772,49 @@ struct OneToNVMICreateMaskOpPattern } }; +struct OneToNVMICreateGroupMaskOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMICreateGroupMaskOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMICreateGroupMaskOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + std::string reason; + FailureOr> materializations = + computeGroupMaskMaterialization(op, &reason); + if (failed(materializations)) + return rewriter.notifyMatchFailure(op, + Twine("create_group_mask ") + reason); + + SmallVector results; + results.reserve(resultTypes.size()); + for (const ConstantMaskChunkMaterialization &materialization : + *materializations) { + if (results.size() >= resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "create_group_mask produced too many physical masks"); + auto maskType = dyn_cast(resultTypes[results.size()]); + if (!maskType) + return rewriter.notifyMatchFailure( + op, "create_group_mask result must be mask"); + FailureOr mask = materializeConstantMaskChunk( + op.getLoc(), maskType, materialization.activeLanes, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "failed to materialize create_group_mask physical chunk"); + results.push_back(*mask); + } + + if (results.size() != resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "create_group_mask physical result count mismatch"); + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + struct OneToNVMILoadOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; @@ -3423,8 +3830,12 @@ struct OneToNVMILoadOpPattern : OneToNOpConversionPattern { "load offset must convert to one value", rewriter); if (failed(source) || failed(offset)) return failure(); + std::optional explicitFullReadElems; + if (auto attr = op.getFullReadElemsAttr()) + explicitFullReadElems = attr.getInt(); FailureOr lanesPerPart = verifyFullOrSafeReadVRegChunks( - op, resultVMIType, (*source).getType(), *offset, rewriter); + op, resultVMIType, op.getSource().getType(), *offset, rewriter, + explicitFullReadElems); if (failed(lanesPerPart)) return failure(); @@ -3448,9 +3859,9 @@ struct OneToNVMILoadOpPattern : OneToNOpConversionPattern { op, "vldsx2 requires matching low/high result types"); Value chunkOffset = createChunkOffset( op.getLoc(), *offset, group * 2 * *lanesPerPart, rewriter); - auto load = rewriter.create( - op.getLoc(), lowType, highType, *source, chunkOffset, - rewriter.getStringAttr(*dist)); + auto load = rewriter.create(op.getLoc(), lowType, highType, + *source, chunkOffset, + rewriter.getStringAttr(*dist)); lows.push_back(load.getLow()); highs.push_back(load.getHigh()); } @@ -3469,14 +3880,14 @@ struct OneToNVMILoadOpPattern : OneToNOpConversionPattern { auto vregType = dyn_cast(resultType); if (!vregType) return rewriter.notifyMatchFailure(op, "load result must be vreg"); - Value chunkOffset = createChunkOffset( - op.getLoc(), *offset, index * *lanesPerPart, rewriter); - contiguousParts.push_back( - rewriter - .create(op.getLoc(), resultType, - /*updated_base=*/Type{}, *source, chunkOffset, - /*dist=*/nullptr) - .getResult()); + Value chunkOffset = createChunkOffset(op.getLoc(), *offset, + index * *lanesPerPart, rewriter); + contiguousParts.push_back(rewriter + .create(op.getLoc(), resultType, + /*updated_base=*/Type{}, + *source, chunkOffset, + /*dist=*/nullptr) + .getResult()); } FailureOr> results = materializeDataLayoutConversion( @@ -3491,10 +3902,8 @@ struct OneToNVMILoadOpPattern : OneToNOpConversionPattern { } }; -struct OneToNVMIGroupLoadOpPattern - : OneToNOpConversionPattern { - using OneToNOpConversionPattern< - VMIGroupLoadOp>::OneToNOpConversionPattern; +struct OneToNVMIGroupLoadOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult matchAndRewrite(VMIGroupLoadOp op, OpAdaptor adaptor, @@ -3502,19 +3911,103 @@ struct OneToNVMIGroupLoadOpPattern auto resultVMIType = cast(op.getResult().getType()); FailureOr source = getSingleValue(op, adaptor.getSource(), - "group_load source must convert to one value", - rewriter); + "group_load source must convert to one value", rewriter); FailureOr offset = getSingleValue(op, adaptor.getOffset(), - "group_load offset must convert to one value", - rewriter); - FailureOr rowStride = - getSingleValue(op, adaptor.getRowStride(), - "group_load row_stride must convert to one value", - rewriter); + "group_load offset must convert to one value", rewriter); + FailureOr rowStride = getSingleValue( + op, adaptor.getRowStride(), + "group_load row_stride must convert to one value", rewriter); if (failed(source) || failed(offset) || failed(rowStride)) return failure(); + VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); + if (resultLayout && resultLayout.isDeinterleaved() && + resultLayout.getBlockElems() == 8 && + resultVMIType.getElementType().isF32()) { + FailureOr groupSize = getGroupSizeFromNumGroups( + resultVMIType, op.getNumGroupsAttr().getInt()); + if (failed(groupSize)) + return rewriter.notifyMatchFailure( + op, "group_load requires num_groups to evenly divide lane count"); + if ((*groupSize != 16 || resultLayout.getFactor() != 2) && + (*groupSize != 32 || resultLayout.getFactor() != 4)) + return rewriter.notifyMatchFailure( + op, "block8 group_load requires S=16/factor=2 or S=32/factor=4"); + if (op.getNumGroupsAttr().getInt() % 8 != 0) + return rewriter.notifyMatchFailure( + op, "block8 group_load requires num_groups multiple of 8"); + std::optional constantRowStride = + getConstantIndexValue(op.getRowStride()); + if (!constantRowStride || *constantRowStride <= 0 || + *constantRowStride % 8 != 0) + return rewriter.notifyMatchFailure( + op, "block8 group_load requires constant positive row_stride " + "divisible by 8 f32 elements"); + if (!isa((*source).getType())) + return rewriter.notifyMatchFailure( + op, "block8 group_load requires !pto.ptr source"); + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + int64_t factor = resultLayout.getFactor(); + FailureOr chunksPerPart = getDataChunksInPart(resultVMIType, 0); + if (failed(chunksPerPart) || *chunksPerPart <= 0) + return rewriter.notifyMatchFailure( + op, "block8 group_load requires known chunks per part"); + for (int64_t part = 1; part < factor; ++part) { + FailureOr currentChunks = + getDataChunksInPart(resultVMIType, part); + if (failed(currentChunks) || *currentChunks != *chunksPerPart) + return rewriter.notifyMatchFailure( + op, "block8 group_load requires uniform chunks per part"); + } + if (static_cast(resultTypes.size()) != factor * *chunksPerPart) + return rewriter.notifyMatchFailure(op, + "block8 group_load arity mismatch"); + + auto makeI16 = [&](int64_t value) -> Value { + return rewriter.create(op.getLoc(), value, 16); + }; + Value blockStride = makeI16(*constantRowStride / 8); + Value zeroI16 = makeI16(0); + auto makePtr = [&](Value elementOffset) -> Value { + return rewriter + .create(op.getLoc(), (*source).getType(), *source, + elementOffset) + .getResult(); + }; + + SmallVector results; + results.reserve(resultTypes.size()); + constexpr int64_t kGroupsPerBlock8Load = 8; + for (int64_t part = 0; part < factor; ++part) { + for (int64_t chunk = 0; chunk < *chunksPerPart; ++chunk) { + int64_t flatIndex = part * *chunksPerPart + chunk; + auto vregType = dyn_cast(resultTypes[flatIndex]); + if (!vregType) + return rewriter.notifyMatchFailure( + op, "block8 group_load result must be vreg"); + FailureOr allMask = + createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); + if (failed(allMask)) + return rewriter.notifyMatchFailure( + op, "failed to create block8 group_load mask"); + Value chunkOffset = createGroupChunkOffset( + op.getLoc(), *offset, *rowStride, chunk * kGroupsPerBlock8Load, + part * resultLayout.getBlockElems(), rewriter); + Value chunkBase = makePtr(chunkOffset); + results.push_back(rewriter + .create(op.getLoc(), vregType, + chunkBase, blockStride, + zeroI16, *allMask) + .getResult()); + } + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + int64_t lanesPerPart = 0; int64_t groupCount = 0; int64_t chunksPerGroup = 0; @@ -3523,14 +4016,13 @@ struct OneToNVMIGroupLoadOpPattern if (failed(groupSize)) return rewriter.notifyMatchFailure( op, "group_load requires num_groups to evenly divide lane count"); - if (failed(checkContiguousFullGroupChunks( - op, resultVMIType, *groupSize, &lanesPerPart, &groupCount, - &chunksPerGroup, rewriter))) + if (failed(checkContiguousFullGroupChunks(op, resultVMIType, *groupSize, + &lanesPerPart, &groupCount, + &chunksPerGroup, rewriter))) return failure(); TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); - if (static_cast(resultTypes.size()) != - groupCount * chunksPerGroup) + if (static_cast(resultTypes.size()) != groupCount * chunksPerGroup) return rewriter.notifyMatchFailure(op, "group_load arity mismatch"); SmallVector results; @@ -3542,15 +4034,156 @@ struct OneToNVMIGroupLoadOpPattern "group_load result must be vreg"); int64_t group = index / chunksPerGroup; int64_t chunkInGroup = index % chunksPerGroup; - Value chunkOffset = createGroupChunkOffset( - op.getLoc(), *offset, *rowStride, group, - chunkInGroup * lanesPerPart, rewriter); - results.push_back( - rewriter - .create(op.getLoc(), resultType, - /*updated_base=*/Type{}, *source, chunkOffset, - /*dist=*/nullptr) - .getResult()); + Value chunkOffset = + createGroupChunkOffset(op.getLoc(), *offset, *rowStride, group, + chunkInGroup * lanesPerPart, rewriter); + results.push_back(rewriter + .create(op.getLoc(), resultType, + /*updated_base=*/Type{}, *source, + chunkOffset, + /*dist=*/nullptr) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIGroupSlotLoadOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIGroupSlotLoadOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIGroupSlotLoadOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto resultVMIType = cast(op.getResult().getType()); + VMILayoutAttr layout = resultVMIType.getLayoutAttr(); + if (!layout || !layout.isGroupSlots() || layout.getSlots() <= 0) + return rewriter.notifyMatchFailure( + op, "group_slot_load requires explicit group_slots layout"); + + FailureOr source = getSingleValue( + op, adaptor.getSource(), + "group_slot_load source must convert to one value", rewriter); + FailureOr offset = getSingleValue( + op, adaptor.getOffset(), + "group_slot_load offset must convert to one value", rewriter); + FailureOr sourceGroupStride = getSingleValue( + op, adaptor.getSourceGroupStride(), + "group_slot_load source_group_stride must convert to one value", + rewriter); + if (failed(source) || failed(offset) || failed(sourceGroupStride)) + return failure(); + if (!isa((*source).getType())) + return rewriter.notifyMatchFailure( + op, "group_slot_load requires !pto.ptr source"); + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + int64_t numGroups = op.getNumGroupsAttr().getInt(); + int64_t slots = layout.getSlots(); + int64_t expectedArity = ceilDivNonNegative(numGroups, slots); + if (static_cast(resultTypes.size()) != expectedArity) + return rewriter.notifyMatchFailure(op, "group_slot_load arity mismatch"); + + auto makeI16 = [&](int64_t value) -> Value { + return rewriter.create(op.getLoc(), value, 16); + }; + Value zeroI16 = makeI16(0); + auto makePtr = [&](Value elementOffset) -> Value { + return rewriter + .create(op.getLoc(), (*source).getType(), *source, + elementOffset) + .getResult(); + }; + + SmallVector results; + results.reserve(resultTypes.size()); + + if (slots == 8) { + std::optional stride = + getConstantIndexValue(op.getSourceGroupStride()); + if (!stride || *stride != 1) + return rewriter.notifyMatchFailure( + op, "slots=8 group_slot_load requires constant unit stride"); + if (resultTypes.size() != 1) + return rewriter.notifyMatchFailure( + op, "slots=8 group_slot_load expects one physical result"); + auto resultType = dyn_cast(resultTypes.front()); + if (!resultType) + return rewriter.notifyMatchFailure( + op, "group_slot_load result must be vreg"); + FailureOr maskType = + getMaskTypeForVReg(resultType, rewriter.getContext()); + if (failed(maskType)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for group_slot_load mask"); + FailureOr oneBlockMask = + createPrefixMask(op.getLoc(), *maskType, "PAT_VL1", rewriter); + if (failed(oneBlockMask)) + return rewriter.notifyMatchFailure( + op, "failed to create group_slot_load mask"); + Value slotBase = makePtr(*offset); + results.push_back(rewriter + .create(op.getLoc(), resultType, slotBase, + zeroI16, zeroI16, *oneBlockMask) + .getResult()); + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + if (slots != 1) + return rewriter.notifyMatchFailure( + op, "group_slot_load supports only slots=8 or slots=1"); + unsigned elementBits = + pto::getPTOStorageElemBitWidth(resultVMIType.getElementType()); + if (elementBits == 0 || 256 % elementBits != 0) + return rewriter.notifyMatchFailure( + op, "slots=1 group_slot_load requires supported element width"); + int64_t alignedStrideElems = 256 / elementBits; + std::optional constantStride = + getConstantIndexValue(op.getSourceGroupStride()); + if (!constantStride || *constantStride <= 0 || + *constantStride % alignedStrideElems != 0) + return rewriter.notifyMatchFailure( + op, Twine("slots=1 group_slot_load requires constant positive " + "source_group_stride divisible by ") + + Twine(alignedStrideElems) + + " elements for 32B lane-0 vsldb alignment"); + + for (auto [group, resultType] : llvm::enumerate(resultTypes)) { + auto vregType = dyn_cast(resultType); + if (!vregType) + return rewriter.notifyMatchFailure( + op, "group_slot_load result must be vreg"); + FailureOr maskType = + getMaskTypeForVReg(vregType, rewriter.getContext()); + if (failed(maskType)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for group_slot_load mask"); + FailureOr oneBlockMask = + createPrefixMask(op.getLoc(), *maskType, "PAT_VL1", rewriter); + if (failed(oneBlockMask)) + return rewriter.notifyMatchFailure( + op, "failed to create group_slot_load mask"); + Value groupOffset = *offset; + if (group != 0) { + Value groupIndex = + rewriter.create(op.getLoc(), group); + Value rowOffset = rewriter + .create( + op.getLoc(), *sourceGroupStride, groupIndex) + .getResult(); + groupOffset = + rewriter.create(op.getLoc(), groupOffset, rowOffset) + .getResult(); + } + Value slotBase = makePtr(groupOffset); + results.push_back(rewriter + .create(op.getLoc(), vregType, slotBase, + zeroI16, zeroI16, *oneBlockMask) + .getResult()); } rewriter.replaceOp(op, results, adaptor.getResultMapping()); @@ -3560,21 +4193,18 @@ struct OneToNVMIGroupLoadOpPattern struct OneToNVMIMaskedLoadOpPattern : OneToNOpConversionPattern { - using OneToNOpConversionPattern< - VMIMaskedLoadOp>::OneToNOpConversionPattern; + using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult matchAndRewrite(VMIMaskedLoadOp op, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override { auto resultVMIType = cast(op.getResult().getType()); - FailureOr source = - getSingleValue(op, adaptor.getSource(), - "masked_load source must convert to one value", - rewriter); - FailureOr offset = - getSingleValue(op, adaptor.getOffset(), - "masked_load offset must convert to one value", - rewriter); + FailureOr source = getSingleValue( + op, adaptor.getSource(), "masked_load source must convert to one value", + rewriter); + FailureOr offset = getSingleValue( + op, adaptor.getOffset(), "masked_load offset must convert to one value", + rewriter); if (failed(source) || failed(offset)) return failure(); @@ -3588,22 +4218,21 @@ struct OneToNVMIMaskedLoadOpPattern TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); if (maskParts.size() != passthruParts.size() || passthruParts.size() != resultTypes.size()) - return rewriter.notifyMatchFailure( - op, "masked_load physical arity mismatch"); + return rewriter.notifyMatchFailure(op, + "masked_load physical arity mismatch"); SmallVector results; results.reserve(resultTypes.size()); - for (auto [index, maskPassthruAndType] : - llvm::enumerate(llvm::zip_equal(maskParts, passthruParts, - resultTypes))) { + for (auto [index, maskPassthruAndType] : llvm::enumerate( + llvm::zip_equal(maskParts, passthruParts, resultTypes))) { auto [mask, passthru, resultType] = maskPassthruAndType; if (!isa(mask.getType()) || passthru.getType() != resultType || !isa(resultType)) return rewriter.notifyMatchFailure( op, "masked_load physical part type mismatch"); - Value chunkOffset = createChunkOffset( - op.getLoc(), *offset, index * *lanesPerPart, rewriter); + Value chunkOffset = createChunkOffset(op.getLoc(), *offset, + index * *lanesPerPart, rewriter); Value loaded = rewriter .create(op.getLoc(), resultType, @@ -3645,22 +4274,19 @@ struct OneToNVMIGatherOpPattern : OneToNOpConversionPattern { SmallVector results; results.reserve(resultTypes.size()); for (auto [indices, mask, passthru, resultType] : - llvm::zip_equal(indicesParts, maskParts, passthruParts, - resultTypes)) { + llvm::zip_equal(indicesParts, maskParts, passthruParts, resultTypes)) { if (!isa(indices.getType()) || !isa(mask.getType()) || passthru.getType() != resultType || !isa(resultType)) - return rewriter.notifyMatchFailure(op, - "gather physical part type mismatch"); + return rewriter.notifyMatchFailure( + op, "gather physical part type mismatch"); - Value gathered = - rewriter - .create(op.getLoc(), resultType, *source, indices, - mask) - .getResult(); + Value gathered = rewriter + .create(op.getLoc(), resultType, + *source, indices, mask) + .getResult(); results.push_back( rewriter - .create(op.getLoc(), resultType, gathered, passthru, - mask) + .create(op.getLoc(), resultType, gathered, passthru, mask) .getResult()); } @@ -3671,21 +4297,18 @@ struct OneToNVMIGatherOpPattern : OneToNOpConversionPattern { struct OneToNVMIExpandLoadOpPattern : OneToNOpConversionPattern { - using OneToNOpConversionPattern< - VMIExpandLoadOp>::OneToNOpConversionPattern; + using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult matchAndRewrite(VMIExpandLoadOp op, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override { auto resultVMIType = cast(op.getResult().getType()); - FailureOr source = - getSingleValue(op, adaptor.getSource(), - "expand_load source must convert to one value", - rewriter); - FailureOr offset = - getSingleValue(op, adaptor.getOffset(), - "expand_load offset must convert to one value", - rewriter); + FailureOr source = getSingleValue( + op, adaptor.getSource(), "expand_load source must convert to one value", + rewriter); + FailureOr offset = getSingleValue( + op, adaptor.getOffset(), "expand_load offset must convert to one value", + rewriter); if (failed(source) || failed(offset)) return failure(); @@ -3700,16 +4323,16 @@ struct OneToNVMIExpandLoadOpPattern results.reserve(resultTypes.size()); for (auto [index, resultType] : llvm::enumerate(resultTypes)) { if (!isa(resultType)) - return rewriter.notifyMatchFailure( - op, "expand_load result must be vreg"); - Value chunkOffset = createChunkOffset( - op.getLoc(), *offset, index * *lanesPerPart, rewriter); - results.push_back( - rewriter - .create(op.getLoc(), resultType, - /*updated_base=*/Type{}, *source, chunkOffset, - /*dist=*/nullptr) - .getResult()); + return rewriter.notifyMatchFailure(op, + "expand_load result must be vreg"); + Value chunkOffset = createChunkOffset(op.getLoc(), *offset, + index * *lanesPerPart, rewriter); + results.push_back(rewriter + .create(op.getLoc(), resultType, + /*updated_base=*/Type{}, *source, + chunkOffset, + /*dist=*/nullptr) + .getResult()); } rewriter.replaceOp(op, results, adaptor.getResultMapping()); @@ -3725,7 +4348,8 @@ struct OneToNVMIExpandLoadOpPattern auto resultType = dyn_cast(resultTypes.front()); auto maskType = dyn_cast(maskParts.front().getType()); - if (!resultType || !maskType || passthruParts.front().getType() != resultType) + if (!resultType || !maskType || + passthruParts.front().getType() != resultType) return rewriter.notifyMatchFailure( op, "runtime expand_load requires physical result/passthru/mask"); @@ -3733,11 +4357,10 @@ struct OneToNVMIExpandLoadOpPattern if (!baseType) return rewriter.notifyMatchFailure(op, "runtime expand_load requires ptr"); - Value gatherBase = - rewriter - .create(op.getLoc(), (*source).getType(), *source, - *offset) - .getResult(); + Value gatherBase = rewriter + .create(op.getLoc(), (*source).getType(), + *source, *offset) + .getResult(); auto indexType = VRegType::get(rewriter.getContext(), resultType.getElementCount(), rewriter.getI32Type()); @@ -3754,19 +4377,17 @@ struct OneToNVMIExpandLoadOpPattern .getResult(); Value indices = rewriter - .create(op.getLoc(), indexType, carrier, - maskParts.front()) + .create(op.getLoc(), indexType, carrier, maskParts.front()) .getResult(); Value gathered = rewriter .create(op.getLoc(), resultType, gatherBase, indices, maskParts.front()) .getResult(); - Value result = - rewriter - .create(op.getLoc(), resultType, gathered, - passthruParts.front(), maskParts.front()) - .getResult(); + Value result = rewriter + .create(op.getLoc(), resultType, gathered, + passthruParts.front(), maskParts.front()) + .getResult(); rewriter.replaceOp(op, SmallVector{result}, adaptor.getResultMapping()); return success(); @@ -3854,18 +4475,16 @@ struct OneToNVMIStoreOpPattern : OneToNOpConversionPattern { if (*activeLanes == 0) continue; } - FailureOr mask = fullPhysicalChunks - ? createAllTrueMaskForVReg(op.getLoc(), - vregType, rewriter) - : createContiguousStoreMask(op.getLoc(), - valueVMIType, - index, vregType, - rewriter); + FailureOr mask = + fullPhysicalChunks + ? createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter) + : createContiguousStoreMask(op.getLoc(), valueVMIType, index, + vregType, rewriter); if (failed(mask)) return rewriter.notifyMatchFailure( op, "unsupported element type for store mask"); - Value chunkOffset = createChunkOffset( - op.getLoc(), *offset, index * *lanesPerPart, rewriter); + Value chunkOffset = createChunkOffset(op.getLoc(), *offset, + index * *lanesPerPart, rewriter); rewriter.create(op.getLoc(), /*updated_base=*/Type{}, value, *destination, chunkOffset, /*dist=*/nullptr, *mask); @@ -3878,44 +4497,133 @@ struct OneToNVMIStoreOpPattern : OneToNOpConversionPattern { struct OneToNVMIGroupStoreOpPattern : OneToNOpConversionPattern { - using OneToNOpConversionPattern< - VMIGroupStoreOp>::OneToNOpConversionPattern; + using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult matchAndRewrite(VMIGroupStoreOp op, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override { auto valueVMIType = cast(op.getValue().getType()); + VMILayoutAttr layout = valueVMIType.getLayoutAttr(); + + FailureOr destination = getSingleValue( + op, adaptor.getDestination(), + "group_store destination must convert to one value", rewriter); + FailureOr offset = getSingleValue( + op, adaptor.getOffset(), "group_store offset must convert to one value", + rewriter); + FailureOr rowStride = getSingleValue( + op, adaptor.getRowStride(), + "group_store row_stride must convert to one value", rewriter); + if (failed(destination) || failed(offset) || failed(rowStride)) + return failure(); + + if (layout && layout.isGroupSlots() && layout.getSlots() == 1 && + layout.getNumGroups() == op.getNumGroupsAttr().getInt()) { + ValueRange valueParts = adaptor.getValue(); + if (static_cast(valueParts.size()) != layout.getNumGroups()) + return rewriter.notifyMatchFailure( + op, "slots=1 group_store arity mismatch"); + unsigned elementBits = + pto::getPTOStorageElemBitWidth(valueVMIType.getElementType()); + if (elementBits == 0 || 256 % elementBits != 0) + return rewriter.notifyMatchFailure( + op, "slots=1 group_store requires supported element width"); + int64_t alignedStrideElems = 256 / elementBits; + std::optional constantRowStride = + getConstantIndexValue(op.getRowStride()); + if (!constantRowStride || *constantRowStride <= 0 || + *constantRowStride % alignedStrideElems != 0) + return rewriter.notifyMatchFailure( + op, Twine("slots=1 group_store requires constant positive " + "row_stride divisible by ") + + Twine(alignedStrideElems) + + " elements for 32B lane-0 vsts alignment"); + + for (auto [group, value] : llvm::enumerate(valueParts)) { + auto vregType = dyn_cast(value.getType()); + if (!vregType) + return rewriter.notifyMatchFailure(op, + "group_store value must be vreg"); + FailureOr maskType = + getMaskTypeForVReg(vregType, rewriter.getContext()); + if (failed(maskType)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for group_store mask"); + FailureOr mask = + createPrefixMask(op.getLoc(), *maskType, "PAT_VL1", rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "failed to create slots=1 group_store mask"); + Value groupOffset = + createGroupChunkOffset(op.getLoc(), *offset, *rowStride, group, + /*chunkLaneOffset=*/0, rewriter); + rewriter.create(op.getLoc(), + /*updated_base=*/Type{}, value, *destination, + groupOffset, /*dist=*/nullptr, *mask); + } + + rewriter.eraseOp(op); + return success(); + } + + if (layout && layout.isGroupSlots() && layout.getSlots() == 8 && + layout.getNumGroups() == op.getNumGroupsAttr().getInt()) { + int64_t numGroups = layout.getNumGroups(); + std::optional constantRowStride = + getConstantIndexValue(op.getRowStride()); + if (!constantRowStride || *constantRowStride != 1) + return rewriter.notifyMatchFailure( + op, "slots=8 group_store requires constant unit row_stride"); + + ValueRange valueParts = adaptor.getValue(); + if (static_cast(valueParts.size()) != + ceilDivNonNegative(numGroups, 8)) + return rewriter.notifyMatchFailure( + op, "slots=8 group_store arity mismatch"); + + for (auto [slotBlock, value] : llvm::enumerate(valueParts)) { + auto vregType = dyn_cast(value.getType()); + if (!vregType) + return rewriter.notifyMatchFailure(op, + "group_store value must be vreg"); + FailureOr maskType = + getMaskTypeForVReg(vregType, rewriter.getContext()); + if (failed(maskType)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for group_store mask"); + int64_t activeGroups = std::min(8, numGroups - slotBlock * 8); + FailureOr mask = createPrefixMaskForActiveLanes( + op.getLoc(), *maskType, activeGroups, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "failed to create slots=8 group_store mask"); + Value groupOffset = createGroupChunkOffset( + op.getLoc(), *offset, *rowStride, slotBlock * 8, + /*chunkLaneOffset=*/0, rewriter); + rewriter.create(op.getLoc(), + /*updated_base=*/Type{}, value, *destination, + groupOffset, /*dist=*/nullptr, *mask); + } + + rewriter.eraseOp(op); + return success(); + } + int64_t lanesPerPart = 0; int64_t groupCount = 0; int64_t chunksPerGroup = 0; - FailureOr groupSize = getGroupSizeFromNumGroups( - valueVMIType, op.getNumGroupsAttr().getInt()); + FailureOr groupSize = + getGroupSizeFromNumGroups(valueVMIType, op.getNumGroupsAttr().getInt()); if (failed(groupSize)) return rewriter.notifyMatchFailure( op, "group_store requires num_groups to evenly divide lane count"); - if (failed(checkContiguousFullGroupChunks( - op, valueVMIType, *groupSize, &lanesPerPart, &groupCount, - &chunksPerGroup, rewriter))) - return failure(); - - FailureOr destination = - getSingleValue(op, adaptor.getDestination(), - "group_store destination must convert to one value", - rewriter); - FailureOr offset = - getSingleValue(op, adaptor.getOffset(), - "group_store offset must convert to one value", - rewriter); - FailureOr rowStride = - getSingleValue(op, adaptor.getRowStride(), - "group_store row_stride must convert to one value", - rewriter); - if (failed(destination) || failed(offset) || failed(rowStride)) + if (failed(checkContiguousFullGroupChunks(op, valueVMIType, *groupSize, + &lanesPerPart, &groupCount, + &chunksPerGroup, rewriter))) return failure(); ValueRange valueParts = adaptor.getValue(); - if (static_cast(valueParts.size()) != - groupCount * chunksPerGroup) + if (static_cast(valueParts.size()) != groupCount * chunksPerGroup) return rewriter.notifyMatchFailure(op, "group_store arity mismatch"); for (auto [index, value] : llvm::enumerate(valueParts)) { @@ -3930,9 +4638,9 @@ struct OneToNVMIGroupStoreOpPattern op, "unsupported element type for group_store mask"); int64_t group = index / chunksPerGroup; int64_t chunkInGroup = index % chunksPerGroup; - Value chunkOffset = createGroupChunkOffset( - op.getLoc(), *offset, *rowStride, group, - chunkInGroup * lanesPerPart, rewriter); + Value chunkOffset = + createGroupChunkOffset(op.getLoc(), *offset, *rowStride, group, + chunkInGroup * lanesPerPart, rewriter); rewriter.create(op.getLoc(), /*updated_base=*/Type{}, value, *destination, chunkOffset, /*dist=*/nullptr, *mask); @@ -3945,8 +4653,7 @@ struct OneToNVMIGroupStoreOpPattern struct OneToNVMIMaskedStoreOpPattern : OneToNOpConversionPattern { - using OneToNOpConversionPattern< - VMIMaskedStoreOp>::OneToNOpConversionPattern; + using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult matchAndRewrite(VMIMaskedStoreOp op, OpAdaptor adaptor, @@ -3958,14 +4665,12 @@ struct OneToNVMIMaskedStoreOpPattern return rewriter.notifyMatchFailure( op, "masked_store requires known physical lanes per part"); - FailureOr destination = - getSingleValue(op, adaptor.getDestination(), - "masked_store destination must convert to one value", - rewriter); - FailureOr offset = - getSingleValue(op, adaptor.getOffset(), - "masked_store offset must convert to one value", - rewriter); + FailureOr destination = getSingleValue( + op, adaptor.getDestination(), + "masked_store destination must convert to one value", rewriter); + FailureOr offset = getSingleValue( + op, adaptor.getOffset(), + "masked_store offset must convert to one value", rewriter); if (failed(destination) || failed(offset)) return failure(); @@ -4019,8 +4724,8 @@ struct OneToNVMIMaskedStoreOpPattern if (failed(storeMask)) return rewriter.notifyMatchFailure( op, "failed to materialize masked_store predicate"); - Value chunkOffset = createChunkOffset( - op.getLoc(), *offset, index * *lanesPerPart, rewriter); + Value chunkOffset = createChunkOffset(op.getLoc(), *offset, + index * *lanesPerPart, rewriter); rewriter.create(op.getLoc(), /*updated_base=*/Type{}, value, *destination, chunkOffset, /*dist=*/nullptr, *storeMask); @@ -4037,10 +4742,9 @@ struct OneToNVMIScatterOpPattern : OneToNOpConversionPattern { LogicalResult matchAndRewrite(VMIScatterOp op, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override { - FailureOr destination = - getSingleValue(op, adaptor.getDestination(), - "scatter destination must convert to one value", - rewriter); + FailureOr destination = getSingleValue( + op, adaptor.getDestination(), + "scatter destination must convert to one value", rewriter); if (failed(destination)) return failure(); @@ -4049,8 +4753,7 @@ struct OneToNVMIScatterOpPattern : OneToNOpConversionPattern { ValueRange maskParts = adaptor.getMask(); if (valueParts.size() != indicesParts.size() || valueParts.size() != maskParts.size()) - return rewriter.notifyMatchFailure(op, - "scatter physical arity mismatch"); + return rewriter.notifyMatchFailure(op, "scatter physical arity mismatch"); for (auto [value, indices, mask] : llvm::zip_equal(valueParts, indicesParts, maskParts)) { @@ -4067,8 +4770,7 @@ struct OneToNVMIScatterOpPattern : OneToNOpConversionPattern { } }; -struct OneToNVMITileReadOpPattern - : OneToNOpConversionPattern { +struct OneToNVMITileReadOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult @@ -4107,9 +4809,9 @@ struct OneToNVMITileReadOpPattern op, "vldsx2 requires matching low/high result types"); Value chunkOffset = createChunkOffset( op.getLoc(), zero, group * 2 * *lanesPerPart, rewriter); - auto load = rewriter.create( - op.getLoc(), lowType, highType, *source, chunkOffset, - rewriter.getStringAttr(*dist)); + auto load = rewriter.create(op.getLoc(), lowType, highType, + *source, chunkOffset, + rewriter.getStringAttr(*dist)); lows.push_back(load.getLow()); highs.push_back(load.getHigh()); } @@ -4128,14 +4830,14 @@ struct OneToNVMITileReadOpPattern auto vregType = dyn_cast(resultType); if (!vregType) return rewriter.notifyMatchFailure(op, "tile_read result must be vreg"); - Value chunkOffset = createChunkOffset( - op.getLoc(), zero, index * *lanesPerPart, rewriter); - contiguousParts.push_back( - rewriter - .create(op.getLoc(), resultType, - /*updated_base=*/Type{}, *source, chunkOffset, - /*dist=*/nullptr) - .getResult()); + Value chunkOffset = + createChunkOffset(op.getLoc(), zero, index * *lanesPerPart, rewriter); + contiguousParts.push_back(rewriter + .create(op.getLoc(), resultType, + /*updated_base=*/Type{}, + *source, chunkOffset, + /*dist=*/nullptr) + .getResult()); } FailureOr> results = materializeDataLayoutConversion( @@ -4150,8 +4852,7 @@ struct OneToNVMITileReadOpPattern } }; -struct OneToNVMITileWriteOpPattern - : OneToNOpConversionPattern { +struct OneToNVMITileWriteOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult @@ -4231,18 +4932,16 @@ struct OneToNVMITileWriteOpPattern if (*activeLanes == 0) continue; } - FailureOr mask = fullPhysicalChunks - ? createAllTrueMaskForVReg(op.getLoc(), - vregType, rewriter) - : createContiguousStoreMask(op.getLoc(), - valueVMIType, - index, vregType, - rewriter); + FailureOr mask = + fullPhysicalChunks + ? createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter) + : createContiguousStoreMask(op.getLoc(), valueVMIType, index, + vregType, rewriter); if (failed(mask)) return rewriter.notifyMatchFailure( op, "unsupported element type for tile_write mask"); - Value chunkOffset = createChunkOffset( - op.getLoc(), zero, index * *lanesPerPart, rewriter); + Value chunkOffset = + createChunkOffset(op.getLoc(), zero, index * *lanesPerPart, rewriter); rewriter.create(op.getLoc(), /*updated_base=*/Type{}, value, *destination, chunkOffset, /*dist=*/nullptr, *mask); @@ -4257,10 +4956,10 @@ template struct OneToNVMIBinaryOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; - LogicalResult - matchAndRewrite(SourceOp op, typename OneToNOpConversionPattern< - SourceOp>::OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite( + SourceOp op, + typename OneToNOpConversionPattern::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { ValueRange lhsParts = adaptor.getLhs(); ValueRange rhsParts = adaptor.getRhs(); TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); @@ -4275,8 +4974,8 @@ struct OneToNVMIBinaryOpPattern : OneToNOpConversionPattern { auto vregType = dyn_cast(resultType); if (!vregType || lhs.getType() != resultType || rhs.getType() != resultType) - return rewriter.notifyMatchFailure(op, - "physical binary part type mismatch"); + return rewriter.notifyMatchFailure( + op, "physical binary part type mismatch"); FailureOr mask = createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); if (failed(mask)) @@ -4322,8 +5021,8 @@ struct OneToNVMIFmaOpPattern : OneToNOpConversionPattern { return rewriter.notifyMatchFailure(op, "unsupported element type for fma"); results.push_back( - rewriter.create(op.getLoc(), resultType, acc, lhs, rhs, - *mask) + rewriter + .create(op.getLoc(), resultType, acc, lhs, rhs, *mask) .getResult()); } @@ -4336,10 +5035,10 @@ template struct OneToNVMIUnaryOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; - LogicalResult - matchAndRewrite(SourceOp op, typename OneToNOpConversionPattern< - SourceOp>::OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite( + SourceOp op, + typename OneToNOpConversionPattern::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { ValueRange sourceParts = adaptor.getSource(); TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); if (sourceParts.size() != resultTypes.size()) @@ -4347,7 +5046,8 @@ struct OneToNVMIUnaryOpPattern : OneToNOpConversionPattern { SmallVector results; results.reserve(resultTypes.size()); - for (auto [source, resultType] : llvm::zip_equal(sourceParts, resultTypes)) { + for (auto [source, resultType] : + llvm::zip_equal(sourceParts, resultTypes)) { auto vregType = dyn_cast(resultType); if (!vregType || source.getType() != resultType) return rewriter.notifyMatchFailure(op, @@ -4371,10 +5071,10 @@ template struct OneToNVMIMaskBinaryOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; - LogicalResult - matchAndRewrite(SourceOp op, typename OneToNOpConversionPattern< - SourceOp>::OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite( + SourceOp op, + typename OneToNOpConversionPattern::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { ValueRange lhsParts = adaptor.getLhs(); ValueRange rhsParts = adaptor.getRhs(); TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); @@ -4398,8 +5098,8 @@ struct OneToNVMIMaskBinaryOpPattern : OneToNOpConversionPattern { return rewriter.notifyMatchFailure( op, "unsupported mask type for all-true mask binary seed"); results.push_back( - rewriter.create(op.getLoc(), resultType, lhs, rhs, - *seedMask) + rewriter + .create(op.getLoc(), resultType, lhs, rhs, *seedMask) .getResult()); } @@ -4412,10 +5112,10 @@ template struct OneToNVMIMaskUnaryOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; - LogicalResult - matchAndRewrite(SourceOp op, typename OneToNOpConversionPattern< - SourceOp>::OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite( + SourceOp op, + typename OneToNOpConversionPattern::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { ValueRange sourceParts = adaptor.getSource(); TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); if (sourceParts.size() != resultTypes.size()) @@ -4449,10 +5149,10 @@ template struct OneToNVMICmpOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; - LogicalResult - matchAndRewrite(SourceOp op, typename OneToNOpConversionPattern< - SourceOp>::OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite( + SourceOp op, + typename OneToNOpConversionPattern::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { std::optional cmpMode = getVPTOCmpMode(op.getPredicate()); if (!cmpMode) return op.emitOpError() @@ -4484,11 +5184,11 @@ struct OneToNVMICmpOpPattern : OneToNOpConversionPattern { if (failed(seedMask)) return rewriter.notifyMatchFailure( op, "unsupported mask type for all-true cmp seed"); - results.push_back( - rewriter - .create(op.getLoc(), resultType, lhs, rhs, *seedMask, - rewriter.getStringAttr(*cmpMode)) - .getResult()); + results.push_back(rewriter + .create(op.getLoc(), resultType, lhs, rhs, + *seedMask, + rewriter.getStringAttr(*cmpMode)) + .getResult()); } rewriter.replaceOp(op, results, adaptor.getResultMapping()); @@ -4519,11 +5219,10 @@ struct OneToNVMISelectOpPattern : OneToNOpConversionPattern { falseValue.getType() != resultType || !isa(resultType)) return rewriter.notifyMatchFailure( op, "physical select part type mismatch"); - results.push_back( - rewriter - .create(op.getLoc(), resultType, trueValue, falseValue, - mask) - .getResult()); + results.push_back(rewriter + .create(op.getLoc(), resultType, trueValue, + falseValue, mask) + .getResult()); } rewriter.replaceOp(op, results, adaptor.getResultMapping()); @@ -4562,18 +5261,17 @@ struct OneToNVMIActivePrefixIndexOpPattern return rewriter.notifyMatchFailure( op, "unsupported element type for active_prefix_index seed mask"); - Value zero = rewriter.create( - op.getLoc(), 0, intType.getWidth()); + Value zero = rewriter.create(op.getLoc(), 0, + intType.getWidth()); Value carrier = rewriter .create(op.getLoc(), resultType, zero, *seedMask, /*position=*/nullptr) .getResult(); - Value result = - rewriter - .create(op.getLoc(), resultType, carrier, - maskParts.front()) - .getResult(); + Value result = rewriter + .create(op.getLoc(), resultType, carrier, + maskParts.front()) + .getResult(); rewriter.replaceOp(op, SmallVector{result}, adaptor.getResultMapping()); return success(); @@ -4600,11 +5298,10 @@ struct OneToNVMICompressOpPattern : OneToNOpConversionPattern { return rewriter.notifyMatchFailure( op, "compress requires physical source/mask/result parts"); - Value result = - rewriter - .create(op.getLoc(), resultType, sourceParts.front(), - maskParts.front()) - .getResult(); + Value result = rewriter + .create(op.getLoc(), resultType, + sourceParts.front(), maskParts.front()) + .getResult(); rewriter.replaceOp(op, SmallVector{result}, adaptor.getResultMapping()); return success(); @@ -4619,14 +5316,12 @@ struct OneToNVMICompressStoreOpPattern LogicalResult matchAndRewrite(VMICompressStoreOp op, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override { - FailureOr destination = - getSingleValue(op, adaptor.getDestination(), - "compress_store destination must convert to one value", - rewriter); - FailureOr offset = - getSingleValue(op, adaptor.getOffset(), - "compress_store offset must convert to one value", - rewriter); + FailureOr destination = getSingleValue( + op, adaptor.getDestination(), + "compress_store destination must convert to one value", rewriter); + FailureOr offset = getSingleValue( + op, adaptor.getOffset(), + "compress_store offset must convert to one value", rewriter); if (failed(destination) || failed(offset)) return failure(); @@ -4648,14 +5343,12 @@ struct OneToNVMICompressStoreOpPattern .create(op.getLoc(), (*destination).getType(), *destination, *offset) .getResult(); - Value squeezed = - rewriter - .create(op.getLoc(), valueType, valueParts.front(), - maskParts.front()) - .getResult(); - auto align = - rewriter.create(op.getLoc(), - AlignType::get(rewriter.getContext())); + Value squeezed = rewriter + .create(op.getLoc(), valueType, + valueParts.front(), maskParts.front()) + .getResult(); + auto align = rewriter.create( + op.getLoc(), AlignType::get(rewriter.getContext())); auto store = rewriter.create( op.getLoc(), align.getResult().getType(), align.getResult(), squeezed, storeBase, rewriter.getStringAttr("POST_UPDATE")); @@ -4667,8 +5360,7 @@ struct OneToNVMICompressStoreOpPattern struct OneToNVMIReduceAddIOpPattern : OneToNOpConversionPattern { - using OneToNOpConversionPattern< - VMIReduceAddIOp>::OneToNOpConversionPattern; + using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult matchAndRewrite(VMIReduceAddIOp op, OpAdaptor adaptor, @@ -4708,16 +5400,16 @@ struct OneToNVMIReduceAddIOpPattern op, "failed to create reduce_addi first-lane mask"); Value accumulator = initParts.front(); - for (auto [sourcePart, maskPart] : llvm::zip_equal(sourceParts, maskParts)) { + for (auto [sourcePart, maskPart] : + llvm::zip_equal(sourceParts, maskParts)) { Value reduced = - rewriter.create(op.getLoc(), resultType, sourcePart, - maskPart) - .getResult(); - accumulator = rewriter - .create(op.getLoc(), resultType, reduced, accumulator, - *firstLaneMask) + .create(op.getLoc(), resultType, sourcePart, maskPart) .getResult(); + accumulator = rewriter + .create(op.getLoc(), resultType, reduced, + accumulator, *firstLaneMask) + .getResult(); } rewriter.replaceOp(op, SmallVector{accumulator}, @@ -4728,8 +5420,7 @@ struct OneToNVMIReduceAddIOpPattern struct OneToNVMIReduceAddFOpPattern : OneToNOpConversionPattern { - using OneToNOpConversionPattern< - VMIReduceAddFOp>::OneToNOpConversionPattern; + using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult matchAndRewrite(VMIReduceAddFOp op, OpAdaptor adaptor, @@ -4769,16 +5460,16 @@ struct OneToNVMIReduceAddFOpPattern op, "failed to create reduce_addf first-lane mask"); Value accumulator = initParts.front(); - for (auto [sourcePart, maskPart] : llvm::zip_equal(sourceParts, maskParts)) { + for (auto [sourcePart, maskPart] : + llvm::zip_equal(sourceParts, maskParts)) { Value reduced = - rewriter.create(op.getLoc(), resultType, sourcePart, - maskPart) - .getResult(); - accumulator = rewriter - .create(op.getLoc(), resultType, reduced, accumulator, - *firstLaneMask) + .create(op.getLoc(), resultType, sourcePart, maskPart) .getResult(); + accumulator = rewriter + .create(op.getLoc(), resultType, reduced, + accumulator, *firstLaneMask) + .getResult(); } rewriter.replaceOp(op, SmallVector{accumulator}, @@ -4808,8 +5499,7 @@ struct OneToNVMIGroupReduceAddFOpPattern op, "group_reduce_addf requires num_groups to evenly divide lane count"); if (succeeded(checkVcgaddGroupReduceShape( - sourceVMIType, maskVMIType, resultVMIType, - *groupSize, nullptr))) { + sourceVMIType, maskVMIType, resultVMIType, *groupSize, nullptr))) { if (sourceParts.size() != maskParts.size() || sourceParts.size() != resultTypes.size() || sourceParts.empty()) return rewriter.notifyMatchFailure( @@ -4832,10 +5522,63 @@ struct OneToNVMIGroupReduceAddFOpPattern SmallVector results; results.reserve(resultTypes.size()); for (auto [sourceIndex, sourcePart] : llvm::enumerate(sourceParts)) { + results.push_back(rewriter + .create(op.getLoc(), resultType, + sourcePart, + maskParts[sourceIndex]) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + if (succeeded(checkS16Block8GroupReduceShape(op, nullptr))) { + int64_t resultPartCount = resultTypes.size(); + if (static_cast(sourceParts.size()) != resultPartCount * 2 || + maskParts.size() != sourceParts.size()) + return rewriter.notifyMatchFailure( + op, "s16 block8 group_reduce_addf arity mismatch"); + + SmallVector results; + results.reserve(resultPartCount); + auto resultType = dyn_cast(resultTypes.front()); + auto maskType = dyn_cast(maskParts.front().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure( + op, "s16 block8 group_reduce_addf requires physical vreg/mask"); + int64_t numGroups = op.getNumGroupsAttr().getInt(); + + for (int64_t resultIndex = 0; resultIndex < resultPartCount; + ++resultIndex) { + int64_t activeGroups = + std::min(8, numGroups - resultIndex * 8); + FailureOr combineMask = createPrefixMaskForActiveLanes( + op.getLoc(), maskType, activeGroups, rewriter); + if (failed(combineMask)) + return rewriter.notifyMatchFailure( + op, "failed to create s16 block8 combine mask"); + Value loSource = sourceParts[resultIndex]; + Value hiSource = sourceParts[resultPartCount + resultIndex]; + Value loMask = maskParts[resultIndex]; + Value hiMask = maskParts[resultPartCount + resultIndex]; + Type physicalResultType = resultTypes[resultIndex]; + if (physicalResultType != resultType || + loSource.getType() != resultType || + hiSource.getType() != resultType || loMask.getType() != maskType || + hiMask.getType() != maskType) + return rewriter.notifyMatchFailure( + op, "s16 block8 group_reduce_addf requires uniform physical " + "types"); + Value lo = + rewriter.create(op.getLoc(), resultType, loSource, loMask) + .getResult(); + Value hi = + rewriter.create(op.getLoc(), resultType, hiSource, hiMask) + .getResult(); results.push_back( rewriter - .create(op.getLoc(), resultType, sourcePart, - maskParts[sourceIndex]) + .create(op.getLoc(), resultType, lo, hi, *combineMask) .getResult()); } @@ -4843,12 +5586,71 @@ struct OneToNVMIGroupReduceAddFOpPattern return success(); } + if (succeeded(checkS32Block8GroupReduceShape(op, nullptr))) { + int64_t resultPartCount = resultTypes.size(); + if (static_cast(sourceParts.size()) != resultPartCount * 4 || + maskParts.size() != sourceParts.size()) + return rewriter.notifyMatchFailure( + op, "s32 block8 group_reduce_addf arity mismatch"); + + SmallVector results; + results.reserve(resultPartCount); + auto resultType = dyn_cast(resultTypes.front()); + auto maskType = dyn_cast(maskParts.front().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure( + op, "s32 block8 group_reduce_addf requires physical vreg/mask"); + int64_t numGroups = op.getNumGroupsAttr().getInt(); + + for (int64_t resultIndex = 0; resultIndex < resultPartCount; + ++resultIndex) { + int64_t activeGroups = + std::min(8, numGroups - resultIndex * 8); + FailureOr combineMask = createPrefixMaskForActiveLanes( + op.getLoc(), maskType, activeGroups, rewriter); + if (failed(combineMask)) + return rewriter.notifyMatchFailure( + op, "failed to create s32 block8 combine mask"); + SmallVector partials; + partials.reserve(4); + for (int64_t part = 0; part < 4; ++part) { + int64_t sourceIndex = part * resultPartCount + resultIndex; + Value source = sourceParts[sourceIndex]; + Value mask = maskParts[sourceIndex]; + Type physicalResultType = resultTypes[resultIndex]; + if (physicalResultType != resultType || + source.getType() != resultType || mask.getType() != maskType) + return rewriter.notifyMatchFailure( + op, "s32 block8 group_reduce_addf requires uniform physical " + "types"); + partials.push_back( + rewriter.create(op.getLoc(), resultType, source, mask) + .getResult()); + } + Value sum01 = rewriter + .create(op.getLoc(), resultType, partials[0], + partials[1], *combineMask) + .getResult(); + Value sum23 = rewriter + .create(op.getLoc(), resultType, partials[2], + partials[3], *combineMask) + .getResult(); + results.push_back(rewriter + .create(op.getLoc(), resultType, sum01, + sum23, *combineMask) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + int64_t lanesPerPart = 0; int64_t groupCount = 0; int64_t chunksPerGroup = 0; - if (failed(checkContiguousFullGroupChunks( - op, sourceVMIType, *groupSize, &lanesPerPart, &groupCount, - &chunksPerGroup, rewriter))) + if (failed(checkContiguousFullGroupChunks(op, sourceVMIType, *groupSize, + &lanesPerPart, &groupCount, + &chunksPerGroup, rewriter))) return failure(); if (sourceParts.size() != maskParts.size() || static_cast(sourceParts.size()) != @@ -4901,11 +5703,10 @@ struct OneToNVMIGroupReduceAddFOpPattern .create(op.getLoc(), resultType, sourceParts[index], maskParts[index]) .getResult(); - *accumulator = - rewriter - .create(op.getLoc(), resultType, reduced, - *accumulator, *firstLaneMask) - .getResult(); + *accumulator = rewriter + .create(op.getLoc(), resultType, reduced, + *accumulator, *firstLaneMask) + .getResult(); } int64_t destChunk = group * chunksPerGroup; @@ -4960,8 +5761,8 @@ struct OneToNVMIGroupBroadcastOpPattern auto firstSourceType = dyn_cast(sourceParts.front().getType()); if (!firstSourceType) - return rewriter.notifyMatchFailure( - op, "group_broadcast source must be vreg"); + return rewriter.notifyMatchFailure(op, + "group_broadcast source must be vreg"); unsigned indexBits = pto::getPTOStorageElemBitWidth(firstSourceType.getElementType()); if (indexBits != 8 && indexBits != 16 && indexBits != 32) @@ -4971,20 +5772,18 @@ struct OneToNVMIGroupBroadcastOpPattern auto indexType = VRegType::get(rewriter.getContext(), firstSourceType.getElementCount(), indexElementType); - std::optional groupSlotIndex; FailureOr allMask = createAllTrueMaskForVReg(op.getLoc(), firstSourceType, rewriter); if (failed(allMask)) return rewriter.notifyMatchFailure( op, "failed to create group_broadcast all mask"); - if (*groupSize < lanesPerPart) { - FailureOr index = createGroupSlotIndexVector( - op.getLoc(), indexType, *groupSize, rewriter); - if (failed(index)) - return rewriter.notifyMatchFailure( - op, "failed to create group_broadcast group-slot index vector"); - groupSlotIndex = *index; - } + VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); + VMILayoutAttr sourceLayout = sourceVMIType.getLayoutAttr(); + int64_t selectionGroupSize = *groupSize; + if (resultLayoutFactor != 1 && resultLayout && + resultLayout.isDeinterleaved() && resultLayout.getBlockElems() > 1 && + *groupSize < lanesPerPart) + selectionGroupSize = resultLayout.getBlockElems(); SmallVector results; results.resize(resultTypes.size()); @@ -4994,44 +5793,102 @@ struct OneToNVMIGroupBroadcastOpPattern return rewriter.notifyMatchFailure( op, "group_broadcast requires uniform physical vreg types"); int64_t sourceChunk = flatIndex; + int64_t baseGroupSlot = 0; if (resultLayoutFactor == 1) { if (*groupSize >= lanesPerPart) { int64_t chunksPerGroup = *groupSize / lanesPerPart; int64_t group = flatIndex / chunksPerGroup; sourceChunk = group * chunksPerGroup; + } else { + VMILayoutAttr sourceLayout = sourceVMIType.getLayoutAttr(); + int64_t slots = sourceLayout.getSlots(); + if (slots <= 0) { + if (sourceParts.empty() || + groupCount % static_cast(sourceParts.size()) != 0) + return rewriter.notifyMatchFailure( + op, "group_broadcast small-group source requires explicit " + "group_slots slots or derivable legacy slot count"); + slots = groupCount / sourceParts.size(); + } + int64_t groupsPerResultChunk = lanesPerPart / *groupSize; + int64_t firstGroup = flatIndex * groupsPerResultChunk; + sourceChunk = firstGroup / slots; + baseGroupSlot = firstGroup % slots; } } else { - int64_t runningFlatIndex = 0; - bool found = false; - for (int64_t part = 0; part < resultLayoutFactor && !found; ++part) { - FailureOr chunks = getDataChunksInPart(resultVMIType, part); - if (failed(chunks)) - return rewriter.notifyMatchFailure( - op, "group_broadcast failed to enumerate result chunks"); - for (int64_t chunk = 0; chunk < *chunks; ++chunk, ++runningFlatIndex) { - if (runningFlatIndex != static_cast(flatIndex)) - continue; - FailureOr firstLogical = - mapPhysicalLaneToLogical(resultVMIType, part, chunk, 0); - FailureOr lastLogical = mapPhysicalLaneToLogical( - resultVMIType, part, chunk, lanesPerPart - 1); - if (failed(firstLogical) || failed(lastLogical)) + bool blockFragmentSmallGroup = + resultLayout && resultLayout.isDeinterleaved() && + resultLayout.getBlockElems() > 1 && *groupSize < lanesPerPart; + if (blockFragmentSmallGroup) { + int64_t runningFlatIndex = 0; + bool found = false; + for (int64_t part = 0; part < resultLayoutFactor && !found; ++part) { + FailureOr chunks = + getDataChunksInPart(resultVMIType, part); + if (failed(chunks)) return rewriter.notifyMatchFailure( - op, "group_broadcast failed to map result chunk lanes"); - int64_t firstGroup = *firstLogical / *groupSize; - int64_t lastGroup = *lastLogical / *groupSize; - if (firstGroup != lastGroup) + op, "group_broadcast failed to enumerate result chunks"); + for (int64_t chunk = 0; chunk < *chunks; + ++chunk, ++runningFlatIndex) { + if (runningFlatIndex != static_cast(flatIndex)) + continue; + int64_t groupsPerResultChunk = + lanesPerPart / resultLayout.getBlockElems(); + int64_t firstGroup = chunk * groupsPerResultChunk; + int64_t slots = sourceLayout.getSlots(); + if (slots <= 0) { + if (sourceParts.empty() || + groupCount % static_cast(sourceParts.size()) != 0) + return rewriter.notifyMatchFailure( + op, + "group_broadcast block-fragment source requires explicit " + "group_slots slots or derivable legacy slot count"); + slots = groupCount / sourceParts.size(); + } + sourceChunk = firstGroup / slots; + baseGroupSlot = firstGroup % slots; + found = true; + break; + } + } + if (!found) + return rewriter.notifyMatchFailure( + op, "group_broadcast result chunk index is out of range"); + } else { + int64_t runningFlatIndex = 0; + bool found = false; + for (int64_t part = 0; part < resultLayoutFactor && !found; ++part) { + FailureOr chunks = + getDataChunksInPart(resultVMIType, part); + if (failed(chunks)) return rewriter.notifyMatchFailure( - op, "group_broadcast result chunk crosses logical groups"); - int64_t chunksPerGroup = *groupSize / lanesPerPart; - sourceChunk = firstGroup * chunksPerGroup; - found = true; - break; + op, "group_broadcast failed to enumerate result chunks"); + for (int64_t chunk = 0; chunk < *chunks; + ++chunk, ++runningFlatIndex) { + if (runningFlatIndex != static_cast(flatIndex)) + continue; + FailureOr firstLogical = + mapPhysicalLaneToLogical(resultVMIType, part, chunk, 0); + FailureOr lastLogical = mapPhysicalLaneToLogical( + resultVMIType, part, chunk, lanesPerPart - 1); + if (failed(firstLogical) || failed(lastLogical)) + return rewriter.notifyMatchFailure( + op, "group_broadcast failed to map result chunk lanes"); + int64_t firstGroup = *firstLogical / *groupSize; + int64_t lastGroup = *lastLogical / *groupSize; + if (firstGroup != lastGroup) + return rewriter.notifyMatchFailure( + op, "group_broadcast result chunk crosses logical groups"); + int64_t chunksPerGroup = *groupSize / lanesPerPart; + sourceChunk = firstGroup * chunksPerGroup; + found = true; + break; + } } + if (!found) + return rewriter.notifyMatchFailure( + op, "group_broadcast result chunk index is out of range"); } - if (!found) - return rewriter.notifyMatchFailure( - op, "group_broadcast result chunk index is out of range"); } if (*groupSize >= lanesPerPart) { if (sourceChunk < 0 || @@ -5040,11 +5897,15 @@ struct OneToNVMIGroupBroadcastOpPattern op, "group_broadcast source chunk is out of range"); results[flatIndex] = rewriter - .create(op.getLoc(), resultType, sourceParts[sourceChunk], - *allMask, rewriter.getStringAttr("LOWEST")) + .create(op.getLoc(), resultType, + sourceParts[sourceChunk], *allMask, + rewriter.getStringAttr("LOWEST")) .getResult(); } else { - if (resultLayoutFactor != 1) + bool blockFragmentSmallGroup = resultLayout && + resultLayout.isDeinterleaved() && + resultLayout.getBlockElems() > 1; + if (resultLayoutFactor != 1 && !blockFragmentSmallGroup) return rewriter.notifyMatchFailure( op, "group_broadcast small-group deinterleaved result is not " "supported"); @@ -5052,6 +5913,12 @@ struct OneToNVMIGroupBroadcastOpPattern sourceChunk >= static_cast(sourceParts.size())) return rewriter.notifyMatchFailure( op, "group_broadcast source chunk is out of range"); + FailureOr groupSlotIndex = createGroupSlotIndexVector( + op.getLoc(), indexType, selectionGroupSize, baseGroupSlot, + rewriter); + if (failed(groupSlotIndex)) + return rewriter.notifyMatchFailure( + op, "failed to create group_broadcast group-slot index vector"); results[flatIndex] = rewriter .create(op.getLoc(), resultType, @@ -5066,15 +5933,13 @@ struct OneToNVMIGroupBroadcastOpPattern }; template -struct OneToNVMIReduceMinMaxFOpPattern - : OneToNOpConversionPattern { +struct OneToNVMIReduceMinMaxFOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; - LogicalResult - matchAndRewrite( + LogicalResult matchAndRewrite( SourceOp op, typename OneToNOpConversionPattern::OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + OneToNPatternRewriter &rewriter) const override { ValueRange sourceParts = adaptor.getSource(); ValueRange initParts = adaptor.getInit(); ValueRange maskParts = adaptor.getMask(); @@ -5112,15 +5977,14 @@ struct OneToNVMIReduceMinMaxFOpPattern Value accumulator = initParts.front(); for (auto [sourcePart, maskPart] : llvm::zip_equal(sourceParts, maskParts)) { - Value reduced = - rewriter.create(op.getLoc(), resultType, sourcePart, - maskPart) - .getResult(); - accumulator = - rewriter - .create(op.getLoc(), resultType, reduced, accumulator, - *firstLaneMask) - .getResult(); + Value reduced = rewriter + .create(op.getLoc(), resultType, + sourcePart, maskPart) + .getResult(); + accumulator = rewriter + .create(op.getLoc(), resultType, reduced, + accumulator, *firstLaneMask) + .getResult(); } rewriter.replaceOp(op, SmallVector{accumulator}, @@ -5156,9 +6020,8 @@ struct OneToNVMIExtFOpPattern : OneToNOpConversionPattern { for (Type resultType : resultTypes) { auto resultVRegType = dyn_cast(resultType); if (!resultVRegType || - (resultVRegTypes.empty() - ? !resultVRegType.getElementType().isF32() - : resultVRegType != resultVRegTypes.front())) + (resultVRegTypes.empty() ? !resultVRegType.getElementType().isF32() + : resultVRegType != resultVRegTypes.front())) return rewriter.notifyMatchFailure( op, "unsupported physical extf result type"); resultVRegTypes.push_back(resultVRegType); @@ -5185,8 +6048,7 @@ struct OneToNVMIExtFOpPattern : OneToNOpConversionPattern { FailureOr mask = createAllTrueMaskForVReg(op.getLoc(), sourceType, rewriter); if (failed(mask)) - return rewriter.notifyMatchFailure(op, - "failed to build extf seed mask"); + return rewriter.notifyMatchFailure(op, "failed to build extf seed mask"); SmallVector results; results.reserve(resultTypes.size()); @@ -5214,12 +6076,59 @@ struct OneToNVMITruncFOpPattern : OneToNOpConversionPattern { LogicalResult matchAndRewrite(VMITruncFOp op, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override { + auto sourceVMIType = cast(op.getSource().getType()); + auto resultVMIType = cast(op.getResult().getType()); ValueRange sourceParts = adaptor.getSource(); TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + + VMILayoutAttr sourceLayout = sourceVMIType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); + if (sourceLayout && resultLayout && sourceLayout.isGroupSlots() && + resultLayout.isGroupSlots()) { + if (sourceLayout.getNumGroups() != resultLayout.getNumGroups() || + sourceLayout.getSlots() != 1 || resultLayout.getSlots() != 1 || + !sourceVMIType.getElementType().isF32() || + pto::getPTOStorageElemBitWidth(resultVMIType.getElementType()) != + 16 || + sourceParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "unsupported group-slot truncf shape"); + + SmallVector results; + results.reserve(resultTypes.size()); + StringAttr rnd = rewriter.getStringAttr("R"); + StringAttr sat = rewriter.getStringAttr("SAT"); + StringAttr even = rewriter.getStringAttr("EVEN"); + FailureOr lane0Mask = createPrefixMask( + op.getLoc(), MaskType::get(rewriter.getContext(), "b32"), "PAT_VL1", + rewriter); + if (failed(lane0Mask)) + return rewriter.notifyMatchFailure( + op, "failed to build group-slot truncf lane0 mask"); + for (auto [sourcePart, physicalResultType] : + llvm::zip_equal(sourceParts, resultTypes)) { + auto sourceType = dyn_cast(sourcePart.getType()); + auto resultType = dyn_cast(physicalResultType); + if (!sourceType || !sourceType.getElementType().isF32() || + !resultType || + pto::getPTOStorageElemBitWidth(resultType.getElementType()) != 16) + return rewriter.notifyMatchFailure( + op, "unsupported group-slot truncf physical type"); + results.push_back(rewriter + .create(op.getLoc(), resultType, + sourcePart, *lane0Mask, rnd, sat, + even) + .getResult()); + } + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + if ((sourceParts.size() != 2 && sourceParts.size() != 4) || resultTypes.size() != 1) return rewriter.notifyMatchFailure( - op, "only f32 deinterleaved=2/4 to 16/8-bit contiguous truncf is supported"); + op, "only f32 deinterleaved=2/4 to 16/8-bit contiguous truncf is " + "supported"); auto sourceType0 = dyn_cast(sourceParts.front().getType()); auto resultType = dyn_cast(resultTypes.front()); @@ -5252,36 +6161,33 @@ struct OneToNVMITruncFOpPattern : OneToNOpConversionPattern { FailureOr resultMask = createAllTrueMaskForVReg(op.getLoc(), resultType, rewriter); if (failed(sourceMask) || failed(resultMask)) - return rewriter.notifyMatchFailure(op, - "failed to build truncf masks"); + return rewriter.notifyMatchFailure(op, "failed to build truncf masks"); StringAttr rnd = rewriter.getStringAttr("R"); StringAttr sat = rewriter.getStringAttr("SAT"); SmallVector partials; partials.reserve(parts.size()); for (auto [sourcePart, part] : llvm::zip_equal(sourceParts, parts)) { - partials.push_back( - rewriter - .create(op.getLoc(), resultType, sourcePart, *sourceMask, - rnd, sat, rewriter.getStringAttr(part)) - .getResult()); + partials.push_back(rewriter + .create(op.getLoc(), resultType, + sourcePart, *sourceMask, rnd, sat, + rewriter.getStringAttr(part)) + .getResult()); } Value merged = partials.front(); for (Value partial : llvm::drop_begin(partials)) - merged = - rewriter - .create(op.getLoc(), resultType, merged, partial, - *resultMask) - .getResult(); + merged = rewriter + .create(op.getLoc(), resultType, merged, partial, + *resultMask) + .getResult(); rewriter.replaceOp(op, merged, adaptor.getResultMapping()); return success(); } }; -struct OneToNVMIBitcastOpPattern - : OneToNOpConversionPattern { +struct OneToNVMIBitcastOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult @@ -5290,8 +6196,7 @@ struct OneToNVMIBitcastOpPattern ValueRange sourceParts = adaptor.getSource(); TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); if (sourceParts.size() != resultTypes.size()) - return rewriter.notifyMatchFailure(op, - "physical bitcast arity mismatch"); + return rewriter.notifyMatchFailure(op, "physical bitcast arity mismatch"); SmallVector results; results.reserve(resultTypes.size()); @@ -5312,8 +6217,7 @@ struct OneToNVMIBitcastOpPattern struct OneToNVMIChannelSplitOpPattern : OneToNOpConversionPattern { - using OneToNOpConversionPattern< - VMIChannelSplitOp>::OneToNOpConversionPattern; + using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult matchAndRewrite(VMIChannelSplitOp op, OpAdaptor adaptor, @@ -5342,9 +6246,9 @@ struct OneToNVMIChannelSplitOpPattern } TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(); - FailureOr> results = materializeDataLayoutConversion( - op, adaptor.getSource(), resultTypes, sourceLayout, channelLayout, - rewriter); + FailureOr> results = + materializeDataLayoutConversion(op, adaptor.getSource(), resultTypes, + sourceLayout, channelLayout, rewriter); if (failed(results)) return failure(); @@ -5355,8 +6259,7 @@ struct OneToNVMIChannelSplitOpPattern struct OneToNVMIChannelMergeOpPattern : OneToNOpConversionPattern { - using OneToNOpConversionPattern< - VMIChannelMergeOp>::OneToNOpConversionPattern; + using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult matchAndRewrite(VMIChannelMergeOp op, OpAdaptor adaptor, @@ -5417,8 +6320,8 @@ struct OneToNVMIShuffleOpPattern : OneToNOpConversionPattern { results.push_back(sourceParts[sourceFlatIndex]); } - if (failed(verifyIdentityPartForwarding(op, results, resultTypes, - rewriter))) + if (failed( + verifyIdentityPartForwarding(op, results, resultTypes, rewriter))) return failure(); rewriter.replaceOp(op, results, adaptor.getResultMapping()); @@ -5448,11 +6351,11 @@ struct OneToNVMIShuffleOpPattern : OneToNOpConversionPattern { if (failed(mask)) return rewriter.notifyMatchFailure( op, "failed to create shuffle lane0 splat mask"); - results.push_back( - rewriter - .create(op.getLoc(), resultType, sourcePart, *mask, - rewriter.getStringAttr("LOWEST")) - .getResult()); + results.push_back(rewriter + .create(op.getLoc(), resultType, + sourcePart, *mask, + rewriter.getStringAttr("LOWEST")) + .getResult()); } rewriter.replaceOp(op, results, adaptor.getResultMapping()); @@ -5463,12 +6366,11 @@ struct OneToNVMIShuffleOpPattern : OneToNOpConversionPattern { FailureOr> vselrPlans = computeShuffleVselrPlans(op, &vselrReason); if (failed(vselrPlans)) - return rewriter.notifyMatchFailure( - op, Twine("shuffle vselr ") + vselrReason); + return rewriter.notifyMatchFailure(op, + Twine("shuffle vselr ") + vselrReason); if (vselrPlans->size() != resultTypes.size()) - return rewriter.notifyMatchFailure(op, - "shuffle vselr arity mismatch"); + return rewriter.notifyMatchFailure(op, "shuffle vselr arity mismatch"); SmallVector results; results.reserve(resultTypes.size()); @@ -5496,8 +6398,8 @@ struct OneToNVMIShuffleOpPattern : OneToNOpConversionPattern { auto indexElementType = IntegerType::get(rewriter.getContext(), indexBits); Type indexType = - VRegType::get(rewriter.getContext(), - sourceVRegType.getElementCount(), indexElementType); + VRegType::get(rewriter.getContext(), sourceVRegType.getElementCount(), + indexElementType); FailureOr base = createScalarOffsetConstant( op.getLoc(), indexElementType, plan.baseLane, rewriter); if (failed(base)) @@ -5508,11 +6410,11 @@ struct OneToNVMIShuffleOpPattern : OneToNOpConversionPattern { Value indexVector = rewriter.create(op.getLoc(), indexType, *base, orderAttr) .getResult(); - results.push_back( - rewriter - .create(op.getLoc(), resultType, - sourceParts[plan.sourceFlatIndex], indexVector) - .getResult()); + results.push_back(rewriter + .create(op.getLoc(), resultType, + sourceParts[plan.sourceFlatIndex], + indexVector) + .getResult()); } rewriter.replaceOp(op, results, adaptor.getResultMapping()); @@ -5590,17 +6492,15 @@ struct OneToNCFCondBranchOpPattern const OneToNTypeMapping &operandMapping = adaptor.getOperandMapping(); unsigned operandIndex = 1; for (unsigned i = 0, e = op.getNumTrueOperands(); i < e; ++i) - llvm::append_range( - trueOperands, - operandMapping.getConvertedValues(flatOperands, operandIndex++)); + llvm::append_range(trueOperands, operandMapping.getConvertedValues( + flatOperands, operandIndex++)); for (unsigned i = 0, e = op.getNumFalseOperands(); i < e; ++i) - llvm::append_range( - falseOperands, - operandMapping.getConvertedValues(flatOperands, operandIndex++)); + llvm::append_range(falseOperands, operandMapping.getConvertedValues( + flatOperands, operandIndex++)); - rewriter.replaceOpWithNewOp( - op, condition.front(), trueDest, trueOperands, falseDest, - falseOperands); + rewriter.replaceOpWithNewOp(op, condition.front(), + trueDest, trueOperands, + falseDest, falseOperands); return success(); } }; @@ -5613,9 +6513,8 @@ struct OneToNCFSwitchOpPattern : OneToNOpConversionPattern { OneToNPatternRewriter &rewriter) const override { auto *converter = getTypeConverter(); llvm::DenseMap convertedBlocks; - Block *defaultDest = - convertBranchDestBlock(op.getDefaultDestination(), rewriter, - *converter, convertedBlocks); + Block *defaultDest = convertBranchDestBlock( + op.getDefaultDestination(), rewriter, *converter, convertedBlocks); SmallVector caseDests; caseDests.reserve(op.getCaseDestinations().size()); @@ -5633,7 +6532,8 @@ struct OneToNCFSwitchOpPattern : OneToNOpConversionPattern { ValueRange flag = adaptor.getFlag(); if (flag.size() != 1) - return rewriter.notifyMatchFailure(op, "flag converted to multiple values"); + return rewriter.notifyMatchFailure(op, + "flag converted to multiple values"); SmallVector defaultOperands; SmallVector> caseOperandStorage; @@ -5643,18 +6543,16 @@ struct OneToNCFSwitchOpPattern : OneToNOpConversionPattern { unsigned operandIndex = 1; for (unsigned i = 0, e = op.getDefaultOperands().size(); i < e; ++i) - llvm::append_range( - defaultOperands, - operandMapping.getConvertedValues(flatOperands, operandIndex++)); + llvm::append_range(defaultOperands, operandMapping.getConvertedValues( + flatOperands, operandIndex++)); caseOperandStorage.reserve(op.getCaseOperandSegments().size()); caseOperands.reserve(op.getCaseOperandSegments().size()); for (int32_t segmentSize : op.getCaseOperandSegments()) { SmallVector operands; for (int32_t i = 0; i < segmentSize; ++i) - llvm::append_range( - operands, - operandMapping.getConvertedValues(flatOperands, operandIndex++)); + llvm::append_range(operands, operandMapping.getConvertedValues( + flatOperands, operandIndex++)); caseOperandStorage.push_back(std::move(operands)); } for (SmallVector &operands : caseOperandStorage) @@ -5694,7 +6592,8 @@ struct OneToNSCFExecuteRegionOpPattern struct OneToNSCFIndexSwitchOpPattern : OneToNOpConversionPattern { - using OneToNOpConversionPattern::OneToNOpConversionPattern; + using OneToNOpConversionPattern< + scf::IndexSwitchOp>::OneToNOpConversionPattern; LogicalResult matchAndRewrite(scf::IndexSwitchOp op, OpAdaptor adaptor, @@ -5712,11 +6611,9 @@ struct OneToNSCFIndexSwitchOpPattern return failure(); auto newOp = rewriter.create( - op.getLoc(), resultTypes, arg.front(), op.getCases(), - op.getNumCases()); + op.getLoc(), resultTypes, arg.front(), op.getCases(), op.getNumCases()); newOp->setAttrs(op->getAttrs()); - rewriter.inlineRegionBefore(op.getDefaultRegion(), - newOp.getDefaultRegion(), + rewriter.inlineRegionBefore(op.getDefaultRegion(), newOp.getDefaultRegion(), newOp.getDefaultRegion().end()); for (auto [srcRegion, dstRegion] : llvm::zip(op.getCaseRegions(), newOp.getCaseRegions())) @@ -5731,80 +6628,59 @@ void populateVMIOneToNConversionPatterns( const VMITargetCapabilityRegistry &capabilities) { populateFuncTypeConversionPatterns(typeConverter, patterns); scf::populateSCFStructuralOneToNTypeConversions(typeConverter, patterns); - patterns - .add(typeConverter, patterns.getContext()); - patterns.add(typeConverter, - patterns.getContext()); + patterns.add(typeConverter, patterns.getContext()); + patterns.add( + typeConverter, patterns.getContext()); patterns.add( typeConverter, patterns.getContext()); - patterns.add, - OneToNVMIMaskBinaryOpPattern, - OneToNVMIMaskBinaryOpPattern, - OneToNVMIMaskUnaryOpPattern, - OneToNVMILoadOpPattern, - OneToNVMIGroupLoadOpPattern, - OneToNVMIMaskedLoadOpPattern, - OneToNVMIGatherOpPattern, - OneToNVMIExpandLoadOpPattern, - OneToNVMIStoreOpPattern, - OneToNVMIGroupStoreOpPattern, - OneToNVMIMaskedStoreOpPattern, - OneToNVMIScatterOpPattern, - OneToNVMITileReadOpPattern, - OneToNVMITileWriteOpPattern, - OneToNVMIBinaryOpPattern, - OneToNVMIBinaryOpPattern, - OneToNVMIBinaryOpPattern, - OneToNVMIBinaryOpPattern, - OneToNVMIBinaryOpPattern, - OneToNVMIBinaryOpPattern, - OneToNVMIFmaOpPattern, - OneToNVMIBinaryOpPattern, - OneToNVMIBinaryOpPattern, - OneToNVMIBinaryOpPattern, - OneToNVMIUnaryOpPattern, - OneToNVMIUnaryOpPattern, - OneToNVMIUnaryOpPattern, - OneToNVMIUnaryOpPattern, - OneToNVMIUnaryOpPattern, - OneToNVMIUnaryOpPattern, - OneToNVMIUnaryOpPattern, - OneToNVMIBinaryOpPattern, - OneToNVMIBinaryOpPattern, - OneToNVMIBinaryOpPattern, - OneToNVMIBinaryOpPattern, - OneToNVMIBinaryOpPattern, - OneToNVMIUnaryOpPattern, - OneToNVMICmpOpPattern, - OneToNVMICmpOpPattern, - OneToNVMISelectOpPattern, - OneToNVMIActivePrefixIndexOpPattern, - OneToNVMICompressOpPattern, - OneToNVMICompressStoreOpPattern, - OneToNVMIReduceAddIOpPattern, - OneToNVMIReduceAddFOpPattern, - OneToNVMIGroupReduceAddFOpPattern, - OneToNVMIGroupBroadcastOpPattern, - OneToNVMIReduceMinMaxFOpPattern, - OneToNVMIReduceMinMaxFOpPattern, - OneToNVMIExtFOpPattern, - OneToNVMITruncFOpPattern, - OneToNVMIBitcastOpPattern, - OneToNVMIChannelSplitOpPattern, - OneToNVMIChannelMergeOpPattern, - OneToNVMIShuffleOpPattern>(typeConverter, - patterns.getContext()); + patterns.add< + OneToNVMIEnsureLayoutOpPattern, OneToNVMIEnsureMaskLayoutOpPattern, + OneToNVMIBroadcastOpPattern, OneToNVMIIotaOpPattern, + OneToNVMIConstantOpPattern, OneToNVMIConstantMaskOpPattern, + OneToNVMICreateMaskOpPattern, OneToNVMICreateGroupMaskOpPattern, + OneToNVMIMaskBinaryOpPattern, + OneToNVMIMaskBinaryOpPattern, + OneToNVMIMaskBinaryOpPattern, + OneToNVMIMaskUnaryOpPattern, OneToNVMILoadOpPattern, + OneToNVMIGroupLoadOpPattern, OneToNVMIGroupSlotLoadOpPattern, + OneToNVMIMaskedLoadOpPattern, OneToNVMIGatherOpPattern, + OneToNVMIExpandLoadOpPattern, OneToNVMIStoreOpPattern, + OneToNVMIGroupStoreOpPattern, OneToNVMIMaskedStoreOpPattern, + OneToNVMIScatterOpPattern, OneToNVMITileReadOpPattern, + OneToNVMITileWriteOpPattern, OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, OneToNVMIFmaOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMICmpOpPattern, OneToNVMICmpOpPattern, + OneToNVMISelectOpPattern, OneToNVMIActivePrefixIndexOpPattern, + OneToNVMICompressOpPattern, OneToNVMICompressStoreOpPattern, + OneToNVMIReduceAddIOpPattern, OneToNVMIReduceAddFOpPattern, + OneToNVMIGroupReduceAddFOpPattern, OneToNVMIGroupBroadcastOpPattern, + OneToNVMIReduceMinMaxFOpPattern, + OneToNVMIReduceMinMaxFOpPattern, + OneToNVMIExtFOpPattern, OneToNVMITruncFOpPattern, + OneToNVMIBitcastOpPattern, OneToNVMIChannelSplitOpPattern, + OneToNVMIChannelMergeOpPattern, OneToNVMIShuffleOpPattern>( + typeConverter, patterns.getContext()); patterns.add( typeConverter, patterns.getContext(), capabilities); } @@ -5812,9 +6688,8 @@ void populateVMIOneToNConversionPatterns( LogicalResult verifyNoResidualVMIIR(ModuleOp module) { WalkResult result = module.walk([&](Operation *op) { if (isa(op)) { - op->emitError() - << kVMIDiagResidualOpPrefix - << "unrealized conversion cast remains after vmi-to-vpto"; + op->emitError() << kVMIDiagResidualOpPrefix + << "unrealized conversion cast remains after vmi-to-vpto"; return WalkResult::interrupt(); } if (auto createMask = dyn_cast(op)) { @@ -5837,9 +6712,8 @@ LogicalResult verifyNoResidualVMIIR(ModuleOp module) { } } if (isVMIOp(op) || hasVMIType(op)) { - op->emitError() - << kVMIDiagResidualOpPrefix - << "failed to convert all VMI ops/types to VPTO"; + op->emitError() << kVMIDiagResidualOpPrefix + << "failed to convert all VMI ops/types to VPTO"; return WalkResult::interrupt(); } return WalkResult::advance(); @@ -5856,8 +6730,7 @@ LogicalResult checkSupportedExtFShape(VMIExtFOp op) { FailureOr resultArity = getVMIPhysicalArity(resultType); if (!sourceLayout || !resultLayout || failed(sourceArity) || failed(resultArity) || !sourceLayout.isContiguous() || - !resultLayout.isDeinterleaved() || - !resultType.getElementType().isF32()) + !resultLayout.isDeinterleaved() || !resultType.getElementType().isF32()) return failure(); unsigned sourceBits = @@ -5871,7 +6744,14 @@ LogicalResult checkSupportedExtFShape(VMIExtFOp op) { return failure(); } -LogicalResult checkSupportedTruncFShape(VMITruncFOp op) { +LogicalResult checkSupportedTruncFShape(VMITruncFOp op, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + auto sourceType = cast(op.getSource().getType()); auto resultType = cast(op.getResult().getType()); VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); @@ -5879,18 +6759,45 @@ LogicalResult checkSupportedTruncFShape(VMITruncFOp op) { FailureOr sourceArity = getVMIPhysicalArity(sourceType); FailureOr resultArity = getVMIPhysicalArity(resultType); if (!sourceLayout || !resultLayout || failed(sourceArity) || - failed(resultArity) || !sourceLayout.isDeinterleaved() || - !resultLayout.isContiguous() || !sourceType.getElementType().isF32() || - *resultArity != 1) - return failure(); + failed(resultArity)) + return fail("requires assigned source/result layouts and computable " + "physical arity"); unsigned resultBits = pto::getPTOStorageElemBitWidth(resultType.getElementType()); + + if (sourceLayout.isGroupSlots() || resultLayout.isGroupSlots()) { + if (!sourceLayout.isGroupSlots() || !resultLayout.isGroupSlots() || + sourceLayout.getNumGroups() != resultLayout.getNumGroups() || + sourceLayout.getSlots() != 1 || resultLayout.getSlots() != 1 || + !sourceType.getElementType().isF32() || resultBits != 16 || + *sourceArity != *resultArity) + return fail("group-slot truncf requires matching " + "group_slots(num_groups=G, slots=1) source/result layouts, " + "f32 source, f16 result, and matching physical arity"); + + auto selectedPlan = op->getAttrOfType(kVMISelectedPlanAttrName); + if (!selectedPlan) + return fail("requires vmi.selected_plan selected by " + "vmi-layout-assignment"); + StringRef expectedPlan = "group_slot_cast_slots1_f32_to_f16"; + if (selectedPlan.getValue() != expectedPlan) + return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + + "' does not match source/result layouts; expected '" + + expectedPlan + "'"); + return success(); + } + + if (!sourceLayout.isDeinterleaved() || !resultLayout.isContiguous() || + !sourceType.getElementType().isF32() || *resultArity != 1) + return fail("requires f32 deinterleaved source and contiguous result"); + if (sourceLayout.getFactor() == 2 && *sourceArity == 2 && resultBits == 16) return success(); if (sourceLayout.getFactor() == 4 && *sourceArity == 4 && resultBits == 8) return success(); - return failure(); + return fail("unsupported deinterleaved truncf factor, arity, or result " + "element width"); } FailureOr> @@ -5900,8 +6807,7 @@ getPhysicalLogicalBitFootprint(VMIVRegType type) { return failure(); FailureOr factor = getDataLayoutFactor(type); - FailureOr lanesPerPart = - getDataLanesPerPart(type.getElementType()); + FailureOr lanesPerPart = getDataLanesPerPart(type.getElementType()); if (failed(factor) || failed(lanesPerPart)) return failure(); @@ -5925,8 +6831,7 @@ getPhysicalLogicalBitFootprint(VMIVRegType type) { return bits; } -LogicalResult checkSupportedBitcastShape(VMIBitcastOp op, - std::string *reason) { +LogicalResult checkSupportedBitcastShape(VMIBitcastOp op, std::string *reason) { auto fail = [&](const Twine &message) -> LogicalResult { if (reason) *reason = message.str(); @@ -5967,9 +6872,10 @@ LogicalResult checkSupportedBitcastShape(VMIBitcastOp op, return success(); } -LogicalResult checkSupportedChannelSplitShape( - const VMITargetCapabilityRegistry &capabilities, VMIChannelSplitOp op, - std::string *reason = nullptr) { +LogicalResult +checkSupportedChannelSplitShape(const VMITargetCapabilityRegistry &capabilities, + VMIChannelSplitOp op, + std::string *reason = nullptr) { auto fail = [&](const Twine &message) -> LogicalResult { if (reason) *reason = message.str(); @@ -6024,9 +6930,10 @@ LogicalResult checkSupportedChannelSplitShape( return success(); } -LogicalResult checkSupportedChannelMergeShape( - const VMITargetCapabilityRegistry &capabilities, VMIChannelMergeOp op, - std::string *reason = nullptr) { +LogicalResult +checkSupportedChannelMergeShape(const VMITargetCapabilityRegistry &capabilities, + VMIChannelMergeOp op, + std::string *reason = nullptr) { auto fail = [&](const Twine &message) -> LogicalResult { if (reason) *reason = message.str(); @@ -6176,9 +7083,9 @@ LogicalResult checkSupportedCompressStoreShape( return fail("requires contiguous value and mask layouts"); VMICapabilityResult destinationCapability = - capabilities.supportsUBPointerMemory( - op.getDestination().getType(), "destination", "pto.vstur", - "pto.vstur stores only to UB"); + capabilities.supportsUBPointerMemory(op.getDestination().getType(), + "destination", "pto.vstur", + "pto.vstur stores only to UB"); if (!destinationCapability.isSupported()) return fail(destinationCapability.reason); @@ -6275,6 +7182,15 @@ LogicalResult checkSupportedGroupReduceAddFShape( VMILayoutAttr maskLayout = maskType.getLayoutAttr(); if (!sourceLayout || !resultLayout || !maskLayout) return fail("requires assigned source, mask, and result layouts"); + + FailureOr groupSize = getGroupSizeFromNumGroups( + sourceType, op.getNumGroupsAttr().getInt(), reason); + if (failed(groupSize)) + return failure(); + if (succeeded(checkS16Block8GroupReduceShape(op, reason))) + return success(); + if (succeeded(checkS32Block8GroupReduceShape(op, reason))) + return success(); if (!sourceLayout.isContiguous() || !resultLayout.isGroupSlots() || resultLayout.getNumGroups() != op.getNumGroupsAttr().getInt() || !maskLayout.isContiguous()) @@ -6296,15 +7212,43 @@ LogicalResult checkSupportedGroupReduceAddFShape( return fail("requires computable source/result/mask physical arity"); if (*sourceArity != *resultArity || *sourceArity != *maskArity) return fail("requires source/result/mask physical arity to match"); - FailureOr groupSize = - getGroupSizeFromNumGroups(sourceType, op.getNumGroupsAttr().getInt(), - reason); - if (failed(groupSize)) + if (succeeded(checkVcgaddGroupReduceShape(sourceType, maskType, resultType, + *groupSize, nullptr))) { + if (resultLayout.getSlots() > 0) { + auto selectedPlan = + op->getAttrOfType(kVMISelectedPlanAttrName); + if (!selectedPlan) + return fail("requires vmi.selected_plan selected by " + "vmi-layout-assignment"); + StringRef expectedPlan = "s8_reduce_contiguous"; + if (selectedPlan.getValue() != expectedPlan) + return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + + "' does not match result layout; expected '" + + expectedPlan + "'"); + } + return success(); + } + if (failed(checkSupportedGroupChunkShape(sourceType, *groupSize, reason))) return failure(); - if (succeeded(checkVcgaddGroupReduceShape( - sourceType, maskType, resultType, *groupSize, nullptr))) + if (resultLayout.getSlots() <= 0) return success(); - return checkSupportedGroupChunkShape(sourceType, *groupSize, reason); + + auto selectedPlan = op->getAttrOfType(kVMISelectedPlanAttrName); + if (!selectedPlan) + return fail("requires vmi.selected_plan selected by " + "vmi-layout-assignment"); + StringRef expectedPlan; + if (sourceLayout.isContiguous() && *groupSize == 64 && + resultLayout.getSlots() == 1) + expectedPlan = "s64_reduce_row_local"; + else + return fail("explicit group_slots group_reduce_addf chunk path has no " + "registered selected_plan for the assigned layouts"); + if (selectedPlan.getValue() != expectedPlan) + return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + + "' does not match result layout; expected '" + expectedPlan + + "'"); + return success(); } LogicalResult checkSupportedGroupBroadcastShape( @@ -6335,6 +7279,27 @@ LogicalResult checkSupportedGroupBroadcastShape( if (resultLayout.isGroupSlots()) return fail("requires dense result layout"); + if (sourceLayout.getSlots() > 0) { + auto selectedPlan = op->getAttrOfType(kVMISelectedPlanAttrName); + if (!selectedPlan) + return fail("requires vmi.selected_plan selected by " + "vmi-layout-assignment"); + + StringRef expectedPlan; + if (sourceLayout.getSlots() == 8) + expectedPlan = "group_broadcast_slots8_vselr"; + else if (sourceLayout.getSlots() == 1) + expectedPlan = "group_broadcast_slots1_vselr"; + else + return fail("supports only slots=8 or slots=1 group_broadcast source " + "layouts"); + + if (selectedPlan.getValue() != expectedPlan) + return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + + "' does not match source layout; expected '" + expectedPlan + + "'"); + } + std::string fullChunkReason; if (failed(checkFullDataPhysicalChunks(sourceType, &fullChunkReason))) return fail(Twine("requires full source physical chunks; ") + @@ -6350,9 +7315,8 @@ LogicalResult checkSupportedGroupBroadcastShape( if (failed(lanesPerPart) || failed(resultLanesPerPart) || *lanesPerPart != *resultLanesPerPart) return fail("requires matching physical lanes per part"); - FailureOr groupSize = - getGroupSizeFromNumGroups(sourceType, op.getNumGroupsAttr().getInt(), - reason); + FailureOr groupSize = getGroupSizeFromNumGroups( + sourceType, op.getNumGroupsAttr().getInt(), reason); if (failed(groupSize)) return failure(); if (*lanesPerPart % *groupSize != 0 && *groupSize % *lanesPerPart != 0) @@ -6364,9 +7328,14 @@ LogicalResult checkSupportedGroupBroadcastShape( return fail("requires known result layout factor"); if (*resultFactor == 1) return success(); + bool blockFragmentSmallGroup = + resultLayout.isDeinterleaved() && resultLayout.getBlockElems() > 1 && + *groupSize < *lanesPerPart && + *lanesPerPart % resultLayout.getBlockElems() == 0; + if (blockFragmentSmallGroup) + return success(); int64_t logicalSpanPerResultChunk = *lanesPerPart * *resultFactor; - if (*groupSize < *lanesPerPart || - *groupSize % logicalSpanPerResultChunk != 0) + if (*groupSize < *lanesPerPart || *groupSize % logicalSpanPerResultChunk != 0) return fail("deinterleaved result requires every physical result chunk to " "stay within one logical group"); return success(); @@ -6382,9 +7351,8 @@ checkSupportedFmaShape(const VMITargetCapabilityRegistry &capabilities, }; auto lhsType = cast(op.getLhs().getType()); - VMICapabilityResult elementCapability = - capabilities.supportsElementType(lhsType.getElementType(), - VMIElementPurpose::VMula); + VMICapabilityResult elementCapability = capabilities.supportsElementType( + lhsType.getElementType(), VMIElementPurpose::VMula); if (!elementCapability.isSupported()) return fail(elementCapability.reason); @@ -6408,9 +7376,8 @@ checkSupportedReluShape(const VMITargetCapabilityRegistry &capabilities, if (failed(checkSupportedMaskableVReg(capabilities, resultType, reason))) return failure(); - VMICapabilityResult elementCapability = - capabilities.supportsElementType(resultType.getElementType(), - VMIElementPurpose::VRelu); + VMICapabilityResult elementCapability = capabilities.supportsElementType( + resultType.getElementType(), VMIElementPurpose::VRelu); if (!elementCapability.isSupported()) return fail(elementCapability.reason); @@ -6447,17 +7414,19 @@ void emitEnsureLayoutMaterializationError(VMIEnsureLayoutOp ensure, "packing plan"; } -LogicalResult verifySupportedVMIToVPTOOps( - ModuleOp module, const VMITargetCapabilityRegistry &capabilities, - bool enableStableGatherMaskedLoad) { - auto emitMemoryUnsupported = [&](Operation *op, StringRef opName, - VMIVRegType type, Value source, - std::optional constantOffset) - -> WalkResult { +LogicalResult +verifySupportedVMIToVPTOOps(ModuleOp module, + const VMITargetCapabilityRegistry &capabilities, + bool enableStableGatherMaskedLoad) { + auto emitMemoryUnsupported = + [&](Operation *op, StringRef opName, VMIVRegType type, Value source, + std::optional constantOffset, + std::optional explicitFullReadElems = + std::nullopt) -> WalkResult { std::string reason; if (succeeded(checkSupportedLoadShape(capabilities, type, source, source.getType(), constantOffset, - &reason))) + explicitFullReadElems, &reason))) return WalkResult::advance(); op->emitError() @@ -6486,13 +7455,13 @@ LogicalResult verifySupportedVMIToVPTOOps( [&](Operation *op, StringRef opName, VMIVRegType type, VMIElementPurpose purpose, StringRef elementContract) -> WalkResult { std::string reason; - if (succeeded(checkSupportedTargetElementVReg( - capabilities, type, purpose, elementContract, &reason))) + if (succeeded(checkSupportedTargetElementVReg(capabilities, type, purpose, + elementContract, &reason))) return WalkResult::advance(); op->emitError() - << kVMIDiagUnsupportedPrefix << opName - << " direct lowering requires " << elementContract + << kVMIDiagUnsupportedPrefix << opName << " direct lowering requires " + << elementContract << " and physical vreg parts with b8/b16/b32 predicate masks (" << reason << ")"; return WalkResult::interrupt(); @@ -6532,10 +7501,15 @@ LogicalResult verifySupportedVMIToVPTOOps( return WalkResult::interrupt(); } - if (auto load = dyn_cast(op)) + if (auto load = dyn_cast(op)) { + std::optional explicitFullReadElems; + if (auto attr = load.getFullReadElemsAttr()) + explicitFullReadElems = attr.getInt(); return emitMemoryUnsupported( op, "pto.vmi.load", cast(load.getResult().getType()), - load.getSource(), getConstantIndexValue(load.getOffset())); + load.getSource(), getConstantIndexValue(load.getOffset()), + explicitFullReadElems); + } if (auto load = dyn_cast(op)) { std::string reason; if (succeeded(checkSupportedGroupLoadShape(capabilities, load, &reason))) @@ -6548,6 +7522,20 @@ LogicalResult verifySupportedVMIToVPTOOps( << reason << ")"; return WalkResult::interrupt(); } + if (auto load = dyn_cast(op)) { + std::string reason; + if (succeeded( + checkSupportedGroupSlotLoadShape(capabilities, load, &reason))) + return WalkResult::advance(); + load.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.group_slot_load requires explicit group_slots result " + "layout matching num_groups, a supported UB pointer source, " + "and either slots=8 with constant unit source_group_stride or " + "slots=1 row-local lowering (" + << reason << ")"; + return WalkResult::interrupt(); + } if (auto load = dyn_cast(op)) { if (enableStableGatherMaskedLoad) { load.emitError() @@ -6557,8 +7545,7 @@ LogicalResult verifySupportedVMIToVPTOOps( return WalkResult::interrupt(); } std::string reason; - if (succeeded(checkSupportedMaskedLoadShape(capabilities, load, - &reason))) + if (succeeded(checkSupportedMaskedLoadShape(capabilities, load, &reason))) return WalkResult::advance(); load.emitError() << kVMIDiagUnsupportedPrefix @@ -6582,8 +7569,7 @@ LogicalResult verifySupportedVMIToVPTOOps( } if (auto load = dyn_cast(op)) { std::string reason; - if (succeeded(checkSupportedExpandLoadShape(capabilities, load, - &reason))) + if (succeeded(checkSupportedExpandLoadShape(capabilities, load, &reason))) return WalkResult::advance(); load.emitError() << kVMIDiagUnsupportedPrefix @@ -6611,8 +7597,8 @@ LogicalResult verifySupportedVMIToVPTOOps( } if (auto store = dyn_cast(op)) { std::string reason; - if (succeeded(checkSupportedGroupStoreShape(capabilities, store, - &reason))) + if (succeeded( + checkSupportedGroupStoreShape(capabilities, store, &reason))) return WalkResult::advance(); store.emitError() << kVMIDiagUnsupportedPrefix @@ -6640,8 +7626,7 @@ LogicalResult verifySupportedVMIToVPTOOps( } if (auto scatter = dyn_cast(op)) { std::string reason; - if (succeeded(checkSupportedScatterShape(capabilities, scatter, - &reason))) + if (succeeded(checkSupportedScatterShape(capabilities, scatter, &reason))) return WalkResult::advance(); scatter.emitError() << kVMIDiagUnsupportedPrefix @@ -6660,8 +7645,7 @@ LogicalResult verifySupportedVMIToVPTOOps( if (auto tileWrite = dyn_cast(op)) { std::string reason; if (succeeded(checkSupportedStoreShape( - capabilities, - cast(tileWrite.getValue().getType()), + capabilities, cast(tileWrite.getValue().getType()), tileWrite.getDestination(), tileWrite.getDestination().getType(), &reason))) return WalkResult::advance(); @@ -6728,88 +7712,62 @@ LogicalResult verifySupportedVMIToVPTOOps( if (auto addf = dyn_cast(op)) return emitTargetElementUnsupported( - op, "pto.vmi.addf", - cast(addf.getResult().getType()), - VMIElementPurpose::F16BF16F32, - "f16/bf16/f32 element type"); + op, "pto.vmi.addf", cast(addf.getResult().getType()), + VMIElementPurpose::F16BF16F32, "f16/bf16/f32 element type"); if (auto addi = dyn_cast(op)) - return emitMaskableUnsupported(op, "pto.vmi.addi", - cast( - addi.getResult().getType())); + return emitMaskableUnsupported( + op, "pto.vmi.addi", cast(addi.getResult().getType())); if (auto subf = dyn_cast(op)) return emitTargetElementUnsupported( - op, "pto.vmi.subf", - cast(subf.getResult().getType()), - VMIElementPurpose::F16BF16F32, - "f16/bf16/f32 element type"); + op, "pto.vmi.subf", cast(subf.getResult().getType()), + VMIElementPurpose::F16BF16F32, "f16/bf16/f32 element type"); if (auto subi = dyn_cast(op)) - return emitMaskableUnsupported(op, "pto.vmi.subi", - cast( - subi.getResult().getType())); + return emitMaskableUnsupported( + op, "pto.vmi.subi", cast(subi.getResult().getType())); if (auto mulf = dyn_cast(op)) return emitTargetElementUnsupported( - op, "pto.vmi.mulf", - cast(mulf.getResult().getType()), - VMIElementPurpose::F16BF16F32, - "f16/bf16/f32 element type"); + op, "pto.vmi.mulf", cast(mulf.getResult().getType()), + VMIElementPurpose::F16BF16F32, "f16/bf16/f32 element type"); if (auto muli = dyn_cast(op)) - return emitMaskableUnsupported(op, "pto.vmi.muli", - cast( - muli.getResult().getType())); + return emitMaskableUnsupported( + op, "pto.vmi.muli", cast(muli.getResult().getType())); if (auto divf = dyn_cast(op)) return emitTargetElementUnsupported( - op, "pto.vmi.divf", - cast(divf.getResult().getType()), - VMIElementPurpose::F16F32, - "f16/f32 element type"); + op, "pto.vmi.divf", cast(divf.getResult().getType()), + VMIElementPurpose::F16F32, "f16/f32 element type"); if (auto minf = dyn_cast(op)) return emitTargetElementUnsupported( - op, "pto.vmi.minf", - cast(minf.getResult().getType()), - VMIElementPurpose::F16BF16F32, - "f16/bf16/f32 element type"); + op, "pto.vmi.minf", cast(minf.getResult().getType()), + VMIElementPurpose::F16BF16F32, "f16/bf16/f32 element type"); if (auto maxf = dyn_cast(op)) return emitTargetElementUnsupported( - op, "pto.vmi.maxf", - cast(maxf.getResult().getType()), - VMIElementPurpose::F16BF16F32, - "f16/bf16/f32 element type"); + op, "pto.vmi.maxf", cast(maxf.getResult().getType()), + VMIElementPurpose::F16BF16F32, "f16/bf16/f32 element type"); if (auto negf = dyn_cast(op)) return emitTargetElementUnsupported( - op, "pto.vmi.negf", - cast(negf.getResult().getType()), - VMIElementPurpose::F16F32, - "f16/f32 element type"); + op, "pto.vmi.negf", cast(negf.getResult().getType()), + VMIElementPurpose::F16F32, "f16/f32 element type"); if (auto absf = dyn_cast(op)) return emitTargetElementUnsupported( - op, "pto.vmi.absf", - cast(absf.getResult().getType()), - VMIElementPurpose::F16F32, - "f16/f32 element type"); + op, "pto.vmi.absf", cast(absf.getResult().getType()), + VMIElementPurpose::F16F32, "f16/f32 element type"); if (auto absi = dyn_cast(op)) return emitTargetElementUnsupported( - op, "pto.vmi.absi", - cast(absi.getResult().getType()), + op, "pto.vmi.absi", cast(absi.getResult().getType()), VMIElementPurpose::SignlessOrSignedI8I16I32, "signless/signed i8/i16/i32 element type"); if (auto sqrt = dyn_cast(op)) return emitTargetElementUnsupported( - op, "pto.vmi.sqrt", - cast(sqrt.getResult().getType()), - VMIElementPurpose::F16F32, - "f16/f32 element type"); + op, "pto.vmi.sqrt", cast(sqrt.getResult().getType()), + VMIElementPurpose::F16F32, "f16/f32 element type"); if (auto exp = dyn_cast(op)) return emitTargetElementUnsupported( - op, "pto.vmi.exp", - cast(exp.getResult().getType()), - VMIElementPurpose::F16F32, - "f16/f32 element type"); + op, "pto.vmi.exp", cast(exp.getResult().getType()), + VMIElementPurpose::F16F32, "f16/f32 element type"); if (auto ln = dyn_cast(op)) return emitTargetElementUnsupported( - op, "pto.vmi.ln", - cast(ln.getResult().getType()), - VMIElementPurpose::F16F32, - "f16/f32 element type"); + op, "pto.vmi.ln", cast(ln.getResult().getType()), + VMIElementPurpose::F16F32, "f16/f32 element type"); if (auto relu = dyn_cast(op)) { std::string reason; if (succeeded(checkSupportedReluShape(capabilities, relu, &reason))) @@ -6822,32 +7780,27 @@ LogicalResult verifySupportedVMIToVPTOOps( return WalkResult::interrupt(); } if (auto andi = dyn_cast(op)) - return emitMaskableUnsupported(op, "pto.vmi.andi", - cast( - andi.getResult().getType())); + return emitMaskableUnsupported( + op, "pto.vmi.andi", cast(andi.getResult().getType())); if (auto ori = dyn_cast(op)) - return emitMaskableUnsupported(op, "pto.vmi.ori", - cast(ori.getResult().getType())); + return emitMaskableUnsupported( + op, "pto.vmi.ori", cast(ori.getResult().getType())); if (auto xori = dyn_cast(op)) - return emitMaskableUnsupported(op, "pto.vmi.xori", - cast( - xori.getResult().getType())); + return emitMaskableUnsupported( + op, "pto.vmi.xori", cast(xori.getResult().getType())); if (auto shli = dyn_cast(op)) - return emitMaskableUnsupported(op, "pto.vmi.shli", - cast( - shli.getResult().getType())); + return emitMaskableUnsupported( + op, "pto.vmi.shli", cast(shli.getResult().getType())); if (auto shrui = dyn_cast(op)) - return emitMaskableUnsupported(op, "pto.vmi.shrui", - cast( - shrui.getResult().getType())); + return emitMaskableUnsupported( + op, "pto.vmi.shrui", cast(shrui.getResult().getType())); if (auto notOp = dyn_cast(op)) - return emitMaskableUnsupported(op, "pto.vmi.not", - cast( - notOp.getResult().getType())); + return emitMaskableUnsupported( + op, "pto.vmi.not", cast(notOp.getResult().getType())); if (auto select = dyn_cast(op)) - return emitMaskableUnsupported(op, "pto.vmi.select", - cast( - select.getResult().getType())); + return emitMaskableUnsupported( + op, "pto.vmi.select", + cast(select.getResult().getType())); if (auto cmpf = dyn_cast(op)) { WalkResult target = emitTargetElementUnsupported( @@ -6874,8 +7827,8 @@ LogicalResult verifySupportedVMIToVPTOOps( if (auto activePrefix = dyn_cast(op)) { std::string reason; - if (succeeded(checkSupportedActivePrefixIndexShape(activePrefix, - &reason))) + if (succeeded( + checkSupportedActivePrefixIndexShape(activePrefix, &reason))) return WalkResult::advance(); activePrefix.emitError() << kVMIDiagUnsupportedPrefix @@ -7013,14 +7966,18 @@ LogicalResult verifySupportedVMIToVPTOOps( } if (auto truncf = dyn_cast(op)) { - if (succeeded(checkSupportedTruncFShape(truncf))) + std::string reason; + if (succeeded(checkSupportedTruncFShape(truncf, &reason))) return WalkResult::advance(); truncf.emitError() << kVMIDiagUnsupportedPrefix << "pto.vmi.truncf supports only f32 deinterleaved=2 source parts " "to one contiguous f16 result chunk or f32 deinterleaved=4 " - "source parts to one contiguous fp8-like result chunk"; + "source parts to one contiguous fp8-like result chunk, or f32 " + "group_slots(num_groups=G, slots=1) to f16 " + "group_slots(num_groups=G, slots=1) with selected_plan (" + << reason << ")"; return WalkResult::interrupt(); } @@ -7041,8 +7998,8 @@ LogicalResult verifySupportedVMIToVPTOOps( if (auto split = dyn_cast(op)) { int64_t channels = split.getNumResults(); std::string reason; - if (succeeded(checkSupportedChannelSplitShape(capabilities, split, - &reason))) + if (succeeded( + checkSupportedChannelSplitShape(capabilities, split, &reason))) return WalkResult::advance(); if (channels != 2 && channels != 4) @@ -7062,8 +8019,8 @@ LogicalResult verifySupportedVMIToVPTOOps( if (auto merge = dyn_cast(op)) { int64_t channels = merge.getInputs().size(); std::string reason; - if (succeeded(checkSupportedChannelMergeShape(capabilities, merge, - &reason))) + if (succeeded( + checkSupportedChannelMergeShape(capabilities, merge, &reason))) return WalkResult::advance(); if (channels != 2 && channels != 4) @@ -7119,8 +8076,7 @@ LogicalResult verifySupportedVMIToVPTOOps( return failure(result.wasInterrupted()); } -struct VMIToVPTOPass - : public mlir::pto::impl::VMIToVPTOBase { +struct VMIToVPTOPass : public mlir::pto::impl::VMIToVPTOBase { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VMIToVPTOPass) void runOnOperation() override { @@ -7130,8 +8086,8 @@ struct VMIToVPTOPass return; } VMITargetCapabilityRegistry capabilities; - if (failed(verifySupportedVMIToVPTOOps( - module, capabilities, enableStableGatherMaskedLoad))) { + if (failed(verifySupportedVMIToVPTOOps(module, capabilities, + enableStableGatherMaskedLoad))) { signalPassFailure(); return; } @@ -7140,13 +8096,11 @@ struct VMIToVPTOPass VMIToVPTOTypeConverter typeConverter; RewritePatternSet patterns(context); - populateVMIOneToNConversionPatterns(typeConverter, patterns, - capabilities); + populateVMIOneToNConversionPatterns(typeConverter, patterns, capabilities); if (failed(applyPartialOneToNConversion(module, typeConverter, std::move(patterns)))) { - module.emitError() - << kVMIDiagResidualOpPrefix - << "failed to convert all VMI ops/types to VPTO"; + module.emitError() << kVMIDiagResidualOpPrefix + << "failed to convert all VMI ops/types to VPTO"; signalPassFailure(); return; } diff --git a/test/lit/vmi/vmi_create_group_mask_invalid.pto b/test/lit/vmi/vmi_create_group_mask_invalid.pto new file mode 100644 index 0000000000..0c3aec3d65 --- /dev/null +++ b/test/lit/vmi/vmi_create_group_mask_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s 2>&1 | FileCheck %s + +module { + func.func @vmi_create_group_mask_lane_count_invalid() { + %c12 = arith.constant 12 : index + // CHECK: pto.vmi.create_group_mask + // CHECK-SAME: requires result lane count to equal num_groups * group_size + %mask = pto.vmi.create_group_mask %c12 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<127xpred> + return + } +} diff --git a/test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto b/test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto new file mode 100644 index 0000000000..dce36f1b5d --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto @@ -0,0 +1,75 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_broadcast_dense_group_users( + %base: !pto.ptr, + %copy_out: !pto.ptr, + %sum_out: !pto.ptr, + %off: index, + %scale: f32) { + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %scale_v = pto.vmi.broadcast %scale + : f32 -> !pto.vmi.vreg<256xf32> + %x = pto.vmi.load %base[%off] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %copy = pto.vmi.addf %x, %scale_v + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + pto.vmi.store %copy, %copy_out[%off] + : !pto.vmi.vreg<256xf32>, !pto.ptr + %mask = pto.vmi.create_group_mask %c32 + {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + %prod = pto.vmi.mulf %x, %scale_v + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %prod, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %sum_out[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_broadcast_dense_group_users( +// ASSIGN: %[[SCALE:.*]] = pto.vmi.broadcast +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[X:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[COPY:.*]] = pto.vmi.addf %[[X]], %[[SCALE]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[COPY_DENSE:.*]] = pto.vmi.ensure_layout %[[COPY]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[COPY_DENSE]] +// ASSIGN: pto.vmi.create_group_mask +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[PROD:.*]] = pto.vmi.mulf %[[X]], %[[SCALE]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.create_group_mask +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[PROD]], %[[MASK]] +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_broadcast_dense_group_users( +// LOWER-COUNT-4: pto.vdup +// LOWER-COUNT-4: pto.vmul +// LOWER-COUNT-4: pto.vcgadd +// LOWER: pto.vadd +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_call_argument_boundary.pto b/test/lit/vmi/vmi_layout_assignment_call_argument_boundary.pto new file mode 100644 index 0000000000..49f2c5e2a8 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_call_argument_boundary.pto @@ -0,0 +1,74 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func private @consume(%x: !pto.vmi.vreg<256xf32>, + %mask: !pto.vmi.mask<256xpred>, + %out: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %out[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + return + } + + func.func @caller(%base: !pto.ptr, + %out: !pto.ptr, + %off: index) { + %c32 = arith.constant 32 : index + %x = pto.vmi.load %base[%off] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %mask = pto.vmi.create_group_mask %c32 + {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + call @consume(%x, %mask, %out, %off) + : (!pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred>, + !pto.ptr, index) -> () + return + } +} + +// ASSIGN-LABEL: func.func private @consume( +// ASSIGN-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[X_SPLIT:.*]] = pto.vmi.ensure_layout +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.ensure_mask_layout +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X_SPLIT]], %[[MASK_SPLIT]] +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN-LABEL: func.func @caller( +// ASSIGN: %[[X:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.create_group_mask +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: call @consume(%[[X]], %[[MASK]] +// ASSIGN-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> + +// LOWER-LABEL: func.func private @consume( +// LOWER-SAME: !pto.vreg<64xf32> +// LOWER-SAME: !pto.mask +// LOWER: pto.vdintlv +// LOWER: pto.pdintlv_b32 +// LOWER-COUNT-4: pto.vcgadd +// LOWER: pto.vsts +// LOWER-LABEL: func.func @caller( +// LOWER: call @consume( +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto b/test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto new file mode 100644 index 0000000000..f4790b5432 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_create_group_mask_s16( + %base: !pto.ptr, %dst: !pto.ptr, %off: index) { + %c1 = arith.constant 1 : index + %c12 = arith.constant 12 : index + %c16 = arith.constant 16 : index + %x = pto.vmi.group_load %base[%off], %c16 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_group_mask %c12 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_create_group_mask_s16( +// ASSIGN: %[[X:.*]] = pto.vmi.group_load +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.create_group_mask +// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.create_group_mask +// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] +// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_create_group_mask_s16( +// LOWER: pto.pset_b32 "PAT_ALL" +// LOWER: pto.plt_b32 +// LOWER: pto.pnot +// LOWER: pto.pand +// LOWER: pto.por +// LOWER-COUNT-2: pto.vcgadd +// LOWER: pto.vadd +// LOWER: pto.vsts +// LOWER-NOT: PAT_M4 +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_dense_f16_f32_store.pto b/test/lit/vmi/vmi_layout_assignment_dense_f16_f32_store.pto new file mode 100644 index 0000000000..147245c484 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_dense_f16_f32_store.pto @@ -0,0 +1,77 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_dense_f16_to_f32_store( + %src: !pto.ptr, + %dst: !pto.ptr, + %off: index) { + %x16 = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<128xf16> + %x32 = pto.vmi.extf %x16 + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + pto.vmi.store %x32, %dst[%off] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } + + func.func @vmi_layout_assignment_dense_f32_to_f16_store( + %src: !pto.ptr, + %dst: !pto.ptr, + %off: index) { + %x32 = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %x16 = pto.vmi.truncf %x32 + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + pto.vmi.store %x16, %dst[%off] + : !pto.vmi.vreg<128xf16>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_dense_f16_to_f32_store( +// ASSIGN: %[[X16:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[X32:.*]] = pto.vmi.extf %[[X16]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[DENSE:.*]] = pto.vmi.ensure_layout %[[X32]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[DENSE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_dense_f16_to_f32_store( +// LOWER: pto.vlds +// LOWER: pto.vcvt {{.*}} {part = "EVEN"} +// LOWER: pto.vcvt {{.*}} {part = "ODD"} +// LOWER: pto.vintlv +// LOWER-COUNT-2: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_dense_f32_to_f16_store( +// ASSIGN: %[[X32:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-NOT: pto.vmi.ensure_layout +// ASSIGN: %[[X16:.*]] = pto.vmi.truncf %[[X32]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[X16]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_dense_f32_to_f16_store( +// LOWER: pto.vldsx2 +// LOWER: pto.vcvt {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} +// LOWER: pto.vor +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto b/test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto new file mode 100644 index 0000000000..a93ae52c17 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto @@ -0,0 +1,58 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_dense_group_reduce_multi_consumer( + %src: !pto.ptr, + %sum_out: !pto.ptr, + %copy_out: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %x = pto.vmi.load %src[%off] : !pto.ptr -> !pto.vmi.vreg<256xf32> + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %sum_out[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + pto.vmi.store %x, %copy_out[%off] + : !pto.vmi.vreg<256xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_dense_group_reduce_multi_consumer( +// ASSIGN: %[[X:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[X_SPLIT:.*]] = pto.vmi.ensure_layout %[[X]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X_SPLIT]], %[[MASK_SPLIT]] +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr +// ASSIGN: pto.vmi.store %[[X]] +// ASSIGN-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_dense_group_reduce_multi_consumer( +// LOWER-COUNT-4: pto.vlds +// LOWER: pto.vdintlv +// LOWER: pto.vdintlv +// LOWER: pto.vdintlv +// LOWER: pto.vdintlv +// LOWER-COUNT-4: pto.vcgadd +// LOWER: pto.vsts +// LOWER-COUNT-4: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto b/test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto new file mode 100644 index 0000000000..af6623a995 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_dense_store_group_slots_invalid( + %source: !pto.vmi.vreg<64xf32>, + %mask: !pto.vmi.mask<64xpred>, + %dst: !pto.ptr, + %off: index) { + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<64xf32>, !pto.vmi.mask<64xpred> + -> !pto.vmi.vreg<64xf32> + // CHECK: {{VMI-UNSUPPORTED}}: pto.vmi.store operand #0 has type + // CHECK-SAME: #pto.vmi.layout + // CHECK-SAME: requires + // CHECK-SAME: #pto.vmi.layout + // CHECK-SAME: pto.vmi.ensure_layout cannot materialize this conversion + // CHECK: failed helper conversion + // CHECK-SAME: unsupported source/result layout pair + pto.vmi.store %sum, %dst[%off] + : !pto.vmi.vreg<64xf32>, !pto.ptr + return + } +} diff --git a/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto b/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto new file mode 100644 index 0000000000..e43d2e5591 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto @@ -0,0 +1,65 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_f32_f8_store_reduce( + %src: !pto.ptr, + %sum: !pto.ptr, + %out8: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %x32 = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %sumv = pto.vmi.group_reduce_addf %x32, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sumv, %sum[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + %x8 = pto.vmi.truncf %x32 + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %x8, %out8[%off] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_f32_f8_store_reduce( +// ASSIGN: %[[X32:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X32]], %[[MASK]] +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_dintlv4" +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN: %[[X8:.*]] = pto.vmi.truncf %[[X32]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[X8]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_f32_f8_store_reduce( +// LOWER-COUNT-4: pto.vlds +// LOWER-COUNT-4: pto.vdintlv +// LOWER-COUNT-4: pto.vcgadd +// LOWER-COUNT-3: pto.vadd +// LOWER: pto.vsts +// LOWER: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "P1", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "P2", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "P3", rnd = "R", sat = "SAT"} +// LOWER-COUNT-3: pto.vor +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_f8_compute_f8.pto b/test/lit/vmi/vmi_layout_assignment_f8_compute_f8.pto new file mode 100644 index 0000000000..0ce6b6b295 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_f8_compute_f8.pto @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_f8_compute_f8( + %src: !pto.ptr, + %scale: f32, + %dst: !pto.ptr, + %off: index) { + %x8 = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<256xf8E4M3FN> + %x32 = pto.vmi.extf %x8 + : !pto.vmi.vreg<256xf8E4M3FN> -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.broadcast %scale + : f32 -> !pto.vmi.vreg<256xf32> + %y32 = pto.vmi.mulf %x32, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %y8 = pto.vmi.truncf %y32 + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %y8, %dst[%off] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_f8_compute_f8( +// ASSIGN: %[[X8:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> +// ASSIGN: %[[X32:.*]] = pto.vmi.extf %[[X8]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[SCALE:.*]] = pto.vmi.broadcast +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[Y32:.*]] = pto.vmi.mulf %[[X32]], %[[SCALE]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[Y8:.*]] = pto.vmi.truncf %[[Y32]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[Y8]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_f8_compute_f8( +// LOWER: pto.vlds +// LOWER-COUNT-4: pto.vcvt {{.*}} {part = "P{{[0-3]}}"} +// LOWER-COUNT-4: pto.vdup +// LOWER-COUNT-4: pto.vmul +// LOWER: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "P1", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "P2", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "P3", rnd = "R", sat = "SAT"} +// LOWER-COUNT-3: pto.vor +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto b/test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto new file mode 100644 index 0000000000..7df6946741 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_broadcast_multi_consumer( + %src: !pto.ptr, + %sum_out: !pto.ptr, + %dense_out: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %x = pto.vmi.load %src[%off] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + %b_for_mul = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %y = pto.vmi.mulf %x, %b_for_mul + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %ysum, %sum_out[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + %b_for_cast = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %h = pto.vmi.truncf %b_for_cast + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + pto.vmi.store %h, %dense_out[%off] + : !pto.vmi.vreg<128xf16>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_broadcast_multi_consumer( +// ASSIGN: %[[X:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]] +// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[B_MUL:.*]] = pto.vmi.group_broadcast %[[SUM]] +// ASSIGN-SAME: vmi.selected_plan = "group_broadcast_slots8_vselr" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[Y:.*]] = pto.vmi.mulf %[[X]], %[[B_MUL]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[YSUM:.*]] = pto.vmi.group_reduce_addf %[[Y]] +// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" +// ASSIGN: pto.vmi.group_store %[[YSUM]] +// ASSIGN: %[[B_CAST:.*]] = pto.vmi.group_broadcast %[[SUM]] +// ASSIGN-SAME: vmi.selected_plan = "group_broadcast_slots8_vselr" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[B_CAST_SPLIT:.*]] = pto.vmi.ensure_layout %[[B_CAST]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[H:.*]] = pto.vmi.truncf %[[B_CAST_SPLIT]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[H]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_broadcast_multi_consumer( +// LOWER: pto.vcgadd +// LOWER: pto.vadd +// LOWER: pto.vselr +// LOWER: pto.vselr +// LOWER: pto.vmul +// LOWER: pto.vmul +// LOWER: pto.vcgadd +// LOWER: pto.vsts +// LOWER: pto.vselr +// LOWER: pto.vselr +// LOWER: pto.vcvt +// LOWER: pto.vor +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_group_broadcast_slots8.pto b/test/lit/vmi/vmi_layout_assignment_group_broadcast_slots8.pto new file mode 100644 index 0000000000..7c1e569bf3 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_broadcast_slots8.pto @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_broadcast_slots8( + %source: !pto.vmi.vreg<1024xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<1024xf32> { + %out = pto.vmi.group_broadcast %source {num_groups = 128} + : !pto.vmi.vreg<1024xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<1024xf32> + return %out : !pto.vmi.vreg<1024xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_broadcast_slots8( +// CHECK-SAME: -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.group_broadcast +// CHECK-SAME: vmi.selected_plan = "group_broadcast_slots8_vselr" +// CHECK-SAME: -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_load.pto b/test/lit/vmi/vmi_layout_assignment_group_load.pto new file mode 100644 index 0000000000..2a90d02d08 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_load.pto @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_load( + %source: !pto.ptr, + %row_stride: index) -> !pto.vmi.vreg<512xf32> { + %c0 = arith.constant 0 : index + %out = pto.vmi.group_load %source[%c0], %row_stride {num_groups = 2} + : !pto.ptr -> !pto.vmi.vreg<512xf32> + return %out : !pto.vmi.vreg<512xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_load( +// CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.group_load +// CHECK-SAME: vmi.selected_plan = "group_load_contiguous_chunks" +// CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto new file mode 100644 index 0000000000..c928df5320 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_load_block8_truncf_invalid( + %src: !pto.ptr, + %sum_dst: !pto.ptr, + %dense_dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %stride24 = arith.constant 24 : index + %c128 = arith.constant 128 : index + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %x = pto.vmi.group_load %src[%off], %stride24 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %sum, %sum_dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + // CHECK: {{VMI-UNSUPPORTED}}: pto.vmi.truncf operand #0 has type + // CHECK-SAME: #pto.vmi.layout + // CHECK-SAME: requires + // CHECK-SAME: #pto.vmi.layout + // CHECK-SAME: pto.vmi.ensure_layout cannot materialize this conversion + // CHECK: failed helper conversion + // CHECK-SAME: unsupported source/result layout pair + %h = pto.vmi.truncf %x + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + pto.vmi.store %h, %dense_dst[%off] + : !pto.vmi.vreg<128xf16>, !pto.ptr + return + } +} diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_s16_compact_stride12_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_load_s16_compact_stride12_invalid.pto new file mode 100644 index 0000000000..113467b492 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_load_s16_compact_stride12_invalid.pto @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_load_s16_compact_stride12_invalid( + %base: !pto.ptr, + %mask: !pto.vmi.mask<128xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %stride12 = arith.constant 12 : index + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_load group_size 16 requires constant positive row_stride divisible by 8 f32 elements for the block8 stride plan + // CHECK-SAME: stable gather fallback is not implemented + %x = pto.vmi.group_load %base[%off], %stride12 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } +} diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_s16_stride_store.pto b/test/lit/vmi/vmi_layout_assignment_group_load_s16_stride_store.pto new file mode 100644 index 0000000000..67215442e5 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_load_s16_stride_store.pto @@ -0,0 +1,50 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_load_s16_stride_store( + %base: !pto.ptr, + %mask: !pto.vmi.mask<128xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %stride24 = arith.constant 24 : index + %x = pto.vmi.group_load %base[%off], %stride24 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_load_s16_stride_store( +// ASSIGN: %[[X:.*]] = pto.vmi.group_load +// ASSIGN-SAME: vmi.selected_plan = "s16_group_load_block8_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout +// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] +// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_load_s16_stride_store( +// LOWER-COUNT-2: pto.vsldb +// LOWER-COUNT-2: pto.vcgadd +// LOWER: pto.vadd +// LOWER: pto.vsts {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_s16_unaligned_stride_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_load_s16_unaligned_stride_invalid.pto new file mode 100644 index 0000000000..ed2ed892f9 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_load_s16_unaligned_stride_invalid.pto @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_load_s16_unaligned_stride_invalid( + %base: !pto.ptr, + %mask: !pto.vmi.mask<128xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %stride20 = arith.constant 20 : index + %x = pto.vmi.group_load %base[%off], %stride20 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } +} + +// CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_load group_size 16 requires constant positive row_stride divisible by 8 f32 elements for the block8 stride plan; stable gather fallback is not implemented diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_broadcast_reduce.pto b/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_broadcast_reduce.pto new file mode 100644 index 0000000000..c97a35855b --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_broadcast_reduce.pto @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_load_s32_stride_broadcast_reduce( + %base: !pto.ptr, + %mask: !pto.vmi.mask<256xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %stride40 = arith.constant 40 : index + %x = pto.vmi.group_load %base[%off], %stride40 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + %b = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %y = pto.vmi.mulf %x, %b + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %ysum = pto.vmi.group_reduce_addf %y, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %ysum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_load_s32_stride_broadcast_reduce( +// ASSIGN: %[[X:.*]] = pto.vmi.group_load +// ASSIGN-SAME: vmi.selected_plan = "s32_group_load_block8_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[B:.*]] = pto.vmi.group_broadcast %[[SUM]] +// ASSIGN-SAME: vmi.selected_plan = "group_broadcast_slots8_vselr" +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[Y:.*]] = pto.vmi.mulf %[[X]], %[[B]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK2:.*]] = pto.vmi.ensure_mask_layout +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[YSUM:.*]] = pto.vmi.group_reduce_addf %[[Y]], %[[MASK2]] +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[YSUM]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_load_s32_stride_broadcast_reduce( +// LOWER-COUNT-4: pto.vsldb +// LOWER-COUNT-4: pto.vcgadd +// LOWER-COUNT-4: pto.vselr +// LOWER-COUNT-4: pto.vmul +// LOWER-COUNT-4: pto.vcgadd +// LOWER: pto.vsts {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_store.pto b/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_store.pto new file mode 100644 index 0000000000..0f506a3a1f --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_store.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_load_s32_stride_store( + %base: !pto.ptr, + %mask: !pto.vmi.mask<256xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %stride40 = arith.constant 40 : index + %x = pto.vmi.group_load %base[%off], %stride40 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_load_s32_stride_store( +// ASSIGN: %[[X:.*]] = pto.vmi.group_load +// ASSIGN-SAME: vmi.selected_plan = "s32_group_load_block8_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_load_s32_stride_store( +// LOWER-COUNT-4: pto.vsldb +// LOWER-COUNT-4: pto.vcgadd +// LOWER-COUNT-3: pto.vadd +// LOWER: pto.vsts {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_s32_unaligned_stride_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_load_s32_unaligned_stride_invalid.pto new file mode 100644 index 0000000000..7cd5ffd85d --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_load_s32_unaligned_stride_invalid.pto @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_load_s32_unaligned_stride_invalid( + %base: !pto.ptr, + %mask: !pto.vmi.mask<256xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %stride34 = arith.constant 34 : index + %x = pto.vmi.group_load %base[%off], %stride34 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + return + } +} + +// CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_load group_size 32 requires constant positive row_stride divisible by 8 f32 elements for the block8 stride plan; stable gather fallback is not implemented diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto new file mode 100644 index 0000000000..3bea54d83f --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_reduce_s12_invalid( + %source: !pto.vmi.vreg<96xf32>, + %mask: !pto.vmi.mask<96xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + // CHECK: {{VMI-UNSUPPORTED}}: pto.vmi.group_reduce_addf lowers through pto.vcgadd + // CHECK-SAME: num_groups deriving a group size aligned to physical chunks + // CHECK-SAME: found padding lane in physical chunk + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<96xf32>, !pto.vmi.mask<96xpred> + -> !pto.vmi.vreg<96xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<96xf32>, !pto.ptr + return + } +} diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_store.pto new file mode 100644 index 0000000000..c4652169d4 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_store.pto @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_s16_store( + %source: !pto.vmi.vreg<128xf32>, + %mask: !pto.vmi.mask<128xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s16_store( +// ASSIGN-SAME: %[[SOURCE:.*]]: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: %[[MASK:.*]]: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[SOURCE_SPLIT:.*]] = pto.vmi.ensure_layout %[[SOURCE]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.ensure_mask_layout %[[MASK]] +// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[SOURCE_SPLIT]], %[[MASK_SPLIT]] +// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s16_store( +// LOWER: %[[LO:.*]], %[[HI:.*]] = pto.vdintlv +// LOWER: %[[MLO:.*]], %[[MHI:.*]] = pto.pdintlv_b32 +// LOWER: %[[VL8:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// LOWER: %[[SLO:.*]] = pto.vcgadd %[[LO]], %[[MLO]] +// LOWER: %[[SHI:.*]] = pto.vcgadd %[[HI]], %[[MHI]] +// LOWER: %[[SUM:.*]] = pto.vadd %[[SLO]], %[[SHI]], %[[VL8]] +// LOWER: %[[STORE_MASK:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// LOWER: pto.vsts %[[SUM]], %arg4[%arg5], %[[STORE_MASK]] +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto new file mode 100644 index 0000000000..e9a3e7c9e9 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto @@ -0,0 +1,59 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store( + %source: !pto.vmi.vreg<128xf32>, + %mask: !pto.vmi.mask<128xpred>, + %dst: !pto.ptr, + %off: index) { + %sum32 = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + %sum16 = pto.vmi.truncf %sum32 + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + %rows = pto.vmi.group_broadcast %sum16 {num_groups = 8} + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf16> + pto.vmi.store %rows, %dst[%off] : !pto.vmi.vreg<128xf16>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store( +// ASSIGN-SAME: %arg0: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: %arg1: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[SOURCE:.*]] = pto.vmi.ensure_layout %arg0 +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout %arg1 +// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM32:.*]] = pto.vmi.group_reduce_addf %[[SOURCE]], %[[MASK]] +// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[B32:.*]] = pto.vmi.group_broadcast %[[SUM32]] +// ASSIGN-SAME: vmi.selected_plan = "group_broadcast_slots8_vselr" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[B32_SPLIT:.*]] = pto.vmi.ensure_layout %[[B32]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[B16:.*]] = pto.vmi.truncf %[[B32_SPLIT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[B16]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store( +// LOWER: pto.vcgadd +// LOWER: pto.vcgadd +// LOWER: pto.vadd +// LOWER: pto.vselr +// LOWER: pto.vselr +// LOWER: pto.vcvt +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_broadcast_reduce.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_broadcast_reduce.pto new file mode 100644 index 0000000000..9fb03c80b2 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_broadcast_reduce.pto @@ -0,0 +1,66 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_s32_broadcast_reduce( + %source: !pto.vmi.vreg<256xf32>, + %mask: !pto.vmi.mask<256xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + %broadcast = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %source, %broadcast + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %scaled_sum = pto.vmi.group_reduce_addf %scaled, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %scaled_sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_broadcast_reduce( +// ASSIGN-SAME: %[[SOURCE:arg[0-9]+]]: !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[SOURCE]] +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[BROADCAST:.*]] = pto.vmi.group_broadcast %[[SUM]] +// ASSIGN-SAME: vmi.selected_plan = "group_broadcast_slots8_vselr" +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[SCALED:.*]] = pto.vmi.mulf %[[SOURCE]], %[[BROADCAST]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[SCALED_SUM:.*]] = pto.vmi.group_reduce_addf %[[SCALED]] +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_broadcast_reduce( +// LOWER-DAG: %[[C2:.*]] = arith.constant 2 : i32 +// LOWER-DAG: %[[C4:.*]] = arith.constant 4 : i32 +// LOWER-DAG: %[[C6:.*]] = arith.constant 6 : i32 +// LOWER: pto.vselr +// LOWER: pto.vdup %[[C2]] +// LOWER: pto.vselr +// LOWER: pto.vdup %[[C4]] +// LOWER: pto.vselr +// LOWER: pto.vdup %[[C6]] +// LOWER: pto.vselr +// LOWER-COUNT-4: pto.vmul +// LOWER: pto.vsts {{.*}}, %arg8[%arg9], {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_multitile_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_multitile_store.pto new file mode 100644 index 0000000000..1d61b4196e --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_multitile_store.pto @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_s32_multitile_store( + %source: !pto.vmi.vreg<512xf32>, + %mask: !pto.vmi.mask<512xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 16, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<512xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 16} + : !pto.vmi.vreg<512xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_multitile_store( +// ASSIGN-SAME: %[[SOURCE:.*]]: !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// ASSIGN-SAME: %[[MASK:.*]]: !pto.vmi.mask<512xb32, #pto.vmi.layout> +// ASSIGN: %[[SOURCE_SPLIT:.*]] = pto.vmi.ensure_layout %[[SOURCE]] +// ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.ensure_mask_layout %[[MASK]] +// ASSIGN-SAME: -> !pto.vmi.mask<512xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[SOURCE_SPLIT]], %[[MASK_SPLIT]] +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN-SAME: !pto.vmi.vreg<512xf32, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_multitile_store( +// LOWER-COUNT-8: pto.vdintlv +// LOWER-COUNT-8: pto.pdintlv_b32 +// LOWER: %[[VL8:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// LOWER-COUNT-8: pto.vcgadd +// LOWER: %[[STORE_MASK0:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// LOWER: pto.vsts {{.*}}, %arg16[%arg17], %[[STORE_MASK0]] +// LOWER: %[[STORE_MASK1:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// LOWER: pto.vsts {{.*}}, %arg16[{{.*}}], %[[STORE_MASK1]] +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_store.pto new file mode 100644 index 0000000000..b51dd875b5 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_store.pto @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_s32_store( + %source: !pto.vmi.vreg<256xf32>, + %mask: !pto.vmi.mask<256xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_store( +// ASSIGN-SAME: %[[SOURCE:.*]]: !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN-SAME: %[[MASK:.*]]: !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[SOURCE_SPLIT:.*]] = pto.vmi.ensure_layout %[[SOURCE]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.ensure_mask_layout %[[MASK]] +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[SOURCE_SPLIT]], %[[MASK_SPLIT]] +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_store( +// LOWER-COUNT-4: pto.vdintlv +// LOWER-COUNT-4: pto.pdintlv_b32 +// LOWER: %[[VL8:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// LOWER-COUNT-4: pto.vcgadd +// LOWER-COUNT-3: pto.vadd {{.*}}, {{.*}}, %[[VL8]] +// LOWER: %[[STORE_MASK:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// LOWER: pto.vsts {{.*}}, %arg8[%arg9], %[[STORE_MASK]] +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto new file mode 100644 index 0000000000..0a7550d004 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto @@ -0,0 +1,85 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile( + %src: memref<256xf32>, %dst: !pto.ptr, %off: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c192 = arith.constant 192 : index + %x = pto.vmi.load %src[%c0] + : memref<256xf32> -> !pto.vmi.vreg<192xf32> + %mask = pto.vmi.create_mask %c192 : index -> !pto.vmi.mask<192xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6, reassoc} + : !pto.vmi.vreg<192xf32>, !pto.vmi.mask<192xpred> + -> !pto.vmi.vreg<192xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 6} + : !pto.vmi.vreg<192xf32>, !pto.ptr + return + } + + func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile_ptr_contract( + %src: !pto.ptr, %dst: !pto.ptr, %off: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c192 = arith.constant 192 : index + %x = pto.vmi.load %src[%c0] {full_read_elems = 256} + : !pto.ptr -> !pto.vmi.vreg<192xf32> + %mask = pto.vmi.create_mask %c192 : index -> !pto.vmi.mask<192xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6, reassoc} + : !pto.vmi.vreg<192xf32>, !pto.vmi.mask<192xpred> + -> !pto.vmi.vreg<192xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 6} + : !pto.vmi.vreg<192xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile( +// ASSIGN: %[[X:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<192xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<192xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: !pto.vmi.mask<192xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<192xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile( +// LOWER-DAG: %[[C6:.*]] = arith.constant 6 : i32 +// LOWER-DAG: %[[C48:.*]] = arith.constant 48 : i32 +// LOWER-COUNT-4: pto.vlds +// LOWER-COUNT-3: pto.vdintlv +// LOWER-COUNT-4: pto.plt_b32 %[[C48]] : i32 -> !pto.mask, i32 +// LOWER: %[[SLOTS:.*]], %{{.*}} = pto.plt_b32 %[[C6]] : i32 -> !pto.mask, i32 +// LOWER-COUNT-4: pto.vcgadd +// LOWER: pto.vadd {{.*}}, {{.*}}, %[[SLOTS]] +// LOWER: %[[STORE:.*]], %{{.*}} = pto.plt_b32 %[[C6]] : i32 -> !pto.mask, i32 +// LOWER: pto.vsts {{.*}}, {{.*}}, %[[STORE]] +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile_ptr_contract( +// ASSIGN: %[[PX:.*]] = pto.vmi.load +// ASSIGN-SAME: {full_read_elems = 256 : i64} +// ASSIGN-SAME: -> !pto.vmi.vreg<192xf32, #pto.vmi.layout> +// ASSIGN: %[[PMASK:.*]] = pto.vmi.create_mask %{{.*}} : index -> !pto.vmi.mask<192xb32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_reduce_addf %[[PX]], %[[PMASK]] +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile_ptr_contract( +// LOWER-COUNT-4: pto.vlds +// LOWER-COUNT-3: pto.vdintlv +// LOWER-COUNT-4: pto.vcgadd +// LOWER: pto.vsts diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto new file mode 100644 index 0000000000..c66ff0eb3c --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid( + %source: !pto.vmi.vreg<192xf32>, + %mask: !pto.vmi.mask<192xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + // CHECK: {{VMI-UNSUPPORTED}}: pto.vmi.group_reduce_addf operand #0 has type + // CHECK-SAME: #pto.vmi.layout + // CHECK-SAME: requires + // CHECK-SAME: #pto.vmi.layout + // CHECK-SAME: pto.vmi.ensure_layout cannot materialize this conversion + // CHECK: requires source and result to have the same physical arity + // CHECK-SAME: partial/tail layout materialization requires an explicit packing plan + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 6, reassoc} + : !pto.vmi.vreg<192xf32>, !pto.vmi.mask<192xpred> + -> !pto.vmi.vreg<192xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 6} + : !pto.vmi.vreg<192xf32>, !pto.ptr + return + } +} diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64.pto new file mode 100644 index 0000000000..2e4c9dd02f --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_reduce_s64( + %source: !pto.vmi.vreg<512xf32>, + %mask: !pto.vmi.mask<512xpred>) -> !pto.vmi.vreg<512xf32> { + %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<512xf32> + return %out : !pto.vmi.vreg<512xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_reduce_s64( +// CHECK-SAME: %arg0: !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// CHECK-SAME: %arg1: !pto.vmi.mask<512xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.group_reduce_addf %arg0, %arg1 +// CHECK-SAME: vmi.selected_plan = "s64_reduce_row_local" +// CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_broadcast_reduce.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_broadcast_reduce.pto new file mode 100644 index 0000000000..6fffb7c636 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_broadcast_reduce.pto @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_s64_broadcast_reduce( + %source: !pto.vmi.vreg<512xf32>, + %mask: !pto.vmi.mask<512xpred>, + %dst: !pto.ptr, + %off: index) { + %c8 = arith.constant 8 : index + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<512xf32> + %broadcast = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<512xf32> -> !pto.vmi.vreg<512xf32> + %scaled = pto.vmi.mulf %source, %broadcast + : !pto.vmi.vreg<512xf32>, !pto.vmi.vreg<512xf32> + -> !pto.vmi.vreg<512xf32> + %scaled_sum = pto.vmi.group_reduce_addf %scaled, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<512xf32> + pto.vmi.group_store %scaled_sum, %dst[%off], %c8 {num_groups = 8} + : !pto.vmi.vreg<512xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s64_broadcast_reduce( +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf +// ASSIGN-SAME: vmi.selected_plan = "s64_reduce_row_local" +// ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// ASSIGN: %[[BROADCAST:.*]] = pto.vmi.group_broadcast %[[SUM]] +// ASSIGN-SAME: vmi.selected_plan = "group_broadcast_slots1_vselr" +// ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// ASSIGN: %[[SCALED_SUM:.*]] = pto.vmi.group_reduce_addf +// ASSIGN-SAME: vmi.selected_plan = "s64_reduce_row_local" +// ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s64_broadcast_reduce( +// LOWER-COUNT-8: pto.vcadd +// LOWER-COUNT-8: pto.vdup {{.*}} {position = "LOWEST"} +// LOWER-COUNT-8: pto.vmul +// LOWER-COUNT-8: pto.vcadd +// LOWER-COUNT-8: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_tail_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_tail_store.pto new file mode 100644 index 0000000000..ec8816fbeb --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_tail_store.pto @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_s64_tail_store( + %src: !pto.ptr, %dst: !pto.ptr, %off: index) { + %c8 = arith.constant 8 : index + %c384 = arith.constant 384 : index + %mask = pto.vmi.create_mask %c384 : index -> !pto.vmi.mask<384xpred> + %x = pto.vmi.load %src[%off] : !pto.ptr -> !pto.vmi.vreg<384xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6, reassoc} + : !pto.vmi.vreg<384xf32>, !pto.vmi.mask<384xpred> + -> !pto.vmi.vreg<384xf32> + pto.vmi.group_store %sum, %dst[%off], %c8 {num_groups = 6} + : !pto.vmi.vreg<384xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s64_tail_store( +// ASSIGN: %[[X:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<384xf32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]] +// ASSIGN-SAME: vmi.selected_plan = "s64_reduce_row_local" +// ASSIGN-SAME: -> !pto.vmi.vreg<384xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s64_tail_store( +// LOWER-COUNT-6: pto.vlds +// LOWER-COUNT-6: pto.vcadd +// LOWER-COUNT-6: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_truncf.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_truncf.pto new file mode 100644 index 0000000000..bf38aee552 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_truncf.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_s64_truncf( + %source: !pto.vmi.vreg<512xf32>, + %mask: !pto.vmi.mask<512xpred>, + %dst: !pto.ptr, + %off: index) { + %c16 = arith.constant 16 : index + %sum32 = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<512xf32> + %sum16 = pto.vmi.truncf %sum32 + : !pto.vmi.vreg<512xf32> -> !pto.vmi.vreg<512xf16> + pto.vmi.group_store %sum16, %dst[%off], %c16 {num_groups = 8} + : !pto.vmi.vreg<512xf16>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s64_truncf( +// ASSIGN-SAME: %arg0: !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// ASSIGN-SAME: %arg1: !pto.vmi.mask<512xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM32:.*]] = pto.vmi.group_reduce_addf %arg0, %arg1 +// ASSIGN-SAME: vmi.selected_plan = "s64_reduce_row_local" +// ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// ASSIGN: %[[SUM16:.*]] = pto.vmi.truncf %[[SUM32]] +// ASSIGN-SAME: vmi.selected_plan = "group_slot_cast_slots1_f32_to_f16" +// ASSIGN-SAME: !pto.vmi.vreg<512xf32, #pto.vmi.layout> -> !pto.vmi.vreg<512xf16, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM16]] +// ASSIGN-SAME: !pto.vmi.vreg<512xf16, #pto.vmi.layout>, !pto.ptr, !pto.mask -> !pto.vreg<128xf16> +// LOWER: pto.pge_b16 "PAT_VL1" +// LOWER: pto.vsts {{.*}} : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_slots8.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_slots8.pto new file mode 100644 index 0000000000..c3e876be05 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_slots8.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_reduce_slots8( + %source: !pto.vmi.vreg<64xf32>, + %mask: !pto.vmi.mask<64xpred>) -> !pto.vmi.vreg<64xf32> { + %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<64xf32>, !pto.vmi.mask<64xpred> + -> !pto.vmi.vreg<64xf32> + return %out : !pto.vmi.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_reduce_slots8( +// CHECK-SAME: %arg0: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: %arg1: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.group_reduce_addf %arg0, %arg1 +// CHECK-SAME: vmi.selected_plan = "s8_reduce_contiguous" +// CHECK-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_slots8_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_slots8_store.pto new file mode 100644 index 0000000000..1329965530 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_slots8_store.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_slots8_store( + %source: !pto.vmi.vreg<64xf32>, + %mask: !pto.vmi.mask<64xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<64xf32>, !pto.vmi.mask<64xpred> + -> !pto.vmi.vreg<64xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<64xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_slots8_store( +// ASSIGN-SAME: %arg0: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// ASSIGN-SAME: %arg1: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %arg0, %arg1 +// ASSIGN-SAME: vmi.selected_plan = "s8_reduce_contiguous" +// ASSIGN-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_slots8_store( +// LOWER: %[[SUM:.*]] = pto.vcgadd %arg0, %arg1 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// LOWER: %[[STORE_MASK:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// LOWER: pto.vsts %[[SUM]], %arg2[%arg3], %[[STORE_MASK]] : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto new file mode 100644 index 0000000000..9f4349d40e --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto @@ -0,0 +1,58 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_slot_load_slots8( + %src: !pto.ptr, %off: index) -> !pto.vmi.vreg<128xf32> { + %c1 = arith.constant 1 : index + %out = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + return %out : !pto.vmi.vreg<128xf32> + } + + func.func @vmi_layout_assignment_group_slot_load_slots1( + %src: !pto.ptr, %off: index, %stride: index) + -> !pto.vmi.vreg<512xf32> { + %out = pto.vmi.group_slot_load %src[%off], %stride {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<512xf32> + return %out : !pto.vmi.vreg<512xf32> + } + + func.func @vmi_layout_assignment_group_slot_load_slots8_store( + %src: !pto.ptr, %dst: !pto.ptr, %off: index) { + %c1 = arith.constant 1 : index + %out = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %out, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_slot_load_slots8( +// CHECK-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.group_slot_load +// CHECK-SAME: vmi.selected_plan = "group_slot_load_slots8_unit_stride" +// CHECK-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_slot_load_slots1( +// CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.group_slot_load +// CHECK-SAME: vmi.selected_plan = "group_slot_load_slots1_row_local" +// CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_slot_load_slots8_store( +// CHECK: %[[OUT:.*]] = pto.vmi.group_slot_load +// CHECK-SAME: vmi.selected_plan = "group_slot_load_slots8_unit_stride" +// CHECK-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: pto.vmi.group_store %[[OUT]] +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto new file mode 100644 index 0000000000..a96b847256 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto @@ -0,0 +1,76 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_slot_load_dual_layout( + %rhs_base: !pto.ptr, + %source16: !pto.vmi.vreg<128xf32>, + %mask16: !pto.vmi.mask<128xpred>, + %source64: !pto.vmi.vreg<512xf32>, + %mask64: !pto.vmi.mask<512xpred>, + %out16: !pto.ptr, + %out64: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %rhs16 = pto.vmi.group_slot_load %rhs_base[%off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum16 = pto.vmi.group_reduce_addf %source16, %mask16 + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + %outv16 = pto.vmi.addf %sum16, %rhs16 + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %outv16, %out16[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + + %rhs64 = pto.vmi.group_slot_load %rhs_base[%off], %c8 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<512xf32> + %sum64 = pto.vmi.group_reduce_addf %source64, %mask64 + {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<512xf32> + %outv64 = pto.vmi.addf %sum64, %rhs64 + : !pto.vmi.vreg<512xf32>, !pto.vmi.vreg<512xf32> + -> !pto.vmi.vreg<512xf32> + pto.vmi.group_store %outv64, %out64[%off], %c8 {num_groups = 8} + : !pto.vmi.vreg<512xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_slot_load_dual_layout( +// ASSIGN: %[[RHS16:.*]] = pto.vmi.group_slot_load +// ASSIGN-SAME: vmi.selected_plan = "group_slot_load_slots8_unit_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[SUM16:.*]] = pto.vmi.group_reduce_addf +// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.addf %[[SUM16]], %[[RHS16]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[RHS64:.*]] = pto.vmi.group_slot_load +// ASSIGN-SAME: vmi.selected_plan = "group_slot_load_slots1_row_local" +// ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// ASSIGN: %[[SUM64:.*]] = pto.vmi.group_reduce_addf +// ASSIGN-SAME: vmi.selected_plan = "s64_reduce_row_local" +// ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.addf %[[SUM64]], %[[RHS64]] +// ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_slot_load_dual_layout( +// LOWER: pto.vsldb +// LOWER: pto.vsts {{.*}}, %arg21[%arg23], {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// LOWER-COUNT-8: pto.vsldb +// LOWER-COUNT-8: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto new file mode 100644 index 0000000000..e6e459c435 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid( + %src: !pto.ptr, %off: index, %stride: index) { + // CHECK: VMI-UNSUPPORTED: pto.vmi.group_slot_load + // CHECK-SAME: slots=1 group_slot_load currently lowers as one lane-0 vsldb per group + // CHECK-SAME: requires constant positive source_group_stride divisible by 8 elements + // CHECK-SAME: packed or unaligned scalar load lowering is not implemented + // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_slot_load" + // CHECK-SAME: !pto.vmi.vreg<512xf32, #pto.vmi.layout> + %out = pto.vmi.group_slot_load %src[%off], %stride {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<512xf32> + return + } +} diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto new file mode 100644 index 0000000000..f8d7bc8af8 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid( + %src: !pto.ptr, %off: index) { + %c2 = arith.constant 2 : index + // CHECK: VMI-UNSUPPORTED: pto.vmi.group_slot_load + // CHECK-SAME: slots=1 group_slot_load currently lowers as one lane-0 vsldb per group + // CHECK-SAME: requires constant positive source_group_stride divisible by 8 elements + // CHECK-SAME: packed or unaligned scalar load lowering is not implemented + // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_slot_load" + // CHECK-SAME: !pto.vmi.vreg<512xf32, #pto.vmi.layout> + %out = pto.vmi.group_slot_load %src[%off], %c2 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<512xf32> + return + } +} diff --git a/test/lit/vmi/vmi_layout_assignment_group_slots_cf_join.pto b/test/lit/vmi/vmi_layout_assignment_group_slots_cf_join.pto new file mode 100644 index 0000000000..d327a7b8bc --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_slots_cf_join.pto @@ -0,0 +1,59 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_slots_cf_join( + %cond: i1, + %src: !pto.ptr, + %rhs: !pto.ptr, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %sum = scf.if %cond -> !pto.vmi.vreg<128xf32> { + %x = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %a = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + scf.yield %a : !pto.vmi.vreg<128xf32> + } else { + %b = pto.vmi.group_slot_load %rhs[%off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + scf.yield %b : !pto.vmi.vreg<128xf32> + } + %bias = pto.vmi.group_slot_load %rhs[%off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %out = pto.vmi.addf %sum, %bias + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %out, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_slots_cf_join( +// CHECK: %[[IF:.*]] = scf.if +// CHECK-SAME: -> (!pto.vreg<64xf32>) +// CHECK: pto.vldsx2 +// CHECK: pto.vcgadd +// CHECK: pto.vcgadd +// CHECK: scf.yield {{.*}} : !pto.vreg<64xf32> +// CHECK: else +// CHECK: pto.vsldb +// CHECK: scf.yield {{.*}} : !pto.vreg<64xf32> +// CHECK: %[[BIAS:.*]] = pto.vsldb +// CHECK: pto.vadd %[[IF]], %[[BIAS]] +// CHECK: pto.vsts +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_slots_fanout.pto b/test/lit/vmi/vmi_layout_assignment_group_slots_fanout.pto new file mode 100644 index 0000000000..d0ac525849 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_slots_fanout.pto @@ -0,0 +1,68 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_slots_fanout( + %source: !pto.vmi.vreg<128xf32>, + %mask: !pto.vmi.mask<128xpred>, + %sum_dst: !pto.ptr, + %out: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %sum, %sum_dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + %broadcast = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %scaled = pto.vmi.mulf %source, %broadcast + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %scaled_sum = pto.vmi.group_reduce_addf %scaled, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %scaled_sum, %out[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_slots_fanout( +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf +// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr +// ASSIGN: %[[BROADCAST:.*]] = pto.vmi.group_broadcast %[[SUM]] +// ASSIGN-SAME: vmi.selected_plan = "group_broadcast_slots8_vselr" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[SCALED:.*]] = pto.vmi.mulf +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[SCALED_SUM:.*]] = pto.vmi.group_reduce_addf %[[SCALED]] +// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SCALED_SUM]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_slots_fanout( +// LOWER-DAG: %[[C4:.*]] = arith.constant 4 : i32 +// LOWER: %[[FIRST_SUM:.*]] = pto.vadd {{.*}}, {{.*}}, {{.*}} : !pto.vreg<64xf32> +// LOWER: pto.vsts %[[FIRST_SUM]], %arg4[%arg6], {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// LOWER: pto.vselr %[[FIRST_SUM]] +// LOWER: pto.vdup %[[C4]] +// LOWER: pto.vselr %[[FIRST_SUM]] +// LOWER-COUNT-2: pto.vmul +// LOWER: pto.vsts {{.*}}, %arg5[%arg6], {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto b/test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto new file mode 100644 index 0000000000..e4b48121bc --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto @@ -0,0 +1,79 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_slots_scf_for( + %init: !pto.ptr, + %base: !pto.ptr, + %out: !pto.ptr, + %off: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + %acc0 = pto.vmi.group_slot_load %init[%off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %acc = scf.for %i = %c0 to %c2 step %c1 + iter_args(%arg = %acc0) -> (!pto.vmi.vreg<128xf32>) { + %x = pto.vmi.group_load %base[%off], %c16 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_group_mask %c16 + {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + %next = pto.vmi.addf %arg, %sum + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + scf.yield %next : !pto.vmi.vreg<128xf32> + } + pto.vmi.group_store %acc, %out[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_slots_scf_for( +// ASSIGN: %[[ACC0:.*]] = pto.vmi.group_slot_load +// ASSIGN-SAME: vmi.selected_plan = "group_slot_load_slots8_unit_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[ACC:.*]] = scf.for +// ASSIGN-SAME: iter_args(%[[ARG:.*]] = %[[ACC0]]) +// ASSIGN-SAME: -> (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) +// ASSIGN: %[[X:.*]] = pto.vmi.group_load +// ASSIGN-SAME: vmi.selected_plan = "s16_group_load_block8_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.create_group_mask +// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.create_group_mask +// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] +// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.addf %[[ARG]], %[[SUM]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: scf.yield +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[ACC]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_slots_scf_for( +// LOWER: pto.vsldb +// LOWER: scf.for +// LOWER-COUNT-2: pto.vcgadd +// LOWER: pto.vadd +// LOWER: scf.yield +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto new file mode 100644 index 0000000000..452ee085ac --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_store_slots1_unit_stride_invalid( + %source: !pto.vmi.vreg<512xf32>, + %mask: !pto.vmi.mask<512xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<512xf32> + // CHECK: VMI-UNSUPPORTED: pto.vmi.group_store + // CHECK-SAME: slots=1 group_store currently lowers as one lane-0 vsts per group + // CHECK-SAME: requires constant positive row_stride divisible by 8 elements + // CHECK-SAME: packed or unaligned contiguous store lowering is not implemented + // CHECK: note: see current operation: "pto.vmi.group_store" + // CHECK-SAME: !pto.vmi.vreg<512xf32, #pto.vmi.layout> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<512xf32>, !pto.ptr + return + } +} diff --git a/test/lit/vmi/vmi_layout_assignment_mask_granularity_f32_f16_store.pto b/test/lit/vmi/vmi_layout_assignment_mask_granularity_f32_f16_store.pto new file mode 100644 index 0000000000..8a74de4097 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_mask_granularity_f32_f16_store.pto @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_mask_granularity_f32_f16_store( + %src: !pto.ptr, + %out32: !pto.ptr, + %out16: !pto.ptr, + %off: index) { + %c96 = arith.constant 96 : index + %x = pto.vmi.load %src[%off] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_mask %c96 : index -> !pto.vmi.mask<128xpred> + pto.vmi.masked_store %x, %out32[%off], %mask + : !pto.vmi.vreg<128xf32>, !pto.ptr, !pto.vmi.mask<128xpred> + %h = pto.vmi.truncf %x : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + pto.vmi.masked_store %h, %out16[%off], %mask + : !pto.vmi.vreg<128xf16>, !pto.ptr, !pto.vmi.mask<128xpred> + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_mask_granularity_f32_f16_store( +// ASSIGN: %[[X:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[M32:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: pto.vmi.masked_store %[[X]] +// ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[X_SPLIT:.*]] = pto.vmi.ensure_layout %[[X]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[H:.*]] = pto.vmi.truncf %[[X_SPLIT]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[M16:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<128xb16, #pto.vmi.layout> +// ASSIGN: pto.vmi.masked_store %[[H]] +// ASSIGN-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_mask_granularity_f32_f16_store( +// LOWER: pto.vlds +// LOWER: pto.vlds +// LOWER: pto.pge_b32 "PAT_ALL" +// LOWER: pto.pge_b32 "PAT_VL32" +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER: pto.vdintlv +// LOWER: pto.vcvt +// LOWER: pto.vcvt +// LOWER: pto.vor +// LOWER: pto.plt_b16 +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_mask_select_store.pto b/test/lit/vmi/vmi_layout_assignment_mask_select_store.pto new file mode 100644 index 0000000000..62ef723511 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_mask_select_store.pto @@ -0,0 +1,64 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_mask_select_store( + %src: !pto.ptr, + %rhs: !pto.ptr, + %dense: !pto.ptr, + %masked: !pto.ptr, + %off: index) { + %c48 = arith.constant 48 : index + %x = pto.vmi.load %src[%off] : !pto.ptr -> !pto.vmi.vreg<64xf32> + %y = pto.vmi.load %rhs[%off] : !pto.ptr -> !pto.vmi.vreg<64xf32> + %mask = pto.vmi.create_mask %c48 : index -> !pto.vmi.mask<64xpred> + %sum = pto.vmi.addf %x, %y + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32> + -> !pto.vmi.vreg<64xf32> + %passthrough = pto.vmi.select %mask, %sum, %x + : !pto.vmi.mask<64xpred>, !pto.vmi.vreg<64xf32>, + !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf32> + pto.vmi.store %passthrough, %dense[%off] + : !pto.vmi.vreg<64xf32>, !pto.ptr + pto.vmi.masked_store %sum, %masked[%off], %mask + : !pto.vmi.vreg<64xf32>, !pto.ptr, !pto.vmi.mask<64xpred> + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_mask_select_store( +// ASSIGN: %[[X:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// ASSIGN: %[[Y:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<64xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.addf %[[X]], %[[Y]] +// ASSIGN-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// ASSIGN: %[[PASS:.*]] = pto.vmi.select %[[MASK]], %[[SUM]], %[[X]] +// ASSIGN-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[PASS]] +// ASSIGN-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.ptr +// ASSIGN: pto.vmi.masked_store %[[SUM]] +// ASSIGN-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.ptr, !pto.vmi.mask<64xb32, #pto.vmi.layout> +// ASSIGN-NOT: pto.vmi.ensure_mask_layout +// ASSIGN-NOT: pto.vmi.ensure_mask_granularity + +// LOWER-LABEL: func.func @vmi_layout_assignment_mask_select_store( +// LOWER: pto.vlds +// LOWER: pto.vlds +// LOWER: pto.plt_b32 +// LOWER: pto.vadd +// LOWER: pto.vsel +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto b/test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto new file mode 100644 index 0000000000..4004ff6fcc --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto @@ -0,0 +1,66 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_masked_load_dense_group_users( + %base: !pto.ptr, + %copy_out: !pto.ptr, + %sum_out: !pto.ptr, + %off: index) { + %c0_f32 = arith.constant 0.000000e+00 : f32 + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %mask = pto.vmi.create_mask %c256 + : index -> !pto.vmi.mask<256xpred> + %zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %x = pto.vmi.masked_load %base[%off], %mask, %zero + : !pto.ptr, !pto.vmi.mask<256xpred>, + !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + pto.vmi.store %x, %copy_out[%off] + : !pto.vmi.vreg<256xf32>, !pto.ptr + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %sum_out[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_masked_load_dense_group_users( +// ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[ZERO:.*]] = pto.vmi.broadcast +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[X:.*]] = pto.vmi.masked_load +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[X]] +// ASSIGN: %[[X_SPLIT:.*]] = pto.vmi.ensure_layout %[[X]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X_SPLIT]], %[[MASK_SPLIT]] +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_masked_load_dense_group_users( +// LOWER-COUNT-4: pto.vsel +// LOWER-COUNT-4: pto.vsts +// LOWER: pto.vdintlv +// LOWER-COUNT-4: pto.vcgadd +// LOWER: pto.vadd +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto b/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto new file mode 100644 index 0000000000..bad43bb869 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto @@ -0,0 +1,39 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_masked_load_group_tail_s32( + %base: !pto.ptr, + %sum_out: !pto.ptr, + %off: index) { + %c0_f32 = arith.constant 0.000000e+00 : f32 + %c1 = arith.constant 1 : index + %c25 = arith.constant 25 : index + %mask = pto.vmi.create_group_mask %c25 + {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + %zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %x = pto.vmi.masked_load %base[%off], %mask, %zero + : !pto.ptr, !pto.vmi.mask<256xpred>, + !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %sum_out[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED: pto.vmi.group_reduce_addf +// CHECK-SAME: s32 block8 lowering does not yet support partial create_group_mask active_elems_per_group during layout assignment +// CHECK-NOT: vmi.selected_plan = "s32_reduce_block8_stride" diff --git a/test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto b/test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto new file mode 100644 index 0000000000..a2d4cab4d9 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_non_load_s32_reduce( + %base: !pto.ptr, + %bias: f32, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %a = pto.vmi.load %base[%off] : !pto.ptr -> !pto.vmi.vreg<256xf32> + %biasv = pto.vmi.broadcast %bias : f32 -> !pto.vmi.vreg<256xf32> + %x = pto.vmi.addf %a, %biasv + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_non_load_s32_reduce( +// ASSIGN-SAME: %[[BASE:arg[0-9]+]]: !pto.ptr +// ASSIGN-SAME: %[[BIAS:arg[0-9]+]]: f32 +// ASSIGN: %[[A:.*]] = pto.vmi.load %[[BASE]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[BIASV:.*]] = pto.vmi.broadcast %[[BIAS]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[X:.*]] = pto.vmi.addf %[[A]], %[[BIASV]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_non_load_s32_reduce( +// LOWER-COUNT-4: pto.vdup %arg1 +// LOWER-COUNT-4: pto.vadd {{.*}}, {{.*}}, {{.*}} : !pto.vreg<64xf32> +// LOWER: %[[VL8:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// LOWER-COUNT-4: pto.vcgadd +// LOWER-COUNT-3: pto.vadd {{.*}}, {{.*}}, %[[VL8]] +// LOWER: pto.vsts {{.*}}, %arg2[%arg3], {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto b/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto new file mode 100644 index 0000000000..3005e53c0a --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_packed_group_slots_truncf_invalid( + %source: !pto.vmi.vreg<128xf32>, + %mask: !pto.vmi.mask<128xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + // CHECK: {{VMI-UNSUPPORTED}}: pto.vmi.truncf operand #0 has type + // CHECK-SAME: #pto.vmi.layout + // CHECK-SAME: requires + // CHECK-SAME: #pto.vmi.layout + // CHECK-SAME: pto.vmi.ensure_layout cannot materialize this conversion + // CHECK: failed helper conversion + // CHECK-SAME: unsupported source/result layout pair + %h = pto.vmi.truncf %sum + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + pto.vmi.group_store %h, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf16>, !pto.ptr + return + } +} diff --git a/test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto b/test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto new file mode 100644 index 0000000000..01e8e55caf --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto @@ -0,0 +1,64 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_widen_f16_store_reduce( + %src: !pto.ptr, + %sum: !pto.ptr, + %dense: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %x16 = pto.vmi.load %src[%off] : !pto.ptr -> !pto.vmi.vreg<128xf16> + %x32 = pto.vmi.extf %x16 : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %sumv = pto.vmi.group_reduce_addf %x32, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %sumv, %sum[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + pto.vmi.store %x32, %dense[%off] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_widen_f16_store_reduce( +// ASSIGN: %[[X16:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[X32:.*]] = pto.vmi.extf %[[X16]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X32]], %[[MASK]] +// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_parity" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN: %[[X32_DENSE:.*]] = pto.vmi.ensure_layout %[[X32]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[X32_DENSE]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_widen_f16_store_reduce( +// LOWER: pto.vlds +// LOWER: pto.vcvt +// LOWER: pto.vcvt +// LOWER: pto.vcgadd +// LOWER: pto.vcgadd +// LOWER: pto.vadd +// LOWER: pto.vsts +// LOWER: pto.vintlv +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_group_slots_invalid.pto b/test/lit/vmi/vmi_layout_group_slots_invalid.pto new file mode 100644 index 0000000000..f354adb6e8 --- /dev/null +++ b/test/lit/vmi/vmi_layout_group_slots_invalid.pto @@ -0,0 +1,18 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_group_slots_invalid( + %arg0: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + return + } +} + +// CHECK: #pto.vmi.layout requires slots to be positive and divide num_groups when specified diff --git a/test/lit/vmi/vmi_load_full_read_elems_invalid.pto b/test/lit/vmi/vmi_load_full_read_elems_invalid.pto new file mode 100644 index 0000000000..102efd4f0e --- /dev/null +++ b/test/lit/vmi/vmi_load_full_read_elems_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_load_full_read_elems_invalid(%src: !pto.ptr) { + %c0 = arith.constant 0 : index + %value = pto.vmi.load %src[%c0] {full_read_elems = 0} + : !pto.ptr -> !pto.vmi.vreg<100xf32> + return + } +} + +// CHECK: 'pto.vmi.load' op requires full_read_elems to be positive diff --git a/test/lit/vmi/vmi_op_verifier_basic.pto b/test/lit/vmi/vmi_op_verifier_basic.pto index bff24c6e07..3ba8eb29dc 100644 --- a/test/lit/vmi/vmi_op_verifier_basic.pto +++ b/test/lit/vmi/vmi_op_verifier_basic.pto @@ -16,6 +16,7 @@ module { %mask_b16: !pto.vmi.mask<128xb16, #pto.vmi.layout>, %mask_b32: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index %f32 = arith.constant 1.000000e+00 : f32 %f16 = arith.constant 1.000000e+00 : f16 %active = arith.constant 64 : index @@ -40,6 +41,10 @@ module { %trunc = pto.vmi.truncf %ext : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> %loaded = pto.vmi.load %ptr[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %slot_loaded = pto.vmi.group_slot_load %ptr[%c0], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %slot_loaded, %ptr[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr pto.vmi.store %loaded, %ptr[%c0] : !pto.vmi.vreg<128xf32>, !pto.ptr %tile_read = pto.vmi.tile_read %tile : memref<128xf32> -> !pto.vmi.vreg<128xf32> pto.vmi.tile_write %tile_read, %tile : !pto.vmi.vreg<128xf32>, memref<128xf32> @@ -94,6 +99,8 @@ module { // CHECK: pto.vmi.extf // CHECK: pto.vmi.truncf // CHECK: pto.vmi.load +// CHECK: pto.vmi.group_slot_load +// CHECK: pto.vmi.group_store // CHECK: pto.vmi.store // CHECK: pto.vmi.tile_read // CHECK: pto.vmi.tile_write diff --git a/test/lit/vmi/vmi_ptoas_call_boundary_vecscope_invalid.pto b/test/lit/vmi/vmi_ptoas_call_boundary_vecscope_invalid.pto new file mode 100644 index 0000000000..950215e5e4 --- /dev/null +++ b/test/lit/vmi/vmi_ptoas_call_boundary_vecscope_invalid.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - 2>&1 | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + module attributes {pto.backend = "vpto", pto.kernel_kind = #pto.kernel_kind} { + func.func private @callee(%x: !pto.vmi.vreg<128xf32>) + -> !pto.vmi.vreg<128xf32> { + %sum = pto.vmi.addf %x, %x + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %sum : !pto.vmi.vreg<128xf32> + } + + func.func @caller(%value: f32, %dst: !pto.ptr, %off: index) { + pto.vecscope { + %x = pto.vmi.broadcast %value : f32 -> !pto.vmi.vreg<128xf32> + %r = func.call @callee(%x) + : (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> + pto.vmi.store %r, %dst[%off] + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + return + } + } +} + +// CHECK: cannot infer resultless pto.vecscope because VPTO vector-scope data cannot have external users +// CHECK-SAME: escaping value type is '!pto.vreg<64xf32>' diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8.pto new file mode 100644 index 0000000000..3a96e94d67 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8.pto @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_broadcast_slots8( + %source: !pto.vmi.vreg<1024xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %out = pto.vmi.group_broadcast %source + {num_groups = 128, vmi.selected_plan = "group_broadcast_slots8_vselr"} + : !pto.vmi.vreg<1024xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, + %p8, %p9, %p10, %p11, %p12, %p13, %p14, %p15 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1024xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, + %p8, %p9, %p10, %p11, %p12, %p13, %p14, %p15 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_broadcast_slots8( +// CHECK-COUNT-16: pto.vselr +// CHECK-NOT: pto.vcadd +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_missing_plan_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_missing_plan_invalid.pto new file mode 100644 index 0000000000..a03cdfd9df --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_missing_plan_invalid.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_group_broadcast_slots8_missing_plan_invalid( + %source: !pto.vmi.vreg<1024xf32, #pto.vmi.layout>) { + %out = pto.vmi.group_broadcast %source {num_groups = 128} + : !pto.vmi.vreg<1024xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> + "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1024xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return + } +} + +// CHECK: VMI-UNSUPPORTED: +// CHECK: pto.vmi.group_broadcast requires full source chunks +// CHECK: requires vmi.selected_plan selected by vmi-layout-assignment diff --git a/test/lit/vmi/vmi_to_vpto_group_load_missing_plan_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_load_missing_plan_invalid.pto new file mode 100644 index 0000000000..563f939f77 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_load_missing_plan_invalid.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_group_load_missing_plan_invalid( + %source: !pto.ptr, + %row_stride: index) { + %c0 = arith.constant 0 : index + %out = pto.vmi.group_load %source[%c0], %row_stride {num_groups = 2} + : !pto.ptr -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<512xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return + } +} + +// CHECK: VMI-UNSUPPORTED: +// CHECK: pto.vmi.group_load requires contiguous full result chunks +// CHECK: requires vmi.selected_plan selected by vmi-layout-assignment diff --git a/test/lit/vmi/vmi_to_vpto_group_ops.pto b/test/lit/vmi/vmi_to_vpto_group_ops.pto index 6a10e168dd..e757c583f6 100644 --- a/test/lit/vmi/vmi_to_vpto_group_ops.pto +++ b/test/lit/vmi/vmi_to_vpto_group_ops.pto @@ -15,7 +15,8 @@ module { %row_stride: index, %mask: !pto.vmi.mask<512xb32, #pto.vmi.layout>) { %c0 = arith.constant 0 : index - %v = pto.vmi.group_load %src[%c0], %row_stride {num_groups = 2} + %v = pto.vmi.group_load %src[%c0], %row_stride + {num_groups = 2, vmi.selected_plan = "group_load_contiguous_chunks"} : !pto.ptr -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> %r = pto.vmi.group_reduce_addf %v, %mask {num_groups = 2, reassoc} : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto new file mode 100644 index 0000000000..ee12b742e8 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_reduce_s64( + %source: !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<512xb32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %out = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc, vmi.selected_plan = "s64_reduce_row_local"} + : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + !pto.vmi.mask<512xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<512xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_s64( +// CHECK-DAG: %[[VL1:.*]] = pto.pge_b32 "PAT_VL1" +// CHECK: pto.vcadd +// CHECK: pto.vadd +// CHECK: pto.vsel {{.*}}, {{.*}}, %[[VL1]] +// CHECK: pto.vcadd +// CHECK: pto.vsel {{.*}}, {{.*}}, %[[VL1]] +// CHECK: return {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_s64_missing_plan_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_s64_missing_plan_invalid.pto new file mode 100644 index 0000000000..96d975ab7d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_s64_missing_plan_invalid.pto @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_group_reduce_s64_missing_plan_invalid( + %source: !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<512xb32, #pto.vmi.layout>) { + %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + !pto.vmi.mask<512xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<512xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return + } +} + +// CHECK: VMI-UNSUPPORTED: +// CHECK: pto.vmi.group_reduce_addf lowers through pto.vcgadd +// CHECK: requires vmi.selected_plan selected by vmi-layout-assignment diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_slots8.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_slots8.pto new file mode 100644 index 0000000000..305c488dd5 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_slots8.pto @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_reduce_slots8( + %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc, vmi.selected_plan = "s8_reduce_contiguous"} + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_slots8( +// CHECK: %[[OUT:.*]] = pto.vcgadd %arg0, %arg1 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vcadd +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_slots8_missing_plan_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_slots8_missing_plan_invalid.pto new file mode 100644 index 0000000000..b67cb34f2d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_slots8_missing_plan_invalid.pto @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_group_reduce_slots8_missing_plan_invalid( + %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return + } +} + +// CHECK: VMI-UNSUPPORTED: +// CHECK: pto.vmi.group_reduce_addf lowers through pto.vcgadd +// CHECK: requires vmi.selected_plan selected by vmi-layout-assignment diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_load.pto b/test/lit/vmi/vmi_to_vpto_group_slot_load.pto new file mode 100644 index 0000000000..5927f63069 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_slot_load.pto @@ -0,0 +1,74 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_slot_load_slots8( + %src: !pto.ptr, %off: index) -> !pto.vreg<64xf32> { + %c1 = arith.constant 1 : index + %out = pto.vmi.group_slot_load %src[%off], %c1 + {num_groups = 8, vmi.selected_plan = "group_slot_load_slots8_unit_stride"} + : !pto.ptr + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_group_slot_load_slots1( + %src: !pto.ptr, %off: index) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %c8 = arith.constant 8 : index + %out = pto.vmi.group_slot_load %src[%off], %c8 + {num_groups = 8, vmi.selected_plan = "group_slot_load_slots1_row_local"} + : !pto.ptr + -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<512xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_group_slot_load_slots8_store( + %src: !pto.ptr, %dst: !pto.ptr, %off: index) { + %c1 = arith.constant 1 : index + %out = pto.vmi.group_slot_load %src[%off], %c1 + {num_groups = 8, vmi.selected_plan = "group_slot_load_slots8_unit_stride"} + : !pto.ptr + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + pto.vmi.group_store %out, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_load_slots8( +// CHECK-DAG: %[[MASK:.*]] = pto.pge_b32 "PAT_VL1" : !pto.mask +// CHECK: %[[BASE:.*]] = pto.addptr %arg0, %arg1 : -> +// CHECK: %[[OUT:.*]] = pto.vsldb %[[BASE]], {{.*}}, {{.*}}, %[[MASK]] : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[OUT]] + +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_load_slots1( +// CHECK-COUNT-8: pto.vsldb + +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_load_slots8_store( +// CHECK: %[[LOAD_MASK:.*]] = pto.pge_b32 "PAT_VL1" : !pto.mask +// CHECK: %[[BASE:.*]] = pto.addptr %arg0, %arg2 : -> +// CHECK: %[[OUT:.*]] = pto.vsldb %[[BASE]], {{.*}}, {{.*}}, %[[LOAD_MASK]] +// CHECK: %[[STORE_MASK:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// CHECK: pto.vsts %[[OUT]], %arg1[%arg2], %[[STORE_MASK]] : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_load_missing_plan_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_slot_load_missing_plan_invalid.pto new file mode 100644 index 0000000000..f442e2fbbe --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_slot_load_missing_plan_invalid.pto @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_group_slot_load_missing_plan_invalid( + %src: !pto.ptr, %off: index) { + %c1 = arith.constant 1 : index + %out = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} + : !pto.ptr + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return + } +} + +// CHECK: VMI-UNSUPPORTED: +// CHECK: pto.vmi.group_slot_load requires explicit group_slots result layout +// CHECK: requires vmi.selected_plan selected by vmi-layout-assignment diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto new file mode 100644 index 0000000000..10d9a2d3fa --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_group_slot_load_nonunit_slots8_invalid( + %src: !pto.ptr, %off: index, %stride: index) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + %out = pto.vmi.group_slot_load %src[%off], %stride + {num_groups = 8, vmi.selected_plan = "group_slot_load_slots8_unit_stride"} + : !pto.ptr + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %out : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } +} + +// CHECK: VMI-UNSUPPORTED: +// CHECK: pto.vmi.group_slot_load requires explicit group_slots result layout +// CHECK: slots=8 group_slot_load requires constant unit source_group_stride diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1.pto b/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1.pto new file mode 100644 index 0000000000..d24f504e67 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1.pto @@ -0,0 +1,39 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_slot_truncf_slots1( + %source: !pto.vmi.vreg<512xf32, #pto.vmi.layout>) + -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16>) { + %narrow = pto.vmi.truncf %source + {vmi.selected_plan = "group_slot_cast_slots1_f32_to_f16"} + : !pto.vmi.vreg<512xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<512xf16, #pto.vmi.layout> + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 = "pto.vmi.unpack"(%narrow) + : (!pto.vmi.vreg<512xf16, #pto.vmi.layout>) + -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16>) + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 + : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_truncf_slots1( +// CHECK-DAG: %[[VL1:.*]] = pto.pge_b32 "PAT_VL1" +// CHECK-COUNT-8: pto.vcvt {{.*}}, %[[VL1]] {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: return {{.*}} : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_missing_plan_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_missing_plan_invalid.pto new file mode 100644 index 0000000000..f265dc0912 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_missing_plan_invalid.pto @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_group_slot_truncf_slots1_missing_plan_invalid( + %source: !pto.vmi.vreg<512xf32, #pto.vmi.layout>) { + %narrow = pto.vmi.truncf %source + : !pto.vmi.vreg<512xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<512xf16, #pto.vmi.layout> + "pto.vmi.unpack"(%narrow) + : (!pto.vmi.vreg<512xf16, #pto.vmi.layout>) + -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16>) + return + } +} + +// CHECK: VMI-UNSUPPORTED: +// CHECK: pto.vmi.truncf supports only +// CHECK: requires vmi.selected_plan selected by vmi-layout-assignment diff --git a/test/lit/vmi/vmi_to_vpto_group_store_slots8_nonunit_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_store_slots8_nonunit_invalid.pto new file mode 100644 index 0000000000..305b039d72 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_store_slots8_nonunit_invalid.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_group_store_slots8_nonunit_invalid( + %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %off: index, + %row_stride: index) { + pto.vmi.group_store %value, %dst[%off], %row_stride {num_groups = 8} + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.ptr + return + } +} + +// CHECK: VMI-UNSUPPORTED: +// CHECK-SAME: pto.vmi.group_store +// CHECK-SAME: slots=8 group_store currently requires constant unit row_stride diff --git a/test/lit/vmi/vmi_to_vpto_quant_dequant.pto b/test/lit/vmi/vmi_to_vpto_quant_dequant.pto index 7d302805d6..a0cc8215cb 100644 --- a/test/lit/vmi/vmi_to_vpto_quant_dequant.pto +++ b/test/lit/vmi/vmi_to_vpto_quant_dequant.pto @@ -252,9 +252,8 @@ module { // CHECK-SAME: %[[QDST:[^,]+]]: !pto.ptr // CHECK: scf.for // CHECK: scf.for -// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: pto.vldsx2 {{.*}}, "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> // CHECK: pto.vmul {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> -// CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> // CHECK: pto.vcvt {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> // CHECK: pto.vcvt {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> // CHECK: pto.vor {{.*}} : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> @@ -302,7 +301,6 @@ module { // CHECK: pto.vor {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> // CHECK: pto.vor {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> // CHECK: pto.vor {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> -// CHECK: scf.if // CHECK: pto.plt_b8 {{.*}} : i32 -> !pto.mask, i32 // CHECK: pto.vsts {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.ptr, !pto.mask // CHECK-NOT: pto.vmi. diff --git a/test/lit/vmi/vmi_type_attr_parse.pto b/test/lit/vmi/vmi_type_attr_parse.pto index 04613c441d..5798114cc7 100644 --- a/test/lit/vmi/vmi_type_attr_parse.pto +++ b/test/lit/vmi/vmi_type_attr_parse.pto @@ -11,17 +11,23 @@ module attributes { pto.vmi_contiguous = #pto.vmi.layout, pto.vmi_deinterleaved2 = #pto.vmi.layout, - pto.vmi_deinterleaved4 = #pto.vmi.layout + pto.vmi_deinterleaved4 = #pto.vmi.layout, + pto.vmi_deinterleaved4_block8 = + #pto.vmi.layout, + pto.vmi_group_slots8 = #pto.vmi.layout } { func.func @vmi_type_attr_parse( %surface: !pto.vmi.vreg<128xf32>, %contiguous: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, %wide2: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, %wide4: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + %wide4_block8: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + %group_slots8: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, %surface_mask: !pto.vmi.mask<128xpred>, %mask_b8: !pto.vmi.mask<128xb8, #pto.vmi.layout>, %mask_b16: !pto.vmi.mask<128xb16, #pto.vmi.layout>, - %mask_b32: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { + %mask_b32: !pto.vmi.mask<128xb32, #pto.vmi.layout>, + %mask_b32_block8: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { return } } @@ -29,12 +35,17 @@ module attributes { // CHECK: pto.vmi_contiguous = #pto.vmi.layout // CHECK: pto.vmi_deinterleaved2 = #pto.vmi.layout // CHECK: pto.vmi_deinterleaved4 = #pto.vmi.layout +// CHECK: pto.vmi_deinterleaved4_block8 = #pto.vmi.layout +// CHECK: pto.vmi_group_slots8 = #pto.vmi.layout // CHECK-LABEL: func.func @vmi_type_attr_parse( // CHECK-SAME: %{{.*}}: !pto.vmi.vreg<128xf32> // CHECK-SAME: %{{.*}}: !pto.vmi.vreg<128xf32, #pto.vmi.layout> // CHECK-SAME: %{{.*}}: !pto.vmi.vreg<128xf32, #pto.vmi.layout> // CHECK-SAME: %{{.*}}: !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// CHECK-SAME: %{{.*}}: !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// CHECK-SAME: %{{.*}}: !pto.vmi.vreg<128xf32, #pto.vmi.layout> // CHECK-SAME: %{{.*}}: !pto.vmi.mask<128xpred> // CHECK-SAME: %{{.*}}: !pto.vmi.mask<128xb8, #pto.vmi.layout> // CHECK-SAME: %{{.*}}: !pto.vmi.mask<128xb16, #pto.vmi.layout> // CHECK-SAME: %{{.*}}: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK-SAME: %{{.*}}: !pto.vmi.mask<128xb32, #pto.vmi.layout> diff --git a/test/vpto/cases/vmi/broadcast-dense-group-users/compare.py b/test/vpto/cases/vmi/broadcast-dense-group-users/compare.py new file mode 100644 index 0000000000..9f34394fa1 --- /dev/null +++ b/test/vpto/cases/vmi/broadcast-dense-group-users/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(name: str, golden_name: str) -> None: + golden = np.fromfile(golden_name, dtype=np.float32) + output = np.fromfile(name, dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + return + if golden.shape != output.shape: + print(f"[ERROR] compare failed {name}: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +def main() -> None: + check("v2.bin", "golden_v2.bin") + check("v3.bin", "golden_v3.bin") + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/broadcast-dense-group-users/golden.py b/test/vpto/cases/vmi/broadcast-dense-group-users/golden.py new file mode 100644 index 0000000000..7df1eedef3 --- /dev/null +++ b/test/vpto/cases/vmi/broadcast-dense-group-users/golden.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 32 +SCALE = np.float32(0.5) +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + base = np.linspace(-0.875, 0.625, COLS, dtype=np.float32) + src = np.empty((ROWS, COLS), dtype=np.float32) + for row in range(ROWS): + src[row, :] = base + np.float32(row) * np.float32(0.03125) + copy = np.full((ROWS, COLS), SENTINEL, dtype=np.float32) + sums = np.full(ROWS, SENTINEL, dtype=np.float32) + golden_copy = src + SCALE + golden_sum = np.sum(src * SCALE, axis=1, dtype=np.float32).astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + copy.reshape(-1).tofile(output_dir / "v2.bin") + sums.tofile(output_dir / "v3.bin") + golden_copy.reshape(-1).astype(np.float32).tofile(output_dir / "golden_v2.bin") + golden_sum.astype(np.float32).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/broadcast-dense-group-users/kernel.pto b/test/vpto/cases/vmi/broadcast-dense-group-users/kernel.pto new file mode 100644 index 0000000000..3881dfc10f --- /dev/null +++ b/test/vpto/cases/vmi/broadcast-dense-group-users/kernel.pto @@ -0,0 +1,68 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_broadcast_dense_group_users_kernel(%src_gm: !pto.ptr, + %copy_gm: !pto.ptr, + %sum_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %scale = arith.constant 5.000000e-01 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_copy = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %scale_vec = pto.vmi.broadcast %scale : f32 -> !pto.vmi.vreg<256xf32> + %x = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %copy = pto.vmi.addf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + pto.vmi.store %copy, %ub_copy[%c0] + : !pto.vmi.vreg<256xf32>, !pto.ptr + + %mask = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + %prod = pto.vmi.mulf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %prod, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_copy, %copy_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/broadcast-dense-group-users/launch.cpp b/test/vpto/cases/vmi/broadcast-dense-group-users/launch.cpp new file mode 100644 index 0000000000..21e26d6cf5 --- /dev/null +++ b/test/vpto/cases/vmi/broadcast-dense-group-users/launch.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_broadcast_dense_group_users_kernel(__gm__ float *src, __gm__ float *copy, + __gm__ float *sum); + +void LaunchVmi_broadcast_dense_group_users_kernel(float *src, float *copy, + float *sum, void *stream) { + vmi_broadcast_dense_group_users_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)copy, (__gm__ float *)sum); +} diff --git a/test/vpto/cases/vmi/broadcast-dense-group-users/main.cpp b/test/vpto/cases/vmi/broadcast-dense-group-users/main.cpp new file mode 100644 index 0000000000..b43a794cdb --- /dev/null +++ b/test/vpto/cases/vmi/broadcast-dense-group-users/main.cpp @@ -0,0 +1,97 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_broadcast_dense_group_users_kernel(float *src, float *copy, + float *sum, void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kCols = 32; + constexpr size_t kSrcElems = kRows * kCols; + constexpr size_t kSumElems = kRows; + size_t srcBytes = kSrcElems * sizeof(float); + size_t copyBytes = kSrcElems * sizeof(float); + size_t sumBytes = kSumElems * sizeof(float); + float *srcHost = nullptr; + float *copyHost = nullptr; + float *sumHost = nullptr; + float *srcDevice = nullptr; + float *copyDevice = nullptr; + float *sumDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(©Host), copyBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)©Device, copyBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", copyBytes, copyHost, copyBytes); + ReadFile("./v3.bin", sumBytes, sumHost, sumBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(copyDevice, copyBytes, copyHost, copyBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_broadcast_dense_group_users_kernel(srcDevice, copyDevice, sumDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(copyHost, copyBytes, copyDevice, copyBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", copyHost, copyBytes); + WriteFile("./v3.bin", sumHost, sumBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(copyDevice); + aclrtFree(sumDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(copyHost); + aclrtFreeHost(sumHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/broadcast-dense-group-users/ptoas.flags b/test/vpto/cases/vmi/broadcast-dense-group-users/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/broadcast-dense-group-users/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/compare.py b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/compare.py new file mode 100644 index 0000000000..837961af76 --- /dev/null +++ b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/compare.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + for name in ("v2", "v3"): + golden = np.fromfile(f"golden_{name}.bin", dtype=np.float32) + output = np.fromfile(f"{name}.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/golden.py b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/golden.py new file mode 100644 index 0000000000..6e5edd801a --- /dev/null +++ b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 32 +INPUT_ELEMS = ROWS * GROUP_SIZE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty(INPUT_ELEMS, dtype=np.float32) + sum_out = np.full(ROWS, SENTINEL, dtype=np.float32) + copy_out = np.full(INPUT_ELEMS, SENTINEL, dtype=np.float32) + golden_sum = np.empty(ROWS, dtype=np.float32) + + base_row = np.linspace(-0.875, 0.625, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = base_row + np.float32(row) * np.float32(0.0625) + src[begin : begin + GROUP_SIZE] = values + golden_sum[row] = np.sum(values, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + sum_out.tofile(output_dir / "v2.bin") + copy_out.tofile(output_dir / "v3.bin") + golden_sum.tofile(output_dir / "golden_v2.bin") + src.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/kernel.pto b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/kernel.pto new file mode 100644 index 0000000000..2d0dcd2c64 --- /dev/null +++ b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/kernel.pto @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dense_group_reduce_multi_consumer_kernel(%src_gm: !pto.ptr, + %sum_gm: !pto.ptr, + %copy_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_copy = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + pto.vmi.store %x, %ub_copy[%c0] + : !pto.vmi.vreg<256xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_copy, %copy_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/launch.cpp b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/launch.cpp new file mode 100644 index 0000000000..1249378267 --- /dev/null +++ b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/launch.cpp @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dense_group_reduce_multi_consumer_kernel(__gm__ float *src, + __gm__ float *sum, + __gm__ float *copy); + +void LaunchVmi_dense_group_reduce_multi_consumer_kernel(float *src, float *sum, + float *copy, + void *stream) { + vmi_dense_group_reduce_multi_consumer_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)sum, (__gm__ float *)copy); +} diff --git a/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/main.cpp b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/main.cpp new file mode 100644 index 0000000000..0482d8339d --- /dev/null +++ b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dense_group_reduce_multi_consumer_kernel(float *src, float *sum, + float *copy, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 32; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kSumElems = kRows; + constexpr size_t kCopyElems = kInputElems; + size_t srcBytes = kInputElems * sizeof(float); + size_t sumBytes = kSumElems * sizeof(float); + size_t copyBytes = kCopyElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *sumHost = nullptr; + float *sumDevice = nullptr; + float *copyHost = nullptr; + float *copyDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMallocHost((void **)(©Host), copyBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)©Device, copyBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", sumBytes, sumHost, sumBytes); + ReadFile("./v3.bin", copyBytes, copyHost, copyBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(copyDevice, copyBytes, copyHost, copyBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_dense_group_reduce_multi_consumer_kernel(srcDevice, sumDevice, + copyDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(copyHost, copyBytes, copyDevice, copyBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", sumHost, sumBytes); + WriteFile("./v3.bin", copyHost, copyBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(sumDevice); + aclrtFree(copyDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(sumHost); + aclrtFreeHost(copyHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/ptoas.flags b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/f32-to-f8-store-reduce/compare.py b/test/vpto/cases/vmi/f32-to-f8-store-reduce/compare.py new file mode 100644 index 0000000000..d00c9b8b26 --- /dev/null +++ b/test/vpto/cases/vmi/f32-to-f8-store-reduce/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check_f32(name: str, atol: float, rtol: float) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.float32) + output = np.fromfile(f"{name}.bin", dtype=np.float32) + close = golden.shape == output.shape and np.allclose(golden, output, atol=atol, rtol=rtol) + if close: + return True + diff = np.nonzero(~np.isclose(golden, output, atol=atol, rtol=rtol))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + return False + + +def check_u8(name: str) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.uint8) + output = np.fromfile(f"{name}.bin", dtype=np.uint8) + if golden.shape == output.shape and np.array_equal(golden, output): + return True + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed {name} idx={idx} golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}") + return False + + +def main() -> None: + if not check_f32("v2", 1e-4, 1e-4) or not check_u8("v3"): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/f32-to-f8-store-reduce/golden.py b/test/vpto/cases/vmi/f32-to-f8-store-reduce/golden.py new file mode 100644 index 0000000000..9034fe8d42 --- /dev/null +++ b/test/vpto/cases/vmi/f32-to-f8-store-reduce/golden.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 32 +VALUES = np.array([0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0], dtype=np.float32) +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8], dtype=np.uint8) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, GROUP_SIZE), dtype=np.float32) + golden_out8 = np.empty((ROWS, GROUP_SIZE), dtype=np.uint8) + for row in range(ROWS): + value_idx = row % len(VALUES) + if row == 0: + src[row, :] = np.tile(VALUES, GROUP_SIZE // len(VALUES)) + golden_out8[row, :] = np.tile(F8E4M3FN_BYTES, GROUP_SIZE // len(F8E4M3FN_BYTES)) + else: + src[row, :] = VALUES[value_idx] + golden_out8[row, :] = F8E4M3FN_BYTES[value_idx] + + golden_sum = np.sum(src, axis=1, dtype=np.float32) + sum_out = np.full(ROWS, SENTINEL_F32, dtype=np.float32) + out8 = np.full(ROWS * GROUP_SIZE, SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + sum_out.tofile(output_dir / "v2.bin") + out8.tofile(output_dir / "v3.bin") + golden_sum.astype(np.float32).tofile(output_dir / "golden_v2.bin") + golden_out8.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/f32-to-f8-store-reduce/kernel.pto b/test/vpto/cases/vmi/f32-to-f8-store-reduce/kernel.pto new file mode 100644 index 0000000000..6f68510ede --- /dev/null +++ b/test/vpto/cases/vmi/f32-to-f8-store-reduce/kernel.pto @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_f32_to_f8_store_reduce_kernel(%src_gm: !pto.ptr, + %sum_gm: !pto.ptr, + %out8_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out8_u8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out8_f8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x32 = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %sum = pto.vmi.group_reduce_addf %x32, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + %x8 = pto.vmi.truncf %x32 + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %x8, %ub_out8_f8[%c0] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out8_u8, %out8_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/f32-to-f8-store-reduce/launch.cpp b/test/vpto/cases/vmi/f32-to-f8-store-reduce/launch.cpp new file mode 100644 index 0000000000..eef7fac9d0 --- /dev/null +++ b/test/vpto/cases/vmi/f32-to-f8-store-reduce/launch.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_f32_to_f8_store_reduce_kernel(__gm__ float *src, __gm__ float *sum, + __gm__ uint8_t *out8); + +void LaunchVmi_f32_to_f8_store_reduce_kernel(float *src, float *sum, + uint8_t *out8, void *stream) { + vmi_f32_to_f8_store_reduce_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)sum, (__gm__ uint8_t *)out8); +} diff --git a/test/vpto/cases/vmi/f32-to-f8-store-reduce/main.cpp b/test/vpto/cases/vmi/f32-to-f8-store-reduce/main.cpp new file mode 100644 index 0000000000..1e3e7e8a86 --- /dev/null +++ b/test/vpto/cases/vmi/f32-to-f8-store-reduce/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_f32_to_f8_store_reduce_kernel(float *src, float *sum, + uint8_t *out8, void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 32; + constexpr size_t kSrcElems = kRows * kGroupSize; + constexpr size_t kSumElems = kRows; + constexpr size_t kOut8Elems = kSrcElems; + size_t srcBytes = kSrcElems * sizeof(float); + size_t sumBytes = kSumElems * sizeof(float); + size_t out8Bytes = kOut8Elems * sizeof(uint8_t); + float *srcHost = nullptr; + float *sumHost = nullptr; + uint8_t *out8Host = nullptr; + float *srcDevice = nullptr; + float *sumDevice = nullptr; + uint8_t *out8Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&out8Host), out8Bytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&out8Device, out8Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", sumBytes, sumHost, sumBytes); + ReadFile("./v3.bin", out8Bytes, out8Host, out8Bytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(out8Device, out8Bytes, out8Host, out8Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_f32_to_f8_store_reduce_kernel(srcDevice, sumDevice, out8Device, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(out8Host, out8Bytes, out8Device, out8Bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", sumHost, sumBytes); + WriteFile("./v3.bin", out8Host, out8Bytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(sumDevice); + aclrtFree(out8Device); + aclrtFreeHost(srcHost); + aclrtFreeHost(sumHost); + aclrtFreeHost(out8Host); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/f32-to-f8-store-reduce/ptoas.flags b/test/vpto/cases/vmi/f32-to-f8-store-reduce/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/f32-to-f8-store-reduce/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/f8-compute-f8/compare.py b/test/vpto/cases/vmi/f8-compute-f8/compare.py new file mode 100644 index 0000000000..68c53a335e --- /dev/null +++ b/test/vpto/cases/vmi/f8-compute-f8/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.uint8) + output = np.fromfile("v2.bin", dtype=np.uint8) + if golden.shape != output.shape or not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/f8-compute-f8/golden.py b/test/vpto/cases/vmi/f8-compute-f8/golden.py new file mode 100644 index 0000000000..e150b09545 --- /dev/null +++ b/test/vpto/cases/vmi/f8-compute-f8/golden.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 256 +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8], dtype=np.uint8) +F8E4M3FN_TIMES2 = np.array([0x00, 0x40, 0xC0, 0x38, 0x48, 0xC8, 0x50, 0xD0], dtype=np.uint8) + + +def generate(output_dir: Path) -> None: + repeats = (ELEMS + len(F8E4M3FN_BYTES) - 1) // len(F8E4M3FN_BYTES) + src = np.tile(F8E4M3FN_BYTES, repeats)[:ELEMS].astype(np.uint8) + dst = np.full(ELEMS, 0xA5, dtype=np.uint8) + golden = np.tile(F8E4M3FN_TIMES2, repeats)[:ELEMS].astype(np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/f8-compute-f8/kernel.pto b/test/vpto/cases/vmi/f8-compute-f8/kernel.pto new file mode 100644 index 0000000000..568cf5fbde --- /dev/null +++ b/test/vpto/cases/vmi/f8-compute-f8/kernel.pto @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_f8_compute_f8_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %scale = arith.constant 2.000000e+00 : f32 + + %ub_src_u8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_src_f8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst_u8 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_dst_f8 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src_u8, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x8 = pto.vmi.load %ub_src_f8[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf8E4M3FN> + %x32 = pto.vmi.extf %x8 + : !pto.vmi.vreg<256xf8E4M3FN> -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.broadcast %scale + : f32 -> !pto.vmi.vreg<256xf32> + %y32 = pto.vmi.mulf %x32, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %y8 = pto.vmi.truncf %y32 + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %y8, %ub_dst_f8[%c0] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst_u8, %dst_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/f8-compute-f8/launch.cpp b/test/vpto/cases/vmi/f8-compute-f8/launch.cpp new file mode 100644 index 0000000000..63b5269670 --- /dev/null +++ b/test/vpto/cases/vmi/f8-compute-f8/launch.cpp @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_f8_compute_f8_kernel(__gm__ uint8_t *src, __gm__ uint8_t *dst); + +void LaunchVmi_f8_compute_f8_kernel(uint8_t *src, uint8_t *dst, + void *stream) { + vmi_f8_compute_f8_kernel<<<1, nullptr, stream>>>( + (__gm__ uint8_t *)src, (__gm__ uint8_t *)dst); +} diff --git a/test/vpto/cases/vmi/f8-compute-f8/main.cpp b/test/vpto/cases/vmi/f8-compute-f8/main.cpp new file mode 100644 index 0000000000..fffc2d6e65 --- /dev/null +++ b/test/vpto/cases/vmi/f8-compute-f8/main.cpp @@ -0,0 +1,76 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_f8_compute_f8_kernel(uint8_t *src, uint8_t *dst, void *stream); + +int main() { + constexpr size_t kElems = 256; + size_t bytes = kElems * sizeof(uint8_t); + uint8_t *srcHost = nullptr; + uint8_t *dstHost = nullptr; + uint8_t *srcDevice = nullptr; + uint8_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), bytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", bytes, srcHost, bytes); + ReadFile("./v2.bin", bytes, dstHost, bytes); + ACL_CHECK(aclrtMemcpy(srcDevice, bytes, srcHost, bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, bytes, dstHost, bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_f8_compute_f8_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, bytes, dstDevice, bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, bytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/f8-compute-f8/ptoas.flags b/test/vpto/cases/vmi/f8-compute-f8/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/f8-compute-f8/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-broadcast-multi-consumer/compare.py b/test/vpto/cases/vmi/group-broadcast-multi-consumer/compare.py new file mode 100644 index 0000000000..da96a2ff71 --- /dev/null +++ b/test/vpto/cases/vmi/group-broadcast-multi-consumer/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_sum = np.fromfile("golden_v2.bin", dtype=np.float32) + output_sum = np.fromfile("v2.bin", dtype=np.float32) + if golden_sum.shape != output_sum.shape or not np.allclose(golden_sum, output_sum, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden_sum, output_sum, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed v2 idx={idx} " + f"golden={golden_sum[idx] if idx >= 0 else 'n/a'} " + f"output={output_sum[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + golden_dense = np.fromfile("golden_v3.bin", dtype=np.float16) + output_dense = np.fromfile("v3.bin", dtype=np.float16) + if golden_dense.shape != output_dense.shape or not np.array_equal(golden_dense, output_dense): + diff = np.nonzero(golden_dense.view(np.uint16) != output_dense.view(np.uint16))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed v3 idx={idx} " + f"golden={golden_dense[idx] if idx >= 0 else 'n/a'} " + f"output={output_dense[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-broadcast-multi-consumer/golden.py b/test/vpto/cases/vmi/group-broadcast-multi-consumer/golden.py new file mode 100644 index 0000000000..a238aaf082 --- /dev/null +++ b/test/vpto/cases/vmi/group-broadcast-multi-consumer/golden.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 16 +ELEMS = ROWS * GROUP_SIZE +SEED = 29 +SENTINEL = np.float16(-17.5) +SUM_SENTINEL = np.float32(-911.0) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + src = rng.uniform(-2.0, 2.0, size=ELEMS).astype(np.float32) + sum_out = np.full(ROWS, SUM_SENTINEL, dtype=np.float32) + dense = np.full(ELEMS, SENTINEL, dtype=np.float16) + golden_sum = np.empty(ROWS, dtype=np.float32) + golden_dense = np.full(ELEMS, SENTINEL, dtype=np.float16) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = src[begin : begin + GROUP_SIZE] + row_sum = np.sum(values, dtype=np.float32) + golden_sum[row] = np.sum(values * row_sum, dtype=np.float32) + golden_dense[begin : begin + GROUP_SIZE] = row_sum.astype(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + sum_out.tofile(output_dir / "v2.bin") + dense.tofile(output_dir / "v3.bin") + golden_sum.tofile(output_dir / "golden_v2.bin") + golden_dense.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-broadcast-multi-consumer/kernel.pto b/test/vpto/cases/vmi/group-broadcast-multi-consumer/kernel.pto new file mode 100644 index 0000000000..3c14b7fc38 --- /dev/null +++ b/test/vpto/cases/vmi/group-broadcast-multi-consumer/kernel.pto @@ -0,0 +1,69 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_broadcast_multi_consumer_kernel(%src_gm: !pto.ptr, + %sum_gm: !pto.ptr, + %dense_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_dense = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum32 = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + %b_for_mul = pto.vmi.group_broadcast %sum32 {num_groups = 8} + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %y = pto.vmi.mulf %x, %b_for_mul + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %ysum, %ub_sum[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + %b_for_cast = pto.vmi.group_broadcast %sum32 {num_groups = 8} + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %h = pto.vmi.truncf %b_for_cast + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + pto.vmi.store %h, %ub_dense[%c0] : !pto.vmi.vreg<128xf16>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_dense, %dense_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-broadcast-multi-consumer/launch.cpp b/test/vpto/cases/vmi/group-broadcast-multi-consumer/launch.cpp new file mode 100644 index 0000000000..2a562a57e3 --- /dev/null +++ b/test/vpto/cases/vmi/group-broadcast-multi-consumer/launch.cpp @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_broadcast_multi_consumer_kernel(__gm__ float *src, __gm__ float *sum, + __gm__ half *dense); + +void LaunchVmi_group_broadcast_multi_consumer_kernel(float *src, float *sum, + uint16_t *dense, + void *stream) { + vmi_group_broadcast_multi_consumer_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)sum, (__gm__ half *)dense); +} diff --git a/test/vpto/cases/vmi/group-broadcast-multi-consumer/main.cpp b/test/vpto/cases/vmi/group-broadcast-multi-consumer/main.cpp new file mode 100644 index 0000000000..dc39a0c47d --- /dev/null +++ b/test/vpto/cases/vmi/group-broadcast-multi-consumer/main.cpp @@ -0,0 +1,92 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_broadcast_multi_consumer_kernel(float *src, float *sum, + uint16_t *dense, + void *stream); + +int main() { + constexpr size_t kElems = 128; + constexpr size_t kRows = 8; + size_t srcBytes = kElems * sizeof(float); + size_t sumBytes = kRows * sizeof(float); + size_t denseBytes = kElems * sizeof(uint16_t); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *sumHost = nullptr; + float *sumDevice = nullptr; + uint16_t *denseHost = nullptr; + uint16_t *denseDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&denseHost), denseBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&denseDevice, denseBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", sumBytes, sumHost, sumBytes); + ReadFile("./v3.bin", denseBytes, denseHost, denseBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(denseDevice, denseBytes, denseHost, denseBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_broadcast_multi_consumer_kernel(srcDevice, sumDevice, + denseDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(denseHost, denseBytes, denseDevice, denseBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", sumHost, sumBytes); + WriteFile("./v3.bin", denseHost, denseBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(sumDevice); + aclrtFree(denseDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(sumHost); + aclrtFreeHost(denseHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-broadcast-multi-consumer/ptoas.flags b/test/vpto/cases/vmi/group-broadcast-multi-consumer/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-broadcast-multi-consumer/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-load-s16-stride-store/compare.py b/test/vpto/cases/vmi/group-load-s16-stride-store/compare.py new file mode 100644 index 0000000000..28299087e5 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s16-stride-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-load-s16-stride-store/golden.py b/test/vpto/cases/vmi/group-load-s16-stride-store/golden.py new file mode 100644 index 0000000000..5c25033808 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s16-stride-store/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 16 +ROW_STRIDE = 24 +INPUT_ELEMS = ROWS * ROW_STRIDE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.full(INPUT_ELEMS, np.float32(-9.0), dtype=np.float32) + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = np.empty(ROWS, dtype=np.float32) + + base_row = np.linspace(-0.5, 0.25, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * ROW_STRIDE + values = base_row + np.float32(row) * np.float32(0.125) + src[begin : begin + GROUP_SIZE] = values + golden[row] = np.sum(values, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-load-s16-stride-store/kernel.pto b/test/vpto/cases/vmi/group-load-s16-stride-store/kernel.pto new file mode 100644 index 0000000000..f28676f8d5 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s16-stride-store/kernel.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_load_s16_stride_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c24 = arith.constant 24 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c768_i64 = arith.constant 768 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c768_i64 + nburst(%c1_i64, %c768_i64, %c768_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %x = pto.vmi.group_load %ub_src[%c0], %c24 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-load-s16-stride-store/launch.cpp b/test/vpto/cases/vmi/group-load-s16-stride-store/launch.cpp new file mode 100644 index 0000000000..ef8fa0d082 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s16-stride-store/launch.cpp @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_load_s16_stride_store_kernel(__gm__ float *src, __gm__ float *dst); + +void LaunchVmi_group_load_s16_stride_store_kernel(float *src, float *dst, + void *stream) { + vmi_group_load_s16_stride_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-load-s16-stride-store/main.cpp b/test/vpto/cases/vmi/group-load-s16-stride-store/main.cpp new file mode 100644 index 0000000000..414e34200e --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s16-stride-store/main.cpp @@ -0,0 +1,80 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_load_s16_stride_store_kernel(float *src, float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kRowStride = 24; + constexpr size_t kInputElems = kRows * kRowStride; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_load_s16_stride_store_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-load-s16-stride-store/ptoas.flags b/test/vpto/cases/vmi/group-load-s16-stride-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s16-stride-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/compare.py b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/compare.py new file mode 100644 index 0000000000..28299087e5 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/golden.py b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/golden.py new file mode 100644 index 0000000000..8cb473640d --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 32 +ROW_STRIDE = 40 +INPUT_ELEMS = ROWS * ROW_STRIDE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.zeros(INPUT_ELEMS, dtype=np.float32) + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = np.empty(ROWS, dtype=np.float32) + + base_row = np.linspace(-1.0, 1.0, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * ROW_STRIDE + values = base_row + np.float32(row) * np.float32(0.125) + src[begin : begin + GROUP_SIZE] = values + reduction = np.sum(values, dtype=np.float32) + golden[row] = np.sum(values * reduction, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/kernel.pto b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/kernel.pto new file mode 100644 index 0000000000..cf2aea21d7 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/kernel.pto @@ -0,0 +1,59 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_load_s32_stride_broadcast_reduce_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c40 = arith.constant 40 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1280_i64 = arith.constant 1280 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1280_i64 + nburst(%c1_i64, %c1280_i64, %c1280_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x = pto.vmi.group_load %ub_src[%c0], %c40 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + %broadcast = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %x, %broadcast + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %scaled_sum = pto.vmi.group_reduce_addf %scaled, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %scaled_sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/launch.cpp b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/launch.cpp new file mode 100644 index 0000000000..d9218a9389 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/launch.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_load_s32_stride_broadcast_reduce_kernel(__gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_load_s32_stride_broadcast_reduce_kernel(float *src, + float *dst, + void *stream) { + vmi_group_load_s32_stride_broadcast_reduce_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/main.cpp b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/main.cpp new file mode 100644 index 0000000000..b994c2192f --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/main.cpp @@ -0,0 +1,82 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_load_s32_stride_broadcast_reduce_kernel(float *src, + float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kRowStride = 40; + constexpr size_t kInputElems = kRows * kRowStride; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_load_s32_stride_broadcast_reduce_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/ptoas.flags b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-load-s32-stride-store/compare.py b/test/vpto/cases/vmi/group-load-s32-stride-store/compare.py new file mode 100644 index 0000000000..28299087e5 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-load-s32-stride-store/golden.py b/test/vpto/cases/vmi/group-load-s32-stride-store/golden.py new file mode 100644 index 0000000000..efe2d5f3b9 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-store/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 32 +ROW_STRIDE = 40 +INPUT_ELEMS = ROWS * ROW_STRIDE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.full(INPUT_ELEMS, np.float32(-9.0), dtype=np.float32) + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = np.empty(ROWS, dtype=np.float32) + + base_row = np.linspace(-0.75, 0.5, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * ROW_STRIDE + values = base_row + np.float32(row) * np.float32(0.0625) + src[begin : begin + GROUP_SIZE] = values + golden[row] = np.sum(values, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-load-s32-stride-store/kernel.pto b/test/vpto/cases/vmi/group-load-s32-stride-store/kernel.pto new file mode 100644 index 0000000000..7afde7d6f5 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-store/kernel.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_load_s32_stride_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c40 = arith.constant 40 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1280_i64 = arith.constant 1280 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1280_i64 + nburst(%c1_i64, %c1280_i64, %c1280_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x = pto.vmi.group_load %ub_src[%c0], %c40 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-load-s32-stride-store/launch.cpp b/test/vpto/cases/vmi/group-load-s32-stride-store/launch.cpp new file mode 100644 index 0000000000..9443a9cfb3 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-store/launch.cpp @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_load_s32_stride_store_kernel(__gm__ float *src, __gm__ float *dst); + +void LaunchVmi_group_load_s32_stride_store_kernel(float *src, float *dst, + void *stream) { + vmi_group_load_s32_stride_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-load-s32-stride-store/main.cpp b/test/vpto/cases/vmi/group-load-s32-stride-store/main.cpp new file mode 100644 index 0000000000..b67ef78981 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-store/main.cpp @@ -0,0 +1,80 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_load_s32_stride_store_kernel(float *src, float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kRowStride = 40; + constexpr size_t kInputElems = kRows * kRowStride; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_load_s32_stride_store_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-load-s32-stride-store/ptoas.flags b/test/vpto/cases/vmi/group-load-s32-stride-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-basic-store/compare.py b/test/vpto/cases/vmi/group-reduce-basic-store/compare.py new file mode 100644 index 0000000000..dc3a89703c --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-basic-store/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(output_name: str, golden_name: str) -> None: + golden = np.fromfile(golden_name, dtype=np.float32) + output = np.fromfile(output_name, dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + return + + if golden.shape != output.shape: + print(f"[ERROR] compare failed {output_name}: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {output_name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +def main() -> None: + check("v4.bin", "golden_v4.bin") + check("v5.bin", "golden_v5.bin") + check("v6.bin", "golden_v6.bin") + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-basic-store/golden.py b/test/vpto/cases/vmi/group-reduce-basic-store/golden.py new file mode 100644 index 0000000000..24071a1b49 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-basic-store/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +SENTINEL = np.float32(-777.0) + + +def fill_matrix(cols: int, base_start: float, row_step: float) -> np.ndarray: + base = np.linspace(base_start, base_start + 1.0, cols, dtype=np.float32) + out = np.empty((ROWS, cols), dtype=np.float32) + for row in range(ROWS): + out[row, :] = base + np.float32(row) * np.float32(row_step) + return out + + +def write_case(output_dir: Path, matrix: np.ndarray, src_name: str, dst_name: str, golden_name: str) -> None: + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = np.sum(matrix, axis=1, dtype=np.float32).astype(np.float32) + matrix.reshape(-1).tofile(output_dir / src_name) + dst.tofile(output_dir / dst_name) + golden.tofile(output_dir / golden_name) + + +def generate(output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + write_case(output_dir, fill_matrix(8, -0.5, 0.03125), "v1.bin", "v4.bin", "golden_v4.bin") + write_case(output_dir, fill_matrix(16, -0.75, 0.046875), "v2.bin", "v5.bin", "golden_v5.bin") + write_case(output_dir, fill_matrix(32, -0.875, 0.0625), "v3.bin", "v6.bin", "golden_v6.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-basic-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-basic-store/kernel.pto new file mode 100644 index 0000000000..4db72772c1 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-basic-store/kernel.pto @@ -0,0 +1,92 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_basic_store_kernel(%src8_gm: !pto.ptr, + %src16_gm: !pto.ptr, + %src32_gm: !pto.ptr, + %dst8_gm: !pto.ptr, + %dst16_gm: !pto.ptr, + %dst32_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + + %ub_src8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_src16 = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_src32 = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_dst8 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_dst16 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_dst32 = pto.castptr %c12288_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src8_gm, %ub_src8, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %src16_gm, %ub_src16, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %src32_gm, %ub_src32, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask8 = pto.vmi.create_mask %c64 : index -> !pto.vmi.mask<64xpred> + %x8 = pto.vmi.load %ub_src8[%c0] : !pto.ptr -> !pto.vmi.vreg<64xf32> + %sum8 = pto.vmi.group_reduce_addf %x8, %mask8 {num_groups = 8, reassoc} + : !pto.vmi.vreg<64xf32>, !pto.vmi.mask<64xpred> + -> !pto.vmi.vreg<64xf32> + pto.vmi.group_store %sum8, %ub_dst8[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<64xf32>, !pto.ptr + + %mask16 = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %x16 = pto.vmi.load %ub_src16[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum16 = pto.vmi.group_reduce_addf %x16, %mask16 {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %sum16, %ub_dst16[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + + %mask32 = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x32 = pto.vmi.load %ub_src32[%c0] : !pto.ptr -> !pto.vmi.vreg<256xf32> + %sum32 = pto.vmi.group_reduce_addf %x32, %mask32 {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum32, %ub_dst32[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst8, %dst8_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_dst16, %dst16_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_dst32, %dst32_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-basic-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-basic-store/launch.cpp new file mode 100644 index 0000000000..a7304f9a15 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-basic-store/launch.cpp @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_basic_store_kernel(__gm__ float *src8, + __gm__ float *src16, + __gm__ float *src32, + __gm__ float *dst8, + __gm__ float *dst16, + __gm__ float *dst32); + +void LaunchVmi_group_reduce_basic_store_kernel(float *src8, float *src16, + float *src32, float *dst8, + float *dst16, float *dst32, + void *stream) { + vmi_group_reduce_basic_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src8, (__gm__ float *)src16, (__gm__ float *)src32, + (__gm__ float *)dst8, (__gm__ float *)dst16, (__gm__ float *)dst32); +} diff --git a/test/vpto/cases/vmi/group-reduce-basic-store/main.cpp b/test/vpto/cases/vmi/group-reduce-basic-store/main.cpp new file mode 100644 index 0000000000..4ddb71365b --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-basic-store/main.cpp @@ -0,0 +1,123 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_basic_store_kernel(float *src8, float *src16, + float *src32, float *dst8, + float *dst16, float *dst32, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kSrc8Elems = kRows * 8; + constexpr size_t kSrc16Elems = kRows * 16; + constexpr size_t kSrc32Elems = kRows * 32; + constexpr size_t kOutputElems = kRows; + size_t src8Bytes = kSrc8Elems * sizeof(float); + size_t src16Bytes = kSrc16Elems * sizeof(float); + size_t src32Bytes = kSrc32Elems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *src8Host = nullptr; + float *src16Host = nullptr; + float *src32Host = nullptr; + float *dst8Host = nullptr; + float *dst16Host = nullptr; + float *dst32Host = nullptr; + float *src8Device = nullptr; + float *src16Device = nullptr; + float *src32Device = nullptr; + float *dst8Device = nullptr; + float *dst16Device = nullptr; + float *dst32Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&src8Host), src8Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&src16Host), src16Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&src32Host), src32Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dst8Host), dstBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dst16Host), dstBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dst32Host), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&src8Device, src8Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&src16Device, src16Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&src32Device, src32Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dst8Device, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dst16Device, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dst32Device, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", src8Bytes, src8Host, src8Bytes); + ReadFile("./v2.bin", src16Bytes, src16Host, src16Bytes); + ReadFile("./v3.bin", src32Bytes, src32Host, src32Bytes); + ReadFile("./v4.bin", dstBytes, dst8Host, dstBytes); + ReadFile("./v5.bin", dstBytes, dst16Host, dstBytes); + ReadFile("./v6.bin", dstBytes, dst32Host, dstBytes); + ACL_CHECK(aclrtMemcpy(src8Device, src8Bytes, src8Host, src8Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(src16Device, src16Bytes, src16Host, src16Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(src32Device, src32Bytes, src32Host, src32Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dst8Device, dstBytes, dst8Host, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dst16Device, dstBytes, dst16Host, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dst32Device, dstBytes, dst32Host, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_basic_store_kernel( + src8Device, src16Device, src32Device, dst8Device, dst16Device, + dst32Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dst8Host, dstBytes, dst8Device, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(dst16Host, dstBytes, dst16Device, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(dst32Host, dstBytes, dst32Device, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v4.bin", dst8Host, dstBytes); + WriteFile("./v5.bin", dst16Host, dstBytes); + WriteFile("./v6.bin", dst32Host, dstBytes); + +cleanup: + aclrtFree(src8Device); + aclrtFree(src16Device); + aclrtFree(src32Device); + aclrtFree(dst8Device); + aclrtFree(dst16Device); + aclrtFree(dst32Device); + aclrtFreeHost(src8Host); + aclrtFreeHost(src16Host); + aclrtFreeHost(src32Host); + aclrtFreeHost(dst8Host); + aclrtFreeHost(dst16Host); + aclrtFreeHost(dst32Host); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-basic-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-basic-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-basic-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/compare.py b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/compare.py new file mode 100644 index 0000000000..28299087e5 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/golden.py b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/golden.py new file mode 100644 index 0000000000..05510a7bd9 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 16 +INPUT_ELEMS = ROWS * GROUP_SIZE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty(INPUT_ELEMS, dtype=np.float32) + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = np.empty(ROWS, dtype=np.float32) + + base_row = np.linspace(-0.625, 0.875, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = base_row + np.float32(row) * np.float32(0.125) + src[begin : begin + GROUP_SIZE] = values + reduction = np.sum(values, dtype=np.float32) + golden[row] = np.sum(values * reduction, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/kernel.pto new file mode 100644 index 0000000000..e41c4d656d --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/kernel.pto @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s16_broadcast_reduce_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + %b = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %y = pto.vmi.mulf %x, %b + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %ysum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/launch.cpp new file mode 100644 index 0000000000..f180d41359 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/launch.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s16_broadcast_reduce_store_kernel(__gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_reduce_s16_broadcast_reduce_store_kernel(float *src, + float *dst, + void *stream) { + vmi_group_reduce_s16_broadcast_reduce_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/main.cpp new file mode 100644 index 0000000000..f3b88b52fa --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/main.cpp @@ -0,0 +1,82 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s16_broadcast_reduce_store_kernel(float *src, + float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 16; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s16_broadcast_reduce_store_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/compare.py b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/compare.py new file mode 100644 index 0000000000..17b5e600cc --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/compare.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + close = np.isclose(golden, output, atol=1e-4, rtol=1e-4) + if golden.shape != output.shape or not np.all(close): + diff = np.nonzero(~close)[0] + idx = int(diff[0]) if diff.size else -1 + g = golden[idx] if idx >= 0 and idx < golden.size else "n/a" + o = output[idx] if idx >= 0 and idx < output.size else "n/a" + print(f"[ERROR] compare failed idx={idx} golden={g} output={o}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/golden.py b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/golden.py new file mode 100644 index 0000000000..f8e59f415f --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 16 +ACTIVE = 12 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, GROUP_SIZE), dtype=np.float32) + active_base = np.linspace(-0.5, 0.75, ACTIVE, dtype=np.float32) + inactive_base = np.linspace(21.0, 24.0, GROUP_SIZE - ACTIVE, dtype=np.float32) + golden = np.empty(ROWS, dtype=np.float32) + for row in range(ROWS): + src[row, :ACTIVE] = active_base + np.float32(row) * np.float32(0.046875) + src[row, ACTIVE:] = inactive_base + np.float32(row) * np.float32(1.5) + reduction = np.sum(src[row, :ACTIVE], dtype=np.float32) + golden[row] = np.sum(src[row, :ACTIVE] * reduction, dtype=np.float32) + + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.astype(np.float32).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/kernel.pto new file mode 100644 index 0000000000..56f042af1e --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/kernel.pto @@ -0,0 +1,63 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s16_group_mask_broadcast_reduce_store_kernel( + %src_gm: !pto.ptr, %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c12 = arith.constant 12 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x = pto.vmi.group_load %ub_src[%c0], %c16 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_group_mask %c12 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + %b = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %y = pto.vmi.mulf %x, %b + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %ysum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/launch.cpp new file mode 100644 index 0000000000..bd5cc88024 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/launch.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s16_group_mask_broadcast_reduce_store_kernel( + __gm__ float *src, __gm__ float *dst); + +void LaunchVmi_group_reduce_s16_group_mask_broadcast_reduce_store_kernel( + float *src, float *dst, void *stream) { + vmi_group_reduce_s16_group_mask_broadcast_reduce_store_kernel<<<1, nullptr, + stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/main.cpp new file mode 100644 index 0000000000..b87811e20c --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/main.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s16_group_mask_broadcast_reduce_store_kernel( + float *src, float *dst, void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 16; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *dstHost = nullptr; + float *srcDevice = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s16_group_mask_broadcast_reduce_store_kernel( + srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/compare.py b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/compare.py new file mode 100644 index 0000000000..17b5e600cc --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/compare.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + close = np.isclose(golden, output, atol=1e-4, rtol=1e-4) + if golden.shape != output.shape or not np.all(close): + diff = np.nonzero(~close)[0] + idx = int(diff[0]) if diff.size else -1 + g = golden[idx] if idx >= 0 and idx < golden.size else "n/a" + o = output[idx] if idx >= 0 and idx < output.size else "n/a" + print(f"[ERROR] compare failed idx={idx} golden={g} output={o}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/golden.py b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/golden.py new file mode 100644 index 0000000000..808e7e271f --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 16 +ACTIVE = 12 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, GROUP_SIZE), dtype=np.float32) + active_base = np.linspace(-0.75, 0.375, ACTIVE, dtype=np.float32) + inactive_base = np.linspace(25.0, 28.0, GROUP_SIZE - ACTIVE, dtype=np.float32) + for row in range(ROWS): + src[row, :ACTIVE] = active_base + np.float32(row) * np.float32(0.0625) + src[row, ACTIVE:] = inactive_base + np.float32(row) * np.float32(2.0) + + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = np.sum(src[:, :ACTIVE], axis=1, dtype=np.float32).astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/kernel.pto new file mode 100644 index 0000000000..c07f2782fd --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/kernel.pto @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s16_group_mask_tail_store_kernel( + %src_gm: !pto.ptr, %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c12 = arith.constant 12 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x = pto.vmi.group_load %ub_src[%c0], %c16 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_group_mask %c12 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/launch.cpp new file mode 100644 index 0000000000..745e836949 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/launch.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s16_group_mask_tail_store_kernel(__gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_reduce_s16_group_mask_tail_store_kernel(float *src, + float *dst, + void *stream) { + vmi_group_reduce_s16_group_mask_tail_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/main.cpp new file mode 100644 index 0000000000..3d55e6ccfa --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/main.cpp @@ -0,0 +1,82 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s16_group_mask_tail_store_kernel(float *src, + float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 16; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *dstHost = nullptr; + float *srcDevice = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s16_group_mask_tail_store_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/compare.py b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/compare.py new file mode 100644 index 0000000000..17b5e600cc --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/compare.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + close = np.isclose(golden, output, atol=1e-4, rtol=1e-4) + if golden.shape != output.shape or not np.all(close): + diff = np.nonzero(~close)[0] + idx = int(diff[0]) if diff.size else -1 + g = golden[idx] if idx >= 0 and idx < golden.size else "n/a" + o = output[idx] if idx >= 0 and idx < output.size else "n/a" + print(f"[ERROR] compare failed idx={idx} golden={g} output={o}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/golden.py b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/golden.py new file mode 100644 index 0000000000..d3f358ba45 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 16 +ACTIVE = 12 +ROW_STRIDE = 24 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.full(ROWS * ROW_STRIDE, np.float32(99.0), dtype=np.float32) + golden = np.empty(ROWS, dtype=np.float32) + active_base = np.linspace(-0.625, 0.5, ACTIVE, dtype=np.float32) + inactive_base = np.linspace(31.0, 35.0, GROUP_SIZE - ACTIVE, dtype=np.float32) + for row in range(ROWS): + begin = row * ROW_STRIDE + src[begin : begin + ACTIVE] = active_base + np.float32(row) * np.float32(0.03125) + src[begin + ACTIVE : begin + GROUP_SIZE] = inactive_base + np.float32(row) + golden[row] = np.sum(src[begin : begin + ACTIVE], dtype=np.float32) + + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.astype(np.float32).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/kernel.pto new file mode 100644 index 0000000000..b53a1a51ff --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/kernel.pto @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s16_stride_group_mask_tail_store_kernel( + %src_gm: !pto.ptr, %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c12 = arith.constant 12 : index + %c24 = arith.constant 24 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c768_i64 = arith.constant 768 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c768_i64 + nburst(%c1_i64, %c768_i64, %c768_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x = pto.vmi.group_load %ub_src[%c0], %c24 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_group_mask %c12 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/launch.cpp new file mode 100644 index 0000000000..ef2e2aaef2 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/launch.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s16_stride_group_mask_tail_store_kernel(__gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_reduce_s16_stride_group_mask_tail_store_kernel( + float *src, float *dst, void *stream) { + vmi_group_reduce_s16_stride_group_mask_tail_store_kernel<<<1, nullptr, + stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/main.cpp new file mode 100644 index 0000000000..4a6af8cac7 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/main.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s16_stride_group_mask_tail_store_kernel( + float *src, float *dst, void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kRowStride = 24; + constexpr size_t kInputElems = kRows * kRowStride; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *dstHost = nullptr; + float *srcDevice = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s16_stride_group_mask_tail_store_kernel( + srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/compare.py b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/compare.py new file mode 100644 index 0000000000..39f37ccd7c --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float16) + output = np.fromfile("v2.bin", dtype=np.float16) + if golden.shape != output.shape or not np.array_equal(golden, output): + diff = np.nonzero(golden.view(np.uint16) != output.view(np.uint16))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/golden.py b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/golden.py new file mode 100644 index 0000000000..2010556d20 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 16 +ELEMS = ROWS * GROUP_SIZE +SEED = 29 +SENTINEL = np.float16(-17.5) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + src = rng.uniform(-2.0, 2.0, size=ELEMS).astype(np.float32) + dst = np.full(ELEMS, SENTINEL, dtype=np.float16) + golden = np.full(ELEMS, SENTINEL, dtype=np.float16) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = src[begin : begin + GROUP_SIZE] + row_sum = np.sum(values, dtype=np.float32).astype(np.float16) + golden[begin : begin + GROUP_SIZE] = row_sum + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/kernel.pto new file mode 100644 index 0000000000..29193f5d6b --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/kernel.pto @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s16_truncf_broadcast_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum32 = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + %sum16 = pto.vmi.truncf %sum32 + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + %rows = pto.vmi.group_broadcast %sum16 {num_groups = 8} + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf16> + pto.vmi.store %rows, %ub_dst[%c0] : !pto.vmi.vreg<128xf16>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/launch.cpp new file mode 100644 index 0000000000..21b6e43c3d --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/launch.cpp @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s16_truncf_broadcast_store_kernel(__gm__ float *src, + __gm__ half *dst); + +void LaunchVmi_group_reduce_s16_truncf_broadcast_store_kernel(float *src, + uint16_t *dst, + void *stream) { + vmi_group_reduce_s16_truncf_broadcast_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ half *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/main.cpp new file mode 100644 index 0000000000..13fe482440 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/main.cpp @@ -0,0 +1,80 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s16_truncf_broadcast_store_kernel(float *src, + uint16_t *dst, + void *stream); + +int main() { + constexpr size_t kElems = 128; + size_t srcBytes = kElems * sizeof(float); + size_t dstBytes = kElems * sizeof(uint16_t); + float *srcHost = nullptr; + float *srcDevice = nullptr; + uint16_t *dstHost = nullptr; + uint16_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s16_truncf_broadcast_store_kernel(srcDevice, + dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/compare.py b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/compare.py new file mode 100644 index 0000000000..28299087e5 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/golden.py b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/golden.py new file mode 100644 index 0000000000..1614628a0b --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 32 +INPUT_ELEMS = ROWS * GROUP_SIZE +BIAS = np.float32(0.25) +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty(INPUT_ELEMS, dtype=np.float32) + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = np.empty(ROWS, dtype=np.float32) + + base_row = np.linspace(-0.5, 0.75, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = base_row + np.float32(row) * np.float32(0.03125) + src[begin : begin + GROUP_SIZE] = values + golden[row] = np.sum(values + BIAS, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/kernel.pto new file mode 100644 index 0000000000..d21fb5efd2 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/kernel.pto @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s32_add_bias_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %bias = arith.constant 2.500000e-01 : f32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<256xf32> + %biasv = pto.vmi.broadcast %bias : f32 -> !pto.vmi.vreg<256xf32> + %biased = pto.vmi.addf %x, %biasv + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %biased, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/launch.cpp new file mode 100644 index 0000000000..b5526b9b23 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/launch.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s32_add_bias_store_kernel(__gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_reduce_s32_add_bias_store_kernel(float *src, float *dst, + void *stream) { + vmi_group_reduce_s32_add_bias_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/main.cpp new file mode 100644 index 0000000000..5c85668ceb --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/main.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s32_add_bias_store_kernel(float *src, float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 32; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s32_add_bias_store_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/compare.py b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/compare.py new file mode 100644 index 0000000000..28299087e5 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/golden.py b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/golden.py new file mode 100644 index 0000000000..aef1ece1b4 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 32 +INPUT_ELEMS = ROWS * GROUP_SIZE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty(INPUT_ELEMS, dtype=np.float32) + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = np.empty(ROWS, dtype=np.float32) + + base_row = np.linspace(-0.875, 0.625, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = base_row + np.float32(row) * np.float32(0.0625) + src[begin : begin + GROUP_SIZE] = values + reduction = np.sum(values, dtype=np.float32) + golden[row] = np.sum(values * reduction, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/kernel.pto new file mode 100644 index 0000000000..f51fe89924 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/kernel.pto @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s32_broadcast_reduce_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + %b = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %y = pto.vmi.mulf %x, %b + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %ysum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/launch.cpp new file mode 100644 index 0000000000..e8decb88f5 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/launch.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s32_broadcast_reduce_store_kernel(__gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_reduce_s32_broadcast_reduce_store_kernel(float *src, + float *dst, + void *stream) { + vmi_group_reduce_s32_broadcast_reduce_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/main.cpp new file mode 100644 index 0000000000..eba17dbdd0 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/main.cpp @@ -0,0 +1,82 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s32_broadcast_reduce_store_kernel(float *src, + float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 32; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s32_broadcast_reduce_store_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/compare.py b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/compare.py new file mode 100644 index 0000000000..28299087e5 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/golden.py b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/golden.py new file mode 100644 index 0000000000..409f321f7d --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/golden.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 32 +INPUT_ELEMS = ROWS * GROUP_SIZE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty(INPUT_ELEMS, dtype=np.float32) + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = np.empty(ROWS, dtype=np.float32) + + base_row = np.linspace(-0.5, 0.75, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = base_row + np.float32(row) * np.float32(0.03125) + src[begin : begin + GROUP_SIZE] = values + golden[row] = np.sum(values, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/kernel.pto new file mode 100644 index 0000000000..de08d084e6 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/kernel.pto @@ -0,0 +1,63 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s32_cf_join_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %zero = arith.constant 0.000000e+00 : f32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %cond = arith.cmpi eq, %c0, %c0 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x = scf.if %cond -> (!pto.vmi.vreg<256xf32>) { + %then_x = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + scf.yield %then_x : !pto.vmi.vreg<256xf32> + } else { + %else_x = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %zero_vec = pto.vmi.broadcast %zero : f32 -> !pto.vmi.vreg<256xf32> + %else_y = pto.vmi.addf %else_x, %zero_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + scf.yield %else_y : !pto.vmi.vreg<256xf32> + } + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/launch.cpp new file mode 100644 index 0000000000..4204a6ca52 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/launch.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s32_cf_join_store_kernel(__gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_reduce_s32_cf_join_store_kernel(float *src, float *dst, + void *stream) { + vmi_group_reduce_s32_cf_join_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/main.cpp new file mode 100644 index 0000000000..a504036a2e --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/main.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s32_cf_join_store_kernel(float *src, float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 32; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s32_cf_join_store_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s32-multitile-store/compare.py b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/compare.py new file mode 100644 index 0000000000..28299087e5 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s32-multitile-store/golden.py b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/golden.py new file mode 100644 index 0000000000..a00c19efbe --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/golden.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 16 +GROUP_SIZE = 32 +INPUT_ELEMS = ROWS * GROUP_SIZE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty(INPUT_ELEMS, dtype=np.float32) + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = np.empty(ROWS, dtype=np.float32) + + base_row = np.linspace(-0.5, 0.75, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = base_row + np.float32(row) * np.float32(0.03125) + src[begin : begin + GROUP_SIZE] = values + golden[row] = np.sum(values, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s32-multitile-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/kernel.pto new file mode 100644 index 0000000000..758691c5cf --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/kernel.pto @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s32_multitile_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c512 = arith.constant 512 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c2048_i64 + nburst(%c1_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c512 : index -> !pto.vmi.mask<512xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<512xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 16, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<512xf32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 16} + : !pto.vmi.vreg<512xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-multitile-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/launch.cpp new file mode 100644 index 0000000000..88c109d7d0 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/launch.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s32_multitile_store_kernel(__gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_reduce_s32_multitile_store_kernel(float *src, float *dst, + void *stream) { + vmi_group_reduce_s32_multitile_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-multitile-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/main.cpp new file mode 100644 index 0000000000..f30ea2a367 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/main.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s32_multitile_store_kernel(float *src, float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 16; + constexpr size_t kGroupSize = 32; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s32_multitile_store_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-multitile-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/compare.py b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/compare.py new file mode 100644 index 0000000000..8c5fc67aca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/compare.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + close = np.isclose(golden, output, atol=1e-4, rtol=1e-4) + if golden.shape != output.shape or not np.all(close): + diff = np.nonzero(~close)[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/golden.py b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/golden.py new file mode 100644 index 0000000000..cf80936861 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +PHYSICAL_ROWS = 8 +ACTIVE_ROWS = 6 +GROUP_SIZE = 32 +INPUT_ELEMS = PHYSICAL_ROWS * GROUP_SIZE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty(INPUT_ELEMS, dtype=np.float32) + dst = np.full(PHYSICAL_ROWS, SENTINEL, dtype=np.float32) + golden = np.full(PHYSICAL_ROWS, SENTINEL, dtype=np.float32) + + base_row = np.linspace(-0.875, 0.625, GROUP_SIZE, dtype=np.float32) + for row in range(PHYSICAL_ROWS): + begin = row * GROUP_SIZE + values = base_row + np.float32(row) * np.float32(0.0625) + src[begin : begin + GROUP_SIZE] = values + if row < ACTIVE_ROWS: + golden[row] = np.sum(values, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/kernel.pto new file mode 100644 index 0000000000..fabed4ee8b --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/kernel.pto @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s32_tail_full_tile_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c192 = arith.constant 192 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c192 : index -> !pto.vmi.mask<192xpred> + %x = pto.vmi.load %ub_src[%c0] {full_read_elems = 256} + : !pto.ptr -> !pto.vmi.vreg<192xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6, reassoc} + : !pto.vmi.vreg<192xf32>, !pto.vmi.mask<192xpred> + -> !pto.vmi.vreg<192xf32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 6} + : !pto.vmi.vreg<192xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/launch.cpp new file mode 100644 index 0000000000..5dd1b3c148 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/launch.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s32_tail_full_tile_store_kernel(__gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_reduce_s32_tail_full_tile_store_kernel(float *src, + float *dst, + void *stream) { + vmi_group_reduce_s32_tail_full_tile_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/main.cpp new file mode 100644 index 0000000000..5cd1b690d2 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/main.cpp @@ -0,0 +1,82 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s32_tail_full_tile_store_kernel(float *src, + float *dst, + void *stream); + +int main() { + constexpr size_t kPhysicalRows = 8; + constexpr size_t kGroupSize = 32; + constexpr size_t kInputElems = kPhysicalRows * kGroupSize; + constexpr size_t kOutputElems = kPhysicalRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s32_tail_full_tile_store_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/compare.py b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/compare.py new file mode 100644 index 0000000000..28299087e5 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/golden.py b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/golden.py new file mode 100644 index 0000000000..24fa390b6c --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 64 +INPUT_ELEMS = ROWS * GROUP_SIZE +OUTPUT_STRIDE = 8 +OUTPUT_ELEMS = ROWS * OUTPUT_STRIDE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty(INPUT_ELEMS, dtype=np.float32) + dst = np.full(OUTPUT_ELEMS, SENTINEL, dtype=np.float32) + golden = np.full(OUTPUT_ELEMS, SENTINEL, dtype=np.float32) + + base_row = np.linspace(-0.5, 0.5, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = base_row + np.float32(row) * np.float32(0.03125) + src[begin : begin + GROUP_SIZE] = values + reduction = np.sum(values, dtype=np.float32) + golden[row * OUTPUT_STRIDE] = np.sum(values * reduction, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/kernel.pto new file mode 100644 index 0000000000..bcb027a753 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/kernel.pto @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s64_broadcast_reduce_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c512 = arith.constant 512 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c2048_i64 + nburst(%c1_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c512 : index -> !pto.vmi.mask<512xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<512xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<512xf32> + %b = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<512xf32> -> !pto.vmi.vreg<512xf32> + %y = pto.vmi.mulf %x, %b + : !pto.vmi.vreg<512xf32>, !pto.vmi.vreg<512xf32> + -> !pto.vmi.vreg<512xf32> + %ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<512xf32> + pto.vmi.group_store %ysum, %ub_dst[%c0], %c8 {num_groups = 8} + : !pto.vmi.vreg<512xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/launch.cpp new file mode 100644 index 0000000000..ba45139736 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/launch.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s64_broadcast_reduce_store_kernel(__gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_reduce_s64_broadcast_reduce_store_kernel(float *src, + float *dst, + void *stream) { + vmi_group_reduce_s64_broadcast_reduce_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/main.cpp new file mode 100644 index 0000000000..91e2c97119 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s64_broadcast_reduce_store_kernel(float *src, + float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 64; + constexpr size_t kOutputStride = 8; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kOutputElems = kRows * kOutputStride; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s64_broadcast_reduce_store_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/compare.py b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/compare.py new file mode 100644 index 0000000000..be861f3da8 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/compare.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.float32) + output = np.fromfile("v3.bin", dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + print("[INFO] compare passed") + return + + if golden.shape != output.shape: + print(f"[ERROR] compare failed: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/golden.py b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/golden.py new file mode 100644 index 0000000000..6d0d25229a --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 64 +RHS_STRIDE = 8 +OUTPUT_STRIDE = 8 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, GROUP_SIZE), dtype=np.float32) + base_row = np.linspace(-0.5, 0.5, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + src[row, :] = base_row + np.float32(row) * np.float32(0.03125) + + rhs = np.linspace(-0.75, 0.75, ROWS * RHS_STRIDE, dtype=np.float32) + dst = np.full(ROWS * OUTPUT_STRIDE, SENTINEL, dtype=np.float32) + golden = np.full(ROWS * OUTPUT_STRIDE, SENTINEL, dtype=np.float32) + for row in range(ROWS): + golden[row * OUTPUT_STRIDE] = ( + np.sum(src[row, :], dtype=np.float32) + rhs[row * RHS_STRIDE] + ) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + rhs.tofile(output_dir / "v2.bin") + dst.tofile(output_dir / "v3.bin") + golden.astype(np.float32).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/kernel.pto new file mode 100644 index 0000000000..04338c1c1b --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/kernel.pto @@ -0,0 +1,64 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s64_slot_add_store_kernel(%src_gm: !pto.ptr, + %rhs_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %c512 = arith.constant 512 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c2048_i64 + nburst(%c1_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %rhs_gm, %ub_rhs, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c512 : index -> !pto.vmi.mask<512xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<512xf32> + %rhs = pto.vmi.group_slot_load %ub_rhs[%c0], %c8 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<512xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<512xf32> + %out = pto.vmi.addf %sum, %rhs + : !pto.vmi.vreg<512xf32>, !pto.vmi.vreg<512xf32> + -> !pto.vmi.vreg<512xf32> + pto.vmi.group_store %out, %ub_dst[%c0], %c8 {num_groups = 8} + : !pto.vmi.vreg<512xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/launch.cpp new file mode 100644 index 0000000000..7225148ff7 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/launch.cpp @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s64_slot_add_store_kernel(__gm__ float *src, + __gm__ float *rhs, + __gm__ float *dst); + +void LaunchVmi_group_reduce_s64_slot_add_store_kernel(float *src, float *rhs, + float *dst, + void *stream) { + vmi_group_reduce_s64_slot_add_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)rhs, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/main.cpp new file mode 100644 index 0000000000..1f5acfaa5c --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s64_slot_add_store_kernel(float *src, float *rhs, + float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 64; + constexpr size_t kRhsStride = 8; + constexpr size_t kOutputStride = 8; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kRhsElems = kRows * kRhsStride; + constexpr size_t kOutputElems = kRows * kOutputStride; + size_t srcBytes = kInputElems * sizeof(float); + size_t rhsBytes = kRhsElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *rhsHost = nullptr; + float *dstHost = nullptr; + float *srcDevice = nullptr; + float *rhsDevice = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&rhsHost), rhsBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&rhsDevice, rhsBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", rhsBytes, rhsHost, rhsBytes); + ReadFile("./v3.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(rhsDevice, rhsBytes, rhsHost, rhsBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s64_slot_add_store_kernel(srcDevice, rhsDevice, + dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(rhsDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(rhsHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s64-tail-store/compare.py b/test/vpto/cases/vmi/group-reduce-s64-tail-store/compare.py new file mode 100644 index 0000000000..17b5e600cc --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-tail-store/compare.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + close = np.isclose(golden, output, atol=1e-4, rtol=1e-4) + if golden.shape != output.shape or not np.all(close): + diff = np.nonzero(~close)[0] + idx = int(diff[0]) if diff.size else -1 + g = golden[idx] if idx >= 0 and idx < golden.size else "n/a" + o = output[idx] if idx >= 0 and idx < output.size else "n/a" + print(f"[ERROR] compare failed idx={idx} golden={g} output={o}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s64-tail-store/golden.py b/test/vpto/cases/vmi/group-reduce-s64-tail-store/golden.py new file mode 100644 index 0000000000..83ac2d015e --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-tail-store/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 6 +GROUP_SIZE = 64 +OUTPUT_STRIDE = 8 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, GROUP_SIZE), dtype=np.float32) + base = np.linspace(-0.625, 0.875, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + src[row, :] = base + np.float32(row) * np.float32(0.046875) + + dst = np.full(ROWS * OUTPUT_STRIDE, SENTINEL, dtype=np.float32) + golden = np.full(ROWS * OUTPUT_STRIDE, SENTINEL, dtype=np.float32) + for row in range(ROWS): + golden[row * OUTPUT_STRIDE] = np.sum(src[row, :], dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.astype(np.float32).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s64-tail-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s64-tail-store/kernel.pto new file mode 100644 index 0000000000..5167c9198a --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-tail-store/kernel.pto @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s64_tail_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %c384 = arith.constant 384 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c192_i64 = arith.constant 192 : i64 + %c1536_i64 = arith.constant 1536 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1536_i64 + nburst(%c1_i64, %c1536_i64, %c1536_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c192_i64 + nburst(%c1_i64, %c192_i64, %c192_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c384 : index -> !pto.vmi.mask<384xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<384xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6, reassoc} + : !pto.vmi.vreg<384xf32>, !pto.vmi.mask<384xpred> + -> !pto.vmi.vreg<384xf32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c8 {num_groups = 6} + : !pto.vmi.vreg<384xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c192_i64 + nburst(%c1_i64, %c192_i64, %c192_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-tail-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s64-tail-store/launch.cpp new file mode 100644 index 0000000000..afdf98b76d --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-tail-store/launch.cpp @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s64_tail_store_kernel(__gm__ float *src, __gm__ float *dst); + +void LaunchVmi_group_reduce_s64_tail_store_kernel(float *src, float *dst, + void *stream) { + vmi_group_reduce_s64_tail_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-tail-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s64-tail-store/main.cpp new file mode 100644 index 0000000000..3223b3561b --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-tail-store/main.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s64_tail_store_kernel(float *src, float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 6; + constexpr size_t kGroupSize = 64; + constexpr size_t kOutputStride = 8; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kOutputElems = kRows * kOutputStride; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *dstHost = nullptr; + float *srcDevice = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s64_tail_store_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-tail-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s64-tail-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-tail-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s64-truncf-store/compare.py b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/compare.py new file mode 100644 index 0000000000..cce2c778b9 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/compare.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float16) + output = np.fromfile("v2.bin", dtype=np.float16) + if golden.shape == output.shape and np.array_equal(golden, output): + print("[INFO] compare passed") + return + diff = np.nonzero(golden.view(np.uint16) != output.view(np.uint16))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s64-truncf-store/golden.py b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/golden.py new file mode 100644 index 0000000000..62b6de2d6e --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/golden.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 64 +OUTPUT_STRIDE = 16 +SENTINEL = np.float16(-17.5) + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, GROUP_SIZE), dtype=np.float32) + base = np.linspace(-0.625, 0.875, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + src[row, :] = base + np.float32(row) * np.float32(0.046875) + + dst = np.full(ROWS * OUTPUT_STRIDE, SENTINEL, dtype=np.float16) + golden = np.full(ROWS * OUTPUT_STRIDE, SENTINEL, dtype=np.float16) + for row in range(ROWS): + row_sum = np.sum(src[row, :], dtype=np.float32) + golden[row * OUTPUT_STRIDE] = np.float16(row_sum) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s64-truncf-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/kernel.pto new file mode 100644 index 0000000000..6436738080 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/kernel.pto @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s64_truncf_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c512 = arith.constant 512 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c2048_i64 + nburst(%c1_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c512 : index -> !pto.vmi.mask<512xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<512xf32> + %sum32 = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<512xf32> + %sum16 = pto.vmi.truncf %sum32 + : !pto.vmi.vreg<512xf32> -> !pto.vmi.vreg<512xf16> + pto.vmi.group_store %sum16, %ub_dst[%c0], %c16 {num_groups = 8} + : !pto.vmi.vreg<512xf16>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-truncf-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/launch.cpp new file mode 100644 index 0000000000..bd0c1e4fa2 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/launch.cpp @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s64_truncf_store_kernel(__gm__ float *src, __gm__ half *dst); + +void LaunchVmi_group_reduce_s64_truncf_store_kernel(float *src, uint16_t *dst, + void *stream) { + vmi_group_reduce_s64_truncf_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ half *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-truncf-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/main.cpp new file mode 100644 index 0000000000..941a7d4622 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/main.cpp @@ -0,0 +1,79 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s64_truncf_store_kernel(float *src, uint16_t *dst, + void *stream); + +int main() { + constexpr size_t kSrcElems = 512; + constexpr size_t kDstElems = 128; + size_t srcBytes = kSrcElems * sizeof(float); + size_t dstBytes = kDstElems * sizeof(uint16_t); + float *srcHost = nullptr; + float *srcDevice = nullptr; + uint16_t *dstHost = nullptr; + uint16_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s64_truncf_store_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-truncf-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-slot-add-store/compare.py b/test/vpto/cases/vmi/group-reduce-slot-add-store/compare.py new file mode 100644 index 0000000000..edcf881e8d --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-slot-add-store/compare.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(name: str, golden_name: str) -> None: + golden = np.fromfile(golden_name, dtype=np.float32) + output = np.fromfile(name, dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + return + + if golden.shape != output.shape: + print(f"[ERROR] compare failed {name}: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +def main() -> None: + check("v4.bin", "golden_v4.bin") + check("v5.bin", "golden_v5.bin") + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-slot-add-store/golden.py b/test/vpto/cases/vmi/group-reduce-slot-add-store/golden.py new file mode 100644 index 0000000000..7e57da8318 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-slot-add-store/golden.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +S16 = 16 +S32 = 32 +SENTINEL = np.float32(-777.0) + + +def fill_matrix(rows: int, cols: int, base_start: float, row_step: float) -> np.ndarray: + base = np.linspace(base_start, base_start + 1.0, cols, dtype=np.float32) + out = np.empty((rows, cols), dtype=np.float32) + for row in range(rows): + out[row, :] = base + np.float32(row) * np.float32(row_step) + return out + + +def generate(output_dir: Path) -> None: + src16 = fill_matrix(ROWS, S16, -0.75, 0.03125) + src32 = fill_matrix(ROWS, S32, -0.875, 0.0625) + rhs = np.linspace(-0.25, 0.625, ROWS, dtype=np.float32) + dst16 = np.full(ROWS, SENTINEL, dtype=np.float32) + dst32 = np.full(ROWS, SENTINEL, dtype=np.float32) + + golden16 = np.sum(src16, axis=1, dtype=np.float32).astype(np.float32) + rhs + golden32 = np.sum(src32, axis=1, dtype=np.float32).astype(np.float32) + rhs + + output_dir.mkdir(parents=True, exist_ok=True) + src16.reshape(-1).tofile(output_dir / "v1.bin") + src32.reshape(-1).tofile(output_dir / "v2.bin") + rhs.tofile(output_dir / "v3.bin") + dst16.tofile(output_dir / "v4.bin") + dst32.tofile(output_dir / "v5.bin") + golden16.astype(np.float32).tofile(output_dir / "golden_v4.bin") + golden32.astype(np.float32).tofile(output_dir / "golden_v5.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-slot-add-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-slot-add-store/kernel.pto new file mode 100644 index 0000000000..291251e0bf --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-slot-add-store/kernel.pto @@ -0,0 +1,86 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_slot_add_store_kernel(%src16_gm: !pto.ptr, + %src32_gm: !pto.ptr, + %rhs_gm: !pto.ptr, + %dst16_gm: !pto.ptr, + %dst32_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + + %ub_src16 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_src32 = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_dst16 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_dst32 = pto.castptr %c12288_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src16_gm, %ub_src16, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %src32_gm, %ub_src32, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %rhs_gm, %ub_rhs, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask16 = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %x16 = pto.vmi.load %ub_src16[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %rhs16 = pto.vmi.group_slot_load %ub_rhs[%c0], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum16 = pto.vmi.group_reduce_addf %x16, %mask16 {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + %out16 = pto.vmi.addf %sum16, %rhs16 + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %out16, %ub_dst16[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + + %mask32 = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x32 = pto.vmi.load %ub_src32[%c0] : !pto.ptr -> !pto.vmi.vreg<256xf32> + %rhs32 = pto.vmi.group_slot_load %ub_rhs[%c0], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %sum32 = pto.vmi.group_reduce_addf %x32, %mask32 {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + %out32 = pto.vmi.addf %sum32, %rhs32 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %out32, %ub_dst32[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst16, %dst16_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_dst32, %dst32_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-slot-add-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-slot-add-store/launch.cpp new file mode 100644 index 0000000000..ba7b786e51 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-slot-add-store/launch.cpp @@ -0,0 +1,38 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_slot_add_store_kernel(__gm__ float *src16, + __gm__ float *src32, + __gm__ float *rhs, + __gm__ float *dst16, + __gm__ float *dst32); + +void LaunchVmi_group_reduce_slot_add_store_kernel(float *src16, float *src32, + float *rhs, float *dst16, + float *dst32, void *stream) { + vmi_group_reduce_slot_add_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src16, (__gm__ float *)src32, (__gm__ float *)rhs, + (__gm__ float *)dst16, (__gm__ float *)dst32); +} diff --git a/test/vpto/cases/vmi/group-reduce-slot-add-store/main.cpp b/test/vpto/cases/vmi/group-reduce-slot-add-store/main.cpp new file mode 100644 index 0000000000..111426c192 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-slot-add-store/main.cpp @@ -0,0 +1,113 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_slot_add_store_kernel(float *src16, float *src32, + float *rhs, float *dst16, + float *dst32, void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kS16 = 16; + constexpr size_t kS32 = 32; + constexpr size_t kSrc16Elems = kRows * kS16; + constexpr size_t kSrc32Elems = kRows * kS32; + constexpr size_t kRhsElems = kRows; + constexpr size_t kOutputElems = kRows; + size_t src16Bytes = kSrc16Elems * sizeof(float); + size_t src32Bytes = kSrc32Elems * sizeof(float); + size_t rhsBytes = kRhsElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *src16Host = nullptr; + float *src32Host = nullptr; + float *rhsHost = nullptr; + float *dst16Host = nullptr; + float *dst32Host = nullptr; + float *src16Device = nullptr; + float *src32Device = nullptr; + float *rhsDevice = nullptr; + float *dst16Device = nullptr; + float *dst32Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&src16Host), src16Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&src32Host), src32Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&rhsHost), rhsBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dst16Host), dstBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dst32Host), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&src16Device, src16Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&src32Device, src32Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&rhsDevice, rhsBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dst16Device, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dst32Device, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", src16Bytes, src16Host, src16Bytes); + ReadFile("./v2.bin", src32Bytes, src32Host, src32Bytes); + ReadFile("./v3.bin", rhsBytes, rhsHost, rhsBytes); + ReadFile("./v4.bin", dstBytes, dst16Host, dstBytes); + ReadFile("./v5.bin", dstBytes, dst32Host, dstBytes); + ACL_CHECK(aclrtMemcpy(src16Device, src16Bytes, src16Host, src16Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(src32Device, src32Bytes, src32Host, src32Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(rhsDevice, rhsBytes, rhsHost, rhsBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dst16Device, dstBytes, dst16Host, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dst32Device, dstBytes, dst32Host, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_slot_add_store_kernel( + src16Device, src32Device, rhsDevice, dst16Device, dst32Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dst16Host, dstBytes, dst16Device, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(dst32Host, dstBytes, dst32Device, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v4.bin", dst16Host, dstBytes); + WriteFile("./v5.bin", dst32Host, dstBytes); + +cleanup: + aclrtFree(src16Device); + aclrtFree(src32Device); + aclrtFree(rhsDevice); + aclrtFree(dst16Device); + aclrtFree(dst32Device); + aclrtFreeHost(src16Host); + aclrtFreeHost(src32Host); + aclrtFreeHost(rhsHost); + aclrtFreeHost(dst16Host); + aclrtFreeHost(dst32Host); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-slot-add-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-slot-add-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-slot-add-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-slots-cf-join-store/compare.py b/test/vpto/cases/vmi/group-slots-cf-join-store/compare.py new file mode 100644 index 0000000000..60aeab3da6 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-cf-join-store/compare.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(name: str) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.float32) + output = np.fromfile(f"{name}.bin", dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + return True + close = np.isclose(golden, output, atol=1e-4, rtol=1e-4) + diff = np.nonzero(~close)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + return False + + +def main() -> None: + if not check("v3") or not check("v4"): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-slots-cf-join-store/golden.py b/test/vpto/cases/vmi/group-slots-cf-join-store/golden.py new file mode 100644 index 0000000000..fa1fc04fe6 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-cf-join-store/golden.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 16 +INPUT_ELEMS = ROWS * GROUP_SIZE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty(INPUT_ELEMS, dtype=np.float32) + rhs = np.linspace(-0.375, 0.5, ROWS, dtype=np.float32) + dst_reduce = np.full(ROWS, SENTINEL, dtype=np.float32) + dst_slot = np.full(ROWS, SENTINEL, dtype=np.float32) + golden_reduce = np.empty(ROWS, dtype=np.float32) + golden_slot = rhs + rhs + + base_row = np.linspace(-0.625, 0.875, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = base_row + np.float32(row) * np.float32(0.125) + src[begin : begin + GROUP_SIZE] = values + golden_reduce[row] = np.sum(values, dtype=np.float32) + rhs[row] + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + rhs.tofile(output_dir / "v2.bin") + dst_reduce.tofile(output_dir / "v3.bin") + dst_slot.tofile(output_dir / "v4.bin") + golden_reduce.tofile(output_dir / "golden_v3.bin") + golden_slot.astype(np.float32).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-slots-cf-join-store/kernel.pto b/test/vpto/cases/vmi/group-slots-cf-join-store/kernel.pto new file mode 100644 index 0000000000..7fcdd382c8 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-cf-join-store/kernel.pto @@ -0,0 +1,97 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_slots_cf_join_store_kernel(%src_gm: !pto.ptr, + %rhs_gm: !pto.ptr, + %dst_reduce_gm: !pto.ptr, + %dst_slot_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_dst_reduce = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_dst_slot = pto.castptr %c12288_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %rhs_gm, %ub_rhs, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %cond_true = arith.cmpi eq, %c0, %c0 : index + %cond_false = arith.cmpi ne, %c0, %c0 : index + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + + %reduce_join = scf.if %cond_true -> !pto.vmi.vreg<128xf32> { + %x = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + scf.yield %sum : !pto.vmi.vreg<128xf32> + } else { + %slot = pto.vmi.group_slot_load %ub_rhs[%c0], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + scf.yield %slot : !pto.vmi.vreg<128xf32> + } + %bias0 = pto.vmi.group_slot_load %ub_rhs[%c0], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %reduce_out = pto.vmi.addf %reduce_join, %bias0 + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %reduce_out, %ub_dst_reduce[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + + %slot_join = scf.if %cond_false -> !pto.vmi.vreg<128xf32> { + %x = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + scf.yield %sum : !pto.vmi.vreg<128xf32> + } else { + %slot = pto.vmi.group_slot_load %ub_rhs[%c0], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + scf.yield %slot : !pto.vmi.vreg<128xf32> + } + %bias1 = pto.vmi.group_slot_load %ub_rhs[%c0], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %slot_out = pto.vmi.addf %slot_join, %bias1 + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %slot_out, %ub_dst_slot[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst_reduce, %dst_reduce_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_dst_slot, %dst_slot_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-slots-cf-join-store/launch.cpp b/test/vpto/cases/vmi/group-slots-cf-join-store/launch.cpp new file mode 100644 index 0000000000..add61550a6 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-cf-join-store/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_slots_cf_join_store_kernel(__gm__ float *src, __gm__ float *rhs, + __gm__ float *dstReduce, + __gm__ float *dstSlot); + +void LaunchVmi_group_slots_cf_join_store_kernel(float *src, float *rhs, + float *dstReduce, + float *dstSlot, void *stream) { + vmi_group_slots_cf_join_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)rhs, (__gm__ float *)dstReduce, + (__gm__ float *)dstSlot); +} diff --git a/test/vpto/cases/vmi/group-slots-cf-join-store/main.cpp b/test/vpto/cases/vmi/group-slots-cf-join-store/main.cpp new file mode 100644 index 0000000000..fb8d6ace69 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-cf-join-store/main.cpp @@ -0,0 +1,102 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_slots_cf_join_store_kernel(float *src, float *rhs, + float *dstReduce, + float *dstSlot, void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 16; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t rhsBytes = kOutputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *rhsHost = nullptr; + float *dstReduceHost = nullptr; + float *dstSlotHost = nullptr; + float *srcDevice = nullptr; + float *rhsDevice = nullptr; + float *dstReduceDevice = nullptr; + float *dstSlotDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&rhsHost), rhsBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstReduceHost), dstBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstSlotHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&rhsDevice, rhsBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstReduceDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstSlotDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", rhsBytes, rhsHost, rhsBytes); + ReadFile("./v3.bin", dstBytes, dstReduceHost, dstBytes); + ReadFile("./v4.bin", dstBytes, dstSlotHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(rhsDevice, rhsBytes, rhsHost, rhsBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstReduceDevice, dstBytes, dstReduceHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstSlotDevice, dstBytes, dstSlotHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_slots_cf_join_store_kernel(srcDevice, rhsDevice, + dstReduceDevice, dstSlotDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstReduceHost, dstBytes, dstReduceDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(dstSlotHost, dstBytes, dstSlotDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", dstReduceHost, dstBytes); + WriteFile("./v4.bin", dstSlotHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(rhsDevice); + aclrtFree(dstReduceDevice); + aclrtFree(dstSlotDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(rhsHost); + aclrtFreeHost(dstReduceHost); + aclrtFreeHost(dstSlotHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-slots-cf-join-store/ptoas.flags b/test/vpto/cases/vmi/group-slots-cf-join-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-cf-join-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/compare.py b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/compare.py new file mode 100644 index 0000000000..49180d97de --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/compare.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(name: str) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.float32) + output = np.fromfile(f"{name}.bin", dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + return True + close = np.isclose(golden, output, atol=1e-4, rtol=1e-4) + diff = np.nonzero(~close)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + return False + + +def main() -> None: + if not check("v2") or not check("v3"): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/golden.py b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/golden.py new file mode 100644 index 0000000000..146d0d1fd2 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/golden.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 16 +INPUT_ELEMS = ROWS * GROUP_SIZE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty(INPUT_ELEMS, dtype=np.float32) + sum_out = np.full(ROWS, SENTINEL, dtype=np.float32) + out = np.full(ROWS, SENTINEL, dtype=np.float32) + golden_sum = np.empty(ROWS, dtype=np.float32) + golden_out = np.empty(ROWS, dtype=np.float32) + + base_row = np.linspace(-0.625, 0.875, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = base_row + np.float32(row) * np.float32(0.125) + src[begin : begin + GROUP_SIZE] = values + reduction = np.sum(values, dtype=np.float32) + golden_sum[row] = reduction + golden_out[row] = np.sum(values * reduction, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + sum_out.tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden_sum.tofile(output_dir / "golden_v2.bin") + golden_out.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/kernel.pto b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/kernel.pto new file mode 100644 index 0000000000..0660b1e0a3 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/kernel.pto @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_slots_fanout_store_broadcast_kernel(%src_gm: !pto.ptr, + %sum_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %sum_gm, %ub_sum, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out_gm, %ub_out, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + %b = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %y = pto.vmi.mulf %x, %b + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %ysum, %ub_out[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out, %out_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/launch.cpp b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/launch.cpp new file mode 100644 index 0000000000..9a0667aae1 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_slots_fanout_store_broadcast_kernel(__gm__ float *src, + __gm__ float *sum, + __gm__ float *out); + +void LaunchVmi_group_slots_fanout_store_broadcast_kernel(float *src, + float *sum, + float *out, + void *stream) { + vmi_group_slots_fanout_store_broadcast_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)sum, (__gm__ float *)out); +} diff --git a/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/main.cpp b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/main.cpp new file mode 100644 index 0000000000..f7b0fee4b8 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/main.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_slots_fanout_store_broadcast_kernel(float *src, + float *sum, + float *out, + void *stream); + +int main() { + constexpr size_t kSrcElems = 128; + constexpr size_t kOutElems = 8; + size_t srcBytes = kSrcElems * sizeof(float); + size_t sumBytes = kOutElems * sizeof(float); + size_t outBytes = kOutElems * sizeof(float); + float *srcHost = nullptr; + float *sumHost = nullptr; + float *outHost = nullptr; + float *srcDevice = nullptr; + float *sumDevice = nullptr; + float *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", sumBytes, sumHost, sumBytes); + ReadFile("./v3.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_slots_fanout_store_broadcast_kernel(srcDevice, sumDevice, + outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", sumHost, sumBytes); + WriteFile("./v3.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(sumDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(sumHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/ptoas.flags b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-slots-scf-for-store/compare.py b/test/vpto/cases/vmi/group-slots-scf-for-store/compare.py new file mode 100644 index 0000000000..be861f3da8 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-scf-for-store/compare.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.float32) + output = np.fromfile("v3.bin", dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + print("[INFO] compare passed") + return + + if golden.shape != output.shape: + print(f"[ERROR] compare failed: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-slots-scf-for-store/golden.py b/test/vpto/cases/vmi/group-slots-scf-for-store/golden.py new file mode 100644 index 0000000000..a62c83071c --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-scf-for-store/golden.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 16 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + init = np.linspace(-0.25, 0.625, ROWS, dtype=np.float32) + base = np.linspace(-0.75, 0.25, COLS, dtype=np.float32) + src = np.empty((ROWS, COLS), dtype=np.float32) + for row in range(ROWS): + src[row, :] = base + np.float32(row) * np.float32(0.03125) + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = init + np.float32(2.0) * np.sum(src, axis=1, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + init.tofile(output_dir / "v1.bin") + src.reshape(-1).tofile(output_dir / "v2.bin") + dst.tofile(output_dir / "v3.bin") + golden.astype(np.float32).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-slots-scf-for-store/kernel.pto b/test/vpto/cases/vmi/group-slots-scf-for-store/kernel.pto new file mode 100644 index 0000000000..8ae0c03444 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-scf-for-store/kernel.pto @@ -0,0 +1,68 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_slots_scf_for_store_kernel(%init_gm: !pto.ptr, + %src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + + %ub_init = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_src = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %init_gm, %ub_init, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %acc0 = pto.vmi.group_slot_load %ub_init[%c0], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %acc = scf.for %i = %c0 to %c2 step %c1 + iter_args(%arg = %acc0) -> (!pto.vmi.vreg<128xf32>) { + %x = pto.vmi.group_load %ub_src[%c0], %c16 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_group_mask %c16 + {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + %next = pto.vmi.addf %arg, %sum + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + scf.yield %next : !pto.vmi.vreg<128xf32> + } + pto.vmi.group_store %acc, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-slots-scf-for-store/launch.cpp b/test/vpto/cases/vmi/group-slots-scf-for-store/launch.cpp new file mode 100644 index 0000000000..6837a88fd4 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-scf-for-store/launch.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_slots_scf_for_store_kernel(__gm__ float *init, __gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_slots_scf_for_store_kernel(float *init, float *src, + float *dst, void *stream) { + vmi_group_slots_scf_for_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)init, (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-slots-scf-for-store/main.cpp b/test/vpto/cases/vmi/group-slots-scf-for-store/main.cpp new file mode 100644 index 0000000000..555d105f43 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-scf-for-store/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_slots_scf_for_store_kernel(float *init, float *src, + float *dst, void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kCols = 16; + constexpr size_t kInitElems = kRows; + constexpr size_t kSrcElems = kRows * kCols; + constexpr size_t kDstElems = kRows; + size_t initBytes = kInitElems * sizeof(float); + size_t srcBytes = kSrcElems * sizeof(float); + size_t dstBytes = kDstElems * sizeof(float); + float *initHost = nullptr; + float *srcHost = nullptr; + float *dstHost = nullptr; + float *initDevice = nullptr; + float *srcDevice = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&initHost), initBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&initDevice, initBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", initBytes, initHost, initBytes); + ReadFile("./v2.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v3.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(initDevice, initBytes, initHost, initBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_slots_scf_for_store_kernel(initDevice, srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(initDevice); + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(initHost); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-slots-scf-for-store/ptoas.flags b/test/vpto/cases/vmi/group-slots-scf-for-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-scf-for-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/mask-granularity-f32-f16-store/compare.py b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/compare.py new file mode 100644 index 0000000000..24d554e100 --- /dev/null +++ b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/compare.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check_f32() -> bool: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-5, rtol=1e-5): + return True + diff = np.nonzero(~np.isclose(golden, output, atol=1e-5, rtol=1e-5))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed v2 idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + return False + + +def check_f16() -> bool: + golden = np.fromfile("golden_v3.bin", dtype=np.float16) + output = np.fromfile("v3.bin", dtype=np.float16) + if golden.shape == output.shape and np.array_equal(golden, output): + return True + diff = np.nonzero(golden.view(np.uint16) != output.view(np.uint16))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed v3 idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + return False + + +def main() -> None: + if not check_f32() or not check_f16(): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/mask-granularity-f32-f16-store/golden.py b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/golden.py new file mode 100644 index 0000000000..6a28077ea8 --- /dev/null +++ b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 128 +ACTIVE = 96 +SEED = 29 +SENTINEL32 = np.float32(-901.25) +SENTINEL16 = np.float16(-17.5) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + src = rng.uniform(-8.0, 8.0, size=ELEMS).astype(np.float32) + out32 = np.full(ELEMS, SENTINEL32, dtype=np.float32) + out16 = np.full(ELEMS, SENTINEL16, dtype=np.float16) + golden32 = out32.copy() + golden16 = out16.copy() + golden32[:ACTIVE] = src[:ACTIVE] + golden16[:ACTIVE] = src[:ACTIVE].astype(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + out32.tofile(output_dir / "v2.bin") + out16.tofile(output_dir / "v3.bin") + golden32.tofile(output_dir / "golden_v2.bin") + golden16.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/mask-granularity-f32-f16-store/kernel.pto b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/kernel.pto new file mode 100644 index 0000000000..f9362793ec --- /dev/null +++ b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/kernel.pto @@ -0,0 +1,60 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_mask_granularity_f32_f16_store_kernel(%src_gm: !pto.ptr, + %out32_gm: !pto.ptr, + %out16_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c96 = arith.constant 96 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out32 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out16 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out32_gm, %ub_out32, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out16_gm, %ub_out16, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_mask %c96 : index -> !pto.vmi.mask<128xpred> + pto.vmi.masked_store %x, %ub_out32[%c0], %mask + : !pto.vmi.vreg<128xf32>, !pto.ptr, !pto.vmi.mask<128xpred> + %h = pto.vmi.truncf %x : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + pto.vmi.masked_store %h, %ub_out16[%c0], %mask + : !pto.vmi.vreg<128xf16>, !pto.ptr, !pto.vmi.mask<128xpred> + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out32, %out32_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out16, %out16_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/mask-granularity-f32-f16-store/launch.cpp b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/launch.cpp new file mode 100644 index 0000000000..de0c069797 --- /dev/null +++ b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/launch.cpp @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_mask_granularity_f32_f16_store_kernel(__gm__ float *src, + __gm__ float *out32, + __gm__ half *out16); + +void LaunchVmi_mask_granularity_f32_f16_store_kernel(float *src, float *out32, + uint16_t *out16, + void *stream) { + vmi_mask_granularity_f32_f16_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)out32, (__gm__ half *)out16); +} diff --git a/test/vpto/cases/vmi/mask-granularity-f32-f16-store/main.cpp b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/main.cpp new file mode 100644 index 0000000000..2a65d8c46d --- /dev/null +++ b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/main.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_mask_granularity_f32_f16_store_kernel(float *src, float *out32, + uint16_t *out16, + void *stream); + +int main() { + constexpr size_t kElems = 128; + size_t srcBytes = kElems * sizeof(float); + size_t out32Bytes = kElems * sizeof(float); + size_t out16Bytes = kElems * sizeof(uint16_t); + float *srcHost = nullptr; + float *out32Host = nullptr; + uint16_t *out16Host = nullptr; + float *srcDevice = nullptr; + float *out32Device = nullptr; + uint16_t *out16Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&out32Host), out32Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&out16Host), out16Bytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&out32Device, out32Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&out16Device, out16Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", out32Bytes, out32Host, out32Bytes); + ReadFile("./v3.bin", out16Bytes, out16Host, out16Bytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(out32Device, out32Bytes, out32Host, out32Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(out16Device, out16Bytes, out16Host, out16Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_mask_granularity_f32_f16_store_kernel(srcDevice, out32Device, + out16Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(out32Host, out32Bytes, out32Device, out32Bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(out16Host, out16Bytes, out16Device, out16Bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", out32Host, out32Bytes); + WriteFile("./v3.bin", out16Host, out16Bytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(out32Device); + aclrtFree(out16Device); + aclrtFreeHost(srcHost); + aclrtFreeHost(out32Host); + aclrtFreeHost(out16Host); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/mask-granularity-f32-f16-store/ptoas.flags b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/mask-select-store/compare.py b/test/vpto/cases/vmi/mask-select-store/compare.py new file mode 100644 index 0000000000..b9e3290e76 --- /dev/null +++ b/test/vpto/cases/vmi/mask-select-store/compare.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + for name in ("v3", "v4"): + golden = np.fromfile(f"golden_{name}.bin", dtype=np.float32) + output = np.fromfile(f"{name}.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-5, rtol=1e-5): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-5, rtol=1e-5))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/mask-select-store/golden.py b/test/vpto/cases/vmi/mask-select-store/golden.py new file mode 100644 index 0000000000..19ce1ebe2c --- /dev/null +++ b/test/vpto/cases/vmi/mask-select-store/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 64 +ACTIVE = 48 +SEED = 29 +SENTINEL = np.float32(-901.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + src = rng.uniform(-8.0, 8.0, size=ELEMS).astype(np.float32) + rhs = rng.uniform(-4.0, 4.0, size=ELEMS).astype(np.float32) + dense = np.full(ELEMS, SENTINEL, dtype=np.float32) + masked = np.full(ELEMS, SENTINEL, dtype=np.float32) + summed = (src + rhs).astype(np.float32) + golden_dense = src.copy() + golden_dense[:ACTIVE] = summed[:ACTIVE] + golden_masked = masked.copy() + golden_masked[:ACTIVE] = summed[:ACTIVE] + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + rhs.tofile(output_dir / "v2.bin") + dense.tofile(output_dir / "v3.bin") + masked.tofile(output_dir / "v4.bin") + golden_dense.tofile(output_dir / "golden_v3.bin") + golden_masked.tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/mask-select-store/kernel.pto b/test/vpto/cases/vmi/mask-select-store/kernel.pto new file mode 100644 index 0000000000..51538fd4e0 --- /dev/null +++ b/test/vpto/cases/vmi/mask-select-store/kernel.pto @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_mask_select_store_kernel(%src_gm: !pto.ptr, + %rhs_gm: !pto.ptr, + %dense_gm: !pto.ptr, + %masked_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c48 = arith.constant 48 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_dense = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_masked = pto.castptr %c12288_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %rhs_gm, %ub_rhs, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dense_gm, %ub_dense, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %masked_gm, %ub_masked, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<64xf32> + %rhs = pto.vmi.load %ub_rhs[%c0] : !pto.ptr -> !pto.vmi.vreg<64xf32> + %mask = pto.vmi.create_mask %c48 : index -> !pto.vmi.mask<64xpred> + %sum = pto.vmi.addf %x, %rhs + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32> + -> !pto.vmi.vreg<64xf32> + %passthrough = pto.vmi.select %mask, %sum, %x + : !pto.vmi.mask<64xpred>, !pto.vmi.vreg<64xf32>, + !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf32> + pto.vmi.store %passthrough, %ub_dense[%c0] + : !pto.vmi.vreg<64xf32>, !pto.ptr + pto.vmi.masked_store %sum, %ub_masked[%c0], %mask + : !pto.vmi.vreg<64xf32>, !pto.ptr, !pto.vmi.mask<64xpred> + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dense, %dense_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_masked, %masked_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/mask-select-store/launch.cpp b/test/vpto/cases/vmi/mask-select-store/launch.cpp new file mode 100644 index 0000000000..d75d0da804 --- /dev/null +++ b/test/vpto/cases/vmi/mask-select-store/launch.cpp @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_mask_select_store_kernel(__gm__ float *src, __gm__ float *rhs, + __gm__ float *dense, __gm__ float *masked); + +void LaunchVmi_mask_select_store_kernel(float *src, float *rhs, float *dense, + float *masked, void *stream) { + vmi_mask_select_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)rhs, (__gm__ float *)dense, + (__gm__ float *)masked); +} diff --git a/test/vpto/cases/vmi/mask-select-store/main.cpp b/test/vpto/cases/vmi/mask-select-store/main.cpp new file mode 100644 index 0000000000..07648040d0 --- /dev/null +++ b/test/vpto/cases/vmi/mask-select-store/main.cpp @@ -0,0 +1,99 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_mask_select_store_kernel(float *src, float *rhs, float *dense, + float *masked, void *stream); + +int main() { + constexpr size_t kElems = 64; + size_t srcBytes = kElems * sizeof(float); + size_t rhsBytes = kElems * sizeof(float); + size_t denseBytes = kElems * sizeof(float); + size_t maskedBytes = kElems * sizeof(float); + float *srcHost = nullptr; + float *rhsHost = nullptr; + float *denseHost = nullptr; + float *maskedHost = nullptr; + float *srcDevice = nullptr; + float *rhsDevice = nullptr; + float *denseDevice = nullptr; + float *maskedDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&rhsHost), rhsBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&denseHost), denseBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&maskedHost), maskedBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&rhsDevice, rhsBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&denseDevice, denseBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&maskedDevice, maskedBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", rhsBytes, rhsHost, rhsBytes); + ReadFile("./v3.bin", denseBytes, denseHost, denseBytes); + ReadFile("./v4.bin", maskedBytes, maskedHost, maskedBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(rhsDevice, rhsBytes, rhsHost, rhsBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(denseDevice, denseBytes, denseHost, denseBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(maskedDevice, maskedBytes, maskedHost, maskedBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_mask_select_store_kernel(srcDevice, rhsDevice, denseDevice, + maskedDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(denseHost, denseBytes, denseDevice, denseBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(maskedHost, maskedBytes, maskedDevice, maskedBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", denseHost, denseBytes); + WriteFile("./v4.bin", maskedHost, maskedBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(rhsDevice); + aclrtFree(denseDevice); + aclrtFree(maskedDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(rhsHost); + aclrtFreeHost(denseHost); + aclrtFreeHost(maskedHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/mask-select-store/ptoas.flags b/test/vpto/cases/vmi/mask-select-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/mask-select-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/masked-load-dense-group-users/compare.py b/test/vpto/cases/vmi/masked-load-dense-group-users/compare.py new file mode 100644 index 0000000000..9f34394fa1 --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-dense-group-users/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(name: str, golden_name: str) -> None: + golden = np.fromfile(golden_name, dtype=np.float32) + output = np.fromfile(name, dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + return + if golden.shape != output.shape: + print(f"[ERROR] compare failed {name}: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +def main() -> None: + check("v2.bin", "golden_v2.bin") + check("v3.bin", "golden_v3.bin") + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/masked-load-dense-group-users/golden.py b/test/vpto/cases/vmi/masked-load-dense-group-users/golden.py new file mode 100644 index 0000000000..41f1b1b714 --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-dense-group-users/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 32 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + base = np.linspace(-0.875, 0.625, COLS, dtype=np.float32) + src = np.empty((ROWS, COLS), dtype=np.float32) + for row in range(ROWS): + src[row, :] = base + np.float32(row) * np.float32(0.03125) + copy = np.full((ROWS, COLS), SENTINEL, dtype=np.float32) + sums = np.full(ROWS, SENTINEL, dtype=np.float32) + golden_copy = src.copy() + golden_sum = np.sum(src, axis=1, dtype=np.float32).astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + copy.reshape(-1).tofile(output_dir / "v2.bin") + sums.tofile(output_dir / "v3.bin") + golden_copy.reshape(-1).astype(np.float32).tofile(output_dir / "golden_v2.bin") + golden_sum.astype(np.float32).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/masked-load-dense-group-users/kernel.pto b/test/vpto/cases/vmi/masked-load-dense-group-users/kernel.pto new file mode 100644 index 0000000000..503068186e --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-dense-group-users/kernel.pto @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_masked_load_dense_group_users_kernel(%src_gm: !pto.ptr, + %copy_gm: !pto.ptr, + %sum_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %zero = arith.constant 0.000000e+00 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_copy = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %zero_vec = pto.vmi.broadcast %zero : f32 -> !pto.vmi.vreg<256xf32> + %x = pto.vmi.masked_load %ub_src[%c0], %mask, %zero_vec + : !pto.ptr, !pto.vmi.mask<256xpred>, + !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + pto.vmi.store %x, %ub_copy[%c0] + : !pto.vmi.vreg<256xf32>, !pto.ptr + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_copy, %copy_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/masked-load-dense-group-users/launch.cpp b/test/vpto/cases/vmi/masked-load-dense-group-users/launch.cpp new file mode 100644 index 0000000000..306dddada0 --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-dense-group-users/launch.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_masked_load_dense_group_users_kernel(__gm__ float *src, __gm__ float *copy, + __gm__ float *sum); + +void LaunchVmi_masked_load_dense_group_users_kernel(float *src, float *copy, + float *sum, void *stream) { + vmi_masked_load_dense_group_users_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)copy, (__gm__ float *)sum); +} diff --git a/test/vpto/cases/vmi/masked-load-dense-group-users/main.cpp b/test/vpto/cases/vmi/masked-load-dense-group-users/main.cpp new file mode 100644 index 0000000000..089794a818 --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-dense-group-users/main.cpp @@ -0,0 +1,97 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_masked_load_dense_group_users_kernel(float *src, float *copy, + float *sum, void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kCols = 32; + constexpr size_t kSrcElems = kRows * kCols; + constexpr size_t kSumElems = kRows; + size_t srcBytes = kSrcElems * sizeof(float); + size_t copyBytes = kSrcElems * sizeof(float); + size_t sumBytes = kSumElems * sizeof(float); + float *srcHost = nullptr; + float *copyHost = nullptr; + float *sumHost = nullptr; + float *srcDevice = nullptr; + float *copyDevice = nullptr; + float *sumDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(©Host), copyBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)©Device, copyBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", copyBytes, copyHost, copyBytes); + ReadFile("./v3.bin", sumBytes, sumHost, sumBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(copyDevice, copyBytes, copyHost, copyBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_masked_load_dense_group_users_kernel(srcDevice, copyDevice, + sumDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(copyHost, copyBytes, copyDevice, copyBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", copyHost, copyBytes); + WriteFile("./v3.bin", sumHost, sumBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(copyDevice); + aclrtFree(sumDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(copyHost); + aclrtFreeHost(sumHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/masked-load-dense-group-users/ptoas.flags b/test/vpto/cases/vmi/masked-load-dense-group-users/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-dense-group-users/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/scf-for-loop-carried-store/compare.py b/test/vpto/cases/vmi/scf-for-loop-carried-store/compare.py new file mode 100644 index 0000000000..28299087e5 --- /dev/null +++ b/test/vpto/cases/vmi/scf-for-loop-carried-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/scf-for-loop-carried-store/golden.py b/test/vpto/cases/vmi/scf-for-loop-carried-store/golden.py new file mode 100644 index 0000000000..bc9c97fdee --- /dev/null +++ b/test/vpto/cases/vmi/scf-for-loop-carried-store/golden.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 128 +SEED = 37 +SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + src = rng.uniform(-4.0, 4.0, size=ELEMS).astype(np.float16) + dst = np.full(ELEMS, SENTINEL, dtype=np.float32) + golden = src.astype(np.float32) * np.float32(4.0) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/scf-for-loop-carried-store/kernel.pto b/test/vpto/cases/vmi/scf-for-loop-carried-store/kernel.pto new file mode 100644 index 0000000000..3398ef3318 --- /dev/null +++ b/test/vpto/cases/vmi/scf-for-loop-carried-store/kernel.pto @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_scf_for_loop_carried_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %packed = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf16> + %init = pto.vmi.extf %packed : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %result = scf.for %i = %c0 to %c2 step %c1 + iter_args(%acc = %init) -> (!pto.vmi.vreg<128xf32>) { + %next = pto.vmi.addf %acc, %acc + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + scf.yield %next : !pto.vmi.vreg<128xf32> + } + pto.vmi.store %result, %ub_dst[%c0] + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/scf-for-loop-carried-store/launch.cpp b/test/vpto/cases/vmi/scf-for-loop-carried-store/launch.cpp new file mode 100644 index 0000000000..b0902d1207 --- /dev/null +++ b/test/vpto/cases/vmi/scf-for-loop-carried-store/launch.cpp @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_scf_for_loop_carried_store_kernel(__gm__ uint16_t *src, __gm__ float *dst); + +void LaunchVmi_scf_for_loop_carried_store_kernel(uint16_t *src, float *dst, + void *stream) { + vmi_scf_for_loop_carried_store_kernel<<<1, nullptr, stream>>>( + (__gm__ uint16_t *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/scf-for-loop-carried-store/main.cpp b/test/vpto/cases/vmi/scf-for-loop-carried-store/main.cpp new file mode 100644 index 0000000000..f45b070260 --- /dev/null +++ b/test/vpto/cases/vmi/scf-for-loop-carried-store/main.cpp @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_scf_for_loop_carried_store_kernel(uint16_t *src, float *dst, + void *stream); + +int main() { + constexpr size_t kElems = 128; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t dstBytes = kElems * sizeof(float); + uint16_t *srcHost = nullptr; + uint16_t *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_scf_for_loop_carried_store_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/scf-for-loop-carried-store/ptoas.flags b/test/vpto/cases/vmi/scf-for-loop-carried-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/scf-for-loop-carried-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/compare.py b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/compare.py new file mode 100644 index 0000000000..c964405de5 --- /dev/null +++ b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/compare.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(name: str, atol: float, rtol: float) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.float32) + output = np.fromfile(f"{name}.bin", dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=atol, rtol=rtol): + return True + close = np.isclose(golden, output, atol=atol, rtol=rtol) + diff = np.nonzero(~close)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + return False + + +def main() -> None: + if not check("v2", 1e-4, 1e-4) or not check("v3", 0.0, 0.0): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/golden.py b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/golden.py new file mode 100644 index 0000000000..b41d0e8681 --- /dev/null +++ b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 16 +ELEMS = ROWS * GROUP_SIZE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, GROUP_SIZE), dtype=np.float16) + base = np.linspace(-0.625, 0.875, GROUP_SIZE, dtype=np.float16) + for row in range(ROWS): + src[row, :] = base + np.float16(row * 0.125) + + dense = np.full(ELEMS, SENTINEL, dtype=np.float32) + sum_out = np.full(ROWS, SENTINEL, dtype=np.float32) + golden_dense = src.astype(np.float32).reshape(-1) + golden_sum = np.empty(ROWS, dtype=np.float32) + for row in range(ROWS): + golden_sum[row] = np.sum(src[row, :].astype(np.float32), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + sum_out.tofile(output_dir / "v2.bin") + dense.tofile(output_dir / "v3.bin") + golden_sum.tofile(output_dir / "golden_v2.bin") + golden_dense.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/kernel.pto b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/kernel.pto new file mode 100644 index 0000000000..9f3dfeabb4 --- /dev/null +++ b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/kernel.pto @@ -0,0 +1,67 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_widen_f16_to_f32_store_reduce_kernel(%src_gm: !pto.ptr, + %sum_gm: !pto.ptr, + %dense_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_dense = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %sum_gm, %ub_sum, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dense_gm, %ub_dense, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x16 = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<128xf16> + %x32 = pto.vmi.extf %x16 + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %sum = pto.vmi.group_reduce_addf %x32, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + pto.vmi.store %x32, %ub_dense[%c0] + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_dense, %dense_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/launch.cpp b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/launch.cpp new file mode 100644 index 0000000000..b0ee12da2b --- /dev/null +++ b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/launch.cpp @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_widen_f16_to_f32_store_reduce_kernel(__gm__ half *src, __gm__ float *sum, + __gm__ float *dense); + +void LaunchVmi_widen_f16_to_f32_store_reduce_kernel(uint16_t *src, float *sum, + float *dense, + void *stream) { + vmi_widen_f16_to_f32_store_reduce_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ float *)sum, (__gm__ float *)dense); +} diff --git a/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/main.cpp b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/main.cpp new file mode 100644 index 0000000000..96a4a102f8 --- /dev/null +++ b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/main.cpp @@ -0,0 +1,92 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_widen_f16_to_f32_store_reduce_kernel(uint16_t *src, float *sum, + float *dense, void *stream); + +int main() { + constexpr size_t kSrcElems = 128; + constexpr size_t kSumElems = 8; + constexpr size_t kDenseElems = 128; + size_t srcBytes = kSrcElems * sizeof(uint16_t); + size_t sumBytes = kSumElems * sizeof(float); + size_t denseBytes = kDenseElems * sizeof(float); + uint16_t *srcHost = nullptr; + float *sumHost = nullptr; + float *denseHost = nullptr; + uint16_t *srcDevice = nullptr; + float *sumDevice = nullptr; + float *denseDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&denseHost), denseBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&denseDevice, denseBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", sumBytes, sumHost, sumBytes); + ReadFile("./v3.bin", denseBytes, denseHost, denseBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(denseDevice, denseBytes, denseHost, denseBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_widen_f16_to_f32_store_reduce_kernel(srcDevice, sumDevice, + denseDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(denseHost, denseBytes, denseDevice, denseBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", sumHost, sumBytes); + WriteFile("./v3.bin", denseHost, denseBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(sumDevice); + aclrtFree(denseDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(sumHost); + aclrtFreeHost(denseHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/ptoas.flags b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi From d225422aaca2d2bd3db465c49de3b9a5532eb170 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 11:45:23 +0800 Subject: [PATCH 05/31] Support S32 partial grouped mask lowering --- .../vmi-layout-assignment-implementation.md | 38 ++++---- .../vmi-layout-assignment-lowering-design.md | 9 +- docs/designs/vmi-layout-lowering-cases.md | 66 ++++++------- lib/PTO/Transforms/VMILayoutAssignment.cpp | 22 ----- lib/PTO/Transforms/VMIToVPTO.cpp | 59 ++++++++++- ..._assignment_masked_load_group_tail_s32.pto | 21 +++- .../compare.py | 40 ++++++++ .../golden.py | 51 ++++++++++ .../kernel.pto | 62 ++++++++++++ .../launch.cpp | 33 +++++++ .../main.cpp | 97 +++++++++++++++++++ .../ptoas.flags | 1 + 12 files changed, 413 insertions(+), 86 deletions(-) create mode 100644 test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/compare.py create mode 100644 test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/golden.py create mode 100644 test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/kernel.pto create mode 100644 test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/launch.cpp create mode 100644 test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/main.cpp create mode 100644 test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/ptoas.flags diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index e6b39dd984..12fedece13 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -758,8 +758,11 @@ Current audit result: ```text 3.44 partial S=32 create_group_mask: - decision moved to vmi-layout-assignment. vmi-to-vpto no longer walks from - group_reduce_addf to the mask defining op to reject the plan. + assignment writes explicit contiguous and deinterleaved mask values. When + lowering the deinterleaved create_group_mask itself, vmi-to-vpto first + materializes contiguous grouped predicate chunks and then applies predicate + pdintlv in the same tree shape as the data vdintlv. It still does not walk + from group_reduce_addf to the mask defining op to choose or reject the plan. masked_load: direct lowering is load + vsel. It does not inspect the mask producer to @@ -870,13 +873,13 @@ the case catalog. Current broad runtime sweep: ```text -WORK_SPACE=$PWD/.tmp/vmi-runtime-batch-39 CASE_PREFIX='vmi/' JOBS=4 \ +WORK_SPACE=$PWD/.tmp/vmi-runtime-batch-40 CASE_PREFIX='vmi/' JOBS=4 \ test/vpto/scripts/run_host_vpto_validation_parallel.sh -PASS=39 FAIL=0 -summary: .tmp/vmi-runtime-batch-39/parallel-summary.tsv +PASS=40 FAIL=0 +summary: .tmp/vmi-runtime-batch-40/parallel-summary.tsv log scan: rg -n "RV_|alignment|\[ERROR\]|\[error\]|ERROR" \ - .tmp/vmi-runtime-batch-39.log + .tmp/vmi-runtime-batch-40.log result: no matches ``` @@ -1054,6 +1057,17 @@ runtime SIM: load must run through layout assignment before VPTO/LLVM emission. ``` +Current checked-in coverage for 3.44 masked_load grouped tail feeding S=32 +reduce: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto + +runtime SIM: + test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store +``` + Current checked-in runtime coverage for 3.12 control-flow join before S=32 `group_reduce`: @@ -1297,7 +1311,6 @@ Diagnostic-only cases: 3.25.1 full ptoas emission for private VMI callees that return VPTO vector values 3.25.2 public/external VMI boundary 3.30 unsafe masked_load tail without stable masked/gather fallback -3.44 masked_load grouped tail with S=32 partial create_group_mask ``` Current checked-in diagnostic coverage for 3.9/3.13/3.14: @@ -1326,7 +1339,6 @@ lit: test/lit/vmi/vmi_ptoas_call_boundary_vecscope_invalid.pto test/lit/vmi/vmi_to_vpto_masked_load_nonfull_invalid.pto test/lit/vmi/vmi_to_vpto_stable_gather_masked_load_todo_invalid.pto - test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto ``` Known implementation gaps before all catalog cases can become runtime SIM @@ -1339,16 +1351,6 @@ dynamic grouped masks: yet. Do not replace grouped masks with prefix create_mask; that would change the semantics. -S=32 partial grouped masks: - 3.44 `masked_load` grouped tail with `active_elems_per_group < 32` is - diagnostic-only for the current S=32 block8 reduce path, and the diagnostic - is emitted by `vmi-layout-assignment` before a selected plan is written. A - runtime probe of the previously allowed lowering did not preserve the logical - 25-lane row sum. A second probe with `active_elems_per_group = 25` produced - row 0 `golden=-3.6290324` but `output=-3.6592741`, and the row-wise error - grew monotonically. This combination must stay unsupported until the - deinterleaved grouped-mask materialization is fixed and validated by SIM. - remaining function runtime coverage: 3.25.1 internal function boundary specialization has layout-assignment and vmi-to-vpto lit coverage, but full ptoas emission still fails after diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index 710ab267a7..d0f71f0daf 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -552,9 +552,12 @@ diagnostic embellishment: Anything else is a layout-assignment responsibility. In particular, an unsupported producer/consumer combination must be rejected before assignment -writes a selected plan. Section 3.44 is the model: partial S=32 grouped masks -are diagnosed in `vmi-layout-assignment`, not by `vmi-to-vpto` walking from -`group_reduce_addf` to the mask producer. +writes a selected plan. Section 3.44 is the model for supported partial S=32 +grouped masks: assignment emits explicit contiguous and deinterleaved mask +values, and `vmi-to-vpto` lowers the deinterleaved mask op itself through +contiguous grouped-mask materialization followed by predicate deinterleave. It +does not walk from `group_reduce_addf` to the mask producer to choose or reject +the plan. ## 9. Physical Value Ordering diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md index b111397fc9..93a6c1dc57 100644 --- a/docs/designs/vmi-layout-lowering-cases.md +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -196,7 +196,7 @@ the immediately following complete endpoints. 3.41 non-rematerializable value with incompatible users complete/materialization 3.42 group_slots scf.for loop-carried accumulator complete 3.43 internal function argument boundary materialization complete/design -3.44 masked_load grouped tail feeding S=32 reduce complete/design +3.44 masked_load grouped tail feeding S=32 reduce complete ``` ### 3.1 `f16 -> f32 -> store` @@ -5167,25 +5167,7 @@ Assigned layouts: !pto.vmi.vreg<256xf32, #pto.vmi.layout> ``` -Current implementation result: - -```text -VMI-UNSUPPORTED: pto.vmi.group_reduce_addf s32 block8 lowering does not yet -support partial create_group_mask active_elems_per_group during layout -assignment -``` - -This must remain a layout-assignment diagnostic until the S=32 block8 -grouped-mask lowering is proven against runtime SIM. Assignment must not write -`vmi.selected_plan = "s32_reduce_block8_stride"` for this case and leave -`vmi-to-vpto` to discover the partial mask by walking the mask defining op. A -`masked_load` can be lowered contiguously and then materialized to -`deinterleaved = 4, block_elems = 8`, but the grouped reduce still needs a -physically correct `create_group_mask` for `active_elems_per_group = 25`. -Allowing the current S=32 block8 path to proceed would not preserve the logical -memory result below. - -Intended VPTO lowering shape after the grouped-mask issue is fixed: +Lowering: ```text %all_b32 = pto.pge_b32 "PAT_ALL" @@ -5209,15 +5191,16 @@ Intended VPTO lowering shape after the grouped-mask issue is fixed: %x_p0, %x_p2 = pto.vdintlv %x01_lo, %x23_lo %x_p1, %x_p3 = pto.vdintlv %x01_hi, %x23_hi -// Correct deinterleaved grouped mask for active columns 0..24: -// part 0 covers columns 0..7 for every row: all active -// part 1 covers columns 8..15 for every row: all active -// part 2 covers columns 16..23 for every row: all active -// part 3 covers columns 24..31 for every row: one active lane per row -%mask_p0 = pto.pset_b32 "PAT_ALL" -%mask_p1 = pto.pset_b32 "PAT_ALL" -%mask_p2 = pto.pset_b32 "PAT_ALL" -%mask_p3 = materialize one lane per 8-lane row block +// The reduce-side grouped mask is not built by guessing the final sparse +// predicate image. It is first materialized as the same contiguous grouped +// mask used by masked_load, then converted to the reduce layout with predicate +// deinterleave. This keeps predicate reordering identical to the data +// reordering above. +%rm0, %rm1, %rm2, %rm3 = materialize contiguous create_group_mask(c25, S=32) +%rm01_lo, %rm01_hi = pto.pdintlv_b32 %rm0, %rm1 +%rm23_lo, %rm23_hi = pto.pdintlv_b32 %rm2, %rm3 +%mask_p0, %mask_p2 = pto.pdintlv_b32 %rm01_lo, %rm23_lo +%mask_p1, %mask_p3 = pto.pdintlv_b32 %rm01_hi, %rm23_hi %s0 = pto.vcgadd %x_p0, %mask_p0 : !pto.vreg<64xf32> %s1 = pto.vcgadd %x_p1, %mask_p1 : !pto.vreg<64xf32> @@ -5244,12 +5227,21 @@ Required assignment rule: ```text `masked_load` and `group_reduce` must share the same grouped mask layout. The passthrough value defines inactive loaded lanes, while the reduce mask defines -participation. Assignment may select a deinterleaved S=32 load plan only when -the rounded physical reads are memory-safe; otherwise it must diagnose or use a -future stable gather fallback. - -Current implementation additionally diagnoses the S=32 block8 partial grouped -mask itself. This is deliberate: the case is not implemented until the -deinterleaved grouped-mask materialization and `vcgadd` interpretation are -validated end to end by SIM. +participation. Assignment materializes two explicit mask values when needed: +one contiguous value for `masked_load`, and one deinterleaved value for +`group_reduce_addf`. `vmi-to-vpto` lowers the deinterleaved +`create_group_mask` by materializing the contiguous grouped predicate chunks +and then applying `pdintlv_b32` in the same tree shape as the data +`vdintlv`. It does not walk from `group_reduce_addf` to the mask producer to +choose or reject the selected plan. + +Assignment may select a deinterleaved S=32 load plan only when the rounded +physical reads are memory-safe; otherwise it must diagnose or use a future +stable gather fallback. + +Runtime coverage: + +```text +test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store +``` ``` diff --git a/lib/PTO/Transforms/VMILayoutAssignment.cpp b/lib/PTO/Transforms/VMILayoutAssignment.cpp index c95e8772ec..9352ffce76 100644 --- a/lib/PTO/Transforms/VMILayoutAssignment.cpp +++ b/lib/PTO/Transforms/VMILayoutAssignment.cpp @@ -832,28 +832,6 @@ struct LayoutSolver { VMILayoutAttr::getDeinterleaved(ctx, 4, /*blockElems=*/1); } } - if (sourceLayout && sourceLayout.isDeinterleaved() && - sourceLayout.getFactor() == 4 && - sourceLayout.getBlockElems() == 8 && numGroups > 0 && - sourceType.getElementCount() % numGroups == 0) { - int64_t groupSize = sourceType.getElementCount() / numGroups; - if (groupSize == 32) { - if (auto groupMask = - reduce.getMask().getDefiningOp()) { - std::optional activeElems = - getConstantIndexValue(groupMask.getActiveElemsPerGroup()); - if (activeElems && *activeElems >= 0 && - *activeElems < groupSize) { - reduce.emitError() - << kVMIDiagUnsupportedPrefix - << "pto.vmi.group_reduce_addf s32 block8 lowering does " - "not yet support partial create_group_mask " - "active_elems_per_group during layout assignment"; - return WalkResult::interrupt(); - } - } - } - } requestDataUse(reduce.getSourceMutable(), sourceLayout); if (failed(requestMaskUse( reduce.getMaskMutable(), sourceLayout, diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index 95141bada7..85dbec5f1e 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -2079,7 +2079,9 @@ computeConstantMaskMaterialization(VMIConstantMaskOp op, std::string *reason) { } FailureOr> -computeGroupMaskMaterialization(VMICreateGroupMaskOp op, std::string *reason) { +computeGroupMaskMaterializationForType(VMICreateGroupMaskOp op, + VMIMaskType resultVMIType, + std::string *reason) { auto fail = [&](const Twine &message) -> FailureOr> { if (reason) @@ -2095,7 +2097,6 @@ computeGroupMaskMaterialization(VMICreateGroupMaskOp op, std::string *reason) { if (!activeAttr) return fail("active_elems_per_group must be an integer constant"); - auto resultVMIType = cast(op.getResult().getType()); VMILayoutAttr layout = resultVMIType.getLayoutAttr(); if (!layout || !VMIMaskType::isConcreteGranularity(resultVMIType.getGranularity())) @@ -2153,6 +2154,12 @@ computeGroupMaskMaterialization(VMICreateGroupMaskOp op, std::string *reason) { return materializations; } +FailureOr> +computeGroupMaskMaterialization(VMICreateGroupMaskOp op, std::string *reason) { + return computeGroupMaskMaterializationForType( + op, cast(op.getResult().getType()), reason); +} + std::optional getPrefixActiveLaneCount(ArrayRef activeLanes) { bool seenInactive = false; int64_t activeCount = 0; @@ -3781,6 +3788,54 @@ struct OneToNVMICreateGroupMaskOpPattern matchAndRewrite(VMICreateGroupMaskOp op, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override { TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + auto resultVMIType = cast(op.getResult().getType()); + VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); + if (resultLayout && resultLayout.isDeinterleaved() && + resultLayout.getFactor() == 4 && resultLayout.getBlockElems() == 8) { + VMILayoutAttr contiguousLayout = + VMILayoutAttr::getContiguous(op.getContext()); + auto contiguousType = + VMIMaskType::get(op.getContext(), resultVMIType.getElementCount(), + resultVMIType.getGranularity(), contiguousLayout); + std::string contiguousReason; + FailureOr> + contiguousMaterializations = computeGroupMaskMaterializationForType( + op, contiguousType, &contiguousReason); + if (failed(contiguousMaterializations)) + return rewriter.notifyMatchFailure( + op, Twine("create_group_mask ") + contiguousReason); + + SmallVector contiguousParts; + contiguousParts.reserve(contiguousMaterializations->size()); + for (const ConstantMaskChunkMaterialization &materialization : + *contiguousMaterializations) { + if (contiguousParts.size() >= resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "create_group_mask produced too many contiguous masks"); + auto maskType = dyn_cast(resultTypes[contiguousParts.size()]); + if (!maskType) + return rewriter.notifyMatchFailure( + op, "create_group_mask result must be mask"); + FailureOr mask = materializeConstantMaskChunk( + op.getLoc(), maskType, materialization.activeLanes, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "failed to materialize create_group_mask contiguous chunk"); + contiguousParts.push_back(*mask); + } + + if (contiguousParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "create_group_mask contiguous physical result count mismatch"); + FailureOr> results = materializeMaskLayoutConversion( + op, contiguousParts, resultTypes, contiguousLayout, resultLayout, + rewriter); + if (failed(results)) + return failure(); + rewriter.replaceOp(op, *results, adaptor.getResultMapping()); + return success(); + } + std::string reason; FailureOr> materializations = computeGroupMaskMaterialization(op, &reason); diff --git a/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto b/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto index bad43bb869..33ee79cb57 100644 --- a/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto +++ b/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto @@ -6,7 +6,8 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER module { func.func @vmi_layout_assignment_masked_load_group_tail_s32( @@ -34,6 +35,18 @@ module { } } -// CHECK: VMI{{-}}UNSUPPORTED: pto.vmi.group_reduce_addf -// CHECK-SAME: s32 block8 lowering does not yet support partial create_group_mask active_elems_per_group during layout assignment -// CHECK-NOT: vmi.selected_plan = "s32_reduce_block8_stride" +// ASSIGN: pto.vmi.create_group_mask +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: pto.vmi.ensure_layout +// ASSIGN-SAME: #pto.vmi.layout +// ASSIGN-SAME: #pto.vmi.layout +// ASSIGN: pto.vmi.create_group_mask +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_reduce_addf +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" +// LOWER: pto.pdintlv_b32 +// LOWER: pto.pdintlv_b32 +// LOWER: pto.pdintlv_b32 +// LOWER: pto.pdintlv_b32 +// LOWER: pto.vcgadd +// LOWER: pto.vsts diff --git a/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/compare.py b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/compare.py new file mode 100644 index 0000000000..9f34394fa1 --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(name: str, golden_name: str) -> None: + golden = np.fromfile(golden_name, dtype=np.float32) + output = np.fromfile(name, dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + return + if golden.shape != output.shape: + print(f"[ERROR] compare failed {name}: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +def main() -> None: + check("v2.bin", "golden_v2.bin") + check("v3.bin", "golden_v3.bin") + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/golden.py b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/golden.py new file mode 100644 index 0000000000..df3f6a24dc --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 32 +ACTIVE = 25 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + active_base = np.linspace(-0.875, 0.625, ACTIVE, dtype=np.float32) + inactive_base = np.linspace(19.0, 22.5, COLS - ACTIVE, dtype=np.float32) + src = np.empty((ROWS, COLS), dtype=np.float32) + for row in range(ROWS): + src[row, :ACTIVE] = active_base + np.float32(row) * np.float32(0.03125) + src[row, ACTIVE:] = inactive_base + np.float32(row) * np.float32(1.75) + + copy = np.full((ROWS, COLS), SENTINEL, dtype=np.float32) + sums = np.full(ROWS, SENTINEL, dtype=np.float32) + golden_copy = src.copy() + golden_copy[:, ACTIVE:] = np.float32(0.0) + golden_sum = np.sum(src[:, :ACTIVE], axis=1, dtype=np.float32).astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + copy.reshape(-1).tofile(output_dir / "v2.bin") + sums.tofile(output_dir / "v3.bin") + golden_copy.reshape(-1).astype(np.float32).tofile(output_dir / "golden_v2.bin") + golden_sum.astype(np.float32).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/kernel.pto b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/kernel.pto new file mode 100644 index 0000000000..37a10109ee --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/kernel.pto @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_masked_load_group_tail_s32_reduce_store_kernel( + %src_gm: !pto.ptr, %copy_gm: !pto.ptr, + %sum_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c25 = arith.constant 25 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %zero = arith.constant 0.000000e+00 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_copy = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_group_mask %c25 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + %zero_vec = pto.vmi.broadcast %zero : f32 -> !pto.vmi.vreg<256xf32> + %x = pto.vmi.masked_load %ub_src[%c0], %mask, %zero_vec + : !pto.ptr, !pto.vmi.mask<256xpred>, + !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + pto.vmi.store %x, %ub_copy[%c0] + : !pto.vmi.vreg<256xf32>, !pto.ptr + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_copy, %copy_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/launch.cpp b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/launch.cpp new file mode 100644 index 0000000000..5b39bc3962 --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/launch.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_masked_load_group_tail_s32_reduce_store_kernel(__gm__ float *src, __gm__ float *copy, + __gm__ float *sum); + +void LaunchVmi_masked_load_group_tail_s32_reduce_store_kernel(float *src, float *copy, + float *sum, void *stream) { + vmi_masked_load_group_tail_s32_reduce_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)copy, (__gm__ float *)sum); +} diff --git a/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/main.cpp b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/main.cpp new file mode 100644 index 0000000000..f9f224885e --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/main.cpp @@ -0,0 +1,97 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_masked_load_group_tail_s32_reduce_store_kernel(float *src, float *copy, + float *sum, void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kCols = 32; + constexpr size_t kSrcElems = kRows * kCols; + constexpr size_t kSumElems = kRows; + size_t srcBytes = kSrcElems * sizeof(float); + size_t copyBytes = kSrcElems * sizeof(float); + size_t sumBytes = kSumElems * sizeof(float); + float *srcHost = nullptr; + float *copyHost = nullptr; + float *sumHost = nullptr; + float *srcDevice = nullptr; + float *copyDevice = nullptr; + float *sumDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(©Host), copyBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)©Device, copyBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", copyBytes, copyHost, copyBytes); + ReadFile("./v3.bin", sumBytes, sumHost, sumBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(copyDevice, copyBytes, copyHost, copyBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_masked_load_group_tail_s32_reduce_store_kernel(srcDevice, copyDevice, + sumDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(copyHost, copyBytes, copyDevice, copyBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", copyHost, copyBytes); + WriteFile("./v3.bin", sumHost, sumBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(copyDevice); + aclrtFree(sumDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(copyHost); + aclrtFreeHost(sumHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/ptoas.flags b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi From bd18dc49793682752fa62d8e112da2a842cec905 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 12:13:15 +0800 Subject: [PATCH 06/31] Support dynamic S32 grouped mask lowering --- .../vmi-layout-assignment-implementation.md | 19 +- .../vmi-layout-assignment-lowering-design.md | 4 +- docs/designs/vmi-layout-lowering-cases.md | 74 ++++++- lib/PTO/Transforms/VMIToVPTO.cpp | 189 +++++++++++++++--- ...signment_create_group_mask_s32_dynamic.pto | 61 ++++++ 5 files changed, 316 insertions(+), 31 deletions(-) create mode 100644 test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index 12fedece13..f0dc821444 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -763,6 +763,9 @@ Current audit result: materializes contiguous grouped predicate chunks and then applies predicate pdintlv in the same tree shape as the data vdintlv. It still does not walk from group_reduce_addf to the mask defining op to choose or reject the plan. + The dynamic active_elems_per_group form is also op-local: vmi-to-vpto lowers + contiguous chunks with vci/vshrs/vshls/vsub/vcmps, then uses the same + predicate pdintlv tree for S=32 deinterleaved masks. masked_load: direct lowering is load + vsel. It does not inspect the mask producer to @@ -904,7 +907,7 @@ layout/rematerialization: mask/tail: 3.11.1, 3.15.1, 3.15.2, 3.21, 3.24, 3.26, 3.29, - 3.30, 3.44 + 3.30, 3.44, 3.45 strided/group-slot memory: 3.27, 3.28, 3.37, 3.39 @@ -1063,6 +1066,7 @@ reduce: ```text lit: test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto + test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto runtime SIM: test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store @@ -1345,11 +1349,14 @@ Known implementation gaps before all catalog cases can become runtime SIM coverage: ```text -dynamic grouped masks: - pto.vmi.create_group_mask exists and supports constant - active_elems_per_group. Dynamic active_elems_per_group is not implemented - yet. Do not replace grouped masks with prefix create_mask; that would change - the semantics. +dynamic grouped mask runtime source: + vmi-to-vpto supports dynamic active_elems_per_group for contiguous b32 + grouped masks and S=32 deinterleaved=4/block_elems=8 masks. Full runtime SIM + coverage still needs a supported scalar source for active_elems_per_group in + vector kernels. Direct GM pto.ldg crashed the Bisheng vector backend in this + test shape, and UB pto.load_scalar reached an invalid scalar LSU address in + the SIM. Do not replace grouped masks with prefix create_mask; that would + change the semantics. remaining function runtime coverage: 3.25.1 internal function boundary specialization has layout-assignment and diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index d0f71f0daf..c13944d348 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -557,7 +557,9 @@ grouped masks: assignment emits explicit contiguous and deinterleaved mask values, and `vmi-to-vpto` lowers the deinterleaved mask op itself through contiguous grouped-mask materialization followed by predicate deinterleave. It does not walk from `group_reduce_addf` to the mask producer to choose or reject -the plan. +the plan. Dynamic `active_elems_per_group` follows the same rule: the +`create_group_mask` op lowers its own SSA scalar with vci/vshrs/vshls/vsub/vcmps +for contiguous chunks before any predicate deinterleave. ## 9. Physical Value Ordering diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md index 93a6c1dc57..e44f32a97e 100644 --- a/docs/designs/vmi-layout-lowering-cases.md +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -197,6 +197,7 @@ the immediately following complete endpoints. 3.42 group_slots scf.for loop-carried accumulator complete 3.43 internal function argument boundary materialization complete/design 3.44 masked_load grouped tail feeding S=32 reduce complete +3.45 dynamic S=32 create_group_mask complete/lit ``` ### 3.1 `f16 -> f32 -> store` @@ -5224,7 +5225,6 @@ for r = 0..7: Required assignment rule: -```text `masked_load` and `group_reduce` must share the same grouped mask layout. The passthrough value defines inactive loaded lanes, while the reduce mask defines participation. Assignment materializes two explicit mask values when needed: @@ -5244,4 +5244,76 @@ Runtime coverage: ```text test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store ``` + +### 3.45 Dynamic S=32 `create_group_mask` + +This is the dynamic-shape form of section 3.44. The active column count is an +SSA `index`, not a constant. The semantic mask is still grouped: + +```text +lane i active iff (i % 32) < active_cols +``` + +VMI input: + +```text +%mask = pto.vmi.create_group_mask %active_cols + {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> +``` + +Assigned layouts: + +```text +%mask for masked_load: + !pto.vmi.mask<256xb32, #pto.vmi.layout> + +%mask for S=32 group_reduce: + !pto.vmi.mask<256xb32, + #pto.vmi.layout> +``` + +Contiguous VPTO lowering for one b32 physical chunk: + +```text +%active_i32 = arith.index_cast %active_cols : index to i32 +%active_nonneg = arith.maxsi %active_i32, %c0_i32 : i32 +%active_clamped = arith.minui %active_nonneg, %c32_i32 : i32 + +%all = pto.pset_b32 "PAT_ALL" : !pto.mask +%lane = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> +%row = pto.vshrs %lane, %c5_i16, %all + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +%row_base = pto.vshls %row, %c5_i16, %all + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +%col = pto.vsub %lane, %row_base, %all + : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask + -> !pto.vreg<64xi32> +%m = pto.vcmps %col, %active_clamped, %all, "lt" + : !pto.vreg<64xi32>, i32, !pto.mask -> !pto.mask ``` + +For `deinterleaved = 4, block_elems = 8`, lowering first emits four contiguous +chunks with the sequence above, then applies the same predicate deinterleave +tree used by section 3.44: + +```text +%rm0, %rm1, %rm2, %rm3 = dynamic contiguous grouped masks +%rm01_lo, %rm01_hi = pto.pdintlv_b32 %rm0, %rm1 +%rm23_lo, %rm23_hi = pto.pdintlv_b32 %rm2, %rm3 +%mask_p0, %mask_p2 = pto.pdintlv_b32 %rm01_lo, %rm23_lo +%mask_p1, %mask_p3 = pto.pdintlv_b32 %rm01_hi, %rm23_hi +``` + +The current lit coverage validates the IR lowering: + +```text +test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto +``` + +Runtime SIM coverage is intentionally not listed yet. A direct runtime case +needs a supported way to feed a dynamic scalar `active_cols` into a vector +kernel. Experiments with GM `pto.ldg` and UB `pto.load_scalar` either crashed +the Bisheng vector backend or produced an invalid scalar LSU address in the +SIM. That is an ABI/source-materialization gap, not a `create_group_mask` +layout-lowering gap. diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index 85dbec5f1e..36ccc21f3f 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -424,6 +424,11 @@ Value createI32Constant(Location loc, int64_t value, return rewriter.create(loc, value, 32); } +Value createI16Constant(Location loc, int64_t value, + PatternRewriter &rewriter) { + return rewriter.create(loc, value, 16); +} + FailureOr createPrefixMaskForActiveLanes(Location loc, MaskType maskType, int64_t activeLanes, PatternRewriter &rewriter) { @@ -477,6 +482,17 @@ Value createPartitionActiveLanes(Location loc, Value activeLanesI32, loc, biased, createI32Constant(loc, factor, rewriter)); } +std::optional getPowerOfTwoLog2(int64_t value) { + if (value <= 0 || (value & (value - 1)) != 0) + return std::nullopt; + int64_t log2 = 0; + while (value > 1) { + value >>= 1; + ++log2; + } + return log2; +} + std::optional getPrefixPattern(int64_t activeLanes, int64_t lanesPerPart) { if (activeLanes <= 0) @@ -2160,6 +2176,96 @@ computeGroupMaskMaterialization(VMICreateGroupMaskOp op, std::string *reason) { op, cast(op.getResult().getType()), reason); } +FailureOr> materializeDynamicContiguousGroupMask( + VMICreateGroupMaskOp op, Value activeElemsPerGroup, + VMIMaskType contiguousVMIType, TypeRange resultTypes, + PatternRewriter &rewriter) { + auto fail = [&](const Twine &message) -> FailureOr> { + (void)rewriter.notifyMatchFailure(op, message); + return failure(); + }; + + VMILayoutAttr layout = contiguousVMIType.getLayoutAttr(); + if (!layout || !layout.isContiguous()) + return fail("dynamic create_group_mask requires contiguous seed layout"); + if (contiguousVMIType.getGranularity() != "b32") + return fail("dynamic create_group_mask currently requires b32 " + "granularity"); + + int64_t numGroups = op.getNumGroupsAttr().getInt(); + int64_t groupSize = op.getGroupSizeAttr().getInt(); + if (numGroups <= 0 || groupSize <= 0 || + contiguousVMIType.getElementCount() != numGroups * groupSize) + return fail("dynamic create_group_mask requires result lane count to " + "match num_groups * group_size"); + + FailureOr lanesPerPart = + getMaskLanesPerPart(contiguousVMIType.getGranularity()); + FailureOr arity = getVMIPhysicalArity(contiguousVMIType); + if (failed(lanesPerPart) || failed(arity) || *arity < 1) + return fail("dynamic create_group_mask requires computable physical " + "mask chunks"); + if (static_cast(resultTypes.size()) != *arity) + return fail("dynamic create_group_mask physical result count mismatch"); + if (groupSize > *lanesPerPart || (*lanesPerPart % groupSize) != 0) + return fail("dynamic create_group_mask currently requires group_size to " + "divide one physical b32 predicate chunk"); + + std::optional shift = getPowerOfTwoLog2(groupSize); + if (!shift) + return fail("dynamic create_group_mask currently requires power-of-two " + "group_size"); + + Location loc = op.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + Type i32 = rewriter.getI32Type(); + auto indexVectorType = VRegType::get(ctx, *lanesPerPart, i32); + Value activeI32 = + clampDynamicActiveLanes(loc, activeElemsPerGroup, groupSize, rewriter); + + SmallVector results; + results.reserve(resultTypes.size()); + for (Type resultType : resultTypes) { + auto maskType = dyn_cast(resultType); + if (!maskType || !maskType.isB32()) + return fail("dynamic create_group_mask result must be b32 mask"); + + FailureOr allMask = createAllTrueMask(loc, maskType, rewriter); + if (failed(allMask)) + return fail("failed to create dynamic create_group_mask all mask"); + + Value zero = createI32Constant(loc, 0, rewriter); + Value lane = + rewriter.create(loc, indexVectorType, zero, StringAttr{}) + .getResult(); + + Value col = lane; + if (groupSize != *lanesPerPart) { + Value shiftScalar = createI16Constant(loc, *shift, rewriter); + Value group = rewriter + .create(loc, indexVectorType, lane, + shiftScalar, *allMask) + .getResult(); + Value groupBase = rewriter + .create(loc, indexVectorType, group, + shiftScalar, *allMask) + .getResult(); + col = rewriter + .create(loc, indexVectorType, lane, groupBase, + *allMask) + .getResult(); + } + + results.push_back(rewriter + .create(loc, maskType, col, activeI32, + *allMask, + rewriter.getStringAttr("lt")) + .getResult()); + } + + return results; +} + std::optional getPrefixActiveLaneCount(ArrayRef activeLanes) { bool seenInactive = false; int64_t activeCount = 0; @@ -3797,31 +3903,50 @@ struct OneToNVMICreateGroupMaskOpPattern auto contiguousType = VMIMaskType::get(op.getContext(), resultVMIType.getElementCount(), resultVMIType.getGranularity(), contiguousLayout); - std::string contiguousReason; - FailureOr> - contiguousMaterializations = computeGroupMaskMaterializationForType( - op, contiguousType, &contiguousReason); - if (failed(contiguousMaterializations)) - return rewriter.notifyMatchFailure( - op, Twine("create_group_mask ") + contiguousReason); - SmallVector contiguousParts; - contiguousParts.reserve(contiguousMaterializations->size()); - for (const ConstantMaskChunkMaterialization &materialization : - *contiguousMaterializations) { - if (contiguousParts.size() >= resultTypes.size()) - return rewriter.notifyMatchFailure( - op, "create_group_mask produced too many contiguous masks"); - auto maskType = dyn_cast(resultTypes[contiguousParts.size()]); - if (!maskType) + auto activeConstant = + op.getActiveElemsPerGroup().getDefiningOp(); + if (activeConstant) { + std::string contiguousReason; + FailureOr> + contiguousMaterializations = computeGroupMaskMaterializationForType( + op, contiguousType, &contiguousReason); + if (failed(contiguousMaterializations)) return rewriter.notifyMatchFailure( - op, "create_group_mask result must be mask"); - FailureOr mask = materializeConstantMaskChunk( - op.getLoc(), maskType, materialization.activeLanes, rewriter); - if (failed(mask)) - return rewriter.notifyMatchFailure( - op, "failed to materialize create_group_mask contiguous chunk"); - contiguousParts.push_back(*mask); + op, Twine("create_group_mask ") + contiguousReason); + + contiguousParts.reserve(contiguousMaterializations->size()); + for (const ConstantMaskChunkMaterialization &materialization : + *contiguousMaterializations) { + if (contiguousParts.size() >= resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "create_group_mask produced too many contiguous masks"); + auto maskType = + dyn_cast(resultTypes[contiguousParts.size()]); + if (!maskType) + return rewriter.notifyMatchFailure( + op, "create_group_mask result must be mask"); + FailureOr mask = materializeConstantMaskChunk( + op.getLoc(), maskType, materialization.activeLanes, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "failed to materialize create_group_mask contiguous chunk"); + contiguousParts.push_back(*mask); + } + } else { + FailureOr active = getSingleValue( + op, adaptor.getActiveElemsPerGroup(), + "create_group_mask active_elems_per_group must convert to one " + "value", + rewriter); + if (failed(active)) + return failure(); + FailureOr> dynamicParts = + materializeDynamicContiguousGroupMask(op, *active, contiguousType, + resultTypes, rewriter); + if (failed(dynamicParts)) + return failure(); + contiguousParts = std::move(*dynamicParts); } if (contiguousParts.size() != resultTypes.size()) @@ -3836,6 +3961,24 @@ struct OneToNVMICreateGroupMaskOpPattern return success(); } + auto activeConstant = + op.getActiveElemsPerGroup().getDefiningOp(); + if (!activeConstant && resultLayout && resultLayout.isContiguous()) { + FailureOr active = getSingleValue( + op, adaptor.getActiveElemsPerGroup(), + "create_group_mask active_elems_per_group must convert to one value", + rewriter); + if (failed(active)) + return failure(); + FailureOr> results = + materializeDynamicContiguousGroupMask(op, *active, resultVMIType, + resultTypes, rewriter); + if (failed(results)) + return failure(); + rewriter.replaceOp(op, *results, adaptor.getResultMapping()); + return success(); + } + std::string reason; FailureOr> materializations = computeGroupMaskMaterialization(op, &reason); diff --git a/test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto b/test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto new file mode 100644 index 0000000000..f68b4d5509 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_create_group_mask_s32_dynamic( + %base: !pto.ptr, + %sum_out: !pto.ptr, + %off: index, + %active_cols: index) { + %c0_f32 = arith.constant 0.000000e+00 : f32 + %c1 = arith.constant 1 : index + %mask = pto.vmi.create_group_mask %active_cols + {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + %zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %x = pto.vmi.masked_load %base[%off], %mask, %zero + : !pto.ptr, !pto.vmi.mask<256xpred>, + !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %sum_out[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_create_group_mask_s32_dynamic( +// ASSIGN-SAME: %[[ACTIVE:arg[0-9]+]]: index) +// ASSIGN: %[[MASK0:.*]] = pto.vmi.create_group_mask %[[ACTIVE]] +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK1:.*]] = pto.vmi.create_group_mask %[[ACTIVE]] +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_reduce_addf +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" + +// LOWER-LABEL: func.func @vmi_layout_assignment_create_group_mask_s32_dynamic( +// LOWER: arith.index_cast +// LOWER: arith.maxsi +// LOWER: arith.minui +// LOWER: pto.vci +// LOWER: pto.vshrs +// LOWER: pto.vshls +// LOWER: pto.vsub +// LOWER-COUNT-8: pto.vcmps +// LOWER-COUNT-4: pto.pdintlv_b32 +// LOWER-COUNT-4: pto.vcgadd +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast From fcf1096192d16e6ddcdf085da0b5261a3ca3b4cb Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 12:15:28 +0800 Subject: [PATCH 07/31] Clarify VMI layout case coverage gaps --- .../vmi-layout-assignment-implementation.md | 18 ++++-- .../vmi-layout-assignment-lowering-design.md | 57 ++++++++++++++++++- 2 files changed, 70 insertions(+), 5 deletions(-) diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index f0dc821444..a4fb146317 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -876,13 +876,13 @@ the case catalog. Current broad runtime sweep: ```text -WORK_SPACE=$PWD/.tmp/vmi-runtime-batch-40 CASE_PREFIX='vmi/' JOBS=4 \ +WORK_SPACE=$PWD/.tmp/vmi-runtime-batch-dynamic-mask CASE_PREFIX='vmi/' JOBS=4 \ test/vpto/scripts/run_host_vpto_validation_parallel.sh PASS=40 FAIL=0 -summary: .tmp/vmi-runtime-batch-40/parallel-summary.tsv +summary: .tmp/vmi-runtime-batch-dynamic-mask/parallel-summary.tsv log scan: rg -n "RV_|alignment|\[ERROR\]|\[error\]|ERROR" \ - .tmp/vmi-runtime-batch-40.log + .tmp/vmi-runtime-batch-dynamic-mask.log result: no matches ``` @@ -1066,12 +1066,22 @@ reduce: ```text lit: test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto - test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto runtime SIM: test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store ``` +Current checked-in lit coverage for 3.45 dynamic S=32 `create_group_mask`: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto + +runtime SIM: + blocked by the current dynamic scalar source gap for vector kernels; see + known implementation gaps below +``` + Current checked-in runtime coverage for 3.12 control-flow join before S=32 `group_reduce`: diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index c13944d348..da58668057 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -72,6 +72,7 @@ control flow: mask and tail: prefix mask group-periodic mask + dynamic group-periodic mask masked_load tail with explicit passthrough instead of padding masked_load grouped tail feeding group_reduce masked select/store @@ -86,6 +87,59 @@ strided memory: group_store slots=1 with non-unit output stride ``` +### 1.1 Case-Set Sufficiency + +The current case set is sufficient to define the first implementation of layout +assignment and lowering. It covers every decision axis that has changed the +design so far: + +```text +physical dense layout: + contiguous, deinterleaved=2/4, block_elems=1/8 + +sparse result layout: + group_slots(G, slots=8) for packed VCG results + group_slots(G, slots=1) for row-local S=64 results + +producer-driven layout: + load, group_load, group_slot_load, broadcast, create_mask, + create_group_mask + +consumer-driven pressure: + dense store, group_reduce, group_store, group_broadcast, truncf, + elementwise/select, masked_load/masked_store + +conflict resolution: + cheap rematerialization, explicit ensure_layout, explicit diagnostics + +control-flow propagation: + scf.if, scf.for iter_args/results, internal/private function boundaries, + public ABI rejection + +memory legality: + full_tile_readable proof, grouped masks, predicate granularity, aligned + strided group memory, stable gather diagnostic +``` + +No extra layout kind should be added unless a new case proves that the existing +layouts and plans cannot express the logical behavior. The remaining open +items are not missing layout semantics: + +```text +dynamic active_elems_per_group runtime source: + create_group_mask layout lowering is defined and has lit coverage; runtime + SIM still needs a supported scalar source/ABI for vector kernels. + +private vector function runtime: + assignment/lowering semantics are defined; full ptoas runtime depends on + backend support or an inlining policy for physical VPTO vector callees. + +diagnostic-only cases: + compact S=12 gather fallback, packed slots=8 width-changing cast, public VMI + ABI, unsafe masked_load tail, and unaligned/dynamic group memory remain + explicit capability boundaries. +``` + ## 2. Layout Domain Layout is a property of a layout-assigned VMI value, not a property inferred by @@ -626,5 +680,6 @@ The design is complete only when: 3. every unsupported case has a precise capability diagnostic 4. every control-flow/function boundary either specializes layout or diagnoses 5. every mask has explicit data layout and predicate granularity -6. every case has an end-to-end test and simulator validation +6. every positive case has end-to-end lit coverage +7. every simulator-supported positive case has simulator validation ``` From 6f04810c493c8e8ad58357a684a16b96bbc7a66c Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 12:18:40 +0800 Subject: [PATCH 08/31] Record VMI layout coverage audit --- .../vmi-layout-assignment-implementation.md | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index a4fb146317..fbfef70804 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -937,6 +937,39 @@ Aggregate catalog headings are covered through their endpoint subcases: 3.25.2 public/external boundary diagnostics ``` +Current coverage audit result: + +```text +SIM-backed positive endpoints: + 3.1, 3.2, 3.3, 3.4, 3.5.1, 3.5.2, 3.5.3, + 3.6.1, 3.6.2, 3.6.3, 3.7.1, 3.7.2, 3.7.3, + 3.8, 3.10, 3.11.1, 3.12, 3.15.1, 3.15.2, + 3.16.1 positive, 3.16.2 positive, 3.17, 3.18, + 3.19.1, 3.20, 3.21, 3.22, 3.23, 3.24, 3.26, + 3.27 positive, 3.28 positive, 3.29, 3.31, 3.32, + 3.33, 3.34, 3.35, 3.36, 3.37, 3.38, 3.39, + 3.40, 3.41, 3.42, 3.44 + +lit-backed positive endpoints with runtime gap: + 3.25.1 private/internal function boundary + 3.43 internal function argument boundary materialization + 3.45 dynamic S=32 create_group_mask + +diagnostic endpoints: + 3.7.4, 3.9, 3.11.2, 3.13, 3.14, 3.15.3, + 3.16.1 non-unit slots=8 source stride, + 3.16.2 dynamic/unaligned slots=1 source stride, + 3.19.2, 3.25.2, 3.27 unaligned source_group_stride, + 3.30 unsafe masked_load tail + +repository evidence: + all concrete lit/runtime paths listed below exist + all 40 runtime case directories contain kernel.pto, launch.cpp, main.cpp, + golden.py, and compare.py + latest broad VMI runtime sweep passed: PASS=40 FAIL=0 + latest full VMI lit sweep passed: 312/312 +``` + Current checked-in coverage for 3.3 dense f8->f32->compute->f8: ```text From 4b3d5be721297b5711a6af0ebad85f3d6c2b55ea Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 12:31:17 +0800 Subject: [PATCH 09/31] Add dynamic S32 group mask runtime coverage --- .../vmi-layout-assignment-implementation.md | 34 +++---- .../vmi-layout-assignment-lowering-design.md | 6 +- docs/designs/vmi-layout-lowering-cases.md | 15 ++- .../compare.py | 40 ++++++++ .../golden.py | 51 ++++++++++ .../kernel.pto | 64 ++++++++++++ .../launch.cpp | 35 +++++++ .../main.cpp | 99 +++++++++++++++++++ .../ptoas.flags | 1 + 9 files changed, 315 insertions(+), 30 deletions(-) create mode 100644 test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/compare.py create mode 100644 test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/golden.py create mode 100644 test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/kernel.pto create mode 100644 test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/launch.cpp create mode 100644 test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/main.cpp create mode 100644 test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/ptoas.flags diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index fbfef70804..b05d6e1552 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -876,13 +876,13 @@ the case catalog. Current broad runtime sweep: ```text -WORK_SPACE=$PWD/.tmp/vmi-runtime-batch-dynamic-mask CASE_PREFIX='vmi/' JOBS=4 \ +WORK_SPACE=$PWD/.tmp/vmi-runtime-batch-dynamic-scalar CASE_PREFIX='vmi/' JOBS=4 \ test/vpto/scripts/run_host_vpto_validation_parallel.sh -PASS=40 FAIL=0 -summary: .tmp/vmi-runtime-batch-dynamic-mask/parallel-summary.tsv +PASS=41 FAIL=0 +summary: .tmp/vmi-runtime-batch-dynamic-scalar/parallel-summary.tsv log scan: rg -n "RV_|alignment|\[ERROR\]|\[error\]|ERROR" \ - .tmp/vmi-runtime-batch-dynamic-mask.log + .tmp/vmi-runtime-batch-dynamic-scalar.log result: no matches ``` @@ -948,12 +948,11 @@ SIM-backed positive endpoints: 3.19.1, 3.20, 3.21, 3.22, 3.23, 3.24, 3.26, 3.27 positive, 3.28 positive, 3.29, 3.31, 3.32, 3.33, 3.34, 3.35, 3.36, 3.37, 3.38, 3.39, - 3.40, 3.41, 3.42, 3.44 + 3.40, 3.41, 3.42, 3.44, 3.45 lit-backed positive endpoints with runtime gap: 3.25.1 private/internal function boundary 3.43 internal function argument boundary materialization - 3.45 dynamic S=32 create_group_mask diagnostic endpoints: 3.7.4, 3.9, 3.11.2, 3.13, 3.14, 3.15.3, @@ -964,9 +963,9 @@ diagnostic endpoints: repository evidence: all concrete lit/runtime paths listed below exist - all 40 runtime case directories contain kernel.pto, launch.cpp, main.cpp, + all 41 runtime case directories contain kernel.pto, launch.cpp, main.cpp, golden.py, and compare.py - latest broad VMI runtime sweep passed: PASS=40 FAIL=0 + latest broad VMI runtime sweep passed: PASS=41 FAIL=0 latest full VMI lit sweep passed: 312/312 ``` @@ -1104,15 +1103,19 @@ runtime SIM: test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store ``` -Current checked-in lit coverage for 3.45 dynamic S=32 `create_group_mask`: +Current checked-in coverage for 3.45 dynamic S=32 `create_group_mask`: ```text lit: test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto runtime SIM: - blocked by the current dynamic scalar source gap for vector kernels; see - known implementation gaps below + test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store + +runtime scalar source: + active_cols is passed as a kernel i32 scalar argument and cast to index inside + vecscope before pto.vmi.create_group_mask. This is an explicit scalar ABI, + not a value recovered by vmi-to-vpto from producer/consumer context. ``` Current checked-in runtime coverage for 3.12 control-flow join before S=32 @@ -1392,15 +1395,6 @@ Known implementation gaps before all catalog cases can become runtime SIM coverage: ```text -dynamic grouped mask runtime source: - vmi-to-vpto supports dynamic active_elems_per_group for contiguous b32 - grouped masks and S=32 deinterleaved=4/block_elems=8 masks. Full runtime SIM - coverage still needs a supported scalar source for active_elems_per_group in - vector kernels. Direct GM pto.ldg crashed the Bisheng vector backend in this - test shape, and UB pto.load_scalar reached an invalid scalar LSU address in - the SIM. Do not replace grouped masks with prefix create_mask; that would - change the semantics. - remaining function runtime coverage: 3.25.1 internal function boundary specialization has layout-assignment and vmi-to-vpto lit coverage, but full ptoas emission still fails after diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index da58668057..99a1a34c6c 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -127,8 +127,10 @@ items are not missing layout semantics: ```text dynamic active_elems_per_group runtime source: - create_group_mask layout lowering is defined and has lit coverage; runtime - SIM still needs a supported scalar source/ABI for vector kernels. + create_group_mask layout lowering is defined and has both lit and SIM + coverage. The supported runtime source is a kernel scalar argument cast to + index inside vecscope; vmi-to-vpto does not recover this value from GM/UB + scalar loads or surrounding context. private vector function runtime: assignment/lowering semantics are defined; full ptoas runtime depends on diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md index e44f32a97e..160b25a398 100644 --- a/docs/designs/vmi-layout-lowering-cases.md +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -197,7 +197,7 @@ the immediately following complete endpoints. 3.42 group_slots scf.for loop-carried accumulator complete 3.43 internal function argument boundary materialization complete/design 3.44 masked_load grouped tail feeding S=32 reduce complete -3.45 dynamic S=32 create_group_mask complete/lit +3.45 dynamic S=32 create_group_mask complete ``` ### 3.1 `f16 -> f32 -> store` @@ -5305,15 +5305,14 @@ tree used by section 3.44: %mask_p1, %mask_p3 = pto.pdintlv_b32 %rm01_hi, %rm23_hi ``` -The current lit coverage validates the IR lowering: +Current coverage validates both IR lowering and runtime behavior: ```text test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto +test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store ``` -Runtime SIM coverage is intentionally not listed yet. A direct runtime case -needs a supported way to feed a dynamic scalar `active_cols` into a vector -kernel. Experiments with GM `pto.ldg` and UB `pto.load_scalar` either crashed -the Bisheng vector backend or produced an invalid scalar LSU address in the -SIM. That is an ABI/source-materialization gap, not a `create_group_mask` -layout-lowering gap. +The runtime case passes `active_cols` as a kernel scalar argument and casts it +to `index` inside `pto.vecscope`. This keeps scalar materialization outside +`vmi-to-vpto`; the lowering pass only consumes the current +`create_group_mask` operand. diff --git a/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/compare.py b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/compare.py new file mode 100644 index 0000000000..9f34394fa1 --- /dev/null +++ b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(name: str, golden_name: str) -> None: + golden = np.fromfile(golden_name, dtype=np.float32) + output = np.fromfile(name, dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + return + if golden.shape != output.shape: + print(f"[ERROR] compare failed {name}: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +def main() -> None: + check("v2.bin", "golden_v2.bin") + check("v3.bin", "golden_v3.bin") + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/golden.py b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/golden.py new file mode 100644 index 0000000000..df3f6a24dc --- /dev/null +++ b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 32 +ACTIVE = 25 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + active_base = np.linspace(-0.875, 0.625, ACTIVE, dtype=np.float32) + inactive_base = np.linspace(19.0, 22.5, COLS - ACTIVE, dtype=np.float32) + src = np.empty((ROWS, COLS), dtype=np.float32) + for row in range(ROWS): + src[row, :ACTIVE] = active_base + np.float32(row) * np.float32(0.03125) + src[row, ACTIVE:] = inactive_base + np.float32(row) * np.float32(1.75) + + copy = np.full((ROWS, COLS), SENTINEL, dtype=np.float32) + sums = np.full(ROWS, SENTINEL, dtype=np.float32) + golden_copy = src.copy() + golden_copy[:, ACTIVE:] = np.float32(0.0) + golden_sum = np.sum(src[:, :ACTIVE], axis=1, dtype=np.float32).astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + copy.reshape(-1).tofile(output_dir / "v2.bin") + sums.tofile(output_dir / "v3.bin") + golden_copy.reshape(-1).astype(np.float32).tofile(output_dir / "golden_v2.bin") + golden_sum.astype(np.float32).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/kernel.pto b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/kernel.pto new file mode 100644 index 0000000000..8e9ebed693 --- /dev/null +++ b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/kernel.pto @@ -0,0 +1,64 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dynamic_create_group_mask_s32_reduce_store_kernel( + %src_gm: !pto.ptr, %copy_gm: !pto.ptr, + %sum_gm: !pto.ptr, %active_cols_i32: i32) + attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %zero = arith.constant 0.000000e+00 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_copy = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %active_cols = arith.index_cast %active_cols_i32 : i32 to index + %mask = pto.vmi.create_group_mask %active_cols + {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + %zero_vec = pto.vmi.broadcast %zero : f32 -> !pto.vmi.vreg<256xf32> + %x = pto.vmi.masked_load %ub_src[%c0], %mask, %zero_vec + : !pto.ptr, !pto.vmi.mask<256xpred>, + !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + pto.vmi.store %x, %ub_copy[%c0] + : !pto.vmi.vreg<256xf32>, !pto.ptr + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_copy, %copy_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/launch.cpp b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/launch.cpp new file mode 100644 index 0000000000..5865140b26 --- /dev/null +++ b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/launch.cpp @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dynamic_create_group_mask_s32_reduce_store_kernel(__gm__ float *src, __gm__ float *copy, + __gm__ float *sum, int activeCols); + +void LaunchVmi_dynamic_create_group_mask_s32_reduce_store_kernel(float *src, float *copy, + float *sum, int activeCols, + void *stream) { + vmi_dynamic_create_group_mask_s32_reduce_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)copy, (__gm__ float *)sum, + activeCols); +} diff --git a/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/main.cpp b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/main.cpp new file mode 100644 index 0000000000..7bd86defb1 --- /dev/null +++ b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/main.cpp @@ -0,0 +1,99 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dynamic_create_group_mask_s32_reduce_store_kernel(float *src, float *copy, + float *sum, int activeCols, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kCols = 32; + constexpr int kActiveCols = 25; + constexpr size_t kSrcElems = kRows * kCols; + constexpr size_t kSumElems = kRows; + size_t srcBytes = kSrcElems * sizeof(float); + size_t copyBytes = kSrcElems * sizeof(float); + size_t sumBytes = kSumElems * sizeof(float); + float *srcHost = nullptr; + float *copyHost = nullptr; + float *sumHost = nullptr; + float *srcDevice = nullptr; + float *copyDevice = nullptr; + float *sumDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(©Host), copyBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)©Device, copyBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", copyBytes, copyHost, copyBytes); + ReadFile("./v3.bin", sumBytes, sumHost, sumBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(copyDevice, copyBytes, copyHost, copyBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_dynamic_create_group_mask_s32_reduce_store_kernel( + srcDevice, copyDevice, sumDevice, kActiveCols, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(copyHost, copyBytes, copyDevice, copyBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", copyHost, copyBytes); + WriteFile("./v3.bin", sumHost, sumBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(copyDevice); + aclrtFree(sumDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(copyHost); + aclrtFreeHost(sumHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/ptoas.flags b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi From 604fd5059ee7bfddf22a325f4b23d43d566ed03a Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 12:35:03 +0800 Subject: [PATCH 10/31] Detail VMI layout assignment request rules --- .../vmi-layout-assignment-implementation.md | 177 ++++++++++++++++++ .../vmi-layout-assignment-lowering-design.md | 111 +++++++++++ 2 files changed, 288 insertions(+) diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index b05d6e1552..3d0cab8215 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -421,6 +421,109 @@ compact S=12 logical S=16: diagnose if gather fallback is disabled/missing ``` +### 6.3.1 Request Builders + +Implement request generation as small per-op builders. The builders produce +candidate plans and use-site requests; they do not rewrite IR. + +```text +buildStoreRequests: + ordinary store -> dense contiguous request unless a layout-aware store plan is + selected + group_store -> group_slots(G,K) request plus stride/alignment capability + checks + +buildCastRequests: + extf f16->f32 -> source contiguous, result deinterleaved=2 + extf f8->f32 -> source contiguous, result deinterleaved=4 + truncf f32->f16 -> source deinterleaved=2/block_elems=1, result contiguous + truncf f32->f8 -> source deinterleaved=4/block_elems=1, result contiguous + group_slots slots=1 f32->f16 -> slot-preserving plan + group_slots slots=8 width-changing cast -> diagnostic unless a packed plan + exists + +buildGroupReduceRequests: + derive S = logical_lanes / num_groups + S=8 -> contiguous source, group_slots(G,8) result + S=16 -> deinterleaved=2/block_elems=1 or block_elems=8 source, + group_slots(G,8) result + S=32 -> deinterleaved=4/block_elems=1 or block_elems=8 source, + group_slots(G,8) result + S=64 -> contiguous source, group_slots(G,1) result + other S -> diagnostic unless an explicit fallback plan is enabled + +buildGroupMemoryRequests: + group_load S=16/S=32 with aligned constant stride -> block_elems=8 plan + group_load row-local full chunks -> contiguous plan + group_slot_load unit stride -> group_slots(G,8) + group_slot_load aligned row-local stride -> group_slots(G,1) + unsupported dynamic/unaligned grouped memory -> diagnostic + +buildMaskRequests: + mask layout follows each consuming data layout + predicate granularity follows each consuming element type + create_mask/create_group_mask may be cloned for incompatible mask layout or + granularity requests + +buildControlFlowRequests: + region yields, branch operands, loop iter_args, call operands, and returns + create equality requests on the carried VMI layout variable +``` + +Request builders must record the requesting op. Diagnostics and inserted +helpers are use-site operations, so the user can see which consumer forced a +layout. + +### 6.3.2 Producer Classes + +The solver uses producer classes to decide whether a conflict can be solved by +cloning, equivalence propagation, or materialization. + +```text +cheap rematerializable producers: + load when address operands dominate the clone site, no intervening may-alias + write exists, and any full_read_elems proof is preserved + broadcast + create_mask + create_group_mask + group_broadcast + group_slot_load when the same address/no-alias/proof conditions as load hold + and the selected memory plan is legal at the clone site + +layout-transparent producers: + add/sub/mul/fma/min/max/neg/abs + select + bitcast + integer bitwise and shift ops + +fixed-layout producers: + extf/truncf physical conversion plans + group_load block-fragment plans + group_reduce result group_slots + masked_load when the physical memory-safety proof fixes a full-read plan +``` + +Conflict policy: + +```text +cheap producer: + clone for each incompatible request when cloning does not duplicate a + side-effect, cross an aliasing write, or duplicate an illegal memory read + +layout-transparent producer: + merge into the consumer-requested equivalence class; insert materialization + only at incompatible uses + +fixed-layout producer: + use registered materialization only; otherwise diagnose +``` + +This is the rule that keeps case 3.32 legal: a plain `load` can be assigned to +`deinterleaved=4, block_elems=1` for both `truncf f32->f8` and S=32 +`group_reduce`. It also keeps case 3.19.2 diagnostic: a strided `group_load` +that selected `block_elems=8` is fixed unless a block8-to-parity +materialization or rematerialized memory plan is registered. + ### 6.4 Solving And Rewriting Algorithm: @@ -451,6 +554,80 @@ Every ensure_* helper has a registered materialization plan. Every function/call signature carrying VMI is specialized or diagnosed. ``` +### 6.5 Rewrite Artifacts + +Assignment rewrites the IR so that later lowering has no hidden choices. + +```text +type rewrite: + every VMI data/mask result and block argument receives a layout attr + +selected_plan rewrite: + context-sensitive ops receive vmi.selected_plan + examples: group_reduce_addf, group_load, group_slot_load, group_broadcast, + group_slot cast, full-read masked_load plans + +clone rewrite: + cheap producers are cloned before their divergent use sites + each clone receives its own layout and selected_plan + +ensure rewrite: + non-cheap values use pto.vmi.ensure_layout or ensure_mask_layout at the use + site, with source and target layouts visible in the types + +granularity rewrite: + one semantic mask used by f32 and f16 consumers gets + ensure_mask_granularity or cloned mask producers + +control-flow rewrite: + scf.if/scf.for yields and block arguments are rewritten to one agreed layout; + materialization is inserted before yield when branches differ + +function rewrite: + private VMI functions are specialized or get callee-entry ensure_layout + public/external VMI functions are diagnosed +``` + +Canonical assigned IR shape for a conflicting load: + +```text +%x = pto.vmi.load ... {vmi.selected_plan = "load_dintlv4"} + : ... -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%x_dense = pto.vmi.ensure_layout %x + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +pto.vmi.store %x_dense, ... +``` + +Canonical assigned IR shape for a cloned cheap producer: + +```text +%x_s16 = pto.vmi.load ... {vmi.selected_plan = "load_dintlv2"} + : ... -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%x_s32 = pto.vmi.load ... {vmi.selected_plan = "load_dintlv4"} + : ... -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +Canonical assigned IR shape for `group_broadcast` multi-use: + +```text +%b0 = pto.vmi.group_broadcast %slots + {vmi.selected_plan = "group_broadcast_slots8_vselr"} + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%b1 = pto.vmi.group_broadcast %slots + {vmi.selected_plan = "group_broadcast_slots8_vselr"} + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +If the assigned IR does not have one of these explicit shapes, `vmi-to-vpto` +must reject it instead of attempting to recover the missing decision. + ## 7. OneToN Type Conversion `vmi-to-vpto` should use OneToN conversion for VMI values. diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index 99a1a34c6c..5e43f6d9ec 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -422,6 +422,117 @@ one mask used by f32 and f16 consumers: vmi-to-vpto consumes the assigned per-use mask materialization ``` +### 5.5 Case-Driven Request Matrix + +The first implementation should build requests from the following finite table. +This table is deliberately case-derived; adding a new request kind requires a +new catalog case or a proof that it is equivalent to one listed here. + +```text +dense store: + requests dense contiguous source + if source is deinterleaved, assignment must insert ensure_layout or select a + store plan such as vstsx2 that consumes the assigned layout explicitly + +truncf f32 -> f16: + requests source deinterleaved=2, block_elems=1 + requests result contiguous f16 + +truncf f32 -> f8: + requests source deinterleaved=4, block_elems=1 + requests result contiguous f8 + +group_reduce S=8: + requests source contiguous + requests result group_slots(num_groups, slots=8) + +group_reduce S=16: + requests source deinterleaved=2, block_elems=1 or block_elems=8 + requests result group_slots(num_groups, slots=8) + +group_reduce S=32: + requests source deinterleaved=4, block_elems=1 or block_elems=8 + requests result group_slots(num_groups, slots=8) + +group_reduce S=64: + requests source contiguous + requests result group_slots(num_groups, slots=1) + +group_broadcast: + requests source group_slots(num_groups, slots=K) + produces one dense result layout per consumer request + is cloned per incompatible dense consumer + +group_store: + requests source group_slots(num_groups, slots=K) + selected plan also records output stride legality + +group_slot_load: + requests result group_slots(num_groups, slots=8) for packed unit-stride slots + requests result group_slots(num_groups, slots=1) for row-local aligned slots + +group_load: + requests result deinterleaved=2/4, block_elems=8 for S=16/S=32 block + fragment plans, or contiguous for row-local full-chunk plans + +masked_load: + requests result layout from its consumers + requests mask layout matching the result + requires explicit passthrough; padding is not synthesized + +create_mask/create_group_mask: + produces whichever mask layout each consumer requests + may be cloned per incompatible mask layout or granularity +``` + +Important negative requests: + +```text +ordinary dense add/mul/store/truncf cannot request group_slots +packed group_slots(slots=8) cannot request width-changing cast unless a packed +slot-preserving cast plan is registered +slots=1 group_store cannot request unit-stride row-major output until a pack or +unaligned-store plan exists +``` + +### 5.6 Conflict Resolution Matrix + +When one value receives incompatible requests, assignment resolves it using the +first legal row below. `vmi-to-vpto` never repeats this decision. + +```text +cheap producer with multiple requested layouts: + clone the producer and assign each clone independently + examples: load, broadcast, create_mask, create_group_mask, group_broadcast + memory-read producers require the same explicit no-alias and safe-read proof + at each clone site + +non-cheap value with registered materialization: + keep one chosen layout on the value and insert ensure_layout at the use site + examples: deinterleaved=4 -> contiguous before dense store + +layout-transparent chain: + assign the whole equivalence class to the non-contiguous consumer request when + that avoids materialization + examples: broadcast -> addf -> S=32 group_reduce + +control-flow join: + all incoming values must be materialized to one layout before yield/branch + examples: scf.if yielding group_slots, scf.for loop-carried group_slots + +private function boundary: + specialize or materialize at call/callee-entry before vmi-to-vpto + +no clone/materialization/specialization plan: + emit a diagnostic naming the requesting op and both layouts +``` + +The cost model may choose between legal rows only when the observable contract +is identical. For example, S=16 `block_elems=1` and `block_elems=8` are both +valid reduce inputs, but `block_elems=8` is selected only when a producer plan +such as strided `group_load` naturally creates 32B row fragments or when cost +proves it cheaper without breaking another consumer such as `truncf`. + ## 6. Layout Assignment Algorithm `vmi-layout-assignment` is module-level. It must see function/call/control-flow From e353ae059169095e0d07f0397d15dd88cff9b532 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 12:40:16 +0800 Subject: [PATCH 11/31] Complete VMI layout request builder coverage --- .../vmi-layout-assignment-implementation.md | 15 +++++++++++++ .../vmi-layout-assignment-lowering-design.md | 22 +++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index 3d0cab8215..36ef7a453c 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -459,15 +459,30 @@ buildGroupMemoryRequests: group_slot_load aligned row-local stride -> group_slots(G,1) unsupported dynamic/unaligned grouped memory -> diagnostic +buildElementwiseRequests: + dense add/mul/fma/min/max/select -> all dense operands/results share one + dense layout + group-slot add/mul/select -> all operands/results share one group_slots(G,K) + dense/group_slots mixing -> diagnostic unless an explicit group_broadcast or + group_store boundary exists + buildMaskRequests: mask layout follows each consuming data layout predicate granularity follows each consuming element type create_mask/create_group_mask may be cloned for incompatible mask layout or granularity requests + masked_store requests source layout, mask layout, and store predicate + granularity explicitly buildControlFlowRequests: region yields, branch operands, loop iter_args, call operands, and returns create equality requests on the carried VMI layout variable + +buildFunctionBoundaryRequests: + private/internal function argument/result layouts are specialized or + materialized with callee-entry/return-site helpers + public/external VMI arguments/results diagnose unless enablePublicVMIABI has + a real ABI plan ``` Request builders must record the requesting op. Diagnostics and inserted diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index 5e43f6d9ec..c16fedcde3 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -467,6 +467,17 @@ group_store: requests source group_slots(num_groups, slots=K) selected plan also records output stride legality +dense elementwise add/mul/fma/min/max/select: + requests all dense data operands and results use one dense layout + mask operands request the same data layout and the consumer element + granularity + +group-slot elementwise add/mul/select: + requests all group-slot operands and results use the same + group_slots(num_groups, slots=K) + rejects mixing dense and group_slots without explicit group_broadcast or + group_store + group_slot_load: requests result group_slots(num_groups, slots=8) for packed unit-stride slots requests result group_slots(num_groups, slots=1) for row-local aligned slots @@ -480,9 +491,20 @@ masked_load: requests mask layout matching the result requires explicit passthrough; padding is not synthesized +masked_store: + requests dense source layout selected by the store plan + requests mask layout matching the source layout and store element granularity + does not choose memory safety for an earlier load + create_mask/create_group_mask: produces whichever mask layout each consumer requests may be cloned per incompatible mask layout or granularity + +scf.if/scf.for/call/return: + requests equality across carried VMI values, yielded values, call operands, + callee arguments, and function results + private/internal functions may specialize or materialize at boundaries + public/external VMI boundaries are diagnostics until an ABI is defined ``` Important negative requests: From cf9a04df66eda4fadbc006eb95d45fc46de66182 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 13:05:04 +0800 Subject: [PATCH 12/31] Inline private VMI physical helpers before VPTO emission --- .../vmi-layout-assignment-implementation.md | 77 ++++++------ .../vmi-layout-assignment-lowering-design.md | 12 +- docs/designs/vmi-layout-lowering-cases.md | 37 +++++- ...o => vmi_ptoas_call_boundary_vecscope.pto} | 15 ++- .../lit/vmi/vmi_ptoas_private_call_inline.pto | 42 +++++++ .../compare.py | 40 ++++++ .../golden.py | 46 +++++++ .../kernel.pto | 70 +++++++++++ .../launch.cpp | 36 ++++++ .../main.cpp | 99 +++++++++++++++ .../ptoas.flags | 1 + .../vmi/private-call-inline-store/compare.py | 40 ++++++ .../vmi/private-call-inline-store/golden.py | 46 +++++++ .../vmi/private-call-inline-store/kernel.pto | 67 ++++++++++ .../vmi/private-call-inline-store/launch.cpp | 33 +++++ .../vmi/private-call-inline-store/main.cpp | 97 +++++++++++++++ .../vmi/private-call-inline-store/ptoas.flags | 1 + tools/ptoas/ptoas.cpp | 117 ++++++++++++++++++ 18 files changed, 829 insertions(+), 47 deletions(-) rename test/lit/vmi/{vmi_ptoas_call_boundary_vecscope_invalid.pto => vmi_ptoas_call_boundary_vecscope.pto} (78%) create mode 100644 test/lit/vmi/vmi_ptoas_private_call_inline.pto create mode 100644 test/vpto/cases/vmi/private-call-argument-boundary-store/compare.py create mode 100644 test/vpto/cases/vmi/private-call-argument-boundary-store/golden.py create mode 100644 test/vpto/cases/vmi/private-call-argument-boundary-store/kernel.pto create mode 100644 test/vpto/cases/vmi/private-call-argument-boundary-store/launch.cpp create mode 100644 test/vpto/cases/vmi/private-call-argument-boundary-store/main.cpp create mode 100644 test/vpto/cases/vmi/private-call-argument-boundary-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/private-call-inline-store/compare.py create mode 100644 test/vpto/cases/vmi/private-call-inline-store/golden.py create mode 100644 test/vpto/cases/vmi/private-call-inline-store/kernel.pto create mode 100644 test/vpto/cases/vmi/private-call-inline-store/launch.cpp create mode 100644 test/vpto/cases/vmi/private-call-inline-store/main.cpp create mode 100644 test/vpto/cases/vmi/private-call-inline-store/ptoas.flags diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index 36ef7a453c..dc54a4af09 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -1068,13 +1068,13 @@ the case catalog. Current broad runtime sweep: ```text -WORK_SPACE=$PWD/.tmp/vmi-runtime-batch-dynamic-scalar CASE_PREFIX='vmi/' JOBS=4 \ +WORK_SPACE=$PWD/.tmp/vmi-runtime-batch-private-calls CASE_PREFIX='vmi/' JOBS=4 \ test/vpto/scripts/run_host_vpto_validation_parallel.sh -PASS=41 FAIL=0 -summary: .tmp/vmi-runtime-batch-dynamic-scalar/parallel-summary.tsv +PASS=43 FAIL=0 +summary: .tmp/vmi-runtime-batch-private-calls/parallel-summary.tsv log scan: rg -n "RV_|alignment|\[ERROR\]|\[error\]|ERROR" \ - .tmp/vmi-runtime-batch-dynamic-scalar.log + .tmp/vmi-runtime-batch-private-calls.log result: no matches ``` @@ -1125,7 +1125,7 @@ Aggregate catalog headings are covered through their endpoint subcases: 3.16.2 row-local slots=1 positive plus dynamic/unaligned diagnostics 3.25 function boundary layout specialization: - 3.25.1 private/internal boundary lit coverage, runtime backend gap + 3.25.1 private/internal boundary lit and runtime coverage 3.25.2 public/external boundary diagnostics ``` @@ -1137,14 +1137,10 @@ SIM-backed positive endpoints: 3.6.1, 3.6.2, 3.6.3, 3.7.1, 3.7.2, 3.7.3, 3.8, 3.10, 3.11.1, 3.12, 3.15.1, 3.15.2, 3.16.1 positive, 3.16.2 positive, 3.17, 3.18, - 3.19.1, 3.20, 3.21, 3.22, 3.23, 3.24, 3.26, + 3.19.1, 3.20, 3.21, 3.22, 3.23, 3.24, 3.25.1, 3.26, 3.27 positive, 3.28 positive, 3.29, 3.31, 3.32, 3.33, 3.34, 3.35, 3.36, 3.37, 3.38, 3.39, - 3.40, 3.41, 3.42, 3.44, 3.45 - -lit-backed positive endpoints with runtime gap: - 3.25.1 private/internal function boundary - 3.43 internal function argument boundary materialization + 3.40, 3.41, 3.42, 3.43, 3.44, 3.45 diagnostic endpoints: 3.7.4, 3.9, 3.11.2, 3.13, 3.14, 3.15.3, @@ -1155,10 +1151,10 @@ diagnostic endpoints: repository evidence: all concrete lit/runtime paths listed below exist - all 41 runtime case directories contain kernel.pto, launch.cpp, main.cpp, + all 43 runtime case directories contain kernel.pto, launch.cpp, main.cpp, golden.py, and compare.py - latest broad VMI runtime sweep passed: PASS=41 FAIL=0 - latest full VMI lit sweep passed: 312/312 + latest broad VMI runtime sweep passed: PASS=43 FAIL=0 + latest full VMI lit sweep passed: 313/313 ``` Current checked-in coverage for 3.3 dense f8->f32->compute->f8: @@ -1354,16 +1350,37 @@ runtime SIM: test/vpto/cases/vmi/group-slots-scf-for-store ``` -Current checked-in lit coverage for 3.43 internal function argument boundary +Current checked-in coverage for 3.25.1 private function result boundary: + +```text +lit: + test/lit/vmi/vmi_ptoas_private_call_inline.pto + +runtime SIM: + test/vpto/cases/vmi/private-call-inline-store + +implementation note: + after vmi-to-vpto physicalizes the private helper, ptoas inlines private + single-block helpers whose signatures contain !pto.vreg or !pto.mask. This + happens before VPTO vecscope/backend emission, so physical vector values do + not escape through a function return. +``` + +Current checked-in coverage for 3.43 internal function argument boundary materialization: ```text lit: test/lit/vmi/vmi_layout_assignment_call_argument_boundary.pto + test/lit/vmi/vmi_ptoas_call_boundary_vecscope.pto runtime SIM: - blocked by the current private vector callee backend path; see known - implementation gaps below + test/vpto/cases/vmi/private-call-argument-boundary-store + +implementation note: + private physical helper inlining also covers void helper calls with physical + VMI arguments, so the backend no longer sees a physical VPTO vector function + ABI for this internal boundary. ``` Current checked-in coverage for packed group-slot RHS elementwise continuations @@ -1550,7 +1567,6 @@ Diagnostic-only cases: 3.16.2 group_slot_load slots=1 dynamic or unaligned stride 3.27 S=32 source_group_stride not divisible by 8 f32 elements 3.19.2 block_elems=8 value consumed by truncf without materialization plan -3.25.1 full ptoas emission for private VMI callees that return VPTO vector values 3.25.2 public/external VMI boundary 3.30 unsafe masked_load tail without stable masked/gather fallback ``` @@ -1578,7 +1594,6 @@ lit: test/lit/vmi/vmi_ptoas_public_result_abi_invalid.pto test/lit/vmi/vmi_layout_assignment_external_call_invalid.pto test/lit/vmi/vmi_layout_assignment_external_decl_invalid.pto - test/lit/vmi/vmi_ptoas_call_boundary_vecscope_invalid.pto test/lit/vmi/vmi_to_vpto_masked_load_nonfull_invalid.pto test/lit/vmi/vmi_to_vpto_stable_gather_masked_load_todo_invalid.pto ``` @@ -1587,25 +1602,11 @@ Known implementation gaps before all catalog cases can become runtime SIM coverage: ```text -remaining function runtime coverage: - 3.25.1 internal function boundary specialization has layout-assignment and - vmi-to-vpto lit coverage, but full ptoas emission still fails after - physicalization because today's inferred pto.vecscope is resultless and VPTO - vector-scope values cannot escape through a function return. Runtime coverage - requires either a resultful vecscope/VPTO vector ABI or an explicit inlining - policy before vecscope inference. - - 3.43 internal function argument boundary materialization has - layout-assignment and vmi-to-vpto lit coverage. Full ptoas emission for a - private void vector callee currently reaches the Bisheng device backend and - fails on the physicalized callee with: - - fatal error: error in backend: Do not know how to split the result of this operator! - - Runtime coverage requires either inlining private vector callees before the - device backend path or adding backend support for the physical VPTO vector - function ABI. This is a runtime/backend gap, not a license for `vmi-to-vpto` - to infer layouts from caller/callee context. +private physical function ABI: + 3.25.1 and 3.43 runtime coverage is closed for private/internal single-block + helpers by inlining private physical VMI helpers after vmi-to-vpto and before + VPTO vecscope/backend emission. Public/external VMI boundaries are still + rejected until a stable VMI ABI is defined. memory-proof runtime coverage: 3.21 S=32 full-tile-readable tail is covered by a runtime case that uses diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index c16fedcde3..0b5a658fbe 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -133,8 +133,10 @@ dynamic active_elems_per_group runtime source: scalar loads or surrounding context. private vector function runtime: - assignment/lowering semantics are defined; full ptoas runtime depends on - backend support or an inlining policy for physical VPTO vector callees. + private/internal single-block helpers are runtime-covered by ptoas inlining + private physical VMI helpers after vmi-to-vpto and before VPTO vecscope/backend + emission. This is a post-physicalization backend hygiene step; vmi-to-vpto + still lowers only from assigned layouts and helper ops. diagnostic-only cases: compact S=12 gather fallback, packed slots=8 width-changing cast, public VMI @@ -683,8 +685,10 @@ the initial value or previous iteration during lowering. Internal/private VMI function boundaries must make layout choices explicit in the assigned IR. The baseline implementation keeps function arguments in a contiguous VMI ABI and inserts callee-entry `ensure_layout` helpers when the -callee body needs another layout. A later private-function optimization may -specialize signatures directly: +callee body needs another layout. Private helpers are then physicalized by +`vmi-to-vpto` and inlined before VPTO vecscope/backend emission so physical +`!pto.vreg`/`!pto.mask` values do not become a backend function ABI. A later +private-function optimization may specialize signatures directly: ```text func @producer() -> !vmi.vreg<256xf32, deinterleaved=4> diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md index 160b25a398..d0ec9f70a5 100644 --- a/docs/designs/vmi-layout-lowering-cases.md +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -177,7 +177,7 @@ the immediately following complete endpoints. 3.22 scf.for loop-carried layout complete 3.23 group_broadcast with multiple dense consumers complete 3.24 mask with elementwise/select/store complete -3.25 function boundary layout specialization complete/design +3.25 function boundary layout specialization complete 3.26 S=16 grouped tail through broadcast/reduce/store complete 3.27 S=32 group_load with stride greater than group size complete 3.28 group_slot_load slots=1 aligned non-unit stride complete @@ -195,7 +195,7 @@ the immediately following complete endpoints. 3.40 scalar broadcast feeding dense and grouped users complete/materialization 3.41 non-rematerializable value with incompatible users complete/materialization 3.42 group_slots scf.for loop-carried accumulator complete -3.43 internal function argument boundary materialization complete/design +3.43 internal function argument boundary materialization complete 3.44 masked_load grouped tail feeding S=32 reduce complete 3.45 dynamic S=32 create_group_mask complete ``` @@ -3269,6 +3269,22 @@ for r = 0..7: out[off + r] = reduce(row_r[0..31]) ``` +Runtime closure: + +```text +lit: + test/lit/vmi/vmi_ptoas_private_call_inline.pto + +runtime SIM: + test/vpto/cases/vmi/private-call-inline-store + +ptoas pipeline: + vmi-layout-assignment makes the private result layout explicit + vmi-to-vpto physicalizes the private helper result into !pto.vreg values + ptoas then inlines private physical VMI helpers before VPTO vecscope/backend + emission, so physical vector values do not escape through a function return +``` + #### 3.25.2 Public Or External VMI Boundary VMI input: @@ -5125,6 +5141,23 @@ optimization must still be expressed in the assigned VMI function type before `vmi-to-vpto` runs. ``` +Runtime closure: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_call_argument_boundary.pto + test/lit/vmi/vmi_ptoas_call_boundary_vecscope.pto + +runtime SIM: + test/vpto/cases/vmi/private-call-argument-boundary-store + +ptoas pipeline: + vmi-layout-assignment inserts explicit callee-entry materialization + vmi-to-vpto physicalizes the call operands and callee body + ptoas then inlines the private physical helper before VPTO vecscope/backend + emission, so the backend never needs a physical VPTO vector function ABI +``` + ### 3.44 `masked_load` Grouped Tail Feeding S=32 Reduce This case connects the explicit `masked_load` tail model from section 3.30 with diff --git a/test/lit/vmi/vmi_ptoas_call_boundary_vecscope_invalid.pto b/test/lit/vmi/vmi_ptoas_call_boundary_vecscope.pto similarity index 78% rename from test/lit/vmi/vmi_ptoas_call_boundary_vecscope_invalid.pto rename to test/lit/vmi/vmi_ptoas_call_boundary_vecscope.pto index 950215e5e4..771ae5904c 100644 --- a/test/lit/vmi/vmi_ptoas_call_boundary_vecscope_invalid.pto +++ b/test/lit/vmi/vmi_ptoas_call_boundary_vecscope.pto @@ -6,7 +6,7 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: not ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - 2>&1 | FileCheck %s +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - | FileCheck %s module attributes {pto.target_arch = "a5"} { module attributes {pto.backend = "vpto", pto.kernel_kind = #pto.kernel_kind} { @@ -31,5 +31,14 @@ module attributes {pto.target_arch = "a5"} { } } -// CHECK: cannot infer resultless pto.vecscope because VPTO vector-scope data cannot have external users -// CHECK-SAME: escaping value type is '!pto.vreg<64xf32>' +// CHECK-NOT: func.func private @callee +// CHECK-LABEL: func.func @caller +// CHECK: pto.vecscope +// CHECK: pto.vdup +// CHECK: pto.vadd +// CHECK: pto.vsts +// CHECK: pto.vsts +// CHECK-NOT: func.call @callee +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_ptoas_private_call_inline.pto b/test/lit/vmi/vmi_ptoas_private_call_inline.pto new file mode 100644 index 0000000000..c5e1604bec --- /dev/null +++ b/test/lit/vmi/vmi_ptoas_private_call_inline.pto @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + module attributes {pto.backend = "vpto", pto.kernel_kind = #pto.kernel_kind} { + func.func private @producer(%scalar: f32) + -> !pto.vmi.vreg<128xf32> { + %value = pto.vmi.broadcast %scalar + : f32 -> !pto.vmi.vreg<128xf32> + return %value : !pto.vmi.vreg<128xf32> + } + + func.func @vmi_ptoas_private_call_inline( + %scalar: f32, + %dst: !pto.ptr, + %offset: index) { + %value = call @producer(%scalar) + : (f32) -> !pto.vmi.vreg<128xf32> + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } + } +} + +// CHECK-NOT: func.func private @producer +// CHECK-LABEL: func.func @vmi_ptoas_private_call_inline +// CHECK: pto.vecscope +// CHECK: pto.vdup +// CHECK: pto.vsts +// CHECK: pto.vsts +// CHECK-NOT: call @producer +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/vpto/cases/vmi/private-call-argument-boundary-store/compare.py b/test/vpto/cases/vmi/private-call-argument-boundary-store/compare.py new file mode 100644 index 0000000000..9f34394fa1 --- /dev/null +++ b/test/vpto/cases/vmi/private-call-argument-boundary-store/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(name: str, golden_name: str) -> None: + golden = np.fromfile(golden_name, dtype=np.float32) + output = np.fromfile(name, dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + return + if golden.shape != output.shape: + print(f"[ERROR] compare failed {name}: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +def main() -> None: + check("v2.bin", "golden_v2.bin") + check("v3.bin", "golden_v3.bin") + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/private-call-argument-boundary-store/golden.py b/test/vpto/cases/vmi/private-call-argument-boundary-store/golden.py new file mode 100644 index 0000000000..41f1b1b714 --- /dev/null +++ b/test/vpto/cases/vmi/private-call-argument-boundary-store/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 32 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + base = np.linspace(-0.875, 0.625, COLS, dtype=np.float32) + src = np.empty((ROWS, COLS), dtype=np.float32) + for row in range(ROWS): + src[row, :] = base + np.float32(row) * np.float32(0.03125) + copy = np.full((ROWS, COLS), SENTINEL, dtype=np.float32) + sums = np.full(ROWS, SENTINEL, dtype=np.float32) + golden_copy = src.copy() + golden_sum = np.sum(src, axis=1, dtype=np.float32).astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + copy.reshape(-1).tofile(output_dir / "v2.bin") + sums.tofile(output_dir / "v3.bin") + golden_copy.reshape(-1).astype(np.float32).tofile(output_dir / "golden_v2.bin") + golden_sum.astype(np.float32).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/private-call-argument-boundary-store/kernel.pto b/test/vpto/cases/vmi/private-call-argument-boundary-store/kernel.pto new file mode 100644 index 0000000000..eb8f7f5e6a --- /dev/null +++ b/test/vpto/cases/vmi/private-call-argument-boundary-store/kernel.pto @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func private @consume(%x: !pto.vmi.vreg<256xf32>, + %mask: !pto.vmi.mask<256xpred>, + %out: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %out[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + return + } + + func.func @vmi_private_call_argument_boundary_store_kernel( + %src_gm: !pto.ptr, %copy_gm: !pto.ptr, + %sum_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_copy = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + pto.vmi.store %x, %ub_copy[%c0] + : !pto.vmi.vreg<256xf32>, !pto.ptr + %mask = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + func.call @consume(%x, %mask, %ub_sum, %c0) + : (!pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred>, + !pto.ptr, index) -> () + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_copy, %copy_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/private-call-argument-boundary-store/launch.cpp b/test/vpto/cases/vmi/private-call-argument-boundary-store/launch.cpp new file mode 100644 index 0000000000..ba6be566de --- /dev/null +++ b/test/vpto/cases/vmi/private-call-argument-boundary-store/launch.cpp @@ -0,0 +1,36 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_private_call_argument_boundary_store_kernel(__gm__ float *src, + __gm__ float *copy, + __gm__ float *sum); + +void LaunchVmi_private_call_argument_boundary_store_kernel(float *src, + float *copy, + float *sum, + void *stream) { + vmi_private_call_argument_boundary_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)copy, (__gm__ float *)sum); +} diff --git a/test/vpto/cases/vmi/private-call-argument-boundary-store/main.cpp b/test/vpto/cases/vmi/private-call-argument-boundary-store/main.cpp new file mode 100644 index 0000000000..5ce943feae --- /dev/null +++ b/test/vpto/cases/vmi/private-call-argument-boundary-store/main.cpp @@ -0,0 +1,99 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_private_call_argument_boundary_store_kernel(float *src, + float *copy, + float *sum, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kCols = 32; + constexpr size_t kSrcElems = kRows * kCols; + constexpr size_t kSumElems = kRows; + size_t srcBytes = kSrcElems * sizeof(float); + size_t copyBytes = kSrcElems * sizeof(float); + size_t sumBytes = kSumElems * sizeof(float); + float *srcHost = nullptr; + float *copyHost = nullptr; + float *sumHost = nullptr; + float *srcDevice = nullptr; + float *copyDevice = nullptr; + float *sumDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(©Host), copyBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)©Device, copyBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", copyBytes, copyHost, copyBytes); + ReadFile("./v3.bin", sumBytes, sumHost, sumBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(copyDevice, copyBytes, copyHost, copyBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_private_call_argument_boundary_store_kernel(srcDevice, copyDevice, + sumDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(copyHost, copyBytes, copyDevice, copyBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", copyHost, copyBytes); + WriteFile("./v3.bin", sumHost, sumBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(copyDevice); + aclrtFree(sumDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(copyHost); + aclrtFreeHost(sumHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/private-call-argument-boundary-store/ptoas.flags b/test/vpto/cases/vmi/private-call-argument-boundary-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/private-call-argument-boundary-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/private-call-inline-store/compare.py b/test/vpto/cases/vmi/private-call-inline-store/compare.py new file mode 100644 index 0000000000..9f34394fa1 --- /dev/null +++ b/test/vpto/cases/vmi/private-call-inline-store/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(name: str, golden_name: str) -> None: + golden = np.fromfile(golden_name, dtype=np.float32) + output = np.fromfile(name, dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + return + if golden.shape != output.shape: + print(f"[ERROR] compare failed {name}: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +def main() -> None: + check("v2.bin", "golden_v2.bin") + check("v3.bin", "golden_v3.bin") + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/private-call-inline-store/golden.py b/test/vpto/cases/vmi/private-call-inline-store/golden.py new file mode 100644 index 0000000000..41f1b1b714 --- /dev/null +++ b/test/vpto/cases/vmi/private-call-inline-store/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 32 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + base = np.linspace(-0.875, 0.625, COLS, dtype=np.float32) + src = np.empty((ROWS, COLS), dtype=np.float32) + for row in range(ROWS): + src[row, :] = base + np.float32(row) * np.float32(0.03125) + copy = np.full((ROWS, COLS), SENTINEL, dtype=np.float32) + sums = np.full(ROWS, SENTINEL, dtype=np.float32) + golden_copy = src.copy() + golden_sum = np.sum(src, axis=1, dtype=np.float32).astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + copy.reshape(-1).tofile(output_dir / "v2.bin") + sums.tofile(output_dir / "v3.bin") + golden_copy.reshape(-1).astype(np.float32).tofile(output_dir / "golden_v2.bin") + golden_sum.astype(np.float32).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/private-call-inline-store/kernel.pto b/test/vpto/cases/vmi/private-call-inline-store/kernel.pto new file mode 100644 index 0000000000..5f7beec943 --- /dev/null +++ b/test/vpto/cases/vmi/private-call-inline-store/kernel.pto @@ -0,0 +1,67 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func private @producer(%src: !pto.ptr, %off: index) + -> !pto.vmi.vreg<256xf32> { + %x = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + return %x : !pto.vmi.vreg<256xf32> + } + + func.func @vmi_private_call_inline_store_kernel(%src_gm: !pto.ptr, + %copy_gm: !pto.ptr, + %sum_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_copy = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x = func.call @producer(%ub_src, %c0) + : (!pto.ptr, index) -> !pto.vmi.vreg<256xf32> + pto.vmi.store %x, %ub_copy[%c0] + : !pto.vmi.vreg<256xf32>, !pto.ptr + + %mask = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_copy, %copy_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/private-call-inline-store/launch.cpp b/test/vpto/cases/vmi/private-call-inline-store/launch.cpp new file mode 100644 index 0000000000..b5015d7cda --- /dev/null +++ b/test/vpto/cases/vmi/private-call-inline-store/launch.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_private_call_inline_store_kernel(__gm__ float *src, __gm__ float *copy, + __gm__ float *sum); + +void LaunchVmi_private_call_inline_store_kernel(float *src, float *copy, + float *sum, void *stream) { + vmi_private_call_inline_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)copy, (__gm__ float *)sum); +} diff --git a/test/vpto/cases/vmi/private-call-inline-store/main.cpp b/test/vpto/cases/vmi/private-call-inline-store/main.cpp new file mode 100644 index 0000000000..325ebc902e --- /dev/null +++ b/test/vpto/cases/vmi/private-call-inline-store/main.cpp @@ -0,0 +1,97 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_private_call_inline_store_kernel(float *src, float *copy, + float *sum, void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kCols = 32; + constexpr size_t kSrcElems = kRows * kCols; + constexpr size_t kSumElems = kRows; + size_t srcBytes = kSrcElems * sizeof(float); + size_t copyBytes = kSrcElems * sizeof(float); + size_t sumBytes = kSumElems * sizeof(float); + float *srcHost = nullptr; + float *copyHost = nullptr; + float *sumHost = nullptr; + float *srcDevice = nullptr; + float *copyDevice = nullptr; + float *sumDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(©Host), copyBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)©Device, copyBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", copyBytes, copyHost, copyBytes); + ReadFile("./v3.bin", sumBytes, sumHost, sumBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(copyDevice, copyBytes, copyHost, copyBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_private_call_inline_store_kernel(srcDevice, copyDevice, sumDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(copyHost, copyBytes, copyDevice, copyBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", copyHost, copyBytes); + WriteFile("./v3.bin", sumHost, sumBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(copyDevice); + aclrtFree(sumDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(copyHost); + aclrtFreeHost(sumHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/private-call-inline-store/ptoas.flags b/test/vpto/cases/vmi/private-call-inline-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/private-call-inline-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 60e88b4276..80f8c22469 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -18,6 +18,8 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" #include "mlir/Parser/Parser.h" @@ -1617,6 +1619,117 @@ static LogicalResult verifyNoPublicVMISignature(ModuleOp module) { return failure(result.wasInterrupted()); } +static bool containsVMIPhysicalType(Type type) { + if (isa(type)) + return true; + if (auto functionType = dyn_cast(type)) { + return llvm::any_of(functionType.getInputs(), containsVMIPhysicalType) || + llvm::any_of(functionType.getResults(), containsVMIPhysicalType); + } + return false; +} + +static bool isPrivatePhysicalVMIHelper(func::FuncOp func) { + return !func.isPublic() && !func.isExternal() && + func.getBody().hasOneBlock() && + containsVMIPhysicalType(func.getFunctionType()); +} + +static LogicalResult inlinePrivatePhysicalVMIHelperCall(func::CallOp call, + func::FuncOp callee) { + if (callee.isExternal()) + return call.emitOpError("callee must have a body before inlining"); + if (!callee.getBody().hasOneBlock()) + return call.emitOpError("callee must be single-block before inlining"); + + Block &entry = callee.getBody().front(); + if (entry.getNumArguments() != call.getNumOperands()) + return call.emitOpError("callee argument count mismatch during inlining"); + + auto returnOp = dyn_cast(entry.getTerminator()); + if (!returnOp) + return call.emitOpError("callee must terminate with func.return"); + if (returnOp.getNumOperands() != call.getNumResults()) + return call.emitOpError("callee return/result arity mismatch during inlining"); + + OpBuilder builder(call); + IRMapping mapping; + for (auto [arg, operand] : llvm::zip(entry.getArguments(), call.getOperands())) + mapping.map(arg, operand); + + for (Operation &op : entry.without_terminator()) { + Operation *newOp = builder.clone(op, mapping); + for (auto [oldResult, newResult] : + llvm::zip(op.getResults(), newOp->getResults())) + mapping.map(oldResult, newResult); + } + + for (auto [callResult, returnOperand] : + llvm::zip(call.getResults(), returnOp.getOperands())) + callResult.replaceAllUsesWith(mapping.lookup(returnOperand)); + + call.erase(); + return success(); +} + +static LogicalResult inlinePrivatePhysicalVMIHelpersInModule(ModuleOp module) { + bool madeProgress = true; + while (madeProgress) { + madeProgress = false; + + SmallVector calls; + module.walk([&](func::CallOp call) { calls.push_back(call); }); + + for (func::CallOp call : calls) { + if (!call || !call->getBlock()) + continue; + + func::FuncOp caller = call->getParentOfType(); + auto calleeAttr = call.getCalleeAttr(); + if (!caller || !calleeAttr) + continue; + + func::FuncOp callee = + SymbolTable::lookupNearestSymbolFrom( + call, calleeAttr.getAttr()); + if (!callee || !isPrivatePhysicalVMIHelper(callee)) + continue; + if (callee == caller) + return call.emitOpError("recursive private VMI helper call cannot be " + "inlined before VPTO emission"); + + if (failed(inlinePrivatePhysicalVMIHelperCall(call, callee))) + return failure(); + madeProgress = true; + } + } + + SymbolTable symbolTable(module); + SmallVector deadFuncs; + for (func::FuncOp func : module.getOps()) { + if (!isPrivatePhysicalVMIHelper(func)) + continue; + auto uses = symbolTable.getSymbolUses(func, module); + if (uses && uses->empty()) + deadFuncs.push_back(func); + } + for (func::FuncOp func : deadFuncs) + func.erase(); + + return success(); +} + +static LogicalResult inlinePrivatePhysicalVMIHelpers(ModuleOp module) { + if (failed(inlinePrivatePhysicalVMIHelpersInModule(module))) + return failure(); + WalkResult result = module.walk([&](ModuleOp nestedModule) { + if (failed(inlinePrivatePhysicalVMIHelpersInModule(nestedModule))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); +} + static LogicalResult runVMISemanticPipeline(OwningOpRef &module) { if (failed(verifyNoPublicVMISignature(module.get()))) return failure(); @@ -1634,6 +1747,10 @@ static LogicalResult runVMISemanticPipeline(OwningOpRef &module) { llvm::errs() << "Error: VMI-to-VPTO pipeline failed.\n"; return failure(); } + if (failed(inlinePrivatePhysicalVMIHelpers(module.get()))) { + llvm::errs() << "Error: failed to inline private VMI physical helpers.\n"; + return failure(); + } return success(); } From e96ba6c7a4844af4f3bee2956271d580424ebc4b Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 13:24:17 +0800 Subject: [PATCH 13/31] Validate required VMI selected plans --- .../vmi-layout-assignment-implementation.md | 58 +++++++- .../vmi-layout-assignment-lowering-design.md | 80 +++++++++++ docs/designs/vmi-layout-lowering-cases.md | 10 +- include/PTO/Transforms/Passes.td | 4 +- lib/PTO/Transforms/PTOValidateVMIIR.cpp | 132 ++++++++++++++++++ ...out_gate_missing_selected_plan_invalid.pto | 23 +++ 6 files changed, 297 insertions(+), 10 deletions(-) create mode 100644 test/lit/vmi/vmi_layout_gate_missing_selected_plan_invalid.pto diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index dc54a4af09..102bcc628c 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -196,6 +196,55 @@ Ops that are uniquely determined by layout may omit this attr, but the rule should be conservative. If future maintainers could reasonably ask "why this lowering?", assignment should write a plan. +Required-plan table for the current implementation: + +```text +op required when +group_load result layout matches a registered group_load plan +group_slot_load explicit group_slots slots=8 or slots=1 result +group_reduce_addf source/result layouts match a registered reduce plan +group_broadcast explicit slots=8 or slots=1 source and dense result +truncf group_slots slots=1 f32->f16 slot-preserving cast +ensure_layout always carries source/result layouts instead of plan +ensure_mask_layout always carries source/result layouts instead of plan +ensure_mask_granularity always carries source/result granularities instead of plan +``` + +Layout/attr-only decisions today: + +```text +load result layout plus full_read_elems/full chunk proof +group_store source group_slots layout plus explicit output stride +masked_load explicit passthrough, mask layout, and memory proof +masked_store/select operand/result layouts plus mask granularity +dense extf/truncf source/result layouts and element widths +``` + +Implementation rule: + +```text +vmi-layout-assignment attaches the required plan before type conversion. +validate-assigned-vmi rejects a required-plan op that lacks vmi.selected_plan. +vmi-to-vpto verifies the plan against the already assigned layouts and emits +VMI-LAYOUT-CONTRACT instead of selecting a fallback from producer/user context. +If a layout/attr-only op later gains a second legal recipe, that recipe must be +promoted into the required-plan table before vmi-to-vpto can emit it. +Unsupported shapes that have no registered plan still diagnose through their +specific capability check rather than failing with a generic missing-plan error. +``` + +Examples of forbidden recovery in `vmi-to-vpto`: + +```text +group_reduce_addf cannot walk to a load/group_load producer to choose S=16 + parity versus block8. +group_store cannot inspect the group_reduce producer; it consumes only the + assigned source layout and explicit stride. +group_broadcast cannot inspect sibling users to decide whether to rematerialize. +masked_load cannot inspect the mask producer to prove memory safety. +func.call cannot inspect the callee body to decide physical function layout. +``` + ## 4. VMI Surface Ops Required By Cases Initial op set from the case catalog: @@ -1068,13 +1117,13 @@ the case catalog. Current broad runtime sweep: ```text -WORK_SPACE=$PWD/.tmp/vmi-runtime-batch-private-calls CASE_PREFIX='vmi/' JOBS=4 \ +WORK_SPACE=$PWD/.tmp/vmi-runtime-batch-selected-plan-gate CASE_PREFIX='vmi/' JOBS=4 \ test/vpto/scripts/run_host_vpto_validation_parallel.sh PASS=43 FAIL=0 -summary: .tmp/vmi-runtime-batch-private-calls/parallel-summary.tsv +summary: .tmp/vmi-runtime-batch-selected-plan-gate/parallel-summary.tsv log scan: rg -n "RV_|alignment|\[ERROR\]|\[error\]|ERROR" \ - .tmp/vmi-runtime-batch-private-calls.log + .tmp/vmi-runtime-batch-selected-plan-gate.log result: no matches ``` @@ -1154,7 +1203,7 @@ repository evidence: all 43 runtime case directories contain kernel.pto, launch.cpp, main.cpp, golden.py, and compare.py latest broad VMI runtime sweep passed: PASS=43 FAIL=0 - latest full VMI lit sweep passed: 313/313 + latest full VMI lit sweep passed: 314/314 ``` Current checked-in coverage for 3.3 dense f8->f32->compute->f8: @@ -1585,6 +1634,7 @@ entries: ```text lit: + test/lit/vmi/vmi_layout_gate_missing_selected_plan_invalid.pto test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto test/lit/vmi/vmi_layout_assignment_group_load_s16_compact_stride12_invalid.pto test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index 0b5a658fbe..9261a938dd 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -280,6 +280,86 @@ invariant is not illustrative: if a lowering decision is not uniquely implied by op + assigned operand/result layouts + explicit attrs, assignment must write a selected plan. +### 4.1 Selected Plan Contract + +`selected_plan` is not an optimization hint. It is the serialized answer to a +question that would otherwise require `vmi-to-vpto` to inspect producer, +consumer, control-flow, memory, or mask context. + +Required plans in the current implementation: + +```text +group_load: + required for registered result layouts. The plan fixes source_group_stride + handling and whether the result is contiguous chunks, S=16 block8, or S=32 + block8. Unsupported shapes diagnose through the capability check instead of + inventing a plan. + +group_slot_load: + required for explicit slots=8 or slots=1 layouts. The plan fixes packed + scalar load versus row-local lane-0 load. A single source op may be + rematerialized into two different planned ops. + +group_reduce_addf: + required for registered S=8/S=16/S=32/S=64 shapes. The plan fixes parity + versus block8, packed slots=8 versus row-local slots=1, and multi-chunk + arity. Unsupported group sizes diagnose as unsupported capability, not as + missing selected_plan. + +group_broadcast: + required for explicit slots=8 or slots=1 sources. The plan fixes source + interpretation and the vselr index recipe for the requested dense result + layout. Legacy bare group_slots are tolerated only as compatibility input and + must not be emitted by layout assignment. + +truncf: + required for group_slots slots=1 f32->f16, where the cast is a slot-preserving + group-slot cast rather than an ordinary dense VCVT path. +``` + +Layout-only or attr-only decisions in the current implementation: + +```text +load: + result layout plus explicit memory attrs decide the lowering. full_read_elems + is the memory-safety proof; vmi-to-vpto may not recover that proof from MTE or + caller context. + +group_store: + source group_slots layout and explicit output stride decide packed slots=8 + versus row-local slots=1 store legality. If another legal store recipe is + introduced, assignment must attach a selected plan before vmi-to-vpto uses it. + +masked_load: + explicit passthrough, mask layout, full physical read, shaped safe-tail memref, + or an explicit diagnostic decide legality. A future stable gather fallback + must be selected by assignment before vmi-to-vpto lowers it. + +masked_store/select/elementwise: + operand/result layouts and explicit mask granularity decide the lowering. + They remain transfer ops unless a future case introduces competing recipes. + +extf/truncf: + dense width-changing paths are layout-determined today. Any future + commute-through-group-broadcast or alternative VCVT recipe must become a + selected plan first. +``` + +Forbidden plan recovery: + +```text +No pattern may synthesize one of the required plans by: + - walking from group_reduce to the load/group_load producer + - walking from store/broadcast/truncf to the group_reduce producer + - scanning sibling users of a group_slots value + - inspecting branch bodies or loop bodies from a control-flow boundary + - inspecting private callee bodies while lowering a call +``` + +If a required plan is missing, `vmi-to-vpto` emits +`VMI-LAYOUT-CONTRACT` at the current op and prints the op name, logical type, +assigned layouts, and the missing plan class. + ## 5. Plan Registry The compiler owns a target-aware plan registry. Layout assignment queries this diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md index d0ec9f70a5..8e2d6bfceb 100644 --- a/docs/designs/vmi-layout-lowering-cases.md +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -167,13 +167,13 @@ the immediately following complete endpoints. 3.12 control-flow join before group_reduce complete 3.13 packed group-slot f32 -> f16 cast illegal diagnostic 3.14 unsupported group size illegal diagnostic -3.15 compact S=12 written as logical S=16 complete/design +3.15 compact S=12 written as logical S=16 complete/diagnostic 3.16 group_slot_load layout contract complete 3.17 group_broadcast feeding deinterleaved consumer complete 3.18 one value with dense and group-reduce consumers complete/materialization 3.19 S=16 reduce block_elems plan selection complete/diagnostic 3.20 group_slots control-flow join complete -3.21 S=32 tail with full-tile-readable source complete/design +3.21 S=32 tail with full-tile-readable source complete 3.22 scf.for loop-carried layout complete 3.23 group_broadcast with multiple dense consumers complete 3.24 mask with elementwise/select/store complete @@ -187,9 +187,9 @@ the immediately following complete endpoints. 3.32 f32 feeding f8 store and S=32 reduce complete 3.33 one dense value feeding S=16 and S=32 reduces complete/materialization 3.34 S=64 group-slot result f32->f16 cast complete -3.35 group_slots fanout to group_store and broadcast complete/design -3.36 same scalar source materialized as slots=8/slots=1 complete/design -3.37 S=64 group_store with non-unit output stride complete/design +3.35 group_slots fanout to group_store and broadcast complete +3.36 same scalar source materialized as slots=8/slots=1 complete/materialization +3.37 S=64 group_store with non-unit output stride complete 3.38 multi-tile S=32 group_reduce complete 3.39 strided S=32 group_load through broadcast/reduce complete 3.40 scalar broadcast feeding dense and grouped users complete/materialization diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index 25ec3324b9..103b6a6df9 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -652,7 +652,9 @@ def PTOValidateVMILayoutIR Checks the post-layout-assignment VMI stage: every VMI data value must have a concrete VMI layout, every VMI mask must have concrete b8/b16/b32 granularity and layout, physical VPTO register values must not appear yet, - and VMI typed values must stay inside VMI semantic/helper or structural ops. + VMI typed values must stay inside VMI semantic/helper or structural ops, + and context-sensitive VMI ops must carry the selected_plan contract emitted + by layout assignment. }]; let constructor = "mlir::pto::createPTOValidateVMILayoutIRPass()"; let dependentDialects = ["mlir::cf::ControlFlowDialect", diff --git a/lib/PTO/Transforms/PTOValidateVMIIR.cpp b/lib/PTO/Transforms/PTOValidateVMIIR.cpp index 6ce3e8eecd..889a5ebe85 100644 --- a/lib/PTO/Transforms/PTOValidateVMIIR.cpp +++ b/lib/PTO/Transforms/PTOValidateVMIIR.cpp @@ -36,6 +36,8 @@ using namespace mlir::pto; namespace { +static constexpr const char *kVMISelectedPlanAttrName = "vmi.selected_plan"; + bool isVMIType(Type type) { return isa(type); } bool isPhysicalVPTOType(Type type) { @@ -159,6 +161,133 @@ LogicalResult emitInvariant(Operation *op, llvm::raw_ostream *diagOS, return failure(); } +LogicalResult emitLayoutContract(Operation *op, llvm::raw_ostream *diagOS, + Twine message) { + InFlightDiagnostic diag = + op->emitError() << kVMIDiagLayoutContractPrefix << message; + (void)diag; + mirrorDiagnostic(diagOS, Twine(kVMIDiagLayoutContractPrefix) + message); + return failure(); +} + +std::optional getGroupSize(VMIVRegType type, int64_t numGroups) { + if (!type || numGroups <= 0 || type.getElementCount() % numGroups != 0) + return std::nullopt; + return type.getElementCount() / numGroups; +} + +bool hasRegisteredGroupReducePlan(VMIGroupReduceAddFOp op) { + auto sourceType = dyn_cast(op.getSource().getType()); + if (!sourceType) + return false; + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + if (!sourceLayout) + return false; + + std::optional groupSize = + getGroupSize(sourceType, op.getNumGroupsAttr().getInt()); + if (!groupSize) + return false; + + if (sourceLayout.isContiguous()) + return *groupSize == 8 || *groupSize == 64; + + if (!sourceLayout.isDeinterleaved()) + return false; + if (*groupSize == 16 && sourceLayout.getFactor() == 2) + return sourceLayout.getBlockElems() == 1 || + sourceLayout.getBlockElems() == 8; + if (*groupSize == 32 && sourceLayout.getFactor() == 4) + return sourceLayout.getBlockElems() == 1 || + sourceLayout.getBlockElems() == 8; + return false; +} + +bool hasRegisteredGroupLoadPlan(VMIGroupLoadOp op) { + auto resultType = dyn_cast(op.getResult().getType()); + if (!resultType) + return false; + VMILayoutAttr layout = resultType.getLayoutAttr(); + if (!layout) + return false; + if (layout.isContiguous()) + return true; + if (!layout.isDeinterleaved() || layout.getBlockElems() != 8) + return false; + + std::optional groupSize = + getGroupSize(resultType, op.getNumGroupsAttr().getInt()); + if (!groupSize) + return false; + return (*groupSize == 16 && layout.getFactor() == 2) || + (*groupSize == 32 && layout.getFactor() == 4); +} + +bool hasRegisteredGroupSlotLoadPlan(VMIGroupSlotLoadOp op) { + auto resultType = dyn_cast(op.getResult().getType()); + if (!resultType) + return false; + VMILayoutAttr layout = resultType.getLayoutAttr(); + return layout && layout.isGroupSlots() && + layout.getNumGroups() == op.getNumGroupsAttr().getInt() && + (layout.getSlots() == 8 || layout.getSlots() == 1); +} + +bool hasRegisteredGroupBroadcastPlan(VMIGroupBroadcastOp op) { + auto sourceType = dyn_cast(op.getSource().getType()); + auto resultType = dyn_cast(op.getResult().getType()); + if (!sourceType || !resultType) + return false; + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + return sourceLayout && resultLayout && sourceLayout.isGroupSlots() && + sourceLayout.getNumGroups() == op.getNumGroupsAttr().getInt() && + !resultLayout.isGroupSlots() && + (sourceLayout.getSlots() == 8 || sourceLayout.getSlots() == 1); +} + +bool hasRegisteredGroupSlotTruncFPlan(Operation *op) { + auto truncf = dyn_cast(op); + if (!truncf) + return false; + + auto sourceType = dyn_cast(truncf.getSource().getType()); + auto resultType = dyn_cast(truncf.getResult().getType()); + if (!sourceType || !resultType) + return false; + + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + return sourceLayout && resultLayout && sourceLayout.isGroupSlots() && + resultLayout.isGroupSlots() && sourceLayout.getSlots() == 1 && + resultLayout.getSlots() == 1 && sourceType.getElementType().isF32() && + resultType.getElementType().isF16(); +} + +bool requiresSelectedPlan(Operation *op) { + if (auto groupLoad = dyn_cast(op)) + return hasRegisteredGroupLoadPlan(groupLoad); + if (auto groupSlotLoad = dyn_cast(op)) + return hasRegisteredGroupSlotLoadPlan(groupSlotLoad); + if (auto reduce = dyn_cast(op)) + return hasRegisteredGroupReducePlan(reduce); + if (auto broadcast = dyn_cast(op)) + return hasRegisteredGroupBroadcastPlan(broadcast); + return hasRegisteredGroupSlotTruncFPlan(op); +} + +LogicalResult verifySelectedPlanContract(Operation *op, + llvm::raw_ostream *diagOS) { + if (!requiresSelectedPlan(op)) + return success(); + if (op->getAttrOfType(kVMISelectedPlanAttrName)) + return success(); + return emitLayoutContract( + op, diagOS, + Twine(op->getName().getStringRef()) + + " requires vmi.selected_plan selected by vmi-layout-assignment"); +} + LogicalResult verifyBoundaryType(Operation *owner, Type type, llvm::raw_ostream *diagOS) { if (isPhysicalVPTOType(type)) @@ -378,6 +507,9 @@ LogicalResult verifyLayoutAssignedOperation(Operation *op, if (!hasVMIOrPhysicalType(op)) return success(); + if (failed(verifySelectedPlanContract(op, diagOS))) + return failure(); + if (isVMIHelperOp(op)) { if (isVMILayoutHelperOp(op)) return success(); diff --git a/test/lit/vmi/vmi_layout_gate_missing_selected_plan_invalid.pto b/test/lit/vmi/vmi_layout_gate_missing_selected_plan_invalid.pto new file mode 100644 index 0000000000..d06bd275ca --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_missing_selected_plan_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_missing_selected_plan_invalid( + %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf requires vmi.selected_plan selected by vmi-layout-assignment From c1e74fb27f3ad8d9dfdec73941a2f85a47f9c948 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 13:34:57 +0800 Subject: [PATCH 14/31] Document VMI layout closure matrix --- .../vmi-layout-assignment-implementation.md | 163 ++++++++++++++++-- 1 file changed, 151 insertions(+), 12 deletions(-) diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index 102bcc628c..dfc8588a86 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -692,6 +692,138 @@ Canonical assigned IR shape for `group_broadcast` multi-use: If the assigned IR does not have one of these explicit shapes, `vmi-to-vpto` must reject it instead of attempting to recover the missing decision. +### 6.6 Case-To-Implementation Closure Matrix + +The current case catalog is sufficient for the first implementation. No new +layout kind is justified by the supported endpoints. The implementation work +should instead close the following finite matrix. Each row names the request +builder that owns the decision, the assignment artifact that must appear in IR, +and the `vmi-to-vpto` contract. + +```text +case family builder / owner assignment artifact +3.1, 3.2, 3.3 dense casts buildCastRequests dense layout on each cast result +3.29 mask width split buildMaskRequests per-use mask granularity helper +3.31, 3.32 dense fanout conflict resolver cloned load or ensure_layout + +vmi-to-vpto contract: + consume only the assigned dense layouts. It may emit VCVT and dense + materialization, but it must not choose deinterleaved=2/4 by inspecting a + later truncf, store, or group_reduce user. +``` + +```text +case family builder / owner assignment artifact +3.4 S=8 reduce buildGroupReduceRequests s8_reduce_contiguous plan +3.5 S=16 reduce buildGroupReduceRequests s16_reduce_parity/block8 plan +3.6 S=32 reduce buildGroupReduceRequests s32_reduce_dintlv4/block8 plan +3.7 S=64 reduce buildGroupReduceRequests s64_reduce_row_local plan +3.11.1 S=64 active-row tail buildMaskRequests active-row store/reduce masks +3.19.1 S=16 block_elems choice buildGroupReduceRequests selected block_elems reduce plan +3.38 multi-tile S=32 reduce buildGroupReduceRequests multiple group_slots chunks +3.26 grouped tail buildMaskRequests split grouped masks +3.44, 3.45 grouped S=32 masks buildMaskRequests explicit deinterleaved mask values + +vmi-to-vpto contract: + lower each reduce from source layout, result group_slots layout, and + selected_plan. It must not walk to the load/group_load producer to decide + parity versus block8, row-local versus packed slots, or static versus dynamic + mask generation. +``` + +```text +case family builder / owner assignment artifact +3.15.1 S=16 row stride 16 buildGroupMemoryRequests block_elems=8 group_load plan +3.15.2 S=16 row stride > 16 buildGroupMemoryRequests strided block_elems=8 plan +3.16.1 group_slot_load slots=8 buildGroupMemoryRequests unit-stride packed slots plan +3.16.2 group_slot_load slots=1 buildGroupMemoryRequests row-local aligned slots plan +3.27 strided group_load buildGroupMemoryRequests positive block_elems=8 plan +3.28 slots=1 non-unit load buildGroupMemoryRequests row-local group_slot_load plan +3.37 slots=1 strided store buildStoreRequests group_store stride/alignment proof +3.39 strided load fanout conflict resolver preserving layout or materialization + +vmi-to-vpto contract: + consume only explicit memory stride/alignment attrs, selected_plan, and + layouts. It must not infer safe read/write placement from neighboring + compute ops. Unsupported dynamic, unaligned, or compact-row gather shapes + stay diagnostics until a gather plan is registered. +``` + +```text +case family builder / owner assignment artifact +3.8 reduce->truncf->broadcast conflict resolver slot cast plus dense materialization +3.10 non-load S=32 producer buildElementwiseRequests transparent deinterleaved chain +3.17 broadcast deint consumer conflict resolver use-site group_broadcast layout +3.18 dense + reduce users conflict resolver clone/rematerialize/ensure_layout +3.23 broadcast multi-user conflict resolver cloned group_broadcast +3.33 S=16 + S=32 users conflict resolver cloned load or materialization +3.34 S=64 slots=1 cast buildCastRequests group_slot_cast selected plan +3.35 slots fanout buildElementwiseRequests same group_slots layout on users +3.36 scalar slots=8/slots=1 conflict resolver cloned group_slot_load/broadcast +3.40 scalar dense + grouped conflict resolver cloned broadcast +3.41 incompatible fixed value conflict resolver diagnostic or ensure_layout + +vmi-to-vpto contract: + each op instance is already single-plan. The lowering pass never scans + sibling users to decide whether to clone, pack, broadcast, or materialize. +``` + +```text +case family builder / owner assignment artifact +3.21 S=32 safe full-read tail buildMaskRequests full_read_elems memory proof +3.24 mask/select/store buildMaskRequests explicit mask layout/granularity +3.12 scf.if before reduce buildControlFlowRequests common yielded layout +3.20 group_slots scf.if buildControlFlowRequests common group_slots layout +3.22 scf.for carried value buildControlFlowRequests fixed-point iter_arg layout +3.25 function boundary buildFunctionBoundary specialized/internal boundary +3.42 loop accumulator buildControlFlowRequests loop-carried group_slots layout +3.43 call argument materialize buildFunctionBoundary callee-entry/return helper + +vmi-to-vpto contract: + block argument, region result, call operand, and function result layouts are + visible in types or helper ops. It must not inspect branch bodies, loop + bodies, callers, or callees to discover a layout. +``` + +```text +diagnostic family builder / owner required failure +3.7.4 slots=1 unit-stride store buildStoreRequests no aligned row-local store plan +3.9 dense store of group slots buildStoreRequests use group_store/group_broadcast +3.11.2 S=32 unsafe tail buildMaskRequests missing full_tile_readable/gather +3.13 slots=8 width cast buildCastRequests no packed slot cast plan +3.14 unsupported group size buildGroupReduceRequests no registered reduce plan +3.15.3 compact S=12 buildGroupMemoryRequests no compact gather plan +3.16.1 slots=8 non-unit load buildGroupMemoryRequests no packed strided slot load plan +3.16.2 slots=1 bad stride buildGroupMemoryRequests no dynamic/unaligned row-local plan +3.19.2 invalid block_elems use conflict resolver no preserving materialization +3.25.2 public/external ABI buildFunctionBoundary no stable public VMI ABI +3.27 unaligned group_load buildGroupMemoryRequests no gather/block fallback plan +3.30 masked_load unsafe tail buildMaskRequests no padding/gather fallback + +vmi-to-vpto contract: + these cases must fail before or at the layout contract boundary with the + requesting op named. They must not be accepted by falling back to a generic + dense load, dense store, or producer/user inspection. +``` + +Additional cases are needed only when the scope changes: + +```text +stable gather fallback enabled: + add compact S=12 positive lowering and masked_load unsafe-tail positive + lowering before accepting either path. + +pack-to-slots=8 or unaligned row-local stores enabled: + add positive S=64 unit-stride group_store and reduce->pack->dense store cases. + +public VMI ABI enabled: + add public call/return ABI cases before removing the public-boundary + diagnostic. + +packed group-slot width cast enabled: + add slots=8 f32->f16 cast and downstream group_store/broadcast cases. +``` + ## 7. OneToN Type Conversion `vmi-to-vpto` should use OneToN conversion for VMI values. @@ -1648,8 +1780,7 @@ lit: test/lit/vmi/vmi_to_vpto_stable_gather_masked_load_todo_invalid.pto ``` -Known implementation gaps before all catalog cases can become runtime SIM -coverage: +Capability boundaries and runtime evidence notes: ```text private physical function ABI: @@ -1745,14 +1876,22 @@ public ABI diagnostic ## 13. Completion Checklist -The implementation is not complete until: - -```text -1. every case has a layout-assignment test -2. every positive case has a vmi-to-vpto test -3. every simulator-supported case has a sim validation -4. every unsupported case has a diagnostic test -5. vmi-to-vpto contains no producer/user context inference -6. missing selected_plan on context-sensitive ops is a hard failure -7. release docs are updated only after the design stabilizes +Current evidence for the case-catalog objective: + +```text +1. every catalog endpoint is mapped in section 6.6 to an assignment owner, + assignment artifact, and vmi-to-vpto contract +2. every SIM-backed positive endpoint is listed in section 11.3 and has a + checked-in runtime case directory +3. every runtime case directory contains kernel.pto, launch.cpp, main.cpp, + golden.py, and compare.py +4. the latest broad VMI runtime sweep passed: PASS=43 FAIL=0 +5. the latest full VMI lit sweep passed: 314/314 +6. every unsupported endpoint listed in section 11.3 has a diagnostic lit test +7. vmi-to-vpto context-sensitive decisions are represented by assigned layouts, + selected_plan, helper ops, rematerialization, or diagnostics +8. missing selected_plan on registered context-sensitive shapes is a hard + validation failure +9. release docs remain untouched; this is still a design/implementation plan + under docs/designs ``` From 067f699f0fef3edf591785d3adbad12eef69c7be Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 14:24:40 +0800 Subject: [PATCH 15/31] Add VMI dense reduce multi-consumer case --- ...ment_widen_dense_reduce_multi_consumer.pto | 83 +++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 test/lit/vmi/vmi_layout_assignment_widen_dense_reduce_multi_consumer.pto diff --git a/test/lit/vmi/vmi_layout_assignment_widen_dense_reduce_multi_consumer.pto b/test/lit/vmi/vmi_layout_assignment_widen_dense_reduce_multi_consumer.pto new file mode 100644 index 0000000000..95f5becf6b --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_widen_dense_reduce_multi_consumer.pto @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_widen_dense_reduce_multi_consumer( + %src: !pto.ptr, + %k1: !pto.vmi.vreg<128xf32>, + %init0: !pto.vmi.vreg<1xf32>, + %init1: !pto.vmi.vreg<1xf32>, + %out0: !pto.ptr, + %out1: !pto.ptr, + %off: index) { + %c128 = arith.constant 128 : index + %c0 = arith.constant 0 : index + %a = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<128xf16> + %w = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %t1 = pto.vmi.mulf %w, %k1 + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %r0 = pto.vmi.reduce_addf %t1, %init0, %mask {reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<1xf32>, + !pto.vmi.mask<128xpred> -> !pto.vmi.vreg<1xf32> + pto.vmi.store %r0, %out0[%c0] + : !pto.vmi.vreg<1xf32>, !pto.ptr + %r = pto.vmi.reduce_addf %w, %init1, %mask {reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<1xf32>, + !pto.vmi.mask<128xpred> -> !pto.vmi.vreg<1xf32> + pto.vmi.store %r, %out1[%c0] + : !pto.vmi.vreg<1xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_widen_dense_reduce_multi_consumer( +// ASSIGN-SAME: %arg1: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: %arg2: !pto.vmi.vreg<1xf32, #pto.vmi.layout> +// ASSIGN-SAME: %arg3: !pto.vmi.vreg<1xf32, #pto.vmi.layout> +// ASSIGN: %[[A:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[W:.*]] = pto.vmi.extf %[[A]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[T1:.*]] = pto.vmi.mulf %[[W]], %arg1 +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[T1_DENSE:.*]] = pto.vmi.ensure_layout %[[T1]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[R0:.*]] = pto.vmi.reduce_addf %[[T1_DENSE]], %arg2, %[[MASK]] +// ASSIGN-SAME: -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[R0]] +// ASSIGN: %[[W_DENSE:.*]] = pto.vmi.ensure_layout %[[W]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[R:.*]] = pto.vmi.reduce_addf %[[W_DENSE]], %arg3, %[[MASK]] +// ASSIGN-SAME: -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[R]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_widen_dense_reduce_multi_consumer( +// LOWER: pto.vlds +// LOWER: pto.vcvt +// LOWER: pto.vcvt +// LOWER: pto.vmul +// LOWER: pto.vintlv +// LOWER: pto.vcadd +// LOWER: pto.vadd +// LOWER: pto.vsts +// LOWER: pto.vintlv +// LOWER: pto.vcadd +// LOWER: pto.vadd +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast From e550b80d62bc99df2feafa48dd21aa0e630d9846 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 15:34:33 +0800 Subject: [PATCH 16/31] Remove VMI selected plan attrs --- .../vmi-layout-assignment-implementation.md | 240 ++++++++---------- .../vmi-layout-assignment-lowering-design.md | 126 +++++---- docs/designs/vmi-layout-lowering-cases.md | 10 +- include/PTO/Transforms/Passes.td | 8 +- lib/PTO/Transforms/PTOValidateVMIIR.cpp | 132 ---------- lib/PTO/Transforms/VMILayoutAssignment.cpp | 140 ---------- lib/PTO/Transforms/VMIToVPTO.cpp | 132 +--------- ...assignment_broadcast_dense_group_users.pto | 1 - ...yout_assignment_call_argument_boundary.pto | 1 - ...ayout_assignment_create_group_mask_s16.pto | 1 - ...signment_create_group_mask_s32_dynamic.pto | 1 - ...ment_dense_group_reduce_multi_consumer.pto | 1 - ..._layout_assignment_f32_f8_store_reduce.pto | 1 - ...ignment_group_broadcast_multi_consumer.pto | 4 - ...yout_assignment_group_broadcast_slots8.pto | 1 - .../vmi/vmi_layout_assignment_group_load.pto | 1 - ...assignment_group_load_s16_stride_store.pto | 2 - ...group_load_s32_stride_broadcast_reduce.pto | 4 - ...assignment_group_load_s32_stride_store.pto | 2 - ...yout_assignment_group_reduce_s16_store.pto | 1 - ...roup_reduce_s16_truncf_broadcast_store.pto | 2 - ...ment_group_reduce_s32_broadcast_reduce.pto | 3 - ...nment_group_reduce_s32_multitile_store.pto | 1 - ...yout_assignment_group_reduce_s32_store.pto | 1 - ...gnment_group_reduce_s32_tail_full_tile.pto | 2 - ...vmi_layout_assignment_group_reduce_s64.pto | 1 - ...ment_group_reduce_s64_broadcast_reduce.pto | 3 - ...assignment_group_reduce_s64_tail_store.pto | 1 - ...out_assignment_group_reduce_s64_truncf.pto | 2 - ..._layout_assignment_group_reduce_slots8.pto | 1 - ...t_assignment_group_reduce_slots8_store.pto | 1 - .../vmi_layout_assignment_group_slot_load.pto | 3 - ...assignment_group_slot_load_dual_layout.pto | 4 - ...i_layout_assignment_group_slots_fanout.pto | 3 - ..._layout_assignment_group_slots_scf_for.pto | 3 - ...signment_masked_load_dense_group_users.pto | 1 - ..._assignment_masked_load_group_tail_s32.pto | 1 - ..._layout_assignment_non_load_s32_reduce.pto | 1 - ...yout_assignment_widen_f16_store_reduce.pto | 1 - ...d.pto => vmi_layout_gate_local_recipe.pto} | 7 +- .../vmi_to_vpto_group_broadcast_slots8.pto | 2 +- ...o_group_broadcast_slots8_local_recipe.pto} | 28 +- ...> vmi_to_vpto_group_load_local_recipe.pto} | 24 +- test/lit/vmi/vmi_to_vpto_group_ops.pto | 2 +- test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto | 2 +- ...to_vpto_group_reduce_s64_local_recipe.pto} | 25 +- .../vmi/vmi_to_vpto_group_reduce_slots8.pto | 2 +- ...vpto_group_reduce_slots8_local_recipe.pto} | 20 +- test/lit/vmi/vmi_to_vpto_group_slot_load.pto | 6 +- ..._to_vpto_group_slot_load_local_recipe.pto} | 18 +- ...group_slot_load_nonunit_slots8_invalid.pto | 2 +- .../vmi_to_vpto_group_slot_truncf_slots1.pto | 1 - ...group_slot_truncf_slots1_local_recipe.pto} | 24 +- 53 files changed, 284 insertions(+), 723 deletions(-) rename test/lit/vmi/{vmi_layout_gate_missing_selected_plan_invalid.pto => vmi_layout_gate_local_recipe.pto} (80%) rename test/lit/vmi/{vmi_to_vpto_group_broadcast_slots8_missing_plan_invalid.pto => vmi_to_vpto_group_broadcast_slots8_local_recipe.pto} (51%) rename test/lit/vmi/{vmi_to_vpto_group_load_missing_plan_invalid.pto => vmi_to_vpto_group_load_local_recipe.pto} (58%) rename test/lit/vmi/{vmi_to_vpto_group_reduce_s64_missing_plan_invalid.pto => vmi_to_vpto_group_reduce_s64_local_recipe.pto} (62%) rename test/lit/vmi/{vmi_to_vpto_group_reduce_slots8_missing_plan_invalid.pto => vmi_to_vpto_group_reduce_slots8_local_recipe.pto} (73%) rename test/lit/vmi/{vmi_to_vpto_group_slot_load_missing_plan_invalid.pto => vmi_to_vpto_group_slot_load_local_recipe.pto} (69%) rename test/lit/vmi/{vmi_to_vpto_group_slot_truncf_slots1_missing_plan_invalid.pto => vmi_to_vpto_group_slot_truncf_slots1_local_recipe.pto} (58%) diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index dfc8588a86..f4c8f8487f 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -35,7 +35,8 @@ vmi-layout-assignment: pto-validate-vmi-layout: verify every VMI data/mask value has layout - verify every context-sensitive op has selected_plan + verify every VMI value has an assigned layout and every non-local lowering + choice has been serialized explicitly verify helper ops have registered materialization plans vmi-to-vpto: @@ -160,51 +161,33 @@ Layout-assigned: Surface VMI types are legal before assignment. Layout-assigned VMI types are required after assignment. -### 3.3 Selected Plan Attribute +### 3.3 Explicit Recipe Carriers -Every context-sensitive op gets a selected plan attr after assignment. The -initial implementation may use a stable string attr: +Lowering decisions are carried by the current op and its types, not by a +separate recipe string. The allowed carriers are: ```text -vmi.selected_plan = "s16_reduce_parity" +op attrs and operands +operand/result VMI layouts +mask granularity and mask layouts +helper ops such as ensure_layout / ensure_mask_layout +cloned or rematerialized producers +diagnostics for unsupported shapes ``` -Once the plan registry syntax is stable, this can become a dedicated plan attr: +If assignment made a non-local choice by inspecting producers, users, sibling +users, control flow, callees, or memory context, it must rewrite the IR so that +the final choice is visible through those carriers before `vmi-to-vpto`. -```text -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -``` - -Ops that are uniquely determined by layout may omit this attr, but the rule -should be conservative. If future maintainers could reasonably ask "why this -lowering?", assignment should write a plan. - -Required-plan table for the current implementation: +Local-decision table for the current implementation: ```text -op required when -group_load result layout matches a registered group_load plan -group_slot_load explicit group_slots slots=8 or slots=1 result -group_reduce_addf source/result layouts match a registered reduce plan -group_broadcast explicit slots=8 or slots=1 source and dense result -truncf group_slots slots=1 f32->f16 slot-preserving cast +op local decision inputs +group_load result layout, num_groups, row_stride, source type +group_slot_load result group_slots layout and source_group_stride +group_reduce_addf source/mask/result layouts, num_groups, reassoc +group_broadcast source/result layouts and num_groups +truncf source/result layouts and element widths ensure_layout always carries source/result layouts instead of plan ensure_mask_layout always carries source/result layouts instead of plan ensure_mask_granularity always carries source/result granularities instead of plan @@ -223,12 +206,12 @@ dense extf/truncf source/result layouts and element widths Implementation rule: ```text -vmi-layout-assignment attaches the required plan before type conversion. -validate-assigned-vmi rejects a required-plan op that lacks vmi.selected_plan. -vmi-to-vpto verifies the plan against the already assigned layouts and emits -VMI-LAYOUT-CONTRACT instead of selecting a fallback from producer/user context. -If a layout/attr-only op later gains a second legal recipe, that recipe must be -promoted into the required-plan table before vmi-to-vpto can emit it. +validate-assigned-vmi validates assigned layouts, mask granularity, boundaries, +and helper placement. +vmi-to-vpto emits VMI-LAYOUT-CONTRACT for missing local proof. +If a layout/attr-only op later gains a second legal recipe that cannot be +distinguished from current-op information, that recipe must be represented by a +new attr, helper op, or rematerialized op before vmi-to-vpto can emit it. Unsupported shapes that have no registered plan still diagnose through their specific capability check rather than failing with a generic missing-plan error. ``` @@ -316,7 +299,6 @@ struct VMILayoutPlan { SmallVector operandLayouts; SmallVector resultLayouts; int64_t cost; - bool requiresSelectedPlanAttr; bool requiresFullTileReadable; bool mayReadInactivePhysicalLanes; DiagnosticBuilder (*explainFailure)(...); @@ -605,15 +587,15 @@ Algorithm: - otherwise insert ensure_layout at use - otherwise diagnose 6. Rewrite VMI result/block/function types with chosen layouts. -7. Attach selected_plan attrs where required. -8. Insert helper ops with source/result layout attrs. +7. Insert helper ops with source/result layout attrs. ``` Rewrite invariants: ```text No VMI data/mask value after assignment has a null layout. -No context-sensitive VMI op after assignment lacks selected_plan. +Any non-local choice is represented by op attrs, operand/result layouts, a +helper op, a clone, or an explicit diagnostic. Every ensure_* helper has a registered materialization plan. Every function/call signature carrying VMI is specialized or diagnosed. ``` @@ -626,14 +608,9 @@ Assignment rewrites the IR so that later lowering has no hidden choices. type rewrite: every VMI data/mask result and block argument receives a layout attr -selected_plan rewrite: - context-sensitive ops receive vmi.selected_plan - examples: group_reduce_addf, group_load, group_slot_load, group_broadcast, - group_slot cast, full-read masked_load plans - clone rewrite: cheap producers are cloned before their divergent use sites - each clone receives its own layout and selected_plan + each clone receives its own layout and attrs ensure rewrite: non-cheap values use pto.vmi.ensure_layout or ensure_mask_layout at the use @@ -655,7 +632,7 @@ function rewrite: Canonical assigned IR shape for a conflicting load: ```text -%x = pto.vmi.load ... {vmi.selected_plan = "load_dintlv4"} +%x = pto.vmi.load ... : ... -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> %x_dense = pto.vmi.ensure_layout %x @@ -668,10 +645,10 @@ pto.vmi.store %x_dense, ... Canonical assigned IR shape for a cloned cheap producer: ```text -%x_s16 = pto.vmi.load ... {vmi.selected_plan = "load_dintlv2"} +%x_s16 = pto.vmi.load ... : ... -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> -%x_s32 = pto.vmi.load ... {vmi.selected_plan = "load_dintlv4"} +%x_s32 = pto.vmi.load ... : ... -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> ``` @@ -679,12 +656,10 @@ Canonical assigned IR shape for `group_broadcast` multi-use: ```text %b0 = pto.vmi.group_broadcast %slots - {vmi.selected_plan = "group_broadcast_slots8_vselr"} : !pto.vmi.vreg<256xf32, #pto.vmi.layout> -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> %b1 = pto.vmi.group_broadcast %slots - {vmi.selected_plan = "group_broadcast_slots8_vselr"} : !pto.vmi.vreg<256xf32, #pto.vmi.layout> -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> ``` @@ -725,10 +700,10 @@ case family builder / owner assignment artifact 3.44, 3.45 grouped S=32 masks buildMaskRequests explicit deinterleaved mask values vmi-to-vpto contract: - lower each reduce from source layout, result group_slots layout, and - selected_plan. It must not walk to the load/group_load producer to decide - parity versus block8, row-local versus packed slots, or static versus dynamic - mask generation. + lower each reduce from the current op's attrs, source/mask layout, result + group_slots layout. It must not walk to the load/group_load producer to + decide parity versus block8, row-local versus packed slots, or static versus + dynamic mask generation. ``` ```text @@ -743,10 +718,10 @@ case family builder / owner assignment artifact 3.39 strided load fanout conflict resolver preserving layout or materialization vmi-to-vpto contract: - consume only explicit memory stride/alignment attrs, selected_plan, and - layouts. It must not infer safe read/write placement from neighboring + consume only explicit memory stride/alignment attrs, current op operands, + and layouts. It must not infer safe read/write placement from neighboring compute ops. Unsupported dynamic, unaligned, or compact-row gather shapes - stay diagnostics until a gather plan is registered. + stay diagnostics until a gather recipe is explicit in the current op. ``` ```text @@ -757,7 +732,7 @@ case family builder / owner assignment artifact 3.18 dense + reduce users conflict resolver clone/rematerialize/ensure_layout 3.23 broadcast multi-user conflict resolver cloned group_broadcast 3.33 S=16 + S=32 users conflict resolver cloned load or materialization -3.34 S=64 slots=1 cast buildCastRequests group_slot_cast selected plan +3.34 S=64 slots=1 cast buildCastRequests group_slot_cast layout 3.35 slots fanout buildElementwiseRequests same group_slots layout on users 3.36 scalar slots=8/slots=1 conflict resolver cloned group_slot_load/broadcast 3.40 scalar dense + grouped conflict resolver cloned broadcast @@ -862,112 +837,111 @@ Each pattern uses: ```text op +op attrs and operand values operand/result layouts -selected_plan adaptor physical values ``` Each pattern rejects: ```text -missing selected_plan for context-sensitive op -layout not matching selected_plan +missing current-op proof for an otherwise unsafe memory recipe missing target capability unexpected group_slots dense consumer ``` -Target selected-plan matrix: +Target local recipe matrix: ```text -load, selected_plan=dense_load_norm: +load, recipe=dense_load_norm: result layout contiguous emits pto.vlds / pto.vsts NORM paths covers dense store users and S=64 row-local reduce input -load, selected_plan=load_dintlv2: +load, recipe=load_dintlv2: result layout deinterleaved=2, block_elems=1 emits vldsx2 DINTLV_B32 or normal load + vdintlv materialization covers f32->f16, S=16 parity reduce, f16->f32 widened values -load, selected_plan=load_dintlv4: +load, recipe=load_dintlv4: result layout deinterleaved=4, block_elems=1 emits two vldsx2 DINTLV_B32 plus vdintlv covers f32->f8, S=32 dintlv4 reduce -group_load, selected_plan=s16_group_load_block8_unit_stride: +group_load, recipe=s16_group_load_block8_unit_stride: result layout deinterleaved=2, block_elems=8 emits vldsx2/BDINTLV for 8 rows of 16xf32 covers compact logical S=16 when source_group_stride == 16 -group_load, selected_plan=s16_group_load_block8_stride: +group_load, recipe=s16_group_load_block8_stride: result layout deinterleaved=2, block_elems=8 emits two vsldb strided 32B block loads requires source_group_stride % 8 == 0 -group_load, selected_plan=s32_group_load_block8_stride: +group_load, recipe=s32_group_load_block8_stride: result layout deinterleaved=4, block_elems=8 emits four vsldb strided 32B block loads requires source_group_stride % 8 == 0 -group_load, selected_plan=group_load_contiguous_chunks: +group_load, recipe=group_load_contiguous_chunks: result layout contiguous emits one vlds per physical group chunk using row_stride address arithmetic covers the currently implemented full-chunk row-local group_load path -group_reduce_addf, selected_plan=s8_reduce_contiguous: +group_reduce_addf, recipe=s8_reduce_contiguous: consumes contiguous f32 with group size 8 produces group_slots(G, slots=8) emits one vcgadd -group_reduce_addf, selected_plan=s16_reduce_parity: +group_reduce_addf, recipe=s16_reduce_parity: consumes deinterleaved=2, block_elems=1 produces group_slots(G, slots=8) emits two vcgadd operations and one vadd -group_reduce_addf, selected_plan=s16_reduce_block8: +group_reduce_addf, recipe=s16_reduce_block8: consumes deinterleaved=2, block_elems=8 produces group_slots(G, slots=8) emits two vcgadd operations and one vadd -group_reduce_addf, selected_plan=s32_reduce_dintlv4: +group_reduce_addf, recipe=s32_reduce_dintlv4: consumes deinterleaved=4, block_elems=1 produces group_slots(G, slots=8) emits four vcgadd operations and a vadd tree -group_reduce_addf, selected_plan=s32_reduce_block8_stride: +group_reduce_addf, recipe=s32_reduce_block8_stride: consumes deinterleaved=4, block_elems=8 produces group_slots(G, slots=8) emits four vcgadd operations and a vadd tree -group_reduce_addf, selected_plan=s64_reduce_row_local: +group_reduce_addf, recipe=s64_reduce_row_local: consumes contiguous f32 with group size 64 produces group_slots(G, slots=1) target lowering emits per-row vcgadd plus vcadd; the current prototype uses the existing row-local VCADD/VADD/VSEL sequence while preserving the same group_slots(G, slots=1) value contract -group_slot_load, selected_plan=group_slot_load_slots8_unit_stride: +group_slot_load, recipe=group_slot_load_slots8_unit_stride: result group_slots(G, slots=8) requires source_group_stride == 1 emits one packed vsldb load -group_slot_load, selected_plan=group_slot_load_slots1_row_local: +group_slot_load, recipe=group_slot_load_slots1_row_local: result group_slots(G, slots=1) supports aligned non-unit source_group_stride requires constant positive source_group_stride divisible by 256 / elementBits emits one lane-0 vsldb per group -group_broadcast, selected_plan=group_broadcast_slots8_vselr: +group_broadcast, recipe=group_broadcast_slots8_vselr: source group_slots(G, slots=8) result dense layout selected per use emits vselr using assigned result layout -group_broadcast, selected_plan=group_broadcast_slots1_vselr: +group_broadcast, recipe=group_broadcast_slots1_vselr: source group_slots(G, slots=1) result dense layout selected per use emits vdup/vselr row-local materialization -truncf, selected_plan=group_slot_cast_slots1_f32_to_f16: +truncf, recipe=group_slot_cast_slots1_f32_to_f16: source/result group_slots(G, slots=1) emits one lane-0 vcvt per group slot block rejects packed slots=8 unless another plan is registered @@ -980,36 +954,29 @@ Current staged implementation status: ```text group_slot_load: - vmi-to-vpto requires vmi.selected_plan and checks it against - #pto.vmi.layout. + vmi-to-vpto lowers from #pto.vmi.layout + and source_group_stride. group_reduce_addf: - explicit slots=8 VCGADD lowering requires - vmi.selected_plan = "s8_reduce_contiguous". Legacy bare num_groups and - generic VCADD lowering still need the plan-registry migration. + explicit slots=8 VCGADD lowering is selected from contiguous source/mask + layout, slots=8 result layout, num_groups, and reassoc. S=16 block8 assignment emits source/mask #pto.vmi.layout, result - #pto.vmi.layout, and - vmi.selected_plan = "s16_reduce_block8"; vmi-to-vpto checks that plan and - lowers through two VCGADDs plus a PAT_VL8 VADD per packed result block. + #pto.vmi.layout; vmi-to-vpto lowers through two + VCGADDs plus a PAT_VL8 VADD per packed result block. S=32 block8 assignment emits source/mask #pto.vmi.layout, result - #pto.vmi.layout, and - vmi.selected_plan = "s32_reduce_block8_stride"; vmi-to-vpto checks that - plan and lowers through four VCGADDs plus a PAT_VL8 VADD tree per packed - result block. - S=64 row-local assignment now emits - vmi.selected_plan = "s64_reduce_row_local" and has focused - layout-assignment/vmi-to-vpto lit coverage; the explicit slots=1 generic - VCADD row-local path also requires and checks that selected_plan. Other - legacy bare num_groups generic VCADD paths still need the plan-registry - migration. + #pto.vmi.layout; vmi-to-vpto lowers through four + VCGADDs plus a PAT_VL8 VADD tree per packed result block. + S=64 row-local assignment uses #pto.vmi.layout + and has focused layout-assignment/vmi-to-vpto lit coverage; the explicit + slots=1 generic VCADD row-local path is selected locally. group_broadcast: - explicit slots=8/1 source layouts require - vmi.selected_plan = "group_broadcast_slots8_vselr" or - "group_broadcast_slots1_vselr". Deinterleaved block-fragment results use - the result layout block_elems as the local vselr selection group, so + explicit slots=8/1 source layouts select + packed or row-local VSELR recipes locally. Deinterleaved block-fragment + results use the result layout block_elems as the local vselr selection group, + so `deinterleaved = 4, block_elems = 8` broadcasts one group slot across each 32B row fragment. VSELR index vectors are materialized per physical result chunk. For small-group results, layout assignment has already fixed the @@ -1018,27 +985,23 @@ group_broadcast: `sourceChunk = firstGroup / slots`, and `baseGroupSlot = firstGroup % slots`. The generated index vector selects `baseGroupSlot .. baseGroupSlot + groupsPerResultChunk - 1`; it must not be - reused across result chunks. Legacy bare num_groups still needs the - plan-registry migration. + reused across result chunks. group_load: - contiguous full-chunk path emits and checks - vmi.selected_plan = "group_load_contiguous_chunks". S=16/S=32 - block-aligned strided loads emit and check - vmi.selected_plan = "s16_group_load_block8_stride" or - "s32_group_load_block8_stride", assign + contiguous full-chunk path is selected from a contiguous result layout. + S=16/S=32 block-aligned strided loads are selected from #pto.vmi.layout, and lower to one vsldb per 32B row fragment and physical chunk. The dedicated S=16 unit-stride - vldsx2/BDINTLV plan remains a design target. S=16/S=32 group_load with a - non-constant, non-positive, or non-8-f32-aligned row_stride is rejected by - vmi-layout-assignment because the stable gather fallback is not implemented. + vldsx2/BDINTLV recipe remains a local peephole target. + S=16/S=32 group_load with a non-constant, non-positive, or non-8-f32-aligned + row_stride is rejected by vmi-layout-assignment because the stable gather + fallback is not implemented. truncf group-slot cast: - layout assignment and vmi-to-vpto support and check - vmi.selected_plan = "group_slot_cast_slots1_f32_to_f16" for - group_slots(G, slots=1) f32 -> f16. The reduce->truncf->group_store - slots=1 flow has focused lit coverage and no longer relies on vmi-to-vpto - inspecting the truncf producer. + layout assignment and vmi-to-vpto support group_slots(G, slots=1) + f32 -> f16 from source/result layouts and element widths. The reduce->truncf + -> group_store slots=1 flow has focused lit coverage and no longer relies on + vmi-to-vpto inspecting the truncf producer. group_store: row-local group_slots(G, slots=1) lowering is implemented as one lane-0 @@ -1058,15 +1021,15 @@ group_store: Examples: ```text -group_reduce_addf, selected_plan=s16_reduce_parity: +group_reduce_addf, recipe=s16_reduce_parity: consume deinterleaved=2, block_elems=1 emit two VCGADDs and one VADD -group_reduce_addf, selected_plan=s16_reduce_block8: +group_reduce_addf, recipe=s16_reduce_block8: consume deinterleaved=2, block_elems=8 emit two VCGADDs and one VADD -group_reduce_addf, selected_plan=s32_reduce_dintlv4: +group_reduce_addf, recipe=s32_reduce_dintlv4: consume deinterleaved=4 emit four VCGADDs and reduction tree @@ -1101,8 +1064,7 @@ After assignment: ```text Every VMI value has layout. Every VMI mask has layout and granularity plan. -Every context-sensitive op has selected_plan. -Every selected_plan matches operand/result layouts. +Every lowering choice is locally deterministic or explicit in attrs/layouts. Every ensure_* helper has a materialization plan. Every control-flow edge has matching VMI layouts. ``` @@ -1121,8 +1083,8 @@ allowed: diagnostic not allowed: - walking from a consumer to a producer to decide a selected_plan - walking from a consumer to a mask producer to decide whether a plan is legal + walking from a consumer to a producer to decide a recipe + walking from a consumer to a mask producer to decide whether a recipe is legal inspecting users to choose a result layout or materialization recovering full_tile_readable from surrounding MTE/caller context ``` @@ -1222,7 +1184,8 @@ Each positive layout-assignment test must check: ```text assigned data layouts assigned mask layouts -selected_plan attrs +assigned op attrs +direct vmi-to-vpto local lowering inserted ensure_layout/rematerialized producers control-flow/function signature specialization ``` @@ -1766,7 +1729,6 @@ entries: ```text lit: - test/lit/vmi/vmi_layout_gate_missing_selected_plan_invalid.pto test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto test/lit/vmi/vmi_layout_assignment_group_load_s16_compact_stride12_invalid.pto test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto @@ -1805,7 +1767,6 @@ memory-proof runtime coverage: layout attrs vmi.vreg/vmi.mask types surface op definitions -selected_plan attr surface/layout validators ``` @@ -1888,10 +1849,9 @@ Current evidence for the case-catalog objective: 4. the latest broad VMI runtime sweep passed: PASS=43 FAIL=0 5. the latest full VMI lit sweep passed: 314/314 6. every unsupported endpoint listed in section 11.3 has a diagnostic lit test -7. vmi-to-vpto context-sensitive decisions are represented by assigned layouts, - selected_plan, helper ops, rematerialization, or diagnostics -8. missing selected_plan on registered context-sensitive shapes is a hard - validation failure +7. vmi-to-vpto decisions are represented by current-op attrs/operands, + assigned layouts, helper ops, rematerialization, or diagnostics +8. no separate recipe string attr is emitted or consumed 9. release docs remain untouched; this is still a design/implementation plan under docs/designs ``` diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index 9261a938dd..b30c0c3472 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -217,7 +217,7 @@ S=64 row-local result -> slots=1 ```text 1. op name and explicit op attrs 2. converted operand/result types with layout -3. selected plan attrs written by layout assignment +3. helper/materialization ops written by layout assignment 4. inserted helper ops 5. target capability registry ``` @@ -251,73 +251,62 @@ or explicit helper: pto.vmi.ensure_mask_granularity ``` -Every context-sensitive op must also have a selected plan if layout alone does -not uniquely identify the lowering: +`vmi-to-vpto` is allowed to choose a deterministic recipe from local +information on the current op: ```text -vmi.selected_plan = "dense_load_norm" -vmi.selected_plan = "load_dintlv2" -vmi.selected_plan = "load_dintlv4" -vmi.selected_plan = "group_load_contiguous_chunks" -vmi.selected_plan = "s16_group_load_block8_unit_stride" -vmi.selected_plan = "s16_group_load_block8_stride" -vmi.selected_plan = "s32_group_load_block8_stride" -vmi.selected_plan = "s8_reduce_contiguous" -vmi.selected_plan = "s16_reduce_parity" -vmi.selected_plan = "s16_reduce_block8" -vmi.selected_plan = "s32_reduce_dintlv4" -vmi.selected_plan = "s32_reduce_block8_stride" -vmi.selected_plan = "s64_reduce_row_local" -vmi.selected_plan = "group_slot_load_slots8_unit_stride" -vmi.selected_plan = "group_slot_load_slots1_row_local" -vmi.selected_plan = "group_broadcast_slots8_vselr" -vmi.selected_plan = "group_broadcast_slots1_vselr" -vmi.selected_plan = "group_slot_cast_slots1_f32_to_f16" +current op name +current op attrs +operand/result types and layouts +current op operand values such as stride and offset +target capability and pass options ``` -The spelling above is illustrative; implementation may use an enum attr. The -invariant is not illustrative: if a lowering decision is not uniquely implied -by op + assigned operand/result layouts + explicit attrs, assignment must write -a selected plan. +This is not context inference. What remains forbidden is walking to producers, +users, sibling users, branch/loop bodies, callees/callers, or nearby memory/MTE +ops to recover a lowering decision or a memory-safety proof. -### 4.1 Selected Plan Contract +If a decision cannot be made from that local information, layout assignment +must rewrite the IR until the decision is explicit in attrs, operand/result +layouts, helper ops, cloned producers, or diagnostics. `vmi-to-vpto` must not +consume a separate string recipe attr. -`selected_plan` is not an optimization hint. It is the serialized answer to a -question that would otherwise require `vmi-to-vpto` to inspect producer, -consumer, control-flow, memory, or mask context. +### 4.1 Local Recipe Contract -Required plans in the current implementation: +The lowering recipe is derived from op + assigned operand/result layouts + +explicit attrs/operands. If two legal recipes cannot be distinguished from +that local information, the IR is missing a semantic carrier and must be +extended before the recipe is implemented. + +Locally deterministic decisions in the current implementation: ```text group_load: - required for registered result layouts. The plan fixes source_group_stride - handling and whether the result is contiguous chunks, S=16 block8, or S=32 - block8. Unsupported shapes diagnose through the capability check instead of - inventing a plan. + result layout, num_groups, row_stride, source type, and target capability + decide contiguous chunks versus S=16/S=32 block8 vsldb lowering. Unit-stride + vldsx2/BDINTLV can be a local peephole for the same block8 layout. group_slot_load: - required for explicit slots=8 or slots=1 layouts. The plan fixes packed - scalar load versus row-local lane-0 load. A single source op may be - rematerialized into two different planned ops. + result group_slots layout and source_group_stride decide packed slots=8 + versus row-local slots=1 vsldb lowering. A single source op may still be + rematerialized into two ops when different users require different result + layouts; each clone is then locally deterministic. group_reduce_addf: - required for registered S=8/S=16/S=32/S=64 shapes. The plan fixes parity - versus block8, packed slots=8 versus row-local slots=1, and multi-chunk - arity. Unsupported group sizes diagnose as unsupported capability, not as - missing selected_plan. + source/mask layout, result group_slots layout, num_groups, element type, and + reassoc decide S=8 contiguous vcgadd, S=16/S=32 deinterleaved vcgadd trees, + and S=64 row-local vcadd/vsel lowering. group_broadcast: - required for explicit slots=8 or slots=1 sources. The plan fixes source - interpretation and the vselr index recipe for the requested dense result - layout. Legacy bare group_slots are tolerated only as compatibility input and - must not be emitted by layout assignment. + source group_slots layout, result dense layout, num_groups, and element type + decide vdup/vselr materialization. truncf: - required for group_slots slots=1 f32->f16, where the cast is a slot-preserving - group-slot cast rather than an ordinary dense VCVT path. + source/result group_slots layouts and element widths decide the slots=1 + f32->f16 slot-preserving vcvt path. ``` -Layout-only or attr-only decisions in the current implementation: +Other layout-only or attr-only decisions in the current implementation: ```text load: @@ -327,8 +316,9 @@ load: group_store: source group_slots layout and explicit output stride decide packed slots=8 - versus row-local slots=1 store legality. If another legal store recipe is - introduced, assignment must attach a selected plan before vmi-to-vpto uses it. + versus row-local slots=1 store legality. If another legal store recipe + needs more information, assignment must make that information explicit in the + op or helper IR before vmi-to-vpto uses it. masked_load: explicit passthrough, mask layout, full physical read, shaped safe-tail memref, @@ -341,14 +331,14 @@ masked_store/select/elementwise: extf/truncf: dense width-changing paths are layout-determined today. Any future - commute-through-group-broadcast or alternative VCVT recipe must become a - selected plan first. + commute-through-group-broadcast or alternative VCVT recipe must have an + explicit IR carrier first. ``` -Forbidden plan recovery: +Forbidden non-local recipe recovery: ```text -No pattern may synthesize one of the required plans by: +No pattern may synthesize a recipe or memory proof by: - walking from group_reduce to the load/group_load producer - walking from store/broadcast/truncf to the group_reduce producer - scanning sibling users of a group_slots value @@ -356,9 +346,9 @@ No pattern may synthesize one of the required plans by: - inspecting private callee bodies while lowering a call ``` -If a required plan is missing, `vmi-to-vpto` emits +If the current op lacks enough local information, `vmi-to-vpto` emits `VMI-LAYOUT-CONTRACT` at the current op and prints the op name, logical type, -assigned layouts, and the missing plan class. +assigned layouts, and the missing decision class. ## 5. Plan Registry @@ -547,7 +537,7 @@ group_broadcast: group_store: requests source group_slots(num_groups, slots=K) - selected plan also records output stride legality + explicit output stride attrs/operands decide store legality dense elementwise add/mul/fma/min/max/select: requests all dense data operands and results use one dense layout @@ -720,7 +710,7 @@ Recommended solving order: 7. Rematerialize cheap producers instead of materializing when cheaper. 8. Specialize internal function signatures. 9. Emit diagnostics for unsatisfied hard constraints. -10. Rewrite VMI types and selected plan attrs. +10. Rewrite VMI types and insert explicit helper/rematerialized ops. ``` Tie-breaking must be deterministic. Suggested priority: @@ -787,10 +777,10 @@ For each op, the pattern: ```text 1. reads operand/result layouts -2. reads selected_plan if required +2. reads current op attrs and operand values 3. asks TypeConverter for ordered physical values -4. emits the registered VPTO recipe -5. fails if the selected plan is missing or target capability is absent +4. emits the locally implied VPTO recipe +5. fails if target capability or required local proof is absent ``` The pattern must not: @@ -825,12 +815,12 @@ diagnostic embellishment: Anything else is a layout-assignment responsibility. In particular, an unsupported producer/consumer combination must be rejected before assignment -writes a selected plan. Section 3.44 is the model for supported partial S=32 +emits layout-assigned IR. Section 3.44 is the model for supported partial S=32 grouped masks: assignment emits explicit contiguous and deinterleaved mask values, and `vmi-to-vpto` lowers the deinterleaved mask op itself through contiguous grouped-mask materialization followed by predicate deinterleave. It does not walk from `group_reduce_addf` to the mask producer to choose or reject -the plan. Dynamic `active_elems_per_group` follows the same rule: the +the recipe. Dynamic `active_elems_per_group` follows the same rule: the `create_group_mask` op lowers its own SSA scalar with vci/vshrs/vshls/vsub/vcmps for contiguous chunks before any predicate deinterleave. @@ -852,8 +842,8 @@ group_slots(G,K): slot_block0, slot_block1, ... ``` -Two physical bundle entries may alias the same VPTO SSA value when the selected -plan proves they have the same contents, such as group_broadcast feeding both +Two physical bundle entries may alias the same VPTO SSA value when the local +recipe proves they have the same contents, such as group_broadcast feeding both parts of a `deinterleaved=2` broadcast result. Arity still follows the layout; aliasing is not a different layout. @@ -866,7 +856,7 @@ Diagnostics are part of the design. They must name: 2. source logical type 3. assigned source layout 4. requested layout -5. missing plan or disabled fallback +5. missing local proof or disabled fallback 6. suggested rewrite when available ``` @@ -894,8 +884,8 @@ public VMI function boundary: The design is complete only when: ```text -1. every case in vmi-layout-lowering-cases.md maps to registered plans -2. every selected plan can be emitted without looking at producer/user context +1. every case in vmi-layout-lowering-cases.md maps to a local recipe +2. every local recipe can be emitted without looking at producer/user context 3. every unsupported case has a precise capability diagnostic 4. every control-flow/function boundary either specializes layout or diagnoses 5. every mask has explicit data layout and predicate granularity diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md index 8e2d6bfceb..262299b3a3 100644 --- a/docs/designs/vmi-layout-lowering-cases.md +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -1522,8 +1522,8 @@ layout transition explicit: `group_broadcast` first produces a dense contiguous f32 value, then `pto.vmi.ensure_layout` materializes the deinterleaved=2 f32 view required by dense `f32 -> f16` truncation. A future direct `group_broadcast -> deinterleaved=2` lowering may remove that materialization, -but it must be implemented as a `group_broadcast` selected plan rather than -hidden inside `truncf` lowering. +but the `group_broadcast` result layout must make that recipe explicit rather +than hiding it inside `truncf` lowering. VPTO lowering result for one full 8-row tile: @@ -3045,9 +3045,9 @@ layout. It is that each use has an explicit layout boundary: %b_for_cast_split = pto.vmi.ensure_layout %b_for_cast ``` -If a future `group_broadcast -> deinterleaved` selected plan is added, layout +If a future direct `group_broadcast -> deinterleaved` recipe is added, layout assignment may assign `%b_for_mul` or `%b_for_cast` directly to that layout, but -the choice must still be visible in the assigned IR and selected plan. +the choice must still be visible in the assigned IR. VPTO lowering result: @@ -5266,7 +5266,7 @@ one contiguous value for `masked_load`, and one deinterleaved value for `create_group_mask` by materializing the contiguous grouped predicate chunks and then applying `pdintlv_b32` in the same tree shape as the data `vdintlv`. It does not walk from `group_reduce_addf` to the mask producer to -choose or reject the selected plan. +choose or reject the recipe. Assignment may select a deinterleaved S=32 load plan only when the rounded physical reads are memory-safe; otherwise it must diagnose or use a future diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index 103b6a6df9..3047197d57 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -652,9 +652,11 @@ def PTOValidateVMILayoutIR Checks the post-layout-assignment VMI stage: every VMI data value must have a concrete VMI layout, every VMI mask must have concrete b8/b16/b32 granularity and layout, physical VPTO register values must not appear yet, - VMI typed values must stay inside VMI semantic/helper or structural ops, - and context-sensitive VMI ops must carry the selected_plan contract emitted - by layout assignment. + and VMI typed values must stay inside VMI semantic/helper or structural ops. + vmi-to-vpto chooses deterministic local recipes from the current op's attrs, + operand/result types, layouts, and operand values; non-local choices must + be represented as explicit attrs, helper ops, cloned producers, or + diagnostics before this stage. }]; let constructor = "mlir::pto::createPTOValidateVMILayoutIRPass()"; let dependentDialects = ["mlir::cf::ControlFlowDialect", diff --git a/lib/PTO/Transforms/PTOValidateVMIIR.cpp b/lib/PTO/Transforms/PTOValidateVMIIR.cpp index 889a5ebe85..6ce3e8eecd 100644 --- a/lib/PTO/Transforms/PTOValidateVMIIR.cpp +++ b/lib/PTO/Transforms/PTOValidateVMIIR.cpp @@ -36,8 +36,6 @@ using namespace mlir::pto; namespace { -static constexpr const char *kVMISelectedPlanAttrName = "vmi.selected_plan"; - bool isVMIType(Type type) { return isa(type); } bool isPhysicalVPTOType(Type type) { @@ -161,133 +159,6 @@ LogicalResult emitInvariant(Operation *op, llvm::raw_ostream *diagOS, return failure(); } -LogicalResult emitLayoutContract(Operation *op, llvm::raw_ostream *diagOS, - Twine message) { - InFlightDiagnostic diag = - op->emitError() << kVMIDiagLayoutContractPrefix << message; - (void)diag; - mirrorDiagnostic(diagOS, Twine(kVMIDiagLayoutContractPrefix) + message); - return failure(); -} - -std::optional getGroupSize(VMIVRegType type, int64_t numGroups) { - if (!type || numGroups <= 0 || type.getElementCount() % numGroups != 0) - return std::nullopt; - return type.getElementCount() / numGroups; -} - -bool hasRegisteredGroupReducePlan(VMIGroupReduceAddFOp op) { - auto sourceType = dyn_cast(op.getSource().getType()); - if (!sourceType) - return false; - VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); - if (!sourceLayout) - return false; - - std::optional groupSize = - getGroupSize(sourceType, op.getNumGroupsAttr().getInt()); - if (!groupSize) - return false; - - if (sourceLayout.isContiguous()) - return *groupSize == 8 || *groupSize == 64; - - if (!sourceLayout.isDeinterleaved()) - return false; - if (*groupSize == 16 && sourceLayout.getFactor() == 2) - return sourceLayout.getBlockElems() == 1 || - sourceLayout.getBlockElems() == 8; - if (*groupSize == 32 && sourceLayout.getFactor() == 4) - return sourceLayout.getBlockElems() == 1 || - sourceLayout.getBlockElems() == 8; - return false; -} - -bool hasRegisteredGroupLoadPlan(VMIGroupLoadOp op) { - auto resultType = dyn_cast(op.getResult().getType()); - if (!resultType) - return false; - VMILayoutAttr layout = resultType.getLayoutAttr(); - if (!layout) - return false; - if (layout.isContiguous()) - return true; - if (!layout.isDeinterleaved() || layout.getBlockElems() != 8) - return false; - - std::optional groupSize = - getGroupSize(resultType, op.getNumGroupsAttr().getInt()); - if (!groupSize) - return false; - return (*groupSize == 16 && layout.getFactor() == 2) || - (*groupSize == 32 && layout.getFactor() == 4); -} - -bool hasRegisteredGroupSlotLoadPlan(VMIGroupSlotLoadOp op) { - auto resultType = dyn_cast(op.getResult().getType()); - if (!resultType) - return false; - VMILayoutAttr layout = resultType.getLayoutAttr(); - return layout && layout.isGroupSlots() && - layout.getNumGroups() == op.getNumGroupsAttr().getInt() && - (layout.getSlots() == 8 || layout.getSlots() == 1); -} - -bool hasRegisteredGroupBroadcastPlan(VMIGroupBroadcastOp op) { - auto sourceType = dyn_cast(op.getSource().getType()); - auto resultType = dyn_cast(op.getResult().getType()); - if (!sourceType || !resultType) - return false; - VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); - VMILayoutAttr resultLayout = resultType.getLayoutAttr(); - return sourceLayout && resultLayout && sourceLayout.isGroupSlots() && - sourceLayout.getNumGroups() == op.getNumGroupsAttr().getInt() && - !resultLayout.isGroupSlots() && - (sourceLayout.getSlots() == 8 || sourceLayout.getSlots() == 1); -} - -bool hasRegisteredGroupSlotTruncFPlan(Operation *op) { - auto truncf = dyn_cast(op); - if (!truncf) - return false; - - auto sourceType = dyn_cast(truncf.getSource().getType()); - auto resultType = dyn_cast(truncf.getResult().getType()); - if (!sourceType || !resultType) - return false; - - VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); - VMILayoutAttr resultLayout = resultType.getLayoutAttr(); - return sourceLayout && resultLayout && sourceLayout.isGroupSlots() && - resultLayout.isGroupSlots() && sourceLayout.getSlots() == 1 && - resultLayout.getSlots() == 1 && sourceType.getElementType().isF32() && - resultType.getElementType().isF16(); -} - -bool requiresSelectedPlan(Operation *op) { - if (auto groupLoad = dyn_cast(op)) - return hasRegisteredGroupLoadPlan(groupLoad); - if (auto groupSlotLoad = dyn_cast(op)) - return hasRegisteredGroupSlotLoadPlan(groupSlotLoad); - if (auto reduce = dyn_cast(op)) - return hasRegisteredGroupReducePlan(reduce); - if (auto broadcast = dyn_cast(op)) - return hasRegisteredGroupBroadcastPlan(broadcast); - return hasRegisteredGroupSlotTruncFPlan(op); -} - -LogicalResult verifySelectedPlanContract(Operation *op, - llvm::raw_ostream *diagOS) { - if (!requiresSelectedPlan(op)) - return success(); - if (op->getAttrOfType(kVMISelectedPlanAttrName)) - return success(); - return emitLayoutContract( - op, diagOS, - Twine(op->getName().getStringRef()) + - " requires vmi.selected_plan selected by vmi-layout-assignment"); -} - LogicalResult verifyBoundaryType(Operation *owner, Type type, llvm::raw_ostream *diagOS) { if (isPhysicalVPTOType(type)) @@ -507,9 +378,6 @@ LogicalResult verifyLayoutAssignedOperation(Operation *op, if (!hasVMIOrPhysicalType(op)) return success(); - if (failed(verifySelectedPlanContract(op, diagOS))) - return failure(); - if (isVMIHelperOp(op)) { if (isVMILayoutHelperOp(op)) return success(); diff --git a/lib/PTO/Transforms/VMILayoutAssignment.cpp b/lib/PTO/Transforms/VMILayoutAssignment.cpp index 9352ffce76..85a57e4ac1 100644 --- a/lib/PTO/Transforms/VMILayoutAssignment.cpp +++ b/lib/PTO/Transforms/VMILayoutAssignment.cpp @@ -65,8 +65,6 @@ struct MaskUseRequest { std::string granularity; }; -static constexpr const char *kVMISelectedPlanAttrName = "vmi.selected_plan"; - static unsigned getElementBitWidth(Type type) { if (isa(type)) return 64; @@ -1572,143 +1570,6 @@ struct LayoutSolver { return success(); } - std::optional getGroupReduceSelectedPlan(VMIGroupReduceAddFOp op) { - auto sourceType = dyn_cast(op.getSource().getType()); - if (!sourceType) - return std::nullopt; - VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); - if (!sourceLayout) - return std::nullopt; - - int64_t numGroups = op.getNumGroupsAttr().getInt(); - if (numGroups <= 0 || sourceType.getElementCount() % numGroups != 0) - return std::nullopt; - int64_t groupSize = sourceType.getElementCount() / numGroups; - - if (sourceLayout.isContiguous()) { - if (groupSize == 8) - return StringRef("s8_reduce_contiguous"); - if (groupSize == 64) - return StringRef("s64_reduce_row_local"); - return std::nullopt; - } - - if (!sourceLayout.isDeinterleaved()) - return std::nullopt; - - if (groupSize == 16 && sourceLayout.getFactor() == 2) { - if (sourceLayout.getBlockElems() == 1) - return StringRef("s16_reduce_parity"); - if (sourceLayout.getBlockElems() == 8) - return StringRef("s16_reduce_block8"); - } - - if (groupSize == 32 && sourceLayout.getFactor() == 4) { - if (sourceLayout.getBlockElems() == 1) - return StringRef("s32_reduce_dintlv4"); - if (sourceLayout.getBlockElems() == 8) - return StringRef("s32_reduce_block8_stride"); - } - - return std::nullopt; - } - - std::optional getGroupSlotLoadSelectedPlan(VMIGroupSlotLoadOp op) { - auto resultType = dyn_cast(op.getResult().getType()); - if (!resultType) - return std::nullopt; - VMILayoutAttr layout = resultType.getLayoutAttr(); - if (!layout || !layout.isGroupSlots() || - layout.getNumGroups() != op.getNumGroupsAttr().getInt()) - return std::nullopt; - if (layout.getSlots() == 8) - return StringRef("group_slot_load_slots8_unit_stride"); - if (layout.getSlots() == 1) - return StringRef("group_slot_load_slots1_row_local"); - return std::nullopt; - } - - std::optional getGroupLoadSelectedPlan(VMIGroupLoadOp op) { - auto resultType = dyn_cast(op.getResult().getType()); - if (!resultType) - return std::nullopt; - VMILayoutAttr layout = resultType.getLayoutAttr(); - if (!layout) - return std::nullopt; - if (layout.isContiguous()) - return StringRef("group_load_contiguous_chunks"); - if (!layout.isDeinterleaved() || layout.getBlockElems() != 8) - return std::nullopt; - - int64_t numGroups = op.getNumGroupsAttr().getInt(); - if (numGroups <= 0 || resultType.getElementCount() % numGroups != 0) - return std::nullopt; - int64_t groupSize = resultType.getElementCount() / numGroups; - if (groupSize == 16 && layout.getFactor() == 2) - return StringRef("s16_group_load_block8_stride"); - if (groupSize == 32 && layout.getFactor() == 4) - return StringRef("s32_group_load_block8_stride"); - return std::nullopt; - } - - std::optional - getGroupBroadcastSelectedPlan(VMIGroupBroadcastOp op) { - auto sourceType = dyn_cast(op.getSource().getType()); - auto resultType = dyn_cast(op.getResult().getType()); - if (!sourceType || !resultType) - return std::nullopt; - VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); - VMILayoutAttr resultLayout = resultType.getLayoutAttr(); - if (!sourceLayout || !resultLayout || !sourceLayout.isGroupSlots() || - sourceLayout.getNumGroups() != op.getNumGroupsAttr().getInt() || - resultLayout.isGroupSlots()) - return std::nullopt; - if (sourceLayout.getSlots() == 8) - return StringRef("group_broadcast_slots8_vselr"); - if (sourceLayout.getSlots() == 1) - return StringRef("group_broadcast_slots1_vselr"); - return std::nullopt; - } - - std::optional getTruncFSelectedPlan(VMITruncFOp op) { - auto sourceType = dyn_cast(op.getSource().getType()); - auto resultType = dyn_cast(op.getResult().getType()); - if (!sourceType || !resultType) - return std::nullopt; - - VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); - VMILayoutAttr resultLayout = resultType.getLayoutAttr(); - if (!sourceLayout || !resultLayout || sourceLayout != resultLayout || - !sourceLayout.isGroupSlots() || sourceLayout.getSlots() != 1) - return std::nullopt; - - unsigned sourceBits = getElementBitWidth(sourceType.getElementType()); - unsigned resultBits = getElementBitWidth(resultType.getElementType()); - if (sourceBits == 32 && resultBits == 16) - return StringRef("group_slot_cast_slots1_f32_to_f16"); - return std::nullopt; - } - - void attachSelectedPlanAttrs() { - Builder builder(ctx); - module.walk([&](Operation *op) { - std::optional plan; - if (auto reduce = dyn_cast(op)) - plan = getGroupReduceSelectedPlan(reduce); - else if (auto load = dyn_cast(op)) - plan = getGroupLoadSelectedPlan(load); - else if (auto load = dyn_cast(op)) - plan = getGroupSlotLoadSelectedPlan(load); - else if (auto broadcast = dyn_cast(op)) - plan = getGroupBroadcastSelectedPlan(broadcast); - else if (auto truncf = dyn_cast(op)) - plan = getTruncFSelectedPlan(truncf); - - if (plan) - op->setAttr(kVMISelectedPlanAttrName, builder.getStringAttr(*plan)); - }); - } - void rewriteFunctionType() { module.walk([&](func::FuncOp func) { if (func.empty()) @@ -1755,7 +1616,6 @@ struct LayoutSolver { rewriteDataTypes(); if (failed(insertDataUseMaterializations())) return failure(); - attachSelectedPlanAttrs(); if (failed(inferMaskRequests())) return failure(); rewriteMaskTypes(); diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index 36ccc21f3f..5b050d640a 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -50,8 +50,6 @@ using namespace mlir::pto; namespace { -static constexpr const char *kVMISelectedPlanAttrName = "vmi.selected_plan"; - bool isVMIType(Type type) { return isa(type); } bool containsVMIType(Type type) { @@ -1187,21 +1185,12 @@ checkSupportedGroupLoadShape(const VMITargetCapabilityRegistry &capabilities, VMILayoutAttr resultLayout = resultType.getLayoutAttr(); if (!resultLayout) return fail("requires assigned result layout"); - auto selectedPlan = op->getAttrOfType(kVMISelectedPlanAttrName); - if (!selectedPlan) - return fail("requires vmi.selected_plan selected by " - "vmi-layout-assignment"); FailureOr groupSize = getGroupSizeFromNumGroups( resultType, op.getNumGroupsAttr().getInt(), reason); if (failed(groupSize)) return failure(); if (resultLayout.isContiguous()) { - StringRef expectedPlan = "group_load_contiguous_chunks"; - if (selectedPlan.getValue() != expectedPlan) - return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + - "' does not match result layout; expected '" + expectedPlan + - "'"); if (failed(checkSupportedLoadShape(capabilities, resultType, op.getSource(), op.getSource().getType(), std::nullopt, std::nullopt, reason))) @@ -1211,18 +1200,10 @@ checkSupportedGroupLoadShape(const VMITargetCapabilityRegistry &capabilities, if (resultLayout.isDeinterleaved() && resultLayout.getBlockElems() == 8 && resultType.getElementType().isF32()) { - StringRef expectedPlan; - if (*groupSize == 16 && resultLayout.getFactor() == 2) - expectedPlan = "s16_group_load_block8_stride"; - else if (*groupSize == 32 && resultLayout.getFactor() == 4) - expectedPlan = "s32_group_load_block8_stride"; - else + if ((*groupSize != 16 || resultLayout.getFactor() != 2) && + (*groupSize != 32 || resultLayout.getFactor() != 4)) return fail("block8 strided group_load requires S=16/factor=2 or " "S=32/factor=4"); - if (selectedPlan.getValue() != expectedPlan) - return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + - "' does not match result layout; expected '" + expectedPlan + - "'"); if (!isa(op.getSource().getType())) return fail("block8 strided group_load requires !pto.ptr source"); if (op.getNumGroupsAttr().getInt() % 8 != 0) @@ -1260,24 +1241,9 @@ LogicalResult checkSupportedGroupSlotLoadShape( return fail("requires explicit group_slots result layout matching " "num_groups"); - auto selectedPlan = op->getAttrOfType(kVMISelectedPlanAttrName); - if (!selectedPlan) - return fail("requires vmi.selected_plan selected by " - "vmi-layout-assignment"); - - StringRef expectedPlan; - if (layout.getSlots() == 8) - expectedPlan = "group_slot_load_slots8_unit_stride"; - else if (layout.getSlots() == 1) - expectedPlan = "group_slot_load_slots1_row_local"; - else + if (layout.getSlots() != 8 && layout.getSlots() != 1) return fail("supports only slots=8 or slots=1 group_slot_load layouts"); - if (selectedPlan.getValue() != expectedPlan) - return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + - "' does not match result layout; expected '" + expectedPlan + - "'"); - if (!capabilities.supportsDirectMemory(op.getSource().getType(), "source") .isSupported()) return fail("requires supported direct memory source"); @@ -2646,18 +2612,6 @@ LogicalResult checkS16Block8GroupReduceShape(VMIGroupReduceAddFOp op, return fail("s16 block8 group_reduce_addf requires two source/mask " "parts per result part"); - auto selectedPlan = op->getAttrOfType(kVMISelectedPlanAttrName); - if (!selectedPlan) - return fail("requires vmi.selected_plan selected by " - "vmi-layout-assignment"); - StringRef expectedPlan = sourceLayout.getBlockElems() == 1 - ? "s16_reduce_parity" - : "s16_reduce_block8"; - if (selectedPlan.getValue() != expectedPlan) - return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + - "' does not match source/result layouts; expected '" + - expectedPlan + "'"); - return success(); } @@ -2711,17 +2665,6 @@ LogicalResult checkS32Block8GroupReduceShape(VMIGroupReduceAddFOp op, return fail("s32 block8 group_reduce_addf requires four source/mask " "parts per result part"); - auto selectedPlan = op->getAttrOfType(kVMISelectedPlanAttrName); - if (!selectedPlan) - return fail("requires vmi.selected_plan selected by " - "vmi-layout-assignment"); - StringRef expectedPlan = sourceLayout.getBlockElems() == 1 - ? "s32_reduce_dintlv4" - : "s32_reduce_block8_stride"; - if (selectedPlan.getValue() != expectedPlan) - return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + - "' does not match source/result layouts; expected '" + - expectedPlan + "'"); return success(); } @@ -6974,15 +6917,6 @@ LogicalResult checkSupportedTruncFShape(VMITruncFOp op, "group_slots(num_groups=G, slots=1) source/result layouts, " "f32 source, f16 result, and matching physical arity"); - auto selectedPlan = op->getAttrOfType(kVMISelectedPlanAttrName); - if (!selectedPlan) - return fail("requires vmi.selected_plan selected by " - "vmi-layout-assignment"); - StringRef expectedPlan = "group_slot_cast_slots1_f32_to_f16"; - if (selectedPlan.getValue() != expectedPlan) - return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + - "' does not match source/result layouts; expected '" + - expectedPlan + "'"); return success(); } @@ -7411,41 +7345,17 @@ LogicalResult checkSupportedGroupReduceAddFShape( if (*sourceArity != *resultArity || *sourceArity != *maskArity) return fail("requires source/result/mask physical arity to match"); if (succeeded(checkVcgaddGroupReduceShape(sourceType, maskType, resultType, - *groupSize, nullptr))) { - if (resultLayout.getSlots() > 0) { - auto selectedPlan = - op->getAttrOfType(kVMISelectedPlanAttrName); - if (!selectedPlan) - return fail("requires vmi.selected_plan selected by " - "vmi-layout-assignment"); - StringRef expectedPlan = "s8_reduce_contiguous"; - if (selectedPlan.getValue() != expectedPlan) - return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + - "' does not match result layout; expected '" + - expectedPlan + "'"); - } + *groupSize, nullptr))) return success(); - } if (failed(checkSupportedGroupChunkShape(sourceType, *groupSize, reason))) return failure(); if (resultLayout.getSlots() <= 0) return success(); - auto selectedPlan = op->getAttrOfType(kVMISelectedPlanAttrName); - if (!selectedPlan) - return fail("requires vmi.selected_plan selected by " - "vmi-layout-assignment"); - StringRef expectedPlan; - if (sourceLayout.isContiguous() && *groupSize == 64 && - resultLayout.getSlots() == 1) - expectedPlan = "s64_reduce_row_local"; - else - return fail("explicit group_slots group_reduce_addf chunk path has no " - "registered selected_plan for the assigned layouts"); - if (selectedPlan.getValue() != expectedPlan) - return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + - "' does not match result layout; expected '" + expectedPlan + - "'"); + if (!sourceLayout.isContiguous() || *groupSize != 64 || + resultLayout.getSlots() != 1) + return fail("explicit group_slots group_reduce_addf chunk path requires " + "contiguous group size 64 source and slots=1 result layout"); return success(); } @@ -7477,26 +7387,10 @@ LogicalResult checkSupportedGroupBroadcastShape( if (resultLayout.isGroupSlots()) return fail("requires dense result layout"); - if (sourceLayout.getSlots() > 0) { - auto selectedPlan = op->getAttrOfType(kVMISelectedPlanAttrName); - if (!selectedPlan) - return fail("requires vmi.selected_plan selected by " - "vmi-layout-assignment"); - - StringRef expectedPlan; - if (sourceLayout.getSlots() == 8) - expectedPlan = "group_broadcast_slots8_vselr"; - else if (sourceLayout.getSlots() == 1) - expectedPlan = "group_broadcast_slots1_vselr"; - else - return fail("supports only slots=8 or slots=1 group_broadcast source " - "layouts"); - - if (selectedPlan.getValue() != expectedPlan) - return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + - "' does not match source layout; expected '" + expectedPlan + - "'"); - } + if (sourceLayout.getSlots() > 0 && sourceLayout.getSlots() != 8 && + sourceLayout.getSlots() != 1) + return fail("supports only slots=8 or slots=1 group_broadcast source " + "layouts"); std::string fullChunkReason; if (failed(checkFullDataPhysicalChunks(sourceType, &fullChunkReason))) @@ -8174,7 +8068,7 @@ verifySupportedVMIToVPTOOps(ModuleOp module, "to one contiguous f16 result chunk or f32 deinterleaved=4 " "source parts to one contiguous fp8-like result chunk, or f32 " "group_slots(num_groups=G, slots=1) to f16 " - "group_slots(num_groups=G, slots=1) with selected_plan (" + "group_slots(num_groups=G, slots=1) (" << reason << ")"; return WalkResult::interrupt(); } diff --git a/test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto b/test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto index dce36f1b5d..51cd09053f 100644 --- a/test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto +++ b/test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto @@ -60,7 +60,6 @@ module { // ASSIGN: %[[MASK:.*]] = pto.vmi.create_group_mask // ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[PROD]], %[[MASK]] -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] diff --git a/test/lit/vmi/vmi_layout_assignment_call_argument_boundary.pto b/test/lit/vmi/vmi_layout_assignment_call_argument_boundary.pto index 49f2c5e2a8..00879170b1 100644 --- a/test/lit/vmi/vmi_layout_assignment_call_argument_boundary.pto +++ b/test/lit/vmi/vmi_layout_assignment_call_argument_boundary.pto @@ -48,7 +48,6 @@ module { // ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.ensure_mask_layout // ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X_SPLIT]], %[[MASK_SPLIT]] -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] // ASSIGN-LABEL: func.func @caller( diff --git a/test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto b/test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto index f4790b5432..2bc648261f 100644 --- a/test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto +++ b/test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto @@ -36,7 +36,6 @@ module { // ASSIGN: %[[MASK:.*]] = pto.vmi.create_group_mask // ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] -// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // LOWER-LABEL: func.func @vmi_layout_assignment_create_group_mask_s16( diff --git a/test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto b/test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto index f68b4d5509..cb0e15864e 100644 --- a/test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto +++ b/test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto @@ -42,7 +42,6 @@ module { // ASSIGN: %[[MASK1:.*]] = pto.vmi.create_group_mask %[[ACTIVE]] // ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_reduce_addf -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" // LOWER-LABEL: func.func @vmi_layout_assignment_create_group_mask_s32_dynamic( // LOWER: arith.index_cast diff --git a/test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto b/test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto index a93ae52c17..8e8a86450d 100644 --- a/test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto +++ b/test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto @@ -38,7 +38,6 @@ module { // ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.create_mask // ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X_SPLIT]], %[[MASK_SPLIT]] -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] // ASSIGN-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr diff --git a/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto b/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto index e43d2e5591..27e304ae27 100644 --- a/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto +++ b/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto @@ -41,7 +41,6 @@ module { // ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask // ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X32]], %[[MASK]] -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_dintlv4" // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] // ASSIGN: %[[X8:.*]] = pto.vmi.truncf %[[X32]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto b/test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto index 7df6946741..20c2754e60 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto @@ -46,18 +46,14 @@ module { // ASSIGN: %[[X:.*]] = pto.vmi.load // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]] -// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[B_MUL:.*]] = pto.vmi.group_broadcast %[[SUM]] -// ASSIGN-SAME: vmi.selected_plan = "group_broadcast_slots8_vselr" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[Y:.*]] = pto.vmi.mulf %[[X]], %[[B_MUL]] // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[YSUM:.*]] = pto.vmi.group_reduce_addf %[[Y]] -// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" // ASSIGN: pto.vmi.group_store %[[YSUM]] // ASSIGN: %[[B_CAST:.*]] = pto.vmi.group_broadcast %[[SUM]] -// ASSIGN-SAME: vmi.selected_plan = "group_broadcast_slots8_vselr" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[B_CAST_SPLIT:.*]] = pto.vmi.ensure_layout %[[B_CAST]] // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_group_broadcast_slots8.pto b/test/lit/vmi/vmi_layout_assignment_group_broadcast_slots8.pto index 7c1e569bf3..2c0f4f8ca7 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_broadcast_slots8.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_broadcast_slots8.pto @@ -22,6 +22,5 @@ module { // CHECK-LABEL: func.func @vmi_layout_assignment_group_broadcast_slots8( // CHECK-SAME: -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> // CHECK: %[[OUT:.*]] = pto.vmi.group_broadcast -// CHECK-SAME: vmi.selected_plan = "group_broadcast_slots8_vselr" // CHECK-SAME: -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> // CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_load.pto b/test/lit/vmi/vmi_layout_assignment_group_load.pto index 2a90d02d08..864683cb04 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_load.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_load.pto @@ -22,6 +22,5 @@ module { // CHECK-LABEL: func.func @vmi_layout_assignment_group_load( // CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> // CHECK: %[[OUT:.*]] = pto.vmi.group_load -// CHECK-SAME: vmi.selected_plan = "group_load_contiguous_chunks" // CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> // CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_s16_stride_store.pto b/test/lit/vmi/vmi_layout_assignment_group_load_s16_stride_store.pto index 67215442e5..a3f045e503 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_load_s16_stride_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_load_s16_stride_store.pto @@ -31,12 +31,10 @@ module { // ASSIGN-LABEL: func.func @vmi_layout_assignment_group_load_s16_stride_store( // ASSIGN: %[[X:.*]] = pto.vmi.group_load -// ASSIGN-SAME: vmi.selected_plan = "s16_group_load_block8_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout // ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] -// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_broadcast_reduce.pto b/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_broadcast_reduce.pto index c97a35855b..df03683335 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_broadcast_reduce.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_broadcast_reduce.pto @@ -40,22 +40,18 @@ module { // ASSIGN-LABEL: func.func @vmi_layout_assignment_group_load_s32_stride_broadcast_reduce( // ASSIGN: %[[X:.*]] = pto.vmi.group_load -// ASSIGN-SAME: vmi.selected_plan = "s32_group_load_block8_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout // ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: %[[B:.*]] = pto.vmi.group_broadcast %[[SUM]] -// ASSIGN-SAME: vmi.selected_plan = "group_broadcast_slots8_vselr" // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: %[[Y:.*]] = pto.vmi.mulf %[[X]], %[[B]] // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: %[[MASK2:.*]] = pto.vmi.ensure_mask_layout // ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[YSUM:.*]] = pto.vmi.group_reduce_addf %[[Y]], %[[MASK2]] -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[YSUM]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_store.pto b/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_store.pto index 0f506a3a1f..abe3301b90 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_store.pto @@ -31,12 +31,10 @@ module { // ASSIGN-LABEL: func.func @vmi_layout_assignment_group_load_s32_stride_store( // ASSIGN: %[[X:.*]] = pto.vmi.group_load -// ASSIGN-SAME: vmi.selected_plan = "s32_group_load_block8_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout // ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] // ASSIGN-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_store.pto index c4652169d4..fb25c2bd91 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_store.pto @@ -34,7 +34,6 @@ module { // ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.ensure_mask_layout %[[MASK]] // ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[SOURCE_SPLIT]], %[[MASK_SPLIT]] -// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] // ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto index e9a3e7c9e9..6339aa15bc 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto @@ -35,10 +35,8 @@ module { // ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout %arg1 // ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> // ASSIGN: %[[SUM32:.*]] = pto.vmi.group_reduce_addf %[[SOURCE]], %[[MASK]] -// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[B32:.*]] = pto.vmi.group_broadcast %[[SUM32]] -// ASSIGN-SAME: vmi.selected_plan = "group_broadcast_slots8_vselr" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[B32_SPLIT:.*]] = pto.vmi.ensure_layout %[[B32]] // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_broadcast_reduce.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_broadcast_reduce.pto index 9fb03c80b2..7a72876ff9 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_broadcast_reduce.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_broadcast_reduce.pto @@ -38,15 +38,12 @@ module { // ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_broadcast_reduce( // ASSIGN-SAME: %[[SOURCE:arg[0-9]+]]: !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[SOURCE]] -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: %[[BROADCAST:.*]] = pto.vmi.group_broadcast %[[SUM]] -// ASSIGN-SAME: vmi.selected_plan = "group_broadcast_slots8_vselr" // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: %[[SCALED:.*]] = pto.vmi.mulf %[[SOURCE]], %[[BROADCAST]] // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: %[[SCALED_SUM:.*]] = pto.vmi.group_reduce_addf %[[SCALED]] -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" // LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_broadcast_reduce( // LOWER-DAG: %[[C2:.*]] = arith.constant 2 : i32 diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_multitile_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_multitile_store.pto index 1d61b4196e..b0d5a12676 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_multitile_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_multitile_store.pto @@ -34,7 +34,6 @@ module { // ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.ensure_mask_layout %[[MASK]] // ASSIGN-SAME: -> !pto.vmi.mask<512xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[SOURCE_SPLIT]], %[[MASK_SPLIT]] -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] // ASSIGN-SAME: !pto.vmi.vreg<512xf32, #pto.vmi.layout>, !pto.ptr diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_store.pto index b51dd875b5..7fe8c425bf 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_store.pto @@ -34,7 +34,6 @@ module { // ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.ensure_mask_layout %[[MASK]] // ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[SOURCE_SPLIT]], %[[MASK_SPLIT]] -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] // ASSIGN-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto index 0a7550d004..d5fa902c56 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto @@ -51,7 +51,6 @@ module { // ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask // ASSIGN-SAME: !pto.vmi.mask<192xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<192xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] @@ -76,7 +75,6 @@ module { // ASSIGN-SAME: -> !pto.vmi.vreg<192xf32, #pto.vmi.layout> // ASSIGN: %[[PMASK:.*]] = pto.vmi.create_mask %{{.*}} : index -> !pto.vmi.mask<192xb32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_reduce_addf %[[PX]], %[[PMASK]] -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" // LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile_ptr_contract( // LOWER-COUNT-4: pto.vlds diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64.pto index 2e4c9dd02f..2901a43f7e 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64.pto @@ -24,6 +24,5 @@ module { // CHECK-SAME: %arg1: !pto.vmi.mask<512xb32, #pto.vmi.layout> // CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> // CHECK: %[[OUT:.*]] = pto.vmi.group_reduce_addf %arg0, %arg1 -// CHECK-SAME: vmi.selected_plan = "s64_reduce_row_local" // CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> // CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_broadcast_reduce.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_broadcast_reduce.pto index 6fffb7c636..982d1d8a28 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_broadcast_reduce.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_broadcast_reduce.pto @@ -37,13 +37,10 @@ module { // ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s64_broadcast_reduce( // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf -// ASSIGN-SAME: vmi.selected_plan = "s64_reduce_row_local" // ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> // ASSIGN: %[[BROADCAST:.*]] = pto.vmi.group_broadcast %[[SUM]] -// ASSIGN-SAME: vmi.selected_plan = "group_broadcast_slots1_vselr" // ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> // ASSIGN: %[[SCALED_SUM:.*]] = pto.vmi.group_reduce_addf -// ASSIGN-SAME: vmi.selected_plan = "s64_reduce_row_local" // ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> // LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s64_broadcast_reduce( diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_tail_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_tail_store.pto index ec8816fbeb..6cbedb442b 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_tail_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_tail_store.pto @@ -29,7 +29,6 @@ module { // ASSIGN: %[[X:.*]] = pto.vmi.load // ASSIGN-SAME: -> !pto.vmi.vreg<384xf32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]] -// ASSIGN-SAME: vmi.selected_plan = "s64_reduce_row_local" // ASSIGN-SAME: -> !pto.vmi.vreg<384xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_truncf.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_truncf.pto index bf38aee552..cf46aa5870 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_truncf.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_truncf.pto @@ -31,10 +31,8 @@ module { // ASSIGN-SAME: %arg0: !pto.vmi.vreg<512xf32, #pto.vmi.layout> // ASSIGN-SAME: %arg1: !pto.vmi.mask<512xb32, #pto.vmi.layout> // ASSIGN: %[[SUM32:.*]] = pto.vmi.group_reduce_addf %arg0, %arg1 -// ASSIGN-SAME: vmi.selected_plan = "s64_reduce_row_local" // ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> // ASSIGN: %[[SUM16:.*]] = pto.vmi.truncf %[[SUM32]] -// ASSIGN-SAME: vmi.selected_plan = "group_slot_cast_slots1_f32_to_f16" // ASSIGN-SAME: !pto.vmi.vreg<512xf32, #pto.vmi.layout> -> !pto.vmi.vreg<512xf16, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM16]] // ASSIGN-SAME: !pto.vmi.vreg<512xf16, #pto.vmi.layout>, !pto.ptr> // CHECK-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> // CHECK: %[[OUT:.*]] = pto.vmi.group_reduce_addf %arg0, %arg1 -// CHECK-SAME: vmi.selected_plan = "s8_reduce_contiguous" // CHECK-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> // CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_slots8_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_slots8_store.pto index 1329965530..0042c64a15 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_slots8_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_slots8_store.pto @@ -30,7 +30,6 @@ module { // ASSIGN-SAME: %arg0: !pto.vmi.vreg<64xf32, #pto.vmi.layout> // ASSIGN-SAME: %arg1: !pto.vmi.mask<64xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %arg0, %arg1 -// ASSIGN-SAME: vmi.selected_plan = "s8_reduce_contiguous" // ASSIGN-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] // ASSIGN-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.ptr diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto index 9f4349d40e..9f629f55f2 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto @@ -39,20 +39,17 @@ module { // CHECK-LABEL: func.func @vmi_layout_assignment_group_slot_load_slots8( // CHECK-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // CHECK: %[[OUT:.*]] = pto.vmi.group_slot_load -// CHECK-SAME: vmi.selected_plan = "group_slot_load_slots8_unit_stride" // CHECK-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // CHECK: return %[[OUT]] // CHECK-LABEL: func.func @vmi_layout_assignment_group_slot_load_slots1( // CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> // CHECK: %[[OUT:.*]] = pto.vmi.group_slot_load -// CHECK-SAME: vmi.selected_plan = "group_slot_load_slots1_row_local" // CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> // CHECK: return %[[OUT]] // CHECK-LABEL: func.func @vmi_layout_assignment_group_slot_load_slots8_store( // CHECK: %[[OUT:.*]] = pto.vmi.group_slot_load -// CHECK-SAME: vmi.selected_plan = "group_slot_load_slots8_unit_stride" // CHECK-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // CHECK: pto.vmi.group_store %[[OUT]] // CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto index a96b847256..b5533d9abc 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto @@ -50,18 +50,14 @@ module { // ASSIGN-LABEL: func.func @vmi_layout_assignment_group_slot_load_dual_layout( // ASSIGN: %[[RHS16:.*]] = pto.vmi.group_slot_load -// ASSIGN-SAME: vmi.selected_plan = "group_slot_load_slots8_unit_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[SUM16:.*]] = pto.vmi.group_reduce_addf -// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.addf %[[SUM16]], %[[RHS16]] // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[RHS64:.*]] = pto.vmi.group_slot_load -// ASSIGN-SAME: vmi.selected_plan = "group_slot_load_slots1_row_local" // ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> // ASSIGN: %[[SUM64:.*]] = pto.vmi.group_reduce_addf -// ASSIGN-SAME: vmi.selected_plan = "s64_reduce_row_local" // ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.addf %[[SUM64]], %[[RHS64]] // ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_group_slots_fanout.pto b/test/lit/vmi/vmi_layout_assignment_group_slots_fanout.pto index d0ac525849..16905f1210 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_slots_fanout.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_slots_fanout.pto @@ -40,17 +40,14 @@ module { // ASSIGN-LABEL: func.func @vmi_layout_assignment_group_slots_fanout( // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf -// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] // ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr // ASSIGN: %[[BROADCAST:.*]] = pto.vmi.group_broadcast %[[SUM]] -// ASSIGN-SAME: vmi.selected_plan = "group_broadcast_slots8_vselr" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[SCALED:.*]] = pto.vmi.mulf // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[SCALED_SUM:.*]] = pto.vmi.group_reduce_addf %[[SCALED]] -// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SCALED_SUM]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto b/test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto index e4b48121bc..c30502a252 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto @@ -45,20 +45,17 @@ module { // ASSIGN-LABEL: func.func @vmi_layout_assignment_group_slots_scf_for( // ASSIGN: %[[ACC0:.*]] = pto.vmi.group_slot_load -// ASSIGN-SAME: vmi.selected_plan = "group_slot_load_slots8_unit_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[ACC:.*]] = scf.for // ASSIGN-SAME: iter_args(%[[ARG:.*]] = %[[ACC0]]) // ASSIGN-SAME: -> (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) // ASSIGN: %[[X:.*]] = pto.vmi.group_load -// ASSIGN-SAME: vmi.selected_plan = "s16_group_load_block8_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.create_group_mask // ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> // ASSIGN: %[[MASK:.*]] = pto.vmi.create_group_mask // ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] -// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.addf %[[ARG]], %[[SUM]] // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto b/test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto index 4004ff6fcc..6c0b2d2ece 100644 --- a/test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto +++ b/test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto @@ -51,7 +51,6 @@ module { // ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.create_mask // ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X_SPLIT]], %[[MASK_SPLIT]] -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] diff --git a/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto b/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto index 33ee79cb57..968e8d03c2 100644 --- a/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto +++ b/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto @@ -43,7 +43,6 @@ module { // ASSIGN: pto.vmi.create_group_mask // ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_reduce_addf -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" // LOWER: pto.pdintlv_b32 // LOWER: pto.pdintlv_b32 // LOWER: pto.pdintlv_b32 diff --git a/test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto b/test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto index a2d4cab4d9..46f7ff71f2 100644 --- a/test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto +++ b/test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto @@ -45,7 +45,6 @@ module { // ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask // ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] // ASSIGN-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr diff --git a/test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto b/test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto index 01e8e55caf..63fc33cfe6 100644 --- a/test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto +++ b/test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto @@ -41,7 +41,6 @@ module { // ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask // ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X32]], %[[MASK]] -// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_parity" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] // ASSIGN: %[[X32_DENSE:.*]] = pto.vmi.ensure_layout %[[X32]] diff --git a/test/lit/vmi/vmi_layout_gate_missing_selected_plan_invalid.pto b/test/lit/vmi/vmi_layout_gate_local_recipe.pto similarity index 80% rename from test/lit/vmi/vmi_layout_gate_missing_selected_plan_invalid.pto rename to test/lit/vmi/vmi_layout_gate_local_recipe.pto index d06bd275ca..7644fae1c6 100644 --- a/test/lit/vmi/vmi_layout_gate_missing_selected_plan_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_local_recipe.pto @@ -6,10 +6,10 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s +// RUN: pto-test-opt %s -pto-validate-vmi-layout-ir | FileCheck %s module { - func.func @vmi_layout_gate_missing_selected_plan_invalid( + func.func @vmi_layout_gate_local_recipe( %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} @@ -20,4 +20,5 @@ module { } } -// CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf requires vmi.selected_plan selected by vmi-layout-assignment +// CHECK-LABEL: func.func @vmi_layout_gate_local_recipe( +// CHECK: pto.vmi.group_reduce_addf diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8.pto index 3a96e94d67..01e40aaae7 100644 --- a/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8.pto +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8.pto @@ -16,7 +16,7 @@ module { !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { %out = pto.vmi.group_broadcast %source - {num_groups = 128, vmi.selected_plan = "group_broadcast_slots8_vselr"} + {num_groups = 128} : !pto.vmi.vreg<1024xf32, #pto.vmi.layout> -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_missing_plan_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_local_recipe.pto similarity index 51% rename from test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_missing_plan_invalid.pto rename to test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_local_recipe.pto index a03cdfd9df..dc1b938924 100644 --- a/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_missing_plan_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_local_recipe.pto @@ -6,24 +6,36 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s module { - func.func @vmi_to_vpto_group_broadcast_slots8_missing_plan_invalid( - %source: !pto.vmi.vreg<1024xf32, #pto.vmi.layout>) { + func.func @vmi_to_vpto_group_broadcast_slots8_local_recipe( + %source: !pto.vmi.vreg<1024xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { %out = pto.vmi.group_broadcast %source {num_groups = 128} : !pto.vmi.vreg<1024xf32, #pto.vmi.layout> -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> - "pto.vmi.unpack"(%out) + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, + %p8, %p9, %p10, %p11, %p12, %p13, %p14, %p15 = "pto.vmi.unpack"(%out) : (!pto.vmi.vreg<1024xf32, #pto.vmi.layout>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) - return + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, + %p8, %p9, %p10, %p11, %p12, %p13, %p14, %p15 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> } } -// CHECK: VMI-UNSUPPORTED: -// CHECK: pto.vmi.group_broadcast requires full source chunks -// CHECK: requires vmi.selected_plan selected by vmi-layout-assignment +// CHECK-LABEL: func.func @vmi_to_vpto_group_broadcast_slots8_local_recipe( +// CHECK-COUNT-16: pto.vselr +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_load_missing_plan_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_load_local_recipe.pto similarity index 58% rename from test/lit/vmi/vmi_to_vpto_group_load_missing_plan_invalid.pto rename to test/lit/vmi/vmi_to_vpto_group_load_local_recipe.pto index 563f939f77..a1c5959f98 100644 --- a/test/lit/vmi/vmi_to_vpto_group_load_missing_plan_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_group_load_local_recipe.pto @@ -6,24 +6,32 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s module { - func.func @vmi_to_vpto_group_load_missing_plan_invalid( + func.func @vmi_to_vpto_group_load_local_recipe( %source: !pto.ptr, - %row_stride: index) { + %row_stride: index) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) { %c0 = arith.constant 0 : index %out = pto.vmi.group_load %source[%c0], %row_stride {num_groups = 2} : !pto.ptr -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> - "pto.vmi.unpack"(%out) + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 = "pto.vmi.unpack"(%out) : (!pto.vmi.vreg<512xf32, #pto.vmi.layout>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) - return + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32> } } -// CHECK: VMI-UNSUPPORTED: -// CHECK: pto.vmi.group_load requires contiguous full result chunks -// CHECK: requires vmi.selected_plan selected by vmi-layout-assignment +// CHECK-LABEL: func.func @vmi_to_vpto_group_load_local_recipe( +// CHECK-COUNT-8: pto.vlds +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_ops.pto b/test/lit/vmi/vmi_to_vpto_group_ops.pto index e757c583f6..380a090a71 100644 --- a/test/lit/vmi/vmi_to_vpto_group_ops.pto +++ b/test/lit/vmi/vmi_to_vpto_group_ops.pto @@ -16,7 +16,7 @@ module { %mask: !pto.vmi.mask<512xb32, #pto.vmi.layout>) { %c0 = arith.constant 0 : index %v = pto.vmi.group_load %src[%c0], %row_stride - {num_groups = 2, vmi.selected_plan = "group_load_contiguous_chunks"} + {num_groups = 2} : !pto.ptr -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> %r = pto.vmi.group_reduce_addf %v, %mask {num_groups = 2, reassoc} : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto index ee12b742e8..55ae7fd255 100644 --- a/test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto @@ -16,7 +16,7 @@ module { !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { %out = pto.vmi.group_reduce_addf %source, %mask - {num_groups = 8, reassoc, vmi.selected_plan = "s64_reduce_row_local"} + {num_groups = 8, reassoc} : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, !pto.vmi.mask<512xb32, #pto.vmi.layout> -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_s64_missing_plan_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_s64_local_recipe.pto similarity index 62% rename from test/lit/vmi/vmi_to_vpto_group_reduce_s64_missing_plan_invalid.pto rename to test/lit/vmi/vmi_to_vpto_group_reduce_s64_local_recipe.pto index 96d975ab7d..4b706dc08d 100644 --- a/test/lit/vmi/vmi_to_vpto_group_reduce_s64_missing_plan_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_s64_local_recipe.pto @@ -6,25 +6,34 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s module { - func.func @vmi_to_vpto_group_reduce_s64_missing_plan_invalid( + func.func @vmi_to_vpto_group_reduce_s64_local_recipe( %source: !pto.vmi.vreg<512xf32, #pto.vmi.layout>, - %mask: !pto.vmi.mask<512xb32, #pto.vmi.layout>) { + %mask: !pto.vmi.mask<512xb32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) { %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, !pto.vmi.mask<512xb32, #pto.vmi.layout> -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> - "pto.vmi.unpack"(%out) + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 = "pto.vmi.unpack"(%out) : (!pto.vmi.vreg<512xf32, #pto.vmi.layout>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) - return + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32> } } -// CHECK: VMI-UNSUPPORTED: -// CHECK: pto.vmi.group_reduce_addf lowers through pto.vcgadd -// CHECK: requires vmi.selected_plan selected by vmi-layout-assignment +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_s64_local_recipe( +// CHECK-COUNT-8: pto.vcadd +// CHECK: pto.vsel +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_slots8.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_slots8.pto index 305c488dd5..2343869ceb 100644 --- a/test/lit/vmi/vmi_to_vpto_group_reduce_slots8.pto +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_slots8.pto @@ -14,7 +14,7 @@ module { %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) -> !pto.vreg<64xf32> { %out = pto.vmi.group_reduce_addf %source, %mask - {num_groups = 8, reassoc, vmi.selected_plan = "s8_reduce_contiguous"} + {num_groups = 8, reassoc} : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.vmi.mask<64xb32, #pto.vmi.layout> -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_slots8_missing_plan_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_slots8_local_recipe.pto similarity index 73% rename from test/lit/vmi/vmi_to_vpto_group_reduce_slots8_missing_plan_invalid.pto rename to test/lit/vmi/vmi_to_vpto_group_reduce_slots8_local_recipe.pto index b67cb34f2d..a6737eae1f 100644 --- a/test/lit/vmi/vmi_to_vpto_group_reduce_slots8_missing_plan_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_slots8_local_recipe.pto @@ -6,23 +6,27 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s module { - func.func @vmi_to_vpto_group_reduce_slots8_missing_plan_invalid( + func.func @vmi_to_vpto_group_reduce_slots8_local_recipe( %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, - %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.vmi.mask<64xb32, #pto.vmi.layout> -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> - "pto.vmi.unpack"(%out) + %part = "pto.vmi.unpack"(%out) : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) -> !pto.vreg<64xf32> - return + return %part : !pto.vreg<64xf32> } } -// CHECK: VMI-UNSUPPORTED: -// CHECK: pto.vmi.group_reduce_addf lowers through pto.vcgadd -// CHECK: requires vmi.selected_plan selected by vmi-layout-assignment +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_slots8_local_recipe( +// CHECK: pto.vcgadd +// CHECK-NOT: pto.vcadd +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_load.pto b/test/lit/vmi/vmi_to_vpto_group_slot_load.pto index 5927f63069..cf6591f36c 100644 --- a/test/lit/vmi/vmi_to_vpto_group_slot_load.pto +++ b/test/lit/vmi/vmi_to_vpto_group_slot_load.pto @@ -13,7 +13,7 @@ module { %src: !pto.ptr, %off: index) -> !pto.vreg<64xf32> { %c1 = arith.constant 1 : index %out = pto.vmi.group_slot_load %src[%off], %c1 - {num_groups = 8, vmi.selected_plan = "group_slot_load_slots8_unit_stride"} + {num_groups = 8} : !pto.ptr -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> %part = "pto.vmi.unpack"(%out) @@ -28,7 +28,7 @@ module { !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { %c8 = arith.constant 8 : index %out = pto.vmi.group_slot_load %src[%off], %c8 - {num_groups = 8, vmi.selected_plan = "group_slot_load_slots1_row_local"} + {num_groups = 8} : !pto.ptr -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 = "pto.vmi.unpack"(%out) @@ -44,7 +44,7 @@ module { %src: !pto.ptr, %dst: !pto.ptr, %off: index) { %c1 = arith.constant 1 : index %out = pto.vmi.group_slot_load %src[%off], %c1 - {num_groups = 8, vmi.selected_plan = "group_slot_load_slots8_unit_stride"} + {num_groups = 8} : !pto.ptr -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> pto.vmi.group_store %out, %dst[%off], %c1 {num_groups = 8} diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_load_missing_plan_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_slot_load_local_recipe.pto similarity index 69% rename from test/lit/vmi/vmi_to_vpto_group_slot_load_missing_plan_invalid.pto rename to test/lit/vmi/vmi_to_vpto_group_slot_load_local_recipe.pto index f442e2fbbe..3a9aa117b5 100644 --- a/test/lit/vmi/vmi_to_vpto_group_slot_load_missing_plan_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_group_slot_load_local_recipe.pto @@ -6,22 +6,24 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s module { - func.func @vmi_to_vpto_group_slot_load_missing_plan_invalid( - %src: !pto.ptr, %off: index) { + func.func @vmi_to_vpto_group_slot_load_local_recipe( + %src: !pto.ptr, %off: index) -> !pto.vreg<64xf32> { %c1 = arith.constant 1 : index %out = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} : !pto.ptr -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> - "pto.vmi.unpack"(%out) + %part = "pto.vmi.unpack"(%out) : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) -> !pto.vreg<64xf32> - return + return %part : !pto.vreg<64xf32> } } -// CHECK: VMI-UNSUPPORTED: -// CHECK: pto.vmi.group_slot_load requires explicit group_slots result layout -// CHECK: requires vmi.selected_plan selected by vmi-layout-assignment +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_load_local_recipe( +// CHECK: pto.vsldb +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto index 10d9a2d3fa..8e58305a01 100644 --- a/test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto @@ -13,7 +13,7 @@ module { %src: !pto.ptr, %off: index, %stride: index) -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { %out = pto.vmi.group_slot_load %src[%off], %stride - {num_groups = 8, vmi.selected_plan = "group_slot_load_slots8_unit_stride"} + {num_groups = 8} : !pto.ptr -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> return %out : !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1.pto b/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1.pto index d24f504e67..3f03f4669a 100644 --- a/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1.pto +++ b/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1.pto @@ -15,7 +15,6 @@ module { !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>) { %narrow = pto.vmi.truncf %source - {vmi.selected_plan = "group_slot_cast_slots1_f32_to_f16"} : !pto.vmi.vreg<512xf32, #pto.vmi.layout> -> !pto.vmi.vreg<512xf16, #pto.vmi.layout> %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 = "pto.vmi.unpack"(%narrow) diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_missing_plan_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_local_recipe.pto similarity index 58% rename from test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_missing_plan_invalid.pto rename to test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_local_recipe.pto index f265dc0912..eec3c06d2a 100644 --- a/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_missing_plan_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_local_recipe.pto @@ -6,23 +6,31 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s module { - func.func @vmi_to_vpto_group_slot_truncf_slots1_missing_plan_invalid( - %source: !pto.vmi.vreg<512xf32, #pto.vmi.layout>) { + func.func @vmi_to_vpto_group_slot_truncf_slots1_local_recipe( + %source: !pto.vmi.vreg<512xf32, #pto.vmi.layout>) + -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16>) { %narrow = pto.vmi.truncf %source : !pto.vmi.vreg<512xf32, #pto.vmi.layout> -> !pto.vmi.vreg<512xf16, #pto.vmi.layout> - "pto.vmi.unpack"(%narrow) + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 = "pto.vmi.unpack"(%narrow) : (!pto.vmi.vreg<512xf16, #pto.vmi.layout>) -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>) - return + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 + : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16> } } -// CHECK: VMI-UNSUPPORTED: -// CHECK: pto.vmi.truncf supports only -// CHECK: requires vmi.selected_plan selected by vmi-layout-assignment +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_truncf_slots1_local_recipe( +// CHECK-COUNT-8: pto.vcvt +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast From bb88c2cc212712f7fd7e6c7830fc682b585deb61 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 20:10:43 +0800 Subject: [PATCH 17/31] Implement VMI layout optimization pipeline --- docs/designs/vmi-dialect-design.md | 9 +- docs/designs/vmi-implementation-manual.md | 112 +- .../vmi-layout-assignment-implementation.md | 398 +++++-- .../vmi-layout-assignment-lowering-design.md | 152 ++- docs/designs/vmi-layout-lowering-cases.md | 130 ++- include/PTO/Transforms/Passes.h | 4 + include/PTO/Transforms/Passes.td | 69 ++ .../PTO/Transforms/VMILocalRecipeRegistry.h | 198 ++++ lib/PTO/Transforms/CMakeLists.txt | 5 + lib/PTO/Transforms/PTOValidateVMIIR.cpp | 245 +++- lib/PTO/Transforms/VMILayoutFoldConsumers.cpp | 134 +++ lib/PTO/Transforms/VMILayoutRematerialize.cpp | 172 +++ .../VMILayoutSinkMaterialization.cpp | 363 ++++++ lib/PTO/Transforms/VMILegalizeArithSelect.cpp | 88 ++ lib/PTO/Transforms/VMILocalRecipeRegistry.cpp | 1006 +++++++++++++++++ lib/PTO/Transforms/VMIToVPTO.cpp | 310 ++--- ...gnment_dense_store_group_slots_invalid.pto | 9 +- ...nment_group_load_block8_truncf_invalid.pto | 9 +- ...ut_assignment_group_reduce_s12_invalid.pto | 5 +- ...p_reduce_s32_tail_no_full_tile_invalid.pto | 5 +- .../vmi_layout_assignment_group_slot_load.pto | 5 +- ...lot_load_slots1_dynamic_stride_invalid.pto | 2 +- ...t_load_slots1_unaligned_stride_invalid.pto | 2 +- ...group_store_slots1_unit_stride_invalid.pto | 2 +- ...ment_packed_group_slots_truncf_invalid.pto | 9 +- .../vmi/vmi_layout_fold_consumers_deint4.pto | 90 ++ ...vmi_layout_fold_consumers_masked_store.pto | 57 + .../vmi/vmi_layout_fold_consumers_store.pto | 92 ++ ...ayout_gate_bitcast_group_slots_invalid.pto | 22 + ...vmi_layout_gate_bitcast_recipe_invalid.pto | 22 + .../vmi_layout_gate_extf_recipe_invalid.pto | 22 + ...ut_gate_group_broadcast_recipe_invalid.pto | 22 + ..._layout_gate_group_load_recipe_invalid.pto | 23 + ...ayout_gate_group_reduce_recipe_invalid.pto | 25 + ...ate_group_reduce_slots1_recipe_invalid.pto | 25 + ...ut_gate_group_slot_load_recipe_invalid.pto | 23 + ..._group_slots_unsupported_slots_invalid.pto | 40 + ...layout_gate_group_store_recipe_invalid.pto | 24 + ...e_helper_materialization_shape_invalid.pto | 35 + .../vmi_layout_gate_helper_recipe_invalid.pto | 22 + .../vmi_layout_gate_store_recipe_invalid.pto | 37 + .../vmi_layout_gate_truncf_recipe_invalid.pto | 22 + .../lit/vmi/vmi_layout_rematerialize_data.pto | 49 + .../lit/vmi/vmi_layout_rematerialize_mask.pto | 55 + ...vmi_layout_sink_materialization_binary.pto | 202 ++++ .../vmi_layout_sink_materialization_mask.pto | 86 ++ test/lit/vmi/vmi_legalize_arith_select.pto | 47 + test/lit/vmi/vmi_ptoas_cli_pipeline.pto | 23 + .../vmi/vmi_to_vpto_bitcast_deint_tail.pto | 33 + .../vmi_to_vpto_bitcast_footprint_invalid.pto | 23 + ...mi_to_vpto_bitcast_group_slots_invalid.pto | 23 + ...vpto_truncf_fp8_128_contiguous_invalid.pto | 4 +- tools/ptoas/ptoas.cpp | 14 + 53 files changed, 4175 insertions(+), 430 deletions(-) create mode 100644 include/PTO/Transforms/VMILocalRecipeRegistry.h create mode 100644 lib/PTO/Transforms/VMILayoutFoldConsumers.cpp create mode 100644 lib/PTO/Transforms/VMILayoutRematerialize.cpp create mode 100644 lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp create mode 100644 lib/PTO/Transforms/VMILegalizeArithSelect.cpp create mode 100644 lib/PTO/Transforms/VMILocalRecipeRegistry.cpp create mode 100644 test/lit/vmi/vmi_layout_fold_consumers_deint4.pto create mode 100644 test/lit/vmi/vmi_layout_fold_consumers_masked_store.pto create mode 100644 test/lit/vmi/vmi_layout_fold_consumers_store.pto create mode 100644 test/lit/vmi/vmi_layout_gate_bitcast_group_slots_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_bitcast_recipe_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_extf_recipe_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_group_broadcast_recipe_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_group_load_recipe_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_group_reduce_recipe_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_group_slot_load_recipe_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_group_store_recipe_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_helper_materialization_shape_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_helper_recipe_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_store_recipe_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_truncf_recipe_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_rematerialize_data.pto create mode 100644 test/lit/vmi/vmi_layout_rematerialize_mask.pto create mode 100644 test/lit/vmi/vmi_layout_sink_materialization_binary.pto create mode 100644 test/lit/vmi/vmi_layout_sink_materialization_mask.pto create mode 100644 test/lit/vmi/vmi_legalize_arith_select.pto create mode 100644 test/lit/vmi/vmi_to_vpto_bitcast_deint_tail.pto create mode 100644 test/lit/vmi/vmi_to_vpto_bitcast_footprint_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_bitcast_group_slots_invalid.pto diff --git a/docs/designs/vmi-dialect-design.md b/docs/designs/vmi-dialect-design.md index 5578ca93d1..7569b787a0 100644 --- a/docs/designs/vmi-dialect-design.md +++ b/docs/designs/vmi-dialect-design.md @@ -1626,10 +1626,11 @@ lowering 不能因为 VPTO 有更快指令就加强或放松这些属性。比 不能拆成 `mulf + addf`,也不能把 `mulf + addf` 合成 `fma`;带 `nsw/nuw` 的 integer op 可以利用 flag 做优化,不带 flag 的 op 必须保持 wraparound/defined overflow 语义。 -`pto.vmi.fma` 不能默认拆成 `mulf + addf`。`bitcast` 只有在当前 layout 下 bit grouping -physically adjacent、且每个对应 physical chunk 的 logical bit footprint 相同时才能 direct; -padding bits 只能流向 result padding bits。否则需要 layout conversion、scratch materialization -或 target capability diagnostic。 +`pto.vmi.fma` 不能默认拆成 `mulf + addf`。`bitcast` 只有在当前 contiguous/deinterleaved +layout 下 bit grouping physically adjacent、且每个对应 physical chunk 的 logical bit +footprint 相同时才能 direct;padding bits 只能流向 result padding bits。group_slots bitcast +暂不复用这个规则,必须等 slot-wise bitcast contract 定义清楚后再支持。否则需要 layout +conversion、scratch materialization 或 target capability diagnostic。 当前 VPTO direct lowering 对逐元素算术、逻辑、比较和 select 还有一条共同硬约束:物理 element width 必须能对应到 `pto.mask`。因此 VMI 语义层可以承载 `index` 或 `f64` diff --git a/docs/designs/vmi-implementation-manual.md b/docs/designs/vmi-implementation-manual.md index cd674db32a..04da993699 100644 --- a/docs/designs/vmi-implementation-manual.md +++ b/docs/designs/vmi-implementation-manual.md @@ -125,8 +125,17 @@ pipeline: ```text pto-validate-vmi-ir vmi-layout-assignment +canonicalize/cse +vmi-layout-fold-consumers +canonicalize/cse +vmi-layout-rematerialize +canonicalize/cse +vmi-layout-sink-materialization +canonicalize/cse +vmi-legalize-arith-select pto-validate-vmi-layout-ir vmi-to-vpto +canonicalize/cse ``` `--enable-vmi` requires `--pto-backend=vpto` or `pto.backend = "vpto"` because the pipeline produces physical VPTO @@ -145,6 +154,8 @@ vmi_ptoas_cli_pipeline.pto: --pto-backend=vpto + --enable-vmi lowers the VMI pipeline pto.backend = "vpto" also selects the VPTO-compatible path explicit --pto-backend=emitc with --enable-vmi is rejected + f16->f32 store lowers through the fold-consumers path, proving the driver + uses the optimized pipeline rather than only the hard skeleton vmi_ptoas_backend_required_invalid.pto: default emitc backend with --enable-vmi and no pto.backend = "vpto" is rejected @@ -155,8 +166,9 @@ vmi_ptoas_public_abi_invalid.pto / vmi_ptoas_public_result_abi_invalid.pto: ## MLIR Framework Usage -三个核心 pass 不应该用同一种 MLIR 机制硬套。这里先定义实现框架选择,避免后续把 layout -求解、结构化控制流改写和 1:N physicalization 混在一个 pattern pass 里。 +三个 correctness stage 和若干 layout optimization pass 不应该用同一种 MLIR 机制硬套。 +这里先定义实现框架选择,避免后续把 layout 求解、优化重写、结构化控制流改写和 1:N +physicalization 混在一个 pattern pass 里。 当前实现框架按下面的职责切开: @@ -168,6 +180,14 @@ vmi-layout-assignment: module-level per-SSA-value constraint solver。先收集等价类、producer natural layout 和 consumer request, 再把结果写回 VMI type/helper op。它可以使用 IRRewriter 改 IR,但不以 TypeConverter 为主模型。 +vmi-layout-fold-consumers / vmi-layout-rematerialize / vmi-layout-sink-materialization: + legal-to-legal VMI optimization passes。它们只消费 layout-assigned VMI IR,并继续产出 + layout-assigned VMI IR;所有新选择必须体现在 current op、type 或 helper IR 中。 + +vmi-legalize-arith-select: + canonicalize 之后的 hygiene pass。它把 scalar-condition arith.select with VMI result + 恢复成 VMI pipeline 可控的结构化控制流形态。 + vmi-to-vpto: MLIR OneToNTypeConversion。每个 layout-assigned VMI value 按统一 physical ordering 展开成多个 VPTO value,并依靠 OneToN structural patterns 重写函数、return、region result、block argument 和 @@ -178,7 +198,7 @@ vmi-to-vpto: 写成 `pto.vmi.ensure_*`,physicalization 后不允许残留 `pto.vmi.*`、`!pto.vmi.*` 或 `unrealized_conversion_cast`。不能把 layout 决策藏在 pass-private side table 里让后续 pass 猜。 -源码级实现应该进一步拆成五个独立层次: +源码级实现应该进一步拆成六个独立层次: ```text IR layer: @@ -201,6 +221,15 @@ Layout solving layer: 负责从 producer/consumer/control-flow/call 关系解出每个 logical value 的 layout, 然后把结果写回 type 或 ensure_* helper。 +Layout optimization layer: + lib/PTO/Transforms/VMILayoutFoldConsumers.cpp + lib/PTO/Transforms/VMILayoutRematerialize.cpp + lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp + lib/PTO/Transforms/VMILegalizeArithSelect.cpp + + 负责在 layout-assigned VMI IR 内做 legal-to-legal 改写。它可以让公共 canonicalize/cse + 协助清理和合并 IR,但不能把决策藏到 side table 里。 + Physicalization layer: lib/PTO/Transforms/VMIToVPTO.cpp @@ -265,6 +294,8 @@ pass input output --------------------------- ---------------------------- ---------------------------- pto-validate-vmi-ir surface VMI IR same IR, or hard failure vmi-layout-assignment surface/layout-partial VMI layout-assigned VMI IR +layout optimization passes layout-assigned VMI IR layout-assigned VMI IR +vmi-legalize-arith-select layout-assigned VMI IR layout-assigned VMI IR pto-validate-vmi-layout-ir layout-assigned VMI IR same IR, or hard failure vmi-to-vpto layout-assigned VMI IR physical VPTO IR final residual verifier physical VPTO candidate no pto.vmi.*, no !pto.vmi.* @@ -314,6 +345,26 @@ lib/PTO/Transforms/VMILayoutAssignment.cpp hide chosen layout in a pass-private side table infer external VMI ABI +lib/PTO/Transforms/VMILayoutFoldConsumers.cpp +lib/PTO/Transforms/VMILayoutRematerialize.cpp +lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp +lib/PTO/Transforms/VMILegalizeArithSelect.cpp + pass: + VMILayoutFoldConsumersPass + VMILayoutRematerializePass + VMILayoutSinkMaterializationPass + VMILegalizeArithSelectPass + role: + legal-to-legal layout-assigned VMI optimization and hygiene + MLIR API: + Operation::walk for local discovery + OpBuilder/RewriterBase for explicit IR rewrites + canonicalize/cse between passes for cleanup and deduplication + must not: + introduce physical VPTO register types + require vmi-to-vpto to inspect producers, users, or CFG + preserve optimization decisions outside IR + lib/PTO/Transforms/VMIToVPTO.cpp pass: VMIToVPTOPass @@ -369,6 +420,15 @@ source file pass primary lib/PTO/Transforms/PTOValidateVMIIR.cpp pto-validate-vmi-ir Operation::walk + recursive type/attr scan lib/PTO/Transforms/PTOValidateVMIIR.cpp pto-validate-vmi-layout-ir Operation::walk + recursive type/attr scan lib/PTO/Transforms/VMILayoutAssignment.cpp vmi-layout-assignment module-level union-find solver + IRRewriter +lib/PTO/Transforms/VMILayoutFoldConsumers.cpp + vmi-layout-fold-consumers Pattern-free local IR rewrite +lib/PTO/Transforms/VMILayoutRematerialize.cpp + vmi-layout-rematerialize Pattern-free local IR rewrite +lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp + vmi-layout-sink-materialization + Pattern-free local IR rewrite +lib/PTO/Transforms/VMILegalizeArithSelect.cpp + vmi-legalize-arith-select Operation::walk + OpBuilder rewrite lib/PTO/Transforms/VMIToVPTO.cpp vmi-to-vpto OneToNTypeConverter + OneToNOpConversionPattern ``` @@ -1108,14 +1168,24 @@ vmi-to-vpto: raw VMI producer -> pto-validate-vmi-ir -> vmi-layout-assignment + -> canonicalize/cse + -> vmi-layout-fold-consumers + -> canonicalize/cse + -> vmi-layout-rematerialize + -> canonicalize/cse + -> vmi-layout-sink-materialization + -> canonicalize/cse + -> vmi-legalize-arith-select -> pto-validate-vmi-layout-ir -> vmi-to-vpto + -> canonicalize/cse -> final residual verifier ``` -The `ptoas --enable-vmi` driver entry uses exactly this sequence before the existing VPTO backend pipeline. The -test-opt entry remains useful for isolated pass debugging, while the `ptoas` flag proves the same sequence is wired -through the user-facing compiler driver. +The `ptoas --enable-vmi` driver entry uses this sequence before the existing VPTO backend pipeline. +The test-opt entry remains useful for isolated pass debugging, while the `ptoas` flag proves the same sequence is +wired through the user-facing compiler driver. The optimization passes are legal-to-legal VMI rewrites; removing one +may affect quality or reject fewer/fewer optimized forms, but it must not make `vmi-to-vpto` recover hidden context. 各阶段之间只通过 IR 传递状态,不通过 pass-private side table 传递语义。也就是说: @@ -2415,11 +2485,14 @@ truncf f32 -> fp8-like: bitcast: source and result layouts must match source/result total logical bits must match - current implementation supports identical physical arity when every source/result - physical chunk carries the same number of logical bits. This covers full chunks - and partial/tail chunks such as 65xf32 -> 130xi16, where the second physical - chunk carries 32 logical bits on both sides. Partial/tail bitcast remains - unsupported if source padding bits would become result logical bits. + current implementation supports contiguous/deinterleaved layouts with identical + physical arity when every source/result physical chunk carries the same number + of logical bits. This covers full chunks and partial/tail chunks such as + 65xf32 -> 130xi16, where the second physical chunk carries 32 logical bits on + both sides, and uneven deinterleaved tails such as 129xf32 -> 129xi32. + Partial/tail bitcast remains unsupported if source padding bits would become + result logical bits. group_slots bitcast is unsupported until a slot-wise + bitcast contract is defined. load/tile_read: result layout chosen by consumers unless memory plan has a cheaper registered sink/source @@ -3141,10 +3214,12 @@ pto.vmi.truncf, direct path: pto.vmi.bitcast: for each physical part: emit pto.vbitcast(source_part) -> result_part_type - source/result layouts must match, physical arity must match, and every - corresponding physical chunk must carry the same number of logical bits. - Padding bits may map only to result padding bits; any shape where source - padding would become result logical data remains unsupported. + source/result layouts must match and must be contiguous/deinterleaved, + physical arity must match, and every corresponding physical chunk must carry + the same number of logical bits. Padding bits may map only to result padding + bits; any shape where source padding would become result logical data remains + unsupported. group_slots bitcast is rejected before vmi-to-vpto until it has + a slot-wise contract. pto.vmi.channel_split / pto.vmi.channel_merge: support 2-way and 4-way channel transforms for contiguous per-channel values @@ -3474,8 +3549,8 @@ Unsupported diagnostics: or f32 deinterleaved=4 source parts to one contiguous fp8-like result chunk unsupported pto.vmi.bitcast shape: - VMI-UNSUPPORTED: pto.vmi.bitcast requires matching source/result layouts with identical physical arity and matching - per-chunk logical bit footprints (...) + VMI-UNSUPPORTED: pto.vmi.bitcast requires matching non-group_slots source/result layouts with identical physical + arity and matching per-chunk logical bit footprints (...) unsupported pto.vmi.channel_split / pto.vmi.channel_merge channel count: VMI-UNSUPPORTED: pto.vmi.channel_split supports only 2 or 4 channels @@ -4343,7 +4418,8 @@ use VMI-UNSUPPORTED in preflight: partial/tail memory access pred-only constant mask without concrete b8/b16/b32 granularity shuffle that requires vselr index-vector materialization - bitcast across partial physical chunks + bitcast with mismatched per-chunk logical bit footprints or group_slots + bitcast without a slot-wise contract use VMI-RESIDUAL-OP: conversion framework finished but VMI op/type/helper/cast remains. diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index f4c8f8487f..03f22ffd42 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -11,40 +11,149 @@ Recommended pass pipeline: ```text -pto-validate-vmi-surface - -> vmi-layout-assignment - -> pto-validate-vmi-layout +pto-validate-vmi-ir + -> vmi-layout-assignment // hard legalization baseline + -> canonicalize/cse + -> vmi-layout-fold-consumers // optional optimization + -> canonicalize/cse + -> vmi-layout-rematerialize // optional optimization + -> canonicalize/cse + -> vmi-layout-sink-materialization // optional optimization + -> canonicalize/cse + -> vmi-legalize-arith-select + -> pto-validate-vmi-layout-ir -> vmi-to-vpto -> canonicalize/cse -> existing VPTO lowering/codegen ``` +Only `vmi-layout-assignment` is required for the first legal implementation. +The optimization passes may be introduced one by one. Their contract is that +they consume legal layout-assigned VMI IR and produce legal layout-assigned VMI +IR; they never move a hidden decision into `vmi-to-vpto`. + Pass responsibilities: ```text -pto-validate-vmi-surface: +pto-validate-vmi-ir: verify surface VMI has no physical VPTO layout dependency reject public/external VMI ABI unless explicitly enabled vmi-layout-assignment: - solve value layouts - choose selected lowering plans + solve hard value layout constraints + choose explicit layouts and local recipe carriers visible in IR insert ensure/rematerialization helpers make internal function boundary layouts explicit rewrite VMI types with layout attrs -pto-validate-vmi-layout: +canonicalize/cse: + remove dead helpers and merge identical cloned producers where MLIR legality + permits + +vmi-layout-fold-consumers: + fold use-site materialization into consumers that can directly consume the + source layout while preserving the same logical effect + example: ensure_layout(deinterleaved=2 -> contiguous) feeding store may become + a store of deinterleaved=2 when the store has a local vstsx2 INTLV recipe + current implementation: pto.vmi.store, pto.vmi.tile_write, and the value + operand of pto.vmi.masked_store when the existing mask arity matches, fed by + ensure_layout from deinterleaved=2/4, block_elems=1 to contiguous. factor=2 + uses the store's vstsx2 INTLV recipe; factor=4 is still store-local, but it + materializes through physical interleave before vsts. + +vmi-layout-rematerialize: + replace explicit ensure_* helpers with cloned cheap layout-polymorphic + producers when the clone directly creates the requested result type + current implementation: splat pto.vmi.constant, pto.vmi.broadcast, + pto.vmi.iota, pto.vmi.create_mask, pto.vmi.create_group_mask, and + pto.vmi.constant_mask + not included in the first implementation: load, group_load, masked_load, + group_slot_load, and group_broadcast; those require separate memory, + execution-count, or source-layout proof before they can be rematerialized + +vmi-layout-sink-materialization: + move ensure_layout across pure layout-transparent elementwise chains when the + rewritten IR reduces materialization cost and keeps every op locally legal + current implementation: sink two identical operand ensure_layout helpers + across binary add/sub/mul/div/min/max/and/or/xor/shl/shru VMI ops, or one + source ensure_layout across unary neg/abs/sqrt/exp/ln/relu/not VMI ops, + producing one result ensure_layout. It also sinks matching + ensure_mask_layout or ensure_mask_granularity helpers across + mask_and/mask_or/mask_xor/mask_not, producing one result mask helper. It + does not sink through select, fma, cast, load, store, reduce, + group_broadcast, or control-flow ops + +vmi-legalize-arith-select: + restore scalar-condition arith.select with VMI result type back to scf.if + after canonicalize; canonicalize may fold simple scf.if into arith.select, + but VMI values must not cross non-VMI semantic ops before vmi-to-vpto + +pto-validate-vmi-layout-ir: verify every VMI data/mask value has layout verify every VMI value has an assigned layout and every non-local lowering choice has been serialized explicitly - verify helper ops have registered materialization plans + verify helper ops have registered materialization recipes. Current + implementation checks `ensure_layout`, `ensure_mask_layout`, and + `ensure_mask_granularity` at the layout gate, so unsupported helper recipes + fail before `vmi-to-vpto`. It also checks the first semantic local-recipe + families, non-contiguous `pto.vmi.store`/`pto.vmi.tile_write`, block8 + `pto.vmi.group_load`, `pto.vmi.group_slot_load`, group_slots + `pto.vmi.group_store`, group_slots `pto.vmi.group_reduce_addf`, + explicit-slots `pto.vmi.group_broadcast`, `pto.vmi.truncf`, + `pto.vmi.extf`, and `pto.vmi.bitcast`, at the layout gate. vmi-to-vpto: use OneToN type conversion - lower only from explicit layout/plan information + lower only from current-op attrs/operands, operand/result layouts, and helper + ops emit VPTO or precise unsupported diagnostic ``` +### 1.1 Hard Constraints Versus Optimizations + +Hard legalization answers "can this program be lowered correctly?" It is +allowed to be conservative: + +```text +%w = pto.vmi.extf %a // natural layout deinterleaved=2 +%t1 = pto.vmi.mulf %w, %k1 // layout-transparent, stays deinterleaved=2 +%t1_c = pto.vmi.ensure_layout %t1 // hard store contract wants contiguous +pto.vmi.store %t1_c, %OUT1 +%w_c = pto.vmi.ensure_layout %w +pto.vmi.store %w_c, %OUT2 +``` + +This is a correct legal shape. The contiguous action is explicit at each store +use, and `vmi-to-vpto` lowers the helper with register materialization such as +`vintlv` before ordinary `vsts`. + +Optimization answers "can the same external effect be cheaper?" A fold pass +may rewrite the two store uses to consume the deinterleaved values directly: + +```text +pto.vmi.store %t1, %OUT1 // value type still says deinterleaved=2 +pto.vmi.store %w, %OUT2 +``` + +This optimized shape is legal only because `pto.vmi.store` has enough local +information to lower a `deinterleaved=2` f32 value to row-major memory, for +example with `vstsx2 INTLV_B32`. The optimization does not require +`vmi-to-vpto` to inspect `%w`'s producer or the sibling store. + +The split gives later passes room to improve layout choices: + +```text +hard pass: + guarantee legality with explicit ensure_* helpers + +optimization passes: + remove, fold, clone, or sink helpers when the optimized IR is still locally + deterministic + +vmi-to-vpto: + physicalize exactly the IR it sees, with no global planning +``` + ## 2. Files To Add Or Update Expected implementation files: @@ -56,10 +165,10 @@ include/PTO/IR/VMIAttrs.td lib/PTO/IR/VMI.cpp include/PTO/Transforms/Passes.td -lib/PTO/Transforms/ValidateVMI.cpp +lib/PTO/Transforms/PTOValidateVMIIR.cpp lib/PTO/Transforms/VMILayoutAssignment.cpp lib/PTO/Transforms/VMIToVPTO.cpp -lib/PTO/Transforms/VMILayoutPlanRegistry.cpp +lib/PTO/Transforms/VMILocalRecipeRegistry.cpp test/lit/vmi/vmi_layout_assignment_*.pto test/lit/vmi/vmi_to_vpto_*.pto @@ -115,7 +224,7 @@ contiguous: deinterleaved: F > 1 B > 0 - direct full-chunk plans require N % (F * B) == 0 + direct full-chunk recipes require N % (F * B) == 0 group_slots: G > 0 @@ -188,9 +297,9 @@ group_slot_load result group_slots layout and source_group_stride group_reduce_addf source/mask/result layouts, num_groups, reassoc group_broadcast source/result layouts and num_groups truncf source/result layouts and element widths -ensure_layout always carries source/result layouts instead of plan -ensure_mask_layout always carries source/result layouts instead of plan -ensure_mask_granularity always carries source/result granularities instead of plan +ensure_layout always carries source/result layouts instead of recipe +ensure_mask_layout always carries source/result layouts instead of recipe +ensure_mask_granularity always carries source/result granularities instead of recipe ``` Layout/attr-only decisions today: @@ -212,8 +321,9 @@ vmi-to-vpto emits VMI-LAYOUT-CONTRACT for missing local proof. If a layout/attr-only op later gains a second legal recipe that cannot be distinguished from current-op information, that recipe must be represented by a new attr, helper op, or rematerialized op before vmi-to-vpto can emit it. -Unsupported shapes that have no registered plan still diagnose through their -specific capability check rather than failing with a generic missing-plan error. +Unsupported shapes that have no registered recipe still diagnose through their +specific capability check rather than failing with a generic missing-recipe +error. ``` Examples of forbidden recovery in `vmi-to-vpto`: @@ -273,29 +383,31 @@ group_slot_load: loads one scalar per group and produces group_slots ``` -## 5. Plan Registry +## 5. Local Recipe Registry -Create one registry object shared by assignment and lowering. +Create one target-aware local recipe registry shared by assignment and lowering. +It is not serialized as a separate recipe-selection attribute. It answers local legality +questions from op kind, explicit attrs/operands, layouts, and target capability. ```c++ -class VMILayoutPlanRegistry { +class VMILocalRecipeRegistry { public: - SmallVector getProducerPlans(Operation *op); - SmallVector getConsumerPlans(OpOperand &use); - SmallVector getTransferPlans(Operation *op); - FailureOr getMaterializationPlan(Type valueType, - VMILayoutKey from, - VMILayoutKey to); + SmallVector getProducerRecipes(Operation *op); + SmallVector getConsumerRecipes(OpOperand &use); + SmallVector getTransferRecipes(Operation *op); + FailureOr + getMaterializationRecipe(Type valueType, VMILayoutKey from, + VMILayoutKey to); bool isCheaplyRematerializable(Operation *op); - bool hasTargetCapability(PlanID plan) const; + bool hasTargetCapability(RecipeID recipe) const; }; ``` -Plan record: +Recipe record: ```c++ -struct VMILayoutPlan { - PlanID id; +struct VMILayoutRecipe { + RecipeID id; SmallVector operandLayouts; SmallVector resultLayouts; int64_t cost; @@ -315,6 +427,69 @@ enablePublicVMIABI diagnosticVerbosity ``` +Assignment and optimization passes may query the registry to decide which IR +shape to produce. `vmi-to-vpto` may query the same registry to verify the +current op is locally lowerable. If the same op, attrs, operands, and +operand/result layouts could map to two different physical recipes with +different observable preconditions, the IR is under-specified; add an explicit +attr, operand, helper op, or distinct VMI semantic op before implementing that +recipe. + +Current implementation status: `VMILocalRecipeRegistry` exists and currently +owns nine local recipe families: + +```text +contiguous store/tile_write consumer recipes: + contiguous vsts + deinterleaved=2 vstsx2 INTLV + deinterleaved=4 materialize-then-vsts + +helper materialization recipes: + data/mask layout identity + data/mask contiguous <-> deinterleaved=2/4 when source/result physical + arity matches and the physical part shape can be materialized + mask granularity identity or b8/b16/b32 predicate cast + +group_slot_load semantic recipes: + slots=8 unit-stride vsldb + slots=1 aligned lane-0 vsldb per group + +block8 group_load semantic recipes: + S=16 deinterleaved=2, block_elems=8 vsldb per row fragment + S=32 deinterleaved=4, block_elems=8 vsldb per row fragment + +group_slots group_store semantic recipes: + slots=8 unit-stride vsts + slots=1 aligned lane-0 vsts per group + +group_slots group_reduce_addf semantic recipes: + S=8 vcgadd + S=16 deinterleaved=2 vcgadd+vadd + S=32 deinterleaved=4 vcgadd+vadd tree + S=64 contiguous slots=1 vcadd/vadd/vsel row-local reduction + +explicit-slots group_broadcast semantic recipes: + slots=8/slots=1 vselr materialization to contiguous or supported + deinterleaved result layouts + +extf/truncf semantic recipes: + contiguous f16/bf16 -> deinterleaved=2 f32 + contiguous f8-like -> deinterleaved=4 f32 + deinterleaved=2 f32 -> contiguous f16 + deinterleaved=4 f32 -> contiguous f8-like + group_slots(G, slots=1) f32 -> f16 + +bitcast semantic recipes: + per-part vbitcast for contiguous/deinterleaved layouts when source/result + layouts match, physical arity matches, and every physical chunk carries the + same logical bit footprint; this does not require each deinterleaved part to + contain the same number of chunks. group_slots bitcast is unsupported until a + slot-wise bitcast contract is defined. +``` + +`vmi-layout-fold-consumers`, `pto-validate-vmi-layout-ir`, and `vmi-to-vpto` +query this registry for the decisions implemented above. + ## 6. Layout Assignment Data Model ### 6.1 Solver State @@ -331,14 +506,14 @@ struct ValueLayoutState { struct UseRequest { OpOperand *operand; VMILayoutKey requestedLayout; - PlanID requestingPlan; + RecipeID requestingRecipe; bool hard; }; -struct OpPlanState { +struct OpRecipeState { Operation *op; - SmallVector candidates; - std::optional chosen; + SmallVector candidates; + std::optional chosen; }; ``` @@ -350,7 +525,7 @@ Walk the module and collect: 1. every VMI value 2. every VMI block argument 3. every VMI function argument/result -4. every VMI op with candidate plans +4. every VMI op with candidate local recipes 5. every branch/yield/call/return edge carrying VMI ``` @@ -455,11 +630,11 @@ compact S=12 logical S=16: ### 6.3.1 Request Builders Implement request generation as small per-op builders. The builders produce -candidate plans and use-site requests; they do not rewrite IR. +candidate recipes and use-site requests; they do not rewrite IR. ```text buildStoreRequests: - ordinary store -> dense contiguous request unless a layout-aware store plan is + ordinary store -> dense contiguous request unless a layout-aware store recipe is selected group_store -> group_slots(G,K) request plus stride/alignment capability checks @@ -469,8 +644,8 @@ buildCastRequests: extf f8->f32 -> source contiguous, result deinterleaved=4 truncf f32->f16 -> source deinterleaved=2/block_elems=1, result contiguous truncf f32->f8 -> source deinterleaved=4/block_elems=1, result contiguous - group_slots slots=1 f32->f16 -> slot-preserving plan - group_slots slots=8 width-changing cast -> diagnostic unless a packed plan + group_slots slots=1 f32->f16 -> slot-preserving recipe + group_slots slots=8 width-changing cast -> diagnostic unless a packed recipe exists buildGroupReduceRequests: @@ -481,11 +656,11 @@ buildGroupReduceRequests: S=32 -> deinterleaved=4/block_elems=1 or block_elems=8 source, group_slots(G,8) result S=64 -> contiguous source, group_slots(G,1) result - other S -> diagnostic unless an explicit fallback plan is enabled + other S -> diagnostic unless an explicit fallback recipe is enabled buildGroupMemoryRequests: - group_load S=16/S=32 with aligned constant stride -> block_elems=8 plan - group_load row-local full chunks -> contiguous plan + group_load S=16/S=32 with aligned constant stride -> block_elems=8 recipe + group_load row-local full chunks -> contiguous recipe group_slot_load unit stride -> group_slots(G,8) group_slot_load aligned row-local stride -> group_slots(G,1) unsupported dynamic/unaligned grouped memory -> diagnostic @@ -513,7 +688,7 @@ buildFunctionBoundaryRequests: private/internal function argument/result layouts are specialized or materialized with callee-entry/return-site helpers public/external VMI arguments/results diagnose unless enablePublicVMIABI has - a real ABI plan + a real ABI recipe ``` Request builders must record the requesting op. Diagnostics and inserted @@ -534,7 +709,7 @@ cheap rematerializable producers: create_group_mask group_broadcast group_slot_load when the same address/no-alias/proof conditions as load hold - and the selected memory plan is legal at the clone site + and the memory recipe is legal at the clone site layout-transparent producers: add/sub/mul/fma/min/max/neg/abs @@ -543,10 +718,10 @@ layout-transparent producers: integer bitwise and shift ops fixed-layout producers: - extf/truncf physical conversion plans - group_load block-fragment plans + extf/truncf physical conversion recipes + group_load block-fragment recipes group_reduce result group_slots - masked_load when the physical memory-safety proof fixes a full-read plan + masked_load when the physical memory-safety proof fixes a full-read recipe ``` Conflict policy: @@ -568,17 +743,17 @@ This is the rule that keeps case 3.32 legal: a plain `load` can be assigned to `deinterleaved=4, block_elems=1` for both `truncf f32->f8` and S=32 `group_reduce`. It also keeps case 3.19.2 diagnostic: a strided `group_load` that selected `block_elems=8` is fixed unless a block8-to-parity -materialization or rematerialized memory plan is registered. +materialization or rematerialized memory recipe is registered. ### 6.4 Solving And Rewriting Algorithm: ```text -1. Pick candidate plan sets for every op. +1. Pick candidate recipe sets for every op. 2. Propagate hard constraints through SCCs. 3. Resolve transfer-equivalent dense values. -4. Choose multi-plan ops by cost: +4. Choose multi-recipe ops by cost: - S=16 parity vs block8 - load memory-fused vs load+materialize - group_slot_load slots=8 vs slots=1 @@ -596,7 +771,7 @@ Rewrite invariants: No VMI data/mask value after assignment has a null layout. Any non-local choice is represented by op attrs, operand/result layouts, a helper op, a clone, or an explicit diagnostic. -Every ensure_* helper has a registered materialization plan. +Every ensure_* helper has a registered materialization recipe. Every function/call signature carrying VMI is specialized or diagnosed. ``` @@ -689,12 +864,12 @@ vmi-to-vpto contract: ```text case family builder / owner assignment artifact -3.4 S=8 reduce buildGroupReduceRequests s8_reduce_contiguous plan -3.5 S=16 reduce buildGroupReduceRequests s16_reduce_parity/block8 plan -3.6 S=32 reduce buildGroupReduceRequests s32_reduce_dintlv4/block8 plan -3.7 S=64 reduce buildGroupReduceRequests s64_reduce_row_local plan +3.4 S=8 reduce buildGroupReduceRequests s8_reduce_contiguous recipe +3.5 S=16 reduce buildGroupReduceRequests s16_reduce_parity/block8 recipe +3.6 S=32 reduce buildGroupReduceRequests s32_reduce_dintlv4/block8 recipe +3.7 S=64 reduce buildGroupReduceRequests s64_reduce_row_local recipe 3.11.1 S=64 active-row tail buildMaskRequests active-row store/reduce masks -3.19.1 S=16 block_elems choice buildGroupReduceRequests selected block_elems reduce plan +3.19.1 S=16 block_elems choice buildGroupReduceRequests explicit block_elems layout 3.38 multi-tile S=32 reduce buildGroupReduceRequests multiple group_slots chunks 3.26 grouped tail buildMaskRequests split grouped masks 3.44, 3.45 grouped S=32 masks buildMaskRequests explicit deinterleaved mask values @@ -708,12 +883,12 @@ vmi-to-vpto contract: ```text case family builder / owner assignment artifact -3.15.1 S=16 row stride 16 buildGroupMemoryRequests block_elems=8 group_load plan +3.15.1 S=16 row stride 16 buildGroupMemoryRequests block_elems=8 group_load recipe 3.15.2 S=16 row stride > 16 buildGroupMemoryRequests strided block_elems=8 plan 3.16.1 group_slot_load slots=8 buildGroupMemoryRequests unit-stride packed slots plan 3.16.2 group_slot_load slots=1 buildGroupMemoryRequests row-local aligned slots plan 3.27 strided group_load buildGroupMemoryRequests positive block_elems=8 plan -3.28 slots=1 non-unit load buildGroupMemoryRequests row-local group_slot_load plan +3.28 slots=1 non-unit load buildGroupMemoryRequests row-local group_slot_load recipe 3.37 slots=1 strided store buildStoreRequests group_store stride/alignment proof 3.39 strided load fanout conflict resolver preserving layout or materialization @@ -762,17 +937,17 @@ vmi-to-vpto contract: ```text diagnostic family builder / owner required failure -3.7.4 slots=1 unit-stride store buildStoreRequests no aligned row-local store plan +3.7.4 slots=1 unit-stride store buildStoreRequests no aligned row-local store recipe 3.9 dense store of group slots buildStoreRequests use group_store/group_broadcast 3.11.2 S=32 unsafe tail buildMaskRequests missing full_tile_readable/gather -3.13 slots=8 width cast buildCastRequests no packed slot cast plan -3.14 unsupported group size buildGroupReduceRequests no registered reduce plan +3.13 slots=8 width cast buildCastRequests no packed slot cast recipe +3.14 unsupported group size buildGroupReduceRequests no registered reduce recipe 3.15.3 compact S=12 buildGroupMemoryRequests no compact gather plan -3.16.1 slots=8 non-unit load buildGroupMemoryRequests no packed strided slot load plan +3.16.1 slots=8 non-unit load buildGroupMemoryRequests no packed strided slot load recipe 3.16.2 slots=1 bad stride buildGroupMemoryRequests no dynamic/unaligned row-local plan 3.19.2 invalid block_elems use conflict resolver no preserving materialization 3.25.2 public/external ABI buildFunctionBoundary no stable public VMI ABI -3.27 unaligned group_load buildGroupMemoryRequests no gather/block fallback plan +3.27 unaligned group_load buildGroupMemoryRequests no gather/block fallback recipe 3.30 masked_load unsafe tail buildMaskRequests no padding/gather fallback vmi-to-vpto contract: @@ -970,7 +1145,7 @@ group_reduce_addf: VCGADDs plus a PAT_VL8 VADD tree per packed result block. S=64 row-local assignment uses #pto.vmi.layout and has focused layout-assignment/vmi-to-vpto lit coverage; the explicit - slots=1 generic VCADD row-local path is selected locally. + slots=1 generic VCADD row-local path is registered and selected locally. group_broadcast: explicit slots=8/1 source layouts select @@ -991,8 +1166,10 @@ group_load: contiguous full-chunk path is selected from a contiguous result layout. S=16/S=32 block-aligned strided loads are selected from #pto.vmi.layout, and lower to one - vsldb per 32B row fragment and physical chunk. The dedicated S=16 unit-stride - vldsx2/BDINTLV recipe remains a local peephole target. + vsldb per 32B row fragment and physical chunk. The explicit block8 recipe + is registered and checked by pto-validate-vmi-layout-ir before vmi-to-vpto. + The dedicated S=16 unit-stride vldsx2/BDINTLV recipe remains a local + peephole target. S=16/S=32 group_load with a non-constant, non-positive, or non-8-f32-aligned row_stride is rejected by vmi-layout-assignment because the stable gather fallback is not implemented. @@ -1010,12 +1187,12 @@ group_store: multiple of the 32B store alignment in destination elements: 8 for f32, 16 for f16, and 32 for f8. Unit-stride f32 output is rejected because only the first row-local store is 32B-aligned; later `group_off + r` stores are - 4B apart. A future pack-to-slots=8 or unaligned-store plan is required before + 4B apart. A future pack-to-slots=8 or unaligned-store recipe is required before contiguous `%c1` slots=1 group_store can be accepted. Packed group_slots(G, slots=8) group_store is implemented only when num_groups is a multiple of 8 and row_stride is constant 1; it emits one PAT_VL8 store per packed slot block. Non-unit packed group stores remain a - design target unless a strided packed-lane store plan is selected explicitly. + design target unless a strided packed-lane store recipe is made explicit. ``` Examples: @@ -1065,7 +1242,7 @@ After assignment: Every VMI value has layout. Every VMI mask has layout and granularity plan. Every lowering choice is locally deterministic or explicit in attrs/layouts. -Every ensure_* helper has a materialization plan. +Every ensure_* helper has a materialization recipe. Every control-flow edge has matching VMI layouts. ``` @@ -1143,8 +1320,8 @@ VMI-LAYOUT-CONTRACT: pto.vmi.truncf requires #pto.vmi.layout, but the source value is fixed to #pto.vmi.layout by the selected - strided group_load plan. Register a rematerialization or preserving - materialization plan, or avoid consuming this block-loaded value with truncf. + strided group_load recipe. Register a rematerialization or preserving + materialization recipe, or avoid consuming this block-loaded value with truncf. ``` ## 11. Test And Simulator Acceptance @@ -1212,13 +1389,13 @@ the case catalog. Current broad runtime sweep: ```text -WORK_SPACE=$PWD/.tmp/vmi-runtime-batch-selected-plan-gate CASE_PREFIX='vmi/' JOBS=4 \ +WORK_SPACE=$PWD/.tmp/vmi-runtime-batch-layout-gate CASE_PREFIX='vmi/' JOBS=4 \ test/vpto/scripts/run_host_vpto_validation_parallel.sh PASS=43 FAIL=0 -summary: .tmp/vmi-runtime-batch-selected-plan-gate/parallel-summary.tsv +summary: .tmp/vmi-runtime-batch-layout-gate/parallel-summary.tsv log scan: rg -n "RV_|alignment|\[ERROR\]|\[error\]|ERROR" \ - .tmp/vmi-runtime-batch-selected-plan-gate.log + .tmp/vmi-runtime-batch-layout-gate.log result: no matches ``` @@ -1298,7 +1475,7 @@ repository evidence: all 43 runtime case directories contain kernel.pto, launch.cpp, main.cpp, golden.py, and compare.py latest broad VMI runtime sweep passed: PASS=43 FAIL=0 - latest full VMI lit sweep passed: 314/314 + latest full VMI lit sweep passed: 340/340 ``` Current checked-in coverage for 3.3 dense f8->f32->compute->f8: @@ -1656,6 +1833,65 @@ runtime SIM: test/vpto/cases/vmi/widen-f16-to-f32-store-reduce ``` +Current checked-in lit coverage for the first `vmi-layout-fold-consumers` +optimization is: + +```text +test/lit/vmi/vmi_layout_fold_consumers_store.pto +test/lit/vmi/vmi_layout_fold_consumers_masked_store.pto +test/lit/vmi/vmi_layout_fold_consumers_deint4.pto +``` + +Current checked-in lit coverage for the first `vmi-layout-rematerialize` +optimization is: + +```text +test/lit/vmi/vmi_layout_rematerialize_data.pto +test/lit/vmi/vmi_layout_rematerialize_mask.pto +``` + +Current checked-in lit coverage for the first +`vmi-layout-sink-materialization` optimization is: + +```text +test/lit/vmi/vmi_layout_sink_materialization_binary.pto +test/lit/vmi/vmi_layout_sink_materialization_mask.pto +``` + +Current checked-in lit coverage for canonicalized VMI control-flow restoration is: + +```text +test/lit/vmi/vmi_legalize_arith_select.pto +test/lit/vmi/vmi_ptoas_cli_control_flow.pto +``` + +Current checked-in lit coverage for the first semantic local-recipe layout gate +is: + +```text +test/lit/vmi/vmi_layout_gate_group_slot_load_recipe_invalid.pto +test/lit/vmi/vmi_layout_gate_group_load_recipe_invalid.pto +test/lit/vmi/vmi_layout_gate_group_store_recipe_invalid.pto +test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto +test/lit/vmi/vmi_layout_gate_store_recipe_invalid.pto +test/lit/vmi/vmi_layout_gate_helper_materialization_shape_invalid.pto +test/lit/vmi/vmi_layout_gate_group_reduce_recipe_invalid.pto +test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto +test/lit/vmi/vmi_layout_gate_group_broadcast_recipe_invalid.pto +test/lit/vmi/vmi_layout_gate_truncf_recipe_invalid.pto +test/lit/vmi/vmi_layout_gate_extf_recipe_invalid.pto +test/lit/vmi/vmi_layout_gate_bitcast_recipe_invalid.pto +test/lit/vmi/vmi_layout_gate_bitcast_group_slots_invalid.pto +``` + +Current checked-in direct `vmi-to-vpto` preflight coverage for bitcast local +recipes is: + +```text +test/lit/vmi/vmi_to_vpto_bitcast_footprint_invalid.pto +test/lit/vmi/vmi_to_vpto_bitcast_group_slots_invalid.pto +``` + Current checked-in coverage for 3.32 f32 feeding f8 store and S=32 reduce: ```text @@ -1710,7 +1946,7 @@ Diagnostic-only cases: 3.16.1 group_slot_load slots=8 non-unit stride 3.16.2 group_slot_load slots=1 dynamic or unaligned stride 3.27 S=32 source_group_stride not divisible by 8 f32 elements -3.19.2 block_elems=8 value consumed by truncf without materialization plan +3.19.2 block_elems=8 value consumed by truncf without materialization recipe 3.25.2 public/external VMI boundary 3.30 unsafe masked_load tail without stable masked/gather fallback ``` @@ -1729,11 +1965,17 @@ entries: ```text lit: + test/lit/vmi/vmi_layout_gate_helper_recipe_invalid.pto + test/lit/vmi/vmi_layout_gate_helper_materialization_shape_invalid.pto test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto test/lit/vmi/vmi_layout_assignment_group_load_s16_compact_stride12_invalid.pto test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto + test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto + test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto + test/lit/vmi/vmi_layout_assignment_group_load_s16_unaligned_stride_invalid.pto + test/lit/vmi/vmi_layout_assignment_group_load_s32_unaligned_stride_invalid.pto test/lit/vmi/vmi_ptoas_public_abi_invalid.pto test/lit/vmi/vmi_ptoas_public_result_abi_invalid.pto test/lit/vmi/vmi_layout_assignment_external_call_invalid.pto @@ -1795,7 +2037,7 @@ group_store ```text 3.8 cast commute through group_broadcast 3.18 dense/group-reduce multi-consumer -3.19 block_elems plan selection +3.19 block_elems recipe selection 3.23 group_broadcast multi-consumer 3.32 f32 feeding f8 store and S=32 reduce 3.33 S=16/S=32 reduce multi-consumer rematerialization @@ -1847,7 +2089,7 @@ Current evidence for the case-catalog objective: 3. every runtime case directory contains kernel.pto, launch.cpp, main.cpp, golden.py, and compare.py 4. the latest broad VMI runtime sweep passed: PASS=43 FAIL=0 -5. the latest full VMI lit sweep passed: 314/314 +5. the latest full VMI lit sweep passed: 340/340 6. every unsupported endpoint listed in section 11.3 has a diagnostic lit test 7. vmi-to-vpto decisions are represented by current-op attrs/operands, assigned layouts, helper ops, rematerialization, or diagnostics diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index b30c0c3472..4c13b07ef8 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -8,8 +8,20 @@ ```text VMI surface IR - -> vmi-layout-assignment - -> layout-assigned VMI IR + -> pto-validate-vmi-ir + -> vmi-layout-assignment // hard legalization baseline + -> canonicalize/cse + -> vmi-layout-fold-consumers // optional optimization + -> canonicalize/cse + -> vmi-layout-rematerialize // optional optimization + -> canonicalize/cse + -> vmi-layout-sink-materialization // optional optimization + -> canonicalize/cse + -> optional later layout optimization passes + -> canonicalize/cse + -> vmi-legalize-arith-select + -> pto-validate-vmi-layout-ir + -> layout-assigned and optimized VMI IR -> vmi-to-vpto -> VPTO IR ``` @@ -20,15 +32,59 @@ VMI surface IR vmi-to-vpto 不允许通过上下文猜 lowering。 任何需要 producer/consumer/control-flow/memory/mask 上下文才能决定的事, -必须在 vmi-layout-assignment 阶段变成显式 IR 信息: +必须在 vmi-layout-assignment 或后续 VMI layout optimization 阶段变成显式 IR: 1. vmi.vreg/vmi.mask 的 layout -2. op 的 selected lowering plan -3. use-site ensure_layout / ensure_mask_layout -4. rematerialized producer +2. current-op attrs/operands that make the local recipe deterministic +3. use-site ensure_layout / ensure_mask_layout / ensure_mask_granularity +4. rematerialized or cloned producer 5. target capability diagnostic ``` +## 0. Hard Legalization And Optimization Boundary + +Layout assignment is a stage, not necessarily one monolithic pass. The design +separates correctness from optimization: + +```text +hard legalization: + produces legal layout-assigned VMI IR for all supported semantics + inserts conservative ensure_* helpers at incompatible uses + may choose a simple canonical layout even when a fused consumer recipe exists + must diagnose unsupported semantics before vmi-to-vpto has to guess + +layout optimization: + rewrites already legal VMI IR into cheaper but equivalent VMI IR + may fold ensure_layout into a layout-aware consumer + may clone/rematerialize cheap producers for different use-site layouts + may sink or hoist layout materialization through pure elementwise chains + may specialize private VMI function signatures +``` + +The driver currently runs MLIR's normal `canonicalize` and `cse` between these +VMI-specific passes. They are allowed to clean up trivially unused helpers, +merge identical rematerialized producers, and expose simpler use-def shapes. +They are not a source of hidden lowering information; after every optimization, +the IR must still carry enough local information for `vmi-to-vpto`. + +The baseline hard pass may emit: + +```text +%x_c = pto.vmi.ensure_layout %x : deinterleaved=2 -> contiguous +pto.vmi.store %x_c +``` + +A later optimization may replace that use with: + +```text +pto.vmi.store %x : deinterleaved=2 +``` + +only if the store op itself has a local deterministic recipe for preserving the +same row-major memory effect, such as a layout-aware `vstsx2 INTLV` lowering. +Both forms are semantically complete. The second form is an optimization, not +a hard requirement for correctness. + ## 1. Source Case Coverage 设计必须覆盖 case catalog 中的端到端场景: @@ -55,7 +111,7 @@ layout conflict: one scalar broadcast rematerialized for dense and grouped users one non-rematerializable value materialized with use-site ensure_layout one scalar group-slot source rematerialized as slots=8 and slots=1 - S=16 block_elems=1/8 plan selection + S=16 block_elems=1/8 recipe selection dense consumer of group_slots diagnostic packed group-slot width-changing cast diagnostic S=64 slots=1 group-slot width-changing cast @@ -122,7 +178,7 @@ memory legality: ``` No extra layout kind should be added unless a new case proves that the existing -layouts and plans cannot express the logical behavior. The remaining open +layouts and recipes cannot express the logical behavior. The remaining open items are not missing layout semantics: ```text @@ -203,14 +259,14 @@ slot_lane(g) = g % K All non-slot lanes are undefined and may only be read by group-aware operations. Ordinary dense `add/mul/store/truncf` cannot consume `group_slots`. -`K` is selected by the lowering plan: +`K` is selected by the producer/consumer local recipe: ```text S=8/16/32 packed VCG result -> slots=8 S=64 row-local result -> slots=1 ``` -## 3. Lowering Context Must Become Assignment Output +## 3. Lowering Context Must Become Explicit IR Output `vmi-to-vpto` may inspect only: @@ -233,7 +289,7 @@ It must not: 6. specialize function signatures during vmi-to-vpto ``` -Any of those decisions belongs to `vmi-layout-assignment`. +Any of those decisions belongs to the layout stage before `vmi-to-vpto`. ## 4. Explicit Assignment Products @@ -323,7 +379,7 @@ group_store: masked_load: explicit passthrough, mask layout, full physical read, shaped safe-tail memref, or an explicit diagnostic decide legality. A future stable gather fallback - must be selected by assignment before vmi-to-vpto lowers it. + must be made explicit by assignment before vmi-to-vpto lowers it. masked_store/select/elementwise: operand/result layouts and explicit mask granularity decide the lowering. @@ -350,40 +406,48 @@ If the current op lacks enough local information, `vmi-to-vpto` emits `VMI-LAYOUT-CONTRACT` at the current op and prints the op name, logical type, assigned layouts, and the missing decision class. -## 5. Plan Registry +## 5. Local Recipe Registry + +The compiler owns a target-aware local recipe registry. Layout assignment and +layout optimization query this registry to decide which explicit IR shape to +produce. `vmi-to-vpto` queries the same registry only to verify and lower the +current op from local information. -The compiler owns a target-aware plan registry. Layout assignment queries this -registry; vmi-to-vpto verifies and consumes the chosen plan. +The registry is not serialized as a separate recipe-selection attribute. If +two legal physical recipes cannot be distinguished by the current op's name, +attrs, operands, operand/result layouts, helper ops, and target options, the +VMI IR is missing a carrier. Add an explicit attr, operand, helper op, or +semantic op before implementing that recipe. -### 5.1 Plan Kinds +### 5.1 Recipe Kinds ```text -ProducerPlan: +ProducerRecipe: op can produce result layout L example: load -> deinterleaved=4 using DINTLV_B32 + vdintlv -ConsumerPlan: +ConsumerRecipe: op can consume operand layout L example: group_reduce S=32 consumes deinterleaved=4 -TransferPlan: +TransferRecipe: op ties operand/result layouts example: addf requires same dense layout for operands/result -MaterializationPlan: +MaterializationRecipe: layout A -> layout B without changing logical value example: deinterleaved=4 -> contiguous by vintlv tree -RematerializationPlan: +RematerializationRecipe: cheap producer can be cloned for a use-site layout example: broadcast/create_mask/group_broadcast -DiagnosticPlan: +DiagnosticRecipe: known unsupported semantic/capability boundary example: compact S=12 requires gather materialization ``` -### 5.2 Dense Plans From Cases +### 5.2 Dense Recipes From Cases ```text f16 -> f32: @@ -413,7 +477,7 @@ load: layouts, such as S=16 deinterleaved=2 and S=32 deinterleaved=4 ``` -### 5.3 Group Plans From Cases +### 5.3 Group Recipes From Cases ```text group_reduce f32 S=8: @@ -449,11 +513,11 @@ group_store: group_slot_cast f32 -> f16: slots=1 row-local source/result is legal with group_slot_cast_slots1_f32_to_f16 - slots=8 packed source is illegal unless a packed slot-preserving plan is + slots=8 packed source is illegal unless a packed slot-preserving recipe is registered ``` -### 5.4 Tail And Memory Safety Plans +### 5.4 Tail And Memory Safety Recipes Mask semantics and memory legality are separate: @@ -504,7 +568,7 @@ new catalog case or a proof that it is equivalent to one listed here. dense store: requests dense contiguous source if source is deinterleaved, assignment must insert ensure_layout or select a - store plan such as vstsx2 that consumes the assigned layout explicitly + store recipe such as vstsx2 that consumes the assigned layout explicitly truncf f32 -> f16: requests source deinterleaved=2, block_elems=1 @@ -556,7 +620,7 @@ group_slot_load: group_load: requests result deinterleaved=2/4, block_elems=8 for S=16/S=32 block - fragment plans, or contiguous for row-local full-chunk plans + fragment recipes, or contiguous for row-local full-chunk recipes masked_load: requests result layout from its consumers @@ -564,7 +628,7 @@ masked_load: requires explicit passthrough; padding is not synthesized masked_store: - requests dense source layout selected by the store plan + requests dense source layout selected by the store recipe requests mask layout matching the source layout and store element granularity does not choose memory safety for an earlier load @@ -584,9 +648,9 @@ Important negative requests: ```text ordinary dense add/mul/store/truncf cannot request group_slots packed group_slots(slots=8) cannot request width-changing cast unless a packed -slot-preserving cast plan is registered +slot-preserving cast recipe is registered slots=1 group_store cannot request unit-stride row-major output until a pack or -unaligned-store plan exists +unaligned-store recipe exists ``` ### 5.6 Conflict Resolution Matrix @@ -617,13 +681,13 @@ control-flow join: private function boundary: specialize or materialize at call/callee-entry before vmi-to-vpto -no clone/materialization/specialization plan: +no clone/materialization/specialization recipe: emit a diagnostic naming the requesting op and both layouts ``` The cost model may choose between legal rows only when the observable contract is identical. For example, S=16 `block_elems=1` and `block_elems=8` are both -valid reduce inputs, but `block_elems=8` is selected only when a producer plan +valid reduce inputs, but `block_elems=8` is selected only when a producer recipe such as strided `group_load` naturally creates 32B row fragments or when cost proves it cheaper without breaking another consumer such as `truncf`. @@ -648,7 +712,7 @@ Create a use-site request for: ```text 1. every operand use that requires a specific layout 2. every control-flow yield/branch/call/return edge -3. every memory operation that requires a memory legality plan +3. every memory operation that requires a memory legality recipe ``` ### 6.2 Constraints @@ -657,14 +721,14 @@ Hard constraints: ```text group_slots cannot feed ordinary dense consumers -direct group-slot width-changing cast requires a slot-preserving plan +direct group-slot width-changing cast requires a slot-preserving recipe public/external VMI function boundary requires a stable ABI or diagnostic S=32 fast tail load requires full_tile_readable or gather fallback ``` -`slots = 1` row-local cast may satisfy the slot-preserving plan requirement. +`slots = 1` row-local cast may satisfy the slot-preserving recipe requirement. Packed `slots = 8` f32->f16 remains a diagnostic unless a separate packed cast -or unpack/materialization plan is registered. +or unpack/materialization recipe is registered. Equivalence constraints: @@ -686,11 +750,11 @@ S=16 group_reduce: one dense value feeding S=16 and S=32 group_reduce: rematerialize a cheap producer per consumer layout, or insert an explicit - materialization plan; the final lowering pass must not pick one layout after + materialization recipe; the final lowering pass must not pick one layout after seeing both users load/group_load: - choose memory plan and result layout together + choose memory recipe and result layout together group_broadcast: rematerialize per dense consumer layout @@ -702,10 +766,10 @@ Recommended solving order: ```text 1. Build function/control-flow SCCs. -2. Collect candidate plans for every op. +2. Collect candidate recipes for every op. 3. Propagate hard required layouts from consumers. 4. Propagate producer natural layouts where they are unique. -5. Resolve multi-plan ops by cost. +5. Resolve multi-recipe ops by cost. 6. Insert use-site materialization where a value has multiple incompatible uses. 7. Rematerialize cheap producers instead of materializing when cheaper. 8. Specialize internal function signatures. @@ -716,10 +780,10 @@ Recommended solving order: Tie-breaking must be deterministic. Suggested priority: ```text -1. Avoid unsupported plans. +1. Avoid unsupported recipes. 2. Prefer rematerializing cheap producers over register materialization. 3. Prefer layouts accepted by all consumers without conversion. -4. Prefer memory-fused layout plans over load + register rearrange. +4. Prefer memory-fused layout recipes over load + register rearrange. 5. Prefer fewer VPTO instructions. 6. Prefer contiguous only when cost ties and no consumer requests a special layout. ``` @@ -806,7 +870,7 @@ current VMI op body/attrs: helper materialization chain: allowed only to strip ensure_mask_layout / ensure_mask_granularity for - static predicate analysis that does not choose a different layout or plan + static predicate analysis that does not choose a different layout or recipe diagnostic embellishment: allowed only to improve an already-failed capability message, such as naming diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md index 262299b3a3..e084ad58c0 100644 --- a/docs/designs/vmi-layout-lowering-cases.md +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -83,7 +83,7 @@ G % K == 0 K must fit in the physical vreg element count ``` -`K` is selected by the producer/consumer plan. It is not always 8. For +`K` is selected by the producer/consumer local recipe. It is not always 8. For `VCGADD`-packed results, `K = 8` matches the eight 32B block results written to the low lanes of one destination vreg. For row-local reductions where each logical group already occupies one full 256B vreg, `K = 1` keeps each group's @@ -99,9 +99,9 @@ physical slot block slot_block(g), lane slot_lane(g) All other lanes are undefined for ordinary VMI consumers. They may only be read by group-aware ops that define how to interpret group slots. -## 2. Plan Selection Rules +## 2. Recipe Selection Rules -VMI cast ops must not hard-code one physical `vcvt` plan as their semantic +VMI cast ops must not hard-code one physical `vcvt` recipe as their semantic layout rule. ```text @@ -112,7 +112,7 @@ dense cast: group-slot cast: source/result are both group_slots(G,K). lowering preserves slot_block(g) and slot_lane(g). Width-changing casts are - legal only when a slot-preserving VPTO plan is registered, or when the cast + legal only when a slot-preserving VPTO recipe is registered, or when the cast can be commuted through a later group-aware consumer such as group_broadcast. ``` @@ -171,7 +171,7 @@ the immediately following complete endpoints. 3.16 group_slot_load layout contract complete 3.17 group_broadcast feeding deinterleaved consumer complete 3.18 one value with dense and group-reduce consumers complete/materialization -3.19 S=16 reduce block_elems plan selection complete/diagnostic +3.19 S=16 reduce block_elems recipe selection complete/diagnostic 3.20 group_slots control-flow join complete 3.21 S=32 tail with full-tile-readable source complete 3.22 scf.for loop-carried layout complete @@ -198,6 +198,7 @@ the immediately following complete endpoints. 3.43 internal function argument boundary materialization complete 3.44 masked_load grouped tail feeding S=32 reduce complete 3.45 dynamic S=32 create_group_mask complete +3.46 extf value and derived elemwise value both stored complete/optimization ``` ### 3.1 `f16 -> f32 -> store` @@ -2561,7 +2562,7 @@ VMI-LAYOUT-CONTRACT: use site. ``` -### 3.19 S=16 Reduce `block_elems` Plan Selection +### 3.19 S=16 Reduce `block_elems` Recipe Selection S=16 f32 group reduction has two legal dense input layouts: @@ -5349,3 +5350,120 @@ The runtime case passes `active_cols` as a kernel scalar argument and casts it to `index` inside `pto.vecscope`. This keeps scalar materialization outside `vmi-to-vpto`; the lowering pass only consumes the current `create_group_mask` operand. + +### 3.46 `extf` Value And Derived Elementwise Value Both Stored + +This case fixes where contiguous materialization belongs when one widened value +is used directly by a store and also by a layout-transparent elementwise chain +that is stored. + +VMI input: + +```text +%a = pto.vmi.load %in[%off] + : memref<128xf16> -> !pto.vmi.vreg<128xf16> +%k = pto.vmi.broadcast %k1 + : f32 -> !pto.vmi.vreg<128xf32> + +%w = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> +%t1 = pto.vmi.mulf %w, %k + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + +pto.vmi.store %t1, %out1[%off] +pto.vmi.store %w, %out2[%off] +``` + +Hard-legalized assigned layouts: + +```text +%a: + !pto.vmi.vreg<128xf16, #pto.vmi.layout> + +%w, %k, %t1: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%t1_c = pto.vmi.ensure_layout %t1: + #pto.vmi.layout -> #pto.vmi.layout +pto.vmi.store %t1_c, %out1[%off] + +%w_c = pto.vmi.ensure_layout %w: + #pto.vmi.layout -> #pto.vmi.layout +pto.vmi.store %w_c, %out2[%off] +``` + +Baseline VPTO lowering result: + +```text +%a0 = pto.vlds %in[%off] {dist = "NORM"} + : !pto.ptr, index -> !pto.vreg<128xf16> + +%w_p0 = pto.vcvt %a0, PAT_ALL_B16 {part = "EVEN"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +%w_p1 = pto.vcvt %a0, PAT_ALL_B16 {part = "ODD"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + +%k_p0 = pto.vdup %k1, PAT_ALL_B32 : f32, !pto.mask -> !pto.vreg<64xf32> +%k_p1 = pto.vdup %k1, PAT_ALL_B32 : f32, !pto.mask -> !pto.vreg<64xf32> + +%t1_p0 = pto.vmul %w_p0, %k_p0, PAT_ALL_B32 : !pto.vreg<64xf32> +%t1_p1 = pto.vmul %w_p1, %k_p1, PAT_ALL_B32 : !pto.vreg<64xf32> + +// ensure_layout for the first store. +%t1_0, %t1_1 = pto.vintlv %t1_p0, %t1_p1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> + -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +pto.vsts %t1_0, %out1[%off], %all_b32 {dist = "NORM_B32"} +pto.vsts %t1_1, %out1[%off_plus_64], %all_b32 {dist = "NORM_B32"} + +// ensure_layout for the second store. +%w_0, %w_1 = pto.vintlv %w_p0, %w_p1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> + -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +pto.vsts %w_0, %out2[%off], %all_b32 {dist = "NORM_B32"} +pto.vsts %w_1, %out2[%off_plus_64], %all_b32 {dist = "NORM_B32"} +``` + +Memory result: + +```text +for i = 0..127: + out1[off + i] = f32(in[off + i]) * k1 + out2[off + i] = f32(in[off + i]) +``` + +Optimization pass result: + +```text +// vmi-layout-fold-consumers may remove both ensure_layout ops if the target +// supports a store recipe that consumes deinterleaved=2 and writes contiguous +// row-major memory. +pto.vmi.store %t1, %out1[%off] +pto.vmi.store %w, %out2[%off] +``` + +Optimized VPTO lowering result: + +```text +pto.vstsx2 %t1_p0, %t1_p1, %out1[%off], "INTLV_B32", PAT_ALL_B32 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, + !pto.mask + +pto.vstsx2 %w_p0, %w_p1, %out2[%off], "INTLV_B32", PAT_ALL_B32 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, + !pto.mask +``` + +Required assignment and optimization rule: + +```text +Hard legalization may always preserve `%w` and `%t1` in deinterleaved=2 and +insert use-site ensure_layout before ordinary stores. This is correct because +the layout change is explicit at the store use. + +Consumer folding is optional. It may remove the ensure_layout only when the +store itself can locally prove the same contiguous memory effect from the +source layout. vmi-to-vpto must not scan the `%w` producer or both store users +to decide this. +``` diff --git a/include/PTO/Transforms/Passes.h b/include/PTO/Transforms/Passes.h index 15a247594b..902d8a2499 100644 --- a/include/PTO/Transforms/Passes.h +++ b/include/PTO/Transforms/Passes.h @@ -107,6 +107,10 @@ LogicalResult validateVMILayoutAssignedIR(ModuleOp module, std::unique_ptr createPTOValidateVMIIRPass(); std::unique_ptr createPTOValidateVMILayoutIRPass(); std::unique_ptr createVMILayoutAssignmentPass(); +std::unique_ptr createVMILayoutFoldConsumersPass(); +std::unique_ptr createVMILayoutRematerializePass(); +std::unique_ptr createVMILayoutSinkMaterializationPass(); +std::unique_ptr createVMILegalizeArithSelectPass(); std::unique_ptr createVMIToVPTOPass(); std::unique_ptr createExpandTileOpPass(); std::unique_ptr createExpandTileOpPass(const ExpandTileOpOptions &options); diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index 3047197d57..91dc8bfc83 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -681,6 +681,75 @@ def VMILayoutAssignment : Pass<"vmi-layout-assignment", "ModuleOp"> { "mlir::scf::SCFDialect"]; } +def VMILayoutFoldConsumers : Pass<"vmi-layout-fold-consumers", "ModuleOp"> { + let summary = "Fold VMI layout materialization into layout-aware consumers"; + let description = [{ + Optimizes legal layout-assigned VMI IR by replacing selected use-site + ensure_layout consumers with consumers that can directly lower from the + source layout while preserving the same logical effect. The pass does not + choose layouts by inspecting producer/user context for vmi-to-vpto; it only + rewrites explicit helper IR into an equivalent local-consumer form. + }]; + let constructor = "mlir::pto::createVMILayoutFoldConsumersPass()"; + let dependentDialects = ["mlir::cf::ControlFlowDialect", + "mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect"]; +} + +def VMILayoutRematerialize : Pass<"vmi-layout-rematerialize", "ModuleOp"> { + let summary = "Rematerialize cheap VMI producers at layout helpers"; + let description = [{ + Optimizes legal layout-assigned VMI IR by replacing selected ensure_layout, + ensure_mask_layout, and ensure_mask_granularity helpers with cloned cheap + producers that directly create the requested result type. The pass is + deliberately limited to pure construction ops, so memory, control-flow, and + mask-tail proofs remain explicit in the IR. + }]; + let constructor = "mlir::pto::createVMILayoutRematerializePass()"; + let dependentDialects = ["mlir::cf::ControlFlowDialect", + "mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect"]; +} + +def VMILayoutSinkMaterialization + : Pass<"vmi-layout-sink-materialization", "ModuleOp"> { + let summary = "Sink VMI layout materialization through transfer ops"; + let description = [{ + Optimizes legal layout-assigned VMI IR by moving matching operand + ensure_layout helpers across pure layout-transparent elementwise operations. + The rewritten IR keeps the layout conversion explicit as a result + ensure_layout, so vmi-to-vpto still lowers from local op information only. + }]; + let constructor = "mlir::pto::createVMILayoutSinkMaterializationPass()"; + let dependentDialects = ["mlir::cf::ControlFlowDialect", + "mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect"]; +} + +def VMILegalizeArithSelect : Pass<"vmi-legalize-arith-select", "ModuleOp"> { + let summary = "Legalize canonical arith.select over VMI values"; + let description = [{ + Rewrites scalar-condition arith.select operations that produce VMI values + back to scf.if. MLIR canonicalization may fold simple scf.if regions into + arith.select, but VMI values must not cross non-VMI semantic ops before + vmi-to-vpto. This pass restores an explicit structural control-flow form + that the VMI converter already handles. + }]; + let constructor = "mlir::pto::createVMILegalizeArithSelectPass()"; + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::cf::ControlFlowDialect", + "mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect"]; +} + def VMIToVPTO : Pass<"vmi-to-vpto", "ModuleOp"> { let summary = "Convert layout-assigned VMI IR to physical VPTO IR"; let description = [{ diff --git a/include/PTO/Transforms/VMILocalRecipeRegistry.h b/include/PTO/Transforms/VMILocalRecipeRegistry.h new file mode 100644 index 0000000000..7356be9e92 --- /dev/null +++ b/include/PTO/Transforms/VMILocalRecipeRegistry.h @@ -0,0 +1,198 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +//===- VMILocalRecipeRegistry.h - VMI local recipe queries ------*- C++ -*-===// +//===----------------------------------------------------------------------===// + +#ifndef PTO_TRANSFORMS_VMILOCALRECIPEREGISTRY_H +#define PTO_TRANSFORMS_VMILOCALRECIPEREGISTRY_H + +#include "PTO/IR/PTO.h" +#include "mlir/Support/LLVM.h" + +#include + +namespace mlir::pto { + +class VMITargetCapabilityRegistry; + +enum class VMIContiguousStoreRecipeKind { + ContiguousVsts, + Deinterleaved2Vstsx2, + DeinterleavedMaterializeThenVsts, +}; + +struct VMIContiguousStoreRecipe { + VMIContiguousStoreRecipeKind kind = + VMIContiguousStoreRecipeKind::ContiguousVsts; +}; + +enum class VMILayoutMaterializationRecipeKind { + Identity, + ContiguousToDeinterleaved, + DeinterleavedToContiguous, +}; + +struct VMILayoutMaterializationRecipe { + VMILayoutMaterializationRecipeKind kind = + VMILayoutMaterializationRecipeKind::Identity; +}; + +enum class VMIMaskGranularityMaterializationRecipeKind { + Identity, + PredicateCast, +}; + +struct VMIMaskGranularityMaterializationRecipe { + VMIMaskGranularityMaterializationRecipeKind kind = + VMIMaskGranularityMaterializationRecipeKind::Identity; +}; + +enum class VMIGroupSlotLoadRecipeKind { + Slots8UnitStrideVsldb, + Slots1AlignedLane0Vsldb, +}; + +struct VMIGroupSlotLoadRecipe { + VMIGroupSlotLoadRecipeKind kind = + VMIGroupSlotLoadRecipeKind::Slots8UnitStrideVsldb; +}; + +enum class VMIGroupLoadRecipeKind { + S16Block8Vsldb, + S32Block8Vsldb, +}; + +struct VMIGroupLoadRecipe { + VMIGroupLoadRecipeKind kind = VMIGroupLoadRecipeKind::S16Block8Vsldb; +}; + +enum class VMIGroupSlotsStoreRecipeKind { + Slots8UnitStrideVsts, + Slots1AlignedLane0Vsts, +}; + +struct VMIGroupSlotsStoreRecipe { + VMIGroupSlotsStoreRecipeKind kind = + VMIGroupSlotsStoreRecipeKind::Slots8UnitStrideVsts; +}; + +enum class VMIGroupReduceAddFRecipeKind { + S8Vcgadd, + S16Deinterleaved2VcgaddVadd, + S32Deinterleaved4VcgaddTree, + S64ContiguousVcaddRows, +}; + +struct VMIGroupReduceAddFRecipe { + VMIGroupReduceAddFRecipeKind kind = VMIGroupReduceAddFRecipeKind::S8Vcgadd; +}; + +enum class VMIGroupBroadcastRecipeKind { + GroupSlotsVselr, +}; + +struct VMIGroupBroadcastRecipe { + VMIGroupBroadcastRecipeKind kind = + VMIGroupBroadcastRecipeKind::GroupSlotsVselr; +}; + +enum class VMITruncFRecipeKind { + Deinterleaved2F32ToContiguousF16, + Deinterleaved4F32ToContiguousF8, + GroupSlots1F32ToF16, +}; + +struct VMITruncFRecipe { + VMITruncFRecipeKind kind = + VMITruncFRecipeKind::Deinterleaved2F32ToContiguousF16; +}; + +enum class VMIExtFRecipeKind { + ContiguousF16ToDeinterleaved2F32, + ContiguousF8ToDeinterleaved4F32, +}; + +struct VMIExtFRecipe { + VMIExtFRecipeKind kind = + VMIExtFRecipeKind::ContiguousF16ToDeinterleaved2F32; +}; + +enum class VMIBitcastRecipeKind { + PerPartVbitcast, +}; + +struct VMIBitcastRecipe { + VMIBitcastRecipeKind kind = VMIBitcastRecipeKind::PerPartVbitcast; +}; + +class VMILocalRecipeRegistry { +public: + FailureOr + getContiguousStoreRecipe(VMIVRegType valueType, + std::string *reason = nullptr) const; + + LogicalResult canFoldContiguousStoreMaterialization( + VMIVRegType sourceType, VMIVRegType resultType, + std::string *reason = nullptr) const; + + FailureOr + getDataLayoutMaterializationRecipe(VMIVRegType sourceType, + VMIVRegType resultType, + std::string *reason = nullptr) const; + + FailureOr + getMaskLayoutMaterializationRecipe(VMIMaskType sourceType, + VMIMaskType resultType, + std::string *reason = nullptr) const; + + FailureOr + getMaskGranularityMaterializationRecipe(VMIMaskType sourceType, + VMIMaskType resultType, + std::string *reason = nullptr) const; + + FailureOr + getGroupSlotLoadRecipe(const VMITargetCapabilityRegistry &capabilities, + VMIGroupSlotLoadOp op, + std::string *reason = nullptr) const; + + FailureOr + getGroupLoadRecipe(const VMITargetCapabilityRegistry &capabilities, + VMIGroupLoadOp op, + std::string *reason = nullptr) const; + + FailureOr + getGroupSlotsStoreRecipe(const VMITargetCapabilityRegistry &capabilities, + VMIGroupStoreOp op, + std::string *reason = nullptr) const; + + FailureOr + getGroupReduceAddFRecipe(const VMITargetCapabilityRegistry &capabilities, + VMIGroupReduceAddFOp op, + std::string *reason = nullptr) const; + + FailureOr + getGroupBroadcastRecipe(const VMITargetCapabilityRegistry &capabilities, + VMIGroupBroadcastOp op, + std::string *reason = nullptr) const; + + FailureOr + getTruncFRecipe(VMITruncFOp op, std::string *reason = nullptr) const; + + FailureOr + getExtFRecipe(VMIExtFOp op, std::string *reason = nullptr) const; + + FailureOr + getBitcastRecipe(VMIBitcastOp op, std::string *reason = nullptr) const; +}; + +} // namespace mlir::pto + +#endif // PTO_TRANSFORMS_VMILOCALRECIPEREGISTRY_H diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index fef96ec2c8..9dbad686a4 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -35,7 +35,12 @@ add_mlir_dialect_library(PTOTransforms VPTOBufferMaterialization.cpp PTOValidateVPTOIR.cpp PTOValidateVMIIR.cpp + VMILegalizeArithSelect.cpp VMILayoutAssignment.cpp + VMILayoutFoldConsumers.cpp + VMILocalRecipeRegistry.cpp + VMILayoutRematerialize.cpp + VMILayoutSinkMaterialization.cpp VMIToVPTO.cpp PTOInferVPTOVecScope.cpp diff --git a/lib/PTO/Transforms/PTOValidateVMIIR.cpp b/lib/PTO/Transforms/PTOValidateVMIIR.cpp index 6ce3e8eecd..7234084c47 100644 --- a/lib/PTO/Transforms/PTOValidateVMIIR.cpp +++ b/lib/PTO/Transforms/PTOValidateVMIIR.cpp @@ -12,6 +12,8 @@ #include "PTO/IR/PTO.h" #include "PTO/IR/VMIUtils.h" #include "PTO/Transforms/Passes.h" +#include "PTO/Transforms/VMILocalRecipeRegistry.h" +#include "PTO/Transforms/VMITargetCapabilities.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -159,6 +161,49 @@ LogicalResult emitInvariant(Operation *op, llvm::raw_ostream *diagOS, return failure(); } +LogicalResult emitLayoutContract(Operation *op, llvm::raw_ostream *diagOS, + Twine message) { + InFlightDiagnostic diag = + op->emitError() << kVMIDiagLayoutContractPrefix << message; + (void)diag; + mirrorDiagnostic(diagOS, Twine(kVMIDiagLayoutContractPrefix) + message); + return failure(); +} + +LogicalResult emitHelperMaterializationContract(Operation *helper, + Type sourceType, + Type resultType, + StringRef helperName, + StringRef reason, + llvm::raw_ostream *diagOS) { + auto emitFallback = [&]() { + return emitLayoutContract( + helper, diagOS, + Twine(helperName) + " has no registered materialization recipe: " + + reason); + }; + + if (helper->getNumResults() != 1 || !helper->getResult(0).hasOneUse()) + return emitFallback(); + + OpOperand &use = *helper->getResult(0).use_begin(); + Operation *requester = use.getOwner(); + std::string message; + llvm::raw_string_ostream os(message); + os << requester->getName() << " operand #" << use.getOperandNumber() + << " has type " << sourceType << " but requires " << resultType << "; " + << helperName << " has no registered materialization recipe: " << reason; + os.flush(); + + InFlightDiagnostic diag = + requester->emitError() << kVMIDiagLayoutContractPrefix << message; + diag.attachNote(helper->getLoc()) + << "failed helper conversion " << sourceType << " -> " << resultType + << " (" << reason << ")"; + mirrorDiagnostic(diagOS, Twine(kVMIDiagLayoutContractPrefix) + message); + return failure(); +} + LogicalResult verifyBoundaryType(Operation *owner, Type type, llvm::raw_ostream *diagOS) { if (isPhysicalVPTOType(type)) @@ -350,6 +395,12 @@ LogicalResult verifyLayoutAssignedOperationTypes(Operation *op, return success(); } +LogicalResult verifyLayoutHelperRecipe(Operation *op, + llvm::raw_ostream *diagOS); + +LogicalResult verifyLayoutSemanticRecipe(Operation *op, + llvm::raw_ostream *diagOS); + LogicalResult verifyOperationBoundary(Operation *op, llvm::raw_ostream *diagOS) { if (failed(verifyOperationTypes(op, diagOS))) @@ -380,19 +431,209 @@ LogicalResult verifyLayoutAssignedOperation(Operation *op, if (isVMIHelperOp(op)) { if (isVMILayoutHelperOp(op)) - return success(); + return verifyLayoutHelperRecipe(op, diagOS); return emitInvariant( op, diagOS, "VMI pack/unpack helper appears before VMI-to-VPTO physicalization"); } - if (isVMISemanticOp(op) || isStructuralOp(op)) + if (isVMISemanticOp(op)) + return verifyLayoutSemanticRecipe(op, diagOS); + if (isStructuralOp(op)) return success(); return emitInvariant(op, diagOS, "VMI typed value is used by a non-VMI semantic op"); } +LogicalResult verifyLayoutHelperRecipe(Operation *op, + llvm::raw_ostream *diagOS) { + VMILocalRecipeRegistry recipes; + + if (auto ensure = dyn_cast(op)) { + auto sourceType = cast(ensure.getSource().getType()); + auto resultType = cast(ensure.getResult().getType()); + std::string reason; + if (failed(recipes.getDataLayoutMaterializationRecipe(sourceType, + resultType, + &reason))) + return emitHelperMaterializationContract( + op, sourceType, resultType, "pto.vmi.ensure_layout", reason, diagOS); + return success(); + } + + if (auto ensure = dyn_cast(op)) { + auto sourceType = cast(ensure.getSource().getType()); + auto resultType = cast(ensure.getResult().getType()); + std::string reason; + if (failed(recipes.getMaskLayoutMaterializationRecipe(sourceType, + resultType, + &reason))) + return emitHelperMaterializationContract( + op, sourceType, resultType, "pto.vmi.ensure_mask_layout", reason, + diagOS); + return success(); + } + + if (auto ensure = dyn_cast(op)) { + auto sourceType = cast(ensure.getSource().getType()); + auto resultType = cast(ensure.getResult().getType()); + std::string reason; + if (failed(recipes.getMaskGranularityMaterializationRecipe( + sourceType, resultType, &reason))) + return emitLayoutContract( + op, diagOS, + Twine("pto.vmi.ensure_mask_granularity has no registered " + "materialization recipe: ") + + reason); + return success(); + } + + return success(); +} + +LogicalResult verifyLayoutSemanticRecipe(Operation *op, + llvm::raw_ostream *diagOS) { + VMILocalRecipeRegistry recipes; + VMITargetCapabilityRegistry capabilities; + + if (auto store = dyn_cast(op)) { + auto valueType = cast(store.getValue().getType()); + VMILayoutAttr layout = valueType.getLayoutAttr(); + if (!layout || layout.isContiguous()) + return success(); + + std::string reason; + if (failed(recipes.getContiguousStoreRecipe(valueType, &reason))) + return emitLayoutContract( + op, diagOS, + Twine("pto.vmi.store has no registered contiguous-memory local " + "recipe: ") + + reason); + return success(); + } + + if (auto tileWrite = dyn_cast(op)) { + auto valueType = cast(tileWrite.getValue().getType()); + VMILayoutAttr layout = valueType.getLayoutAttr(); + if (!layout || layout.isContiguous()) + return success(); + + std::string reason; + if (failed(recipes.getContiguousStoreRecipe(valueType, &reason))) + return emitLayoutContract( + op, diagOS, + Twine("pto.vmi.tile_write has no registered contiguous-memory local " + "recipe: ") + + reason); + return success(); + } + + if (auto load = dyn_cast(op)) { + auto resultType = cast(load.getResult().getType()); + VMILayoutAttr layout = resultType.getLayoutAttr(); + if (!layout || !layout.isDeinterleaved() || layout.getBlockElems() != 8 || + !resultType.getElementType().isF32()) + return success(); + + std::string reason; + if (failed(recipes.getGroupLoadRecipe(capabilities, load, &reason))) + return emitLayoutContract( + op, diagOS, + Twine("pto.vmi.group_load has no registered block8 local recipe: ") + + reason); + return success(); + } + + if (auto load = dyn_cast(op)) { + std::string reason; + if (failed(recipes.getGroupSlotLoadRecipe(capabilities, load, &reason))) + return emitLayoutContract( + op, diagOS, + Twine("pto.vmi.group_slot_load has no registered local recipe: ") + + reason); + return success(); + } + + if (auto store = dyn_cast(op)) { + auto valueType = cast(store.getValue().getType()); + VMILayoutAttr layout = valueType.getLayoutAttr(); + if (!layout || !layout.isGroupSlots()) + return success(); + + std::string reason; + if (failed(recipes.getGroupSlotsStoreRecipe(capabilities, store, &reason))) + return emitLayoutContract( + op, diagOS, + Twine("pto.vmi.group_store has no registered group_slots local " + "recipe: ") + + reason); + return success(); + } + + if (auto reduce = dyn_cast(op)) { + auto resultType = cast(reduce.getResult().getType()); + VMILayoutAttr layout = resultType.getLayoutAttr(); + if (!layout || !layout.isGroupSlots()) + return success(); + + std::string reason; + if (failed(recipes.getGroupReduceAddFRecipe(capabilities, reduce, + &reason))) + return emitLayoutContract( + op, diagOS, + Twine("pto.vmi.group_reduce_addf has no registered group_slots " + "local recipe: ") + + reason); + return success(); + } + + if (auto broadcast = dyn_cast(op)) { + auto sourceType = cast(broadcast.getSource().getType()); + VMILayoutAttr layout = sourceType.getLayoutAttr(); + if (!layout || !layout.isGroupSlots() || layout.getSlots() <= 0) + return success(); + + std::string reason; + if (failed(recipes.getGroupBroadcastRecipe(capabilities, broadcast, + &reason))) + return emitLayoutContract( + op, diagOS, + Twine("pto.vmi.group_broadcast has no registered local recipe: ") + + reason); + return success(); + } + + if (auto truncf = dyn_cast(op)) { + std::string reason; + if (failed(recipes.getTruncFRecipe(truncf, &reason))) + return emitLayoutContract( + op, diagOS, + Twine("pto.vmi.truncf has no registered local recipe: ") + reason); + return success(); + } + + if (auto extf = dyn_cast(op)) { + std::string reason; + if (failed(recipes.getExtFRecipe(extf, &reason))) + return emitLayoutContract( + op, diagOS, + Twine("pto.vmi.extf has no registered local recipe: ") + reason); + return success(); + } + + if (auto bitcast = dyn_cast(op)) { + std::string reason; + if (failed(recipes.getBitcastRecipe(bitcast, &reason))) + return emitLayoutContract( + op, diagOS, + Twine("pto.vmi.bitcast has no registered local recipe: ") + reason); + return success(); + } + + return success(); +} + struct PTOValidateVMIIRPass : public mlir::pto::impl::PTOValidateVMIIRBase { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PTOValidateVMIIRPass) diff --git a/lib/PTO/Transforms/VMILayoutFoldConsumers.cpp b/lib/PTO/Transforms/VMILayoutFoldConsumers.cpp new file mode 100644 index 0000000000..26536f196d --- /dev/null +++ b/lib/PTO/Transforms/VMILayoutFoldConsumers.cpp @@ -0,0 +1,134 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +//===- VMILayoutFoldConsumers.cpp - Fold VMI layout consumers ------------===// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/IR/VMIUtils.h" +#include "PTO/Transforms/Passes.h" +#include "PTO/Transforms/VMILocalRecipeRegistry.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/STLExtras.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_VMILAYOUTFOLDCONSUMERS +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +static bool isFoldableStoreEnsure(VMIEnsureLayoutOp ensure) { + auto sourceType = dyn_cast(ensure.getSource().getType()); + auto resultType = dyn_cast(ensure.getResult().getType()); + if (!sourceType || !resultType) + return false; + + VMILocalRecipeRegistry recipes; + return succeeded( + recipes.canFoldContiguousStoreMaterialization(sourceType, resultType)); +} + +static void tryFoldEnsureLayoutIntoOperand( + OpOperand &operand, SmallVectorImpl &maybeDeadEnsures) { + auto ensure = operand.get().getDefiningOp(); + if (!ensure || !isFoldableStoreEnsure(ensure)) + return; + + operand.set(ensure.getSource()); + maybeDeadEnsures.push_back(ensure); +} + +static void tryFoldEnsureLayoutIntoMaskedStore( + VMIMaskedStoreOp store, + SmallVectorImpl &maybeDeadEnsures, + SmallVectorImpl &maybeDeadMaskEnsures) { + auto ensure = store.getValue().getDefiningOp(); + if (!ensure || !isFoldableStoreEnsure(ensure)) + return; + auto maskEnsure = store.getMask().getDefiningOp(); + if (!maskEnsure) + return; + + auto sourceType = dyn_cast(ensure.getSource().getType()); + auto maskSourceType = dyn_cast(maskEnsure.getSource().getType()); + auto maskResultType = dyn_cast(maskEnsure.getResult().getType()); + if (!sourceType || !maskSourceType || !maskResultType) + return; + + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr maskSourceLayout = maskSourceType.getLayoutAttr(); + VMILayoutAttr maskResultLayout = maskResultType.getLayoutAttr(); + if (!sourceLayout || !maskSourceLayout || !maskResultLayout) + return; + if (sourceLayout != maskSourceLayout || !maskResultLayout.isContiguous()) + return; + + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr maskArity = getVMIPhysicalArity(maskSourceType); + if (failed(sourceArity) || failed(maskArity) || *sourceArity != *maskArity) + return; + + store.getValueMutable().set(ensure.getSource()); + store.getMaskMutable().set(maskEnsure.getSource()); + maybeDeadEnsures.push_back(ensure); + maybeDeadMaskEnsures.push_back(maskEnsure); +} + +struct VMILayoutFoldConsumersPass + : public mlir::pto::impl::VMILayoutFoldConsumersBase< + VMILayoutFoldConsumersPass> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VMILayoutFoldConsumersPass) + + void runOnOperation() override { + ModuleOp module = getOperation(); + SmallVector maybeDeadEnsures; + SmallVector maybeDeadMaskEnsures; + + module.walk([&](Operation *op) { + if (auto store = dyn_cast(op)) + tryFoldEnsureLayoutIntoOperand(store.getValueMutable(), + maybeDeadEnsures); + if (auto tileWrite = dyn_cast(op)) + tryFoldEnsureLayoutIntoOperand(tileWrite.getValueMutable(), + maybeDeadEnsures); + if (auto maskedStore = dyn_cast(op)) + tryFoldEnsureLayoutIntoMaskedStore(maskedStore, maybeDeadEnsures, + maybeDeadMaskEnsures); + }); + + for (VMIEnsureMaskLayoutOp ensure : llvm::reverse(maybeDeadMaskEnsures)) { + if (ensure->use_empty()) + ensure.erase(); + } + for (VMIEnsureLayoutOp ensure : llvm::reverse(maybeDeadEnsures)) { + if (ensure->use_empty()) + ensure.erase(); + } + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createVMILayoutFoldConsumersPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/VMILayoutRematerialize.cpp b/lib/PTO/Transforms/VMILayoutRematerialize.cpp new file mode 100644 index 0000000000..4f230d4189 --- /dev/null +++ b/lib/PTO/Transforms/VMILayoutRematerialize.cpp @@ -0,0 +1,172 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +//===- VMILayoutRematerialize.cpp - Rematerialize VMI producers ----------===// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/STLExtras.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_VMILAYOUTREMATERIALIZE +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +static bool hasConcreteLayout(VMIVRegType type) { + return type && static_cast(type.getLayoutAttr()); +} + +static bool hasConcreteLayout(VMIMaskType type) { + return type && static_cast(type.getLayoutAttr()); +} + +static std::optional rematerializeDataProducer(Value value, + VMIVRegType resultType, + Location loc, + OpBuilder &builder) { + if (!hasConcreteLayout(resultType)) + return std::nullopt; + + if (auto constant = value.getDefiningOp()) { + auto denseAttr = dyn_cast(constant.getValue()); + if (denseAttr && denseAttr.isSplat()) + return builder + .create(loc, resultType, constant.getValue()) + .getResult(); + } + + if (auto broadcast = value.getDefiningOp()) + return builder.create(loc, resultType, + broadcast.getValue()) + .getResult(); + + if (auto iota = value.getDefiningOp()) + return builder + .create(loc, resultType, iota.getBase(), + iota.getOrderAttr()) + .getResult(); + + return std::nullopt; +} + +static std::optional rematerializeMaskProducer(Value value, + VMIMaskType resultType, + Location loc, + OpBuilder &builder) { + if (!hasConcreteLayout(resultType)) + return std::nullopt; + + if (auto createMask = value.getDefiningOp()) + return builder + .create(loc, resultType, createMask.getActiveLanes()) + .getResult(); + + if (auto createGroupMask = value.getDefiningOp()) + return builder + .create( + loc, resultType, createGroupMask.getActiveElemsPerGroup(), + createGroupMask.getNumGroupsAttr(), createGroupMask.getGroupSizeAttr()) + .getResult(); + + if (auto constantMask = value.getDefiningOp()) + return builder + .create(loc, resultType, + constantMask.getValueAttr()) + .getResult(); + + return std::nullopt; +} + +static bool tryReplaceDataEnsure(VMIEnsureLayoutOp ensure) { + auto resultType = dyn_cast(ensure.getResult().getType()); + if (!resultType) + return false; + + OpBuilder builder(ensure); + auto result = rematerializeDataProducer(ensure.getSource(), resultType, + ensure->getLoc(), builder); + if (!result) + return false; + + ensure.getResult().replaceAllUsesWith(*result); + ensure.erase(); + return true; +} + +template +static bool tryReplaceMaskEnsure(EnsureOp ensure) { + auto resultType = dyn_cast(ensure.getResult().getType()); + if (!resultType) + return false; + + OpBuilder builder(ensure); + auto result = rematerializeMaskProducer(ensure.getSource(), resultType, + ensure->getLoc(), builder); + if (!result) + return false; + + ensure.getResult().replaceAllUsesWith(*result); + ensure.erase(); + return true; +} + +struct VMILayoutRematerializePass + : public mlir::pto::impl::VMILayoutRematerializeBase< + VMILayoutRematerializePass> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VMILayoutRematerializePass) + + void runOnOperation() override { + ModuleOp module = getOperation(); + SmallVector helpers; + module.walk([&](Operation *op) { + if (isa(op)) + helpers.push_back(op); + }); + + for (Operation *op : helpers) { + if (op->getBlock() == nullptr) + continue; + + if (auto ensure = dyn_cast(op)) { + tryReplaceDataEnsure(ensure); + continue; + } + + if (auto ensure = dyn_cast(op)) { + tryReplaceMaskEnsure(ensure); + continue; + } + + if (auto ensure = dyn_cast(op)) + tryReplaceMaskEnsure(ensure); + } + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createVMILayoutRematerializePass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp b/lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp new file mode 100644 index 0000000000..c3bbf67731 --- /dev/null +++ b/lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp @@ -0,0 +1,363 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +//===- VMILayoutSinkMaterialization.cpp - Sink VMI layout helpers --------===// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/STLExtras.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_VMILAYOUTSINKMATERIALIZATION +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +struct BinaryVRegOperands { + OpOperand *lhs = nullptr; + OpOperand *rhs = nullptr; +}; + +struct UnaryVRegOperand { + OpOperand *source = nullptr; +}; + +struct BinaryMaskOperands { + OpOperand *lhs = nullptr; + OpOperand *rhs = nullptr; +}; + +struct UnaryMaskOperand { + OpOperand *source = nullptr; +}; + +static std::optional getSinkableBinaryOperands(Operation *op) { + if (auto addf = dyn_cast(op)) + return BinaryVRegOperands{&addf.getLhsMutable(), &addf.getRhsMutable()}; + if (auto addi = dyn_cast(op)) + return BinaryVRegOperands{&addi.getLhsMutable(), &addi.getRhsMutable()}; + if (auto subf = dyn_cast(op)) + return BinaryVRegOperands{&subf.getLhsMutable(), &subf.getRhsMutable()}; + if (auto subi = dyn_cast(op)) + return BinaryVRegOperands{&subi.getLhsMutable(), &subi.getRhsMutable()}; + if (auto mulf = dyn_cast(op)) + return BinaryVRegOperands{&mulf.getLhsMutable(), &mulf.getRhsMutable()}; + if (auto muli = dyn_cast(op)) + return BinaryVRegOperands{&muli.getLhsMutable(), &muli.getRhsMutable()}; + if (auto divf = dyn_cast(op)) + return BinaryVRegOperands{&divf.getLhsMutable(), &divf.getRhsMutable()}; + if (auto minf = dyn_cast(op)) + return BinaryVRegOperands{&minf.getLhsMutable(), &minf.getRhsMutable()}; + if (auto maxf = dyn_cast(op)) + return BinaryVRegOperands{&maxf.getLhsMutable(), &maxf.getRhsMutable()}; + if (auto andi = dyn_cast(op)) + return BinaryVRegOperands{&andi.getLhsMutable(), &andi.getRhsMutable()}; + if (auto ori = dyn_cast(op)) + return BinaryVRegOperands{&ori.getLhsMutable(), &ori.getRhsMutable()}; + if (auto xori = dyn_cast(op)) + return BinaryVRegOperands{&xori.getLhsMutable(), &xori.getRhsMutable()}; + if (auto shli = dyn_cast(op)) + return BinaryVRegOperands{&shli.getLhsMutable(), &shli.getRhsMutable()}; + if (auto shrui = dyn_cast(op)) + return BinaryVRegOperands{&shrui.getLhsMutable(), &shrui.getRhsMutable()}; + return std::nullopt; +} + +static std::optional getSinkableUnaryOperand(Operation *op) { + if (auto negf = dyn_cast(op)) + return UnaryVRegOperand{&negf.getSourceMutable()}; + if (auto absf = dyn_cast(op)) + return UnaryVRegOperand{&absf.getSourceMutable()}; + if (auto absi = dyn_cast(op)) + return UnaryVRegOperand{&absi.getSourceMutable()}; + if (auto sqrt = dyn_cast(op)) + return UnaryVRegOperand{&sqrt.getSourceMutable()}; + if (auto exp = dyn_cast(op)) + return UnaryVRegOperand{&exp.getSourceMutable()}; + if (auto ln = dyn_cast(op)) + return UnaryVRegOperand{&ln.getSourceMutable()}; + if (auto relu = dyn_cast(op)) + return UnaryVRegOperand{&relu.getSourceMutable()}; + if (auto notOp = dyn_cast(op)) + return UnaryVRegOperand{¬Op.getSourceMutable()}; + return std::nullopt; +} + +static std::optional +getSinkableBinaryMaskOperands(Operation *op) { + if (auto maskAnd = dyn_cast(op)) + return BinaryMaskOperands{&maskAnd.getLhsMutable(), + &maskAnd.getRhsMutable()}; + if (auto maskOr = dyn_cast(op)) + return BinaryMaskOperands{&maskOr.getLhsMutable(), + &maskOr.getRhsMutable()}; + if (auto maskXor = dyn_cast(op)) + return BinaryMaskOperands{&maskXor.getLhsMutable(), + &maskXor.getRhsMutable()}; + return std::nullopt; +} + +static std::optional +getSinkableUnaryMaskOperand(Operation *op) { + if (auto maskNot = dyn_cast(op)) + return UnaryMaskOperand{&maskNot.getSourceMutable()}; + return std::nullopt; +} + +static bool isSameMaterialization(VMIEnsureLayoutOp ensure, + VMIVRegType resultType) { + if (!ensure || !resultType) + return false; + + auto sourceType = dyn_cast(ensure.getSource().getType()); + auto ensureResultType = dyn_cast(ensure.getResult().getType()); + if (!sourceType || !ensureResultType) + return false; + + return ensureResultType == resultType && sourceType != resultType; +} + +static bool isSameMaterialization(VMIEnsureLayoutOp lhsEnsure, + VMIEnsureLayoutOp rhsEnsure, + VMIVRegType resultType) { + if (!lhsEnsure || !rhsEnsure || !resultType) + return false; + + auto lhsSourceType = dyn_cast(lhsEnsure.getSource().getType()); + auto rhsSourceType = dyn_cast(rhsEnsure.getSource().getType()); + auto lhsResultType = dyn_cast(lhsEnsure.getResult().getType()); + auto rhsResultType = dyn_cast(rhsEnsure.getResult().getType()); + if (!lhsSourceType || !rhsSourceType || !lhsResultType || !rhsResultType) + return false; + + return lhsSourceType == rhsSourceType && lhsResultType == rhsResultType && + lhsResultType == resultType && lhsSourceType != resultType; +} + +template +static bool isSameMaskMaterialization(EnsureOp ensure, VMIMaskType resultType) { + if (!ensure || !resultType) + return false; + + auto sourceType = dyn_cast(ensure.getSource().getType()); + auto ensureResultType = dyn_cast(ensure.getResult().getType()); + if (!sourceType || !ensureResultType) + return false; + + return ensureResultType == resultType && sourceType != resultType; +} + +template +static bool isSameMaskMaterialization(EnsureOp lhsEnsure, EnsureOp rhsEnsure, + VMIMaskType resultType) { + if (!lhsEnsure || !rhsEnsure || !resultType) + return false; + + auto lhsSourceType = dyn_cast(lhsEnsure.getSource().getType()); + auto rhsSourceType = dyn_cast(rhsEnsure.getSource().getType()); + auto lhsResultType = dyn_cast(lhsEnsure.getResult().getType()); + auto rhsResultType = dyn_cast(rhsEnsure.getResult().getType()); + if (!lhsSourceType || !rhsSourceType || !lhsResultType || !rhsResultType) + return false; + + return lhsSourceType == rhsSourceType && lhsResultType == rhsResultType && + lhsResultType == resultType && lhsSourceType != resultType; +} + +static bool trySinkBinaryMaterialization(Operation *op) { + std::optional operands = getSinkableBinaryOperands(op); + if (!operands || op->getNumResults() != 1) + return false; + + auto resultType = dyn_cast(op->getResult(0).getType()); + if (!resultType) + return false; + + auto lhsEnsure = operands->lhs->get().getDefiningOp(); + auto rhsEnsure = operands->rhs->get().getDefiningOp(); + if (!isSameMaterialization(lhsEnsure, rhsEnsure, resultType)) + return false; + + auto sourceType = cast(lhsEnsure.getSource().getType()); + OpBuilder builder(op); + OperationState state(op->getLoc(), op->getName()); + state.addOperands({lhsEnsure.getSource(), rhsEnsure.getSource()}); + state.addTypes(sourceType); + state.addAttributes(op->getAttrs()); + Operation *newOp = builder.create(state); + + builder.setInsertionPointAfter(newOp); + auto resultEnsure = builder.create( + op->getLoc(), resultType, newOp->getResult(0)); + op->getResult(0).replaceAllUsesWith(resultEnsure.getResult()); + op->erase(); + + if (lhsEnsure->use_empty()) + lhsEnsure.erase(); + if (rhsEnsure != lhsEnsure && rhsEnsure->use_empty()) + rhsEnsure.erase(); + return true; +} + +template +static bool trySinkBinaryMaskMaterialization(Operation *op) { + std::optional operands = getSinkableBinaryMaskOperands(op); + if (!operands || op->getNumResults() != 1) + return false; + + auto resultType = dyn_cast(op->getResult(0).getType()); + if (!resultType) + return false; + + auto lhsEnsure = operands->lhs->get().getDefiningOp(); + auto rhsEnsure = operands->rhs->get().getDefiningOp(); + if (!isSameMaskMaterialization(lhsEnsure, rhsEnsure, resultType)) + return false; + + auto sourceType = cast(lhsEnsure.getSource().getType()); + OpBuilder builder(op); + OperationState state(op->getLoc(), op->getName()); + state.addOperands({lhsEnsure.getSource(), rhsEnsure.getSource()}); + state.addTypes(sourceType); + state.addAttributes(op->getAttrs()); + Operation *newOp = builder.create(state); + + builder.setInsertionPointAfter(newOp); + auto resultEnsure = + builder.create(op->getLoc(), resultType, newOp->getResult(0)); + op->getResult(0).replaceAllUsesWith(resultEnsure.getResult()); + op->erase(); + + if (lhsEnsure->use_empty()) + lhsEnsure.erase(); + if (rhsEnsure != lhsEnsure && rhsEnsure->use_empty()) + rhsEnsure.erase(); + return true; +} + +static bool trySinkUnaryMaterialization(Operation *op) { + std::optional operand = getSinkableUnaryOperand(op); + if (!operand || op->getNumResults() != 1) + return false; + + auto resultType = dyn_cast(op->getResult(0).getType()); + if (!resultType) + return false; + + auto sourceEnsure = + operand->source->get().getDefiningOp(); + if (!isSameMaterialization(sourceEnsure, resultType)) + return false; + + auto sourceType = cast(sourceEnsure.getSource().getType()); + OpBuilder builder(op); + OperationState state(op->getLoc(), op->getName()); + state.addOperands(sourceEnsure.getSource()); + state.addTypes(sourceType); + state.addAttributes(op->getAttrs()); + Operation *newOp = builder.create(state); + + builder.setInsertionPointAfter(newOp); + auto resultEnsure = builder.create( + op->getLoc(), resultType, newOp->getResult(0)); + op->getResult(0).replaceAllUsesWith(resultEnsure.getResult()); + op->erase(); + + if (sourceEnsure->use_empty()) + sourceEnsure.erase(); + return true; +} + +template +static bool trySinkUnaryMaskMaterialization(Operation *op) { + std::optional operand = getSinkableUnaryMaskOperand(op); + if (!operand || op->getNumResults() != 1) + return false; + + auto resultType = dyn_cast(op->getResult(0).getType()); + if (!resultType) + return false; + + auto sourceEnsure = + operand->source->get().getDefiningOp(); + if (!isSameMaskMaterialization(sourceEnsure, resultType)) + return false; + + auto sourceType = cast(sourceEnsure.getSource().getType()); + OpBuilder builder(op); + OperationState state(op->getLoc(), op->getName()); + state.addOperands(sourceEnsure.getSource()); + state.addTypes(sourceType); + state.addAttributes(op->getAttrs()); + Operation *newOp = builder.create(state); + + builder.setInsertionPointAfter(newOp); + auto resultEnsure = + builder.create(op->getLoc(), resultType, newOp->getResult(0)); + op->getResult(0).replaceAllUsesWith(resultEnsure.getResult()); + op->erase(); + + if (sourceEnsure->use_empty()) + sourceEnsure.erase(); + return true; +} + +static bool trySinkMaskMaterialization(Operation *op) { + return trySinkBinaryMaskMaterialization(op) || + trySinkBinaryMaskMaterialization(op) || + trySinkUnaryMaskMaterialization(op) || + trySinkUnaryMaskMaterialization(op); +} + +struct VMILayoutSinkMaterializationPass + : public mlir::pto::impl::VMILayoutSinkMaterializationBase< + VMILayoutSinkMaterializationPass> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + VMILayoutSinkMaterializationPass) + + void runOnOperation() override { + ModuleOp module = getOperation(); + SmallVector candidates; + module.walk([&](Operation *op) { + if (getSinkableBinaryOperands(op) || getSinkableUnaryOperand(op) || + getSinkableBinaryMaskOperands(op) || getSinkableUnaryMaskOperand(op)) + candidates.push_back(op); + }); + + for (Operation *op : candidates) { + if (op->getBlock() == nullptr) + continue; + if (!trySinkBinaryMaterialization(op)) { + if (!trySinkUnaryMaterialization(op)) + trySinkMaskMaterialization(op); + } + } + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createVMILayoutSinkMaterializationPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/VMILegalizeArithSelect.cpp b/lib/PTO/Transforms/VMILegalizeArithSelect.cpp new file mode 100644 index 0000000000..471215985f --- /dev/null +++ b/lib/PTO/Transforms/VMILegalizeArithSelect.cpp @@ -0,0 +1,88 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +//===- VMILegalizeArithSelect.cpp - Legalize VMI arith.select ------------===// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/STLExtras.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_VMILEGALIZEARITHSELECT +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +static bool isVMIValueType(Type type) { + return isa(type); +} + +static bool hasScalarI1Condition(arith::SelectOp select) { + return select.getCondition().getType().isSignlessInteger(1); +} + +static void rewriteSelectToIf(arith::SelectOp select) { + OpBuilder builder(select); + auto ifOp = builder.create( + select.getLoc(), TypeRange{select.getResult().getType()}, + select.getCondition(), /*withElseRegion=*/true); + + { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + builder.create(select.getLoc(), select.getTrueValue()); + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + builder.create(select.getLoc(), select.getFalseValue()); + } + + select.getResult().replaceAllUsesWith(ifOp.getResult(0)); + select.erase(); +} + +struct VMILegalizeArithSelectPass + : public mlir::pto::impl::VMILegalizeArithSelectBase< + VMILegalizeArithSelectPass> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VMILegalizeArithSelectPass) + + void runOnOperation() override { + ModuleOp module = getOperation(); + SmallVector selects; + module.walk([&](arith::SelectOp select) { + if (isVMIValueType(select.getResult().getType()) && + hasScalarI1Condition(select)) + selects.push_back(select); + }); + + for (arith::SelectOp select : llvm::reverse(selects)) { + if (select->getBlock() != nullptr) + rewriteSelectToIf(select); + } + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createVMILegalizeArithSelectPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/VMILocalRecipeRegistry.cpp b/lib/PTO/Transforms/VMILocalRecipeRegistry.cpp new file mode 100644 index 0000000000..7364084028 --- /dev/null +++ b/lib/PTO/Transforms/VMILocalRecipeRegistry.cpp @@ -0,0 +1,1006 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +//===- VMILocalRecipeRegistry.cpp - VMI local recipe queries --------------===// +//===----------------------------------------------------------------------===// + +#include "PTO/Transforms/VMILocalRecipeRegistry.h" + +#include "PTO/IR/PTOTypeUtils.h" +#include "PTO/IR/VMIUtils.h" +#include "PTO/Transforms/VMITargetCapabilities.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "llvm/ADT/Twine.h" + +#include + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +static LogicalResult failWithReason(const Twine &message, std::string *reason) { + if (reason) + *reason = message.str(); + return failure(); +} + +static LogicalResult checkFullDataPhysicalChunks(VMIVRegType type, + std::string *reason) { + FailureOr lanesPerPart = getDataLanesPerPart(type.getElementType()); + if (failed(lanesPerPart)) + return failWithReason("requires known physical lanes per part", reason); + + FailureOr arity = getVMIPhysicalArity(type); + if (failed(arity)) + return failWithReason("requires computable physical arity", reason); + + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout) + return failWithReason("requires assigned layout", reason); + int64_t factor = layout.isDeinterleaved() ? layout.getFactor() : 1; + if (factor <= 0 || *arity % factor != 0) + return failWithReason("requires arity divisible by layout factor", reason); + + int64_t chunksPerPart = *arity / factor; + for (int64_t part = 0; part < factor; ++part) { + for (int64_t chunk = 0; chunk < chunksPerPart; ++chunk) { + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = isPaddingLane(type, part, chunk, lane); + if (failed(padding)) + return failWithReason("failed to map physical padding lane", reason); + if (*padding) + return failWithReason("found padding lane in physical chunk", reason); + } + } + } + + return success(); +} + +static bool hasX2MemoryDistToken(Type elementType) { + unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); + return elementBits == 8 || elementBits == 16 || elementBits == 32; +} + +static std::optional getConstantIndexValue(Value value) { + if (auto constant = value.getDefiningOp()) + return constant.value(); + if (auto constant = value.getDefiningOp()) { + if (constant.getType().isIndex()) + return constant.value(); + } + return std::nullopt; +} + +static int64_t ceilDivNonNegative(int64_t lhs, int64_t rhs) { + assert(lhs >= 0 && rhs > 0); + return (lhs + rhs - 1) / rhs; +} + +static FailureOr getVMITypeElementCount(Type type) { + if (auto vregType = dyn_cast(type)) + return vregType.getElementCount(); + if (auto maskType = dyn_cast(type)) + return maskType.getElementCount(); + return failure(); +} + +static FailureOr getVMITypeLayoutFactor(Type type) { + VMILayoutAttr layout; + if (auto vregType = dyn_cast(type)) + layout = vregType.getLayoutAttr(); + else if (auto maskType = dyn_cast(type)) + layout = maskType.getLayoutAttr(); + else + return failure(); + if (!layout) + return failure(); + return layout.isDeinterleaved() ? layout.getFactor() : 1; +} + +static FailureOr getVMITypeLanesPerPart(Type type) { + if (auto vregType = dyn_cast(type)) + return getDataLanesPerPart(vregType.getElementType()); + if (auto maskType = dyn_cast(type)) + return getMaskLanesPerPart(maskType.getGranularity()); + return failure(); +} + +static FailureOr getVMITypeChunksInPart(Type type, int64_t part) { + FailureOr elementCount = getVMITypeElementCount(type); + FailureOr factor = getVMITypeLayoutFactor(type); + FailureOr lanesPerPart = getVMITypeLanesPerPart(type); + if (failed(elementCount) || failed(factor) || failed(lanesPerPart) || + part < 0 || part >= *factor) + return failure(); + + int64_t logicalLanesInPart = (*elementCount + *factor - 1 - part) / *factor; + return ceilDivNonNegative(logicalLanesInPart, *lanesPerPart); +} + +static LogicalResult checkFullVMIPhysicalChunks(Type type, + std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + FailureOr factor = getVMITypeLayoutFactor(type); + FailureOr lanesPerPart = getVMITypeLanesPerPart(type); + if (failed(factor) || failed(lanesPerPart)) + return fail("requires assigned layout with known physical lanes per part"); + + for (int64_t part = 0; part < *factor; ++part) { + FailureOr chunks = getVMITypeChunksInPart(type, part); + if (failed(chunks)) + return fail("requires known physical chunks"); + for (int64_t chunk = 0; chunk < *chunks; ++chunk) { + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = isPaddingLane(type, part, chunk, lane); + if (failed(padding)) + return fail("failed to map physical padding lane"); + if (*padding) + return fail("found padding lane in physical chunk"); + } + } + } + + return success(); +} + +static FailureOr +getContiguousMaterializationPartCount(Type type, std::string *reason) { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + FailureOr arity = getVMIPhysicalArity(type); + FailureOr factor = getVMITypeLayoutFactor(type); + if (failed(arity) || failed(factor)) + return fail("requires computable physical arity and assigned layout"); + + VMILayoutAttr layout; + if (auto vregType = dyn_cast(type)) + layout = vregType.getLayoutAttr(); + else if (auto maskType = dyn_cast(type)) + layout = maskType.getLayoutAttr(); + else + return fail("requires VMI data or mask type"); + + if (!layout) + return fail("requires assigned layout"); + if (layout.isContiguous()) + return *arity; + if (!layout.isDeinterleaved() || + (layout.getFactor() != 2 && layout.getFactor() != 4)) + return fail("requires contiguous or deinterleaved=2/4 layout"); + + FailureOr chunksPerGroup = getVMITypeChunksInPart(type, 0); + if (failed(chunksPerGroup)) + return fail("requires known physical chunks per part"); + if (*chunksPerGroup == 0) + return fail("requires at least one physical chunk per part"); + + for (int64_t part = 1; part < *factor; ++part) { + FailureOr chunks = getVMITypeChunksInPart(type, part); + if (failed(chunks)) + return fail("requires known physical chunks per part"); + if (*chunks != *chunksPerGroup) + return fail("requires every deinterleaved part to have the same " + "physical chunk count"); + } + + return *arity; +} + +static LogicalResult checkLayoutMaterializationShape(Type sourceType, + Type resultType, + VMILayoutAttr sourceLayout, + VMILayoutAttr resultLayout, + std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(resultArity)) + return fail("requires computable source/result physical arity"); + if (*sourceArity != *resultArity) + return fail("requires source and result to have the same physical arity"); + + if (sourceLayout == resultLayout) + return success(); + + std::string sourceReason; + std::string resultReason; + LogicalResult sourceFull = + checkFullVMIPhysicalChunks(sourceType, &sourceReason); + LogicalResult resultFull = + checkFullVMIPhysicalChunks(resultType, &resultReason); + if (succeeded(sourceFull) && succeeded(resultFull)) + return success(); + + std::string sourceMaterializationReason; + FailureOr sourceMaterializedParts = + getContiguousMaterializationPartCount(sourceType, + &sourceMaterializationReason); + std::string resultMaterializationReason; + FailureOr resultMaterializedParts = + getContiguousMaterializationPartCount(resultType, + &resultMaterializationReason); + if (succeeded(sourceMaterializedParts) && + succeeded(resultMaterializedParts) && + *sourceMaterializedParts == *sourceArity && + *resultMaterializedParts == *resultArity) + return success(); + + if (failed(sourceFull)) + return fail(Twine("source ") + sourceReason + "; source materialization " + + sourceMaterializationReason); + return fail(Twine("result ") + resultReason + "; result materialization " + + resultMaterializationReason); +} + +static FailureOr getGroupSizeFromNumGroups(VMIVRegType type, + int64_t numGroups, + std::string *reason) { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + if (numGroups <= 0) + return fail("requires num_groups to be positive"); + if (type.getElementCount() % numGroups != 0) + return fail("requires num_groups to evenly divide logical lane count"); + return type.getElementCount() / numGroups; +} + +static FailureOr getDataLayoutFactor(VMIVRegType type) { + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout) + return failure(); + return layout.isDeinterleaved() ? layout.getFactor() : 1; +} + +static FailureOr> +getPhysicalLogicalBitFootprint(VMIVRegType type) { + unsigned elementBits = pto::getPTOStorageElemBitWidth(type.getElementType()); + if (elementBits == 0) + return failure(); + + FailureOr factor = getDataLayoutFactor(type); + FailureOr lanesPerPart = getDataLanesPerPart(type.getElementType()); + FailureOr arity = getVMIPhysicalArity(type); + if (failed(factor) || failed(lanesPerPart) || failed(arity) || *factor <= 0) + return failure(); + + SmallVector bits; + bits.reserve(*arity); + for (int64_t part = 0; part < *factor; ++part) { + for (int64_t chunk = 0; chunk < *arity; ++chunk) { + int64_t activeLanes = 0; + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = isPaddingLane(type, part, chunk, lane); + if (failed(padding)) + return failure(); + if (!*padding) + ++activeLanes; + } + if (activeLanes > 0) + bits.push_back(activeLanes * static_cast(elementBits)); + } + } + if (static_cast(bits.size()) != *arity) + return failure(); + return bits; +} + +static FailureOr +getLayoutMaterializationRecipe(VMILayoutAttr sourceLayout, + VMILayoutAttr resultLayout, + std::string *reason) { + auto fail = [&](const Twine &message) + -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (!sourceLayout || !resultLayout) + return fail("requires assigned source/result layouts"); + if (sourceLayout == resultLayout) + return VMILayoutMaterializationRecipe{ + VMILayoutMaterializationRecipeKind::Identity}; + if (sourceLayout.isContiguous() && resultLayout.isDeinterleaved() && + (resultLayout.getFactor() == 2 || resultLayout.getFactor() == 4)) + return VMILayoutMaterializationRecipe{ + VMILayoutMaterializationRecipeKind::ContiguousToDeinterleaved}; + if (sourceLayout.isDeinterleaved() && resultLayout.isContiguous() && + (sourceLayout.getFactor() == 2 || sourceLayout.getFactor() == 4)) + return VMILayoutMaterializationRecipe{ + VMILayoutMaterializationRecipeKind::DeinterleavedToContiguous}; + return fail("unsupported source/result layout pair"); +} + +} // namespace + +FailureOr +VMILocalRecipeRegistry::getContiguousStoreRecipe(VMIVRegType valueType, + std::string *reason) const { + auto fail = [&](const Twine &message) + -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + VMILayoutAttr layout = valueType.getLayoutAttr(); + if (!layout) + return fail("requires assigned value layout"); + if (layout.isContiguous()) + return VMIContiguousStoreRecipe{ + VMIContiguousStoreRecipeKind::ContiguousVsts}; + if (!layout.isDeinterleaved()) + return fail("requires contiguous or deinterleaved value layout"); + if (layout.getBlockElems() != 1) + return fail("requires block_elems=1 deinterleaved value layout"); + if (failed(checkFullDataPhysicalChunks(valueType, reason))) + return failure(); + + if (layout.getFactor() == 2) { + if (!hasX2MemoryDistToken(valueType.getElementType())) + return fail("requires 8/16/32-bit element type for vstsx2 INTLV"); + return VMIContiguousStoreRecipe{ + VMIContiguousStoreRecipeKind::Deinterleaved2Vstsx2}; + } + + if (layout.getFactor() == 4) + return VMIContiguousStoreRecipe{ + VMIContiguousStoreRecipeKind::DeinterleavedMaterializeThenVsts}; + + return fail("requires deinterleaved factor 2 or 4"); +} + +LogicalResult VMILocalRecipeRegistry::canFoldContiguousStoreMaterialization( + VMIVRegType sourceType, VMIVRegType resultType, std::string *reason) const { + if (sourceType.getElementType() != resultType.getElementType()) + return failWithReason("source/result element types must match", reason); + if (sourceType.getElementCount() != resultType.getElementCount()) + return failWithReason("source/result element counts must match", reason); + + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!resultLayout || !resultLayout.isContiguous()) + return failWithReason("result layout must be contiguous", reason); + + FailureOr recipe = + getContiguousStoreRecipe(sourceType, reason); + if (failed(recipe)) + return failure(); + if (recipe->kind == VMIContiguousStoreRecipeKind::ContiguousVsts) + return failWithReason("source layout is already contiguous", reason); + + return success(); +} + +FailureOr +VMILocalRecipeRegistry::getDataLayoutMaterializationRecipe( + VMIVRegType sourceType, VMIVRegType resultType, + std::string *reason) const { + auto fail = [&](const Twine &message) + -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (sourceType.getElementType() != resultType.getElementType()) + return fail("source/result element types must match"); + if (sourceType.getElementCount() != resultType.getElementCount()) + return fail("source/result element counts must match"); + + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + FailureOr recipe = + getLayoutMaterializationRecipe(sourceLayout, resultLayout, reason); + if (failed(recipe)) + return failure(); + if (failed(checkLayoutMaterializationShape(sourceType, resultType, + sourceLayout, resultLayout, + reason))) + return failure(); + return recipe; +} + +FailureOr +VMILocalRecipeRegistry::getMaskLayoutMaterializationRecipe( + VMIMaskType sourceType, VMIMaskType resultType, + std::string *reason) const { + auto fail = [&](const Twine &message) + -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (sourceType.getElementCount() != resultType.getElementCount()) + return fail("source/result mask element counts must match"); + if (sourceType.getGranularity() != resultType.getGranularity()) + return fail("source/result mask granularities must match"); + + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + FailureOr recipe = + getLayoutMaterializationRecipe(sourceLayout, resultLayout, reason); + if (failed(recipe)) + return failure(); + if (failed(checkLayoutMaterializationShape(sourceType, resultType, + sourceLayout, resultLayout, + reason))) + return failure(); + return recipe; +} + +FailureOr +VMILocalRecipeRegistry::getMaskGranularityMaterializationRecipe( + VMIMaskType sourceType, VMIMaskType resultType, + std::string *reason) const { + auto fail = [&](const Twine &message) + -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (sourceType.getElementCount() != resultType.getElementCount()) + return fail("source/result mask element counts must match"); + if (sourceType.getLayoutAttr() != resultType.getLayoutAttr()) + return fail("source/result mask layouts must match"); + if (!VMIMaskType::isConcreteGranularity(sourceType.getGranularity()) || + !VMIMaskType::isConcreteGranularity(resultType.getGranularity())) + return fail("requires concrete b8/b16/b32 source and result granularities"); + if (sourceType.getGranularity() == resultType.getGranularity()) + return VMIMaskGranularityMaterializationRecipe{ + VMIMaskGranularityMaterializationRecipeKind::Identity}; + + return VMIMaskGranularityMaterializationRecipe{ + VMIMaskGranularityMaterializationRecipeKind::PredicateCast}; +} + +FailureOr +VMILocalRecipeRegistry::getGroupSlotLoadRecipe( + const VMITargetCapabilityRegistry &capabilities, VMIGroupSlotLoadOp op, + std::string *reason) const { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr layout = resultType.getLayoutAttr(); + if (!layout || !layout.isGroupSlots() || + layout.getNumGroups() != op.getNumGroupsAttr().getInt() || + layout.getSlots() <= 0) + return fail("requires explicit group_slots result layout matching " + "num_groups"); + + if (layout.getSlots() != 8 && layout.getSlots() != 1) + return fail("supports only slots=8 or slots=1 group_slot_load layouts"); + + if (!capabilities.supportsDirectMemory(op.getSource().getType(), "source") + .isSupported()) + return fail("requires supported direct memory source"); + if (!isa(op.getSource().getType())) + return fail("requires !pto.ptr source for vsldb lowering"); + + std::optional stride = + getConstantIndexValue(op.getSourceGroupStride()); + if (layout.getSlots() == 8) { + if (!stride || *stride != 1) + return fail("slots=8 group_slot_load requires constant unit " + "source_group_stride"); + return VMIGroupSlotLoadRecipe{ + VMIGroupSlotLoadRecipeKind::Slots8UnitStrideVsldb}; + } + + unsigned elementBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); + if (elementBits == 0 || 256 % elementBits != 0) + return fail("slots=1 group_slot_load requires an 8/16/32-bit element " + "type"); + int64_t alignedStrideElems = 256 / elementBits; + if (!stride || *stride <= 0 || *stride % alignedStrideElems != 0) + return fail(Twine("slots=1 group_slot_load currently lowers as one " + "lane-0 vsldb per group and requires constant " + "positive source_group_stride divisible by ") + + Twine(alignedStrideElems) + + " elements for 32B load alignment; packed or unaligned " + "scalar load lowering is not implemented"); + + return VMIGroupSlotLoadRecipe{ + VMIGroupSlotLoadRecipeKind::Slots1AlignedLane0Vsldb}; +} + +FailureOr VMILocalRecipeRegistry::getGroupLoadRecipe( + const VMITargetCapabilityRegistry &capabilities, VMIGroupLoadOp op, + std::string *reason) const { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr layout = resultType.getLayoutAttr(); + if (!layout || !layout.isDeinterleaved() || layout.getBlockElems() != 8 || + !resultType.getElementType().isF32()) + return fail("requires deinterleaved block8 f32 result layout"); + + FailureOr groupSize = + getGroupSizeFromNumGroups(resultType, op.getNumGroupsAttr().getInt(), + reason); + if (failed(groupSize)) + return failure(); + + if ((*groupSize != 16 || layout.getFactor() != 2) && + (*groupSize != 32 || layout.getFactor() != 4)) + return fail("block8 strided group_load requires S=16/factor=2 or " + "S=32/factor=4"); + + if (!capabilities.supportsDirectMemory(op.getSource().getType(), "source") + .isSupported()) + return fail("requires supported direct memory source"); + if (!isa(op.getSource().getType())) + return fail("block8 strided group_load requires !pto.ptr source"); + + if (op.getNumGroupsAttr().getInt() % 8 != 0) + return fail("block8 strided group_load requires num_groups multiple of 8"); + + std::optional rowStride = getConstantIndexValue(op.getRowStride()); + if (!rowStride || *rowStride <= 0 || *rowStride % 8 != 0) + return fail("block8 strided group_load requires constant positive " + "row_stride divisible by 8 f32 elements"); + + std::string fullChunkReason; + if (failed(checkFullDataPhysicalChunks(resultType, &fullChunkReason))) + return fail(Twine("block8 strided group_load requires full physical " + "result chunks; ") + + fullChunkReason); + + if (*groupSize == 16) + return VMIGroupLoadRecipe{VMIGroupLoadRecipeKind::S16Block8Vsldb}; + return VMIGroupLoadRecipe{VMIGroupLoadRecipeKind::S32Block8Vsldb}; +} + +FailureOr +VMILocalRecipeRegistry::getGroupSlotsStoreRecipe( + const VMITargetCapabilityRegistry &capabilities, VMIGroupStoreOp op, + std::string *reason) const { + auto fail = + [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto valueType = cast(op.getValue().getType()); + VMILayoutAttr layout = valueType.getLayoutAttr(); + if (!layout || !layout.isGroupSlots()) + return fail("requires group_slots value layout"); + + int64_t numGroups = op.getNumGroupsAttr().getInt(); + if (layout.getNumGroups() != numGroups) + return fail("group_slots group_store requires layout num_groups to " + "match op num_groups"); + + VMICapabilityResult elementCapability = capabilities.supportsElementType( + valueType.getElementType(), VMIElementPurpose::PredicateMask); + if (!elementCapability.isSupported()) + return fail(elementCapability.reason); + + FailureOr arity = getVMIPhysicalArity(valueType); + if (failed(arity) || *arity < 1) + return fail("requires computable non-empty physical vreg parts"); + + if (layout.getSlots() == 1) { + if (*arity != numGroups) + return fail("slots=1 group_store requires one physical part per " + "group"); + unsigned elementBits = + pto::getPTOStorageElemBitWidth(valueType.getElementType()); + if (elementBits == 0 || 256 % elementBits != 0) + return fail("slots=1 group_store requires an 8/16/32-bit element " + "type"); + int64_t alignedStrideElems = 256 / elementBits; + std::optional rowStride = getConstantIndexValue(op.getRowStride()); + if (!rowStride || *rowStride <= 0 || *rowStride % alignedStrideElems != 0) + return fail(Twine("slots=1 group_store currently lowers as one " + "lane-0 vsts per group and requires constant " + "positive row_stride divisible by ") + + Twine(alignedStrideElems) + + " elements for 32B store alignment; packed or unaligned " + "contiguous store lowering is not implemented"); + return VMIGroupSlotsStoreRecipe{ + VMIGroupSlotsStoreRecipeKind::Slots1AlignedLane0Vsts}; + } + + if (layout.getSlots() == 8) { + std::optional rowStride = getConstantIndexValue(op.getRowStride()); + if (!rowStride || *rowStride != 1) + return fail("slots=8 group_store currently requires constant unit " + "row_stride"); + if (*arity != ceilDivNonNegative(numGroups, 8)) + return fail("slots=8 group_store arity must equal ceil(num_groups / " + "8)"); + return VMIGroupSlotsStoreRecipe{ + VMIGroupSlotsStoreRecipeKind::Slots8UnitStrideVsts}; + } + + return fail("group_slots group_store currently supports only slots=1 or " + "unit-stride slots=8"); +} + +FailureOr +VMILocalRecipeRegistry::getGroupReduceAddFRecipe( + const VMITargetCapabilityRegistry &capabilities, VMIGroupReduceAddFOp op, + std::string *reason) const { + auto fail = + [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (!op->hasAttr("reassoc")) + return fail("requires reassoc attr for pair-wise floating-point " + "reduction"); + + auto sourceType = cast(op.getSource().getType()); + auto maskType = cast(op.getMask().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + int64_t numGroups = op.getNumGroupsAttr().getInt(); + if (!sourceLayout || !maskLayout || !resultLayout) + return fail("requires assigned source, mask, and result layouts"); + if (!resultLayout.isGroupSlots() || resultLayout.getNumGroups() != numGroups) + return fail("requires group_slots result layout matching num_groups"); + if (resultLayout.getSlots() != 8 && resultLayout.getSlots() != 1) { + FailureOr groupSize = + getGroupSizeFromNumGroups(sourceType, numGroups, reason); + if (succeeded(groupSize) && resultLayout.getSlots() <= 0 && + *groupSize != 8 && *groupSize != 16 && *groupSize != 32) + return fail("stable group_reduce_addf slots=8 recipes support group " + "size 8, 16, or 32"); + return fail("stable group_reduce_addf local recipes currently require " + "result layout slots=8 or slots=1"); + } + + VMICapabilityResult elementCapability = + capabilities.supportsReductionElementType(VMIReductionKind::AddF, + sourceType.getElementType()); + if (!elementCapability.isSupported()) + return fail(elementCapability.reason); + if (!sourceType.getElementType().isF32() || + sourceType.getElementType() != resultType.getElementType()) + return fail("stable group_reduce_addf local recipes require f32 " + "source/result"); + if (sourceType.getElementCount() != resultType.getElementCount()) + return fail("requires source/result lane count to match"); + + FailureOr groupSize = + getGroupSizeFromNumGroups(sourceType, numGroups, reason); + if (failed(groupSize)) + return failure(); + + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(maskArity) || failed(resultArity)) + return fail("requires computable source/mask/result physical arity"); + if (*sourceArity < 1 || *maskArity != *sourceArity) + return fail("requires matching non-empty source/mask physical arity"); + + if (resultLayout.getSlots() == 1) { + if (*groupSize != 64) + return fail("stable group_reduce_addf slots=1 recipes support group " + "size 64"); + if (!sourceLayout.isContiguous() || !maskLayout.isContiguous()) + return fail("s64 group_reduce_addf requires contiguous source/mask " + "layouts"); + if (*resultArity != *sourceArity) + return fail("s64 group_reduce_addf requires source/result physical " + "arity to match"); + std::string sourceFullReason; + if (failed(checkFullDataPhysicalChunks(sourceType, &sourceFullReason))) + return fail(Twine("s64 group_reduce_addf requires full source chunks; ") + + sourceFullReason); + return VMIGroupReduceAddFRecipe{ + VMIGroupReduceAddFRecipeKind::S64ContiguousVcaddRows}; + } + + if (*groupSize == 8) { + if (!sourceLayout.isContiguous() || !maskLayout.isContiguous()) + return fail("s8 group_reduce_addf requires contiguous source/mask " + "layouts"); + std::string sourceFullReason; + if (failed(checkFullDataPhysicalChunks(sourceType, &sourceFullReason))) + return fail(Twine("s8 group_reduce_addf requires full source chunks; ") + + sourceFullReason); + if (*resultArity != *sourceArity) + return fail("s8 group_reduce_addf requires source/result physical " + "arity to match"); + return VMIGroupReduceAddFRecipe{VMIGroupReduceAddFRecipeKind::S8Vcgadd}; + } + + if (*groupSize == 16) { + if (!sourceLayout.isDeinterleaved() || sourceLayout.getFactor() != 2 || + (sourceLayout.getBlockElems() != 1 && + sourceLayout.getBlockElems() != 8)) + return fail("s16 group_reduce_addf requires source layout " + "deinterleaved=2 with block_elems=1 or block_elems=8"); + if (!maskLayout.isDeinterleaved() || maskLayout.getFactor() != 2 || + maskLayout.getBlockElems() != sourceLayout.getBlockElems()) + return fail("s16 group_reduce_addf requires matching mask layout " + "deinterleaved=2 with the same block_elems"); + int64_t expectedResultArity = ceilDivNonNegative(numGroups, 8); + if (*resultArity != expectedResultArity || + *sourceArity != *resultArity * 2) + return fail("s16 group_reduce_addf requires two source/mask parts per " + "result part"); + return VMIGroupReduceAddFRecipe{ + VMIGroupReduceAddFRecipeKind::S16Deinterleaved2VcgaddVadd}; + } + + if (*groupSize == 32) { + if (!sourceLayout.isDeinterleaved() || sourceLayout.getFactor() != 4 || + (sourceLayout.getBlockElems() != 1 && + sourceLayout.getBlockElems() != 8)) + return fail("s32 group_reduce_addf requires source layout " + "deinterleaved=4 with block_elems=1 or block_elems=8"); + if (!maskLayout.isDeinterleaved() || maskLayout.getFactor() != 4 || + maskLayout.getBlockElems() != sourceLayout.getBlockElems()) + return fail("s32 group_reduce_addf requires matching mask layout " + "deinterleaved=4 with the same block_elems"); + int64_t expectedResultArity = ceilDivNonNegative(numGroups, 8); + if (*resultArity != expectedResultArity || + *sourceArity != *resultArity * 4) + return fail("s32 group_reduce_addf requires four source/mask parts per " + "result part"); + return VMIGroupReduceAddFRecipe{ + VMIGroupReduceAddFRecipeKind::S32Deinterleaved4VcgaddTree}; + } + + return fail("stable group_reduce_addf slots=8 recipes support group size " + "8, 16, or 32"); +} + +FailureOr +VMILocalRecipeRegistry::getGroupBroadcastRecipe( + const VMITargetCapabilityRegistry &capabilities, VMIGroupBroadcastOp op, + std::string *reason) const { + (void)capabilities; + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + if (sourceType.getElementType() != resultType.getElementType() || + sourceType.getElementCount() != resultType.getElementCount()) + return fail("requires source/result shape and element type to match"); + + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + int64_t numGroups = op.getNumGroupsAttr().getInt(); + if (!sourceLayout || !resultLayout) + return fail("requires assigned source/result layouts"); + if (!sourceLayout.isGroupSlots() || sourceLayout.getNumGroups() != numGroups) + return fail("requires matching num_groups source layout"); + if (resultLayout.isGroupSlots()) + return fail("requires dense result layout"); + if (sourceLayout.getSlots() > 0 && sourceLayout.getSlots() != 8 && + sourceLayout.getSlots() != 1) + return fail("supports only slots=8 or slots=1 group_broadcast source " + "layouts"); + + std::string fullChunkReason; + if (failed(checkFullDataPhysicalChunks(sourceType, &fullChunkReason))) + return fail(Twine("requires full source physical chunks; ") + + fullChunkReason); + if (failed(checkFullDataPhysicalChunks(resultType, &fullChunkReason))) + return fail(Twine("requires full result physical chunks; ") + + fullChunkReason); + + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + FailureOr resultLanesPerPart = + getDataLanesPerPart(resultType.getElementType()); + if (failed(lanesPerPart) || failed(resultLanesPerPart) || + *lanesPerPart != *resultLanesPerPart) + return fail("requires matching physical lanes per part"); + + FailureOr groupSize = + getGroupSizeFromNumGroups(sourceType, numGroups, reason); + if (failed(groupSize)) + return failure(); + if (*lanesPerPart % *groupSize != 0 && *groupSize % *lanesPerPart != 0) + return fail("requires derived group size to divide or be a multiple of " + "physical lanes per part"); + + FailureOr resultFactor = getDataLayoutFactor(resultType); + if (failed(resultFactor)) + return fail("requires known result layout factor"); + if (*resultFactor == 1) + return VMIGroupBroadcastRecipe{ + VMIGroupBroadcastRecipeKind::GroupSlotsVselr}; + + bool blockFragmentSmallGroup = + resultLayout.isDeinterleaved() && resultLayout.getBlockElems() > 1 && + *groupSize < *lanesPerPart && + *lanesPerPart % resultLayout.getBlockElems() == 0; + if (blockFragmentSmallGroup) + return VMIGroupBroadcastRecipe{ + VMIGroupBroadcastRecipeKind::GroupSlotsVselr}; + + int64_t logicalSpanPerResultChunk = *lanesPerPart * *resultFactor; + if (*groupSize < *lanesPerPart || *groupSize % logicalSpanPerResultChunk != 0) + return fail("deinterleaved result requires every physical result chunk to " + "stay within one logical group"); + + return VMIGroupBroadcastRecipe{ + VMIGroupBroadcastRecipeKind::GroupSlotsVselr}; +} + +FailureOr +VMILocalRecipeRegistry::getTruncFRecipe(VMITruncFOp op, + std::string *reason) const { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (!sourceLayout || !resultLayout || failed(sourceArity) || + failed(resultArity)) + return fail("requires assigned source/result layouts and computable " + "physical arity"); + + unsigned resultBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); + + if (sourceLayout.isGroupSlots() || resultLayout.isGroupSlots()) { + if (!sourceLayout.isGroupSlots() || !resultLayout.isGroupSlots() || + sourceLayout.getNumGroups() != resultLayout.getNumGroups() || + sourceLayout.getSlots() != 1 || resultLayout.getSlots() != 1 || + !sourceType.getElementType().isF32() || resultBits != 16 || + *sourceArity != *resultArity) + return fail("group-slot truncf requires matching " + "group_slots(num_groups=G, slots=1) source/result layouts, " + "f32 source, f16 result, and matching physical arity"); + return VMITruncFRecipe{VMITruncFRecipeKind::GroupSlots1F32ToF16}; + } + + if (!sourceLayout.isDeinterleaved() || !resultLayout.isContiguous() || + !sourceType.getElementType().isF32() || *resultArity != 1) + return fail("requires f32 deinterleaved source and contiguous result"); + + if (sourceLayout.getFactor() == 2 && *sourceArity == 2 && resultBits == 16) + return VMITruncFRecipe{ + VMITruncFRecipeKind::Deinterleaved2F32ToContiguousF16}; + if (sourceLayout.getFactor() == 4 && *sourceArity == 4 && resultBits == 8) + return VMITruncFRecipe{ + VMITruncFRecipeKind::Deinterleaved4F32ToContiguousF8}; + + return fail("unsupported deinterleaved truncf factor, arity, or result " + "element width"); +} + +FailureOr +VMILocalRecipeRegistry::getExtFRecipe(VMIExtFOp op, + std::string *reason) const { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (!sourceLayout || !resultLayout || failed(sourceArity) || + failed(resultArity)) + return fail("requires assigned source/result layouts and computable " + "physical arity"); + if (!sourceLayout.isContiguous() || !resultLayout.isDeinterleaved() || + !resultType.getElementType().isF32()) + return fail("requires contiguous source layout and deinterleaved f32 " + "result layout"); + + unsigned sourceBits = + pto::getPTOStorageElemBitWidth(sourceType.getElementType()); + if (sourceBits == 16 && resultLayout.getFactor() == 2 && + *resultArity == 2 * *sourceArity) + return VMIExtFRecipe{ + VMIExtFRecipeKind::ContiguousF16ToDeinterleaved2F32}; + if (sourceBits == 8 && resultLayout.getFactor() == 4 && + *resultArity == 4 * *sourceArity) + return VMIExtFRecipe{ + VMIExtFRecipeKind::ContiguousF8ToDeinterleaved4F32}; + + return fail("unsupported extf source element width, result factor, or " + "physical arity"); +} + +FailureOr +VMILocalRecipeRegistry::getBitcastRecipe(VMIBitcastOp op, + std::string *reason) const { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!sourceLayout || !resultLayout) + return fail("requires assigned source and result layouts"); + if (sourceLayout != resultLayout) + return fail("requires matching source and result layouts"); + if (sourceLayout.isGroupSlots()) + return fail("does not support group_slots layouts"); + + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(resultArity)) + return fail("requires computable source and result physical arity"); + if (*sourceArity != *resultArity) + return fail("requires source and result to have the same physical arity"); + + FailureOr> sourceBits = + getPhysicalLogicalBitFootprint(sourceType); + FailureOr> resultBits = + getPhysicalLogicalBitFootprint(resultType); + if (failed(sourceBits) || failed(resultBits)) + return fail("requires computable physical logical bit footprints"); + if (sourceBits->size() != resultBits->size()) + return fail("requires source and result physical footprint counts to " + "match"); + for (auto [source, result] : llvm::zip_equal(*sourceBits, *resultBits)) { + if (source != result) + return fail("requires matching logical bit footprint in every physical " + "chunk"); + } + + return VMIBitcastRecipe{VMIBitcastRecipeKind::PerPartVbitcast}; +} diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index 5b050d640a..a59b5dbadb 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -18,6 +18,7 @@ #include "PTO/IR/PTOTypeUtils.h" #include "PTO/IR/VMIUtils.h" #include "PTO/Transforms/Passes.h" +#include "PTO/Transforms/VMILocalRecipeRegistry.h" #include "PTO/Transforms/VMITargetCapabilities.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -1200,24 +1201,9 @@ checkSupportedGroupLoadShape(const VMITargetCapabilityRegistry &capabilities, if (resultLayout.isDeinterleaved() && resultLayout.getBlockElems() == 8 && resultType.getElementType().isF32()) { - if ((*groupSize != 16 || resultLayout.getFactor() != 2) && - (*groupSize != 32 || resultLayout.getFactor() != 4)) - return fail("block8 strided group_load requires S=16/factor=2 or " - "S=32/factor=4"); - if (!isa(op.getSource().getType())) - return fail("block8 strided group_load requires !pto.ptr source"); - if (op.getNumGroupsAttr().getInt() % 8 != 0) - return fail("block8 strided group_load requires num_groups multiple " - "of 8"); - std::optional rowStride = getConstantIndexValue(op.getRowStride()); - if (!rowStride || *rowStride <= 0 || *rowStride % 8 != 0) - return fail("block8 strided group_load requires constant positive " - "row_stride divisible by 8 f32 elements"); - std::string fullChunkReason; - if (failed(checkFullDataPhysicalChunks(resultType, &fullChunkReason))) - return fail(Twine("block8 strided group_load requires full physical " - "result chunks; ") + - fullChunkReason); + VMILocalRecipeRegistry recipes; + if (failed(recipes.getGroupLoadRecipe(capabilities, op, reason))) + return failure(); return success(); } @@ -1227,55 +1213,10 @@ checkSupportedGroupLoadShape(const VMITargetCapabilityRegistry &capabilities, LogicalResult checkSupportedGroupSlotLoadShape( const VMITargetCapabilityRegistry &capabilities, VMIGroupSlotLoadOp op, std::string *reason) { - auto fail = [&](const Twine &message) -> LogicalResult { - if (reason) - *reason = message.str(); + VMILocalRecipeRegistry recipes; + if (failed(recipes.getGroupSlotLoadRecipe(capabilities, op, reason))) return failure(); - }; - - auto resultType = cast(op.getResult().getType()); - VMILayoutAttr layout = resultType.getLayoutAttr(); - if (!layout || !layout.isGroupSlots() || - layout.getNumGroups() != op.getNumGroupsAttr().getInt() || - layout.getSlots() <= 0) - return fail("requires explicit group_slots result layout matching " - "num_groups"); - - if (layout.getSlots() != 8 && layout.getSlots() != 1) - return fail("supports only slots=8 or slots=1 group_slot_load layouts"); - - if (!capabilities.supportsDirectMemory(op.getSource().getType(), "source") - .isSupported()) - return fail("requires supported direct memory source"); - if (!isa(op.getSource().getType())) - return fail("requires !pto.ptr source for vsldb lowering"); - if (layout.getSlots() == 8) { - std::optional stride = - getConstantIndexValue(op.getSourceGroupStride()); - if (!stride || *stride != 1) - return fail("slots=8 group_slot_load requires constant unit " - "source_group_stride"); - return success(); - } - if (layout.getSlots() == 1) { - unsigned elementBits = - pto::getPTOStorageElemBitWidth(resultType.getElementType()); - if (elementBits == 0 || 256 % elementBits != 0) - return fail("slots=1 group_slot_load requires an 8/16/32-bit element " - "type"); - int64_t alignedStrideElems = 256 / elementBits; - std::optional stride = - getConstantIndexValue(op.getSourceGroupStride()); - if (!stride || *stride <= 0 || *stride % alignedStrideElems != 0) - return fail(Twine("slots=1 group_slot_load currently lowers as one " - "lane-0 vsldb per group and requires constant " - "positive source_group_stride divisible by ") + - Twine(alignedStrideElems) + - " elements for 32B load alignment; packed or unaligned " - "scalar load lowering is not implemented"); - return success(); - } - llvm_unreachable("unsupported group_slot_load slots should be rejected"); + return success(); } LogicalResult @@ -1301,46 +1242,10 @@ checkSupportedGroupStoreShape(const VMITargetCapabilityRegistry &capabilities, if (!accessPlan.targetCapability.isSupported()) return fail(accessPlan.targetCapability.reason); - if (failed(checkSupportedMaskableVReg(capabilities, valueType, reason))) + VMILocalRecipeRegistry recipes; + if (failed(recipes.getGroupSlotsStoreRecipe(capabilities, op, reason))) return failure(); - - FailureOr arity = getVMIPhysicalArity(valueType); - if (failed(arity)) - return fail("requires computable physical arity"); - if (layout.getSlots() == 1) { - if (*arity != numGroups) - return fail("slots=1 group_store requires one physical part per " - "group"); - unsigned elementBits = - pto::getPTOStorageElemBitWidth(valueType.getElementType()); - if (elementBits == 0 || 256 % elementBits != 0) - return fail("slots=1 group_store requires an 8/16/32-bit element " - "type"); - int64_t alignedStrideElems = 256 / elementBits; - std::optional rowStride = - getConstantIndexValue(op.getRowStride()); - if (!rowStride || *rowStride <= 0 || *rowStride % alignedStrideElems != 0) - return fail(Twine("slots=1 group_store currently lowers as one " - "lane-0 vsts per group and requires constant " - "positive row_stride divisible by ") + - Twine(alignedStrideElems) + - " elements for 32B store alignment; packed or unaligned " - "contiguous store lowering is not implemented"); - return success(); - } - if (layout.getSlots() == 8) { - std::optional rowStride = - getConstantIndexValue(op.getRowStride()); - if (!rowStride || *rowStride != 1) - return fail("slots=8 group_store currently requires constant unit " - "row_stride"); - if (*arity != ceilDivNonNegative(numGroups, 8)) - return fail("slots=8 group_store arity must equal ceil(num_groups / " - "8)"); - return success(); - } - return fail("group_slots group_store currently supports only slots=1 or " - "unit-stride slots=8"); + return success(); } FailureOr groupSize = getGroupSizeFromNumGroups( @@ -3309,6 +3214,13 @@ struct OneToNVMIEnsureLayoutOpPattern OneToNPatternRewriter &rewriter) const override { auto sourceType = cast(op.getSource().getType()); auto resultType = cast(op.getResult().getType()); + VMILocalRecipeRegistry recipes; + std::string recipeReason; + if (failed(recipes.getDataLayoutMaterializationRecipe( + sourceType, resultType, &recipeReason))) + return rewriter.notifyMatchFailure( + op, Twine("ensure_layout has no registered materialization recipe: ") + + recipeReason); VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); VMILayoutAttr resultLayout = resultType.getLayoutAttr(); if (!sourceLayout || !resultLayout) @@ -3336,6 +3248,14 @@ struct OneToNVMIEnsureMaskLayoutOpPattern OneToNPatternRewriter &rewriter) const override { auto sourceType = cast(op.getSource().getType()); auto resultType = cast(op.getResult().getType()); + VMILocalRecipeRegistry recipes; + std::string recipeReason; + if (failed(recipes.getMaskLayoutMaterializationRecipe( + sourceType, resultType, &recipeReason))) + return rewriter.notifyMatchFailure( + op, + Twine("ensure_mask_layout has no registered materialization recipe: ") + + recipeReason); if (sourceType.getGranularity() != resultType.getGranularity()) return rewriter.notifyMatchFailure( op, "mask layout helper cannot also change granularity"); @@ -3367,6 +3287,14 @@ struct OneToNVMIEnsureMaskGranularityOpPattern OneToNPatternRewriter &rewriter) const override { auto sourceType = cast(op.getSource().getType()); auto resultType = cast(op.getResult().getType()); + VMILocalRecipeRegistry recipes; + std::string recipeReason; + if (failed(recipes.getMaskGranularityMaterializationRecipe( + sourceType, resultType, &recipeReason))) + return rewriter.notifyMatchFailure( + op, Twine("ensure_mask_granularity has no registered materialization " + "recipe: ") + + recipeReason); if (sourceType.getLayout() != resultType.getLayout()) return rewriter.notifyMatchFailure( op, "mask granularity helper cannot also change layout"); @@ -4549,7 +4477,6 @@ struct OneToNVMIStoreOpPattern : OneToNOpConversionPattern { op, "store requires known physical lanes per part"); bool fullPhysicalChunks = succeeded(checkFullDataPhysicalChunks(valueVMIType, nullptr)); - FailureOr destination = getSingleValue(op, adaptor.getDestination(), "store destination must convert to one value", rewriter); @@ -4560,9 +4487,12 @@ struct OneToNVMIStoreOpPattern : OneToNOpConversionPattern { return failure(); ValueRange valueParts = adaptor.getValue(); - VMILayoutAttr valueLayout = valueVMIType.getLayoutAttr(); - if (fullPhysicalChunks && valueLayout && valueLayout.isDeinterleaved() && - valueLayout.getFactor() == 2) { + VMILocalRecipeRegistry localRecipes; + FailureOr storeRecipe = + localRecipes.getContiguousStoreRecipe(valueVMIType); + if (succeeded(storeRecipe) && + storeRecipe->kind == + VMIContiguousStoreRecipeKind::Deinterleaved2Vstsx2) { std::optional dist = getX2MemoryDistToken(valueVMIType.getElementType(), "INTLV"); if (dist && !valueParts.empty() && valueParts.size() % 2 == 0) { @@ -5007,7 +4937,6 @@ struct OneToNVMITileWriteOpPattern : OneToNOpConversionPattern { op, "tile_write requires known physical lanes per part"); bool fullPhysicalChunks = succeeded(checkFullDataPhysicalChunks(valueVMIType, nullptr)); - FailureOr destination = getSingleValue( op, adaptor.getDestination(), "tile_write destination must convert to one value", rewriter); @@ -5016,9 +4945,12 @@ struct OneToNVMITileWriteOpPattern : OneToNOpConversionPattern { ValueRange valueParts = adaptor.getValue(); Value zero = rewriter.create(op.getLoc(), 0); - VMILayoutAttr valueLayout = valueVMIType.getLayoutAttr(); - if (fullPhysicalChunks && valueLayout && valueLayout.isDeinterleaved() && - valueLayout.getFactor() == 2) { + VMILocalRecipeRegistry localRecipes; + FailureOr storeRecipe = + localRecipes.getContiguousStoreRecipe(valueVMIType); + if (succeeded(storeRecipe) && + storeRecipe->kind == + VMIContiguousStoreRecipeKind::Deinterleaved2Vstsx2) { std::optional dist = getX2MemoryDistToken(valueVMIType.getElementType(), "INTLV"); if (dist && !valueParts.empty() && valueParts.size() % 2 == 0) { @@ -6862,145 +6794,26 @@ LogicalResult verifyNoResidualVMIIR(ModuleOp module) { return failure(result.wasInterrupted()); } -LogicalResult checkSupportedExtFShape(VMIExtFOp op) { - auto sourceType = cast(op.getSource().getType()); - auto resultType = cast(op.getResult().getType()); - VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); - VMILayoutAttr resultLayout = resultType.getLayoutAttr(); - FailureOr sourceArity = getVMIPhysicalArity(sourceType); - FailureOr resultArity = getVMIPhysicalArity(resultType); - if (!sourceLayout || !resultLayout || failed(sourceArity) || - failed(resultArity) || !sourceLayout.isContiguous() || - !resultLayout.isDeinterleaved() || !resultType.getElementType().isF32()) +LogicalResult checkSupportedExtFShape(VMIExtFOp op, + std::string *reason = nullptr) { + VMILocalRecipeRegistry recipes; + if (failed(recipes.getExtFRecipe(op, reason))) return failure(); - - unsigned sourceBits = - pto::getPTOStorageElemBitWidth(sourceType.getElementType()); - if (sourceBits == 16 && resultLayout.getFactor() == 2 && - *resultArity == 2 * *sourceArity) - return success(); - if (sourceBits == 8 && resultLayout.getFactor() == 4 && - *resultArity == 4 * *sourceArity) - return success(); - return failure(); + return success(); } LogicalResult checkSupportedTruncFShape(VMITruncFOp op, std::string *reason = nullptr) { - auto fail = [&](const Twine &message) -> LogicalResult { - if (reason) - *reason = message.str(); + VMILocalRecipeRegistry recipes; + if (failed(recipes.getTruncFRecipe(op, reason))) return failure(); - }; - - auto sourceType = cast(op.getSource().getType()); - auto resultType = cast(op.getResult().getType()); - VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); - VMILayoutAttr resultLayout = resultType.getLayoutAttr(); - FailureOr sourceArity = getVMIPhysicalArity(sourceType); - FailureOr resultArity = getVMIPhysicalArity(resultType); - if (!sourceLayout || !resultLayout || failed(sourceArity) || - failed(resultArity)) - return fail("requires assigned source/result layouts and computable " - "physical arity"); - - unsigned resultBits = - pto::getPTOStorageElemBitWidth(resultType.getElementType()); - - if (sourceLayout.isGroupSlots() || resultLayout.isGroupSlots()) { - if (!sourceLayout.isGroupSlots() || !resultLayout.isGroupSlots() || - sourceLayout.getNumGroups() != resultLayout.getNumGroups() || - sourceLayout.getSlots() != 1 || resultLayout.getSlots() != 1 || - !sourceType.getElementType().isF32() || resultBits != 16 || - *sourceArity != *resultArity) - return fail("group-slot truncf requires matching " - "group_slots(num_groups=G, slots=1) source/result layouts, " - "f32 source, f16 result, and matching physical arity"); - - return success(); - } - - if (!sourceLayout.isDeinterleaved() || !resultLayout.isContiguous() || - !sourceType.getElementType().isF32() || *resultArity != 1) - return fail("requires f32 deinterleaved source and contiguous result"); - - if (sourceLayout.getFactor() == 2 && *sourceArity == 2 && resultBits == 16) - return success(); - if (sourceLayout.getFactor() == 4 && *sourceArity == 4 && resultBits == 8) - return success(); - return fail("unsupported deinterleaved truncf factor, arity, or result " - "element width"); -} - -FailureOr> -getPhysicalLogicalBitFootprint(VMIVRegType type) { - unsigned elementBits = pto::getPTOStorageElemBitWidth(type.getElementType()); - if (elementBits == 0) - return failure(); - - FailureOr factor = getDataLayoutFactor(type); - FailureOr lanesPerPart = getDataLanesPerPart(type.getElementType()); - if (failed(factor) || failed(lanesPerPart)) - return failure(); - - SmallVector bits; - for (int64_t part = 0; part < *factor; ++part) { - FailureOr chunks = getDataChunksInPart(type, part); - if (failed(chunks)) - return failure(); - for (int64_t chunk = 0; chunk < *chunks; ++chunk) { - int64_t activeLanes = 0; - for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { - FailureOr padding = isPaddingLane(type, part, chunk, lane); - if (failed(padding)) - return failure(); - if (!*padding) - ++activeLanes; - } - bits.push_back(activeLanes * static_cast(elementBits)); - } - } - return bits; + return success(); } LogicalResult checkSupportedBitcastShape(VMIBitcastOp op, std::string *reason) { - auto fail = [&](const Twine &message) -> LogicalResult { - if (reason) - *reason = message.str(); + VMILocalRecipeRegistry recipes; + if (failed(recipes.getBitcastRecipe(op, reason))) return failure(); - }; - - auto sourceType = cast(op.getSource().getType()); - auto resultType = cast(op.getResult().getType()); - VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); - VMILayoutAttr resultLayout = resultType.getLayoutAttr(); - if (!sourceLayout || !resultLayout) - return fail("requires assigned source and result layouts"); - if (sourceLayout != resultLayout) - return fail("requires matching source and result layouts"); - - FailureOr sourceArity = getVMIPhysicalArity(sourceType); - FailureOr resultArity = getVMIPhysicalArity(resultType); - if (failed(sourceArity) || failed(resultArity)) - return fail("requires computable source and result physical arity"); - if (*sourceArity != *resultArity) - return fail("requires source and result to have the same physical arity"); - - FailureOr> sourceBits = - getPhysicalLogicalBitFootprint(sourceType); - FailureOr> resultBits = - getPhysicalLogicalBitFootprint(resultType); - if (failed(sourceBits) || failed(resultBits)) - return fail("requires computable physical logical bit footprints"); - if (sourceBits->size() != resultBits->size()) - return fail("requires source and result physical footprint counts to " - "match"); - for (auto [source, result] : llvm::zip_equal(*sourceBits, *resultBits)) { - if (source != result) - return fail("requires matching logical bit footprint in every physical " - "chunk"); - } - return success(); } @@ -7315,6 +7128,10 @@ LogicalResult checkSupportedGroupReduceAddFShape( if (!sourceLayout || !resultLayout || !maskLayout) return fail("requires assigned source, mask, and result layouts"); + VMILocalRecipeRegistry recipes; + if (succeeded(recipes.getGroupReduceAddFRecipe(capabilities, op, nullptr))) + return success(); + FailureOr groupSize = getGroupSizeFromNumGroups( sourceType, op.getNumGroupsAttr().getInt(), reason); if (failed(groupSize)) @@ -7381,6 +7198,9 @@ LogicalResult checkSupportedGroupBroadcastShape( VMILayoutAttr resultLayout = resultType.getLayoutAttr(); if (!sourceLayout || !resultLayout) return fail("requires assigned source/result layouts"); + VMILocalRecipeRegistry recipes; + if (succeeded(recipes.getGroupBroadcastRecipe(capabilities, op, nullptr))) + return success(); if (!sourceLayout.isGroupSlots() || sourceLayout.getNumGroups() != op.getNumGroupsAttr().getInt()) return fail("requires matching num_groups source layout"); @@ -8045,7 +7865,8 @@ verifySupportedVMIToVPTOOps(ModuleOp module, } if (auto extf = dyn_cast(op)) { - if (succeeded(checkSupportedExtFShape(extf))) + std::string reason; + if (succeeded(checkSupportedExtFShape(extf, &reason))) return WalkResult::advance(); extf.emitError() @@ -8053,7 +7874,8 @@ verifySupportedVMIToVPTOOps(ModuleOp module, << "pto.vmi.extf supports contiguous 16-bit float-like or fp8-like " "physical source chunks to f32 deinterleaved=2/4 results; " "partial/tail is allowed only when source padding maps to result " - "padding"; + "padding (" + << reason << ")"; return WalkResult::interrupt(); } diff --git a/test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto b/test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto index af6623a995..ec29b4387a 100644 --- a/test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto @@ -18,13 +18,8 @@ module { {num_groups = 8, reassoc} : !pto.vmi.vreg<64xf32>, !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<64xf32> - // CHECK: {{VMI-UNSUPPORTED}}: pto.vmi.store operand #0 has type - // CHECK-SAME: #pto.vmi.layout - // CHECK-SAME: requires - // CHECK-SAME: #pto.vmi.layout - // CHECK-SAME: pto.vmi.ensure_layout cannot materialize this conversion - // CHECK: failed helper conversion - // CHECK-SAME: unsupported source/result layout pair + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.store operand #0 has type !pto.vmi.vreg<64xf32, #pto.vmi.layout> but requires !pto.vmi.vreg<64xf32, #pto.vmi.layout>; pto.vmi.ensure_layout has no registered materialization recipe: unsupported source/result layout pair + // CHECK: failed helper conversion '!pto.vmi.vreg<64xf32, #pto.vmi.layout>' -> '!pto.vmi.vreg<64xf32, #pto.vmi.layout>' (unsupported source/result layout pair) pto.vmi.store %sum, %dst[%off] : !pto.vmi.vreg<64xf32>, !pto.ptr return diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto index c928df5320..6ed4e7f9e7 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto @@ -26,13 +26,8 @@ module { -> !pto.vmi.vreg<128xf32> pto.vmi.group_store %sum, %sum_dst[%off], %c1 {num_groups = 8} : !pto.vmi.vreg<128xf32>, !pto.ptr - // CHECK: {{VMI-UNSUPPORTED}}: pto.vmi.truncf operand #0 has type - // CHECK-SAME: #pto.vmi.layout - // CHECK-SAME: requires - // CHECK-SAME: #pto.vmi.layout - // CHECK-SAME: pto.vmi.ensure_layout cannot materialize this conversion - // CHECK: failed helper conversion - // CHECK-SAME: unsupported source/result layout pair + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.truncf operand #0 has type !pto.vmi.vreg<128xf32, #pto.vmi.layout> but requires !pto.vmi.vreg<128xf32, #pto.vmi.layout>; pto.vmi.ensure_layout has no registered materialization recipe: unsupported source/result layout pair + // CHECK: failed helper conversion '!pto.vmi.vreg<128xf32, #pto.vmi.layout>' -> '!pto.vmi.vreg<128xf32, #pto.vmi.layout>' (unsupported source/result layout pair) %h = pto.vmi.truncf %x : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> pto.vmi.store %h, %dense_dst[%off] diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto index 3bea54d83f..eccb4e0007 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto @@ -15,9 +15,8 @@ module { %dst: !pto.ptr, %off: index) { %c1 = arith.constant 1 : index - // CHECK: {{VMI-UNSUPPORTED}}: pto.vmi.group_reduce_addf lowers through pto.vcgadd - // CHECK-SAME: num_groups deriving a group size aligned to physical chunks - // CHECK-SAME: found padding lane in physical chunk + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots local recipe + // CHECK-SAME: stable group_reduce_addf slots=8 recipes support group size 8, 16, or 32 %sum = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<96xf32>, !pto.vmi.mask<96xpred> diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto index c66ff0eb3c..cface43bab 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto @@ -15,13 +15,12 @@ module { %dst: !pto.ptr, %off: index) { %c1 = arith.constant 1 : index - // CHECK: {{VMI-UNSUPPORTED}}: pto.vmi.group_reduce_addf operand #0 has type + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf operand #0 has type // CHECK-SAME: #pto.vmi.layout // CHECK-SAME: requires // CHECK-SAME: #pto.vmi.layout - // CHECK-SAME: pto.vmi.ensure_layout cannot materialize this conversion + // CHECK-SAME: pto.vmi.ensure_layout has no registered materialization recipe // CHECK: requires source and result to have the same physical arity - // CHECK-SAME: partial/tail layout materialization requires an explicit packing plan %sum = pto.vmi.group_reduce_addf %source, %mask {num_groups = 6, reassoc} : !pto.vmi.vreg<192xf32>, !pto.vmi.mask<192xpred> diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto index 9f629f55f2..94cd55c58c 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto @@ -18,9 +18,10 @@ module { } func.func @vmi_layout_assignment_group_slot_load_slots1( - %src: !pto.ptr, %off: index, %stride: index) + %src: !pto.ptr, %off: index) -> !pto.vmi.vreg<512xf32> { - %out = pto.vmi.group_slot_load %src[%off], %stride {num_groups = 8} + %c8 = arith.constant 8 : index + %out = pto.vmi.group_slot_load %src[%off], %c8 {num_groups = 8} : !pto.ptr -> !pto.vmi.vreg<512xf32> return %out : !pto.vmi.vreg<512xf32> } diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto index e6e459c435..b8cd439d23 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto @@ -11,7 +11,7 @@ module { func.func @vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid( %src: !pto.ptr, %off: index, %stride: index) { - // CHECK: VMI-UNSUPPORTED: pto.vmi.group_slot_load + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_slot_load has no registered local recipe // CHECK-SAME: slots=1 group_slot_load currently lowers as one lane-0 vsldb per group // CHECK-SAME: requires constant positive source_group_stride divisible by 8 elements // CHECK-SAME: packed or unaligned scalar load lowering is not implemented diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto index f8d7bc8af8..b432d7c68c 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto @@ -12,7 +12,7 @@ module { func.func @vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid( %src: !pto.ptr, %off: index) { %c2 = arith.constant 2 : index - // CHECK: VMI-UNSUPPORTED: pto.vmi.group_slot_load + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_slot_load has no registered local recipe // CHECK-SAME: slots=1 group_slot_load currently lowers as one lane-0 vsldb per group // CHECK-SAME: requires constant positive source_group_stride divisible by 8 elements // CHECK-SAME: packed or unaligned scalar load lowering is not implemented diff --git a/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto index 452ee085ac..996760ed66 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto @@ -19,7 +19,7 @@ module { {num_groups = 8, reassoc} : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> -> !pto.vmi.vreg<512xf32> - // CHECK: VMI-UNSUPPORTED: pto.vmi.group_store + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_store has no registered group_slots local recipe // CHECK-SAME: slots=1 group_store currently lowers as one lane-0 vsts per group // CHECK-SAME: requires constant positive row_stride divisible by 8 elements // CHECK-SAME: packed or unaligned contiguous store lowering is not implemented diff --git a/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto b/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto index 3005e53c0a..e57954b16e 100644 --- a/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto @@ -19,13 +19,8 @@ module { {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> -> !pto.vmi.vreg<128xf32> - // CHECK: {{VMI-UNSUPPORTED}}: pto.vmi.truncf operand #0 has type - // CHECK-SAME: #pto.vmi.layout - // CHECK-SAME: requires - // CHECK-SAME: #pto.vmi.layout - // CHECK-SAME: pto.vmi.ensure_layout cannot materialize this conversion - // CHECK: failed helper conversion - // CHECK-SAME: unsupported source/result layout pair + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.truncf operand #0 has type !pto.vmi.vreg<128xf32, #pto.vmi.layout> but requires !pto.vmi.vreg<128xf32, #pto.vmi.layout>; pto.vmi.ensure_layout has no registered materialization recipe: unsupported source/result layout pair + // CHECK: failed helper conversion '!pto.vmi.vreg<128xf32, #pto.vmi.layout>' -> '!pto.vmi.vreg<128xf32, #pto.vmi.layout>' (unsupported source/result layout pair) %h = pto.vmi.truncf %sum : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> pto.vmi.group_store %h, %dst[%off], %c1 {num_groups = 8} diff --git a/test/lit/vmi/vmi_layout_fold_consumers_deint4.pto b/test/lit/vmi/vmi_layout_fold_consumers_deint4.pto new file mode 100644 index 0000000000..84ba3b5b1e --- /dev/null +++ b/test/lit/vmi/vmi_layout_fold_consumers_deint4.pto @@ -0,0 +1,90 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-fold-consumers | FileCheck %s --check-prefix=FOLD +// RUN: pto-test-opt %s -vmi-layout-fold-consumers -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_fold_consumers_store_deint4( + %value: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + %value_c = pto.vmi.ensure_layout %value + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + pto.vmi.store %value_c, %dst[%offset] + : !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + !pto.ptr + return + } + + func.func @vmi_layout_fold_consumers_masked_store_deint4( + %value: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<256xb32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + %value_c = pto.vmi.ensure_layout %value + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %mask_c = pto.vmi.ensure_mask_layout %mask + : !pto.vmi.mask<256xb32, #pto.vmi.layout> + -> !pto.vmi.mask<256xb32, #pto.vmi.layout> + pto.vmi.masked_store %value_c, %dst[%offset], %mask_c + : !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<256xb32, #pto.vmi.layout> + return + } +} + +// FOLD-LABEL: func.func @vmi_layout_fold_consumers_store_deint4( +// FOLD-SAME: %[[VALUE:.*]]: !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD: pto.vmi.store %[[VALUE]] +// FOLD-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD: return + +// FOLD-LABEL: func.func @vmi_layout_fold_consumers_masked_store_deint4( +// FOLD-SAME: %[[VALUE:.*]]: !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// FOLD-SAME: %[[MASK:.*]]: !pto.vmi.mask<256xb32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD-NOT: pto.vmi.ensure_mask_layout +// FOLD: pto.vmi.masked_store %[[VALUE]] +// FOLD-SAME: %[[MASK]] +// FOLD-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// FOLD-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD-NOT: pto.vmi.ensure_mask_layout +// FOLD: return + +// LOWER-LABEL: func.func @vmi_layout_fold_consumers_store_deint4( +// LOWER: pto.vintlv +// LOWER: pto.vintlv +// LOWER: pto.vintlv +// LOWER: pto.vintlv +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER: pto.vsts + +// LOWER-LABEL: func.func @vmi_layout_fold_consumers_masked_store_deint4( +// LOWER: pto.vintlv +// LOWER: pto.vintlv +// LOWER: pto.vintlv +// LOWER: pto.vintlv +// LOWER: pto.pintlv_b32 +// LOWER: pto.pintlv_b32 +// LOWER: pto.pintlv_b32 +// LOWER: pto.pintlv_b32 +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_fold_consumers_masked_store.pto b/test/lit/vmi/vmi_layout_fold_consumers_masked_store.pto new file mode 100644 index 0000000000..8f31b78f7b --- /dev/null +++ b/test/lit/vmi/vmi_layout_fold_consumers_masked_store.pto @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-fold-consumers | FileCheck %s --check-prefix=FOLD +// RUN: pto-test-opt %s -vmi-layout-fold-consumers -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_fold_consumers_masked_store( + %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + %value_c = pto.vmi.ensure_layout %value + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %mask_c = pto.vmi.ensure_mask_layout %mask + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + pto.vmi.masked_store %value_c, %dst[%offset], %mask_c + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + return + } +} + +// FOLD-LABEL: func.func @vmi_layout_fold_consumers_masked_store( +// FOLD-SAME: %[[VALUE:.*]]: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// FOLD-SAME: %[[MASK:.*]]: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD-NOT: pto.vmi.ensure_mask_layout +// FOLD: pto.vmi.masked_store %[[VALUE]] +// FOLD-SAME: %[[MASK]] +// FOLD-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// FOLD-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD-NOT: pto.vmi.ensure_mask_layout +// FOLD: return + +// LOWER-LABEL: func.func @vmi_layout_fold_consumers_masked_store( +// LOWER-SAME: %[[V0:[^,]+]]: !pto.vreg<64xf32> +// LOWER-SAME: %[[V1:[^,]+]]: !pto.vreg<64xf32> +// LOWER-SAME: %[[M0:[^,]+]]: !pto.mask +// LOWER-SAME: %[[M1:[^,]+]]: !pto.mask +// LOWER: %[[LOW:.*]], %[[HIGH:.*]] = pto.vintlv %[[V0]], %[[V1]] +// LOWER: %[[ML:.*]], %[[MH:.*]] = pto.pintlv_b32 %[[M0]], %[[M1]] +// LOWER: pto.vsts %[[LOW]] +// LOWER-SAME: %[[ML]] +// LOWER: pto.vsts %[[HIGH]] +// LOWER-SAME: %[[MH]] +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_fold_consumers_store.pto b/test/lit/vmi/vmi_layout_fold_consumers_store.pto new file mode 100644 index 0000000000..281d737861 --- /dev/null +++ b/test/lit/vmi/vmi_layout_fold_consumers_store.pto @@ -0,0 +1,92 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-fold-consumers | FileCheck %s --check-prefix=FOLD +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-fold-consumers -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_fold_consumers_store( + %src: !pto.vmi.vreg<128xf16>, + %scale: f32, + %out1: !pto.ptr, + %out2: !pto.ptr, + %offset: index) { + %scale_v = pto.vmi.broadcast %scale + : f32 -> !pto.vmi.vreg<128xf32> + %wide = pto.vmi.extf %src + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %prod = pto.vmi.mulf %wide, %scale_v + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.store %prod, %out1[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + pto.vmi.store %wide, %out2[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } + + func.func @vmi_layout_fold_consumers_tile_write( + %src: !pto.vmi.vreg<128xf16>, + %dst: memref<128xf32>) { + %wide = pto.vmi.extf %src + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + pto.vmi.tile_write %wide, %dst + : !pto.vmi.vreg<128xf32>, memref<128xf32> + return + } + +} + +// FOLD-LABEL: func.func @vmi_layout_fold_consumers_store( +// FOLD-SAME: %[[SRC:.*]]: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// FOLD: %[[SCALE:.*]] = pto.vmi.broadcast +// FOLD-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// FOLD: %[[WIDE:.*]] = pto.vmi.extf %[[SRC]] +// FOLD-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// FOLD: %[[PROD:.*]] = pto.vmi.mulf %[[WIDE]], %[[SCALE]] +// FOLD-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD: pto.vmi.store %[[PROD]] +// FOLD-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD: pto.vmi.store %[[WIDE]] +// FOLD-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD: return + +// LOWER-LABEL: func.func @vmi_layout_fold_consumers_store( +// LOWER: %[[SCALE0:.*]] = pto.vdup +// LOWER: %[[SCALE1:.*]] = pto.vdup +// LOWER: %[[WIDE0:.*]] = pto.vcvt +// LOWER: %[[WIDE1:.*]] = pto.vcvt +// LOWER: %[[PROD0:.*]] = pto.vmul %[[WIDE0]], %[[SCALE0]] +// LOWER: %[[PROD1:.*]] = pto.vmul %[[WIDE1]], %[[SCALE1]] +// LOWER-NOT: pto.vintlv +// LOWER: pto.vstsx2 %[[PROD0]], %[[PROD1]] +// LOWER: pto.vstsx2 %[[WIDE0]], %[[WIDE1]] +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. + + +// FOLD-LABEL: func.func @vmi_layout_fold_consumers_tile_write( +// FOLD-SAME: %[[SRC:.*]]: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// FOLD: %[[WIDE:.*]] = pto.vmi.extf %[[SRC]] +// FOLD-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD: pto.vmi.tile_write %[[WIDE]] +// FOLD-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD: return + +// LOWER-LABEL: func.func @vmi_layout_fold_consumers_tile_write( +// LOWER: %[[WIDE0:.*]] = pto.vcvt +// LOWER: %[[WIDE1:.*]] = pto.vcvt +// LOWER-NOT: pto.vintlv +// LOWER: pto.vstsx2 %[[WIDE0]], %[[WIDE1]] +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_gate_bitcast_group_slots_invalid.pto b/test/lit/vmi/vmi_layout_gate_bitcast_group_slots_invalid.pto new file mode 100644 index 0000000000..e63567e48d --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_bitcast_group_slots_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_bitcast_group_slots_invalid( + %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.bitcast has no registered local recipe + // CHECK-SAME: does not support group_slots layouts + // CHECK: note: see current operation: %{{.*}} = "pto.vmi.bitcast" + %out = pto.vmi.bitcast %source + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_bitcast_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_bitcast_recipe_invalid.pto new file mode 100644 index 0000000000..806aaa26dd --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_bitcast_recipe_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_bitcast_recipe_invalid( + %source: !pto.vmi.vreg<65xf32, #pto.vmi.layout>) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.bitcast has no registered local recipe + // CHECK-SAME: requires matching logical bit footprint in every physical chunk + // CHECK: note: see current operation: %{{.*}} = "pto.vmi.bitcast" + %out = pto.vmi.bitcast %source + : !pto.vmi.vreg<65xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<130xi16, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_extf_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_extf_recipe_invalid.pto new file mode 100644 index 0000000000..7bda214fed --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_extf_recipe_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_extf_recipe_invalid( + %source: !pto.vmi.vreg<128xf16, #pto.vmi.layout>) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.extf has no registered local recipe + // CHECK-SAME: requires contiguous source layout and deinterleaved f32 result layout + // CHECK: note: see current operation: %{{.*}} = "pto.vmi.extf" + %out = pto.vmi.extf %source + : !pto.vmi.vreg<128xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_group_broadcast_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_broadcast_recipe_invalid.pto new file mode 100644 index 0000000000..224858064c --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_group_broadcast_recipe_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_group_broadcast_recipe_invalid( + %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_broadcast has no registered local recipe + // CHECK-SAME: supports only slots=8 or slots=1 group_broadcast source layouts + // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_broadcast" + %out = pto.vmi.group_broadcast %source {num_groups = 8} + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_group_load_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_load_recipe_invalid.pto new file mode 100644 index 0000000000..8f9fb2c809 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_group_load_recipe_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_group_load_recipe_invalid( + %src: !pto.ptr, %off: index, %stride: index) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_load has no registered block8 local recipe + // CHECK-SAME: block8 strided group_load requires constant positive row_stride divisible by 8 f32 elements + // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_load" + %out = pto.vmi.group_load %src[%off], %stride + {num_groups = 8} + : !pto.ptr + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_group_reduce_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_reduce_recipe_invalid.pto new file mode 100644 index 0000000000..673f3ee47b --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_group_reduce_recipe_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_group_reduce_recipe_invalid( + %source: !pto.vmi.vreg<96xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<96xb32, #pto.vmi.layout>) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots local recipe + // CHECK-SAME: stable group_reduce_addf slots=8 recipes support group size 8, 16, or 32 + // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_reduce_addf" + %out = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<96xf32, #pto.vmi.layout>, + !pto.vmi.mask<96xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<96xf32, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto new file mode 100644 index 0000000000..f4071e4c47 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_group_reduce_slots1_recipe_invalid( + %source: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<256xb32, #pto.vmi.layout>) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots local recipe + // CHECK-SAME: stable group_reduce_addf slots=1 recipes support group size 64 + // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_reduce_addf" + %out = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + !pto.vmi.mask<256xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_group_slot_load_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_slot_load_recipe_invalid.pto new file mode 100644 index 0000000000..31e7f13c3e --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_group_slot_load_recipe_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_group_slot_load_recipe_invalid( + %src: !pto.ptr, %off: index, %stride: index) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_slot_load has no registered local recipe + // CHECK-SAME: slots=8 group_slot_load requires constant unit source_group_stride + // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_slot_load" + %out = pto.vmi.group_slot_load %src[%off], %stride + {num_groups = 8} + : !pto.ptr + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto new file mode 100644 index 0000000000..b8576fe3b7 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_group_store_slots2_invalid( + %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %dst: !pto.ptr, %off: index, %row_stride: index) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_store has no registered group_slots local recipe + // CHECK-SAME: group_slots group_store currently supports only slots=1 or unit-stride slots=8 + pto.vmi.group_store %value, %dst[%off], %row_stride + {num_groups = 8} + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.ptr + return + } +} + +// ----- + +module { + func.func @vmi_layout_gate_group_reduce_slots2_invalid( + %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots local recipe + // CHECK-SAME: stable group_reduce_addf local recipes currently require result layout slots=8 or slots=1 + %out = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_group_store_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_store_recipe_invalid.pto new file mode 100644 index 0000000000..c7003a887d --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_group_store_recipe_invalid.pto @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_group_store_recipe_invalid( + %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %dst: !pto.ptr, %off: index, %row_stride: index) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_store has no registered group_slots local recipe + // CHECK-SAME: slots=8 group_store currently requires constant unit row_stride + // CHECK: note: see current operation: "pto.vmi.group_store" + pto.vmi.group_store %value, %dst[%off], %row_stride + {num_groups = 8} + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.ptr + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_helper_materialization_shape_invalid.pto b/test/lit/vmi/vmi_layout_gate_helper_materialization_shape_invalid.pto new file mode 100644 index 0000000000..53cc5c2a12 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_helper_materialization_shape_invalid.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_ensure_layout_shape_invalid( + %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.ensure_layout has no registered materialization recipe + // CHECK-SAME: requires source and result to have the same physical arity + %dense = pto.vmi.ensure_layout %value + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} + +// ----- + +module { + func.func @vmi_layout_gate_ensure_mask_layout_shape_invalid( + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.ensure_mask_layout has no registered materialization recipe + // CHECK-SAME: requires source and result to have the same physical arity + %dense = pto.vmi.ensure_mask_layout %mask + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_helper_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_helper_recipe_invalid.pto new file mode 100644 index 0000000000..871e14eb5b --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_helper_recipe_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_helper_recipe_invalid( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) { + %bad = pto.vmi.ensure_layout %value + : !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.ensure_layout has no registered materialization recipe +// CHECK-SAME: unsupported source/result layout pair diff --git a/test/lit/vmi/vmi_layout_gate_store_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_store_recipe_invalid.pto new file mode 100644 index 0000000000..3877eb1a3a --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_store_recipe_invalid.pto @@ -0,0 +1,37 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_store_deint_tail_invalid( + %value: !pto.vmi.vreg<129xf32, #pto.vmi.layout>, + %dst: !pto.ptr, %offset: index) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.store has no registered contiguous-memory local recipe + // CHECK-SAME: requires arity divisible by layout factor + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<129xf32, #pto.vmi.layout>, + !pto.ptr + return + } +} + +// ----- + +module { + func.func @vmi_layout_gate_tile_write_deint_tail_invalid( + %value: !pto.vmi.vreg<129xf32, #pto.vmi.layout>, + %dst: memref<129xf32>) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.tile_write has no registered contiguous-memory local recipe + // CHECK-SAME: requires arity divisible by layout factor + pto.vmi.tile_write %value, %dst + : !pto.vmi.vreg<129xf32, #pto.vmi.layout>, + memref<129xf32> + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_truncf_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_truncf_recipe_invalid.pto new file mode 100644 index 0000000000..68e7963b1b --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_truncf_recipe_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_truncf_recipe_invalid( + %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.truncf has no registered local recipe + // CHECK-SAME: group-slot truncf requires matching group_slots(num_groups=G, slots=1) + // CHECK: note: see current operation: %{{.*}} = "pto.vmi.truncf" + %out = pto.vmi.truncf %source + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_rematerialize_data.pto b/test/lit/vmi/vmi_layout_rematerialize_data.pto new file mode 100644 index 0000000000..29faa34fb1 --- /dev/null +++ b/test/lit/vmi/vmi_layout_rematerialize_data.pto @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-rematerialize | FileCheck %s + +module { + func.func @vmi_layout_rematerialize_data( + %scalar: f32, + %base: f32) + -> (!pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %broadcast = pto.vmi.broadcast %scalar + : f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %broadcast_deint = pto.vmi.ensure_layout %broadcast + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + + %iota = pto.vmi.iota %base + : f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %iota_deint = pto.vmi.ensure_layout %iota + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + + %constant = "pto.vmi.constant"() { + value = dense<1.000000e+00> : tensor<128xf32> + } : () -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %constant_deint = pto.vmi.ensure_layout %constant + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + + return %broadcast_deint, %iota_deint, %constant_deint + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_layout_rematerialize_data( +// CHECK: %[[BCAST:.*]] = pto.vmi.broadcast %arg0 : f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: %[[IOTA:.*]] = pto.vmi.iota %arg1 : f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: %[[CONST:.*]] = "pto.vmi.constant"(){{.*}}dense<1.000000e+00> : tensor<128xf32>{{.*}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-NOT: pto.vmi.ensure_layout +// CHECK: return %[[BCAST]], %[[IOTA]], %[[CONST]] diff --git a/test/lit/vmi/vmi_layout_rematerialize_mask.pto b/test/lit/vmi/vmi_layout_rematerialize_mask.pto new file mode 100644 index 0000000000..6c3bb60053 --- /dev/null +++ b/test/lit/vmi/vmi_layout_rematerialize_mask.pto @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-rematerialize | FileCheck %s + +module { + func.func @vmi_layout_rematerialize_mask(%active: index) + -> (!pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout>) { + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %mask_b16 = pto.vmi.ensure_mask_granularity %mask + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + %mask_deint = pto.vmi.ensure_mask_layout %mask + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + + %group_mask = pto.vmi.create_group_mask %active + {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %group_mask_deint = pto.vmi.ensure_mask_layout %group_mask + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + + %constant_mask = "pto.vmi.constant_mask"() { + value = dense : tensor<128xi1> + } : () -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %constant_mask_b16 = pto.vmi.ensure_mask_granularity %constant_mask + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + + return %mask_b16, %mask_deint, %group_mask_deint, %constant_mask_b16 + : !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_layout_rematerialize_mask( +// CHECK: %[[M16:.*]] = pto.vmi.create_mask %arg0 : index -> !pto.vmi.mask<128xb16, #pto.vmi.layout> +// CHECK: %[[MDEINT:.*]] = pto.vmi.create_mask %arg0 : index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: %[[GMDEINT:.*]] = pto.vmi.create_group_mask %arg0{{.*}}group_size = 16{{.*}}num_groups = 8{{.*}}index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: %[[CM16:.*]] = "pto.vmi.constant_mask"(){{.*}}dense : tensor<128xi1>{{.*}}!pto.vmi.mask<128xb16, #pto.vmi.layout> +// CHECK-NOT: pto.vmi.ensure_mask_layout +// CHECK-NOT: pto.vmi.ensure_mask_granularity +// CHECK: return %[[M16]], %[[MDEINT]], %[[GMDEINT]], %[[CM16]] diff --git a/test/lit/vmi/vmi_layout_sink_materialization_binary.pto b/test/lit/vmi/vmi_layout_sink_materialization_binary.pto new file mode 100644 index 0000000000..9db3fcb22b --- /dev/null +++ b/test/lit/vmi/vmi_layout_sink_materialization_binary.pto @@ -0,0 +1,202 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-sink-materialization -pto-validate-vmi-layout-ir | FileCheck %s + +module { + func.func @vmi_layout_sink_materialization_addf( + %lhs: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + %lhs_deint = pto.vmi.ensure_layout %lhs + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %rhs_deint = pto.vmi.ensure_layout %rhs + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %sum = pto.vmi.addf %lhs_deint, %rhs_deint + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %sum : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_materialization_muli( + %lhs: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> { + %lhs_deint = pto.vmi.ensure_layout %lhs + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %rhs_deint = pto.vmi.ensure_layout %rhs + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %prod = pto.vmi.muli %lhs_deint, %rhs_deint + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + return %prod : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_materialization_single_ensure_kept( + %lhs: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + %lhs_deint = pto.vmi.ensure_layout %lhs + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %sum = pto.vmi.addf %lhs_deint, %rhs + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %sum : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_materialization_unary( + %src: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + %src_deint = pto.vmi.ensure_layout %src + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %neg = pto.vmi.negf %src_deint + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %neg : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_materialization_unary_integer( + %src: !pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> { + %src_deint = pto.vmi.ensure_layout %src + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %abs = pto.vmi.absi %src_deint + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + return %abs : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_materialization_bitwise( + %lhs: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> { + %lhs_deint = pto.vmi.ensure_layout %lhs + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %rhs_deint = pto.vmi.ensure_layout %rhs + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %and = pto.vmi.andi %lhs_deint, %rhs_deint + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + return %and : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_materialization_shift( + %lhs: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> { + %lhs_deint = pto.vmi.ensure_layout %lhs + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %rhs_deint = pto.vmi.ensure_layout %rhs + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %shift = pto.vmi.shli %lhs_deint, %rhs_deint + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + return %shift : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_materialization_not( + %src: !pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> { + %src_deint = pto.vmi.ensure_layout %src + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %not = pto.vmi.not %src_deint + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + return %not : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_addf( +// CHECK-NOT: pto.vmi.ensure_layout %arg0 +// CHECK-NOT: pto.vmi.ensure_layout %arg1 +// CHECK: %[[SUM:.*]] = pto.vmi.addf %arg0, %arg1 +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: %[[SUM_DEINT:.*]] = pto.vmi.ensure_layout %[[SUM]] +// CHECK-SAME: #pto.vmi.layout +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[SUM_DEINT]] + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_muli( +// CHECK-NOT: pto.vmi.ensure_layout %arg0 +// CHECK-NOT: pto.vmi.ensure_layout %arg1 +// CHECK: %[[PROD:.*]] = pto.vmi.muli %arg0, %arg1 +// CHECK-SAME: !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK: %[[PROD_DEINT:.*]] = pto.vmi.ensure_layout %[[PROD]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[PROD_DEINT]] + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_single_ensure_kept( +// CHECK: %[[LHS_DEINT:.*]] = pto.vmi.ensure_layout %arg0 +// CHECK: %[[SUM2:.*]] = pto.vmi.addf %[[LHS_DEINT]], %arg1 +// CHECK: return %[[SUM2]] + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_unary( +// CHECK-NOT: pto.vmi.ensure_layout %arg0 +// CHECK: %[[NEG:.*]] = pto.vmi.negf %arg0 +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: %[[NEG_DEINT:.*]] = pto.vmi.ensure_layout %[[NEG]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[NEG_DEINT]] + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_unary_integer( +// CHECK-NOT: pto.vmi.ensure_layout %arg0 +// CHECK: %[[ABS:.*]] = pto.vmi.absi %arg0 +// CHECK-SAME: !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK: %[[ABS_DEINT:.*]] = pto.vmi.ensure_layout %[[ABS]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[ABS_DEINT]] + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_bitwise( +// CHECK-NOT: pto.vmi.ensure_layout %arg0 +// CHECK-NOT: pto.vmi.ensure_layout %arg1 +// CHECK: %[[AND:.*]] = pto.vmi.andi %arg0, %arg1 +// CHECK-SAME: !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK: %[[AND_DEINT:.*]] = pto.vmi.ensure_layout %[[AND]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[AND_DEINT]] + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_shift( +// CHECK-NOT: pto.vmi.ensure_layout %arg0 +// CHECK-NOT: pto.vmi.ensure_layout %arg1 +// CHECK: %[[SHIFT:.*]] = pto.vmi.shli %arg0, %arg1 +// CHECK-SAME: !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK: %[[SHIFT_DEINT:.*]] = pto.vmi.ensure_layout %[[SHIFT]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[SHIFT_DEINT]] + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_not( +// CHECK-NOT: pto.vmi.ensure_layout %arg0 +// CHECK: %[[NOT:.*]] = pto.vmi.not %arg0 +// CHECK-SAME: !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK: %[[NOT_DEINT:.*]] = pto.vmi.ensure_layout %[[NOT]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[NOT_DEINT]] diff --git a/test/lit/vmi/vmi_layout_sink_materialization_mask.pto b/test/lit/vmi/vmi_layout_sink_materialization_mask.pto new file mode 100644 index 0000000000..0effb48323 --- /dev/null +++ b/test/lit/vmi/vmi_layout_sink_materialization_mask.pto @@ -0,0 +1,86 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-sink-materialization -pto-validate-vmi-layout-ir | FileCheck %s + +module { + func.func @vmi_layout_sink_mask_layout_binary( + %lhs: !pto.vmi.mask<128xb32, #pto.vmi.layout>, + %rhs: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> { + %lhs_deint = pto.vmi.ensure_mask_layout %lhs + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %rhs_deint = pto.vmi.ensure_mask_layout %rhs + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %out = pto.vmi.mask_and %lhs_deint, %rhs_deint + : !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + return %out : !pto.vmi.mask<128xb32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_mask_granularity_binary( + %lhs: !pto.vmi.mask<128xb32, #pto.vmi.layout>, + %rhs: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> { + %lhs_b16 = pto.vmi.ensure_mask_granularity %lhs + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + %rhs_b16 = pto.vmi.ensure_mask_granularity %rhs + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + %out = pto.vmi.mask_or %lhs_b16, %rhs_b16 + : !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + return %out : !pto.vmi.mask<128xb16, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_mask_layout_unary( + %source: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> { + %source_deint = pto.vmi.ensure_mask_layout %source + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %out = pto.vmi.mask_not %source_deint + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + return %out : !pto.vmi.mask<128xb32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_layout_sink_mask_layout_binary( +// CHECK-NOT: pto.vmi.ensure_mask_layout %arg0 +// CHECK-NOT: pto.vmi.ensure_mask_layout %arg1 +// CHECK: %[[OUT:.*]] = pto.vmi.mask_and %arg0, %arg1 +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: %[[OUT_DEINT:.*]] = pto.vmi.ensure_mask_layout %[[OUT]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[OUT_DEINT]] + +// CHECK-LABEL: func.func @vmi_layout_sink_mask_granularity_binary( +// CHECK-NOT: pto.vmi.ensure_mask_granularity %arg0 +// CHECK-NOT: pto.vmi.ensure_mask_granularity %arg1 +// CHECK: %[[OUT:.*]] = pto.vmi.mask_or %arg0, %arg1 +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: %[[OUT_B16:.*]] = pto.vmi.ensure_mask_granularity %[[OUT]] +// CHECK-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> +// CHECK: return %[[OUT_B16]] + +// CHECK-LABEL: func.func @vmi_layout_sink_mask_layout_unary( +// CHECK-NOT: pto.vmi.ensure_mask_layout %arg0 +// CHECK: %[[OUT:.*]] = pto.vmi.mask_not %arg0 +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: %[[OUT_DEINT:.*]] = pto.vmi.ensure_mask_layout %[[OUT]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[OUT_DEINT]] diff --git a/test/lit/vmi/vmi_legalize_arith_select.pto b/test/lit/vmi/vmi_legalize_arith_select.pto new file mode 100644 index 0000000000..0661b6764e --- /dev/null +++ b/test/lit/vmi/vmi_legalize_arith_select.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-legalize-arith-select -pto-validate-vmi-layout-ir | FileCheck %s + +module { + func.func @vmi_legalize_arith_select_vreg( + %cond: i1, + %lhs: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + %selected = arith.select %cond, %lhs, %rhs + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %selected : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } + + func.func @vmi_legalize_arith_select_mask( + %cond: i1, + %lhs: !pto.vmi.mask<128xb32, #pto.vmi.layout>, + %rhs: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> { + %selected = arith.select %cond, %lhs, %rhs + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + return %selected : !pto.vmi.mask<128xb32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_legalize_arith_select_vreg( +// CHECK-NOT: arith.select +// CHECK: %[[IF:.*]] = scf.if %arg0 -> (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) { +// CHECK: scf.yield %arg1 +// CHECK: } else { +// CHECK: scf.yield %arg2 +// CHECK: return %[[IF]] + +// CHECK-LABEL: func.func @vmi_legalize_arith_select_mask( +// CHECK-NOT: arith.select +// CHECK: %[[IF:.*]] = scf.if %arg0 -> (!pto.vmi.mask<128xb32, #pto.vmi.layout>) { +// CHECK: scf.yield %arg1 +// CHECK: } else { +// CHECK: scf.yield %arg2 +// CHECK: return %[[IF]] diff --git a/test/lit/vmi/vmi_ptoas_cli_pipeline.pto b/test/lit/vmi/vmi_ptoas_cli_pipeline.pto index 8957bb1f40..e49dba60c3 100644 --- a/test/lit/vmi/vmi_ptoas_cli_pipeline.pto +++ b/test/lit/vmi/vmi_ptoas_cli_pipeline.pto @@ -22,6 +22,19 @@ module attributes {pto.target_arch = "a5"} { : !pto.vmi.vreg<128xf32>, !pto.ptr return } + + func.func @vmi_ptoas_cli_fold_consumers_pipeline( + %src: !pto.ptr, + %dst: !pto.ptr, + %offset: index) { + %x16 = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<128xf16> + %x32 = pto.vmi.extf %x16 + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + pto.vmi.store %x32, %dst[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } } } @@ -34,6 +47,16 @@ module attributes {pto.target_arch = "a5"} { // CHECK-NOT: !pto.vmi. // CHECK-NOT: unrealized_conversion_cast +// CHECK-LABEL: func.func @vmi_ptoas_cli_fold_consumers_pipeline +// CHECK: pto.vlds +// CHECK: pto.vcvt {{.*}} {part = "EVEN"} +// CHECK: pto.vcvt {{.*}} {part = "ODD"} +// CHECK-NOT: pto.vintlv +// CHECK: pto.vstsx2 +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + // ATTR-LABEL: func.func @vmi_ptoas_cli_pipeline // ATTR: pto.vecscope // ATTR: pto.vdup diff --git a/test/lit/vmi/vmi_to_vpto_bitcast_deint_tail.pto b/test/lit/vmi/vmi_to_vpto_bitcast_deint_tail.pto new file mode 100644 index 0000000000..fa1a5524dc --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_bitcast_deint_tail.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_bitcast_deint_tail( + %value: !pto.vmi.vreg<129xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<129xi32, #pto.vmi.layout> { + %cast = pto.vmi.bitcast %value + : !pto.vmi.vreg<129xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<129xi32, #pto.vmi.layout> + return %cast : !pto.vmi.vreg<129xi32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_bitcast_deint_tail( +// CHECK-SAME: %[[V0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[V1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[V2:[^)]+]]: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.vreg<64xi32>) +// CHECK-DAG: %[[B0:.*]] = pto.vbitcast %[[V0]] : !pto.vreg<64xf32> -> !pto.vreg<64xi32> +// CHECK-DAG: %[[B1:.*]] = pto.vbitcast %[[V1]] : !pto.vreg<64xf32> -> !pto.vreg<64xi32> +// CHECK-DAG: %[[B2:.*]] = pto.vbitcast %[[V2]] : !pto.vreg<64xf32> -> !pto.vreg<64xi32> +// CHECK: return %[[B0]], %[[B1]], %[[B2]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_bitcast_footprint_invalid.pto b/test/lit/vmi/vmi_to_vpto_bitcast_footprint_invalid.pto new file mode 100644 index 0000000000..2d7b904af1 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_bitcast_footprint_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_bitcast_footprint_invalid( + %source: !pto.vmi.vreg<65xf32, #pto.vmi.layout>) { + %out = pto.vmi.bitcast %source + : !pto.vmi.vreg<65xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<130xi16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.bitcast requires matching source/result layouts +// CHECK-SAME: identical physical arity and matching per-chunk logical bit footprints +// CHECK-SAME: requires matching logical bit footprint in every physical chunk diff --git a/test/lit/vmi/vmi_to_vpto_bitcast_group_slots_invalid.pto b/test/lit/vmi/vmi_to_vpto_bitcast_group_slots_invalid.pto new file mode 100644 index 0000000000..49d728f73d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_bitcast_group_slots_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_bitcast_group_slots_invalid( + %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %out = pto.vmi.bitcast %source + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.bitcast requires matching source/result layouts +// CHECK-SAME: identical physical arity and matching per-chunk logical bit footprints +// CHECK-SAME: does not support group_slots layouts diff --git a/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto b/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto index 5297123e5a..f78e4ef5f2 100644 --- a/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto @@ -17,9 +17,9 @@ module { } } -// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.truncf operand #0 has type {{'?}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.truncf operand #0 has type {{'?}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> // CHECK-SAME: but requires {{'?}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> -// CHECK-SAME: pto.vmi.ensure_layout cannot materialize this conversion +// CHECK-SAME: pto.vmi.ensure_layout has no registered materialization recipe // CHECK: failed helper conversion {{'?}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> // CHECK-SAME: {{'?}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> // CHECK-SAME: requires source and result to have the same physical arity diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 80f8c22469..ea3c73fac1 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -1738,8 +1738,22 @@ static LogicalResult runVMISemanticPipeline(OwningOpRef &module) { pm.enableVerifier(); pm.addPass(pto::createPTOValidateVMIIRPass()); pm.addPass(pto::createVMILayoutAssignmentPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + pm.addPass(pto::createVMILayoutFoldConsumersPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + pm.addPass(pto::createVMILayoutRematerializePass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + pm.addPass(pto::createVMILayoutSinkMaterializationPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + pm.addPass(pto::createVMILegalizeArithSelectPass()); pm.addPass(pto::createPTOValidateVMILayoutIRPass()); pm.addPass(pto::createVMIToVPTOPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); if (failed(applyConfiguredPassManagerCLOptions(pm, "VMI-to-VPTO pipeline"))) return failure(); From 76860281dea6941dcc906d9804e0d501a6d22808 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 20:36:42 +0800 Subject: [PATCH 18/31] Support multi-chunk VMI group reduce slots --- docs/designs/vmi-implementation-manual.md | 4 +- .../vmi-layout-assignment-implementation.md | 20 ++++--- .../vmi-layout-assignment-lowering-design.md | 7 +++ .../PTO/Transforms/VMILocalRecipeRegistry.h | 2 +- lib/PTO/Transforms/VMILayoutAssignment.cpp | 5 +- lib/PTO/Transforms/VMILocalRecipeRegistry.cpp | 20 ++++--- lib/PTO/Transforms/VMIToVPTO.cpp | 55 +++++++++++++++---- ...mi_layout_assignment_group_reduce_s256.pto | 28 ++++++++++ ...ate_group_reduce_slots1_recipe_invalid.pto | 2 +- .../vmi/vmi_to_vpto_group_broadcast_deint.pto | 4 +- ...mi_to_vpto_group_reduce_s256_broadcast.pto | 44 +++++++++++++++ 11 files changed, 159 insertions(+), 32 deletions(-) create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_s256.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_reduce_s256_broadcast.pto diff --git a/docs/designs/vmi-implementation-manual.md b/docs/designs/vmi-implementation-manual.md index 04da993699..04d32aea51 100644 --- a/docs/designs/vmi-implementation-manual.md +++ b/docs/designs/vmi-implementation-manual.md @@ -3111,7 +3111,9 @@ pto.vmi.group_broadcast: result may be contiguous with full physical chunks result may also be deinterleaved when S is large enough that every physical result chunk stays inside one logical group, for example N=512, G=2, S=256, - L=64, deinterleaved=4 + L=64, deinterleaved=4. If the source is + #pto.vmi.layout, the source physical part is + selected by group id rather than by source chunk id. derived group size S must divide or be a multiple of L for canonical group-slot addressing if result is contiguous and S < L, each physical chunk contains multiple group diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index 03f22ffd42..31cc618335 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -466,7 +466,9 @@ group_slots group_reduce_addf semantic recipes: S=8 vcgadd S=16 deinterleaved=2 vcgadd+vadd S=32 deinterleaved=4 vcgadd+vadd tree - S=64 contiguous slots=1 vcadd/vadd/vsel row-local reduction + S>=physical_chunk_lanes contiguous slots=1 vcadd/vadd/vsel row-local + reduction, with one physical result part per group. For f32 this covers + S=64, S=128, S=256, ... explicit-slots group_broadcast semantic recipes: slots=8/slots=1 vselr materialization to contiguous or supported @@ -1031,7 +1033,7 @@ Target local recipe matrix: load, recipe=dense_load_norm: result layout contiguous emits pto.vlds / pto.vsts NORM paths - covers dense store users and S=64 row-local reduce input + covers dense store users and full-chunk row-local reduce input load, recipe=load_dintlv2: result layout deinterleaved=2, block_elems=1 @@ -1088,8 +1090,9 @@ group_reduce_addf, recipe=s32_reduce_block8_stride: produces group_slots(G, slots=8) emits four vcgadd operations and a vadd tree -group_reduce_addf, recipe=s64_reduce_row_local: - consumes contiguous f32 with group size 64 +group_reduce_addf, recipe=full_chunk_reduce_row_local: + consumes contiguous f32 with group size that is a multiple of one physical + chunk produces group_slots(G, slots=1) target lowering emits per-row vcgadd plus vcadd; the current prototype uses the existing row-local VCADD/VADD/VSEL sequence while preserving the same @@ -1143,9 +1146,10 @@ group_reduce_addf: #pto.vmi.layout, result #pto.vmi.layout; vmi-to-vpto lowers through four VCGADDs plus a PAT_VL8 VADD tree per packed result block. - S=64 row-local assignment uses #pto.vmi.layout - and has focused layout-assignment/vmi-to-vpto lit coverage; the explicit - slots=1 generic VCADD row-local path is registered and selected locally. + Full-chunk row-local assignment, including S=64 and S=256 f32 cases, uses + #pto.vmi.layout and has focused + layout-assignment/vmi-to-vpto lit coverage; the explicit slots=1 generic + VCADD row-local path is registered and selected locally. group_broadcast: explicit slots=8/1 source layouts select @@ -2089,7 +2093,7 @@ Current evidence for the case-catalog objective: 3. every runtime case directory contains kernel.pto, launch.cpp, main.cpp, golden.py, and compare.py 4. the latest broad VMI runtime sweep passed: PASS=43 FAIL=0 -5. the latest full VMI lit sweep passed: 340/340 +5. the latest full VMI lit sweep passed: 342/342 6. every unsupported endpoint listed in section 11.3 has a diagnostic lit test 7. vmi-to-vpto decisions are represented by current-op attrs/operands, assigned layouts, helper ops, rematerialization, or diagnostics diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index 4c13b07ef8..13588bce3b 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -498,6 +498,13 @@ group_reduce f32 S=64: input contiguous result group_slots(G, slots=1) +group_reduce f32 S=128/S=256/...: + input contiguous + result group_slots(G, slots=1) + lowering reduces each full physical chunk with vcadd, accumulates all chunks + in the same logical group with lane0 vadd, and writes one physical result + part per group + group_slot_load: result group_slots(G, slots=8) for packed slots result group_slots(G, slots=1) for row-local slots diff --git a/include/PTO/Transforms/VMILocalRecipeRegistry.h b/include/PTO/Transforms/VMILocalRecipeRegistry.h index 7356be9e92..10cde1dc96 100644 --- a/include/PTO/Transforms/VMILocalRecipeRegistry.h +++ b/include/PTO/Transforms/VMILocalRecipeRegistry.h @@ -88,7 +88,7 @@ enum class VMIGroupReduceAddFRecipeKind { S8Vcgadd, S16Deinterleaved2VcgaddVadd, S32Deinterleaved4VcgaddTree, - S64ContiguousVcaddRows, + ContiguousVcaddRows, }; struct VMIGroupReduceAddFRecipe { diff --git a/lib/PTO/Transforms/VMILayoutAssignment.cpp b/lib/PTO/Transforms/VMILayoutAssignment.cpp index 85a57e4ac1..2ff9e50ae2 100644 --- a/lib/PTO/Transforms/VMILayoutAssignment.cpp +++ b/lib/PTO/Transforms/VMILayoutAssignment.cpp @@ -249,7 +249,10 @@ struct LayoutSolver { return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); if (groupSize == 32) return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); - if (groupSize == 64) + FailureOr lanesPerPart = + getDataLanesPerPart(type.getElementType()); + if (succeeded(lanesPerPart) && groupSize >= *lanesPerPart && + groupSize % *lanesPerPart == 0) return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/1); } return getGroupSlotsLayout(numGroups); diff --git a/lib/PTO/Transforms/VMILocalRecipeRegistry.cpp b/lib/PTO/Transforms/VMILocalRecipeRegistry.cpp index 7364084028..7cd5281353 100644 --- a/lib/PTO/Transforms/VMILocalRecipeRegistry.cpp +++ b/lib/PTO/Transforms/VMILocalRecipeRegistry.cpp @@ -719,21 +719,25 @@ VMILocalRecipeRegistry::getGroupReduceAddFRecipe( return fail("requires matching non-empty source/mask physical arity"); if (resultLayout.getSlots() == 1) { - if (*groupSize != 64) + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + if (failed(lanesPerPart) || *groupSize < *lanesPerPart || + *groupSize % *lanesPerPart != 0) return fail("stable group_reduce_addf slots=1 recipes support group " - "size 64"); + "sizes that are multiples of one physical chunk"); if (!sourceLayout.isContiguous() || !maskLayout.isContiguous()) - return fail("s64 group_reduce_addf requires contiguous source/mask " + return fail("slots=1 group_reduce_addf requires contiguous source/mask " "layouts"); - if (*resultArity != *sourceArity) - return fail("s64 group_reduce_addf requires source/result physical " - "arity to match"); + if (*resultArity != numGroups) + return fail("slots=1 group_reduce_addf requires one physical result " + "part per group"); std::string sourceFullReason; if (failed(checkFullDataPhysicalChunks(sourceType, &sourceFullReason))) - return fail(Twine("s64 group_reduce_addf requires full source chunks; ") + + return fail(Twine("slots=1 group_reduce_addf requires full source " + "chunks; ") + sourceFullReason); return VMIGroupReduceAddFRecipe{ - VMIGroupReduceAddFRecipeKind::S64ContiguousVcaddRows}; + VMIGroupReduceAddFRecipeKind::ContiguousVcaddRows}; } if (*groupSize == 8) { diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index a59b5dbadb..1c92be4018 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -5725,10 +5725,17 @@ struct OneToNVMIGroupReduceAddFOpPattern &lanesPerPart, &groupCount, &chunksPerGroup, rewriter))) return failure(); + VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); + bool rowLocalSlots1Result = + resultLayout && resultLayout.isGroupSlots() && + resultLayout.getNumGroups() == groupCount && + resultLayout.getSlots() == 1; + int64_t expectedResultParts = + rowLocalSlots1Result ? groupCount : groupCount * chunksPerGroup; if (sourceParts.size() != maskParts.size() || static_cast(sourceParts.size()) != groupCount * chunksPerGroup || - resultTypes.size() != sourceParts.size()) + static_cast(resultTypes.size()) != expectedResultParts) return rewriter.notifyMatchFailure( op, "group_reduce_addf requires matching source/mask/result arity"); @@ -5782,7 +5789,7 @@ struct OneToNVMIGroupReduceAddFOpPattern .getResult(); } - int64_t destChunk = group * chunksPerGroup; + int64_t destChunk = rowLocalSlots1Result ? group : group * chunksPerGroup; results[destChunk] = rewriter .create(op.getLoc(), resultType, *accumulator, @@ -5857,6 +5864,18 @@ struct OneToNVMIGroupBroadcastOpPattern resultLayout.isDeinterleaved() && resultLayout.getBlockElems() > 1 && *groupSize < lanesPerPart) selectionGroupSize = resultLayout.getBlockElems(); + auto resolveLargeGroupSource = [&](int64_t group, int64_t chunksPerGroup, + int64_t &sourceChunk, + int64_t &baseGroupSlot) { + int64_t slots = sourceLayout.getSlots(); + if (slots > 0) { + sourceChunk = group / slots; + baseGroupSlot = group % slots; + return; + } + sourceChunk = group * chunksPerGroup; + baseGroupSlot = 0; + }; SmallVector results; results.resize(resultTypes.size()); @@ -5871,7 +5890,8 @@ struct OneToNVMIGroupBroadcastOpPattern if (*groupSize >= lanesPerPart) { int64_t chunksPerGroup = *groupSize / lanesPerPart; int64_t group = flatIndex / chunksPerGroup; - sourceChunk = group * chunksPerGroup; + resolveLargeGroupSource(group, chunksPerGroup, sourceChunk, + baseGroupSlot); } else { VMILayoutAttr sourceLayout = sourceVMIType.getLayoutAttr(); int64_t slots = sourceLayout.getSlots(); @@ -5953,7 +5973,8 @@ struct OneToNVMIGroupBroadcastOpPattern return rewriter.notifyMatchFailure( op, "group_broadcast result chunk crosses logical groups"); int64_t chunksPerGroup = *groupSize / lanesPerPart; - sourceChunk = firstGroup * chunksPerGroup; + resolveLargeGroupSource(firstGroup, chunksPerGroup, sourceChunk, + baseGroupSlot); found = true; break; } @@ -5968,12 +5989,26 @@ struct OneToNVMIGroupBroadcastOpPattern sourceChunk >= static_cast(sourceParts.size())) return rewriter.notifyMatchFailure( op, "group_broadcast source chunk is out of range"); - results[flatIndex] = - rewriter - .create(op.getLoc(), resultType, - sourceParts[sourceChunk], *allMask, - rewriter.getStringAttr("LOWEST")) - .getResult(); + if (sourceLayout.getSlots() > 1) { + FailureOr groupSlotIndex = createGroupSlotIndexVector( + op.getLoc(), indexType, selectionGroupSize, baseGroupSlot, + rewriter); + if (failed(groupSlotIndex)) + return rewriter.notifyMatchFailure( + op, "failed to create group_broadcast group-slot index vector"); + results[flatIndex] = + rewriter + .create(op.getLoc(), resultType, + sourceParts[sourceChunk], *groupSlotIndex) + .getResult(); + } else { + results[flatIndex] = + rewriter + .create(op.getLoc(), resultType, + sourceParts[sourceChunk], *allMask, + rewriter.getStringAttr("LOWEST")) + .getResult(); + } } else { bool blockFragmentSmallGroup = resultLayout && resultLayout.isDeinterleaved() && diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s256.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s256.pto new file mode 100644 index 0000000000..15fba5a1de --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s256.pto @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_reduce_s256( + %source: !pto.vmi.vreg<512xf32>, + %mask: !pto.vmi.mask<512xpred>) -> !pto.vmi.vreg<512xf32> { + %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 2, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<512xf32> + return %out : !pto.vmi.vreg<512xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_reduce_s256( +// CHECK-SAME: %arg0: !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// CHECK-SAME: %arg1: !pto.vmi.mask<512xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.group_reduce_addf %arg0, %arg1 +// CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto index f4071e4c47..6e0b04e8f6 100644 --- a/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto @@ -13,7 +13,7 @@ module { %source: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, %mask: !pto.vmi.mask<256xb32, #pto.vmi.layout>) { // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots local recipe - // CHECK-SAME: stable group_reduce_addf slots=1 recipes support group size 64 + // CHECK-SAME: stable group_reduce_addf slots=1 recipes support group sizes that are multiples of one physical chunk // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_reduce_addf" %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_deint.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_deint.pto index 078b61b5bf..9c2aff3759 100644 --- a/test/lit/vmi/vmi_to_vpto_group_broadcast_deint.pto +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_deint.pto @@ -10,13 +10,13 @@ module { func.func @vmi_to_vpto_group_broadcast_deint( - %sum: !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + %sum: !pto.vmi.vreg<512xf32, #pto.vmi.layout>, %src_f8: !pto.vmi.vreg<512xf8E4M3FN>) -> !pto.vmi.vreg<512xf32> { %src_f32 = pto.vmi.extf %src_f8 : !pto.vmi.vreg<512xf8E4M3FN> -> !pto.vmi.vreg<512xf32> %sum_vec = pto.vmi.group_broadcast %sum {num_groups = 2} - : !pto.vmi.vreg<512xf32, #pto.vmi.layout> + : !pto.vmi.vreg<512xf32, #pto.vmi.layout> -> !pto.vmi.vreg<512xf32> %out = pto.vmi.mulf %sum_vec, %src_f32 : !pto.vmi.vreg<512xf32>, !pto.vmi.vreg<512xf32> -> !pto.vmi.vreg<512xf32> diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_s256_broadcast.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_s256_broadcast.pto new file mode 100644 index 0000000000..f2681f3359 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_s256_broadcast.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_reduce_s256_broadcast( + %source: !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<512xb32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %sum = pto.vmi.group_reduce_addf %source, %mask {num_groups = 2, reassoc} + : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + !pto.vmi.mask<512xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + %broadcast = pto.vmi.group_broadcast %sum {num_groups = 2} + : !pto.vmi.vreg<512xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 = "pto.vmi.unpack"(%broadcast) + : (!pto.vmi.vreg<512xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_s256_broadcast( +// CHECK: pto.vcadd +// CHECK: pto.vadd +// CHECK: pto.vsel +// CHECK: pto.vdup {{.*}} {position = "LOWEST"} +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast From 58787c2f6d9341c413d58fbbfb06142c43356337 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 23:48:11 +0800 Subject: [PATCH 19/31] Implement typed VMI group reduce lowering --- docs/designs/vmi-implementation-manual.md | 4 +- .../vmi-layout-assignment-implementation.md | 168 ++++-- .../vmi-layout-assignment-lowering-design.md | 125 ++-- docs/designs/vmi-layout-lowering-cases.md | 552 ++++++++++++++++++ include/PTO/IR/VMIOps.td | 34 ++ .../PTO/Transforms/VMILocalRecipeRegistry.h | 44 +- .../PTO/Transforms/VMITargetCapabilities.h | 23 +- lib/PTO/IR/VMI.cpp | 104 ++++ lib/PTO/Transforms/VMILayoutAssignment.cpp | 158 ++++- lib/PTO/Transforms/VMILocalRecipeRegistry.cpp | 225 +++++-- lib/PTO/Transforms/VMIToVPTO.cpp | 429 ++++++++++++-- .../vmi/vmi_group_reduce_addi_i16_invalid.pto | 24 + .../vmi/vmi_group_reduce_addi_i8_invalid.pto | 24 + ...ut_assignment_group_reduce_s12_invalid.pto | 2 +- ...i_layout_assignment_group_reduce_typed.pto | 56 ++ ...ayout_gate_group_reduce_recipe_invalid.pto | 2 +- ...ate_group_reduce_slots1_recipe_invalid.pto | 2 +- ..._group_slots_unsupported_slots_invalid.pto | 2 +- .../vmi/vmi_to_vpto_group_reduce_typed.pto | 80 +++ .../vmi/vmi_to_vpto_integer_cast_reduce.pto | 44 ++ test/lit/vmi/vmi_to_vpto_integer_casts.pto | 64 ++ ...id.pto => vmi_to_vpto_reduce_addf_f16.pto} | 27 +- .../vmi_to_vpto_trunci_i8_signed_invalid.pto | 29 + .../group-reduce-f16-addf-store/compare.py | 37 ++ .../vmi/group-reduce-f16-addf-store/golden.py | 43 ++ .../group-reduce-f16-addf-store/kernel.pto | 51 ++ .../group-reduce-f16-addf-store/launch.cpp | 34 ++ .../vmi/group-reduce-f16-addf-store/main.cpp | 86 +++ .../group-reduce-f16-addf-store/ptoas.flags | 1 + .../compare.py | 37 ++ .../golden.py | 43 ++ .../kernel.pto | 52 ++ .../launch.cpp | 36 ++ .../main.cpp | 88 +++ .../ptoas.flags | 1 + .../group-reduce-i32-addi-store/compare.py | 37 ++ .../vmi/group-reduce-i32-addi-store/golden.py | 42 ++ .../group-reduce-i32-addi-store/kernel.pto | 51 ++ .../group-reduce-i32-addi-store/launch.cpp | 35 ++ .../vmi/group-reduce-i32-addi-store/main.cpp | 86 +++ .../group-reduce-i32-addi-store/ptoas.flags | 1 + .../compare.py | 37 ++ .../golden.py | 44 ++ .../kernel.pto | 52 ++ .../launch.cpp | 36 ++ .../main.cpp | 88 +++ .../ptoas.flags | 1 + 47 files changed, 3039 insertions(+), 202 deletions(-) create mode 100644 test/lit/vmi/vmi_group_reduce_addi_i16_invalid.pto create mode 100644 test/lit/vmi/vmi_group_reduce_addi_i8_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_typed.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_reduce_typed.pto create mode 100644 test/lit/vmi/vmi_to_vpto_integer_cast_reduce.pto create mode 100644 test/lit/vmi/vmi_to_vpto_integer_casts.pto rename test/lit/vmi/{vmi_to_vpto_reduce_addf_f16_invalid.pto => vmi_to_vpto_reduce_addf_f16.pto} (55%) create mode 100644 test/lit/vmi/vmi_to_vpto_trunci_i8_signed_invalid.pto create mode 100644 test/vpto/cases/vmi/group-reduce-f16-addf-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-f16-addf-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-f16-addf-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-f16-addf-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-f16-addf-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-f16-addf-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-i32-addi-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-i32-addi-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-i32-addi-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-i32-addi-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-i32-addi-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-i32-addi-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/ptoas.flags diff --git a/docs/designs/vmi-implementation-manual.md b/docs/designs/vmi-implementation-manual.md index 04d32aea51..497e951e73 100644 --- a/docs/designs/vmi-implementation-manual.md +++ b/docs/designs/vmi-implementation-manual.md @@ -3867,8 +3867,8 @@ Slice 4 完成条件: per-feature negative tests. 9. Same-family reduction ops reject unsupported direct-lowering shapes consistently. Covered by vmi_to_vpto_reduce_shape_invalid.pto together with the existing reduce add/min/max positive and - per-feature negative tests, including vmi_to_vpto_reduce_addi_i16_invalid.pto and - vmi_to_vpto_reduce_addf_f16_invalid.pto. + per-feature tests, including vmi_to_vpto_reduce_addi_i16_invalid.pto for narrow integer rejection and + vmi_to_vpto_reduce_addf_f16.pto for f16 floating-point reduction lowering. 10. Target-specific element contracts are checked before OneToN rewriting for direct VPTO ops. Covered by vmi_to_vpto_bf16_arith.pto, vmi_to_vpto_math_element_type_invalid.pto, vmi_to_vpto_cmp_select.pto, vmi_to_vpto_cmp_element_type_invalid.pto, diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index 31cc618335..a6583a3d8b 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -98,7 +98,7 @@ pto-validate-vmi-layout-ir: fail before `vmi-to-vpto`. It also checks the first semantic local-recipe families, non-contiguous `pto.vmi.store`/`pto.vmi.tile_write`, block8 `pto.vmi.group_load`, `pto.vmi.group_slot_load`, group_slots - `pto.vmi.group_store`, group_slots `pto.vmi.group_reduce_addf`, + `pto.vmi.group_store`, group_slots `pto.vmi.group_reduce_add{f|i}`, explicit-slots `pto.vmi.group_broadcast`, `pto.vmi.truncf`, `pto.vmi.extf`, and `pto.vmi.bitcast`, at the layout gate. @@ -294,7 +294,7 @@ Local-decision table for the current implementation: op local decision inputs group_load result layout, num_groups, row_stride, source type group_slot_load result group_slots layout and source_group_stride -group_reduce_addf source/mask/result layouts, num_groups, reassoc +group_reduce_add{f|i} source/mask/result layouts, num_groups, typed reduce semantics group_broadcast source/result layouts and num_groups truncf source/result layouts and element widths ensure_layout always carries source/result layouts instead of recipe @@ -329,8 +329,8 @@ error. Examples of forbidden recovery in `vmi-to-vpto`: ```text -group_reduce_addf cannot walk to a load/group_load producer to choose S=16 - parity versus block8. +group_reduce_add{f|i} cannot walk to a load/group_load producer to choose + two-vlane parity versus block8. group_store cannot inspect the group_reduce producer; it consumes only the assigned source layout and explicit stride. group_broadcast cannot inspect sibling users to decide whether to rematerialize. @@ -354,12 +354,17 @@ create_group_mask extf truncf +extsi +extui +trunci addf +addi mulf select broadcast group_reduce_addf +group_reduce_addi group_broadcast group_store @@ -368,6 +373,30 @@ ensure_mask_layout // internal ensure_mask_granularity // internal ``` +Type policy before lowering: + +```text +storage / memory boundary: + f8-like, i8, f16, i16, f32, i32 may appear as load/store element types when + the target memory instruction supports the physical width. + +cast boundary: + f8-like may appear as extf/truncf source or destination. + i8 may appear as extsi/extui/trunci source or destination. Signedness is an + op semantic, not a VMI type spelling. + Current VPTO lowering supports 32-bit integer narrowing to unsigned i8 + storage, matching the available VCVTII s32/u32 -> u8 forms; signed i8 + narrowing needs a separate target recipe. + +compute / accumulator: + floating compute baseline: f16/f32, with reassoc required for reductions + that lower through pair-wise VPTO reductions. + integer compute baseline: i32 for grouped reduction; i8/i16 storage must + first cast to i32 because integer reduction instructions widen narrow inputs. + f8/i8 are not baseline accumulator/compute types. Supporting direct 8-bit + compute requires a target capability entry and a separate recipe family. +``` + Important semantic split: ```text @@ -462,13 +491,20 @@ group_slots group_store semantic recipes: slots=8 unit-stride vsts slots=1 aligned lane-0 vsts per group -group_slots group_reduce_addf semantic recipes: - S=8 vcgadd - S=16 deinterleaved=2 vcgadd+vadd - S=32 deinterleaved=4 vcgadd+vadd tree - S>=physical_chunk_lanes contiguous slots=1 vcadd/vadd/vsel row-local - reduction, with one physical result part per group. For f32 this covers - S=64, S=128, S=256, ... +group_slots group_reduce_add{f|i} semantic recipes: + define E = sizeof(T), VLaneElems = 32B / E, L = 256B / E, S = N / G. + T is the accumulator/reduce element type after any required storage cast. + f8 storage reduces through f32; i8 storage reduces through an explicit + signed/unsigned integer cast to an accumulator type such as i32. In the + baseline contract, f8/i8 are cast-boundary storage types rather than + accumulator/compute types. + S=VLaneElems contiguous vcgadd + S=2*VLaneElems deinterleaved=2 vcgadd+vadd + S=4*VLaneElems deinterleaved=4 vcgadd+vadd tree + S>=L && S%L==0 contiguous slots=1 vcadd/vadd/vsel row-local reduction, + with one physical result part per group. For 32-bit element types this covers + S=64, S=128, S=256, ...; for 16-bit element types this covers S=128, S=256, + ... explicit-slots group_broadcast semantic recipes: slots=8/slots=1 vselr materialization to contiguous or supported @@ -481,6 +517,14 @@ extf/truncf semantic recipes: deinterleaved=4 f32 -> contiguous f8-like group_slots(G, slots=1) f32 -> f16 +extsi/extui/trunci semantic recipes: + contiguous i8 -> deinterleaved=2 i16 through VCVTII.{s,u}82{s,u}16 #part + contiguous i8 -> deinterleaved=4 i32 through VCVTII.{s,u}82{s,u}32 #pp + deinterleaved=2 i16 -> contiguous i8 through VCVTII.*162*8 #part + deinterleaved=4 i32 -> contiguous ui8 through VCVTII.*322u8 #pp + packed group_slots integer width-changing cast is unsupported until a + slot-wise cast recipe is defined. + bitcast semantic recipes: per-part vbitcast for contiguous/deinterleaved layouts when source/result layouts match, physical arity matches, and every physical chunk carries the @@ -651,13 +695,15 @@ buildCastRequests: exists buildGroupReduceRequests: - derive S = logical_lanes / num_groups - S=8 -> contiguous source, group_slots(G,8) result - S=16 -> deinterleaved=2/block_elems=1 or block_elems=8 source, - group_slots(G,8) result - S=32 -> deinterleaved=4/block_elems=1 or block_elems=8 source, - group_slots(G,8) result - S=64 -> contiguous source, group_slots(G,1) result + derive E = sizeof(accumulator type), VLaneElems = 32B / E, + L = 256B / E, and S = logical_lanes / num_groups + S=VLaneElems -> contiguous source, group_slots(G,8) result + S=2*VLaneElems -> deinterleaved=2/block_elems=1 or block_elems=8 source, + group_slots(G,8) result + S=4*VLaneElems -> deinterleaved=4/block_elems=1 or block_elems=8 source, + group_slots(G,8) result + S>=L && S%L==0 -> contiguous source, group_slots(G,1) result + 8-bit storage must be cast to an accumulator type before this request builder other S -> diagnostic unless an explicit fallback recipe is enabled buildGroupMemoryRequests: @@ -866,10 +912,10 @@ vmi-to-vpto contract: ```text case family builder / owner assignment artifact -3.4 S=8 reduce buildGroupReduceRequests s8_reduce_contiguous recipe -3.5 S=16 reduce buildGroupReduceRequests s16_reduce_parity/block8 recipe -3.6 S=32 reduce buildGroupReduceRequests s32_reduce_dintlv4/block8 recipe -3.7 S=64 reduce buildGroupReduceRequests s64_reduce_row_local recipe +3.4 32-bit S=8 reduce buildGroupReduceRequests one_vlane contiguous recipe +3.5 32-bit S=16 reduce buildGroupReduceRequests two_vlane parity/block8 recipe +3.6 32-bit S=32 reduce buildGroupReduceRequests four_vlane dintlv4/block8 recipe +3.7 32-bit S=64 reduce buildGroupReduceRequests full_chunk row_local recipe 3.11.1 S=64 active-row tail buildMaskRequests active-row store/reduce masks 3.19.1 S=16 block_elems choice buildGroupReduceRequests explicit block_elems layout 3.38 multi-tile S=32 reduce buildGroupReduceRequests multiple group_slots chunks @@ -1065,34 +1111,34 @@ group_load, recipe=group_load_contiguous_chunks: emits one vlds per physical group chunk using row_stride address arithmetic covers the currently implemented full-chunk row-local group_load path -group_reduce_addf, recipe=s8_reduce_contiguous: - consumes contiguous f32 with group size 8 +group_reduce_add{f|i}, recipe=one_vlane_reduce_contiguous: + consumes contiguous accumulator type T with group size VLaneElems(T) produces group_slots(G, slots=8) emits one vcgadd -group_reduce_addf, recipe=s16_reduce_parity: +group_reduce_add{f|i}, recipe=two_vlane_reduce_deinterleaved: consumes deinterleaved=2, block_elems=1 produces group_slots(G, slots=8) emits two vcgadd operations and one vadd -group_reduce_addf, recipe=s16_reduce_block8: +group_reduce_add{f|i}, recipe=two_vlane_reduce_block8: consumes deinterleaved=2, block_elems=8 produces group_slots(G, slots=8) emits two vcgadd operations and one vadd -group_reduce_addf, recipe=s32_reduce_dintlv4: +group_reduce_add{f|i}, recipe=four_vlane_reduce_dintlv4: consumes deinterleaved=4, block_elems=1 produces group_slots(G, slots=8) emits four vcgadd operations and a vadd tree -group_reduce_addf, recipe=s32_reduce_block8_stride: +group_reduce_add{f|i}, recipe=four_vlane_reduce_block8_stride: consumes deinterleaved=4, block_elems=8 produces group_slots(G, slots=8) emits four vcgadd operations and a vadd tree -group_reduce_addf, recipe=full_chunk_reduce_row_local: - consumes contiguous f32 with group size that is a multiple of one physical - chunk +group_reduce_add{f|i}, recipe=full_chunk_reduce_row_local: + consumes contiguous accumulator type T with group size that is a multiple of + one physical chunk L(T) produces group_slots(G, slots=1) target lowering emits per-row vcgadd plus vcadd; the current prototype uses the existing row-local VCADD/VADD/VSEL sequence while preserving the same @@ -1150,6 +1196,9 @@ group_reduce_addf: #pto.vmi.layout and has focused layout-assignment/vmi-to-vpto lit coverage; the explicit slots=1 generic VCADD row-local path is registered and selected locally. + group_reduce_addi is implemented for i32 accumulator values. i8/i16 storage + must be widened explicitly before grouped reduction because narrow integer + reduction instructions widen their result. group_broadcast: explicit slots=8/1 source layouts select @@ -1199,18 +1248,69 @@ group_store: design target unless a strided packed-lane store recipe is made explicit. ``` +Current implementation contract for type-generic grouped reduction: + +```text +ODS/verifiers: + pto.vmi.group_reduce_addi is the integer counterpart to group_reduce_addf. + group_reduce_addi accepts i32 accumulator element types; i8/i16 direct + grouped reduction is rejected with a diagnostic that points users to + extsi/extui. + extsi/extui/trunci carry integer signedness across storage/accumulator + boundaries without overloading add semantics. + +Layout assignment: + compute VLaneElems and L from the accumulator/reduce element type: + VLaneElems = 32B / sizeof(accumulator T) + L = 256B / sizeof(accumulator T) + use the same S formula for f16/f32/i32 once the typed reduce op and target + capability say the type is legal. + route f8 storage through extf to f32 before group_reduce_addf. + route i8/i16 storage through extsi/extui to i32 before group_reduce_addi. + route integer narrowing to i8 through trunci; direct i8 compute remains + illegal unless the target capability registry exposes an explicit recipe. + diagnose direct f8/i8 compute use with a message that points at the offending + op and suggests inserting the explicit cast when the op is meant to consume + storage data. + +Local recipe registry: + replace f32-shaped recipe keys with width-parametric recipe classes: + one_vlane_reduce + two_vlane_reduce_deinterleaved + four_vlane_reduce_deinterleaved + full_chunk_row_local_reduce + key legality on accumulator byte width, source/mask layout, result + group_slots layout, num_groups, and target instruction capability. + +VMI-to-VPTO: + lower group_reduce_addi through the same VCGADD/VADD skeleton used for + floating-point where the target supports the integer accumulator type. + materialize integer casts explicitly before reduction; direct i8 group reduce + and direct i16 group reduce must not silently become a widening reduction in + this pass. + keep VPTO lowering local: it consumes assigned layouts and registered local + recipes, but does not invent a new global layout plan. + +Tests: + cover f16 direct and i16-storage-to-i32 grouped reductions. + add i32 S=8/S=16/S=32/S=64 group-reduce cases. + add f8 storage -> extf -> f32 group_reduce_addf cases. + add i8/i16 storage -> extsi/extui -> i32 group_reduce_addi cases. + add invalid direct f8/i8/i16 grouped-reduce diagnostics. +``` + Examples: ```text -group_reduce_addf, recipe=s16_reduce_parity: +group_reduce_add{f|i}, recipe=two_vlane_reduce_deinterleaved: consume deinterleaved=2, block_elems=1 emit two VCGADDs and one VADD -group_reduce_addf, recipe=s16_reduce_block8: +group_reduce_add{f|i}, recipe=two_vlane_reduce_block8: consume deinterleaved=2, block_elems=8 emit two VCGADDs and one VADD -group_reduce_addf, recipe=s32_reduce_dintlv4: +group_reduce_add{f|i}, recipe=four_vlane_reduce_dintlv4: consume deinterleaved=4 emit four VCGADDs and reduction tree diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index 13588bce3b..82a84082c6 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -94,11 +94,18 @@ dense cast: f16 -> f32 -> store f32 -> f16 -> store f8 -> f32 -> compute -> f8 + f8 -> f32 accumulator -> group_reduce_addf + i8/i16 -> signed/unsigned integer cast to i32 accumulator + -> group_reduce_addi + f8/i8 appear as cast source or cast destination at compute boundaries + integer narrowing back to i8 is an explicit cast, not implicit arithmetic f16 -> f32 shared by dense store and S=16 reduce f32 shared by f8 store and S=32 reduce group reduce: - S=8, S=16, S=32, S=64 + 32-bit accumulator: S=8, S=16, S=32, S=64 + 16-bit accumulator: S=16, S=32, S=64, S=128 + 8-bit storage reduces only through an explicit accumulator cast reduce -> group_store reduce -> group_slot_load/elemwise -> group_store reduce -> group_broadcast -> elemwise -> reduce -> store @@ -205,6 +212,29 @@ diagnostic-only cases: Layout is a property of a layout-assigned VMI value, not a property inferred by the final lowering pattern. +Type policy: + +```text +storage boundary: + f8-like/i8/f16/i16/f32/i32 may appear in load/store values when the target + memory instruction supports the physical width. + +cast boundary: + f8-like participates through extf/truncf. + i8 participates through extsi/extui/trunci. Signedness is carried by the + cast op semantics, not by a separate layout. + On the current VPTO target, 32-bit to 8-bit integer narrowing is only a + baseline recipe for unsigned i8 results because the available VCVTII forms + are s32/u32 -> u8. + +compute boundary: + baseline floating compute uses f16/f32. + baseline integer grouped reduction compute uses i32 accumulators. i8/i16 + storage must be widened first because integer reduction instructions widen + narrow inputs. + f8/i8 are not baseline accumulator/compute element types. +``` + ### 2.1 Dense Layouts ```text @@ -348,10 +378,37 @@ group_slot_load: rematerialized into two ops when different users require different result layouts; each clone is then locally deterministic. -group_reduce_addf: +group_reduce_add{f|i}: source/mask layout, result group_slots layout, num_groups, element type, and - reassoc decide S=8 contiguous vcgadd, S=16/S=32 deinterleaved vcgadd trees, - and S=64 row-local vcadd/vsel lowering. + the typed reduce semantics decide the local reduction recipe. The recipe is + not keyed by f32 shape names. It is derived from the element byte width. + Floating-point `group_reduce_addf` carries `reassoc`; integer + `group_reduce_addi` does not. + + VLaneElems = 32B / sizeof(T) + L = 256B / sizeof(T) + S = logical_lane_count / num_groups + + S == VLaneElems -> contiguous vcgadd, result slots=8 + S == 2 * VLaneElems -> deinterleaved=2 vcgadd tree, result slots=8 + S == 4 * VLaneElems -> deinterleaved=4 vcgadd tree, result slots=8 + S >= L && S % L == 0 -> contiguous row-local vcadd/vsel, result slots=1 + + Type support is controlled by the typed reduce op semantics and target + capability, not by separate per-type shape rules. Once a type is legal for a + reduce op, the same formula above selects its layout and local recipe. The + current checked-in implementation may lag this design target; that is staged + implementation status, not a design boundary. + + The formula is applied to the accumulator/reduce element type, not + necessarily the storage element type. 8-bit floating-point storage first + casts to f32 for `group_reduce_addf`; 8-bit and 16-bit integer storage first + casts to a signed/unsigned i32 accumulator for + `group_reduce_addi`. In the baseline VMI contract, f8/i8 are storage and + cast-boundary types: they may be the source or destination of cast, load, and + store, but they are not accumulator/compute types for group reduce. Direct + 8-bit grouped reduction is illegal unless the target exposes an explicit + 8-bit compute recipe. group_broadcast: source group_slots layout, result dense layout, num_groups, and element type @@ -480,30 +537,15 @@ load: ### 5.3 Group Recipes From Cases ```text -group_reduce f32 S=8: - input contiguous - result group_slots(G, slots=8) - -group_reduce f32 S=16: - legal input layout A: deinterleaved=2, block_elems=1 - legal input layout B: deinterleaved=2, block_elems=8 - result group_slots(G, slots=8) - -group_reduce f32 S=32: - legal input layout A: deinterleaved=4, block_elems=1 - legal input layout B: deinterleaved=4, block_elems=8 - result group_slots(G, slots=8) - -group_reduce f32 S=64: - input contiguous - result group_slots(G, slots=1) - -group_reduce f32 S=128/S=256/...: - input contiguous - result group_slots(G, slots=1) - lowering reduces each full physical chunk with vcadd, accumulates all chunks - in the same logical group with lane0 vadd, and writes one physical result - part per group +group_reduce_add{f|i} typed shape classification: + define E = sizeof(T), VLaneElems = 32B / E, L = 256B / E, S = N / G. + S=VLaneElems uses contiguous input and group_slots(G, slots=8). + S=2*VLaneElems uses deinterleaved=2 input/mask and group_slots(G, slots=8). + S=4*VLaneElems uses deinterleaved=4 input/mask and group_slots(G, slots=8). + S>=L && S%L==0 uses contiguous input/mask and group_slots(G, slots=1); + lowering reduces each full physical chunk, accumulates all chunks in the + same logical group through lane0, and writes one physical result part per + group. group_slot_load: result group_slots(G, slots=8) for packed slots @@ -585,21 +627,18 @@ truncf f32 -> f8: requests source deinterleaved=4, block_elems=1 requests result contiguous f8 -group_reduce S=8: - requests source contiguous - requests result group_slots(num_groups, slots=8) - -group_reduce S=16: - requests source deinterleaved=2, block_elems=1 or block_elems=8 - requests result group_slots(num_groups, slots=8) - -group_reduce S=32: - requests source deinterleaved=4, block_elems=1 or block_elems=8 - requests result group_slots(num_groups, slots=8) - -group_reduce S=64: - requests source contiguous - requests result group_slots(num_groups, slots=1) +group_reduce_add{f|i}: + computes E = sizeof(accumulator type), VLaneElems = 32B / E, + L = 256B / E, and S = logical_lanes / num_groups + S=VLaneElems requests source contiguous and result group_slots(G, slots=8) + S=2*VLaneElems requests source deinterleaved=2 and result + group_slots(G, slots=8) + S=4*VLaneElems requests source deinterleaved=4 and result + group_slots(G, slots=8) + S>=L && S%L==0 requests source contiguous and result + group_slots(G, slots=1) + 8-bit storage reaches this request only after an explicit cast to the + accumulator type group_broadcast: requests source group_slots(num_groups, slots=K) diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md index e084ad58c0..e17c14844b 100644 --- a/docs/designs/vmi-layout-lowering-cases.md +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -5467,3 +5467,555 @@ store itself can locally prove the same contiguous memory effect from the source layout. vmi-to-vpto must not scan the `%w` producer or both store users to decide this. ``` + +### 3.47 Type-Parametric Group Reduce Rule + +The group-reduce layout rule is parameterized by the element width, not by f32 +case names. + +```text +E = sizeof(T) +VLaneElems = 32B / E +L = 256B / E +S = logical_lane_count / num_groups +``` + +The canonical grouped-reduce layouts are: + +```text +S == VLaneElems: + source/mask layout = contiguous + result layout = group_slots(num_groups=G, slots=8) + +S == 2 * VLaneElems: + source/mask layout = deinterleaved=2 + result layout = group_slots(num_groups=G, slots=8) + +S == 4 * VLaneElems: + source/mask layout = deinterleaved=4 + result layout = group_slots(num_groups=G, slots=8) + +S >= L && S % L == 0: + source/mask layout = contiguous + result layout = group_slots(num_groups=G, slots=1) +``` + +Concrete shape table: + +```text +T VLaneElems L packed cases row-local cases +f32 8 64 S=8, S=16, S=32 S=64, S=128, ... +i32 8 64 S=8, S=16, S=32 S=64, S=128, ... +f16 16 128 S=16, S=32, S=64 S=128, S=256, ... +i16 16 128 S=16, S=32, S=64 S=128, S=256, ... +f8 32 256 cast to f32 before grouped reduce +i8 32 256 cast to i16/i32 before grouped reduce +``` + +These non-f32 cases are part of the type-generic layout/lowering design. If a +typed reduce op admits the element type and the target capability registry +accepts it, assignment must use the same `VLaneElems/L/S` formula instead of +adding per-type shape special cases. Any f32-only behavior in the current +implementation is staged implementation status, not the intended design limit. +For the current baseline, `f8/i8` are storage and cast-boundary types: they are +valid as load/store element types and as cast source/destination, but compute +ops such as group reduce consume the post-cast accumulator type. + +### 3.48 16-bit Typed Group Reduce, `S = VLaneElems = 16` + +This case covers both `f16` and `i16`. The element width is the same, so the +layout and VPTO instruction skeleton are identical. The VMI op name carries the +semantic difference: + +```text +f16: pto.vmi.group_reduce_addf ... {reassoc} +i16 storage: pto.vmi.extsi/extui ... -> i32 group_reduce_addi ... +``` + +VMI-shaped input: + +```text +// Floating form. +%xf = pto.vmi.load %base_f16[%off] + : memref<128xf16> -> !pto.vmi.vreg<128xf16> +%mf = pto.vmi.create_group_mask %c16 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> +%sumf = pto.vmi.group_reduce_addf %xf, %mf {num_groups = 8, reassoc} +pto.vmi.group_store %sumf, %out_f16[%group_off], %c1 {num_groups = 8} + +// Integer form. +%xi = pto.vmi.load %base_i16[%off] + : memref<128xi16> -> !pto.vmi.vreg<128xi16> +%mi = pto.vmi.create_group_mask %c16 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> +%sumi = pto.vmi.group_reduce_addi %xi, %mi {num_groups = 8} +pto.vmi.group_store %sumi, %out_i16[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%xf, %mf, %xi, %mi: + #pto.vmi.layout + +%sumf: + !pto.vmi.vreg<128xf16, #pto.vmi.layout> + +%sumi: + !pto.vmi.vreg<128xi16, #pto.vmi.layout> +``` + +VPTO lowering shape: + +```text +%x0 = pto.vlds %base[%off] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<128xT16> + +%all_b16 = pto.pge_b16 "PAT_ALL" +%slot8_b16 = pto.pge_b16 "PAT_VL8" + +%sum0 = pto.vcgadd %x0, %all_b16 + : !pto.vreg<128xT16>, !pto.mask -> !pto.vreg<128xT16> + +pto.vsts %sum0, %out[%group_off], %slot8_b16 {dist = "NORM_B16"} + : !pto.vreg<128xT16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r] = reduce_T16(base[off + r * 16 + 0 .. 15]) +``` + +### 3.49 16-bit Typed Group Reduce, `S = 2 * VLaneElems = 32` + +This case covers both `f16` and `i16`. Each logical row is 64B and must be +split into two 32B VLane fragments before `vcgadd`. + +VMI-shaped input: + +```text +%x = pto.vmi.load %base[%off] + : memref<256xT16> -> !pto.vmi.vreg<256xT16> +%mask = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> +%sum = pto.vmi.group_reduce_add{f|i} %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x, %mask: + #pto.vmi.layout + +%sum: + !pto.vmi.vreg<256xT16, #pto.vmi.layout> +``` + +VPTO lowering shape: + +```text +%x_p0, %x_p1 = pto.vldsx2 %base[%off], "DINTLV_B16" + : !pto.ptr, index -> !pto.vreg<128xT16>, !pto.vreg<128xT16> + +%all_b16 = pto.pge_b16 "PAT_ALL" +%slot8_b16 = pto.pge_b16 "PAT_VL8" + +%s0 = pto.vcgadd %x_p0, %all_b16 + : !pto.vreg<128xT16>, !pto.mask -> !pto.vreg<128xT16> +%s1 = pto.vcgadd %x_p1, %all_b16 + : !pto.vreg<128xT16>, !pto.mask -> !pto.vreg<128xT16> +%sum0 = pto.vadd %s0, %s1, %slot8_b16 + : !pto.vreg<128xT16>, !pto.vreg<128xT16>, !pto.mask + -> !pto.vreg<128xT16> + +pto.vsts %sum0, %out[%group_off], %slot8_b16 {dist = "NORM_B16"} + : !pto.vreg<128xT16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r] = reduce_T16(base[off + r * 32 + 0 .. 31]) +``` + +### 3.50 16-bit Typed Group Reduce, `S = 4 * VLaneElems = 64` + +This is the four-fragment packed case for both `f16` and `i16`. + +Assigned layouts: + +```text +%x, %mask: + #pto.vmi.layout + +%sum: + !pto.vmi.vreg<512xT16, #pto.vmi.layout> +``` + +VPTO lowering shape: + +```text +%x_p0, %x_p1, %x_p2, %x_p3 = materialize deinterleaved=4 input + : four !pto.vreg<128xT16> + +%all_b16 = pto.pge_b16 "PAT_ALL" +%slot8_b16 = pto.pge_b16 "PAT_VL8" + +%s0 = pto.vcgadd %x_p0, %all_b16 : !pto.vreg<128xT16> +%s1 = pto.vcgadd %x_p1, %all_b16 : !pto.vreg<128xT16> +%s2 = pto.vcgadd %x_p2, %all_b16 : !pto.vreg<128xT16> +%s3 = pto.vcgadd %x_p3, %all_b16 : !pto.vreg<128xT16> + +%s01 = pto.vadd %s0, %s1, %slot8_b16 : !pto.vreg<128xT16> +%s23 = pto.vadd %s2, %s3, %slot8_b16 : !pto.vreg<128xT16> +%sum0 = pto.vadd %s01, %s23, %slot8_b16 : !pto.vreg<128xT16> + +pto.vsts %sum0, %out[%group_off], %slot8_b16 {dist = "NORM_B16"} + : !pto.vreg<128xT16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r] = reduce_T16(base[off + r * 64 + 0 .. 63]) +``` + +### 3.51 16-bit Typed Group Reduce, `S = L = 128` + +This is the first row-local full-physical-chunk case for both `f16` and `i16`. +The canonical result is row-local `slots = 1`, not packed `slots = 8`. + +VMI-shaped input: + +```text +%x = pto.vmi.load %base[%off] + : memref<1024xT16> -> !pto.vmi.vreg<1024xT16> +%mask = pto.vmi.create_group_mask %c128 {num_groups = 8, group_size = 128} + : index -> !pto.vmi.mask<1024xpred> +%sum = pto.vmi.group_reduce_add{f|i} %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x, %mask: + #pto.vmi.layout + +%sum: + !pto.vmi.vreg<1024xT16, #pto.vmi.layout> +``` + +VPTO lowering shape: + +```text +%all_b16 = pto.pge_b16 "PAT_ALL" +%slot8_b16 = pto.pge_b16 "PAT_VL8" +%slot1_b16 = pto.pge_b16 "PAT_VL1" + +// Repeated for r = 0..7. +%x_r = pto.vlds %base[%row_off_r] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<128xT16> +%partial_r = pto.vcgadd %x_r, %all_b16 + : !pto.vreg<128xT16>, !pto.mask -> !pto.vreg<128xT16> +%sum_r = pto.vcadd %partial_r, %slot8_b16 + : !pto.vreg<128xT16>, !pto.mask -> !pto.vreg<128xT16> + +pto.vsts %sum_r, %out[%group_off_plus_r], %slot1_b16 {dist = "NORM_B16"} + : !pto.vreg<128xT16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r] = reduce_T16(base[off + r * 128 + 0 .. 127]) +``` + +### 3.52 32-bit Typed Group Reduce + +This case covers both `f32` and `i32`. The element width is the same, so +`VLaneElems = 8` and `L = 64` for both. Floating-point uses +`group_reduce_addf` with `reassoc`; integer uses `group_reduce_addi`. + +Example for `S = 2 * VLaneElems = 16`: + +```text +%x: + !pto.vmi.vreg<128xT32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<128xT32, #pto.vmi.layout> +``` + +VPTO lowering shape: + +```text +%x_p0, %x_p1 = pto.vldsx2 %base[%off], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xT32>, !pto.vreg<64xT32> + +%all_b32 = pto.pge_b32 "PAT_ALL" +%slot8_b32 = pto.pge_b32 "PAT_VL8" + +%s0 = pto.vcgadd %x_p0, %all_b32 + : !pto.vreg<64xT32>, !pto.mask -> !pto.vreg<64xT32> +%s1 = pto.vcgadd %x_p1, %all_b32 + : !pto.vreg<64xT32>, !pto.mask -> !pto.vreg<64xT32> +%sum0 = pto.vadd %s0, %s1, %slot8_b32 + : !pto.vreg<64xT32>, !pto.vreg<64xT32>, !pto.mask + -> !pto.vreg<64xT32> + +pto.vsts %sum0, %out[%group_off], %slot8_b32 {dist = "NORM_B32"} + : !pto.vreg<64xT32>, !pto.ptr, !pto.mask +``` + +The same formula gives: + +```text +S=8: + contiguous, slots=8, one vcgadd. + +S=32: + deinterleaved=4, slots=8, four vcgadd plus vadd tree. + +S=64: + contiguous, slots=1, row-local vcgadd plus vcadd. + +S=128: + contiguous, slots=1, row-local multi-chunk accumulation. +``` + +### 3.53 Integer Semantics And Invalid Typed Reductions + +Integer group reduction is not a variant of `group_reduce_addf`; it requires a +typed integer op: + +```text +%sum = pto.vmi.group_reduce_addi %x, %mask {num_groups = G} +``` + +Required semantics: + +```text +inactive lanes contribute integer zero +addition uses the target's normal integer add behavior +wrap/saturating variants must be represented by distinct ops if both are needed +signedness does not affect add, but does affect future max/min integer reduces +``` + +Required invalid cases: + +```text +pto.vmi.group_reduce_addf with integer element type -> verifier error +pto.vmi.group_reduce_addi with floating-point element type -> verifier error +pto.vmi.group_reduce_addi i8 -> invalid direct 8-bit accumulator reduce; + cast to i16/i32 first unless target exposes i8 vcgadd +S not in {VLaneElems, 2*VLaneElems, 4*VLaneElems} and not a full-chunk multiple + -> layout-contract diagnostic +``` + +### 3.54 8-bit Floating Group Reduce + +There is no direct f8 `vcgadd` grouped reduction in the current target model, +but f8 supports cast to an accumulator type. The semantic path is: + +```text +f8 storage -> cast/extf to f32 accumulator -> group_reduce_addf on f32 +``` + +Here `f8` is only the cast source and the memory element type. The reduction +itself is a f32 accumulator operation. + +The group size remains a logical-lane property. For example, reducing eight +rows of 32 f8 elements produces the same logical result as reducing eight rows +of 32 f32 accumulator elements after extension. + +VMI-shaped input: + +```text +%x8 = pto.vmi.load %base_f8[%off] + : memref<256xf8> -> !pto.vmi.vreg<256xf8> +%x32 = pto.vmi.extf %x8 + : !pto.vmi.vreg<256xf8> -> !pto.vmi.vreg<256xf32> +%mask = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> +%sum = pto.vmi.group_reduce_addf %x32, %mask {num_groups = 8, reassoc} +pto.vmi.group_store %sum, %out_f32[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x8: + !pto.vmi.vreg<256xf8, #pto.vmi.layout> + +%x32, %mask: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + !pto.vmi.mask<256xb32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering shape: + +```text +%x8_packed = pto.vlds %base_f8[%off] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<256xf8> + +%all_b8 = pto.pge_b8 "PAT_ALL" +%all_b32 = pto.pge_b32 "PAT_ALL" +%slot8_b32 = pto.pge_b32 "PAT_VL8" + +%x32_p0 = pto.vcvt %x8_packed, %all_b8 {part = "P0"} + : !pto.vreg<256xf8>, !pto.mask -> !pto.vreg<64xf32> +%x32_p1 = pto.vcvt %x8_packed, %all_b8 {part = "P1"} + : !pto.vreg<256xf8>, !pto.mask -> !pto.vreg<64xf32> +%x32_p2 = pto.vcvt %x8_packed, %all_b8 {part = "P2"} + : !pto.vreg<256xf8>, !pto.mask -> !pto.vreg<64xf32> +%x32_p3 = pto.vcvt %x8_packed, %all_b8 {part = "P3"} + : !pto.vreg<256xf8>, !pto.mask -> !pto.vreg<64xf32> + +%s0 = pto.vcgadd %x32_p0, %all_b32 : !pto.vreg<64xf32> +%s1 = pto.vcgadd %x32_p1, %all_b32 : !pto.vreg<64xf32> +%s2 = pto.vcgadd %x32_p2, %all_b32 : !pto.vreg<64xf32> +%s3 = pto.vcgadd %x32_p3, %all_b32 : !pto.vreg<64xf32> +%s01 = pto.vadd %s0, %s1, %slot8_b32 : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %slot8_b32 : !pto.vreg<64xf32> +%sum0 = pto.vadd %s01, %s23, %slot8_b32 : !pto.vreg<64xf32> + +pto.vsts %sum0, %out_f32[%group_off], %slot8_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out_f32[group_off + r] = + reduce_f32(f32(base_f8[off + r * 32 + 0 .. 31])) +``` + +Direct f8 grouped reduction is invalid: + +```text +pto.vmi.group_reduce_addf %x8, %mask + : !pto.vmi.vreg<256xf8>, !pto.vmi.mask<256xpred> + -> verifier or layout-contract diagnostic +``` + +### 3.55 8-bit Integer Group Reduce + +The current target model has no i8 `vcgadd`. It does have widening `vcadd` for +full-vector reductions, but grouped reduction needs one partial result per +32B VLane. Since 8-bit integers support cast to wider integer types, the +baseline grouped path casts before reducing: + +```text +i8/i16 storage -> signed/unsigned cast to i32 accumulator + -> group_reduce_addi on the accumulator type +``` + +Here `i8`/`i16` are only cast sources and memory element types. The reduction +itself is an i32 accumulator operation, with signedness handled by the cast. + +The integer cast operation must carry signedness. This document uses +`extsi/extui` as the widening spelling and `trunci` as the narrowing spelling: + +```text +%x32 = pto.vmi.extsi %x8 : !pto.vmi.vreg -> !pto.vmi.vreg +%x32 = pto.vmi.extui %x8 : !pto.vmi.vreg -> !pto.vmi.vreg +%x8 = pto.vmi.trunci %x32 : !pto.vmi.vreg -> !pto.vmi.vreg +``` + +The last form is unsigned i8 on the current VPTO target: VISA exposes +VCVTII.s322u8/u322u8 for 32-bit to 8-bit narrowing, not a signed-i8 +destination form. + +VMI-shaped input: + +```text +%x8 = pto.vmi.load %base_i8[%off] + : memref<256xi8> -> !pto.vmi.vreg<256xi8> +%x32 = pto.vmi.extsi %x8 + : !pto.vmi.vreg<256xi8> -> !pto.vmi.vreg<256xi32> +%mask = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> +%sum = pto.vmi.group_reduce_addi %x32, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out_i32[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x8: + !pto.vmi.vreg<256xi8, #pto.vmi.layout> + +%x32, %mask: + !pto.vmi.vreg<256xi32, #pto.vmi.layout> + !pto.vmi.mask<256xb32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xi32, #pto.vmi.layout> +``` + +VPTO lowering shape after integer cast materialization: + +```text +%x32_p0, %x32_p1, %x32_p2, %x32_p3 = + materialize signed cast i8 -> i32 with deinterleaved=4 layout + : four !pto.vreg<64xi32> + +%all_b32 = pto.pge_b32 "PAT_ALL" +%slot8_b32 = pto.pge_b32 "PAT_VL8" + +%s0 = pto.vcgadd %x32_p0, %all_b32 : !pto.vreg<64xi32> +%s1 = pto.vcgadd %x32_p1, %all_b32 : !pto.vreg<64xi32> +%s2 = pto.vcgadd %x32_p2, %all_b32 : !pto.vreg<64xi32> +%s3 = pto.vcgadd %x32_p3, %all_b32 : !pto.vreg<64xi32> +%s01 = pto.vadd %s0, %s1, %slot8_b32 : !pto.vreg<64xi32> +%s23 = pto.vadd %s2, %s3, %slot8_b32 : !pto.vreg<64xi32> +%sum0 = pto.vadd %s01, %s23, %slot8_b32 : !pto.vreg<64xi32> + +pto.vsts %sum0, %out_i32[%group_off], %slot8_b32 {dist = "NORM_B32"} + : !pto.vreg<64xi32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out_i32[group_off + r] = + reduce_i32(sign_extend(base_i8[off + r * 32 + 0 .. 31])) +``` + +Direct i8 grouped reduction without the cast is invalid: + +```text +pto.vmi.group_reduce_addi %x8, %mask + : !pto.vmi.vreg<256xi8>, !pto.vmi.mask<256xpred> + -> verifier or layout-contract diagnostic +``` + +An optimized row-local i8 full-chunk recipe may be added later for +`S = 256` by using widening `vcadd`, but that requires a widening +`group_slots` result contract and must not change the baseline cast-to-accumulator +semantics above. + +If the final memory result is i8, narrowing is a separate cast after the +accumulator computation: + +```text +%sum32 = pto.vmi.group_reduce_addi %x32, %mask {num_groups = 8} +%sum8 = pto.vmi.trunci %sum32 +pto.vmi.group_store %sum8, %out_i8[%group_off], %c1 {num_groups = 8} +``` + +That packed group-slot `trunci` path is not a baseline recipe yet; the +implementation must either define a slot-wise VCVTII recipe or diagnose at +layout assignment. diff --git a/include/PTO/IR/VMIOps.td b/include/PTO/IR/VMIOps.td index 80036f9946..d14b6fe8ee 100644 --- a/include/PTO/IR/VMIOps.td +++ b/include/PTO/IR/VMIOps.td @@ -418,6 +418,16 @@ def VMIGroupReduceAddFOp : VMI_Op<"group_reduce_addf"> { let assemblyFormat = "$source `,` $mask attr-dict `:` type($source) `,` type($mask) `->` type($result)"; } +def VMIGroupReduceAddIOp : VMI_Op<"group_reduce_addi"> { + let summary = "VMI masked integer add reduction within fixed logical groups"; + let arguments = (ins VMI_VRegTypeConstraint:$source, + VMI_MaskTypeConstraint:$mask, + I64Attr:$num_groups); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `,` $mask attr-dict `:` type($source) `,` type($mask) `->` type($result)"; +} + def VMIGroupBroadcastOp : VMI_Op<"group_broadcast"> { let summary = "VMI broadcast group-slot values back to each logical group"; let arguments = (ins VMI_VRegTypeConstraint:$source, @@ -443,6 +453,30 @@ def VMITruncFOp : VMI_Op<"truncf"> { let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; } +def VMIExtSIOp : VMI_Op<"extsi"> { + let summary = "VMI signed integer elementwise extension"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIExtUIOp : VMI_Op<"extui"> { + let summary = "VMI unsigned integer elementwise extension"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMITruncIOp : VMI_Op<"trunci"> { + let summary = "VMI saturating integer elementwise truncation"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + def VMIBitcastOp : VMI_Op<"bitcast"> { let summary = "VMI bitwise vector reinterpretation"; let arguments = (ins VMI_VRegTypeConstraint:$source); diff --git a/include/PTO/Transforms/VMILocalRecipeRegistry.h b/include/PTO/Transforms/VMILocalRecipeRegistry.h index 10cde1dc96..8472a32c4c 100644 --- a/include/PTO/Transforms/VMILocalRecipeRegistry.h +++ b/include/PTO/Transforms/VMILocalRecipeRegistry.h @@ -85,14 +85,15 @@ struct VMIGroupSlotsStoreRecipe { }; enum class VMIGroupReduceAddFRecipeKind { - S8Vcgadd, - S16Deinterleaved2VcgaddVadd, - S32Deinterleaved4VcgaddTree, + OneVLaneVcgadd, + TwoVLaneDeinterleaved2VcgaddVadd, + FourVLaneDeinterleaved4VcgaddTree, ContiguousVcaddRows, }; struct VMIGroupReduceAddFRecipe { - VMIGroupReduceAddFRecipeKind kind = VMIGroupReduceAddFRecipeKind::S8Vcgadd; + VMIGroupReduceAddFRecipeKind kind = + VMIGroupReduceAddFRecipeKind::OneVLaneVcgadd; }; enum class VMIGroupBroadcastRecipeKind { @@ -125,6 +126,27 @@ struct VMIExtFRecipe { VMIExtFRecipeKind::ContiguousF16ToDeinterleaved2F32; }; +enum class VMITruncIRecipeKind { + Deinterleaved2I32ToContiguousI16, + Deinterleaved4I32ToContiguousI8, + GroupSlots1I32ToI16, +}; + +struct VMITruncIRecipe { + VMITruncIRecipeKind kind = + VMITruncIRecipeKind::Deinterleaved2I32ToContiguousI16; +}; + +enum class VMIExtIRecipeKind { + ContiguousI16ToDeinterleaved2I32, + ContiguousI8ToDeinterleaved4I32, +}; + +struct VMIExtIRecipe { + VMIExtIRecipeKind kind = + VMIExtIRecipeKind::ContiguousI16ToDeinterleaved2I32; +}; + enum class VMIBitcastRecipeKind { PerPartVbitcast, }; @@ -178,6 +200,11 @@ class VMILocalRecipeRegistry { VMIGroupReduceAddFOp op, std::string *reason = nullptr) const; + FailureOr + getGroupReduceAddIRecipe(const VMITargetCapabilityRegistry &capabilities, + VMIGroupReduceAddIOp op, + std::string *reason = nullptr) const; + FailureOr getGroupBroadcastRecipe(const VMITargetCapabilityRegistry &capabilities, VMIGroupBroadcastOp op, @@ -189,6 +216,15 @@ class VMILocalRecipeRegistry { FailureOr getExtFRecipe(VMIExtFOp op, std::string *reason = nullptr) const; + FailureOr + getExtSIRecipe(VMIExtSIOp op, std::string *reason = nullptr) const; + + FailureOr + getExtUIRecipe(VMIExtUIOp op, std::string *reason = nullptr) const; + + FailureOr + getTruncIRecipe(VMITruncIOp op, std::string *reason = nullptr) const; + FailureOr getBitcastRecipe(VMIBitcastOp op, std::string *reason = nullptr) const; }; diff --git a/include/PTO/Transforms/VMITargetCapabilities.h b/include/PTO/Transforms/VMITargetCapabilities.h index 15b4f19f1d..a96a73a6d0 100644 --- a/include/PTO/Transforms/VMITargetCapabilities.h +++ b/include/PTO/Transforms/VMITargetCapabilities.h @@ -42,6 +42,8 @@ enum class VMIElementPurpose { enum class VMIReductionKind { AddI, AddF, + GroupAddI, + GroupAddF, MaxF, MinF, }; @@ -229,11 +231,26 @@ class VMITargetCapabilityRegistry { "currently supports only 32-bit integer elements because narrow " "vcadd widens its result"); case VMIReductionKind::AddF: - if (elementType.isF32()) + if (elementType.isF16() || elementType.isF32()) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "currently supports only f16/f32 elements for floating-point " + "reduction"); + case VMIReductionKind::GroupAddI: { + auto intType = dyn_cast(elementType); + if (intType && intType.getWidth() == 32) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "grouped integer add reduction supports only i32 accumulator " + "elements because narrow integer reductions widen their result; " + "cast i8/i16 storage before grouped reduction"); + } + case VMIReductionKind::GroupAddF: + if (elementType.isF16() || elementType.isF32()) return VMICapabilityResult::supported(); return VMICapabilityResult::missingCapability( - "currently supports only f32 elements; f16 requires an explicit " - "accumulator precision and rounding contract"); + "grouped floating-point add reduction supports f16/f32 accumulator " + "elements"); case VMIReductionKind::MaxF: case VMIReductionKind::MinF: if (elementType.isF16() || elementType.isF32()) diff --git a/lib/PTO/IR/VMI.cpp b/lib/PTO/IR/VMI.cpp index ff7170044e..b504de67f5 100644 --- a/lib/PTO/IR/VMI.cpp +++ b/lib/PTO/IR/VMI.cpp @@ -61,6 +61,16 @@ static bool isVMIIntegerLikeType(Type type) { return isa(type); } +static bool isVMISignedOrSignlessIntegerType(Type type) { + auto integerType = dyn_cast(type); + return integerType && !integerType.isUnsigned(); +} + +static bool isVMIUnsignedIntegerType(Type type) { + auto integerType = dyn_cast(type); + return integerType && integerType.isUnsigned(); +} + static bool isVMIIotaElementType(Type type) { if (auto intType = dyn_cast(type)) return intType.getWidth() == 8 || intType.getWidth() == 16 || @@ -1154,6 +1164,50 @@ LogicalResult VMIGroupReduceAddFOp::verify() { getNumGroupsAttr().getInt()); } +LogicalResult VMIGroupReduceAddIOp::verify() { + auto sourceType = cast(getSource().getType()); + auto maskType = cast(getMask().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(sourceType.getElementType())) + return emitOpError("requires integer-like VMI source element type"); + auto intType = dyn_cast(sourceType.getElementType()); + if (!intType || intType.getWidth() != 32) + return emitOpError( + "requires i32 accumulator element type; cast i8/i16 storage to i32 " + "before grouped reduction because integer reduction widens narrow " + "inputs"); + if (sourceType.getElementCount() != resultType.getElementCount()) + return emitOpError( + "requires source and result logical lane counts to match"); + if (sourceType.getElementType() != resultType.getElementType()) + return emitOpError("requires source and result element types to match"); + if (auto sourceLayout = sourceType.getLayoutAttr()) { + bool supportedSourceLayout = + sourceLayout.isContiguous() || + (sourceLayout.isDeinterleaved() && sourceLayout.getFactor() == 2 && + (sourceLayout.getBlockElems() == 1 || + sourceLayout.getBlockElems() == 8)) || + (sourceLayout.isDeinterleaved() && sourceLayout.getFactor() == 4 && + (sourceLayout.getBlockElems() == 1 || + sourceLayout.getBlockElems() == 8)); + if (!supportedSourceLayout) + return emitOpError( + "requires layout-assigned source to use contiguous layout or " + "deinterleaved=2/4 layout with block_elems=1 or block_elems=8"); + } + if (auto resultLayout = resultType.getLayoutAttr()) { + if (!resultLayout.isGroupSlots() || + resultLayout.getNumGroups() != getNumGroupsAttr().getInt()) + return emitOpError() << "requires layout-assigned result to use " + "#pto.vmi.layout"; + } + if (failed(verifyMaskMatchesData(getOperation(), maskType, sourceType))) + return failure(); + return verifyNumGroups(getOperation(), sourceType, + getNumGroupsAttr().getInt()); +} + LogicalResult VMIGroupBroadcastOp::verify() { auto sourceType = cast(getSource().getType()); auto resultType = cast(getResult().getType()); @@ -1212,6 +1266,56 @@ LogicalResult VMITruncFOp::verify() { return success(); } +LogicalResult VMIExtSIOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementCount() != resultType.getElementCount()) + return emitOpError( + "requires source and result logical lane counts to match"); + if (!isVMISignedOrSignlessIntegerType(sourceType.getElementType()) || + !isVMISignedOrSignlessIntegerType(resultType.getElementType())) + return emitOpError( + "requires signed or signless integer source and result element types"); + if (getVMIElementBitWidth(sourceType.getElementType()) >= + getVMIElementBitWidth(resultType.getElementType())) + return emitOpError( + "requires result element type to be wider than source element type"); + return success(); +} + +LogicalResult VMIExtUIOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementCount() != resultType.getElementCount()) + return emitOpError( + "requires source and result logical lane counts to match"); + if (!isVMIUnsignedIntegerType(sourceType.getElementType()) || + !isVMIUnsignedIntegerType(resultType.getElementType())) + return emitOpError( + "requires unsigned integer source and result element types"); + if (getVMIElementBitWidth(sourceType.getElementType()) >= + getVMIElementBitWidth(resultType.getElementType())) + return emitOpError( + "requires result element type to be wider than source element type"); + return success(); +} + +LogicalResult VMITruncIOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementCount() != resultType.getElementCount()) + return emitOpError( + "requires source and result logical lane counts to match"); + if (!isVMIIntegerLikeType(sourceType.getElementType()) || + !isVMIIntegerLikeType(resultType.getElementType())) + return emitOpError("requires integer source and result element types"); + if (getVMIElementBitWidth(sourceType.getElementType()) <= + getVMIElementBitWidth(resultType.getElementType())) + return emitOpError( + "requires result element type to be narrower than source element type"); + return success(); +} + LogicalResult VMIBitcastOp::verify() { auto sourceType = cast(getSource().getType()); auto resultType = cast(getResult().getType()); diff --git a/lib/PTO/Transforms/VMILayoutAssignment.cpp b/lib/PTO/Transforms/VMILayoutAssignment.cpp index 2ff9e50ae2..5f30ba82e0 100644 --- a/lib/PTO/Transforms/VMILayoutAssignment.cpp +++ b/lib/PTO/Transforms/VMILayoutAssignment.cpp @@ -236,6 +236,13 @@ struct LayoutSolver { return VMILayoutAttr::getGroupSlots(ctx, numGroups); } + std::optional getVLaneElems(Type elementType) { + FailureOr lanesPerPart = getDataLanesPerPart(elementType); + if (failed(lanesPerPart) || *lanesPerPart % 8 != 0) + return std::nullopt; + return *lanesPerPart / 8; + } + VMILayoutAttr getPreferredGroupSlotsLayout(VMIVRegType type, int64_t numGroups) { if (VMILayoutAttr existing = type.getLayoutAttr()) @@ -243,11 +250,10 @@ struct LayoutSolver { return existing; if (numGroups > 0 && type.getElementCount() % numGroups == 0) { int64_t groupSize = type.getElementCount() / numGroups; - if (groupSize == 8) - return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); - if (groupSize == 16) - return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); - if (groupSize == 32) + std::optional vlaneElems = getVLaneElems(type.getElementType()); + if (vlaneElems && (groupSize == *vlaneElems || + groupSize == 2 * *vlaneElems || + groupSize == 4 * *vlaneElems)) return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); FailureOr lanesPerPart = getDataLanesPerPart(type.getElementType()); @@ -264,9 +270,10 @@ struct LayoutSolver { return existing; if (numGroups > 0 && type.getElementCount() % numGroups == 0) { int64_t groupSize = type.getElementCount() / numGroups; - if (groupSize == 16) + std::optional vlaneElems = getVLaneElems(type.getElementType()); + if (vlaneElems && groupSize == 2 * *vlaneElems) return VMILayoutAttr::getDeinterleaved(ctx, 2, /*blockElems=*/8); - if (groupSize == 32) + if (vlaneElems && groupSize == 4 * *vlaneElems) return VMILayoutAttr::getDeinterleaved(ctx, 4, /*blockElems=*/8); } return getContiguousLayout(); @@ -279,7 +286,9 @@ struct LayoutSolver { return existing; if (numGroups > 0 && type.getElementCount() % numGroups == 0) { int64_t groupSize = type.getElementCount() / numGroups; - if (groupSize == 64) + FailureOr lanesPerPart = + getDataLanesPerPart(type.getElementType()); + if (succeeded(lanesPerPart) && groupSize == *lanesPerPart) return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/1); } return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); @@ -349,7 +358,8 @@ struct LayoutSolver { if (solved && solved.isGroupSlots() && solved.getNumGroups() == numGroups && solved.getSlots() > 0) return solved; - if (value.getDefiningOp()) + if (value.getDefiningOp() || + value.getDefiningOp()) return getPreferredGroupSlotsLayout(type, numGroups); if (value.getDefiningOp()) return getPreferredGroupSlotLoadLayout(type, numGroups); @@ -387,9 +397,10 @@ struct LayoutSolver { if (!resultType) continue; unsigned resultBits = getElementBitWidth(resultType.getElementType()); - if (groupSize == 16 && resultBits == 16) + std::optional vlaneElems = getVLaneElems(sourceType.getElementType()); + if (vlaneElems && groupSize == 2 * *vlaneElems && resultBits == 16) return true; - if (groupSize == 32 && resultBits == 8) + if (vlaneElems && groupSize == 4 * *vlaneElems && resultBits == 8) return true; } return false; @@ -810,12 +821,16 @@ struct LayoutSolver { if (solvedSourceLayout && numGroups > 0 && sourceType.getElementCount() % numGroups == 0) { int64_t groupSize = sourceType.getElementCount() / numGroups; - if (groupSize == 16 && solvedSourceLayout.isDeinterleaved() && + std::optional vlaneElems = + getVLaneElems(sourceType.getElementType()); + if (vlaneElems && groupSize == 2 * *vlaneElems && + solvedSourceLayout.isDeinterleaved() && solvedSourceLayout.getFactor() == 2 && (solvedSourceLayout.getBlockElems() == 1 || solvedSourceLayout.getBlockElems() == 8)) sourceLayout = solvedSourceLayout; - if (groupSize == 32 && solvedSourceLayout.isDeinterleaved() && + if (vlaneElems && groupSize == 4 * *vlaneElems && + solvedSourceLayout.isDeinterleaved() && solvedSourceLayout.getFactor() == 4 && (solvedSourceLayout.getBlockElems() == 1 || solvedSourceLayout.getBlockElems() == 8)) @@ -825,10 +840,12 @@ struct LayoutSolver { int64_t groupSize = sourceType.getElementCount() / numGroups; if (hasCompatibleTruncFUseForGroupReduce(reduce.getSource(), groupSize)) { - if (groupSize == 16) + std::optional vlaneElems = + getVLaneElems(sourceType.getElementType()); + if (vlaneElems && groupSize == 2 * *vlaneElems) sourceLayout = VMILayoutAttr::getDeinterleaved(ctx, 2, /*blockElems=*/1); - if (groupSize == 32) + if (vlaneElems && groupSize == 4 * *vlaneElems) sourceLayout = VMILayoutAttr::getDeinterleaved(ctx, 4, /*blockElems=*/1); } @@ -846,6 +863,45 @@ struct LayoutSolver { return WalkResult::interrupt(); return WalkResult::advance(); } + if (auto reduce = dyn_cast(op)) { + auto sourceType = cast(reduce.getSource().getType()); + auto resultType = cast(reduce.getResult().getType()); + VMILayoutAttr sourceLayout = getPreferredGroupReduceSourceLayout( + sourceType, reduce.getNumGroupsAttr().getInt()); + VMILayoutAttr solvedSourceLayout = + getExplicitDataLayout(reduce.getSource()); + int64_t numGroups = reduce.getNumGroupsAttr().getInt(); + if (solvedSourceLayout && numGroups > 0 && + sourceType.getElementCount() % numGroups == 0) { + int64_t groupSize = sourceType.getElementCount() / numGroups; + std::optional vlaneElems = + getVLaneElems(sourceType.getElementType()); + if (vlaneElems && groupSize == 2 * *vlaneElems && + solvedSourceLayout.isDeinterleaved() && + solvedSourceLayout.getFactor() == 2 && + (solvedSourceLayout.getBlockElems() == 1 || + solvedSourceLayout.getBlockElems() == 8)) + sourceLayout = solvedSourceLayout; + if (vlaneElems && groupSize == 4 * *vlaneElems && + solvedSourceLayout.isDeinterleaved() && + solvedSourceLayout.getFactor() == 4 && + (solvedSourceLayout.getBlockElems() == 1 || + solvedSourceLayout.getBlockElems() == 8)) + sourceLayout = solvedSourceLayout; + } + requestDataUse(reduce.getSourceMutable(), sourceLayout); + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceLayout, + getMaskGranularityForElement(sourceType.getElementType()), op))) + return WalkResult::interrupt(); + if (failed(setNaturalLayout( + reduce.getResult(), + getPreferredGroupSlotsLayout( + resultType, reduce.getNumGroupsAttr().getInt()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } if (auto broadcast = dyn_cast(op)) { auto sourceType = cast(broadcast.getSource().getType()); requestDataUse(broadcast.getSourceMutable(), @@ -873,6 +929,46 @@ struct LayoutSolver { } return WalkResult::advance(); } + if (auto extsi = dyn_cast(op)) { + auto sourceType = cast(extsi.getSource().getType()); + auto resultType = cast(extsi.getResult().getType()); + unsigned sourceBits = getElementBitWidth(sourceType.getElementType()); + unsigned resultBits = getElementBitWidth(resultType.getElementType()); + if (sourceBits == 16 && resultBits == 32) { + requestDataUse(extsi.getSourceMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(extsi.getResult(), + VMILayoutAttr::getDeinterleaved(ctx, 2), + op))) + return WalkResult::interrupt(); + } else if (sourceBits == 8 && resultBits == 32) { + requestDataUse(extsi.getSourceMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(extsi.getResult(), + VMILayoutAttr::getDeinterleaved(ctx, 4), + op))) + return WalkResult::interrupt(); + } + return WalkResult::advance(); + } + if (auto extui = dyn_cast(op)) { + auto sourceType = cast(extui.getSource().getType()); + auto resultType = cast(extui.getResult().getType()); + unsigned sourceBits = getElementBitWidth(sourceType.getElementType()); + unsigned resultBits = getElementBitWidth(resultType.getElementType()); + if (sourceBits == 16 && resultBits == 32) { + requestDataUse(extui.getSourceMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(extui.getResult(), + VMILayoutAttr::getDeinterleaved(ctx, 2), + op))) + return WalkResult::interrupt(); + } else if (sourceBits == 8 && resultBits == 32) { + requestDataUse(extui.getSourceMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(extui.getResult(), + VMILayoutAttr::getDeinterleaved(ctx, 4), + op))) + return WalkResult::interrupt(); + } + return WalkResult::advance(); + } if (auto truncf = dyn_cast(op)) { auto sourceType = cast(truncf.getSource().getType()); auto resultType = cast(truncf.getResult().getType()); @@ -897,6 +993,30 @@ struct LayoutSolver { return WalkResult::interrupt(); return WalkResult::advance(); } + if (auto trunci = dyn_cast(op)) { + auto sourceType = cast(trunci.getSource().getType()); + auto resultType = cast(trunci.getResult().getType()); + unsigned sourceBits = getElementBitWidth(sourceType.getElementType()); + unsigned resultBits = getElementBitWidth(resultType.getElementType()); + VMILayoutAttr sourceLayout = getDataLayout(trunci.getSource()); + if (sourceBits == 32 && resultBits == 16 && sourceLayout && + sourceLayout.isGroupSlots() && sourceLayout.getSlots() == 1) { + requestDataUse(trunci.getSourceMutable(), sourceLayout); + if (failed(setNaturalLayout(trunci.getResult(), sourceLayout, op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (sourceBits == 32 && resultBits == 16) + requestDataUse(trunci.getSourceMutable(), + VMILayoutAttr::getDeinterleaved(ctx, 2)); + else if (sourceBits == 32 && resultBits == 8) + requestDataUse(trunci.getSourceMutable(), + VMILayoutAttr::getDeinterleaved(ctx, 4)); + if (failed(setNaturalLayout(trunci.getResult(), getContiguousLayout(), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } if (auto bitcast = dyn_cast(op)) { if (failed(unite(bitcast.getSource(), bitcast.getResult(), op))) return WalkResult::interrupt(); @@ -1463,6 +1583,14 @@ struct LayoutSolver { return WalkResult::interrupt(); return WalkResult::advance(); } + if (auto reduce = dyn_cast(op)) { + auto sourceType = cast(reduce.getSource().getType()); + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceType.getLayoutAttr(), + getMaskGranularityForElement(sourceType.getElementType()), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } if (auto load = dyn_cast(op)) { auto resultType = cast(load.getResult().getType()); if (failed(requestMaskUse( diff --git a/lib/PTO/Transforms/VMILocalRecipeRegistry.cpp b/lib/PTO/Transforms/VMILocalRecipeRegistry.cpp index 7cd5281353..34b843737c 100644 --- a/lib/PTO/Transforms/VMILocalRecipeRegistry.cpp +++ b/lib/PTO/Transforms/VMILocalRecipeRegistry.cpp @@ -657,9 +657,12 @@ VMILocalRecipeRegistry::getGroupSlotsStoreRecipe( } FailureOr -VMILocalRecipeRegistry::getGroupReduceAddFRecipe( - const VMITargetCapabilityRegistry &capabilities, VMIGroupReduceAddFOp op, - std::string *reason) const { +getGroupReduceAddRecipeImpl(const VMITargetCapabilityRegistry &capabilities, + Operation *op, VMIVRegType sourceType, + VMIMaskType maskType, VMIVRegType resultType, + int64_t numGroups, bool requiresReassoc, + VMIReductionKind reductionKind, + std::string *reason) { auto fail = [&](const Twine &message) -> FailureOr { if (reason) @@ -667,17 +670,13 @@ VMILocalRecipeRegistry::getGroupReduceAddFRecipe( return failure(); }; - if (!op->hasAttr("reassoc")) + if (requiresReassoc && !op->hasAttr("reassoc")) return fail("requires reassoc attr for pair-wise floating-point " "reduction"); - auto sourceType = cast(op.getSource().getType()); - auto maskType = cast(op.getMask().getType()); - auto resultType = cast(op.getResult().getType()); VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); VMILayoutAttr maskLayout = maskType.getLayoutAttr(); VMILayoutAttr resultLayout = resultType.getLayoutAttr(); - int64_t numGroups = op.getNumGroupsAttr().getInt(); if (!sourceLayout || !maskLayout || !resultLayout) return fail("requires assigned source, mask, and result layouts"); if (!resultLayout.isGroupSlots() || resultLayout.getNumGroups() != numGroups) @@ -685,23 +684,28 @@ VMILocalRecipeRegistry::getGroupReduceAddFRecipe( if (resultLayout.getSlots() != 8 && resultLayout.getSlots() != 1) { FailureOr groupSize = getGroupSizeFromNumGroups(sourceType, numGroups, reason); + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + int64_t vlaneElems = + succeeded(lanesPerPart) && *lanesPerPart % 8 == 0 ? *lanesPerPart / 8 + : -1; if (succeeded(groupSize) && resultLayout.getSlots() <= 0 && - *groupSize != 8 && *groupSize != 16 && *groupSize != 32) - return fail("stable group_reduce_addf slots=8 recipes support group " - "size 8, 16, or 32"); - return fail("stable group_reduce_addf local recipes currently require " + (*groupSize != vlaneElems && *groupSize != 2 * vlaneElems && + *groupSize != 4 * vlaneElems)) + return fail("stable group_reduce_add slots=8 recipes support group " + "sizes VLaneElems, 2*VLaneElems, or 4*VLaneElems"); + return fail("stable group_reduce_add local recipes currently require " "result layout slots=8 or slots=1"); } VMICapabilityResult elementCapability = - capabilities.supportsReductionElementType(VMIReductionKind::AddF, + capabilities.supportsReductionElementType(reductionKind, sourceType.getElementType()); if (!elementCapability.isSupported()) return fail(elementCapability.reason); - if (!sourceType.getElementType().isF32() || - sourceType.getElementType() != resultType.getElementType()) - return fail("stable group_reduce_addf local recipes require f32 " - "source/result"); + if (sourceType.getElementType() != resultType.getElementType()) + return fail("stable group_reduce_add local recipes require matching " + "source/result element types"); if (sourceType.getElementCount() != resultType.getElementCount()) return fail("requires source/result lane count to match"); @@ -709,6 +713,11 @@ VMILocalRecipeRegistry::getGroupReduceAddFRecipe( getGroupSizeFromNumGroups(sourceType, numGroups, reason); if (failed(groupSize)) return failure(); + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + if (failed(lanesPerPart) || *lanesPerPart % 8 != 0) + return fail("requires element type with known physical VLane width"); + int64_t vlaneElems = *lanesPerPart / 8; FailureOr sourceArity = getVMIPhysicalArity(sourceType); FailureOr maskArity = getVMIPhysicalArity(maskType); @@ -719,81 +728,105 @@ VMILocalRecipeRegistry::getGroupReduceAddFRecipe( return fail("requires matching non-empty source/mask physical arity"); if (resultLayout.getSlots() == 1) { - FailureOr lanesPerPart = - getDataLanesPerPart(sourceType.getElementType()); if (failed(lanesPerPart) || *groupSize < *lanesPerPart || *groupSize % *lanesPerPart != 0) - return fail("stable group_reduce_addf slots=1 recipes support group " + return fail("stable group_reduce_add slots=1 recipes support group " "sizes that are multiples of one physical chunk"); if (!sourceLayout.isContiguous() || !maskLayout.isContiguous()) - return fail("slots=1 group_reduce_addf requires contiguous source/mask " + return fail("slots=1 group_reduce_add requires contiguous source/mask " "layouts"); if (*resultArity != numGroups) - return fail("slots=1 group_reduce_addf requires one physical result " + return fail("slots=1 group_reduce_add requires one physical result " "part per group"); std::string sourceFullReason; if (failed(checkFullDataPhysicalChunks(sourceType, &sourceFullReason))) - return fail(Twine("slots=1 group_reduce_addf requires full source " + return fail(Twine("slots=1 group_reduce_add requires full source " "chunks; ") + sourceFullReason); return VMIGroupReduceAddFRecipe{ VMIGroupReduceAddFRecipeKind::ContiguousVcaddRows}; } - if (*groupSize == 8) { + if (*groupSize == vlaneElems) { if (!sourceLayout.isContiguous() || !maskLayout.isContiguous()) - return fail("s8 group_reduce_addf requires contiguous source/mask " + return fail("one-vlane group_reduce_add requires contiguous source/mask " "layouts"); std::string sourceFullReason; if (failed(checkFullDataPhysicalChunks(sourceType, &sourceFullReason))) - return fail(Twine("s8 group_reduce_addf requires full source chunks; ") + + return fail(Twine("one-vlane group_reduce_add requires full source " + "chunks; ") + sourceFullReason); if (*resultArity != *sourceArity) - return fail("s8 group_reduce_addf requires source/result physical " + return fail("one-vlane group_reduce_add requires source/result physical " "arity to match"); - return VMIGroupReduceAddFRecipe{VMIGroupReduceAddFRecipeKind::S8Vcgadd}; + return VMIGroupReduceAddFRecipe{ + VMIGroupReduceAddFRecipeKind::OneVLaneVcgadd}; } - if (*groupSize == 16) { + if (*groupSize == 2 * vlaneElems) { if (!sourceLayout.isDeinterleaved() || sourceLayout.getFactor() != 2 || (sourceLayout.getBlockElems() != 1 && sourceLayout.getBlockElems() != 8)) - return fail("s16 group_reduce_addf requires source layout " + return fail("two-vlane group_reduce_add requires source layout " "deinterleaved=2 with block_elems=1 or block_elems=8"); if (!maskLayout.isDeinterleaved() || maskLayout.getFactor() != 2 || maskLayout.getBlockElems() != sourceLayout.getBlockElems()) - return fail("s16 group_reduce_addf requires matching mask layout " + return fail("two-vlane group_reduce_add requires matching mask layout " "deinterleaved=2 with the same block_elems"); int64_t expectedResultArity = ceilDivNonNegative(numGroups, 8); if (*resultArity != expectedResultArity || *sourceArity != *resultArity * 2) - return fail("s16 group_reduce_addf requires two source/mask parts per " + return fail("two-vlane group_reduce_add requires two source/mask parts per " "result part"); return VMIGroupReduceAddFRecipe{ - VMIGroupReduceAddFRecipeKind::S16Deinterleaved2VcgaddVadd}; + VMIGroupReduceAddFRecipeKind::TwoVLaneDeinterleaved2VcgaddVadd}; } - if (*groupSize == 32) { + if (*groupSize == 4 * vlaneElems) { if (!sourceLayout.isDeinterleaved() || sourceLayout.getFactor() != 4 || (sourceLayout.getBlockElems() != 1 && sourceLayout.getBlockElems() != 8)) - return fail("s32 group_reduce_addf requires source layout " + return fail("four-vlane group_reduce_add requires source layout " "deinterleaved=4 with block_elems=1 or block_elems=8"); if (!maskLayout.isDeinterleaved() || maskLayout.getFactor() != 4 || maskLayout.getBlockElems() != sourceLayout.getBlockElems()) - return fail("s32 group_reduce_addf requires matching mask layout " + return fail("four-vlane group_reduce_add requires matching mask layout " "deinterleaved=4 with the same block_elems"); int64_t expectedResultArity = ceilDivNonNegative(numGroups, 8); if (*resultArity != expectedResultArity || *sourceArity != *resultArity * 4) - return fail("s32 group_reduce_addf requires four source/mask parts per " + return fail("four-vlane group_reduce_add requires four source/mask parts per " "result part"); return VMIGroupReduceAddFRecipe{ - VMIGroupReduceAddFRecipeKind::S32Deinterleaved4VcgaddTree}; + VMIGroupReduceAddFRecipeKind::FourVLaneDeinterleaved4VcgaddTree}; } - return fail("stable group_reduce_addf slots=8 recipes support group size " - "8, 16, or 32"); + return fail("stable group_reduce_add slots=8 recipes support group sizes " + "VLaneElems, 2*VLaneElems, or 4*VLaneElems"); +} + +FailureOr +VMILocalRecipeRegistry::getGroupReduceAddFRecipe( + const VMITargetCapabilityRegistry &capabilities, VMIGroupReduceAddFOp op, + std::string *reason) const { + return getGroupReduceAddRecipeImpl( + capabilities, op.getOperation(), cast(op.getSource().getType()), + cast(op.getMask().getType()), + cast(op.getResult().getType()), + op.getNumGroupsAttr().getInt(), /*requiresReassoc=*/true, + VMIReductionKind::GroupAddF, reason); +} + +FailureOr +VMILocalRecipeRegistry::getGroupReduceAddIRecipe( + const VMITargetCapabilityRegistry &capabilities, VMIGroupReduceAddIOp op, + std::string *reason) const { + return getGroupReduceAddRecipeImpl( + capabilities, op.getOperation(), cast(op.getSource().getType()), + cast(op.getMask().getType()), + cast(op.getResult().getType()), + op.getNumGroupsAttr().getInt(), /*requiresReassoc=*/false, + VMIReductionKind::GroupAddI, reason); } FailureOr @@ -964,6 +997,118 @@ VMILocalRecipeRegistry::getExtFRecipe(VMIExtFOp op, "physical arity"); } +template +static FailureOr getExtIRecipeImpl(OpT op, + std::string *reason) { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (!sourceLayout || !resultLayout || failed(sourceArity) || + failed(resultArity)) + return fail("requires assigned source/result layouts and computable " + "physical arity"); + if (!sourceLayout.isContiguous() || !resultLayout.isDeinterleaved() || + !isa(sourceType.getElementType()) || + !isa(resultType.getElementType())) + return fail("requires contiguous integer source layout and deinterleaved " + "integer result layout"); + + unsigned sourceBits = + pto::getPTOStorageElemBitWidth(sourceType.getElementType()); + unsigned resultBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); + if (sourceBits == 16 && resultBits == 32 && resultLayout.getFactor() == 2 && + *resultArity == 2 * *sourceArity) + return VMIExtIRecipe{ + VMIExtIRecipeKind::ContiguousI16ToDeinterleaved2I32}; + if (sourceBits == 8 && resultBits == 32 && resultLayout.getFactor() == 4 && + *resultArity == 4 * *sourceArity) + return VMIExtIRecipe{ + VMIExtIRecipeKind::ContiguousI8ToDeinterleaved4I32}; + + return fail("unsupported integer extension source/result element width, " + "result factor, or physical arity"); +} + +FailureOr +VMILocalRecipeRegistry::getExtSIRecipe(VMIExtSIOp op, + std::string *reason) const { + return getExtIRecipeImpl(op, reason); +} + +FailureOr +VMILocalRecipeRegistry::getExtUIRecipe(VMIExtUIOp op, + std::string *reason) const { + return getExtIRecipeImpl(op, reason); +} + +FailureOr +VMILocalRecipeRegistry::getTruncIRecipe(VMITruncIOp op, + std::string *reason) const { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (!sourceLayout || !resultLayout || failed(sourceArity) || + failed(resultArity)) + return fail("requires assigned source/result layouts and computable " + "physical arity"); + if (!isa(sourceType.getElementType()) || + !isa(resultType.getElementType())) + return fail("requires integer source and result element types"); + + unsigned sourceBits = + pto::getPTOStorageElemBitWidth(sourceType.getElementType()); + unsigned resultBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); + + if (sourceLayout.isGroupSlots() || resultLayout.isGroupSlots()) { + if (!sourceLayout.isGroupSlots() || !resultLayout.isGroupSlots() || + sourceLayout.getNumGroups() != resultLayout.getNumGroups() || + sourceLayout.getSlots() != 1 || resultLayout.getSlots() != 1 || + sourceBits != 32 || resultBits != 16 || *sourceArity != *resultArity) + return fail("group-slot trunci requires matching " + "group_slots(num_groups=G, slots=1) source/result layouts, " + "32-bit integer source, 16-bit integer result, and matching " + "physical arity"); + return VMITruncIRecipe{VMITruncIRecipeKind::GroupSlots1I32ToI16}; + } + + if (!sourceLayout.isDeinterleaved() || !resultLayout.isContiguous() || + sourceBits != 32 || *resultArity != 1) + return fail("requires 32-bit integer deinterleaved source and contiguous " + "integer result"); + + if (sourceLayout.getFactor() == 2 && *sourceArity == 2 && resultBits == 16) + return VMITruncIRecipe{ + VMITruncIRecipeKind::Deinterleaved2I32ToContiguousI16}; + if (sourceLayout.getFactor() == 4 && *sourceArity == 4 && resultBits == 8 && + cast(resultType.getElementType()).isUnsigned()) + return VMITruncIRecipe{ + VMITruncIRecipeKind::Deinterleaved4I32ToContiguousI8}; + + return fail("unsupported deinterleaved trunci factor, arity, result element " + "width, or result signedness; 32-bit to 8-bit integer narrowing " + "requires unsigned i8 result"); +} + FailureOr VMILocalRecipeRegistry::getBitcastRecipe(VMIBitcastOp op, std::string *reason) const { diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index 1c92be4018..c44fc114ec 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -38,6 +38,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/Support/raw_ostream.h" #include +#include namespace mlir { namespace pto { @@ -2434,12 +2435,17 @@ LogicalResult checkVcgaddGroupReduceShape(VMIVRegType sourceType, return failure(); }; - if (!sourceType.getElementType().isF32() || - sourceType.getElementType() != resultType.getElementType()) - return fail("vcgadd group_reduce_addf path requires f32 source/result"); - if (groupSize != 8) - return fail("vcgadd group_reduce_addf path requires group size = 8 for " - "f32 32-byte VLane groups"); + if (sourceType.getElementType() != resultType.getElementType()) + return fail("vcgadd group_reduce_add path requires matching " + "source/result element types"); + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + if (failed(lanesPerPart) || *lanesPerPart % 8 != 0) + return fail("vcgadd group_reduce_add path requires known VLane width"); + int64_t vlaneElems = *lanesPerPart / 8; + if (groupSize != vlaneElems) + return fail("vcgadd group_reduce_add path requires group size equal to " + "one 32-byte VLane"); VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); VMILayoutAttr resultLayout = resultType.getLayoutAttr(); VMILayoutAttr maskLayout = maskType.getLayoutAttr(); @@ -2447,27 +2453,28 @@ LogicalResult checkVcgaddGroupReduceShape(VMIVRegType sourceType, if (!sourceLayout || !resultLayout || !maskLayout || !sourceLayout.isContiguous() || !resultLayout.isGroupSlots() || resultLayout.getNumGroups() != numGroups || !maskLayout.isContiguous()) - return fail("vcgadd group_reduce_addf path requires contiguous source/mask " + return fail("vcgadd group_reduce_add path requires contiguous source/mask " "layouts and matching num_groups result layout"); std::string sourceFullReason; if (failed(checkFullDataPhysicalChunks(sourceType, &sourceFullReason))) - return fail(Twine("vcgadd group_reduce_addf path requires full source " + return fail(Twine("vcgadd group_reduce_add path requires full source " "chunks; ") + sourceFullReason); FailureOr sourceArity = getVMIPhysicalArity(sourceType); FailureOr maskArity = getVMIPhysicalArity(maskType); FailureOr resultArity = getVMIPhysicalArity(resultType); if (failed(sourceArity) || failed(maskArity) || failed(resultArity)) - return fail("vcgadd group_reduce_addf path requires computable physical " + return fail("vcgadd group_reduce_add path requires computable physical " "arity"); if (*sourceArity < 1 || *sourceArity != *maskArity || *sourceArity != *resultArity) - return fail("vcgadd group_reduce_addf path requires matching non-empty " + return fail("vcgadd group_reduce_add path requires matching non-empty " "source/mask/result physical arity"); return success(); } -LogicalResult checkS16Block8GroupReduceShape(VMIGroupReduceAddFOp op, +template +LogicalResult checkS16Block8GroupReduceShape(OpTy op, std::string *reason) { auto fail = [&](const Twine &message) -> LogicalResult { if (reason) @@ -2478,14 +2485,18 @@ LogicalResult checkS16Block8GroupReduceShape(VMIGroupReduceAddFOp op, auto sourceType = cast(op.getSource().getType()); auto maskType = cast(op.getMask().getType()); auto resultType = cast(op.getResult().getType()); - if (!sourceType.getElementType().isF32() || - sourceType.getElementType() != resultType.getElementType()) - return fail("s16 block8 group_reduce_addf requires f32 source/result"); + if (sourceType.getElementType() != resultType.getElementType()) + return fail("two-vlane group_reduce_add requires matching source/result " + "element types"); FailureOr groupSize = getGroupSizeFromNumGroups(sourceType, op.getNumGroupsAttr().getInt()); - if (failed(groupSize) || *groupSize != 16) - return fail("s16 block8 group_reduce_addf requires group size 16"); + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + if (failed(groupSize) || failed(lanesPerPart) || *lanesPerPart % 8 != 0 || + *groupSize != 2 * (*lanesPerPart / 8)) + return fail("two-vlane group_reduce_add requires group size equal to two " + "32-byte VLanes"); VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); VMILayoutAttr maskLayout = maskType.getLayoutAttr(); @@ -2494,33 +2505,34 @@ LogicalResult checkS16Block8GroupReduceShape(VMIGroupReduceAddFOp op, if (!sourceLayout || !sourceLayout.isDeinterleaved() || sourceLayout.getFactor() != 2 || (sourceLayout.getBlockElems() != 1 && sourceLayout.getBlockElems() != 8)) - return fail("s16 group_reduce_addf requires source layout " + return fail("two-vlane group_reduce_add requires source layout " "deinterleaved=2 with block_elems=1 or block_elems=8"); if (!maskLayout || !maskLayout.isDeinterleaved() || maskLayout.getFactor() != 2 || maskLayout.getBlockElems() != sourceLayout.getBlockElems()) - return fail("s16 group_reduce_addf requires matching mask layout " + return fail("two-vlane group_reduce_add requires matching mask layout " "deinterleaved=2 with the same block_elems"); if (!resultLayout || !resultLayout.isGroupSlots() || resultLayout.getNumGroups() != numGroups || resultLayout.getSlots() != 8) - return fail("s16 block8 group_reduce_addf requires " + return fail("two-vlane group_reduce_add requires " "group_slots(num_groups, slots=8) result layout"); FailureOr sourceArity = getVMIPhysicalArity(sourceType); FailureOr maskArity = getVMIPhysicalArity(maskType); FailureOr resultArity = getVMIPhysicalArity(resultType); if (failed(sourceArity) || failed(maskArity) || failed(resultArity)) - return fail("s16 block8 group_reduce_addf requires computable physical " + return fail("two-vlane group_reduce_add requires computable physical " "arity"); int64_t expectedResultArity = ceilDivNonNegative(numGroups, 8); if (*resultArity != expectedResultArity || *sourceArity != *resultArity * 2 || *maskArity != *sourceArity) - return fail("s16 block8 group_reduce_addf requires two source/mask " + return fail("two-vlane group_reduce_add requires two source/mask " "parts per result part"); return success(); } -LogicalResult checkS32Block8GroupReduceShape(VMIGroupReduceAddFOp op, +template +LogicalResult checkS32Block8GroupReduceShape(OpTy op, std::string *reason) { auto fail = [&](const Twine &message) -> LogicalResult { if (reason) @@ -2531,14 +2543,18 @@ LogicalResult checkS32Block8GroupReduceShape(VMIGroupReduceAddFOp op, auto sourceType = cast(op.getSource().getType()); auto maskType = cast(op.getMask().getType()); auto resultType = cast(op.getResult().getType()); - if (!sourceType.getElementType().isF32() || - sourceType.getElementType() != resultType.getElementType()) - return fail("s32 block8 group_reduce_addf requires f32 source/result"); + if (sourceType.getElementType() != resultType.getElementType()) + return fail("four-vlane group_reduce_add requires matching source/result " + "element types"); FailureOr groupSize = getGroupSizeFromNumGroups(sourceType, op.getNumGroupsAttr().getInt()); - if (failed(groupSize) || *groupSize != 32) - return fail("s32 block8 group_reduce_addf requires group size 32"); + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + if (failed(groupSize) || failed(lanesPerPart) || *lanesPerPart % 8 != 0 || + *groupSize != 4 * (*lanesPerPart / 8)) + return fail("four-vlane group_reduce_add requires group size equal to four " + "32-byte VLanes"); VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); VMILayoutAttr maskLayout = maskType.getLayoutAttr(); @@ -2547,27 +2563,27 @@ LogicalResult checkS32Block8GroupReduceShape(VMIGroupReduceAddFOp op, if (!sourceLayout || !sourceLayout.isDeinterleaved() || sourceLayout.getFactor() != 4 || (sourceLayout.getBlockElems() != 1 && sourceLayout.getBlockElems() != 8)) - return fail("s32 group_reduce_addf requires source layout " + return fail("four-vlane group_reduce_add requires source layout " "deinterleaved=4 with block_elems=1 or block_elems=8"); if (!maskLayout || !maskLayout.isDeinterleaved() || maskLayout.getFactor() != 4 || maskLayout.getBlockElems() != sourceLayout.getBlockElems()) - return fail("s32 group_reduce_addf requires matching mask layout " + return fail("four-vlane group_reduce_add requires matching mask layout " "deinterleaved=4 with the same block_elems"); if (!resultLayout || !resultLayout.isGroupSlots() || resultLayout.getNumGroups() != numGroups || resultLayout.getSlots() != 8) - return fail("s32 block8 group_reduce_addf requires " + return fail("four-vlane group_reduce_add requires " "group_slots(num_groups, slots=8) result layout"); FailureOr sourceArity = getVMIPhysicalArity(sourceType); FailureOr maskArity = getVMIPhysicalArity(maskType); FailureOr resultArity = getVMIPhysicalArity(resultType); if (failed(sourceArity) || failed(maskArity) || failed(resultArity)) - return fail("s32 block8 group_reduce_addf requires computable physical " + return fail("four-vlane group_reduce_add requires computable physical " "arity"); int64_t expectedResultArity = ceilDivNonNegative(numGroups, 8); if (*resultArity != expectedResultArity || *sourceArity != *resultArity * 4 || *maskArity != *sourceArity) - return fail("s32 block8 group_reduce_addf requires four source/mask " + return fail("four-vlane group_reduce_add requires four source/mask " "parts per result part"); return success(); @@ -5551,13 +5567,12 @@ struct OneToNVMIReduceAddFOpPattern } }; -struct OneToNVMIGroupReduceAddFOpPattern - : OneToNOpConversionPattern { - using OneToNOpConversionPattern< - VMIGroupReduceAddFOp>::OneToNOpConversionPattern; +template +struct OneToNVMIGroupReduceAddOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult - matchAndRewrite(VMIGroupReduceAddFOp op, OpAdaptor adaptor, + matchAndRewrite(OpTy op, typename OneToNOpConversionPattern::OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override { auto sourceVMIType = cast(op.getSource().getType()); auto maskVMIType = cast(op.getMask().getType()); @@ -6295,6 +6310,215 @@ struct OneToNVMITruncFOpPattern : OneToNOpConversionPattern { } }; +template +struct OneToNVMIExtIOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite( + OpT op, typename OneToNOpConversionPattern::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.empty()) + return rewriter.notifyMatchFailure( + op, "integer extension requires at least one physical source chunk"); + + auto sourceType = dyn_cast(sourceParts.front().getType()); + if (!sourceType) + return rewriter.notifyMatchFailure( + op, "expected physical integer extension source"); + for (Value sourcePart : sourceParts) { + auto currentSourceType = dyn_cast(sourcePart.getType()); + if (!currentSourceType || currentSourceType != sourceType) + return rewriter.notifyMatchFailure( + op, "integer extension source physical parts must have matching " + "type"); + } + + SmallVector resultVRegTypes; + resultVRegTypes.reserve(resultTypes.size()); + for (Type resultType : resultTypes) { + auto resultVRegType = dyn_cast(resultType); + if (!resultVRegType || + !isa(resultVRegType.getElementType()) || + (resultVRegTypes.empty() ? pto::getPTOStorageElemBitWidth( + resultVRegType.getElementType()) != 32 + : resultVRegType != + resultVRegTypes.front())) + return rewriter.notifyMatchFailure( + op, "unsupported physical integer extension result type"); + resultVRegTypes.push_back(resultVRegType); + } + + unsigned sourceBits = + pto::getPTOStorageElemBitWidth(sourceType.getElementType()); + ArrayRef parts; + int64_t factor = 0; + if (sourceBits == 16 && resultTypes.size() == 2 * sourceParts.size()) { + static constexpr StringRef kEvenOddParts[] = {"EVEN", "ODD"}; + parts = kEvenOddParts; + factor = 2; + } else if (sourceBits == 8 && + resultTypes.size() == 4 * sourceParts.size()) { + static constexpr StringRef kPacked4Parts[] = {"P0", "P1", "P2", "P3"}; + parts = kPacked4Parts; + factor = 4; + } else { + return rewriter.notifyMatchFailure( + op, "unsupported physical integer extension source/result width " + "relation"); + } + + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), sourceType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "failed to build integer extension seed mask"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (int64_t partIndex = 0; partIndex < factor; ++partIndex) { + for (auto [chunkIndex, sourcePart] : llvm::enumerate(sourceParts)) { + VRegType resultType = + resultVRegTypes[partIndex * sourceParts.size() + chunkIndex]; + results.push_back( + rewriter + .create(op.getLoc(), resultType, sourcePart, *mask, + /*rnd=*/nullptr, /*sat=*/nullptr, + rewriter.getStringAttr(parts[partIndex])) + .getResult()); + } + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMITruncIOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMITruncIOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto sourceVMIType = cast(op.getSource().getType()); + auto resultVMIType = cast(op.getResult().getType()); + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + + VMILayoutAttr sourceLayout = sourceVMIType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); + if (sourceLayout && resultLayout && sourceLayout.isGroupSlots() && + resultLayout.isGroupSlots()) { + if (sourceLayout.getNumGroups() != resultLayout.getNumGroups() || + sourceLayout.getSlots() != 1 || resultLayout.getSlots() != 1 || + pto::getPTOStorageElemBitWidth(sourceVMIType.getElementType()) != + 32 || + pto::getPTOStorageElemBitWidth(resultVMIType.getElementType()) != + 16 || + sourceParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "unsupported group-slot trunci shape"); + + SmallVector results; + results.reserve(resultTypes.size()); + StringAttr sat = rewriter.getStringAttr("SAT"); + StringAttr even = rewriter.getStringAttr("EVEN"); + FailureOr lane0Mask = createPrefixMask( + op.getLoc(), MaskType::get(rewriter.getContext(), "b32"), "PAT_VL1", + rewriter); + if (failed(lane0Mask)) + return rewriter.notifyMatchFailure( + op, "failed to build group-slot trunci lane0 mask"); + for (auto [sourcePart, physicalResultType] : + llvm::zip_equal(sourceParts, resultTypes)) { + auto sourceType = dyn_cast(sourcePart.getType()); + auto resultType = dyn_cast(physicalResultType); + if (!sourceType || + pto::getPTOStorageElemBitWidth(sourceType.getElementType()) != 32 || + !resultType || + pto::getPTOStorageElemBitWidth(resultType.getElementType()) != 16) + return rewriter.notifyMatchFailure( + op, "unsupported group-slot trunci physical type"); + results.push_back(rewriter + .create(op.getLoc(), resultType, + sourcePart, *lane0Mask, + /*rnd=*/nullptr, sat, even) + .getResult()); + } + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + if ((sourceParts.size() != 2 && sourceParts.size() != 4) || + resultTypes.size() != 1) + return rewriter.notifyMatchFailure( + op, "only 32-bit integer deinterleaved=2/4 to 16/8-bit contiguous " + "trunci is supported"); + + auto sourceType0 = dyn_cast(sourceParts.front().getType()); + auto resultType = dyn_cast(resultTypes.front()); + if (!sourceType0 || !isa(sourceType0.getElementType()) || + !resultType || !isa(resultType.getElementType())) + return rewriter.notifyMatchFailure( + op, "unsupported physical trunci source/result type"); + for (Value sourcePart : sourceParts) { + auto sourceType = dyn_cast(sourcePart.getType()); + if (!sourceType || sourceType != sourceType0) + return rewriter.notifyMatchFailure( + op, "trunci source physical parts must have matching 32-bit " + "integer type"); + } + + if (pto::getPTOStorageElemBitWidth(sourceType0.getElementType()) != 32) + return rewriter.notifyMatchFailure( + op, "trunci source physical element width must be 32-bit"); + unsigned resultBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); + ArrayRef parts; + if (sourceParts.size() == 2 && resultBits == 16) { + static constexpr StringRef kEvenOddParts[] = {"EVEN", "ODD"}; + parts = kEvenOddParts; + } else if (sourceParts.size() == 4 && resultBits == 8) { + static constexpr StringRef kPacked4Parts[] = {"P0", "P1", "P2", "P3"}; + parts = kPacked4Parts; + } else { + return rewriter.notifyMatchFailure( + op, "unsupported physical trunci source/result width relation"); + } + + FailureOr sourceMask = + createAllTrueMaskForVReg(op.getLoc(), sourceType0, rewriter); + FailureOr resultMask = + createAllTrueMaskForVReg(op.getLoc(), resultType, rewriter); + if (failed(sourceMask) || failed(resultMask)) + return rewriter.notifyMatchFailure(op, "failed to build trunci masks"); + + StringAttr sat = rewriter.getStringAttr("SAT"); + SmallVector partials; + partials.reserve(parts.size()); + for (auto [sourcePart, part] : llvm::zip_equal(sourceParts, parts)) { + partials.push_back(rewriter + .create(op.getLoc(), resultType, + sourcePart, *sourceMask, + /*rnd=*/nullptr, sat, + rewriter.getStringAttr(part)) + .getResult()); + } + + Value merged = partials.front(); + for (Value partial : llvm::drop_begin(partials)) + merged = rewriter + .create(op.getLoc(), resultType, merged, partial, + *resultMask) + .getResult(); + + rewriter.replaceOp(op, merged, adaptor.getResultMapping()); + return success(); + } +}; + struct OneToNVMIBitcastOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; @@ -6782,10 +7006,14 @@ void populateVMIOneToNConversionPatterns( OneToNVMISelectOpPattern, OneToNVMIActivePrefixIndexOpPattern, OneToNVMICompressOpPattern, OneToNVMICompressStoreOpPattern, OneToNVMIReduceAddIOpPattern, OneToNVMIReduceAddFOpPattern, - OneToNVMIGroupReduceAddFOpPattern, OneToNVMIGroupBroadcastOpPattern, + OneToNVMIGroupReduceAddOpPattern, + OneToNVMIGroupReduceAddOpPattern, + OneToNVMIGroupBroadcastOpPattern, OneToNVMIReduceMinMaxFOpPattern, OneToNVMIReduceMinMaxFOpPattern, OneToNVMIExtFOpPattern, OneToNVMITruncFOpPattern, + OneToNVMIExtIOpPattern, + OneToNVMIExtIOpPattern, OneToNVMITruncIOpPattern, OneToNVMIBitcastOpPattern, OneToNVMIChannelSplitOpPattern, OneToNVMIChannelMergeOpPattern, OneToNVMIShuffleOpPattern>( typeConverter, patterns.getContext()); @@ -6845,6 +7073,30 @@ LogicalResult checkSupportedTruncFShape(VMITruncFOp op, return success(); } +LogicalResult checkSupportedExtSIShape(VMIExtSIOp op, + std::string *reason = nullptr) { + VMILocalRecipeRegistry recipes; + if (failed(recipes.getExtSIRecipe(op, reason))) + return failure(); + return success(); +} + +LogicalResult checkSupportedExtUIShape(VMIExtUIOp op, + std::string *reason = nullptr) { + VMILocalRecipeRegistry recipes; + if (failed(recipes.getExtUIRecipe(op, reason))) + return failure(); + return success(); +} + +LogicalResult checkSupportedTruncIShape(VMITruncIOp op, + std::string *reason = nullptr) { + VMILocalRecipeRegistry recipes; + if (failed(recipes.getTruncIRecipe(op, reason))) + return failure(); + return success(); +} + LogicalResult checkSupportedBitcastShape(VMIBitcastOp op, std::string *reason) { VMILocalRecipeRegistry recipes; if (failed(recipes.getBitcastRecipe(op, reason))) @@ -7143,8 +7395,9 @@ checkSupportedReduceShape(const VMITargetCapabilityRegistry &capabilities, return success(); } -LogicalResult checkSupportedGroupReduceAddFShape( - const VMITargetCapabilityRegistry &capabilities, VMIGroupReduceAddFOp op, +template +LogicalResult checkSupportedGroupReduceAddShape( + const VMITargetCapabilityRegistry &capabilities, OpTy op, std::string *reason = nullptr) { auto fail = [&](const Twine &message) -> LogicalResult { if (reason) @@ -7152,8 +7405,10 @@ LogicalResult checkSupportedGroupReduceAddFShape( return failure(); }; - if (!op->hasAttr("reassoc")) + if constexpr (std::is_same_v) { + if (!op->hasAttr("reassoc")) return fail("requires reassoc attr for pair-wise floating-point reduction"); + } auto sourceType = cast(op.getSource().getType()); auto resultType = cast(op.getResult().getType()); auto maskType = cast(op.getMask().getType()); @@ -7164,8 +7419,13 @@ LogicalResult checkSupportedGroupReduceAddFShape( return fail("requires assigned source, mask, and result layouts"); VMILocalRecipeRegistry recipes; - if (succeeded(recipes.getGroupReduceAddFRecipe(capabilities, op, nullptr))) - return success(); + if constexpr (std::is_same_v) { + if (succeeded(recipes.getGroupReduceAddFRecipe(capabilities, op, nullptr))) + return success(); + } else { + if (succeeded(recipes.getGroupReduceAddIRecipe(capabilities, op, nullptr))) + return success(); + } FailureOr groupSize = getGroupSizeFromNumGroups( sourceType, op.getNumGroupsAttr().getInt(), reason); @@ -7181,7 +7441,9 @@ LogicalResult checkSupportedGroupReduceAddFShape( return fail("requires contiguous source/mask layouts and matching " "num_groups result layout"); VMICapabilityResult elementCapability = - capabilities.supportsReductionElementType(VMIReductionKind::AddF, + capabilities.supportsReductionElementType( + std::is_same_v ? VMIReductionKind::GroupAddF + : VMIReductionKind::GroupAddI, sourceType.getElementType()); if (!elementCapability.isSupported()) return fail(elementCapability.reason); @@ -7204,10 +7466,15 @@ LogicalResult checkSupportedGroupReduceAddFShape( if (resultLayout.getSlots() <= 0) return success(); - if (!sourceLayout.isContiguous() || *groupSize != 64 || + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + if (failed(lanesPerPart)) + return fail("requires known physical chunk lane count"); + if (!sourceLayout.isContiguous() || *groupSize != *lanesPerPart || resultLayout.getSlots() != 1) - return fail("explicit group_slots group_reduce_addf chunk path requires " - "contiguous group size 64 source and slots=1 result layout"); + return fail("explicit group_slots group_reduce_add chunk path requires " + "contiguous full-physical-chunk group size source and slots=1 " + "result layout"); return success(); } @@ -7843,13 +8110,13 @@ verifySupportedVMIToVPTOOps(ModuleOp module, if (auto reduce = dyn_cast(op)) { std::string reason; - if (succeeded(checkSupportedGroupReduceAddFShape(capabilities, reduce, - &reason))) + if (succeeded( + checkSupportedGroupReduceAddShape(capabilities, reduce, &reason))) return WalkResult::advance(); reduce.emitError() << kVMIDiagUnsupportedPrefix - << "pto.vmi.group_reduce_addf lowers through pto.vcgadd for f32 " - "32B groups or through pto.vcadd with reassoc, contiguous full " + << "pto.vmi.group_reduce_addf lowers through pto.vcgadd for 32B " + "VLane groups or through pto.vcadd with reassoc, contiguous full " "source/mask chunks, #pto.vmi.layout result " "chunks, and num_groups deriving a group size aligned to " "physical chunks (" @@ -7857,6 +8124,21 @@ verifySupportedVMIToVPTOOps(ModuleOp module, return WalkResult::interrupt(); } + if (auto reduce = dyn_cast(op)) { + std::string reason; + if (succeeded( + checkSupportedGroupReduceAddShape(capabilities, reduce, &reason))) + return WalkResult::advance(); + reduce.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.group_reduce_addi lowers through pto.vcgadd/vadd only " + "for i32 accumulator values; i8/i16 storage must be cast to i32 " + "before grouped reduction because narrow integer reductions " + "widen their result (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto reduce = dyn_cast(op)) { std::string reason; if (succeeded(checkSupportedReduceShape( @@ -7930,6 +8212,51 @@ verifySupportedVMIToVPTOOps(ModuleOp module, return WalkResult::interrupt(); } + if (auto extsi = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedExtSIShape(extsi, &reason))) + return WalkResult::advance(); + + extsi.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.extsi supports contiguous signed/signless 8-bit or " + "16-bit integer physical source chunks to 32-bit integer " + "deinterleaved=4/2 results (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto extui = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedExtUIShape(extui, &reason))) + return WalkResult::advance(); + + extui.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.extui supports contiguous unsigned 8-bit or 16-bit " + "integer physical source chunks to unsigned 32-bit integer " + "deinterleaved=4/2 results (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto trunci = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedTruncIShape(trunci, &reason))) + return WalkResult::advance(); + + trunci.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.trunci supports only 32-bit integer deinterleaved=2 " + "source parts to one contiguous 16-bit integer result chunk, " + "32-bit integer deinterleaved=4 source parts to one contiguous " + "8-bit integer result chunk, or 32-bit integer " + "group_slots(num_groups=G, slots=1) to 16-bit integer " + "group_slots(num_groups=G, slots=1) (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto bitcast = dyn_cast(op)) { std::string reason; if (succeeded(checkSupportedBitcastShape(bitcast, &reason))) diff --git a/test/lit/vmi/vmi_group_reduce_addi_i16_invalid.pto b/test/lit/vmi/vmi_group_reduce_addi_i16_invalid.pto new file mode 100644 index 0000000000..948dfe9c54 --- /dev/null +++ b/test/lit/vmi/vmi_group_reduce_addi_i16_invalid.pto @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @direct_i16_group_reduce_invalid( + %source: !pto.vmi.vreg<128xi16>, + %mask: !pto.vmi.mask<128xpred>) { + %sum = pto.vmi.group_reduce_addi %source, %mask {num_groups = 8} + : !pto.vmi.vreg<128xi16>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xi16> + return + } +} + +// CHECK: requires i32 accumulator element type; cast i8/i16 storage to i32 before grouped reduction because integer reduction widens narrow inputs diff --git a/test/lit/vmi/vmi_group_reduce_addi_i8_invalid.pto b/test/lit/vmi/vmi_group_reduce_addi_i8_invalid.pto new file mode 100644 index 0000000000..578acc00b9 --- /dev/null +++ b/test/lit/vmi/vmi_group_reduce_addi_i8_invalid.pto @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @direct_i8_group_reduce_invalid( + %source: !pto.vmi.vreg<256xi8>, + %mask: !pto.vmi.mask<256xpred>) { + %sum = pto.vmi.group_reduce_addi %source, %mask {num_groups = 8} + : !pto.vmi.vreg<256xi8>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xi8> + return + } +} + +// CHECK: requires i32 accumulator element type; cast i8/i16 storage to i32 before grouped reduction because integer reduction widens narrow inputs diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto index eccb4e0007..b322e5700e 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto @@ -16,7 +16,7 @@ module { %off: index) { %c1 = arith.constant 1 : index // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots local recipe - // CHECK-SAME: stable group_reduce_addf slots=8 recipes support group size 8, 16, or 32 + // CHECK-SAME: stable group_reduce_add slots=8 recipes support group sizes VLaneElems, 2*VLaneElems, or 4*VLaneElems %sum = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<96xf32>, !pto.vmi.mask<96xpred> diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_typed.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_typed.pto new file mode 100644 index 0000000000..34bf1c9633 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_typed.pto @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @typed_group_reduce_assignment( + %f16: !pto.vmi.vreg<256xf16>, + %mf16: !pto.vmi.mask<256xpred>, + %i16: !pto.vmi.vreg<128xi16>, + %mi16: !pto.vmi.mask<128xpred>, + %i32: !pto.vmi.vreg<128xi32>, + %mi32: !pto.vmi.mask<128xpred>) { + %sum_f16 = pto.vmi.group_reduce_addf %f16, %mf16 {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf16>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf16> + %wide_i16 = pto.vmi.extsi %i16 + : !pto.vmi.vreg<128xi16> -> !pto.vmi.vreg<128xi32> + %sum_i16 = pto.vmi.group_reduce_addi %wide_i16, %mi16 {num_groups = 8} + : !pto.vmi.vreg<128xi32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xi32> + %sum_i32 = pto.vmi.group_reduce_addi %i32, %mi32 {num_groups = 8} + : !pto.vmi.vreg<128xi32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xi32> + return + } +} + +// CHECK-LABEL: func.func @typed_group_reduce_assignment( +// CHECK: %[[F16_SPLIT:.*]] = pto.vmi.ensure_layout +// CHECK-SAME: -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> +// CHECK: %[[MF16_SPLIT:.*]] = pto.vmi.ensure_mask_layout +// CHECK-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// CHECK: %[[MF16_B16:.*]] = pto.vmi.ensure_mask_granularity %[[MF16_SPLIT]] +// CHECK-SAME: -> !pto.vmi.mask<256xb16, #pto.vmi.layout> +// CHECK: pto.vmi.group_reduce_addf %[[F16_SPLIT]], %[[MF16_B16]] +// CHECK-SAME: -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> +// CHECK: %[[WIDE_I16:.*]] = pto.vmi.extsi %arg2 +// CHECK-SAME: -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK: %[[MI16_SPLIT:.*]] = pto.vmi.ensure_mask_layout +// CHECK-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: pto.vmi.group_reduce_addi %[[WIDE_I16]], %[[MI16_SPLIT]] +// CHECK-SAME: -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK: %[[I32_SPLIT:.*]] = pto.vmi.ensure_layout +// CHECK-SAME: -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK: %[[MI32_SPLIT:.*]] = pto.vmi.ensure_mask_layout +// CHECK-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: pto.vmi.group_reduce_addi %[[I32_SPLIT]], %[[MI32_SPLIT]] +// CHECK-SAME: -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_gate_group_reduce_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_reduce_recipe_invalid.pto index 673f3ee47b..33a7bc0fae 100644 --- a/test/lit/vmi/vmi_layout_gate_group_reduce_recipe_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_group_reduce_recipe_invalid.pto @@ -13,7 +13,7 @@ module { %source: !pto.vmi.vreg<96xf32, #pto.vmi.layout>, %mask: !pto.vmi.mask<96xb32, #pto.vmi.layout>) { // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots local recipe - // CHECK-SAME: stable group_reduce_addf slots=8 recipes support group size 8, 16, or 32 + // CHECK-SAME: stable group_reduce_add slots=8 recipes support group sizes VLaneElems, 2*VLaneElems, or 4*VLaneElems // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_reduce_addf" %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} diff --git a/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto index 6e0b04e8f6..d33315f88d 100644 --- a/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto @@ -13,7 +13,7 @@ module { %source: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, %mask: !pto.vmi.mask<256xb32, #pto.vmi.layout>) { // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots local recipe - // CHECK-SAME: stable group_reduce_addf slots=1 recipes support group sizes that are multiples of one physical chunk + // CHECK-SAME: stable group_reduce_add slots=1 recipes support group sizes that are multiples of one physical chunk // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_reduce_addf" %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} diff --git a/test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto index b8576fe3b7..c787f57fea 100644 --- a/test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto @@ -29,7 +29,7 @@ module { %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots local recipe - // CHECK-SAME: stable group_reduce_addf local recipes currently require result layout slots=8 or slots=1 + // CHECK-SAME: stable group_reduce_add local recipes currently require result layout slots=8 or slots=1 %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_typed.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_typed.pto new file mode 100644 index 0000000000..f01c6865a1 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_typed.pto @@ -0,0 +1,80 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_group_reduce_addf_f16_vlane( + %source: !pto.vmi.vreg<128xf16, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb16, #pto.vmi.layout>) + -> !pto.vreg<128xf16> { + %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<128xf16, #pto.vmi.layout>) + -> !pto.vreg<128xf16> + return %part : !pto.vreg<128xf16> + } + + func.func @vmi_group_reduce_addi_i16_storage_to_i32_vlane( + %source: !pto.vmi.vreg<128xi16, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vreg<64xi32> { + %wide = pto.vmi.extsi %source + : !pto.vmi.vreg<128xi16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %out = pto.vmi.group_reduce_addi %wide, %mask {num_groups = 8} + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> !pto.vreg<64xi32> + return %part : !pto.vreg<64xi32> + } + + func.func @vmi_group_reduce_addi_i32_two_vlane( + %source: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vreg<64xi32> { + %out = pto.vmi.group_reduce_addi %source, %mask {num_groups = 8} + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> !pto.vreg<64xi32> + return %part : !pto.vreg<64xi32> + } +} + +// CHECK-LABEL: func.func @vmi_group_reduce_addf_f16_vlane( +// CHECK: %[[OUT:.*]] = pto.vcgadd %arg0, %arg1 : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: return %[[OUT]] + +// CHECK-LABEL: func.func @vmi_group_reduce_addi_i16_storage_to_i32_vlane( +// CHECK: %[[EVEN:.*]] = pto.vcvt %arg0, {{.*}} {part = "EVEN"} : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[ODD:.*]] = pto.vcvt %arg0, {{.*}} {part = "ODD"} : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[S0:.*]] = pto.vcgadd %[[EVEN]], %arg1 +// CHECK: %[[S1:.*]] = pto.vcgadd %[[ODD]], %arg2 +// CHECK: %[[SUM:.*]] = pto.vadd %[[S0]], %[[S1]] +// CHECK: return %[[SUM]] + +// CHECK-LABEL: func.func @vmi_group_reduce_addi_i32_two_vlane( +// CHECK: %[[MASK:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// CHECK: %[[SLO:.*]] = pto.vcgadd %arg0, %arg2 +// CHECK: %[[SHI:.*]] = pto.vcgadd %arg1, %arg3 +// CHECK: %[[SUM:.*]] = pto.vadd %[[SLO]], %[[SHI]], %[[MASK]] +// CHECK: return %[[SUM]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_integer_cast_reduce.pto b/test/lit/vmi/vmi_to_vpto_integer_cast_reduce.pto new file mode 100644 index 0000000000..c3e7403e91 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_integer_cast_reduce.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_extsi_i8_to_i32_group_reduce( + %source: !pto.vmi.vreg<256xi8>, + %mask: !pto.vmi.mask<256xb32, #pto.vmi.layout>) + -> !pto.vmi.vreg<256xi32, #pto.vmi.layout> { + %wide = pto.vmi.extsi %source + : !pto.vmi.vreg<256xi8> + -> !pto.vmi.vreg<256xi32, #pto.vmi.layout> + %sum = pto.vmi.group_reduce_addi %wide, %mask {num_groups = 8} + : !pto.vmi.vreg<256xi32, #pto.vmi.layout>, + !pto.vmi.mask<256xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xi32, #pto.vmi.layout> + return %sum : !pto.vmi.vreg<256xi32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_extsi_i8_to_i32_group_reduce( +// CHECK: %[[P0:.*]] = pto.vcvt %arg0, {{.*}} {part = "P0"} : !pto.vreg<256xi8>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[P1:.*]] = pto.vcvt %arg0, {{.*}} {part = "P1"} : !pto.vreg<256xi8>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[P2:.*]] = pto.vcvt %arg0, {{.*}} {part = "P2"} : !pto.vreg<256xi8>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[P3:.*]] = pto.vcvt %arg0, {{.*}} {part = "P3"} : !pto.vreg<256xi8>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[S0:.*]] = pto.vcgadd %[[P0]] +// CHECK: %[[S1:.*]] = pto.vcgadd %[[P1]] +// CHECK: %[[S2:.*]] = pto.vcgadd %[[P2]] +// CHECK: %[[S3:.*]] = pto.vcgadd %[[P3]] +// CHECK: %[[A01:.*]] = pto.vadd %[[S0]], %[[S1]] +// CHECK: %[[A23:.*]] = pto.vadd %[[S2]], %[[S3]] +// CHECK: %[[SUM:.*]] = pto.vadd %[[A01]], %[[A23]] +// CHECK: return %[[SUM]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_integer_casts.pto b/test/lit/vmi/vmi_to_vpto_integer_casts.pto new file mode 100644 index 0000000000..50051aab6d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_integer_casts.pto @@ -0,0 +1,64 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_extui_u8_to_u32( + %input: !pto.vmi.vreg<256xui8, #pto.vmi.layout>) + -> (!pto.vreg<64xui32>, !pto.vreg<64xui32>, + !pto.vreg<64xui32>, !pto.vreg<64xui32>) { + %wide = pto.vmi.extui %input + : !pto.vmi.vreg<256xui8, #pto.vmi.layout> + -> !pto.vmi.vreg<256xui32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%wide) + : (!pto.vmi.vreg<256xui32, #pto.vmi.layout>) + -> (!pto.vreg<64xui32>, !pto.vreg<64xui32>, + !pto.vreg<64xui32>, !pto.vreg<64xui32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xui32>, !pto.vreg<64xui32>, + !pto.vreg<64xui32>, !pto.vreg<64xui32> + } + + func.func @vmi_to_vpto_trunci_i32_to_ui8( + %wide: !pto.vmi.vreg<256xi32, #pto.vmi.layout>) + -> !pto.vreg<256xui8> { + %narrow = pto.vmi.trunci %wide + : !pto.vmi.vreg<256xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xui8, #pto.vmi.layout> + %p = "pto.vmi.unpack"(%narrow) + : (!pto.vmi.vreg<256xui8, #pto.vmi.layout>) + -> !pto.vreg<256xui8> + return %p : !pto.vreg<256xui8> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_extui_u8_to_u32( +// CHECK-SAME: %[[INPUT:.*]]: !pto.vreg<256xui8> +// CHECK: %[[MASK:.*]] = pto.pset_b8 "PAT_ALL" : !pto.mask +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "P0"} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<64xui32> +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "P1"} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<64xui32> +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "P2"} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<64xui32> +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "P3"} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<64xui32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_trunci_i32_to_ui8( +// CHECK: %[[P0:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "P0", sat = "SAT"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<256xui8> +// CHECK: %[[P1:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "P1", sat = "SAT"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<256xui8> +// CHECK: %[[P2:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "P2", sat = "SAT"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<256xui8> +// CHECK: %[[P3:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "P3", sat = "SAT"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<256xui8> +// CHECK: %[[M01:.*]] = pto.vor %[[P0]], %[[P1]] +// CHECK: %[[M012:.*]] = pto.vor %[[M01]], %[[P2]] +// CHECK: pto.vor %[[M012]], %[[P3]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_reduce_addf_f16_invalid.pto b/test/lit/vmi/vmi_to_vpto_reduce_addf_f16.pto similarity index 55% rename from test/lit/vmi/vmi_to_vpto_reduce_addf_f16_invalid.pto rename to test/lit/vmi/vmi_to_vpto_reduce_addf_f16.pto index 4e24ee12a8..fc4ebdc92a 100644 --- a/test/lit/vmi/vmi_to_vpto_reduce_addf_f16_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_reduce_addf_f16.pto @@ -6,21 +6,36 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s module { - func.func @vmi_to_vpto_reduce_addf_f16_invalid( + func.func @vmi_to_vpto_reduce_addf_f16( %source: !pto.vmi.vreg<128xf16, #pto.vmi.layout>, %init: !pto.vmi.vreg<1xf16, #pto.vmi.layout>, - %mask: !pto.vmi.mask<128xb16, #pto.vmi.layout>) { + %mask: !pto.vmi.mask<128xb16, #pto.vmi.layout>) + -> !pto.vreg<128xf16> { %out = pto.vmi.reduce_addf %source, %init, %mask {reassoc} : !pto.vmi.vreg<128xf16, #pto.vmi.layout>, !pto.vmi.vreg<1xf16, #pto.vmi.layout>, !pto.vmi.mask<128xb16, #pto.vmi.layout> -> !pto.vmi.vreg<1xf16, #pto.vmi.layout> - return + %p = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1xf16, #pto.vmi.layout>) + -> !pto.vreg<128xf16> + return %p : !pto.vreg<128xf16> } } -// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.reduce_addf lowers through pto.vcadd only with reassoc -// CHECK-SAME: currently supports only f32 elements +// CHECK-LABEL: func.func @vmi_to_vpto_reduce_addf_f16( +// CHECK-SAME: %arg0: !pto.vreg<128xf16> +// CHECK-SAME: %arg1: !pto.vreg<128xf16> +// CHECK-SAME: %arg2: !pto.mask +// CHECK: %[[LANE0:.*]] = pto.pge_b16 "PAT_VL1" : !pto.mask +// CHECK: %[[REDUCED:.*]] = pto.vcadd %arg0, %arg2 +// CHECK-SAME: !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: %[[OUT:.*]] = pto.vadd %[[REDUCED]], %arg1, %[[LANE0]] +// CHECK-SAME: !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: return %[[OUT]] : !pto.vreg<128xf16> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_trunci_i8_signed_invalid.pto b/test/lit/vmi/vmi_to_vpto_trunci_i8_signed_invalid.pto new file mode 100644 index 0000000000..145ef2a7b9 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_trunci_i8_signed_invalid.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_trunci_i32_to_i8_invalid( + %wide: !pto.vmi.vreg<256xi32, #pto.vmi.layout>) + -> !pto.vreg<256xi8> { + %narrow = pto.vmi.trunci %wide + : !pto.vmi.vreg<256xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xi8, #pto.vmi.layout> + %p = "pto.vmi.unpack"(%narrow) + : (!pto.vmi.vreg<256xi8, #pto.vmi.layout>) + -> !pto.vreg<256xi8> + return %p : !pto.vreg<256xi8> + } +} + +// CHECK: VMI-UNSUPPORTED +// CHECK: pto.vmi.trunci supports only +// CHECK: 32-bit to 8-bit integer narrowing requires unsigned i8 result diff --git a/test/vpto/cases/vmi/group-reduce-f16-addf-store/compare.py b/test/vpto/cases/vmi/group-reduce-f16-addf-store/compare.py new file mode 100644 index 0000000000..fbba5d605b --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-addf-store/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under +# the terms and conditions of CANN Open Software License Agreement Version 2.0 +# (the "License"). Please refer to the License for details. You may not use +# this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +# AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +# FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +# for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float16) + output = np.fromfile("v2.bin", dtype=np.float16) + if golden.shape == output.shape and np.array_equal(golden, output): + print("[INFO] compare passed") + return + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-f16-addf-store/golden.py b/test/vpto/cases/vmi/group-reduce-f16-addf-store/golden.py new file mode 100644 index 0000000000..beed48b5da --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-addf-store/golden.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under +# the terms and conditions of CANN Open Software License Agreement Version 2.0 +# (the "License"). Please refer to the License for details. You may not use +# this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +# AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +# FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +# for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 16 + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, COLS), dtype=np.float16) + base = np.array([-3, -2, -1, 0, 1, 2, 3, 4], dtype=np.float16) + for row in range(ROWS): + src[row, :] = np.tile(np.roll(base, row), 2) + dst = np.full(ROWS, np.float16(-17), dtype=np.float16) + golden = np.sum(src, axis=1, dtype=np.float16).astype(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-f16-addf-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-f16-addf-store/kernel.pto new file mode 100644 index 0000000000..b8d274c280 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-addf-store/kernel.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_f16_addf_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c512_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf16> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf16>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf16> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf16>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c16_i64 + nburst(%c1_i64, %c16_i64, %c16_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-f16-addf-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-f16-addf-store/launch.cpp new file mode 100644 index 0000000000..8cfb1e58b5 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-addf-store/launch.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_f16_addf_store_kernel(__gm__ half *src, __gm__ half *dst); + +void LaunchVmi_group_reduce_f16_addf_store_kernel(uint16_t *src, uint16_t *dst, + void *stream) { + vmi_group_reduce_f16_addf_store_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ half *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-f16-addf-store/main.cpp b/test/vpto/cases/vmi/group-reduce-f16-addf-store/main.cpp new file mode 100644 index 0000000000..7a92e1a331 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-addf-store/main.cpp @@ -0,0 +1,86 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_f16_addf_store_kernel(uint16_t *src, uint16_t *dst, + void *stream); + +int main() { + constexpr size_t kInputElems = 128; + constexpr size_t kOutputElems = 8; + size_t srcBytes = kInputElems * sizeof(uint16_t); + size_t dstBytes = kOutputElems * sizeof(uint16_t); + uint16_t *srcHost = nullptr; + uint16_t *dstHost = nullptr; + uint16_t *srcDevice = nullptr; + uint16_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, + ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, + ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_f16_addf_store_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-f16-addf-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-f16-addf-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-addf-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/compare.py b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/compare.py new file mode 100644 index 0000000000..612b15c3f6 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under +# the terms and conditions of CANN Open Software License Agreement Version 2.0 +# (the "License"). Please refer to the License for details. You may not use +# this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +# AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +# FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +# for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.int32) + output = np.fromfile("v2.bin", dtype=np.int32) + if golden.shape == output.shape and np.array_equal(golden, output): + print("[INFO] compare passed") + return + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/golden.py b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/golden.py new file mode 100644 index 0000000000..00097384f0 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/golden.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under +# the terms and conditions of CANN Open Software License Agreement Version 2.0 +# (the "License"). Please refer to the License for details. You may not use +# this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +# AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +# FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +# for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 16 + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, COLS), dtype=np.int16) + base = np.array([-5, -3, -1, 0, 2, 4, 6, 8], dtype=np.int16) + for row in range(ROWS): + src[row, :] = np.tile(np.roll(base, row), 2) + dst = np.full(ROWS, -777, dtype=np.int32) + golden = np.sum(src.astype(np.int32), axis=1, dtype=np.int32).astype(np.int32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/kernel.pto new file mode 100644 index 0000000000..da95759e3c --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/kernel.pto @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_i16_extsi_i32_addi_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c512_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %x16 = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<128xi16> + %x32 = pto.vmi.extsi %x16 : !pto.vmi.vreg<128xi16> -> !pto.vmi.vreg<128xi32> + %sum = pto.vmi.group_reduce_addi %x32, %mask {num_groups = 8} + : !pto.vmi.vreg<128xi32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xi32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xi32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/launch.cpp new file mode 100644 index 0000000000..255de845bd --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/launch.cpp @@ -0,0 +1,36 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_i16_extsi_i32_addi_store_kernel(__gm__ int16_t *src, + __gm__ int32_t *dst); + +void LaunchVmi_group_reduce_i16_extsi_i32_addi_store_kernel(int16_t *src, + int32_t *dst, + void *stream) { + vmi_group_reduce_i16_extsi_i32_addi_store_kernel<<<1, nullptr, stream>>>( + (__gm__ int16_t *)src, (__gm__ int32_t *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/main.cpp b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/main.cpp new file mode 100644 index 0000000000..277a78662f --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/main.cpp @@ -0,0 +1,88 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_i16_extsi_i32_addi_store_kernel(int16_t *src, + int32_t *dst, + void *stream); + +int main() { + constexpr size_t kInputElems = 128; + constexpr size_t kOutputElems = 8; + size_t srcBytes = kInputElems * sizeof(int16_t); + size_t dstBytes = kOutputElems * sizeof(int32_t); + int16_t *srcHost = nullptr; + int32_t *dstHost = nullptr; + int16_t *srcDevice = nullptr; + int32_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, + ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, + ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_i16_extsi_i32_addi_store_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-i32-addi-store/compare.py b/test/vpto/cases/vmi/group-reduce-i32-addi-store/compare.py new file mode 100644 index 0000000000..612b15c3f6 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i32-addi-store/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under +# the terms and conditions of CANN Open Software License Agreement Version 2.0 +# (the "License"). Please refer to the License for details. You may not use +# this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +# AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +# FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +# for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.int32) + output = np.fromfile("v2.bin", dtype=np.int32) + if golden.shape == output.shape and np.array_equal(golden, output): + print("[INFO] compare passed") + return + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-i32-addi-store/golden.py b/test/vpto/cases/vmi/group-reduce-i32-addi-store/golden.py new file mode 100644 index 0000000000..4153e74342 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i32-addi-store/golden.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under +# the terms and conditions of CANN Open Software License Agreement Version 2.0 +# (the "License"). Please refer to the License for details. You may not use +# this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +# AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +# FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +# for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 8 + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, COLS), dtype=np.int32) + for row in range(ROWS): + src[row, :] = np.arange(COLS, dtype=np.int32) + row * 3 - 5 + dst = np.full(ROWS, -777, dtype=np.int32) + golden = np.sum(src, axis=1, dtype=np.int32).astype(np.int32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-i32-addi-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-i32-addi-store/kernel.pto new file mode 100644 index 0000000000..783658e453 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i32-addi-store/kernel.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_i32_addi_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c512_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c64 : index -> !pto.vmi.mask<64xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<64xi32> + %sum = pto.vmi.group_reduce_addi %x, %mask {num_groups = 8} + : !pto.vmi.vreg<64xi32>, !pto.vmi.mask<64xpred> + -> !pto.vmi.vreg<64xi32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<64xi32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-i32-addi-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-i32-addi-store/launch.cpp new file mode 100644 index 0000000000..5783bfd5a8 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i32-addi-store/launch.cpp @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_i32_addi_store_kernel(__gm__ int32_t *src, + __gm__ int32_t *dst); + +void LaunchVmi_group_reduce_i32_addi_store_kernel(int32_t *src, int32_t *dst, + void *stream) { + vmi_group_reduce_i32_addi_store_kernel<<<1, nullptr, stream>>>( + (__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-i32-addi-store/main.cpp b/test/vpto/cases/vmi/group-reduce-i32-addi-store/main.cpp new file mode 100644 index 0000000000..385f3ae909 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i32-addi-store/main.cpp @@ -0,0 +1,86 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_i32_addi_store_kernel(int32_t *src, int32_t *dst, + void *stream); + +int main() { + constexpr size_t kInputElems = 64; + constexpr size_t kOutputElems = 8; + size_t srcBytes = kInputElems * sizeof(int32_t); + size_t dstBytes = kOutputElems * sizeof(int32_t); + int32_t *srcHost = nullptr; + int32_t *dstHost = nullptr; + int32_t *srcDevice = nullptr; + int32_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, + ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, + ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_i32_addi_store_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-i32-addi-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-i32-addi-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i32-addi-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/compare.py b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/compare.py new file mode 100644 index 0000000000..612b15c3f6 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under +# the terms and conditions of CANN Open Software License Agreement Version 2.0 +# (the "License"). Please refer to the License for details. You may not use +# this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +# AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +# FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +# for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.int32) + output = np.fromfile("v2.bin", dtype=np.int32) + if golden.shape == output.shape and np.array_equal(golden, output): + print("[INFO] compare passed") + return + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/golden.py b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/golden.py new file mode 100644 index 0000000000..76d46fff4c --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/golden.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under +# the terms and conditions of CANN Open Software License Agreement Version 2.0 +# (the "License"). Please refer to the License for details. You may not use +# this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +# AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +# FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +# for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 32 + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, COLS), dtype=np.int8) + for row in range(ROWS): + src[row, :] = ((np.arange(COLS, dtype=np.int16) * 3 + row * 5) % 41 - 20).astype( + np.int8 + ) + dst = np.full(ROWS, -777, dtype=np.int32) + golden = np.sum(src.astype(np.int32), axis=1, dtype=np.int32).astype(np.int32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/kernel.pto new file mode 100644 index 0000000000..97154d0dd6 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/kernel.pto @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_i8_extsi_i32_addi_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c512_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x8 = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<256xi8> + %x32 = pto.vmi.extsi %x8 : !pto.vmi.vreg<256xi8> -> !pto.vmi.vreg<256xi32> + %sum = pto.vmi.group_reduce_addi %x32, %mask {num_groups = 8} + : !pto.vmi.vreg<256xi32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xi32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xi32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/launch.cpp new file mode 100644 index 0000000000..1e046a8eb5 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/launch.cpp @@ -0,0 +1,36 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_i8_extsi_i32_addi_store_kernel(__gm__ int8_t *src, + __gm__ int32_t *dst); + +void LaunchVmi_group_reduce_i8_extsi_i32_addi_store_kernel(int8_t *src, + int32_t *dst, + void *stream) { + vmi_group_reduce_i8_extsi_i32_addi_store_kernel<<<1, nullptr, stream>>>( + (__gm__ int8_t *)src, (__gm__ int32_t *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/main.cpp b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/main.cpp new file mode 100644 index 0000000000..cef9801b4d --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/main.cpp @@ -0,0 +1,88 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_i8_extsi_i32_addi_store_kernel(int8_t *src, + int32_t *dst, + void *stream); + +int main() { + constexpr size_t kInputElems = 256; + constexpr size_t kOutputElems = 8; + size_t srcBytes = kInputElems * sizeof(int8_t); + size_t dstBytes = kOutputElems * sizeof(int32_t); + int8_t *srcHost = nullptr; + int32_t *dstHost = nullptr; + int8_t *srcDevice = nullptr; + int32_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, + ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, + ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_i8_extsi_i32_addi_store_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi From 221f02e803c0199ad70306fa60be4c0ebfbdb1c3 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Tue, 23 Jun 2026 09:11:32 +0800 Subject: [PATCH 20/31] Implement VMI layout support lowering --- docs/designs/vmi-implementation-manual.md | 137 ++-- .../vmi-layout-assignment-implementation.md | 649 +++++++++--------- .../vmi-layout-assignment-lowering-design.md | 522 +++++++------- docs/designs/vmi-layout-lowering-cases.md | 195 +++--- include/PTO/Transforms/Passes.h | 3 +- include/PTO/Transforms/Passes.td | 10 +- include/PTO/Transforms/VMILayoutSupport.h | 287 ++++++++ .../PTO/Transforms/VMILocalRecipeRegistry.h | 234 ------- lib/PTO/Transforms/CMakeLists.txt | 2 +- lib/PTO/Transforms/PTOValidateVMIIR.cpp | 157 +++-- lib/PTO/Transforms/VMILayoutAssignment.cpp | 311 +++------ lib/PTO/Transforms/VMILayoutFoldConsumers.cpp | 6 +- .../VMILayoutSinkMaterialization.cpp | 274 +++++++- ...ecipeRegistry.cpp => VMILayoutSupport.cpp} | 505 +++++++++----- lib/PTO/Transforms/VMIToVPTO.cpp | 399 +++-------- ...assignment_broadcast_dense_group_users.pto | 6 +- .../vmi_layout_assignment_broadcast_remat.pto | 7 +- .../vmi_layout_assignment_constant_remat.pto | 8 +- ...ayout_assignment_create_group_mask_s16.pto | 6 +- ...signment_create_group_mask_s32_dynamic.pto | 6 +- ...ment_dense_group_reduce_multi_consumer.pto | 6 +- ...gnment_dense_store_group_slots_invalid.pto | 2 +- ..._layout_assignment_f32_f8_store_reduce.pto | 6 +- ...nment_group_load_block8_truncf_invalid.pto | 2 +- ...ut_assignment_group_reduce_s12_invalid.pto | 7 +- ...gnment_group_reduce_s32_tail_full_tile.pto | 12 +- ...p_reduce_s32_tail_no_full_tile_invalid.pto | 4 +- ...lot_load_slots1_dynamic_stride_invalid.pto | 2 +- ...t_load_slots1_unaligned_stride_invalid.pto | 2 +- ..._layout_assignment_group_slots_scf_for.pto | 6 +- ...group_store_slots1_unit_stride_invalid.pto | 2 +- .../vmi/vmi_layout_assignment_iota_remat.pto | 7 +- ...ignment_mask_granularity_f32_f16_store.pto | 8 +- .../vmi/vmi_layout_assignment_mask_remat.pto | 63 +- ...signment_masked_load_dense_group_users.pto | 4 +- ..._assignment_masked_load_group_tail_s32.pto | 4 +- ..._layout_assignment_non_load_s32_reduce.pto | 4 +- ...ment_packed_group_slots_truncf_invalid.pto | 2 +- ...yout_assignment_widen_f16_store_reduce.pto | 6 +- ...ayout_gate_bitcast_group_slots_invalid.pto | 2 +- ...i_layout_gate_bitcast_support_invalid.pto} | 4 +- ... vmi_layout_gate_extf_support_invalid.pto} | 4 +- ..._gate_group_broadcast_support_invalid.pto} | 4 +- ...ayout_gate_group_load_support_invalid.pto} | 4 +- ...e_group_reduce_slots1_support_invalid.pto} | 6 +- ...out_gate_group_reduce_support_invalid.pto} | 6 +- ..._gate_group_slot_load_support_invalid.pto} | 4 +- ..._group_slots_unsupported_slots_invalid.pto | 6 +- ...yout_gate_group_store_support_invalid.pto} | 4 +- ...e_helper_materialization_shape_invalid.pto | 4 +- ...mi_layout_gate_helper_support_invalid.pto} | 4 +- ...vmi_layout_gate_store_support_invalid.pto} | 4 +- ...recipe.pto => vmi_layout_gate_support.pto} | 4 +- ...mi_layout_gate_truncf_support_invalid.pto} | 4 +- .../lit/vmi/vmi_layout_rematerialize_data.pto | 17 + ...vmi_layout_sink_materialization_binary.pto | 122 ++++ ...mi_to_vpto_constant_mask_rematerialize.pto | 2 +- .../vmi_to_vpto_create_mask_rematerialize.pto | 2 +- ...o_vpto_group_broadcast_slots8_support.pto} | 4 +- .../vmi/vmi_to_vpto_group_broadcast_vselr.pto | 4 +- ...pto => vmi_to_vpto_group_load_support.pto} | 4 +- test/lit/vmi/vmi_to_vpto_group_ops.pto | 4 +- ...vpto_group_reduce_legacy_slots_invalid.pto | 27 + ... vmi_to_vpto_group_reduce_s64_support.pto} | 4 +- ...i_to_vpto_group_reduce_slots8_support.pto} | 4 +- .../vmi/vmi_to_vpto_group_reduce_vcgadd.pto | 4 +- ...to_vpto_group_reduce_vcgadd_multichunk.pto | 4 +- ...> vmi_to_vpto_group_slot_load_support.pto} | 4 +- ...vpto_group_slot_truncf_slots1_support.pto} | 4 +- test/lit/vmi/vmi_to_vpto_quant_dequant.pto | 11 +- ...vpto_truncf_fp8_128_contiguous_invalid.pto | 4 +- 71 files changed, 2369 insertions(+), 1793 deletions(-) create mode 100644 include/PTO/Transforms/VMILayoutSupport.h delete mode 100644 include/PTO/Transforms/VMILocalRecipeRegistry.h rename lib/PTO/Transforms/{VMILocalRecipeRegistry.cpp => VMILayoutSupport.cpp} (73%) rename test/lit/vmi/{vmi_layout_gate_bitcast_recipe_invalid.pto => vmi_layout_gate_bitcast_support_invalid.pto} (94%) rename test/lit/vmi/{vmi_layout_gate_extf_recipe_invalid.pto => vmi_layout_gate_extf_support_invalid.pto} (94%) rename test/lit/vmi/{vmi_layout_gate_group_broadcast_recipe_invalid.pto => vmi_layout_gate_group_broadcast_support_invalid.pto} (93%) rename test/lit/vmi/{vmi_layout_gate_group_load_recipe_invalid.pto => vmi_layout_gate_group_load_support_invalid.pto} (93%) rename test/lit/vmi/{vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto => vmi_layout_gate_group_reduce_slots1_support_invalid.pto} (85%) rename test/lit/vmi/{vmi_layout_gate_group_reduce_recipe_invalid.pto => vmi_layout_gate_group_reduce_support_invalid.pto} (85%) rename test/lit/vmi/{vmi_layout_gate_group_slot_load_recipe_invalid.pto => vmi_layout_gate_group_slot_load_support_invalid.pto} (93%) rename test/lit/vmi/{vmi_layout_gate_group_store_recipe_invalid.pto => vmi_layout_gate_group_store_support_invalid.pto} (93%) rename test/lit/vmi/{vmi_layout_gate_helper_recipe_invalid.pto => vmi_layout_gate_helper_support_invalid.pto} (92%) rename test/lit/vmi/{vmi_layout_gate_store_recipe_invalid.pto => vmi_layout_gate_store_support_invalid.pto} (95%) rename test/lit/vmi/{vmi_layout_gate_local_recipe.pto => vmi_layout_gate_support.pto} (92%) rename test/lit/vmi/{vmi_layout_gate_truncf_recipe_invalid.pto => vmi_layout_gate_truncf_support_invalid.pto} (94%) rename test/lit/vmi/{vmi_to_vpto_group_broadcast_slots8_local_recipe.pto => vmi_to_vpto_group_broadcast_slots8_support.pto} (94%) rename test/lit/vmi/{vmi_to_vpto_group_load_local_recipe.pto => vmi_to_vpto_group_load_support.pto} (94%) create mode 100644 test/lit/vmi/vmi_to_vpto_group_reduce_legacy_slots_invalid.pto rename test/lit/vmi/{vmi_to_vpto_group_reduce_s64_local_recipe.pto => vmi_to_vpto_group_reduce_s64_support.pto} (94%) rename test/lit/vmi/{vmi_to_vpto_group_reduce_slots8_local_recipe.pto => vmi_to_vpto_group_reduce_slots8_support.pto} (91%) rename test/lit/vmi/{vmi_to_vpto_group_slot_load_local_recipe.pto => vmi_to_vpto_group_slot_load_support.pto} (91%) rename test/lit/vmi/{vmi_to_vpto_group_slot_truncf_slots1_local_recipe.pto => vmi_to_vpto_group_slot_truncf_slots1_support.pto} (96%) diff --git a/docs/designs/vmi-implementation-manual.md b/docs/designs/vmi-implementation-manual.md index 497e951e73..6bb7a7e0fe 100644 --- a/docs/designs/vmi-implementation-manual.md +++ b/docs/designs/vmi-implementation-manual.md @@ -143,8 +143,10 @@ values and ops. It is not part of the default PTOAS pipeline; existing PTO/VPTO unless the flag is set. The `ptoas --enable-vmi` user-facing entry also rejects public functions whose signature contains `!pto.vmi.*`. -Internal/private VMI-typed functions may still be specialized by `vmi-layout-assignment` and physicalized by -`vmi-to-vpto`, but a public VMI ABI requires an explicit materialization plan and must not be inferred from the +Internal/private VMI-typed functions are materialized at explicit boundary +helpers by baseline `vmi-layout-assignment` and physicalized by `vmi-to-vpto`. +A later optimization pass may specialize private signatures. A public VMI ABI +requires an explicit materialization plan and must not be inferred from the layout solver. CLI coverage: @@ -198,7 +200,7 @@ vmi-to-vpto: 写成 `pto.vmi.ensure_*`,physicalization 后不允许残留 `pto.vmi.*`、`!pto.vmi.*` 或 `unrealized_conversion_cast`。不能把 layout 决策藏在 pass-private side table 里让后续 pass 猜。 -源码级实现应该进一步拆成六个独立层次: +源码级实现应该进一步拆成七个独立层次: ```text IR layer: @@ -221,6 +223,20 @@ Layout solving layer: 负责从 producer/consumer/control-flow/call 关系解出每个 logical value 的 layout, 然后把结果写回 type 或 ensure_* helper。 +Layout support query layer: + include/PTO/Transforms/VMILayoutSupport.h + lib/PTO/Transforms/VMILayoutSupport.cpp + + 只放跨阶段共享的纯查询:cast layout fact、group_reduce layout fact、 + ensure_* materialization support、layout-aware store support 等。它可以被 + assignment、validation、layout optimization 和 vmi-to-vpto 调用,但不能保存 + per-value 状态,不能返回 VPTO 指令计划,不能决定 clone/rematerialize,也不能 + 通过 producer/user/control-flow context 恢复 lowering 决策。 + + 加新 query 的标准是:至少两个阶段需要同一个语义事实,并且重复实现会导致 + assignment、validation、lowering 对同一个 layout shape 得出不同结论。只有 + 一个 lowering pattern 自己使用的分支应该留在该 pattern 内。 + Layout optimization layer: lib/PTO/Transforms/VMILayoutFoldConsumers.cpp lib/PTO/Transforms/VMILayoutRematerialize.cpp @@ -257,7 +273,8 @@ Union-find + DenseMap: 用于 layout assignment 的 per-SSA-value 等价类求解。 IRRewriter/RewriterBase: - 用于 layout assignment 之后的 type rewrite、helper insertion、cheap producer rematerialization。 + 用于 layout assignment 之后的 type rewrite、helper insertion;cheap producer + rematerialization 属于后续 layout optimization pass。 OneToNTypeConverter + OneToNOpConversionPattern: 只用于 vmi-to-vpto,把一个 logical VMI value 展成多个 VPTO value。 @@ -1002,7 +1019,7 @@ SymbolTable: 解析 direct internal func.call;带 VMI type 的 external/indirect call 先拒绝。 IRRewriter: - 改写 function/block/result type,插入 ensure_*,必要时 rematerialize cheap producer。 + 改写 function/block/result type,插入 ensure_*。 verifyLayoutAssignedVMIIR: pass 末尾 hard gate,确认所有决策已经 materialize 到 IR。 @@ -1248,7 +1265,7 @@ The solver runs in phases: 3. add producer natural-layout constraints 4. add consumer layout/granularity requests 5. solve each equivalence class -6. insert ensure_* or rematerialize producers for non-class-compatible uses +6. insert ensure_* for non-class-compatible uses 7. rewrite value types and function signatures 8. run pto-validate-vmi-layout-ir ``` @@ -1300,9 +1317,10 @@ store/tile_write: consumer requests contiguous externally visible order ``` -If one equivalence class has incompatible natural layouts, the pass must diagnose `VMI-LAYOUT-CONTRACT` unless a -defined rematerialization path can split the value before the conflict. The first version should only rematerialize -trivially replayable producers: +If one equivalence class has incompatible natural layouts, the pass must diagnose `VMI-LAYOUT-CONTRACT` unless an +explicit use-site `ensure_*` can represent the requested materialization. Baseline layout assignment does not +clone/rematerialize producers. The separate `vmi-layout-rematerialize` optimization may replace an `ensure_*` +with a cloned trivially replayable producer after the materialization request is visible in IR: ```text constant @@ -1595,8 +1613,8 @@ Layout assignment completion checks: 2. No surface !pto.vmi.mask remains. 3. Every VMI function argument, result, block argument, branch operand, call operand, and return operand has the layout-assigned type selected by the solved equivalence class. -4. Every consumer-specific mismatch is represented either by a rematerialized cheap producer or by an explicit - pto.vmi.ensure_* op immediately before that consumer. +4. Every consumer-specific mismatch is represented by an explicit pto.vmi.ensure_* op immediately before that + consumer. Optional optimization passes may later replace selected helpers with rematerialized cheap producers. 5. External declarations with VMI types are rejected; they are not rewritten into an implicit ABI. ``` @@ -2430,7 +2448,6 @@ allowed layouts: bitset {contiguous, deinterleaved2, deinterleaved4} required mask granularity: pred/b8/b16/b32 or unknown natural layout preference hard constraints -soft costs ``` No information required by later passes may live only in this data structure. After the pass, type/attr/op @@ -2495,10 +2512,12 @@ bitcast: bitcast contract is defined. load/tile_read: - result layout chosen by consumers unless memory plan has a cheaper registered sink/source + baseline result layout is deterministic from explicit layout attrs or the + producer natural layout; consumer-specific alternatives are represented by + ensure_layout and optimized later store/tile_write: - can consume any layout only if target registry has preserving store path + baseline requests contiguous source layout current implementation records a contiguous use-site request for vmi.store and inserts pto.vmi.ensure_layout when the stored value class solved to a non-contiguous layout. This makes externally visible memory order explicit in @@ -2508,7 +2527,8 @@ store/tile_write: the same physical chunk count and therefore forms complete intlv groups. shuffle/channel_split/channel_merge: - default result layout contiguous unless target registry provides direct layout-preserving path + default result layout contiguous unless the current op explicitly carries a + supported layout-preserving contract current implementation supports pto.vmi.shuffle when every result physical chunk forwards one source physical chunk with identical lane positions for all non-padding result lanes. Result padding lanes are ignored by the @@ -2538,12 +2558,12 @@ Implement deterministic solving: ```text 1. Collect region/SCC constraints, including scf/cf/function/call boundaries. 2. Propagate impossible layouts and required mask granularities. -3. Pick a layout per node using minimum cost. -4. Tie-break: explicit layout already present on the VMI type, then natural layout, then contiguous. +3. Pick one layout per node using deterministic priority, not a cost model: + explicit layout already present on the VMI type, then unique natural layout, + then hard non-contiguous request, then contiguous. 5. Rewrite result/block/function types to layout-assigned VMI types. 6. Insert ensure_layout / ensure_mask_layout / ensure_mask_granularity at uses that need conversion. -7. Clone rematerializable producers per use when cheaper than conversion. -8. Run verifier gate. +7. Run verifier gate. ``` Current implementation status: @@ -2591,7 +2611,8 @@ Do not implement a local greedy pattern pass that ignores block arguments or fun CFG 处理分两层。第一层是必须做的 layout equivalence:同一个控制流值在 result、yield、region/block argument 之间必须形成同一个 layout/mask 约束组。第二层才是 layout conflict resolution:当同一个 producer 的不同 consumers 希望不同 layout 时,插入 -`ensure_layout`、`ensure_mask_layout` 或 rematerialize producer。 +`ensure_layout` 或 `ensure_mask_layout`。后续 `vmi-layout-rematerialize` 可以把部分 helper +替换成重放的纯构造 producer。 当前可落地的最小实现先做第一层。它不尝试在 branch 边界自动插入 conversion,因此下面这些 关系一旦因为 natural layout 或 mask granularity 冲突无法合并,必须报 `VMI-LAYOUT-CONTRACT`, @@ -3051,17 +3072,13 @@ pto.vmi.group_reduce_addf: requires {reassoc} N = logical lane count; G = num_groups; S = N / G L = physical lanes per 256B chunk for the element type. - The result carries #pto.vmi.layout, a sparse group-slot - layout. It is not a dense vector layout: only group_slot(g) lanes have - semantic values. - group_slot(g) is canonical and derived from N, G, and L: - if S < L: - low_elems = L / S - chunk_stride = 1 - if S >= L: - low_elems = 1 - chunk_stride = S / L - group_slot(g) = (g / low_elems) * chunk_stride * L + (g % low_elems) + The result carries #pto.vmi.layout, a sparse + group-slot layout. It is not a dense vector layout: only slot lanes have + semantic values. Supported K values are: + K = 8 for VCGADD-style packed results, where group g is stored in + physical chunk floor(g / 8), lane g % 8. + K = 1 for row-local VCADD results, where group g is stored in physical + chunk g, lane 0. for each group g: result[group_slot(g)] = reduce_add(source[g * S .. (g + 1) * S), mask in same range) @@ -3069,10 +3086,10 @@ pto.vmi.group_reduce_addf: direct lowering materializes them as zero where the hardware path does not already define them. The result remains a VMI vector with the same element type and logical lane - count as the source, but its layout is #pto.vmi.layout. + count as the source, but its layout is an explicit group-slot layout. layout assignment: source use is requested as contiguous - result natural layout is #pto.vmi.layout + result natural layout is #pto.vmi.layout mask use is requested as contiguous with granularity derived from source element width current direct lowering: @@ -3085,8 +3102,8 @@ pto.vmi.group_reduce_addf: Otherwise: derived group size S must be a multiple of physical lanes per part lower each source chunk with pto.vcadd, combine chunks in the same group - with pto.vadd under PAT_VL1, then place group g at group_slot(g) in the - #pto.vmi.layout result. All other result chunks/lane values + with pto.vadd under PAT_VL1, then place group g in the slot lane defined by + K. All other result chunks/lane values are zero. unsupported cases: missing reassoc attr @@ -3097,17 +3114,17 @@ pto.vmi.group_reduce_addf: pto.vmi.group_broadcast: semantic: N = logical lane count; G = num_groups; S = N / G - source must carry #pto.vmi.layout. For each group g, the - source value is read from group_slot(g), using the same canonical group_slot - definition as pto.vmi.group_reduce_addf. The result broadcasts it back to + source must carry #pto.vmi.layout. For each group + g, the source value is read from the slot lane defined by K. The result broadcasts it back to each logical group: result[g * S + i] = source[group_slot(g)] layout assignment: - source use is requested as #pto.vmi.layout + source use is requested as #pto.vmi.layout result is consumer-driven. If no consumer requests another layout, it defaults to contiguous. current direct lowering: - source must carry #pto.vmi.layout with full physical chunks + source must carry #pto.vmi.layout with full + physical chunks result may be contiguous with full physical chunks result may also be deinterleaved when S is large enough that every physical result chunk stays inside one logical group, for example N=512, G=2, S=256, @@ -4011,23 +4028,27 @@ Slice 5 完成条件: writeMask fallback paths must report `VMI-UNSUPPORTED`. ``` -## 8. Target Capability Registry +## 8. Target Capabilities And Layout Fact Helpers -Add one explicit registry object, passed into layout assignment and VMI-to-VPTO: +Keep target capabilities separate from layout assignment policy. The shared +helpers expose target support and small layout/materialization facts; they do +not select a global lowering plan and are not a shared lowering-plan registry +between assignment and VMI-to-VPTO. ```text supportsElementType(type, purpose) -getNaturalLayout(op) -supportsLayoutConversion(srcLayout, dstLayout, elementType) -getLayoutMaterializationPlan(srcLayout, dstLayout, elementType) +getPreferredCastLayoutFact(sourceType, resultType) +getPreferredGroupReduceLayoutFact(sourceType, numGroups) +canMaterializeDataLayout(sourceType, resultType) +canMaterializeMaskLayout(sourceType, resultType) supportsMaskGranularityConversion(srcG, dstG) -supportsMemoryAccessPlan(plan) +supportsMemoryAccessProof(proof) supportsPrefixPopcount(maskType) supportsReductionScanContract(op) getScratchResource(plan) ``` -The registry returns structured results: +Capability and materialization helpers return structured results: ```text supported @@ -4170,7 +4191,7 @@ If any answer is no, the slice is not ready to be treated as complete. ## 13. Adding One VMI Op End To End -新增一个 `pto.vmi.*` op 时,不要只补 ODS 和 lowering pattern。它必须穿过固定的六个落点, +新增一个 `pto.vmi.*` op 时,不要只补 ODS 和 lowering pattern。它必须穿过固定的七个落点, 否则很容易出现 verifier 能过、layout pass 不知道怎么约束、或控制流 physicalization 后残留 VMI type。 ```text @@ -4180,20 +4201,24 @@ If any answer is no, the slice is not ready to be treated as complete. 2. semantic verifier: lib/PTO/IR/VMI.cpp -3. layout facts: +3. layout assignment facts: lib/PTO/Transforms/VMILayoutAssignment.cpp -4. vmi-to-vpto preflight: +4. shared layout support, when the fact crosses stages: + include/PTO/Transforms/VMILayoutSupport.h + lib/PTO/Transforms/VMILayoutSupport.cpp + +5. vmi-to-vpto preflight: lib/PTO/Transforms/VMIToVPTO.cpp::verifySupportedVMIToVPTOOps -5. OneToN lowering pattern: +6. OneToN lowering pattern: lib/PTO/Transforms/VMIToVPTO.cpp::populateVMIOneToNConversionPatterns -6. focused lit tests: +7. focused lit tests: test/lit/vmi/ ``` -这六个落点的职责不同: +这七个落点的职责不同: ```text ODS: @@ -4211,6 +4236,12 @@ LayoutAssignment: - mask consumer required granularity 不能在 collect 阶段改 IR。 +VMILayoutSupport: + 只放跨 assignment、validation、optimization、lowering 中至少两个阶段共享的纯查询。 + 典型内容是 cast layout fact、group_reduce layout fact、ensure_* materialization support。 + 不能返回 VPTO instruction sequence、不能决定 clone/rematerialize、不能读取 producer/user context。 + 只有一个 lowering pattern 自己使用的判断不要抽到这里。 + VMIToVPTO preflight: 在 rewrite 前拒绝当前 lowering 不支持但语义合法的 case。 典型例子是 partial physical chunk、non-prefix mask constant、dynamic create_mask、unsupported shuffle。 diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index a6583a3d8b..e1fa19cc7e 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -41,8 +41,8 @@ pto-validate-vmi-ir: vmi-layout-assignment: solve hard value layout constraints - choose explicit layouts and local recipe carriers visible in IR - insert ensure/rematerialization helpers + choose explicit layouts visible in IR + insert ensure_layout / ensure_mask_layout / ensure_mask_granularity helpers make internal function boundary layouts explicit rewrite VMI types with layout attrs @@ -54,11 +54,12 @@ vmi-layout-fold-consumers: fold use-site materialization into consumers that can directly consume the source layout while preserving the same logical effect example: ensure_layout(deinterleaved=2 -> contiguous) feeding store may become - a store of deinterleaved=2 when the store has a local vstsx2 INTLV recipe + a store of deinterleaved=2 when the store has a layout-aware vstsx2 INTLV + lowering current implementation: pto.vmi.store, pto.vmi.tile_write, and the value operand of pto.vmi.masked_store when the existing mask arity matches, fed by ensure_layout from deinterleaved=2/4, block_elems=1 to contiguous. factor=2 - uses the store's vstsx2 INTLV recipe; factor=4 is still store-local, but it + uses the store's vstsx2 INTLV lowering; factor=4 is still store-local, but it materializes through physical interleave before vsts. vmi-layout-rematerialize: @@ -73,15 +74,17 @@ vmi-layout-rematerialize: vmi-layout-sink-materialization: move ensure_layout across pure layout-transparent elementwise chains when the - rewritten IR reduces materialization cost and keeps every op locally legal + rewritten IR reduces materialization overhead and keeps every op locally legal current implementation: sink two identical operand ensure_layout helpers - across binary add/sub/mul/div/min/max/and/or/xor/shl/shru VMI ops, or one - source ensure_layout across unary neg/abs/sqrt/exp/ln/relu/not VMI ops, - producing one result ensure_layout. It also sinks matching - ensure_mask_layout or ensure_mask_granularity helpers across - mask_and/mask_or/mask_xor/mask_not, producing one result mask helper. It - does not sink through select, fma, cast, load, store, reduce, - group_broadcast, or control-flow ops + across binary add/sub/mul/div/min/max/and/or/xor/shl/shru VMI ops, three + identical operand ensure_layout helpers across fma, or one source + ensure_layout across unary neg/abs/sqrt/exp/ln/relu/not VMI ops, producing + one result ensure_layout. It also sinks compare data helpers to one result + ensure_mask_layout, and sinks select only when both selected values and the + mask carry matching explicit helpers. Matching ensure_mask_layout or + ensure_mask_granularity helpers are sunk across mask_and/mask_or/mask_xor/ + mask_not, producing one result mask helper. It does not sink through cast, + load, store, reduce, group_broadcast, or control-flow ops. vmi-legalize-arith-select: restore scalar-condition arith.select with VMI result type back to scf.if @@ -92,11 +95,12 @@ pto-validate-vmi-layout-ir: verify every VMI data/mask value has layout verify every VMI value has an assigned layout and every non-local lowering choice has been serialized explicitly - verify helper ops have registered materialization recipes. Current + verify helper ops have supported materialization paths. Current implementation checks `ensure_layout`, `ensure_mask_layout`, and - `ensure_mask_granularity` at the layout gate, so unsupported helper recipes - fail before `vmi-to-vpto`. It also checks the first semantic local-recipe - families, non-contiguous `pto.vmi.store`/`pto.vmi.tile_write`, block8 + `ensure_mask_granularity` at the layout gate, so unsupported helper + materializations fail before `vmi-to-vpto`. It also checks the first + semantic local lowering families, non-contiguous + `pto.vmi.store`/`pto.vmi.tile_write`, block8 `pto.vmi.group_load`, `pto.vmi.group_slot_load`, group_slots `pto.vmi.group_store`, group_slots `pto.vmi.group_reduce_add{f|i}`, explicit-slots `pto.vmi.group_broadcast`, `pto.vmi.truncf`, @@ -168,7 +172,7 @@ include/PTO/Transforms/Passes.td lib/PTO/Transforms/PTOValidateVMIIR.cpp lib/PTO/Transforms/VMILayoutAssignment.cpp lib/PTO/Transforms/VMIToVPTO.cpp -lib/PTO/Transforms/VMILocalRecipeRegistry.cpp +small layout fact/materialization helpers under lib/PTO/Transforms test/lit/vmi/vmi_layout_assignment_*.pto test/lit/vmi/vmi_to_vpto_*.pto @@ -224,7 +228,7 @@ contiguous: deinterleaved: F > 1 B > 0 - direct full-chunk recipes require N % (F * B) == 0 + direct full-chunk lowerings require N % (F * B) == 0 group_slots: G > 0 @@ -236,12 +240,13 @@ group_slots: Parser compatibility during migration: ```text -#pto.vmi.layout +#pto.vmi.layout ``` -is accepted as a legacy spelling for the pre-design implicit group layout. New -`vmi-layout-assignment` output must not rely on that implicit form. It must -print one of: +is the lowering contract for group-slot values. The parser still accepts +`#pto.vmi.layout` as a legacy spelling for the pre-design +implicit group layout, but `vmi-to-vpto` support queries require explicit slots. +New `vmi-layout-assignment` output must print one of: ```text #pto.vmi.layout @@ -270,10 +275,10 @@ Layout-assigned: Surface VMI types are legal before assignment. Layout-assigned VMI types are required after assignment. -### 3.3 Explicit Recipe Carriers +### 3.3 Explicit Lowering Carriers Lowering decisions are carried by the current op and its types, not by a -separate recipe string. The allowed carriers are: +separate lowering-plan string. The allowed carriers are: ```text op attrs and operands @@ -297,9 +302,9 @@ group_slot_load result group_slots layout and source_group_stride group_reduce_add{f|i} source/mask/result layouts, num_groups, typed reduce semantics group_broadcast source/result layouts and num_groups truncf source/result layouts and element widths -ensure_layout always carries source/result layouts instead of recipe -ensure_mask_layout always carries source/result layouts instead of recipe -ensure_mask_granularity always carries source/result granularities instead of recipe +ensure_layout always carries source/result layouts +ensure_mask_layout always carries source/result layouts +ensure_mask_granularity always carries source/result granularities ``` Layout/attr-only decisions today: @@ -318,11 +323,12 @@ Implementation rule: validate-assigned-vmi validates assigned layouts, mask granularity, boundaries, and helper placement. vmi-to-vpto emits VMI-LAYOUT-CONTRACT for missing local proof. -If a layout/attr-only op later gains a second legal recipe that cannot be -distinguished from current-op information, that recipe must be represented by a +If a layout/attr-only op later gains a second legal lowering that cannot be +distinguished from current-op information, that lowering must be represented by a new attr, helper op, or rematerialized op before vmi-to-vpto can emit it. -Unsupported shapes that have no registered recipe still diagnose through their -specific capability check rather than failing with a generic missing-recipe +Unsupported shapes that have no explicit materialization/lowering path still +diagnose through their specific capability check rather than failing with a generic +missing-lowering error. ``` @@ -386,7 +392,7 @@ cast boundary: op semantic, not a VMI type spelling. Current VPTO lowering supports 32-bit integer narrowing to unsigned i8 storage, matching the available VCVTII s32/u32 -> u8 forms; signed i8 - narrowing needs a separate target recipe. + narrowing needs a separate target lowering. compute / accumulator: floating compute baseline: f16/f32, with reassoc required for reductions @@ -394,7 +400,7 @@ compute / accumulator: integer compute baseline: i32 for grouped reduction; i8/i16 storage must first cast to i32 because integer reduction instructions widen narrow inputs. f8/i8 are not baseline accumulator/compute types. Supporting direct 8-bit - compute requires a target capability entry and a separate recipe family. + compute requires a target capability entry and a separate lowering family. ``` Important semantic split: @@ -412,129 +418,169 @@ group_slot_load: loads one scalar per group and produces group_slots ``` -## 5. Local Recipe Registry +## 5. Layout Fact Helpers And Ensure-Based Optimization Hooks -Create one target-aware local recipe registry shared by assignment and lowering. -It is not serialized as a separate recipe-selection attribute. It answers local legality -questions from op kind, explicit attrs/operands, layouts, and target capability. +Do not implement a target-aware lowering-plan registry shared by assignment and +lowering. The shared contract is the IR itself: assigned VMI layouts, explicit +`ensure_layout` / `ensure_mask_layout` / `ensure_mask_granularity` helpers, +semantic op attrs/operands, and target capability diagnostics. -```c++ -class VMILocalRecipeRegistry { -public: - SmallVector getProducerRecipes(Operation *op); - SmallVector getConsumerRecipes(OpOperand &use); - SmallVector getTransferRecipes(Operation *op); - FailureOr - getMaterializationRecipe(Type valueType, VMILayoutKey from, - VMILayoutKey to); - bool isCheaplyRematerializable(Operation *op); - bool hasTargetCapability(RecipeID recipe) const; -}; +Small pure helpers are still useful when they remove duplicated layout math. +They must return semantic layout facts, not VPTO instruction plans, costs, +clone decisions, or multi-user plans. + +Keep the support layer small. A query belongs in `VMILayoutSupport` only when +at least two stages need the same fact and a mismatch would create an +assignment-vs-lowering bug. Current valid shared facts are: + +```text +cast layout fact: + shared by layout assignment, layout validation, and vmi-to-vpto. + Example: f32->f8 must see deinterleaved=4 source and contiguous result in + every stage. + +group_reduce layout fact: + shared by layout assignment, layout validation, and vmi-to-vpto. + Example: S=2*VLaneElems means deinterleaved=2 source/mask and + group_slots(G, slots=8) result in every stage. + +layout materialization support: + shared by layout validation, vmi-to-vpto, and helper-based optimizations. + Example: ensure_layout from deinterleaved=2 f32 to contiguous f32 is the same + materialization whether it survives to lowering or is folded into a store. + +contiguous store support: + shared by fold-consumers and vmi-to-vpto because both must preserve the same + row-major memory effect when consuming a non-contiguous value. ``` -Recipe record: +Do not add a support query for a single private branch such as "this exact op +uses this exact VPTO mnemonic". Keep that branch in the lowering pattern until +another stage needs the same semantic fact. This prevents `VMILayoutSupport` +from becoming a second copy of the lowering pass. ```c++ -struct VMILayoutRecipe { - RecipeID id; - SmallVector operandLayouts; - SmallVector resultLayouts; - int64_t cost; - bool requiresFullTileReadable; - bool mayReadInactivePhysicalLanes; - DiagnosticBuilder (*explainFailure)(...); +struct VMICastLayoutFact { + VMICastLayoutKind kind; + VMILayoutAttr sourceLayout; + VMILayoutAttr resultLayout; + int64_t factor; }; + +struct VMIGroupReduceLayoutFact { + VMILayoutAttr sourceLayout; + VMILayoutAttr maskLayout; + VMILayoutAttr resultLayout; + int64_t groupSize; + int64_t vlaneElems; +}; + +FailureOr +getPreferredCastLayoutFact(VMIVRegType sourceType, VMIVRegType resultType); + +FailureOr +getPreferredGroupReduceLayoutFact(VMIVRegType sourceType, int64_t numGroups); + +LogicalResult canMaterializeDataLayout(VMIVRegType sourceType, + VMIVRegType resultType, + std::string *reason); ``` -The registry must be target-aware but deterministic. It should not read global -mutable state. Pass options configure fallback availability: +Baseline assignment uses these helpers only to produce assigned layouts and +use-site helpers. It does not clone producers, rematerialize cheap ops, choose +memory-fused layouts by cost, or specialize private function signatures for +performance. + +Optimization passes are deliberately helper-driven: ```text -enableScratchFallback -enableGatherFallback -enablePublicVMIABI -diagnosticVerbosity +fold-consumers: + input shape: ensure_layout feeding a layout-aware consumer. + support query: can this consumer preserve the same logical memory effect from + the source layout? + output shape: the consumer directly uses the source value. + +rematerialize: + input shape: cheap producer feeding ensure_layout / ensure_mask_layout. + support query: can the cloned producer directly create the requested type? + output shape: a cloned producer at the use. + +sink-materialization: + input shape: pure elementwise op whose operands are matching ensure_* helpers. + support query: can the result helper be materialized if it remains? + output shape: the op runs in the source layout and one helper remains on the + result. ``` -Assignment and optimization passes may query the registry to decide which IR -shape to produce. `vmi-to-vpto` may query the same registry to verify the -current op is locally lowerable. If the same op, attrs, operands, and -operand/result layouts could map to two different physical recipes with -different observable preconditions, the IR is under-specified; add an explicit -attr, operand, helper op, or distinct VMI semantic op before implementing that -recipe. +These passes may improve multi-consumer cases without asking assignment to solve +a global cost problem. Assignment guarantees a legal baseline with helpers; +optimization removes or moves those helpers locally when the rewritten IR still +contains enough information for `vmi-to-vpto`. -Current implementation status: `VMILocalRecipeRegistry` exists and currently -owns nine local recipe families: +Implementation-relevant layout facts: ```text -contiguous store/tile_write consumer recipes: - contiguous vsts - deinterleaved=2 vstsx2 INTLV - deinterleaved=4 materialize-then-vsts +dense store/tile_write: + requests contiguous source. If the value is assigned deinterleaved, + assignment inserts ensure_layout at the store use. A later optimization may + fold ensure_layout + store into a layout-aware store lowering. -helper materialization recipes: - data/mask layout identity - data/mask contiguous <-> deinterleaved=2/4 when source/result physical - arity matches and the physical part shape can be materialized - mask granularity identity or b8/b16/b32 predicate cast +data/mask helper materialization: + identity conversions are always legal. + contiguous <-> deinterleaved=2/4 is legal only when source/result physical + arity and physical chunk shapes make the same logical value materializable. + unsupported conversions remain explicit diagnostics. -group_slot_load semantic recipes: - slots=8 unit-stride vsldb - slots=1 aligned lane-0 vsldb per group - -block8 group_load semantic recipes: - S=16 deinterleaved=2, block_elems=8 vsldb per row fragment - S=32 deinterleaved=4, block_elems=8 vsldb per row fragment +group_slot_load: + assigned result layout is group_slots(G, slots=8) for packed slots or + group_slots(G, slots=1) for row-local slots. -group_slots group_store semantic recipes: - slots=8 unit-stride vsts - slots=1 aligned lane-0 vsts per group +block8 group_load: + assigned result layout is deinterleaved=2/4 with block_elems=8 only when the + op carries the required constant stride and memory-safety proof. -group_slots group_reduce_add{f|i} semantic recipes: - define E = sizeof(T), VLaneElems = 32B / E, L = 256B / E, S = N / G. - T is the accumulator/reduce element type after any required storage cast. - f8 storage reduces through f32; i8 storage reduces through an explicit - signed/unsigned integer cast to an accumulator type such as i32. In the - baseline contract, f8/i8 are cast-boundary storage types rather than - accumulator/compute types. - S=VLaneElems contiguous vcgadd - S=2*VLaneElems deinterleaved=2 vcgadd+vadd - S=4*VLaneElems deinterleaved=4 vcgadd+vadd tree - S>=L && S%L==0 contiguous slots=1 vcadd/vadd/vsel row-local reduction, - with one physical result part per group. For 32-bit element types this covers - S=64, S=128, S=256, ...; for 16-bit element types this covers S=128, S=256, - ... +group_store: + consumes group_slots(G,K). Explicit output stride attrs/operands decide + whether slots=8 packed or slots=1 row-local stores are legal. + +group_reduce_add{f|i}: + define E = sizeof(accumulator T), VLaneElems = 32B / E, L = 256B / E, + S = N / G. T is the accumulator/reduce element type after any required + storage cast. + S=VLaneElems uses contiguous source/mask and group_slots(G, slots=8). + S=2*VLaneElems uses deinterleaved=2 source/mask and group_slots(G, slots=8). + S=4*VLaneElems uses deinterleaved=4 source/mask and group_slots(G, slots=8). + S>=L && S%L==0 uses contiguous source/mask and group_slots(G, slots=1). -explicit-slots group_broadcast semantic recipes: - slots=8/slots=1 vselr materialization to contiguous or supported - deinterleaved result layouts +group_broadcast: + consumes group_slots(G,K) and produces one assigned dense layout. If another + consumer wants a different dense layout, assignment inserts ensure_layout. + Optimization may clone/rematerialize group_broadcast per use. -extf/truncf semantic recipes: +extf/truncf: contiguous f16/bf16 -> deinterleaved=2 f32 contiguous f8-like -> deinterleaved=4 f32 deinterleaved=2 f32 -> contiguous f16 deinterleaved=4 f32 -> contiguous f8-like - group_slots(G, slots=1) f32 -> f16 + group_slots(G, slots=1) f32 -> f16 remains a slot-preserving transform. -extsi/extui/trunci semantic recipes: - contiguous i8 -> deinterleaved=2 i16 through VCVTII.{s,u}82{s,u}16 #part - contiguous i8 -> deinterleaved=4 i32 through VCVTII.{s,u}82{s,u}32 #pp - deinterleaved=2 i16 -> contiguous i8 through VCVTII.*162*8 #part - deinterleaved=4 i32 -> contiguous ui8 through VCVTII.*322u8 #pp +extsi/extui/trunci: + contiguous i8/i16 -> deinterleaved i32 according to widening factor. + deinterleaved i32 -> contiguous i8/i16 according to narrowing factor. packed group_slots integer width-changing cast is unsupported until a - slot-wise cast recipe is defined. + slot-wise transform is represented explicitly. -bitcast semantic recipes: - per-part vbitcast for contiguous/deinterleaved layouts when source/result - layouts match, physical arity matches, and every physical chunk carries the - same logical bit footprint; this does not require each deinterleaved part to - contain the same number of chunks. group_slots bitcast is unsupported until a - slot-wise bitcast contract is defined. +bitcast: + per-part vbitcast is valid for contiguous/deinterleaved layouts when + source/result layouts match, physical arity matches, and every physical chunk + carries the same logical bit footprint. group_slots bitcast is unsupported + until a slot-wise bitcast contract is defined. ``` -`vmi-layout-fold-consumers`, `pto-validate-vmi-layout-ir`, and `vmi-to-vpto` -query this registry for the decisions implemented above. +`vmi-layout-fold-consumers`, rematerialization, sink/hoist, and private +function specialization passes consume explicit helper IR. They may replace +helpers with cheaper equivalent IR, but they must not introduce hidden lowering +plans that `vmi-to-vpto` has to rediscover from producer/user context. ## 6. Layout Assignment Data Model @@ -544,23 +590,17 @@ query this registry for the decisions implemented above. struct ValueLayoutState { Value value; Type logicalType; - SmallVector candidates; std::optional chosen; + std::optional naturalLayout; SmallVector useRequests; }; struct UseRequest { OpOperand *operand; VMILayoutKey requestedLayout; - RecipeID requestingRecipe; + Operation *requestingOp; bool hard; }; - -struct OpRecipeState { - Operation *op; - SmallVector candidates; - std::optional chosen; -}; ``` ### 6.2 Collection Phase @@ -571,7 +611,7 @@ Walk the module and collect: 1. every VMI value 2. every VMI block argument 3. every VMI function argument/result -4. every VMI op with candidate local recipes +4. every VMI op with natural producer layouts or use-site layout requests 5. every branch/yield/call/return edge carrying VMI ``` @@ -602,13 +642,11 @@ truncf f32->f16: result contiguous group_reduce S=16: - source candidate deinterleaved=2, block_elems=1 - source candidate deinterleaved=2, block_elems=8 + source request deinterleaved=2, block_elems=1 result group_slots(G, slots=8) group_reduce S=32: - source candidate deinterleaved=4, block_elems=1 - source candidate deinterleaved=4, block_elems=8 + source request deinterleaved=4, block_elems=1 result group_slots(G, slots=8) group_reduce S=64: @@ -617,8 +655,8 @@ group_reduce S=64: group_broadcast: source request group_slots(G,K) - result candidate comes from each dense consumer request - op is rematerializable per use + result receives one assigned dense layout + incompatible dense uses are represented by ensure_layout ordinary dense add/mul/select: operands/results same dense layout @@ -634,33 +672,21 @@ group_store: source request group_slots(G,K) ``` -Consumer-driven adoption is limited to producers that are layout-transparent or -can produce the requested memory layout directly: +Baseline assignment does not perform consumer-driven adoption for performance. +It records natural producer layouts and hard use-site requests. If a request +does not match the assigned layout, the pass inserts an explicit helper at that +use. ```text -direct layout producer: - load, tile_read - -layout-transparent producer: - broadcast, constant, iota - add/sub/mul/fma/div/min/max/neg/abs/sqrt/exp/ln/relu - integer bitwise/shift/not - select, bitcast -``` +natural layout producer: + extf/truncf, group_reduce, group_slot_load, group_load when the op itself + carries a layout-producing contract -For a non-load layout-transparent producer, only non-contiguous consumer -requests may be adopted by the producer equivalence class. Contiguous requests -from ordinary stores are handled by use-site `ensure_layout` or -rematerialization instead. This prevents a dense store from overwriting a -natural `deinterleaved` cast layout while still allowing: - -```text -load -> broadcast -> addf -> S=32 group_reduce +layout equality producer: + dense add/mul/select and CFG-carried values tie operands/results but do not + pick a cheaper layout by cost ``` -to assign the whole producer chain as -`deinterleaved = 4, block_elems = 8` before `vmi-to-vpto`. - Memory legality constraints: ```text @@ -676,12 +702,12 @@ compact S=12 logical S=16: ### 6.3.1 Request Builders Implement request generation as small per-op builders. The builders produce -candidate recipes and use-site requests; they do not rewrite IR. +natural layouts, use-site requests, equality constraints, and diagnostics; they +do not choose optimization plans. ```text buildStoreRequests: - ordinary store -> dense contiguous request unless a layout-aware store recipe is - selected + ordinary store -> dense contiguous request group_store -> group_slots(G,K) request plus stride/alignment capability checks @@ -690,25 +716,25 @@ buildCastRequests: extf f8->f32 -> source contiguous, result deinterleaved=4 truncf f32->f16 -> source deinterleaved=2/block_elems=1, result contiguous truncf f32->f8 -> source deinterleaved=4/block_elems=1, result contiguous - group_slots slots=1 f32->f16 -> slot-preserving recipe - group_slots slots=8 width-changing cast -> diagnostic unless a packed recipe - exists + group_slots slots=1 f32->f16 -> explicit slot-preserving transform + group_slots slots=8 width-changing cast -> diagnostic unless a packed + transform is explicitly represented buildGroupReduceRequests: derive E = sizeof(accumulator type), VLaneElems = 32B / E, L = 256B / E, and S = logical_lanes / num_groups S=VLaneElems -> contiguous source, group_slots(G,8) result - S=2*VLaneElems -> deinterleaved=2/block_elems=1 or block_elems=8 source, + S=2*VLaneElems -> deinterleaved=2/block_elems=1 source, group_slots(G,8) result - S=4*VLaneElems -> deinterleaved=4/block_elems=1 or block_elems=8 source, + S=4*VLaneElems -> deinterleaved=4/block_elems=1 source, group_slots(G,8) result S>=L && S%L==0 -> contiguous source, group_slots(G,1) result 8-bit storage must be cast to an accumulator type before this request builder - other S -> diagnostic unless an explicit fallback recipe is enabled + other S -> diagnostic unless an explicit fallback op/helper is enabled buildGroupMemoryRequests: - group_load S=16/S=32 with aligned constant stride -> block_elems=8 recipe - group_load row-local full chunks -> contiguous recipe + group_load S=16/S=32 with aligned constant stride -> natural block_elems=8 + group_load row-local full chunks -> natural contiguous group_slot_load unit stride -> group_slots(G,8) group_slot_load aligned row-local stride -> group_slots(G,1) unsupported dynamic/unaligned grouped memory -> diagnostic @@ -723,8 +749,8 @@ buildElementwiseRequests: buildMaskRequests: mask layout follows each consuming data layout predicate granularity follows each consuming element type - create_mask/create_group_mask may be cloned for incompatible mask layout or - granularity requests + create_mask/create_group_mask produce one assigned mask layout and use + ensure_mask_layout / ensure_mask_granularity for incompatible uses masked_store requests source layout, mask layout, and store predicate granularity explicitly @@ -733,20 +759,22 @@ buildControlFlowRequests: create equality requests on the carried VMI layout variable buildFunctionBoundaryRequests: - private/internal function argument/result layouts are specialized or - materialized with callee-entry/return-site helpers + private/internal function argument/result layouts are materialized with + callee-entry/return-site helpers in baseline assignment; signature + specialization is an optimization pass public/external VMI arguments/results diagnose unless enablePublicVMIABI has - a real ABI recipe + a real ABI contract ``` Request builders must record the requesting op. Diagnostics and inserted helpers are use-site operations, so the user can see which consumer forced a layout. -### 6.3.2 Producer Classes +### 6.3.2 Optimization Producer Classes -The solver uses producer classes to decide whether a conflict can be solved by -cloning, equivalence propagation, or materialization. +Baseline assignment does not use producer classes to solve conflicts. It +inserts helpers. Later optimization passes may classify producers to replace +helpers with cheaper equivalent IR. ```text cheap rematerializable producers: @@ -757,7 +785,7 @@ cheap rematerializable producers: create_group_mask group_broadcast group_slot_load when the same address/no-alias/proof conditions as load hold - and the memory recipe is legal at the clone site + and the memory access remains legal at the clone site layout-transparent producers: add/sub/mul/fma/min/max/neg/abs @@ -766,13 +794,13 @@ layout-transparent producers: integer bitwise and shift ops fixed-layout producers: - extf/truncf physical conversion recipes - group_load block-fragment recipes + extf/truncf physical conversion layouts + group_load block-fragment layouts group_reduce result group_slots - masked_load when the physical memory-safety proof fixes a full-read recipe + masked_load when the physical memory-safety proof fixes a full-read lowering ``` -Conflict policy: +Optimization conflict policy: ```text cheap producer: @@ -784,31 +812,28 @@ layout-transparent producer: only at incompatible uses fixed-layout producer: - use registered materialization only; otherwise diagnose + use explicit helper materialization only; otherwise diagnose ``` -This is the rule that keeps case 3.32 legal: a plain `load` can be assigned to -`deinterleaved=4, block_elems=1` for both `truncf f32->f8` and S=32 -`group_reduce`. It also keeps case 3.19.2 diagnostic: a strided `group_load` -that selected `block_elems=8` is fixed unless a block8-to-parity -materialization or rematerialized memory recipe is registered. +These classes are not assignment constraints. They are rewrite preconditions +for passes that consume `ensure_layout` and decide whether the helper can be +folded, sunk, hoisted, or replaced by rematerialization. ### 6.4 Solving And Rewriting Algorithm: ```text -1. Pick candidate recipe sets for every op. -2. Propagate hard constraints through SCCs. -3. Resolve transfer-equivalent dense values. -4. Choose multi-recipe ops by cost: - - S=16 parity vs block8 - - load memory-fused vs load+materialize - - group_slot_load slots=8 vs slots=1 -5. For conflicting uses: - - rematerialize cheap producer where legal - - otherwise insert ensure_layout at use - - otherwise diagnose +1. Collect natural layouts, use-site requests, equality constraints, and + memory-safety proofs. +2. Propagate equality constraints through SCCs. +3. Choose one deterministic assigned layout per value/equivalence class: + explicit user layout, then unique producer natural layout, then hard + non-contiguous layout, then contiguous. +4. For conflicting uses, insert ensure_layout / ensure_mask_layout / + ensure_mask_granularity at the use. +5. Emit diagnostics for unsupported semantic constraints or missing explicit + materialization/memory-safety proof. 6. Rewrite VMI result/block/function types with chosen layouts. 7. Insert helper ops with source/result layout attrs. ``` @@ -818,9 +843,12 @@ Rewrite invariants: ```text No VMI data/mask value after assignment has a null layout. Any non-local choice is represented by op attrs, operand/result layouts, a -helper op, a clone, or an explicit diagnostic. -Every ensure_* helper has a registered materialization recipe. -Every function/call signature carrying VMI is specialized or diagnosed. +helper op, or an explicit diagnostic. Cloned/rematerialized producers may +appear only after later layout optimization passes. +Every ensure_* helper has an explicit supported materialization path or a +diagnostic. +Every function/call boundary carrying VMI is materialized, kept in an explicit +ABI contract, or diagnosed. ``` ### 6.5 Rewrite Artifacts @@ -831,24 +859,21 @@ Assignment rewrites the IR so that later lowering has no hidden choices. type rewrite: every VMI data/mask result and block argument receives a layout attr -clone rewrite: - cheap producers are cloned before their divergent use sites - each clone receives its own layout and attrs - ensure rewrite: - non-cheap values use pto.vmi.ensure_layout or ensure_mask_layout at the use + mismatched uses get pto.vmi.ensure_layout or ensure_mask_layout at the use site, with source and target layouts visible in the types granularity rewrite: one semantic mask used by f32 and f16 consumers gets - ensure_mask_granularity or cloned mask producers + ensure_mask_granularity at the use site control-flow rewrite: scf.if/scf.for yields and block arguments are rewritten to one agreed layout; materialization is inserted before yield when branches differ function rewrite: - private VMI functions are specialized or get callee-entry ensure_layout + baseline private VMI functions get callee-entry/return-site ensure_layout; + signature specialization is an optimization pass public/external VMI functions are diagnosed ``` @@ -865,7 +890,8 @@ Canonical assigned IR shape for a conflicting load: pto.vmi.store %x_dense, ... ``` -Canonical assigned IR shape for a cloned cheap producer: +Optional future optimized IR shape for a cloned load with an explicit +safe-read/execution proof: ```text %x_s16 = pto.vmi.load ... @@ -878,12 +904,12 @@ Canonical assigned IR shape for a cloned cheap producer: Canonical assigned IR shape for `group_broadcast` multi-use: ```text -%b0 = pto.vmi.group_broadcast %slots +%b = pto.vmi.group_broadcast %slots : !pto.vmi.vreg<256xf32, #pto.vmi.layout> -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> -%b1 = pto.vmi.group_broadcast %slots - : !pto.vmi.vreg<256xf32, #pto.vmi.layout> +%b_c = pto.vmi.ensure_layout %b + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> ``` @@ -912,10 +938,10 @@ vmi-to-vpto contract: ```text case family builder / owner assignment artifact -3.4 32-bit S=8 reduce buildGroupReduceRequests one_vlane contiguous recipe -3.5 32-bit S=16 reduce buildGroupReduceRequests two_vlane parity/block8 recipe -3.6 32-bit S=32 reduce buildGroupReduceRequests four_vlane dintlv4/block8 recipe -3.7 32-bit S=64 reduce buildGroupReduceRequests full_chunk row_local recipe +3.4 32-bit S=8 reduce buildGroupReduceRequests one_vlane contiguous lowering +3.5 32-bit S=16 reduce buildGroupReduceRequests two_vlane parity/block8 layout +3.6 32-bit S=32 reduce buildGroupReduceRequests four_vlane dintlv4/block8 layout +3.7 32-bit S=64 reduce buildGroupReduceRequests full_chunk row_local lowering 3.11.1 S=64 active-row tail buildMaskRequests active-row store/reduce masks 3.19.1 S=16 block_elems choice buildGroupReduceRequests explicit block_elems layout 3.38 multi-tile S=32 reduce buildGroupReduceRequests multiple group_slots chunks @@ -931,12 +957,12 @@ vmi-to-vpto contract: ```text case family builder / owner assignment artifact -3.15.1 S=16 row stride 16 buildGroupMemoryRequests block_elems=8 group_load recipe +3.15.1 S=16 row stride 16 buildGroupMemoryRequests block_elems=8 group_load layout 3.15.2 S=16 row stride > 16 buildGroupMemoryRequests strided block_elems=8 plan 3.16.1 group_slot_load slots=8 buildGroupMemoryRequests unit-stride packed slots plan 3.16.2 group_slot_load slots=1 buildGroupMemoryRequests row-local aligned slots plan 3.27 strided group_load buildGroupMemoryRequests positive block_elems=8 plan -3.28 slots=1 non-unit load buildGroupMemoryRequests row-local group_slot_load recipe +3.28 slots=1 non-unit load buildGroupMemoryRequests row-local group_slot_load layout 3.37 slots=1 strided store buildStoreRequests group_store stride/alignment proof 3.39 strided load fanout conflict resolver preserving layout or materialization @@ -944,7 +970,7 @@ vmi-to-vpto contract: consume only explicit memory stride/alignment attrs, current op operands, and layouts. It must not infer safe read/write placement from neighboring compute ops. Unsupported dynamic, unaligned, or compact-row gather shapes - stay diagnostics until a gather recipe is explicit in the current op. + stay diagnostics until a gather fallback is explicit in the current op. ``` ```text @@ -952,13 +978,13 @@ case family builder / owner assignment artifact 3.8 reduce->truncf->broadcast conflict resolver slot cast plus dense materialization 3.10 non-load S=32 producer buildElementwiseRequests transparent deinterleaved chain 3.17 broadcast deint consumer conflict resolver use-site group_broadcast layout -3.18 dense + reduce users conflict resolver clone/rematerialize/ensure_layout -3.23 broadcast multi-user conflict resolver cloned group_broadcast -3.33 S=16 + S=32 users conflict resolver cloned load or materialization +3.18 dense + reduce users conflict resolver ensure_layout; optional remat/fold +3.23 broadcast multi-user conflict resolver per-op group_broadcast layout +3.33 S=16 + S=32 users conflict resolver use-site materialization; optional cloned load 3.34 S=64 slots=1 cast buildCastRequests group_slot_cast layout 3.35 slots fanout buildElementwiseRequests same group_slots layout on users -3.36 scalar slots=8/slots=1 conflict resolver cloned group_slot_load/broadcast -3.40 scalar dense + grouped conflict resolver cloned broadcast +3.36 scalar slots=8/slots=1 conflict resolver explicit slots=8/slots=1 producers +3.40 scalar dense + grouped conflict resolver ensure_layout; optional broadcast remat 3.41 incompatible fixed value conflict resolver diagnostic or ensure_layout vmi-to-vpto contract: @@ -985,17 +1011,17 @@ vmi-to-vpto contract: ```text diagnostic family builder / owner required failure -3.7.4 slots=1 unit-stride store buildStoreRequests no aligned row-local store recipe +3.7.4 slots=1 unit-stride store buildStoreRequests no aligned row-local store path 3.9 dense store of group slots buildStoreRequests use group_store/group_broadcast 3.11.2 S=32 unsafe tail buildMaskRequests missing full_tile_readable/gather -3.13 slots=8 width cast buildCastRequests no packed slot cast recipe -3.14 unsupported group size buildGroupReduceRequests no registered reduce recipe +3.13 slots=8 width cast buildCastRequests no packed slot cast transform +3.14 unsupported group size buildGroupReduceRequests no supported reduce layout/lowering 3.15.3 compact S=12 buildGroupMemoryRequests no compact gather plan -3.16.1 slots=8 non-unit load buildGroupMemoryRequests no packed strided slot load recipe +3.16.1 slots=8 non-unit load buildGroupMemoryRequests no packed strided slot load path 3.16.2 slots=1 bad stride buildGroupMemoryRequests no dynamic/unaligned row-local plan 3.19.2 invalid block_elems use conflict resolver no preserving materialization 3.25.2 public/external ABI buildFunctionBoundary no stable public VMI ABI -3.27 unaligned group_load buildGroupMemoryRequests no gather/block fallback recipe +3.27 unaligned group_load buildGroupMemoryRequests no gather/block fallback path 3.30 masked_load unsafe tail buildMaskRequests no padding/gather fallback vmi-to-vpto contract: @@ -1068,75 +1094,75 @@ adaptor physical values Each pattern rejects: ```text -missing current-op proof for an otherwise unsafe memory recipe +missing current-op proof for an otherwise unsafe memory lowering missing target capability unexpected group_slots dense consumer ``` -Target local recipe matrix: +Target local lowering matrix: ```text -load, recipe=dense_load_norm: +load, lowering=dense_load_norm: result layout contiguous emits pto.vlds / pto.vsts NORM paths covers dense store users and full-chunk row-local reduce input -load, recipe=load_dintlv2: +load, lowering=load_dintlv2: result layout deinterleaved=2, block_elems=1 emits vldsx2 DINTLV_B32 or normal load + vdintlv materialization covers f32->f16, S=16 parity reduce, f16->f32 widened values -load, recipe=load_dintlv4: +load, lowering=load_dintlv4: result layout deinterleaved=4, block_elems=1 emits two vldsx2 DINTLV_B32 plus vdintlv covers f32->f8, S=32 dintlv4 reduce -group_load, recipe=s16_group_load_block8_unit_stride: +group_load, lowering=s16_group_load_block8_unit_stride: result layout deinterleaved=2, block_elems=8 emits vldsx2/BDINTLV for 8 rows of 16xf32 covers compact logical S=16 when source_group_stride == 16 -group_load, recipe=s16_group_load_block8_stride: +group_load, lowering=s16_group_load_block8_stride: result layout deinterleaved=2, block_elems=8 emits two vsldb strided 32B block loads requires source_group_stride % 8 == 0 -group_load, recipe=s32_group_load_block8_stride: +group_load, lowering=s32_group_load_block8_stride: result layout deinterleaved=4, block_elems=8 emits four vsldb strided 32B block loads requires source_group_stride % 8 == 0 -group_load, recipe=group_load_contiguous_chunks: +group_load, lowering=group_load_contiguous_chunks: result layout contiguous emits one vlds per physical group chunk using row_stride address arithmetic covers the currently implemented full-chunk row-local group_load path -group_reduce_add{f|i}, recipe=one_vlane_reduce_contiguous: +group_reduce_add{f|i}, lowering=one_vlane_reduce_contiguous: consumes contiguous accumulator type T with group size VLaneElems(T) produces group_slots(G, slots=8) emits one vcgadd -group_reduce_add{f|i}, recipe=two_vlane_reduce_deinterleaved: +group_reduce_add{f|i}, lowering=two_vlane_reduce_deinterleaved: consumes deinterleaved=2, block_elems=1 produces group_slots(G, slots=8) emits two vcgadd operations and one vadd -group_reduce_add{f|i}, recipe=two_vlane_reduce_block8: +group_reduce_add{f|i}, lowering=two_vlane_reduce_block8: consumes deinterleaved=2, block_elems=8 produces group_slots(G, slots=8) emits two vcgadd operations and one vadd -group_reduce_add{f|i}, recipe=four_vlane_reduce_dintlv4: +group_reduce_add{f|i}, lowering=four_vlane_reduce_dintlv4: consumes deinterleaved=4, block_elems=1 produces group_slots(G, slots=8) emits four vcgadd operations and a vadd tree -group_reduce_add{f|i}, recipe=four_vlane_reduce_block8_stride: +group_reduce_add{f|i}, lowering=four_vlane_reduce_block8_stride: consumes deinterleaved=4, block_elems=8 produces group_slots(G, slots=8) emits four vcgadd operations and a vadd tree -group_reduce_add{f|i}, recipe=full_chunk_reduce_row_local: +group_reduce_add{f|i}, lowering=full_chunk_reduce_row_local: consumes contiguous accumulator type T with group size that is a multiple of one physical chunk L(T) produces group_slots(G, slots=1) @@ -1144,31 +1170,31 @@ group_reduce_add{f|i}, recipe=full_chunk_reduce_row_local: the existing row-local VCADD/VADD/VSEL sequence while preserving the same group_slots(G, slots=1) value contract -group_slot_load, recipe=group_slot_load_slots8_unit_stride: +group_slot_load, lowering=group_slot_load_slots8_unit_stride: result group_slots(G, slots=8) requires source_group_stride == 1 emits one packed vsldb load -group_slot_load, recipe=group_slot_load_slots1_row_local: +group_slot_load, lowering=group_slot_load_slots1_row_local: result group_slots(G, slots=1) supports aligned non-unit source_group_stride requires constant positive source_group_stride divisible by 256 / elementBits emits one lane-0 vsldb per group -group_broadcast, recipe=group_broadcast_slots8_vselr: +group_broadcast, lowering=group_broadcast_slots8_vselr: source group_slots(G, slots=8) result dense layout selected per use emits vselr using assigned result layout -group_broadcast, recipe=group_broadcast_slots1_vselr: +group_broadcast, lowering=group_broadcast_slots1_vselr: source group_slots(G, slots=1) result dense layout selected per use emits vdup/vselr row-local materialization -truncf, recipe=group_slot_cast_slots1_f32_to_f16: +truncf, lowering=group_slot_cast_slots1_f32_to_f16: source/result group_slots(G, slots=1) emits one lane-0 vcvt per group slot block - rejects packed slots=8 unless another plan is registered + rejects packed slots=8 unless slot-preserving cast support exists ``` The target matrix is the implementation contract. The staged status below @@ -1195,14 +1221,15 @@ group_reduce_addf: Full-chunk row-local assignment, including S=64 and S=256 f32 cases, uses #pto.vmi.layout and has focused layout-assignment/vmi-to-vpto lit coverage; the explicit slots=1 generic - VCADD row-local path is registered and selected locally. + VCADD row-local lowering is selected locally from the current op attrs and + assigned layouts. group_reduce_addi is implemented for i32 accumulator values. i8/i16 storage must be widened explicitly before grouped reduction because narrow integer reduction instructions widen their result. group_broadcast: explicit slots=8/1 source layouts select - packed or row-local VSELR recipes locally. Deinterleaved block-fragment + packed or row-local VSELR lowerings locally. Deinterleaved block-fragment results use the result layout block_elems as the local vselr selection group, so `deinterleaved = 4, block_elems = 8` broadcasts one group slot across each @@ -1219,9 +1246,9 @@ group_load: contiguous full-chunk path is selected from a contiguous result layout. S=16/S=32 block-aligned strided loads are selected from #pto.vmi.layout, and lower to one - vsldb per 32B row fragment and physical chunk. The explicit block8 recipe - is registered and checked by pto-validate-vmi-layout-ir before vmi-to-vpto. - The dedicated S=16 unit-stride vldsx2/BDINTLV recipe remains a local + vsldb per 32B row fragment and physical chunk. The explicit block8 support + is checked by pto-validate-vmi-layout-ir before vmi-to-vpto. + The dedicated S=16 unit-stride vldsx2/BDINTLV lowering remains a local peephole target. S=16/S=32 group_load with a non-constant, non-positive, or non-8-f32-aligned row_stride is rejected by vmi-layout-assignment because the stable gather @@ -1240,12 +1267,12 @@ group_store: multiple of the 32B store alignment in destination elements: 8 for f32, 16 for f16, and 32 for f8. Unit-stride f32 output is rejected because only the first row-local store is 32B-aligned; later `group_off + r` stores are - 4B apart. A future pack-to-slots=8 or unaligned-store recipe is required before + 4B apart. A future pack-to-slots=8 or unaligned-store lowering is required before contiguous `%c1` slots=1 group_store can be accepted. Packed group_slots(G, slots=8) group_store is implemented only when num_groups is a multiple of 8 and row_stride is constant 1; it emits one PAT_VL8 store per packed slot block. Non-unit packed group stores remain a - design target unless a strided packed-lane store recipe is made explicit. + design target unless a strided packed-lane store lowering is made explicit. ``` Current implementation contract for type-generic grouped reduction: @@ -1268,17 +1295,18 @@ Layout assignment: route f8 storage through extf to f32 before group_reduce_addf. route i8/i16 storage through extsi/extui to i32 before group_reduce_addi. route integer narrowing to i8 through trunci; direct i8 compute remains - illegal unless the target capability registry exposes an explicit recipe. + illegal unless target capability and explicit op semantics define that + lowering. diagnose direct f8/i8 compute use with a message that points at the offending op and suggests inserting the explicit cast when the op is meant to consume storage data. -Local recipe registry: - replace f32-shaped recipe keys with width-parametric recipe classes: - one_vlane_reduce - two_vlane_reduce_deinterleaved - four_vlane_reduce_deinterleaved - full_chunk_row_local_reduce +Layout fact helpers: + replace f32-shaped checks with width-parametric group-reduce classifiers: + one_vlane_reduce layout fact + two_vlane_reduce_deinterleaved layout fact + four_vlane_reduce_deinterleaved layout fact + full_chunk_row_local_reduce layout fact key legality on accumulator byte width, source/mask layout, result group_slots layout, num_groups, and target instruction capability. @@ -1288,8 +1316,8 @@ VMI-to-VPTO: materialize integer casts explicitly before reduction; direct i8 group reduce and direct i16 group reduce must not silently become a widening reduction in this pass. - keep VPTO lowering local: it consumes assigned layouts and registered local - recipes, but does not invent a new global layout plan. + keep VPTO lowering local: it consumes assigned layouts and current-op + attrs/operands, but does not invent a new global layout plan. Tests: cover f16 direct and i16-storage-to-i32 grouped reductions. @@ -1302,15 +1330,15 @@ Tests: Examples: ```text -group_reduce_add{f|i}, recipe=two_vlane_reduce_deinterleaved: +group_reduce_add{f|i}, lowering=two_vlane_reduce_deinterleaved: consume deinterleaved=2, block_elems=1 emit two VCGADDs and one VADD -group_reduce_add{f|i}, recipe=two_vlane_reduce_block8: +group_reduce_add{f|i}, lowering=two_vlane_reduce_block8: consume deinterleaved=2, block_elems=8 emit two VCGADDs and one VADD -group_reduce_add{f|i}, recipe=four_vlane_reduce_dintlv4: +group_reduce_add{f|i}, lowering=four_vlane_reduce_dintlv4: consume deinterleaved=4 emit four VCGADDs and reduction tree @@ -1346,7 +1374,7 @@ After assignment: Every VMI value has layout. Every VMI mask has layout and granularity plan. Every lowering choice is locally deterministic or explicit in attrs/layouts. -Every ensure_* helper has a materialization recipe. +Every ensure_* helper has a materialization path. Every control-flow edge has matching VMI layouts. ``` @@ -1364,8 +1392,8 @@ allowed: diagnostic not allowed: - walking from a consumer to a producer to decide a recipe - walking from a consumer to a mask producer to decide whether a recipe is legal + walking from a consumer to a producer to decide a lowering + walking from a consumer to a mask producer to decide whether a lowering is legal inspecting users to choose a result layout or materialization recovering full_tile_readable from surrounding MTE/caller context ``` @@ -1378,7 +1406,7 @@ Current audit result: lowering the deinterleaved create_group_mask itself, vmi-to-vpto first materializes contiguous grouped predicate chunks and then applies predicate pdintlv in the same tree shape as the data vdintlv. It still does not walk - from group_reduce_addf to the mask defining op to choose or reject the plan. + from group_reduce_addf to the mask defining op to choose or reject lowering. The dynamic active_elems_per_group form is also op-local: vmi-to-vpto lowers contiguous chunks with vci/vshrs/vshls/vsub/vcmps, then uses the same predicate pdintlv tree for S=32 deinterleaved masks. @@ -1413,7 +1441,7 @@ op name logical type actual layout requested layout -selected/missing plan +selected/missing support path recommended rewrite or option ``` @@ -1424,8 +1452,8 @@ VMI-LAYOUT-CONTRACT: pto.vmi.truncf requires #pto.vmi.layout, but the source value is fixed to #pto.vmi.layout by the selected - strided group_load recipe. Register a rematerialization or preserving - materialization recipe, or avoid consuming this block-loaded value with truncf. + strided group_load layout. Register a rematerialization or preserving + materialization path, or avoid consuming this block-loaded value with truncf. ``` ## 11. Test And Simulator Acceptance @@ -1493,14 +1521,13 @@ the case catalog. Current broad runtime sweep: ```text -WORK_SPACE=$PWD/.tmp/vmi-runtime-batch-layout-gate CASE_PREFIX='vmi/' JOBS=4 \ +WORK_SPACE=$PWD/.tmp/vmi-runtime-batch-final CASE_PREFIX='vmi/' JOBS=4 \ test/vpto/scripts/run_host_vpto_validation_parallel.sh -PASS=43 FAIL=0 -summary: .tmp/vmi-runtime-batch-layout-gate/parallel-summary.tsv -log scan: rg -n "RV_|alignment|\[ERROR\]|\[error\]|ERROR" \ - .tmp/vmi-runtime-batch-layout-gate.log -result: no matches +TOTAL_CASES=47 +PASS=47 FAIL=0 +summary: .tmp/vmi-runtime-batch-final/parallel-summary.tsv +result: all summary entries are PASS ``` The `find: Permission denied` messages printed while discovering CANN simulator @@ -1576,10 +1603,10 @@ diagnostic endpoints: repository evidence: all concrete lit/runtime paths listed below exist - all 43 runtime case directories contain kernel.pto, launch.cpp, main.cpp, + all 47 runtime case directories contain kernel.pto, launch.cpp, main.cpp, golden.py, and compare.py - latest broad VMI runtime sweep passed: PASS=43 FAIL=0 - latest full VMI lit sweep passed: 340/340 + latest broad VMI runtime sweep passed: PASS=47 FAIL=0 + latest full VMI lit sweep passed: 350/350 ``` Current checked-in coverage for 3.3 dense f8->f32->compute->f8: @@ -1958,7 +1985,7 @@ Current checked-in lit coverage for the first `vmi-layout-sink-materialization` optimization is: ```text -test/lit/vmi/vmi_layout_sink_materialization_binary.pto +test/lit/vmi/vmi_layout_sink_materialization_binary.pto // unary, binary, fma, cmp, and select data ops test/lit/vmi/vmi_layout_sink_materialization_mask.pto ``` @@ -1969,27 +1996,27 @@ test/lit/vmi/vmi_legalize_arith_select.pto test/lit/vmi/vmi_ptoas_cli_control_flow.pto ``` -Current checked-in lit coverage for the first semantic local-recipe layout gate +Current checked-in lit coverage for the first semantic local-lowering layout gate is: ```text -test/lit/vmi/vmi_layout_gate_group_slot_load_recipe_invalid.pto -test/lit/vmi/vmi_layout_gate_group_load_recipe_invalid.pto -test/lit/vmi/vmi_layout_gate_group_store_recipe_invalid.pto +test/lit/vmi/vmi_layout_gate_group_slot_load_support_invalid.pto +test/lit/vmi/vmi_layout_gate_group_load_support_invalid.pto +test/lit/vmi/vmi_layout_gate_group_store_support_invalid.pto test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto -test/lit/vmi/vmi_layout_gate_store_recipe_invalid.pto +test/lit/vmi/vmi_layout_gate_store_support_invalid.pto test/lit/vmi/vmi_layout_gate_helper_materialization_shape_invalid.pto -test/lit/vmi/vmi_layout_gate_group_reduce_recipe_invalid.pto -test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto -test/lit/vmi/vmi_layout_gate_group_broadcast_recipe_invalid.pto -test/lit/vmi/vmi_layout_gate_truncf_recipe_invalid.pto -test/lit/vmi/vmi_layout_gate_extf_recipe_invalid.pto -test/lit/vmi/vmi_layout_gate_bitcast_recipe_invalid.pto +test/lit/vmi/vmi_layout_gate_group_reduce_support_invalid.pto +test/lit/vmi/vmi_layout_gate_group_reduce_slots1_support_invalid.pto +test/lit/vmi/vmi_layout_gate_group_broadcast_support_invalid.pto +test/lit/vmi/vmi_layout_gate_truncf_support_invalid.pto +test/lit/vmi/vmi_layout_gate_extf_support_invalid.pto +test/lit/vmi/vmi_layout_gate_bitcast_support_invalid.pto test/lit/vmi/vmi_layout_gate_bitcast_group_slots_invalid.pto ``` Current checked-in direct `vmi-to-vpto` preflight coverage for bitcast local -recipes is: +lowering is: ```text test/lit/vmi/vmi_to_vpto_bitcast_footprint_invalid.pto @@ -2050,7 +2077,7 @@ Diagnostic-only cases: 3.16.1 group_slot_load slots=8 non-unit stride 3.16.2 group_slot_load slots=1 dynamic or unaligned stride 3.27 S=32 source_group_stride not divisible by 8 f32 elements -3.19.2 block_elems=8 value consumed by truncf without materialization recipe +3.19.2 block_elems=8 value consumed by truncf without materialization path 3.25.2 public/external VMI boundary 3.30 unsafe masked_load tail without stable masked/gather fallback ``` @@ -2069,7 +2096,7 @@ entries: ```text lit: - test/lit/vmi/vmi_layout_gate_helper_recipe_invalid.pto + test/lit/vmi/vmi_layout_gate_helper_support_invalid.pto test/lit/vmi/vmi_layout_gate_helper_materialization_shape_invalid.pto test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto test/lit/vmi/vmi_layout_assignment_group_load_s16_compact_stride12_invalid.pto @@ -2141,15 +2168,15 @@ group_store ```text 3.8 cast commute through group_broadcast 3.18 dense/group-reduce multi-consumer -3.19 block_elems recipe selection +3.19 block_elems layout selection 3.23 group_broadcast multi-consumer 3.32 f32 feeding f8 store and S=32 reduce 3.33 S=16/S=32 reduce multi-consumer rematerialization 3.34 slots=1 group-slot f32->f16 cast 3.35 group_slots fanout to group_store and group_broadcast -3.36 group_slot_load rematerialized for slots=8/slots=1 +3.36 group_slot_load expressed as explicit slots=8/slots=1 producers 3.38 multi-tile group_slots arity -3.40 scalar broadcast rematerialized for dense/grouped users +3.40 scalar broadcast materialized for dense/grouped users 3.41 non-rematerializable value with ensure_layout ``` @@ -2192,12 +2219,12 @@ Current evidence for the case-catalog objective: checked-in runtime case directory 3. every runtime case directory contains kernel.pto, launch.cpp, main.cpp, golden.py, and compare.py -4. the latest broad VMI runtime sweep passed: PASS=43 FAIL=0 -5. the latest full VMI lit sweep passed: 342/342 +4. the latest broad VMI runtime sweep passed: PASS=47 FAIL=0 +5. the latest full VMI lit sweep passed: 350/350 6. every unsupported endpoint listed in section 11.3 has a diagnostic lit test 7. vmi-to-vpto decisions are represented by current-op attrs/operands, assigned layouts, helper ops, rematerialization, or diagnostics -8. no separate recipe string attr is emitted or consumed +8. no separate lowering-plan string attr is emitted or consumed 9. release docs remain untouched; this is still a design/implementation plan under docs/designs ``` diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index 82a84082c6..42e62e8b3a 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -35,7 +35,7 @@ vmi-to-vpto 不允许通过上下文猜 lowering。 必须在 vmi-layout-assignment 或后续 VMI layout optimization 阶段变成显式 IR: 1. vmi.vreg/vmi.mask 的 layout -2. current-op attrs/operands that make the local recipe deterministic +2. current-op attrs/operands that make the local lowering deterministic 3. use-site ensure_layout / ensure_mask_layout / ensure_mask_granularity 4. rematerialized or cloned producer 5. target capability diagnostic @@ -50,7 +50,7 @@ separates correctness from optimization: hard legalization: produces legal layout-assigned VMI IR for all supported semantics inserts conservative ensure_* helpers at incompatible uses - may choose a simple canonical layout even when a fused consumer recipe exists + may choose a simple canonical layout even when a fused consumer lowering exists must diagnose unsupported semantics before vmi-to-vpto has to guess layout optimization: @@ -80,7 +80,7 @@ A later optimization may replace that use with: pto.vmi.store %x : deinterleaved=2 ``` -only if the store op itself has a local deterministic recipe for preserving the +only if the store op itself has a local deterministic lowering for preserving the same row-major memory effect, such as a layout-aware `vstsx2 INTLV` lowering. Both forms are semantically complete. The second form is an optimization, not a hard requirement for correctness. @@ -115,10 +115,10 @@ group reduce: layout conflict: one value with dense and group-reduce consumers one value with S=16 and S=32 group-reduce consumers - one scalar broadcast rematerialized for dense and grouped users + one scalar broadcast materialized for dense and grouped users, with optional rematerialization one non-rematerializable value materialized with use-site ensure_layout - one scalar group-slot source rematerialized as slots=8 and slots=1 - S=16 block_elems=1/8 recipe selection + one scalar group-slot source expressed as explicit slots=8 and slots=1 producers + S=16 block_elems=1/8 layout selection dense consumer of group_slots diagnostic packed group-slot width-changing cast diagnostic S=64 slots=1 group-slot width-changing cast @@ -173,7 +173,9 @@ consumer-driven pressure: elementwise/select, masked_load/masked_store conflict resolution: - cheap rematerialization, explicit ensure_layout, explicit diagnostics + explicit ensure_layout, explicit ensure_mask_layout, explicit diagnostics + optimization passes may later replace the helpers with rematerialization or + layout-aware consumers control-flow propagation: scf.if, scf.for iter_args/results, internal/private function boundaries, @@ -185,7 +187,7 @@ memory legality: ``` No extra layout kind should be added unless a new case proves that the existing -layouts and recipes cannot express the logical behavior. The remaining open +layouts and explicit helper contracts cannot express the logical behavior. The remaining open items are not missing layout semantics: ```text @@ -224,7 +226,7 @@ cast boundary: i8 participates through extsi/extui/trunci. Signedness is carried by the cast op semantics, not by a separate layout. On the current VPTO target, 32-bit to 8-bit integer narrowing is only a - baseline recipe for unsigned i8 results because the available VCVTII forms + baseline lowering for unsigned i8 results because the available VCVTII forms are s32/u32 -> u8. compute boundary: @@ -289,7 +291,7 @@ slot_lane(g) = g % K All non-slot lanes are undefined and may only be read by group-aware operations. Ordinary dense `add/mul/store/truncf` cannot consume `group_slots`. -`K` is selected by the producer/consumer local recipe: +`K` is selected by the assigned producer/result contract: ```text S=8/16/32 packed VCG result -> slots=8 @@ -337,7 +339,7 @@ or explicit helper: pto.vmi.ensure_mask_granularity ``` -`vmi-to-vpto` is allowed to choose a deterministic recipe from local +`vmi-to-vpto` is allowed to choose a deterministic lowering from local information on the current op: ```text @@ -354,104 +356,80 @@ ops to recover a lowering decision or a memory-safety proof. If a decision cannot be made from that local information, layout assignment must rewrite the IR until the decision is explicit in attrs, operand/result -layouts, helper ops, cloned producers, or diagnostics. `vmi-to-vpto` must not -consume a separate string recipe attr. +layouts, helper ops, or diagnostics. Later optimization passes may replace +helpers with cloned/rematerialized producers, but `vmi-to-vpto` must not +consume a separate string lowering-plan attr. -### 4.1 Local Recipe Contract +### 4.1 Local Lowering Contract -The lowering recipe is derived from op + assigned operand/result layouts + -explicit attrs/operands. If two legal recipes cannot be distinguished from +The lowering path is derived from op + assigned operand/result layouts + +explicit attrs/operands. If two legal lowerings cannot be distinguished from that local information, the IR is missing a semantic carrier and must be -extended before the recipe is implemented. +extended before that lowering is implemented. -Locally deterministic decisions in the current implementation: +The shared abstraction is a layout fact classifier, not a central lowering-plan +registry. A classifier may answer questions such as: ```text -group_load: - result layout, num_groups, row_stride, source type, and target capability - decide contiguous chunks versus S=16/S=32 block8 vsldb lowering. Unit-stride - vldsx2/BDINTLV can be a local peephole for the same block8 layout. - -group_slot_load: - result group_slots layout and source_group_stride decide packed slots=8 - versus row-local slots=1 vsldb lowering. A single source op may still be - rematerialized into two ops when different users require different result - layouts; each clone is then locally deterministic. - -group_reduce_add{f|i}: - source/mask layout, result group_slots layout, num_groups, element type, and - the typed reduce semantics decide the local reduction recipe. The recipe is - not keyed by f32 shape names. It is derived from the element byte width. - Floating-point `group_reduce_addf` carries `reassoc`; integer - `group_reduce_addi` does not. - - VLaneElems = 32B / sizeof(T) - L = 256B / sizeof(T) - S = logical_lane_count / num_groups - - S == VLaneElems -> contiguous vcgadd, result slots=8 - S == 2 * VLaneElems -> deinterleaved=2 vcgadd tree, result slots=8 - S == 4 * VLaneElems -> deinterleaved=4 vcgadd tree, result slots=8 - S >= L && S % L == 0 -> contiguous row-local vcadd/vsel, result slots=1 - - Type support is controlled by the typed reduce op semantics and target - capability, not by separate per-type shape rules. Once a type is legal for a - reduce op, the same formula above selects its layout and local recipe. The - current checked-in implementation may lag this design target; that is staged - implementation status, not a design boundary. - - The formula is applied to the accumulator/reduce element type, not - necessarily the storage element type. 8-bit floating-point storage first - casts to f32 for `group_reduce_addf`; 8-bit and 16-bit integer storage first - casts to a signed/unsigned i32 accumulator for - `group_reduce_addi`. In the baseline VMI contract, f8/i8 are storage and - cast-boundary types: they may be the source or destination of cast, load, and - store, but they are not accumulator/compute types for group reduce. Direct - 8-bit grouped reduction is illegal unless the target exposes an explicit - 8-bit compute recipe. - -group_broadcast: - source group_slots layout, result dense layout, num_groups, and element type - decide vdup/vselr materialization. - -truncf: - source/result group_slots layouts and element widths decide the slots=1 - f32->f16 slot-preserving vcvt path. +cast layout fact: + f16/i16 -> f32/i32 requires contiguous source and deinterleaved=2 result + f8/i8 -> f32/i32 requires contiguous source and deinterleaved=4 result + f32/i32 -> f16/i16 requires deinterleaved=2 source and contiguous result + f32/i32 -> f8/i8 requires deinterleaved=4 source and contiguous result + +group_reduce layout fact: + define E = sizeof(accumulator T), VLaneElems = 32B / E, + L = 256B / E, S = N / G. + S == VLaneElems requires contiguous source/mask and + group_slots(G, slots=8) result. + S == 2 * VLaneElems requires deinterleaved=2 source/mask and + group_slots(G, slots=8) result. + S == 4 * VLaneElems requires deinterleaved=4 source/mask and + group_slots(G, slots=8) result. + S >= L && S % L == 0 requires contiguous source/mask and + group_slots(G, slots=1) result. + +memory safety fact: + full_read_elems, shaped safe-tail memref, or explicit fallback option + proves whether rounded-up physical reads are legal. ``` -Other layout-only or attr-only decisions in the current implementation: +These helpers return semantic layout requirements and capability diagnostics. +They do not return VPTO instruction names, cost decisions, clone decisions, or +multi-user plans. -```text -load: - result layout plus explicit memory attrs decide the lowering. full_read_elems - is the memory-safety proof; vmi-to-vpto may not recover that proof from MTE or - caller context. - -group_store: - source group_slots layout and explicit output stride decide packed slots=8 - versus row-local slots=1 store legality. If another legal store recipe - needs more information, assignment must make that information explicit in the - op or helper IR before vmi-to-vpto uses it. +The useful shared fact is the part that would otherwise be recomputed by two or +more stages and must stay identical for correctness: -masked_load: - explicit passthrough, mask layout, full physical read, shaped safe-tail memref, - or an explicit diagnostic decide legality. A future stable gather fallback - must be made explicit by assignment before vmi-to-vpto lowers it. - -masked_store/select/elementwise: - operand/result layouts and explicit mask granularity decide the lowering. - They remain transfer ops unless a future case introduces competing recipes. - -extf/truncf: - dense width-changing paths are layout-determined today. Any future - commute-through-group-broadcast or alternative VCVT recipe must have an - explicit IR carrier first. +```text +cast width ratio: + assignment uses it to request source/result layouts and insert ensure_layout. + validation uses it to reject unsupported assigned cast shapes. + lowering uses it to check the local op shape before emitting VPTO. + +group_reduce lane partition: + assignment uses N/G and accumulator element width to request source/mask and + result layouts. + validation uses the same math to reject legacy or incomplete group_slots. + lowering uses the already assigned layouts to select the local VPTO sequence. + +layout materialization shape: + assignment may insert ensure_layout without proving every physical sequence. + validation and lowering use one support query to decide whether that explicit + helper is materializable on the target. + optimization uses the same query only when it wants to fold/sink/remove an + explicit helper. ``` -Forbidden non-local recipe recovery: +The helper is not useful when it only renames one local pattern. A single +`if (is this op with this attr)` that is not shared by assignment, validation, +lowering, or an optimization should stay local to that pass. The support layer +exists to prevent divergent layout math, not to move every branch into a table. + +Forbidden non-local lowering recovery: ```text -No pattern may synthesize a recipe or memory proof by: +No pattern may recover a lowering decision or memory proof by: - walking from group_reduce to the load/group_load producer - walking from store/broadcast/truncf to the group_reduce producer - scanning sibling users of a group_slots value @@ -463,48 +441,59 @@ If the current op lacks enough local information, `vmi-to-vpto` emits `VMI-LAYOUT-CONTRACT` at the current op and prints the op name, logical type, assigned layouts, and the missing decision class. -## 5. Local Recipe Registry - -The compiler owns a target-aware local recipe registry. Layout assignment and -layout optimization query this registry to decide which explicit IR shape to -produce. `vmi-to-vpto` queries the same registry only to verify and lower the -current op from local information. - -The registry is not serialized as a separate recipe-selection attribute. If -two legal physical recipes cannot be distinguished by the current op's name, -attrs, operands, operand/result layouts, helper ops, and target options, the -VMI IR is missing a carrier. Add an explicit attr, operand, helper op, or -semantic op before implementing that recipe. +## 5. Layout Requests, Helpers, And Optimization -### 5.1 Recipe Kinds +The compiler must not carry a target-aware lowering-plan registry as the shared +contract between assignment, optimization, validation, and lowering. The +shared contract is: ```text -ProducerRecipe: - op can produce result layout L - example: load -> deinterleaved=4 using DINTLV_B32 + vdintlv - -ConsumerRecipe: - op can consume operand layout L - example: group_reduce S=32 consumes deinterleaved=4 - -TransferRecipe: - op ties operand/result layouts - example: addf requires same dense layout for operands/result - -MaterializationRecipe: - layout A -> layout B without changing logical value - example: deinterleaved=4 -> contiguous by vintlv tree +1. assigned layouts on VMI types +2. explicit use-site helpers: ensure_layout, ensure_mask_layout, + ensure_mask_granularity +3. explicit op attrs/operands that are part of the semantic op +4. small layout fact classifiers shared only where they remove duplicated + layout math +5. target capability diagnostics +``` -RematerializationRecipe: - cheap producer can be cloned for a use-site layout - example: broadcast/create_mask/group_broadcast +This split makes optimization simpler only when optimization is phrased as +rewriting explicit helper IR: -DiagnosticRecipe: - known unsupported semantic/capability boundary - example: compact S=12 requires gather materialization +```text +baseline: + %x_d2 = pto.vmi.extf %x_f16 + %a = pto.vmi.addf %x_d2, %k_d2 + %a_c = pto.vmi.ensure_layout %a : deinterleaved=2 -> contiguous + pto.vmi.store %a_c, %out0 + %x_c = pto.vmi.ensure_layout %x_d2 : deinterleaved=2 -> contiguous + pto.vmi.store %x_c, %out1 + +fold-consumers: + checks only each local ensure_layout + store use. + If VMILayoutSupport says the store can preserve row-major memory from the + source layout, rewrite that use to store the source directly. + It does not inspect sibling users of %x_d2 and does not recompute the layout + assignment. + +rematerialize: + checks only cheap producer + ensure_layout. + If the producer can directly create the requested layout, clone/rematerialize + that producer for the use. + Memory producers such as group_slot_load are excluded until a separate proof + says cloning is semantically and economically valid. + +sink-materialization: + checks only explicit ensure_* operands of a layout-transparent op. + If every operand helper is compatible, rebuild the op in the source layout and + leave one ensure_* on the result. ``` -### 5.2 Dense Recipes From Cases +If an optimization needs a global cost decision, it should produce a new +explicit IR shape and then rely on canonicalize/CSE. It must not communicate a +private decision to `vmi-to-vpto`. + +### 5.1 Baseline Dense Layout Requests ```text f16 -> f32: @@ -526,47 +515,43 @@ f32 -> f8: elementwise dense: all dense operands/results share the same layout -broadcast scalar: - rematerializable to any dense layout requested by the consumer - -load: - may be rematerialized per use when two consumers request incompatible dense - layouts, such as S=16 deinterleaved=2 and S=32 deinterleaved=4 +dense store: + requests contiguous source + if the stored value is assigned deinterleaved, baseline assignment inserts + ensure_layout at the store use ``` -### 5.3 Group Recipes From Cases +### 5.2 Baseline Group Layout Requests ```text -group_reduce_add{f|i} typed shape classification: - define E = sizeof(T), VLaneElems = 32B / E, L = 256B / E, S = N / G. - S=VLaneElems uses contiguous input and group_slots(G, slots=8). - S=2*VLaneElems uses deinterleaved=2 input/mask and group_slots(G, slots=8). - S=4*VLaneElems uses deinterleaved=4 input/mask and group_slots(G, slots=8). - S>=L && S%L==0 uses contiguous input/mask and group_slots(G, slots=1); - lowering reduces each full physical chunk, accumulates all chunks in the - same logical group through lane0, and writes one physical result part per - group. +group_reduce_add{f|i}: + uses the group_reduce layout fact in section 4.1. + The source and mask operands request the computed dense layout. + The result is assigned group_slots(G, slots=8) or group_slots(G, slots=1). + Floating-point `group_reduce_addf` carries `reassoc`; integer + `group_reduce_addi` does not. group_slot_load: result group_slots(G, slots=8) for packed slots result group_slots(G, slots=1) for row-local slots group_broadcast: - source group_slots(G,K) - result is dense layout requested by each consumer - rematerialize per use instead of forcing one result layout + source requests group_slots(G,K) + result requests one dense layout + incompatible dense consumers are represented by ensure_layout after the + broadcast result; a later optimization may clone/rematerialize the broadcast group_store: - source group_slots(G,K) + source requests group_slots(G,K) + explicit output stride attrs/operands decide store legality group_slot_cast f32 -> f16: - slots=1 row-local source/result is legal with - group_slot_cast_slots1_f32_to_f16 - slots=8 packed source is illegal unless a packed slot-preserving recipe is - registered + slots=1 row-local source/result is legal + slots=8 packed source is illegal unless a future explicit helper or semantic + op defines the packed slot-preserving transform ``` -### 5.4 Tail And Memory Safety Recipes +### 5.3 Tail And Memory Safety Mask semantics and memory legality are separate: @@ -607,7 +592,7 @@ one mask used by f32 and f16 consumers: vmi-to-vpto consumes the assigned per-use mask materialization ``` -### 5.5 Case-Driven Request Matrix +### 5.4 Case-Driven Request Matrix The first implementation should build requests from the following finite table. This table is deliberately case-derived; adding a new request kind requires a @@ -616,8 +601,9 @@ new catalog case or a proof that it is equivalent to one listed here. ```text dense store: requests dense contiguous source - if source is deinterleaved, assignment must insert ensure_layout or select a - store recipe such as vstsx2 that consumes the assigned layout explicitly + if source is deinterleaved, baseline assignment inserts ensure_layout at the + store use. A later optimization may fold that helper into a layout-aware + store lowering such as vstsx2. truncf f32 -> f16: requests source deinterleaved=2, block_elems=1 @@ -642,8 +628,9 @@ group_reduce_add{f|i}: group_broadcast: requests source group_slots(num_groups, slots=K) - produces one dense result layout per consumer request - is cloned per incompatible dense consumer + produces one assigned dense result layout + incompatible dense consumers are represented by ensure_layout uses; a later + optimization may clone/rematerialize the group_broadcast per consumer group_store: requests source group_slots(num_groups, slots=K) @@ -666,7 +653,7 @@ group_slot_load: group_load: requests result deinterleaved=2/4, block_elems=8 for S=16/S=32 block - fragment recipes, or contiguous for row-local full-chunk recipes + fragments, or contiguous for row-local full chunks masked_load: requests result layout from its consumers @@ -674,18 +661,20 @@ masked_load: requires explicit passthrough; padding is not synthesized masked_store: - requests dense source layout selected by the store recipe + requests dense source layout required by the store op requests mask layout matching the source layout and store element granularity does not choose memory safety for an earlier load create_mask/create_group_mask: - produces whichever mask layout each consumer requests - may be cloned per incompatible mask layout or granularity + produces one assigned mask layout and granularity + incompatible mask consumers are represented by ensure_mask_layout or + ensure_mask_granularity; optimization may clone/rematerialize the mask op scf.if/scf.for/call/return: requests equality across carried VMI values, yielded values, call operands, callee arguments, and function results - private/internal functions may specialize or materialize at boundaries + baseline private/internal functions materialize at boundaries; optimization + may specialize signatures public/external VMI boundaries are diagnostics until an ABI is defined ``` @@ -694,48 +683,50 @@ Important negative requests: ```text ordinary dense add/mul/store/truncf cannot request group_slots packed group_slots(slots=8) cannot request width-changing cast unless a packed -slot-preserving cast recipe is registered +slot-preserving cast transform is explicitly represented slots=1 group_store cannot request unit-stride row-major output until a pack or -unaligned-store recipe exists +unaligned-store transform is explicitly represented ``` -### 5.6 Conflict Resolution Matrix +### 5.5 Optimization Hooks + +Baseline assignment resolves incompatible use-site requests by keeping one +assigned layout on the value and inserting explicit helpers at the use sites +that need another layout. It does not clone producers, rematerialize cheap +ops, choose memory-fused layouts by cost, or specialize private function +signatures for performance. -When one value receives incompatible requests, assignment resolves it using the -first legal row below. `vmi-to-vpto` never repeats this decision. +Those choices belong to later VMI layout optimization passes. They consume +the explicit helper IR and may rewrite it when the rewrite preserves the same +logical value and externally visible memory effect: ```text -cheap producer with multiple requested layouts: - clone the producer and assign each clone independently - examples: load, broadcast, create_mask, create_group_mask, group_broadcast - memory-read producers require the same explicit no-alias and safe-read proof - at each clone site +ensure_layout + store: + fold into a layout-aware store if the store can directly consume the source + layout and still write row-major memory -non-cheap value with registered materialization: - keep one chosen layout on the value and insert ensure_layout at the use site - examples: deinterleaved=4 -> contiguous before dense store +producer + ensure_layout: + clone/rematerialize the producer for that use only when the producer is cheap + or has an explicit safe-read proof -layout-transparent chain: - assign the whole equivalence class to the non-contiguous consumer request when - that avoids materialization - examples: broadcast -> addf -> S=32 group_reduce +elementwise chain + ensure_layout: + sink or hoist materialization through pure layout-transparent ops -control-flow join: - all incoming values must be materialized to one layout before yield/branch - examples: scf.if yielding group_slots, scf.for loop-carried group_slots +group_broadcast + incompatible dense consumers: + type each group_broadcast op for its consumer layout; do not force one result + layout across independent group_broadcast users -private function boundary: - specialize or materialize at call/callee-entry before vmi-to-vpto +create_mask/create_group_mask + incompatible mask consumers: + clone/rematerialize the mask producer per layout or predicate granularity -no clone/materialization/specialization recipe: - emit a diagnostic naming the requesting op and both layouts +private function boundary: + specialize function signatures only in an optimization pass; baseline + assignment materializes at boundary uses ``` -The cost model may choose between legal rows only when the observable contract -is identical. For example, S=16 `block_elems=1` and `block_elems=8` are both -valid reduce inputs, but `block_elems=8` is selected only when a producer recipe -such as strided `group_load` naturally creates 32B row fragments or when cost -proves it cheaper without breaking another consumer such as `truncf`. +If no helper materialization or optimization rewrite is legal, the diagnostic +must name the value's assigned layout, the use-site requested layout, and the +op that requested it. ## 6. Layout Assignment Algorithm @@ -758,7 +749,7 @@ Create a use-site request for: ```text 1. every operand use that requires a specific layout 2. every control-flow yield/branch/call/return edge -3. every memory operation that requires a memory legality recipe +3. every memory operation that requires an explicit memory legality proof ``` ### 6.2 Constraints @@ -767,14 +758,14 @@ Hard constraints: ```text group_slots cannot feed ordinary dense consumers -direct group-slot width-changing cast requires a slot-preserving recipe +direct group-slot width-changing cast requires an explicit slot-preserving transform public/external VMI function boundary requires a stable ABI or diagnostic S=32 fast tail load requires full_tile_readable or gather fallback ``` -`slots = 1` row-local cast may satisfy the slot-preserving recipe requirement. +`slots = 1` row-local cast may satisfy the slot-preserving transform requirement. Packed `slots = 8` f32->f16 remains a diagnostic unless a separate packed cast -or unpack/materialization recipe is registered. +or unpack/materialization transform is represented explicitly. Equivalence constraints: @@ -788,22 +779,23 @@ scf.if/scf.for: as the region result/iter_arg ``` -Candidate constraints: +Canonical baseline constraints: ```text S=16 group_reduce: - choose block_elems=1 or block_elems=8 by cost and explicit assignment constraints + request deinterleaved=2; baseline uses block_elems=1 unless the producer + result already carries block_elems=8 as an explicit layout one dense value feeding S=16 and S=32 group_reduce: - rematerialize a cheap producer per consumer layout, or insert an explicit - materialization recipe; the final lowering pass must not pick one layout after - seeing both users + keep the value's assigned layout and insert ensure_layout at both use sites + that need deinterleaved=2 or deinterleaved=4 load/group_load: - choose memory recipe and result layout together + use the op's assigned result layout and explicit memory-safety attrs only group_broadcast: - rematerialize per dense consumer layout + keep one assigned dense result layout and communicate other dense use layouts + through ensure_layout ``` ### 6.3 Solving @@ -812,26 +804,24 @@ Recommended solving order: ```text 1. Build function/control-flow SCCs. -2. Collect candidate recipes for every op. -3. Propagate hard required layouts from consumers. -4. Propagate producer natural layouts where they are unique. -5. Resolve multi-recipe ops by cost. -6. Insert use-site materialization where a value has multiple incompatible uses. -7. Rematerialize cheap producers instead of materializing when cheaper. -8. Specialize internal function signatures. -9. Emit diagnostics for unsatisfied hard constraints. -10. Rewrite VMI types and insert explicit helper/rematerialized ops. +2. Collect natural producer layouts and hard use-site layout requests. +3. Propagate equality constraints through dense elementwise ops and CFG edges. +4. Choose one deterministic assigned layout for each value or equivalence + class. +5. Insert ensure_layout / ensure_mask_layout / ensure_mask_granularity at uses + whose requested layout differs from the assigned layout. +6. Emit diagnostics for unsupported semantic constraints or missing explicit + memory-safety proofs. +7. Rewrite VMI types and insert explicit helper ops. ``` -Tie-breaking must be deterministic. Suggested priority: +Tie-breaking must be deterministic and deliberately simple. Suggested priority: ```text -1. Avoid unsupported recipes. -2. Prefer rematerializing cheap producers over register materialization. -3. Prefer layouts accepted by all consumers without conversion. -4. Prefer memory-fused layout recipes over load + register rearrange. -5. Prefer fewer VPTO instructions. -6. Prefer contiguous only when cost ties and no consumer requests a special layout. +1. Preserve an explicit user-provided layout attr. +2. Preserve a unique producer natural layout when present. +3. Preserve an equality-class non-contiguous layout when required by a hard op. +4. Otherwise choose contiguous. ``` ## 7. Control Flow And Functions @@ -889,7 +879,7 @@ For each op, the pattern: 1. reads operand/result layouts 2. reads current op attrs and operand values 3. asks TypeConverter for ordered physical values -4. emits the locally implied VPTO recipe +4. emits the locally implied VPTO lowering 5. fails if target capability or required local proof is absent ``` @@ -916,7 +906,7 @@ current VMI op body/attrs: helper materialization chain: allowed only to strip ensure_mask_layout / ensure_mask_granularity for - static predicate analysis that does not choose a different layout or recipe + static predicate analysis that does not choose a different layout or lowering diagnostic embellishment: allowed only to improve an already-failed capability message, such as naming @@ -930,7 +920,7 @@ grouped masks: assignment emits explicit contiguous and deinterleaved mask values, and `vmi-to-vpto` lowers the deinterleaved mask op itself through contiguous grouped-mask materialization followed by predicate deinterleave. It does not walk from `group_reduce_addf` to the mask producer to choose or reject -the recipe. Dynamic `active_elems_per_group` follows the same rule: the +the lowering. Dynamic `active_elems_per_group` follows the same rule: the `create_group_mask` op lowers its own SSA scalar with vci/vshrs/vshls/vsub/vcmps for contiguous chunks before any predicate deinterleave. @@ -952,8 +942,8 @@ group_slots(G,K): slot_block0, slot_block1, ... ``` -Two physical bundle entries may alias the same VPTO SSA value when the local -recipe proves they have the same contents, such as group_broadcast feeding both +Two physical bundle entries may alias the same VPTO SSA value when the current +op semantics prove they have the same contents, such as group_broadcast feeding both parts of a `deinterleaved=2` broadcast result. Arity still follows the layout; aliasing is not a different layout. @@ -989,15 +979,87 @@ public VMI function boundary: make function internal, inline before assignment, or define ABI layout ``` -## 11. Design Completion Criteria +## 11. Implementation Migration Checks + +The design is useful only if the implementation removes duplicated decision +points instead of renaming them. The migration target is: + +```text +assignment: + computes assigned layouts, records use-site requests, inserts ensure_* helpers, + and diagnoses unsupported semantics + does not clone/rematerialize producers + does not choose memory-fused layouts by cost + does not inspect sibling users to optimize a value + +layout optimization: + consumes explicit ensure_* helpers + may fold ensure_layout into layout-aware consumers + may clone/rematerialize cheap producers + may sink/hoist materialization through pure elementwise chains + may specialize private function signatures + +vmi-to-vpto: + consumes current op attrs/operands, assigned operand/result layouts, and + explicit helper ops + performs local physical shape and target-capability checks + does not recover layout plans from producers, sibling users, CFG regions, or + callees/callers +``` + +Concrete implementation debt to remove: + +```text +1. Move assignment-side data/mask rematerialization into + vmi-layout-rematerialize. Baseline assignment should insert ensure_* for + mismatched uses. +2. Keep `VMILayoutSupport` as target capability and layout-shape queries, not + as a shared plan table. Group-reduce layout math now lives in + `getPreferredGroupReduceLayoutFact`. Dense cast layout shape now lives in + `getPreferredCastLayoutFact`. Helper materialization gates use + `canMaterializeDataLayout`, `canMaterializeMaskLayout`, and + `canMaterializeMaskGranularity`. +3. Assignment, validation, and lowering may call layout fact helpers, but must + not each independently derive VLaneElems/groupSize/factor/slots rules. +4. Keep store-fold, rematerialization, and sink/hoist as local rewrites over + explicit ensure_* IR. They must not walk sibling users to rediscover why the + helper exists. +5. Update pass descriptions, diagnostics, and tests so "assignment only" output + is legal with helpers, and optimized output is a separate, equivalent IR + form. +``` + +Regression tests should prove the boundary: + +```text +assignment only: + multi-consumer values keep one assigned layout and use ensure_* at mismatched + uses + +fold-consumers: + ensure_layout + store becomes a layout-aware store only when the consumer can + preserve the same row-major memory effect + +rematerialize: + cheap producer + ensure_layout becomes a cloned/rematerialized producer; with + the pass disabled, the ensure_layout form remains legal + +vmi-to-vpto: + rejects any residual need for producer/user context with VMI-LAYOUT-CONTRACT +``` + +## 12. Design Completion Criteria The design is complete only when: ```text -1. every case in vmi-layout-lowering-cases.md maps to a local recipe -2. every local recipe can be emitted without looking at producer/user context +1. every case in vmi-layout-lowering-cases.md maps to assignment requests, + explicit helpers, or a precise diagnostic +2. every VMI-to-VPTO lowering can be emitted without looking at producer/user + context 3. every unsupported case has a precise capability diagnostic -4. every control-flow/function boundary either specializes layout or diagnoses +4. every control-flow/function boundary materializes, specializes in an + optimization pass, or diagnoses 5. every mask has explicit data layout and predicate granularity 6. every positive case has end-to-end lit coverage 7. every simulator-supported positive case has simulator validation diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md index e17c14844b..d2d7b3835d 100644 --- a/docs/designs/vmi-layout-lowering-cases.md +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -83,7 +83,7 @@ G % K == 0 K must fit in the physical vreg element count ``` -`K` is selected by the producer/consumer local recipe. It is not always 8. For +`K` is selected by the producer/consumer layout support rule. It is not always 8. For `VCGADD`-packed results, `K = 8` matches the eight 32B block results written to the low lanes of one destination vreg. For row-local reductions where each logical group already occupies one full 256B vreg, `K = 1` keeps each group's @@ -99,10 +99,11 @@ physical slot block slot_block(g), lane slot_lane(g) All other lanes are undefined for ordinary VMI consumers. They may only be read by group-aware ops that define how to interpret group slots. -## 2. Recipe Selection Rules +## 2. Layout Support Selection Rules -VMI cast ops must not hard-code one physical `vcvt` recipe as their semantic -layout rule. +VMI cast ops must not hard-code one physical `vcvt` lowering as their semantic +layout rule. Layout assignment records the required value layout; target +support queries only answer whether that layout can be materialized or lowered. ```text dense cast: @@ -112,7 +113,7 @@ dense cast: group-slot cast: source/result are both group_slots(G,K). lowering preserves slot_block(g) and slot_lane(g). Width-changing casts are - legal only when a slot-preserving VPTO recipe is registered, or when the cast + legal only when slot-preserving VPTO lowering support exists, or when the cast can be commuted through a later group-aware consumer such as group_broadcast. ``` @@ -171,7 +172,7 @@ the immediately following complete endpoints. 3.16 group_slot_load layout contract complete 3.17 group_broadcast feeding deinterleaved consumer complete 3.18 one value with dense and group-reduce consumers complete/materialization -3.19 S=16 reduce block_elems recipe selection complete/diagnostic +3.19 S=16 reduce block_elems support selection complete/diagnostic 3.20 group_slots control-flow join complete 3.21 S=32 tail with full-tile-readable source complete 3.22 scf.for loop-carried layout complete @@ -752,7 +753,7 @@ row-major store of this layout must be rejected with: VMI-LAYOUT-CONTRACT: pto.vmi.store requires materializing #pto.vmi.layout to contiguous, but no - VPTO block-interleave materialization/store plan is registered. + VPTO block-interleave materialization/store support exists. ``` #### 3.5.3 Reduce Result, Elementwise, Store @@ -1182,8 +1183,8 @@ slot_lane(r) = 0 Trying to canonicalize this result to `slots = 8` would require packing lane 0 from eight different physical vregs into lanes 0..7 of one vreg. This document -does not use that plan. `slots = 1` is the canonical layout for S=64 row-local -group reductions. +does not use that packing transform. `slots = 1` is the canonical layout for +S=64 row-local group reductions. #### 3.7.1 Reduce And Store Group Sums @@ -1464,8 +1465,8 @@ group_off + 0, group_off + 1, group_off + 2, ... Only the first address is necessarily 32B-aligned. The remaining f32 addresses are 4B apart and are not valid for this `vsts` lowering. The compiler must not -accept this as a clean lowering until either a pack-to-slots=8 plan or an -unaligned-store plan is selected. +accept this as a clean lowering until either pack-to-slots=8 materialization +support or unaligned-store support exists. VMI input: @@ -1523,7 +1524,7 @@ layout transition explicit: `group_broadcast` first produces a dense contiguous f32 value, then `pto.vmi.ensure_layout` materializes the deinterleaved=2 f32 view required by dense `f32 -> f16` truncation. A future direct `group_broadcast -> deinterleaved=2` lowering may remove that materialization, -but the `group_broadcast` result layout must make that recipe explicit rather +but the `group_broadcast` result layout must make that support path explicit rather than hiding it inside `truncf` lowering. VPTO lowering result for one full 8-row tile: @@ -1776,13 +1777,13 @@ contract: ```text VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf with group size 32 and num_groups tail 6 requires - materializing #pto.vmi.layout. The registered fast plan + materializing #pto.vmi.layout. The fast lowering support uses vldsx2 DINTLV_B32 over a full 8-row tile. This source is not marked - full-tile-readable, and the stable gather tail plan is not implemented. + full-tile-readable, and the stable gather tail fallback is not implemented. ``` -If a future option enables the stable gather tail plan, the same VMI input may -lower by gathering only the active lanes. Until that plan is registered, the +If a future option enables the stable gather tail fallback, the same VMI input +may lower by gathering only the active lanes. Until that support exists, the converter must not silently issue the full-tile `vldsx2` loads. ### 3.12 Control-Flow Join Before `group_reduce` @@ -1848,7 +1849,8 @@ VPTO lowering result for the join: } ``` -The consumer after the join is the same S=32 reduction plan as section 3.6: +The consumer after the join uses the same S=32 reduction lowering support as +section 3.6: ```text %all_b32 = pto.pge_b32 "PAT_ALL" @@ -1876,7 +1878,7 @@ for r = 0..7: ``` If the two branches cannot be assigned the same layout and no materialization -plan exists before `scf.yield`, the required diagnostic is: +support exists before `scf.yield`, the required diagnostic is: ```text VMI-LAYOUT-CONTRACT: @@ -1917,7 +1919,7 @@ Required diagnostic: VMI-LAYOUT-CONTRACT: pto.vmi.truncf cannot lower from #pto.vmi.layout f32 to f16 because no - slot-preserving width-changing VPTO plan is registered. f32->f16 vcvt writes + slot-preserving width-changing VPTO support exists. f32->f16 vcvt writes even/odd sub-lanes, not lanes 0..7. Use group_broadcast before truncf, or keep the group_store element type as f32. ``` @@ -1936,8 +1938,8 @@ VMI input: pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} ``` -Here `S = 96 / 8 = 12` f32 elements per group. The current VCG-based plans use -32B groups, i.e. 8 f32 elements per row fragment: +Here `S = 96 / 8 = 12` f32 elements per group. The current VCG-based lowering +support uses 32B groups, i.e. 8 f32 elements per row fragment: ```text S = 8 -> one VCGADD block per group @@ -1950,8 +1952,8 @@ Required diagnostic: ```text VMI-LAYOUT-CONTRACT: - pto.vmi.group_reduce_addf with f32 group size 12 has no registered VPTO - layout plan. Supported VCG-based f32 group sizes are 8, 16, 32, and 64. + pto.vmi.group_reduce_addf with f32 group size 12 has no supported VPTO + layout/lowering path. Supported VCG-based f32 group sizes are 8, 16, 32, and 64. A scalar/gather fallback or a rewrite to logical group size 16 with an explicit per-group mask is required. ``` @@ -2291,8 +2293,8 @@ for g = 0..7: out[group_off + g] = rhs_base[rhs_off + g] ``` -If `source_group_stride != 1`, this packed `slots = 8` plan requires a -strided/gather group-slot load materializer. Until that plan is registered, +If `source_group_stride != 1`, this packed `slots = 8` layout requires a +strided/gather group-slot load materializer. Until that support exists, `group_slot_load` with `slots = 8` and non-unit stride must diagnose instead of silently using full-group `group_load`. @@ -2461,8 +2463,9 @@ look through the defining `group_broadcast` and choose a hidden broadcast shape. This case forces layout assignment to handle a solvable use-site conflict. One consumer requires an S=32 group-reduce layout; another consumer requires dense row-major store. This is not semantically illegal. It must be solved by -use-site materialization or producer rematerialization when a registered plan -exists. +explicit use-site materialization. A later optimization pass may fold the +materialization into a store or rematerialize a cheap producer when the required +support exists. VMI input: @@ -2487,10 +2490,10 @@ Assigned layouts: requires #pto.vmi.layout ``` -If `%x` is cheap to rematerialize, layout assignment may clone the producer for -the dense store. Otherwise, if the registry has a `deinterleaved = 4 -> -contiguous` materialization plan, layout assignment may keep `%x` in -`deinterleaved = 4` and insert `ensure_layout` before the dense store. +Baseline layout assignment keeps `%x` in the group-reduce layout and inserts +`ensure_layout` before the dense store use. A later rematerialization pass may +clone the load for the dense store if that is profitable. A later fold-consumer +pass may also fold `ensure_layout + store` into a layout-aware store lowering. VPTO lowering result: @@ -2551,18 +2554,17 @@ for i = 0..255: copy_out[off + i] = base[off + i] ``` -If the `deinterleaved = 4 -> contiguous` plan is not registered, the required -diagnostic is: +If `deinterleaved = 4 -> contiguous` materialization support does not exist, the +required diagnostic is: ```text VMI-LAYOUT-CONTRACT: value %x is required as #pto.vmi.layout by pto.vmi.group_reduce_addf and as #pto.vmi.layout by - pto.vmi.store, but no registered materialization plan exists at the store - use site. + pto.vmi.store, but no materialization support exists at the store use site. ``` -### 3.19 S=16 Reduce `block_elems` Recipe Selection +### 3.19 S=16 Reduce `block_elems` Support Selection S=16 f32 group reduction has two legal dense input layouts: @@ -2576,10 +2578,11 @@ It is also a valid S=16 reduction layout: each physical part contains eight values per row, so `VCGADD` can reduce each part and `VADD` can combine the two partial sums. -`block_elems = 8` is still useful when the producer is a block load plan such -as `BDINTLV` or `vsldb` over 32B row fragments. Layout assignment must select -between these plans by producer/consumer cost. It must not hard-code S=16 -reduce to `block_elems = 8`. +`block_elems = 8` is still useful when the producer is a block load shape such +as `BDINTLV` or `vsldb` over 32B row fragments. Baseline layout assignment must +express any mismatch with an explicit `ensure_layout`; producer rematerialization +or consumer folding can choose the cheaper equivalent form later. Assignment +must not hard-code S=16 reduce to `block_elems = 8`. #### 3.19.1 Continuous S=16 Reduce And Truncf, `block_elems = 1` @@ -2662,7 +2665,7 @@ for i = 0..127: #### 3.19.2 Block-Load Producer Fixed To `block_elems = 8` This is the real conflict case. The value is fixed to `block_elems = 8` -because the producer is a registered block-load plan. A later `truncf` +because the producer uses block-load support. A later `truncf` requires element-parity `block_elems = 1`. VMI input: @@ -2691,7 +2694,8 @@ Assigned layouts before the conflicting `truncf` use: ``` The reduction path is legal and uses the same `vsldb` block-load shape as -section 3.15.2. The `truncf` path is legal only if one of these plans exists: +section 3.15.2. The `truncf` path is legal only if one of these transforms +exists: ```text 1. rematerialize the original memory producer as block_elems=1 @@ -2699,15 +2703,15 @@ section 3.15.2. The `truncf` path is legal only if one of these plans exists: 3. use an explicitly enabled scratch/reload fallback ``` -If no such plan is registered, the required diagnostic is: +If no such transform exists, the required diagnostic is: ```text VMI-LAYOUT-CONTRACT: pto.vmi.truncf requires #pto.vmi.layout, but the source value is - fixed to #pto.vmi.layout by the selected - strided group_load plan. Register a rematerialization or preserving - materialization plan, or avoid consuming this block-loaded value with truncf. + fixed to #pto.vmi.layout by the strided + group_load. Add rematerialization or preserving materialization support, or + avoid consuming this block-loaded value with truncf. ``` ### 3.20 `group_slots` Control-Flow Join @@ -2987,8 +2991,9 @@ for r = 0..7: ### 3.23 `group_broadcast` With Multiple Dense Consumers One `group_slots` value may feed multiple `group_broadcast` uses with different -dense result layout requirements. Layout assignment should rematerialize the -broadcast per use instead of forcing one result layout onto all consumers. +dense result layout requirements. Each `group_broadcast` op has its own result +layout, so layout assignment should type each op at its use site instead of +forcing one result layout onto all consumers. VMI input: @@ -3046,7 +3051,7 @@ layout. It is that each use has an explicit layout boundary: %b_for_cast_split = pto.vmi.ensure_layout %b_for_cast ``` -If a future direct `group_broadcast -> deinterleaved` recipe is added, layout +If a future direct `group_broadcast -> deinterleaved` support path is added, layout assignment may assign `%b_for_mul` or `%b_for_cast` directly to that layout, but the choice must still be visible in the assigned IR. @@ -3498,21 +3503,21 @@ Required diagnostic when the stride is not block-aligned: ```text VMI-LAYOUT-CONTRACT: pto.vmi.group_load group_size 32 with source_group_stride not divisible by - 8 f32 elements cannot use the registered vsldb strided-block plan. Enable a - stable gather plan or choose a block-aligned source_group_stride. + 8 f32 elements cannot use the vsldb strided-block lowering support. Enable a + stable gather fallback or choose a block-aligned source_group_stride. ``` Required assignment rule: ```text -This producer selects the S=32 block-fragment plan: +This producer requires the S=32 block-fragment layout: #pto.vmi.layout It must not be unified with the contiguous-load S=32 plan from section 3.6: #pto.vmi.layout Both layouts are legal inputs to group_reduce_addf S=32, but they require -different producer materialization plans. +different producer materialization/lowering support. ``` ### 3.28 `group_slot_load` `slots = 1` With Aligned Non-Unit Stride @@ -3703,7 +3708,7 @@ Required assignment rule: the per-use typed mask materialization inserted by vmi-layout-assignment. For a rematerializable `create_mask`, assignment may clone it as b32/b16 masks. For a non-rematerializable mask producer, assignment must insert -`ensure_mask_granularity` or diagnose if no materialization plan is registered. +`ensure_mask_granularity` or diagnose if no materialization support exists. ``` ### 3.30 `masked_load` Tail Without Padding @@ -4002,10 +4007,10 @@ S=32 reduce over 8 groups: #pto.vmi.layout ``` -The program is semantically legal. Layout assignment must solve it by cloning -or rematerializing the cheap load for one use, or by inserting an explicit -registered materialization plan. `vmi-to-vpto` must not inspect both users and -choose one locally. +The program is semantically legal. Baseline layout assignment solves it by +inserting an explicit use-site `ensure_layout`. A later optimization pass may +clone or rematerialize the cheap load for one use. `vmi-to-vpto` must not +inspect both users and choose one locally. VMI input: @@ -4114,11 +4119,11 @@ for r = 0..7: Required assignment rule: ```text -If a cheap producer such as load can produce both requested layouts, clone or -rematerialize it at the use sites and assign each clone independently. If the -producer is not rematerializable and no deinterleaved=2 <-> deinterleaved=4 -materialization plan is registered, emit a layout-contract diagnostic naming -both consumers and both required layouts. +Baseline assignment inserts `ensure_layout` at the mismatched use. A later +rematerialization pass may clone a cheap producer such as load and assign each +clone independently. If no deinterleaved=2 <-> deinterleaved=4 materialization +support exists, emit a layout-contract diagnostic naming both consumers and +both required layouts. ``` ### 3.34 S=64 Group-Slot Result `f32 -> f16` Cast @@ -4380,7 +4385,7 @@ VPTO lowering result: pto.vsts %out16_block, %out16[%group_off16], %slot8 {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask -// Row-local S=64 RHS: rematerialize the same scalar stream into one lane-0 +// Row-local S=64 RHS: a separate group_slot_load op produces one lane-0 // value per physical row-local result. %rhs64_r = pto.vsldb %rhs_base[%rhs_off_plus_r], %c0_i16, %c0_i16, %one_b32 : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> @@ -4406,9 +4411,12 @@ for r = 0..7: Required assignment rule: ```text -`group_slot_load` is cheaply rematerializable. If two use sites request -different `group_slots` layouts, clone/rematerialize the load per use. Do not -invent a common layout or make `vmi-to-vpto` inspect both users. +`group_slot_load` is a memory op, so the baseline rematerialization pass must +not clone it as a generic cheap producer. If two use sites need different +`group_slots` layouts, the legal first-stage shape is to write two explicit +`group_slot_load` ops, as above, or to introduce a future load-cloning +optimization with an explicit memory-safety proof. Do not invent a common +layout or make `vmi-to-vpto` inspect both users. ``` ### 3.37 S=64 `group_store` With Non-Unit Output Stride @@ -4467,15 +4475,15 @@ Required assignment rule: ```text If `group_store` has non-unit row_stride and the source can legally use `slots = 1`, assignment may select `slots = 1` to keep the store legal. If the -source is fixed to `slots = 8`, the current target plan must diagnose unless a -strided packed store materializer is registered. +source is fixed to `slots = 8`, current target support must diagnose unless a +strided packed store materializer exists. ``` ### 3.38 Multi-Tile S=32 `group_reduce` The S=32 plan is not only a one-tile special case. For more than eight groups, layout assignment keeps the same layout and `vmi-to-vpto` emits the same -8-row tile recipe for each physical tile. +8-row tile lowering sequence for each physical tile. VMI input: @@ -4663,10 +4671,11 @@ though both have four physical parts. ### 3.40 Scalar Broadcast Feeding Dense And Grouped Users This case fixes the rule for ordinary scalar broadcasts. A scalar broadcast is -not born with a physical layout. Layout assignment may either rematerialize it -per use, or assign the transfer-equivalent producer chain to the non-contiguous -layout requested by the grouped consumer and insert an explicit materialization -at the dense store use. The latter is the concrete plan below. +not born with a physical layout. Baseline layout assignment assigns the +transfer-equivalent producer chain to the non-contiguous layout requested by the +grouped consumer and inserts an explicit materialization at the dense store use. +The later `vmi-layout-rematerialize` pass may replace that helper with a cloned +broadcast when profitable. VMI input: @@ -4785,20 +4794,21 @@ for r = 0..7: Required assignment rule: ```text -`broadcast` is layout-transparent and cheaply rematerializable, but assignment -does not have to force a separate contiguous broadcast just because a dense -store exists. It may choose a common deinterleaved compute layout for -transfer-equivalent elementwise ops and insert `ensure_layout` at the dense -store. The required invariant is that this choice is explicit in the assigned -IR; `vmi-to-vpto` must not infer it by inspecting both users. +`broadcast` is layout-transparent and cheaply rematerializable by the optional +`vmi-layout-rematerialize` pass, but baseline assignment does not have to force +a separate contiguous broadcast just because a dense store exists. It may +choose a common deinterleaved compute layout for transfer-equivalent elementwise +ops and insert `ensure_layout` at the dense store. The required invariant is +that this choice is explicit in the assigned IR; `vmi-to-vpto` must not infer it +by inspecting both users. ``` ### 3.41 Non-Rematerializable Value With Incompatible Users This is the non-cheap counterpart to section 3.18. A `masked_load` has explicit mask and passthrough semantics, so layout assignment should not clone it as a -normal cheap load unless the registry explicitly marks that clone legal. The -conflict is solved by inserting `ensure_layout` at one use site. +normal cheap load unless a dedicated rematerialization rule proves that clone +legal. The conflict is solved by inserting `ensure_layout` at one use site. VMI input: @@ -4898,10 +4908,11 @@ for r = 0..7: Required assignment rule: ```text -For non-rematerializable producers, assignment must insert a registered -use-site materialization plan, such as contiguous -> deinterleaved=4. If no -plan exists, it must diagnose at assignment time. `vmi-to-vpto` must not clone -the masked_load or choose a materialization after seeing both users. +For non-rematerializable producers, assignment must insert an explicit use-site +materialization helper, such as contiguous -> deinterleaved=4. If that helper +has no supported materialization, the layout gate must diagnose before +vmi-to-vpto. `vmi-to-vpto` must not clone the masked_load or choose a +materialization after seeing both users. ``` ### 3.42 `group_slots` `scf.for` Loop-Carried Accumulator @@ -5267,9 +5278,9 @@ one contiguous value for `masked_load`, and one deinterleaved value for `create_group_mask` by materializing the contiguous grouped predicate chunks and then applying `pdintlv_b32` in the same tree shape as the data `vdintlv`. It does not walk from `group_reduce_addf` to the mask producer to -choose or reject the recipe. +choose or reject the support path. -Assignment may select a deinterleaved S=32 load plan only when the rounded +Assignment may select a deinterleaved S=32 load layout only when the rounded physical reads are memory-safe; otherwise it must diagnose or use a future stable gather fallback. @@ -5437,7 +5448,7 @@ Optimization pass result: ```text // vmi-layout-fold-consumers may remove both ensure_layout ops if the target -// supports a store recipe that consumes deinterleaved=2 and writes contiguous +// supports store lowering that consumes deinterleaved=2 and writes contiguous // row-major memory. pto.vmi.store %t1, %out1[%off] pto.vmi.store %w, %out2[%off] @@ -6002,7 +6013,7 @@ pto.vmi.group_reduce_addi %x8, %mask -> verifier or layout-contract diagnostic ``` -An optimized row-local i8 full-chunk recipe may be added later for +An optimized row-local i8 full-chunk lowering path may be added later for `S = 256` by using widening `vcadd`, but that requires a widening `group_slots` result contract and must not change the baseline cast-to-accumulator semantics above. @@ -6016,6 +6027,6 @@ accumulator computation: pto.vmi.group_store %sum8, %out_i8[%group_off], %c1 {num_groups = 8} ``` -That packed group-slot `trunci` path is not a baseline recipe yet; the -implementation must either define a slot-wise VCVTII recipe or diagnose at +That packed group-slot `trunci` path is not baseline lowering support yet; the +implementation must either define slot-wise VCVTII lowering support or diagnose at layout assignment. diff --git a/include/PTO/Transforms/Passes.h b/include/PTO/Transforms/Passes.h index 902d8a2499..061dcd7626 100644 --- a/include/PTO/Transforms/Passes.h +++ b/include/PTO/Transforms/Passes.h @@ -103,7 +103,8 @@ std::unique_ptr createPTOValidateVPTOEmissionIRPass(); LogicalResult validateVMIProducerBoundaryIR(ModuleOp module, llvm::raw_ostream *diagOS = nullptr); LogicalResult validateVMILayoutAssignedIR(ModuleOp module, - llvm::raw_ostream *diagOS = nullptr); + llvm::raw_ostream *diagOS = nullptr, + bool verifyHelperSupport = true); std::unique_ptr createPTOValidateVMIIRPass(); std::unique_ptr createPTOValidateVMILayoutIRPass(); std::unique_ptr createVMILayoutAssignmentPass(); diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index 91dc8bfc83..354d0b6d66 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -653,10 +653,12 @@ def PTOValidateVMILayoutIR a concrete VMI layout, every VMI mask must have concrete b8/b16/b32 granularity and layout, physical VPTO register values must not appear yet, and VMI typed values must stay inside VMI semantic/helper or structural ops. - vmi-to-vpto chooses deterministic local recipes from the current op's attrs, - operand/result types, layouts, and operand values; non-local choices must - be represented as explicit attrs, helper ops, cloned producers, or - diagnostics before this stage. + vmi-to-vpto chooses deterministic lowerings from the current op's attrs, + operand/result types, layouts, and operand values. Non-local choices must + be represented as explicit attrs, helper ops, or diagnostics before this + stage. Later VMI layout optimization passes may replace helpers with + cloned/rematerialized producers, but the layout gate must not depend on + hidden producer/user context. }]; let constructor = "mlir::pto::createPTOValidateVMILayoutIRPass()"; let dependentDialects = ["mlir::cf::ControlFlowDialect", diff --git a/include/PTO/Transforms/VMILayoutSupport.h b/include/PTO/Transforms/VMILayoutSupport.h new file mode 100644 index 0000000000..9a274a2a9b --- /dev/null +++ b/include/PTO/Transforms/VMILayoutSupport.h @@ -0,0 +1,287 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +//===- VMILayoutSupport.h - VMI layout support queries ------*- C++ -*-===// +//===----------------------------------------------------------------------===// + +#ifndef PTO_TRANSFORMS_VMILAYOUTSUPPORT_H +#define PTO_TRANSFORMS_VMILAYOUTSUPPORT_H + +#include "PTO/IR/PTO.h" +#include "mlir/Support/LLVM.h" + +#include + +namespace mlir::pto { + +class VMITargetCapabilityRegistry; + +enum class VMIContiguousStoreSupportKind { + ContiguousVsts, + Deinterleaved2Vstsx2, + DeinterleavedMaterializeThenVsts, +}; + +struct VMIContiguousStoreSupport { + VMIContiguousStoreSupportKind kind = + VMIContiguousStoreSupportKind::ContiguousVsts; +}; + +enum class VMILayoutMaterializationSupportKind { + Identity, + ContiguousToDeinterleaved, + DeinterleavedToContiguous, +}; + +struct VMILayoutMaterializationSupport { + VMILayoutMaterializationSupportKind kind = + VMILayoutMaterializationSupportKind::Identity; +}; + +enum class VMIMaskGranularityMaterializationSupportKind { + Identity, + PredicateCast, +}; + +struct VMIMaskGranularityMaterializationSupport { + VMIMaskGranularityMaterializationSupportKind kind = + VMIMaskGranularityMaterializationSupportKind::Identity; +}; + +enum class VMICastLayoutKind { + Widen2x, + Widen4x, + Narrow2x, + Narrow4x, +}; + +struct VMICastLayoutFact { + VMICastLayoutKind kind = VMICastLayoutKind::Widen2x; + VMILayoutAttr sourceLayout; + VMILayoutAttr resultLayout; + int64_t sourceBits = 0; + int64_t resultBits = 0; + int64_t factor = 1; +}; + +enum class VMIGroupSlotLoadSupportKind { + Slots8UnitStrideVsldb, + Slots1AlignedLane0Vsldb, +}; + +struct VMIGroupSlotLoadSupport { + VMIGroupSlotLoadSupportKind kind = + VMIGroupSlotLoadSupportKind::Slots8UnitStrideVsldb; +}; + +enum class VMIGroupLoadSupportKind { + S16Block8Vsldb, + S32Block8Vsldb, +}; + +struct VMIGroupLoadSupport { + VMIGroupLoadSupportKind kind = VMIGroupLoadSupportKind::S16Block8Vsldb; +}; + +enum class VMIGroupSlotsStoreSupportKind { + Slots8UnitStrideVsts, + Slots1AlignedLane0Vsts, +}; + +struct VMIGroupSlotsStoreSupport { + VMIGroupSlotsStoreSupportKind kind = + VMIGroupSlotsStoreSupportKind::Slots8UnitStrideVsts; +}; + +enum class VMIGroupReduceLayoutKind { + OneVLane, + TwoVLane, + FourVLane, + RowLocal, +}; + +struct VMIGroupReduceLayoutFact { + VMIGroupReduceLayoutKind kind = VMIGroupReduceLayoutKind::OneVLane; + VMILayoutAttr sourceLayout; + VMILayoutAttr maskLayout; + VMILayoutAttr resultLayout; + int64_t groupSize = 0; + int64_t lanesPerPart = 0; + int64_t vlaneElems = 0; +}; + +enum class VMIGroupReduceAddFSupportKind { + OneVLaneVcgadd, + TwoVLaneDeinterleaved2VcgaddVadd, + FourVLaneDeinterleaved4VcgaddTree, + ContiguousVcaddRows, +}; + +struct VMIGroupReduceAddFSupport { + VMIGroupReduceAddFSupportKind kind = + VMIGroupReduceAddFSupportKind::OneVLaneVcgadd; +}; + +enum class VMIGroupBroadcastSupportKind { + GroupSlotsVselr, +}; + +struct VMIGroupBroadcastSupport { + VMIGroupBroadcastSupportKind kind = + VMIGroupBroadcastSupportKind::GroupSlotsVselr; +}; + +enum class VMITruncFSupportKind { + Deinterleaved2F32ToContiguousF16, + Deinterleaved4F32ToContiguousF8, + GroupSlots1F32ToF16, +}; + +struct VMITruncFSupport { + VMITruncFSupportKind kind = + VMITruncFSupportKind::Deinterleaved2F32ToContiguousF16; +}; + +enum class VMIExtFSupportKind { + ContiguousF16ToDeinterleaved2F32, + ContiguousF8ToDeinterleaved4F32, +}; + +struct VMIExtFSupport { + VMIExtFSupportKind kind = + VMIExtFSupportKind::ContiguousF16ToDeinterleaved2F32; +}; + +enum class VMITruncISupportKind { + Deinterleaved2I32ToContiguousI16, + Deinterleaved4I32ToContiguousI8, + GroupSlots1I32ToI16, +}; + +struct VMITruncISupport { + VMITruncISupportKind kind = + VMITruncISupportKind::Deinterleaved2I32ToContiguousI16; +}; + +enum class VMIExtISupportKind { + ContiguousI16ToDeinterleaved2I32, + ContiguousI8ToDeinterleaved4I32, +}; + +struct VMIExtISupport { + VMIExtISupportKind kind = + VMIExtISupportKind::ContiguousI16ToDeinterleaved2I32; +}; + +enum class VMIBitcastSupportKind { + PerPartVbitcast, +}; + +struct VMIBitcastSupport { + VMIBitcastSupportKind kind = VMIBitcastSupportKind::PerPartVbitcast; +}; + +class VMILayoutSupport { +public: + FailureOr + getContiguousStoreSupport(VMIVRegType valueType, + std::string *reason = nullptr) const; + + LogicalResult canFoldContiguousStoreMaterialization( + VMIVRegType sourceType, VMIVRegType resultType, + std::string *reason = nullptr) const; + + FailureOr + getDataLayoutMaterializationSupport(VMIVRegType sourceType, + VMIVRegType resultType, + std::string *reason = nullptr) const; + + LogicalResult canMaterializeDataLayout(VMIVRegType sourceType, + VMIVRegType resultType, + std::string *reason = nullptr) const; + + FailureOr + getMaskLayoutMaterializationSupport(VMIMaskType sourceType, + VMIMaskType resultType, + std::string *reason = nullptr) const; + + LogicalResult canMaterializeMaskLayout(VMIMaskType sourceType, + VMIMaskType resultType, + std::string *reason = nullptr) const; + + FailureOr + getMaskGranularityMaterializationSupport(VMIMaskType sourceType, + VMIMaskType resultType, + std::string *reason = nullptr) const; + + LogicalResult canMaterializeMaskGranularity( + VMIMaskType sourceType, VMIMaskType resultType, + std::string *reason = nullptr) const; + + FailureOr + getPreferredCastLayoutFact(VMIVRegType sourceType, VMIVRegType resultType, + std::string *reason = nullptr) const; + + FailureOr + getGroupSlotLoadSupport(const VMITargetCapabilityRegistry &capabilities, + VMIGroupSlotLoadOp op, + std::string *reason = nullptr) const; + + FailureOr + getGroupLoadSupport(const VMITargetCapabilityRegistry &capabilities, + VMIGroupLoadOp op, + std::string *reason = nullptr) const; + + FailureOr + getGroupSlotsStoreSupport(const VMITargetCapabilityRegistry &capabilities, + VMIGroupStoreOp op, + std::string *reason = nullptr) const; + + FailureOr + getPreferredGroupReduceLayoutFact(VMIVRegType sourceType, int64_t numGroups, + std::string *reason = nullptr) const; + + FailureOr + getGroupReduceAddFSupport(const VMITargetCapabilityRegistry &capabilities, + VMIGroupReduceAddFOp op, + std::string *reason = nullptr) const; + + FailureOr + getGroupReduceAddISupport(const VMITargetCapabilityRegistry &capabilities, + VMIGroupReduceAddIOp op, + std::string *reason = nullptr) const; + + FailureOr + getGroupBroadcastSupport(const VMITargetCapabilityRegistry &capabilities, + VMIGroupBroadcastOp op, + std::string *reason = nullptr) const; + + FailureOr + getTruncFSupport(VMITruncFOp op, std::string *reason = nullptr) const; + + FailureOr + getExtFSupport(VMIExtFOp op, std::string *reason = nullptr) const; + + FailureOr + getExtSISupport(VMIExtSIOp op, std::string *reason = nullptr) const; + + FailureOr + getExtUISupport(VMIExtUIOp op, std::string *reason = nullptr) const; + + FailureOr + getTruncISupport(VMITruncIOp op, std::string *reason = nullptr) const; + + FailureOr + getBitcastSupport(VMIBitcastOp op, std::string *reason = nullptr) const; +}; + +} // namespace mlir::pto + +#endif // PTO_TRANSFORMS_VMILAYOUTSUPPORT_H diff --git a/include/PTO/Transforms/VMILocalRecipeRegistry.h b/include/PTO/Transforms/VMILocalRecipeRegistry.h deleted file mode 100644 index 8472a32c4c..0000000000 --- a/include/PTO/Transforms/VMILocalRecipeRegistry.h +++ /dev/null @@ -1,234 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under -// the terms and conditions of CANN Open Software License Agreement Version 2.0 -// (the "License"). Please refer to the License for details. You may not use -// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON -// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS -// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository -// for the full text of the License. - -//===- VMILocalRecipeRegistry.h - VMI local recipe queries ------*- C++ -*-===// -//===----------------------------------------------------------------------===// - -#ifndef PTO_TRANSFORMS_VMILOCALRECIPEREGISTRY_H -#define PTO_TRANSFORMS_VMILOCALRECIPEREGISTRY_H - -#include "PTO/IR/PTO.h" -#include "mlir/Support/LLVM.h" - -#include - -namespace mlir::pto { - -class VMITargetCapabilityRegistry; - -enum class VMIContiguousStoreRecipeKind { - ContiguousVsts, - Deinterleaved2Vstsx2, - DeinterleavedMaterializeThenVsts, -}; - -struct VMIContiguousStoreRecipe { - VMIContiguousStoreRecipeKind kind = - VMIContiguousStoreRecipeKind::ContiguousVsts; -}; - -enum class VMILayoutMaterializationRecipeKind { - Identity, - ContiguousToDeinterleaved, - DeinterleavedToContiguous, -}; - -struct VMILayoutMaterializationRecipe { - VMILayoutMaterializationRecipeKind kind = - VMILayoutMaterializationRecipeKind::Identity; -}; - -enum class VMIMaskGranularityMaterializationRecipeKind { - Identity, - PredicateCast, -}; - -struct VMIMaskGranularityMaterializationRecipe { - VMIMaskGranularityMaterializationRecipeKind kind = - VMIMaskGranularityMaterializationRecipeKind::Identity; -}; - -enum class VMIGroupSlotLoadRecipeKind { - Slots8UnitStrideVsldb, - Slots1AlignedLane0Vsldb, -}; - -struct VMIGroupSlotLoadRecipe { - VMIGroupSlotLoadRecipeKind kind = - VMIGroupSlotLoadRecipeKind::Slots8UnitStrideVsldb; -}; - -enum class VMIGroupLoadRecipeKind { - S16Block8Vsldb, - S32Block8Vsldb, -}; - -struct VMIGroupLoadRecipe { - VMIGroupLoadRecipeKind kind = VMIGroupLoadRecipeKind::S16Block8Vsldb; -}; - -enum class VMIGroupSlotsStoreRecipeKind { - Slots8UnitStrideVsts, - Slots1AlignedLane0Vsts, -}; - -struct VMIGroupSlotsStoreRecipe { - VMIGroupSlotsStoreRecipeKind kind = - VMIGroupSlotsStoreRecipeKind::Slots8UnitStrideVsts; -}; - -enum class VMIGroupReduceAddFRecipeKind { - OneVLaneVcgadd, - TwoVLaneDeinterleaved2VcgaddVadd, - FourVLaneDeinterleaved4VcgaddTree, - ContiguousVcaddRows, -}; - -struct VMIGroupReduceAddFRecipe { - VMIGroupReduceAddFRecipeKind kind = - VMIGroupReduceAddFRecipeKind::OneVLaneVcgadd; -}; - -enum class VMIGroupBroadcastRecipeKind { - GroupSlotsVselr, -}; - -struct VMIGroupBroadcastRecipe { - VMIGroupBroadcastRecipeKind kind = - VMIGroupBroadcastRecipeKind::GroupSlotsVselr; -}; - -enum class VMITruncFRecipeKind { - Deinterleaved2F32ToContiguousF16, - Deinterleaved4F32ToContiguousF8, - GroupSlots1F32ToF16, -}; - -struct VMITruncFRecipe { - VMITruncFRecipeKind kind = - VMITruncFRecipeKind::Deinterleaved2F32ToContiguousF16; -}; - -enum class VMIExtFRecipeKind { - ContiguousF16ToDeinterleaved2F32, - ContiguousF8ToDeinterleaved4F32, -}; - -struct VMIExtFRecipe { - VMIExtFRecipeKind kind = - VMIExtFRecipeKind::ContiguousF16ToDeinterleaved2F32; -}; - -enum class VMITruncIRecipeKind { - Deinterleaved2I32ToContiguousI16, - Deinterleaved4I32ToContiguousI8, - GroupSlots1I32ToI16, -}; - -struct VMITruncIRecipe { - VMITruncIRecipeKind kind = - VMITruncIRecipeKind::Deinterleaved2I32ToContiguousI16; -}; - -enum class VMIExtIRecipeKind { - ContiguousI16ToDeinterleaved2I32, - ContiguousI8ToDeinterleaved4I32, -}; - -struct VMIExtIRecipe { - VMIExtIRecipeKind kind = - VMIExtIRecipeKind::ContiguousI16ToDeinterleaved2I32; -}; - -enum class VMIBitcastRecipeKind { - PerPartVbitcast, -}; - -struct VMIBitcastRecipe { - VMIBitcastRecipeKind kind = VMIBitcastRecipeKind::PerPartVbitcast; -}; - -class VMILocalRecipeRegistry { -public: - FailureOr - getContiguousStoreRecipe(VMIVRegType valueType, - std::string *reason = nullptr) const; - - LogicalResult canFoldContiguousStoreMaterialization( - VMIVRegType sourceType, VMIVRegType resultType, - std::string *reason = nullptr) const; - - FailureOr - getDataLayoutMaterializationRecipe(VMIVRegType sourceType, - VMIVRegType resultType, - std::string *reason = nullptr) const; - - FailureOr - getMaskLayoutMaterializationRecipe(VMIMaskType sourceType, - VMIMaskType resultType, - std::string *reason = nullptr) const; - - FailureOr - getMaskGranularityMaterializationRecipe(VMIMaskType sourceType, - VMIMaskType resultType, - std::string *reason = nullptr) const; - - FailureOr - getGroupSlotLoadRecipe(const VMITargetCapabilityRegistry &capabilities, - VMIGroupSlotLoadOp op, - std::string *reason = nullptr) const; - - FailureOr - getGroupLoadRecipe(const VMITargetCapabilityRegistry &capabilities, - VMIGroupLoadOp op, - std::string *reason = nullptr) const; - - FailureOr - getGroupSlotsStoreRecipe(const VMITargetCapabilityRegistry &capabilities, - VMIGroupStoreOp op, - std::string *reason = nullptr) const; - - FailureOr - getGroupReduceAddFRecipe(const VMITargetCapabilityRegistry &capabilities, - VMIGroupReduceAddFOp op, - std::string *reason = nullptr) const; - - FailureOr - getGroupReduceAddIRecipe(const VMITargetCapabilityRegistry &capabilities, - VMIGroupReduceAddIOp op, - std::string *reason = nullptr) const; - - FailureOr - getGroupBroadcastRecipe(const VMITargetCapabilityRegistry &capabilities, - VMIGroupBroadcastOp op, - std::string *reason = nullptr) const; - - FailureOr - getTruncFRecipe(VMITruncFOp op, std::string *reason = nullptr) const; - - FailureOr - getExtFRecipe(VMIExtFOp op, std::string *reason = nullptr) const; - - FailureOr - getExtSIRecipe(VMIExtSIOp op, std::string *reason = nullptr) const; - - FailureOr - getExtUIRecipe(VMIExtUIOp op, std::string *reason = nullptr) const; - - FailureOr - getTruncIRecipe(VMITruncIOp op, std::string *reason = nullptr) const; - - FailureOr - getBitcastRecipe(VMIBitcastOp op, std::string *reason = nullptr) const; -}; - -} // namespace mlir::pto - -#endif // PTO_TRANSFORMS_VMILOCALRECIPEREGISTRY_H diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index 9dbad686a4..69d8cd212a 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -38,7 +38,7 @@ add_mlir_dialect_library(PTOTransforms VMILegalizeArithSelect.cpp VMILayoutAssignment.cpp VMILayoutFoldConsumers.cpp - VMILocalRecipeRegistry.cpp + VMILayoutSupport.cpp VMILayoutRematerialize.cpp VMILayoutSinkMaterialization.cpp VMIToVPTO.cpp diff --git a/lib/PTO/Transforms/PTOValidateVMIIR.cpp b/lib/PTO/Transforms/PTOValidateVMIIR.cpp index 7234084c47..6fdf6acf07 100644 --- a/lib/PTO/Transforms/PTOValidateVMIIR.cpp +++ b/lib/PTO/Transforms/PTOValidateVMIIR.cpp @@ -12,7 +12,7 @@ #include "PTO/IR/PTO.h" #include "PTO/IR/VMIUtils.h" #include "PTO/Transforms/Passes.h" -#include "PTO/Transforms/VMILocalRecipeRegistry.h" +#include "PTO/Transforms/VMILayoutSupport.h" #include "PTO/Transforms/VMITargetCapabilities.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" @@ -22,6 +22,7 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/Pass/Pass.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/raw_ostream.h" @@ -170,6 +171,33 @@ LogicalResult emitLayoutContract(Operation *op, llvm::raw_ostream *diagOS, return failure(); } +LogicalResult emitLayoutSupportContract(Operation *op, + llvm::raw_ostream *diagOS, + Twine message, StringRef reason) { + std::string text; + llvm::raw_string_ostream os(text); + os << message << ": " << reason; + + bool printedAny = false; + auto printValueType = [&](StringRef kind, int64_t index, Type type) { + if (!isVMIType(type)) + return; + if (!printedAny) { + os << "; VMI types:"; + printedAny = true; + } + os << " " << kind << "#" << index << "=" << type; + }; + + for (auto [index, operand] : llvm::enumerate(op->getOperands())) + printValueType("operand", static_cast(index), operand.getType()); + for (auto [index, result] : llvm::enumerate(op->getResults())) + printValueType("result", static_cast(index), result.getType()); + + os.flush(); + return emitLayoutContract(op, diagOS, text); +} + LogicalResult emitHelperMaterializationContract(Operation *helper, Type sourceType, Type resultType, @@ -179,7 +207,7 @@ LogicalResult emitHelperMaterializationContract(Operation *helper, auto emitFallback = [&]() { return emitLayoutContract( helper, diagOS, - Twine(helperName) + " has no registered materialization recipe: " + + Twine(helperName) + " has no registered materialization support: " + reason); }; @@ -192,7 +220,7 @@ LogicalResult emitHelperMaterializationContract(Operation *helper, llvm::raw_string_ostream os(message); os << requester->getName() << " operand #" << use.getOperandNumber() << " has type " << sourceType << " but requires " << resultType << "; " - << helperName << " has no registered materialization recipe: " << reason; + << helperName << " has no registered materialization support: " << reason; os.flush(); InFlightDiagnostic diag = @@ -395,10 +423,10 @@ LogicalResult verifyLayoutAssignedOperationTypes(Operation *op, return success(); } -LogicalResult verifyLayoutHelperRecipe(Operation *op, +LogicalResult verifyLayoutHelperSupport(Operation *op, llvm::raw_ostream *diagOS); -LogicalResult verifyLayoutSemanticRecipe(Operation *op, +LogicalResult verifyLayoutSemanticSupport(Operation *op, llvm::raw_ostream *diagOS); LogicalResult verifyOperationBoundary(Operation *op, @@ -422,7 +450,8 @@ LogicalResult verifyOperationBoundary(Operation *op, } LogicalResult verifyLayoutAssignedOperation(Operation *op, - llvm::raw_ostream *diagOS) { + llvm::raw_ostream *diagOS, + bool verifyHelperSupports = true) { if (failed(verifyLayoutAssignedOperationTypes(op, diagOS))) return failure(); @@ -431,14 +460,15 @@ LogicalResult verifyLayoutAssignedOperation(Operation *op, if (isVMIHelperOp(op)) { if (isVMILayoutHelperOp(op)) - return verifyLayoutHelperRecipe(op, diagOS); + return verifyHelperSupports ? verifyLayoutHelperSupport(op, diagOS) + : success(); return emitInvariant( op, diagOS, "VMI pack/unpack helper appears before VMI-to-VPTO physicalization"); } if (isVMISemanticOp(op)) - return verifyLayoutSemanticRecipe(op, diagOS); + return verifyLayoutSemanticSupport(op, diagOS); if (isStructuralOp(op)) return success(); @@ -446,17 +476,16 @@ LogicalResult verifyLayoutAssignedOperation(Operation *op, "VMI typed value is used by a non-VMI semantic op"); } -LogicalResult verifyLayoutHelperRecipe(Operation *op, +LogicalResult verifyLayoutHelperSupport(Operation *op, llvm::raw_ostream *diagOS) { - VMILocalRecipeRegistry recipes; + VMILayoutSupport supports; if (auto ensure = dyn_cast(op)) { auto sourceType = cast(ensure.getSource().getType()); auto resultType = cast(ensure.getResult().getType()); std::string reason; - if (failed(recipes.getDataLayoutMaterializationRecipe(sourceType, - resultType, - &reason))) + if (failed(supports.canMaterializeDataLayout(sourceType, resultType, + &reason))) return emitHelperMaterializationContract( op, sourceType, resultType, "pto.vmi.ensure_layout", reason, diagOS); return success(); @@ -466,9 +495,8 @@ LogicalResult verifyLayoutHelperRecipe(Operation *op, auto sourceType = cast(ensure.getSource().getType()); auto resultType = cast(ensure.getResult().getType()); std::string reason; - if (failed(recipes.getMaskLayoutMaterializationRecipe(sourceType, - resultType, - &reason))) + if (failed(supports.canMaterializeMaskLayout(sourceType, resultType, + &reason))) return emitHelperMaterializationContract( op, sourceType, resultType, "pto.vmi.ensure_mask_layout", reason, diagOS); @@ -479,12 +507,12 @@ LogicalResult verifyLayoutHelperRecipe(Operation *op, auto sourceType = cast(ensure.getSource().getType()); auto resultType = cast(ensure.getResult().getType()); std::string reason; - if (failed(recipes.getMaskGranularityMaterializationRecipe( - sourceType, resultType, &reason))) + if (failed(supports.canMaterializeMaskGranularity(sourceType, resultType, + &reason))) return emitLayoutContract( op, diagOS, Twine("pto.vmi.ensure_mask_granularity has no registered " - "materialization recipe: ") + + "materialization support: ") + reason); return success(); } @@ -492,9 +520,9 @@ LogicalResult verifyLayoutHelperRecipe(Operation *op, return success(); } -LogicalResult verifyLayoutSemanticRecipe(Operation *op, +LogicalResult verifyLayoutSemanticSupport(Operation *op, llvm::raw_ostream *diagOS) { - VMILocalRecipeRegistry recipes; + VMILayoutSupport supports; VMITargetCapabilityRegistry capabilities; if (auto store = dyn_cast(op)) { @@ -504,12 +532,11 @@ LogicalResult verifyLayoutSemanticRecipe(Operation *op, return success(); std::string reason; - if (failed(recipes.getContiguousStoreRecipe(valueType, &reason))) - return emitLayoutContract( + if (failed(supports.getContiguousStoreSupport(valueType, &reason))) + return emitLayoutSupportContract( op, diagOS, - Twine("pto.vmi.store has no registered contiguous-memory local " - "recipe: ") + - reason); + "pto.vmi.store has no registered contiguous-memory layout support", + reason); return success(); } @@ -520,12 +547,12 @@ LogicalResult verifyLayoutSemanticRecipe(Operation *op, return success(); std::string reason; - if (failed(recipes.getContiguousStoreRecipe(valueType, &reason))) - return emitLayoutContract( + if (failed(supports.getContiguousStoreSupport(valueType, &reason))) + return emitLayoutSupportContract( op, diagOS, - Twine("pto.vmi.tile_write has no registered contiguous-memory local " - "recipe: ") + - reason); + "pto.vmi.tile_write has no registered contiguous-memory layout " + "support", + reason); return success(); } @@ -537,21 +564,19 @@ LogicalResult verifyLayoutSemanticRecipe(Operation *op, return success(); std::string reason; - if (failed(recipes.getGroupLoadRecipe(capabilities, load, &reason))) - return emitLayoutContract( + if (failed(supports.getGroupLoadSupport(capabilities, load, &reason))) + return emitLayoutSupportContract( op, diagOS, - Twine("pto.vmi.group_load has no registered block8 local recipe: ") + - reason); + "pto.vmi.group_load has no registered block8 layout support", reason); return success(); } if (auto load = dyn_cast(op)) { std::string reason; - if (failed(recipes.getGroupSlotLoadRecipe(capabilities, load, &reason))) - return emitLayoutContract( + if (failed(supports.getGroupSlotLoadSupport(capabilities, load, &reason))) + return emitLayoutSupportContract( op, diagOS, - Twine("pto.vmi.group_slot_load has no registered local recipe: ") + - reason); + "pto.vmi.group_slot_load has no registered layout support", reason); return success(); } @@ -562,12 +587,11 @@ LogicalResult verifyLayoutSemanticRecipe(Operation *op, return success(); std::string reason; - if (failed(recipes.getGroupSlotsStoreRecipe(capabilities, store, &reason))) - return emitLayoutContract( + if (failed(supports.getGroupSlotsStoreSupport(capabilities, store, &reason))) + return emitLayoutSupportContract( op, diagOS, - Twine("pto.vmi.group_store has no registered group_slots local " - "recipe: ") + - reason); + "pto.vmi.group_store has no registered group_slots layout support", + reason); return success(); } @@ -578,13 +602,13 @@ LogicalResult verifyLayoutSemanticRecipe(Operation *op, return success(); std::string reason; - if (failed(recipes.getGroupReduceAddFRecipe(capabilities, reduce, + if (failed(supports.getGroupReduceAddFSupport(capabilities, reduce, &reason))) - return emitLayoutContract( + return emitLayoutSupportContract( op, diagOS, - Twine("pto.vmi.group_reduce_addf has no registered group_slots " - "local recipe: ") + - reason); + "pto.vmi.group_reduce_addf has no registered group_slots layout " + "support", + reason); return success(); } @@ -595,39 +619,37 @@ LogicalResult verifyLayoutSemanticRecipe(Operation *op, return success(); std::string reason; - if (failed(recipes.getGroupBroadcastRecipe(capabilities, broadcast, + if (failed(supports.getGroupBroadcastSupport(capabilities, broadcast, &reason))) - return emitLayoutContract( + return emitLayoutSupportContract( op, diagOS, - Twine("pto.vmi.group_broadcast has no registered local recipe: ") + - reason); + "pto.vmi.group_broadcast has no registered layout support", reason); return success(); } if (auto truncf = dyn_cast(op)) { std::string reason; - if (failed(recipes.getTruncFRecipe(truncf, &reason))) - return emitLayoutContract( - op, diagOS, - Twine("pto.vmi.truncf has no registered local recipe: ") + reason); + if (failed(supports.getTruncFSupport(truncf, &reason))) + return emitLayoutSupportContract( + op, diagOS, "pto.vmi.truncf has no registered layout support", + reason); return success(); } if (auto extf = dyn_cast(op)) { std::string reason; - if (failed(recipes.getExtFRecipe(extf, &reason))) - return emitLayoutContract( - op, diagOS, - Twine("pto.vmi.extf has no registered local recipe: ") + reason); + if (failed(supports.getExtFSupport(extf, &reason))) + return emitLayoutSupportContract( + op, diagOS, "pto.vmi.extf has no registered layout support", reason); return success(); } if (auto bitcast = dyn_cast(op)) { std::string reason; - if (failed(recipes.getBitcastRecipe(bitcast, &reason))) - return emitLayoutContract( - op, diagOS, - Twine("pto.vmi.bitcast has no registered local recipe: ") + reason); + if (failed(supports.getBitcastSupport(bitcast, &reason))) + return emitLayoutSupportContract( + op, diagOS, "pto.vmi.bitcast has no registered layout support", + reason); return success(); } @@ -668,9 +690,10 @@ LogicalResult mlir::pto::validateVMIProducerBoundaryIR( } LogicalResult mlir::pto::validateVMILayoutAssignedIR( - ModuleOp module, llvm::raw_ostream *diagOS) { + ModuleOp module, llvm::raw_ostream *diagOS, bool verifyHelperSupports) { WalkResult result = module.walk([&](Operation *op) { - if (failed(verifyLayoutAssignedOperation(op, diagOS))) + if (failed(verifyLayoutAssignedOperation(op, diagOS, + verifyHelperSupports))) return WalkResult::interrupt(); return WalkResult::advance(); }); diff --git a/lib/PTO/Transforms/VMILayoutAssignment.cpp b/lib/PTO/Transforms/VMILayoutAssignment.cpp index 5f30ba82e0..eb3593c9ee 100644 --- a/lib/PTO/Transforms/VMILayoutAssignment.cpp +++ b/lib/PTO/Transforms/VMILayoutAssignment.cpp @@ -15,6 +15,7 @@ #include "PTO/IR/PTOTypeUtils.h" #include "PTO/IR/VMIUtils.h" #include "PTO/Transforms/Passes.h" +#include "PTO/Transforms/VMILayoutSupport.h" #include "PTO/Transforms/VMITargetCapabilities.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" @@ -248,19 +249,11 @@ struct LayoutSolver { if (VMILayoutAttr existing = type.getLayoutAttr()) if (existing.isGroupSlots() && existing.getSlots() > 0) return existing; - if (numGroups > 0 && type.getElementCount() % numGroups == 0) { - int64_t groupSize = type.getElementCount() / numGroups; - std::optional vlaneElems = getVLaneElems(type.getElementType()); - if (vlaneElems && (groupSize == *vlaneElems || - groupSize == 2 * *vlaneElems || - groupSize == 4 * *vlaneElems)) - return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); - FailureOr lanesPerPart = - getDataLanesPerPart(type.getElementType()); - if (succeeded(lanesPerPart) && groupSize >= *lanesPerPart && - groupSize % *lanesPerPart == 0) - return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/1); - } + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredGroupReduceLayoutFact(type, numGroups); + if (succeeded(fact)) + return fact->resultLayout; return getGroupSlotsLayout(numGroups); } @@ -268,14 +261,11 @@ struct LayoutSolver { int64_t numGroups) { if (VMILayoutAttr existing = type.getLayoutAttr()) return existing; - if (numGroups > 0 && type.getElementCount() % numGroups == 0) { - int64_t groupSize = type.getElementCount() / numGroups; - std::optional vlaneElems = getVLaneElems(type.getElementType()); - if (vlaneElems && groupSize == 2 * *vlaneElems) - return VMILayoutAttr::getDeinterleaved(ctx, 2, /*blockElems=*/8); - if (vlaneElems && groupSize == 4 * *vlaneElems) - return VMILayoutAttr::getDeinterleaved(ctx, 4, /*blockElems=*/8); - } + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredGroupReduceLayoutFact(type, numGroups); + if (succeeded(fact)) + return fact->sourceLayout; return getContiguousLayout(); } @@ -406,6 +396,27 @@ struct LayoutSolver { return false; } + bool isCompatibleGroupReduceSourceLayout(VMIGroupReduceLayoutFact fact, + VMILayoutAttr layout) { + if (!layout) + return false; + if (fact.kind == VMIGroupReduceLayoutKind::OneVLane || + fact.kind == VMIGroupReduceLayoutKind::RowLocal) + return layout.isContiguous(); + int64_t factor = fact.kind == VMIGroupReduceLayoutKind::TwoVLane ? 2 : 4; + return layout.isDeinterleaved() && layout.getFactor() == factor && + (layout.getBlockElems() == 1 || layout.getBlockElems() == 8); + } + + VMILayoutAttr getTruncFCompatibleGroupReduceSourceLayout( + VMIGroupReduceLayoutFact fact) { + if (fact.kind == VMIGroupReduceLayoutKind::TwoVLane) + return VMILayoutAttr::getDeinterleaved(ctx, 2, /*blockElems=*/1); + if (fact.kind == VMIGroupReduceLayoutKind::FourVLane) + return VMILayoutAttr::getDeinterleaved(ctx, 4, /*blockElems=*/1); + return {}; + } + LogicalResult requestMask(Value mask, VMILayoutAttr layout, StringRef granularity, Operation *op) { unsigned id = addMaskValue(mask); @@ -813,41 +824,23 @@ struct LayoutSolver { if (auto reduce = dyn_cast(op)) { auto sourceType = cast(reduce.getSource().getType()); auto resultType = cast(reduce.getResult().getType()); + int64_t numGroups = reduce.getNumGroupsAttr().getInt(); + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredGroupReduceLayoutFact(sourceType, numGroups); VMILayoutAttr sourceLayout = getPreferredGroupReduceSourceLayout( - sourceType, reduce.getNumGroupsAttr().getInt()); + sourceType, numGroups); VMILayoutAttr solvedSourceLayout = getExplicitDataLayout(reduce.getSource()); - int64_t numGroups = reduce.getNumGroupsAttr().getInt(); - if (solvedSourceLayout && numGroups > 0 && - sourceType.getElementCount() % numGroups == 0) { - int64_t groupSize = sourceType.getElementCount() / numGroups; - std::optional vlaneElems = - getVLaneElems(sourceType.getElementType()); - if (vlaneElems && groupSize == 2 * *vlaneElems && - solvedSourceLayout.isDeinterleaved() && - solvedSourceLayout.getFactor() == 2 && - (solvedSourceLayout.getBlockElems() == 1 || - solvedSourceLayout.getBlockElems() == 8)) - sourceLayout = solvedSourceLayout; - if (vlaneElems && groupSize == 4 * *vlaneElems && - solvedSourceLayout.isDeinterleaved() && - solvedSourceLayout.getFactor() == 4 && - (solvedSourceLayout.getBlockElems() == 1 || - solvedSourceLayout.getBlockElems() == 8)) - sourceLayout = solvedSourceLayout; - } else if (!sourceType.getLayoutAttr() && numGroups > 0 && - sourceType.getElementCount() % numGroups == 0) { - int64_t groupSize = sourceType.getElementCount() / numGroups; + if (solvedSourceLayout && succeeded(fact) && + isCompatibleGroupReduceSourceLayout(*fact, solvedSourceLayout)) { + sourceLayout = solvedSourceLayout; + } else if (!sourceType.getLayoutAttr() && succeeded(fact)) { if (hasCompatibleTruncFUseForGroupReduce(reduce.getSource(), - groupSize)) { - std::optional vlaneElems = - getVLaneElems(sourceType.getElementType()); - if (vlaneElems && groupSize == 2 * *vlaneElems) - sourceLayout = - VMILayoutAttr::getDeinterleaved(ctx, 2, /*blockElems=*/1); - if (vlaneElems && groupSize == 4 * *vlaneElems) - sourceLayout = - VMILayoutAttr::getDeinterleaved(ctx, 4, /*blockElems=*/1); + fact->groupSize)) { + if (VMILayoutAttr truncLayout = + getTruncFCompatibleGroupReduceSourceLayout(*fact)) + sourceLayout = truncLayout; } } requestDataUse(reduce.getSourceMutable(), sourceLayout); @@ -857,8 +850,9 @@ struct LayoutSolver { return WalkResult::interrupt(); if (failed(setNaturalLayout( reduce.getResult(), - getPreferredGroupSlotsLayout( - resultType, reduce.getNumGroupsAttr().getInt()), + succeeded(fact) ? fact->resultLayout + : getPreferredGroupSlotsLayout(resultType, + numGroups), op))) return WalkResult::interrupt(); return WalkResult::advance(); @@ -866,29 +860,17 @@ struct LayoutSolver { if (auto reduce = dyn_cast(op)) { auto sourceType = cast(reduce.getSource().getType()); auto resultType = cast(reduce.getResult().getType()); + int64_t numGroups = reduce.getNumGroupsAttr().getInt(); + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredGroupReduceLayoutFact(sourceType, numGroups); VMILayoutAttr sourceLayout = getPreferredGroupReduceSourceLayout( - sourceType, reduce.getNumGroupsAttr().getInt()); + sourceType, numGroups); VMILayoutAttr solvedSourceLayout = getExplicitDataLayout(reduce.getSource()); - int64_t numGroups = reduce.getNumGroupsAttr().getInt(); - if (solvedSourceLayout && numGroups > 0 && - sourceType.getElementCount() % numGroups == 0) { - int64_t groupSize = sourceType.getElementCount() / numGroups; - std::optional vlaneElems = - getVLaneElems(sourceType.getElementType()); - if (vlaneElems && groupSize == 2 * *vlaneElems && - solvedSourceLayout.isDeinterleaved() && - solvedSourceLayout.getFactor() == 2 && - (solvedSourceLayout.getBlockElems() == 1 || - solvedSourceLayout.getBlockElems() == 8)) - sourceLayout = solvedSourceLayout; - if (vlaneElems && groupSize == 4 * *vlaneElems && - solvedSourceLayout.isDeinterleaved() && - solvedSourceLayout.getFactor() == 4 && - (solvedSourceLayout.getBlockElems() == 1 || - solvedSourceLayout.getBlockElems() == 8)) - sourceLayout = solvedSourceLayout; - } + if (solvedSourceLayout && succeeded(fact) && + isCompatibleGroupReduceSourceLayout(*fact, solvedSourceLayout)) + sourceLayout = solvedSourceLayout; requestDataUse(reduce.getSourceMutable(), sourceLayout); if (failed(requestMaskUse( reduce.getMaskMutable(), sourceLayout, @@ -896,8 +878,9 @@ struct LayoutSolver { return WalkResult::interrupt(); if (failed(setNaturalLayout( reduce.getResult(), - getPreferredGroupSlotsLayout( - resultType, reduce.getNumGroupsAttr().getInt()), + succeeded(fact) ? fact->resultLayout + : getPreferredGroupSlotsLayout(resultType, + numGroups), op))) return WalkResult::interrupt(); return WalkResult::advance(); @@ -912,19 +895,15 @@ struct LayoutSolver { if (auto extf = dyn_cast(op)) { auto sourceType = cast(extf.getSource().getType()); auto resultType = cast(extf.getResult().getType()); - unsigned sourceBits = getElementBitWidth(sourceType.getElementType()); - unsigned resultBits = getElementBitWidth(resultType.getElementType()); - if (sourceBits == 16 && resultBits == 32) { - requestDataUse(extf.getSourceMutable(), getContiguousLayout()); - if (failed(setNaturalLayout(extf.getResult(), - VMILayoutAttr::getDeinterleaved(ctx, 2), - op))) - return WalkResult::interrupt(); - } else if (sourceBits == 8 && resultBits == 32) { - requestDataUse(extf.getSourceMutable(), getContiguousLayout()); - if (failed(setNaturalLayout(extf.getResult(), - VMILayoutAttr::getDeinterleaved(ctx, 4), - op))) + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredCastLayoutFact(sourceType, resultType); + if (succeeded(fact) && + (fact->kind == VMICastLayoutKind::Widen2x || + fact->kind == VMICastLayoutKind::Widen4x)) { + requestDataUse(extf.getSourceMutable(), fact->sourceLayout); + if (failed( + setNaturalLayout(extf.getResult(), fact->resultLayout, op))) return WalkResult::interrupt(); } return WalkResult::advance(); @@ -932,19 +911,15 @@ struct LayoutSolver { if (auto extsi = dyn_cast(op)) { auto sourceType = cast(extsi.getSource().getType()); auto resultType = cast(extsi.getResult().getType()); - unsigned sourceBits = getElementBitWidth(sourceType.getElementType()); - unsigned resultBits = getElementBitWidth(resultType.getElementType()); - if (sourceBits == 16 && resultBits == 32) { - requestDataUse(extsi.getSourceMutable(), getContiguousLayout()); - if (failed(setNaturalLayout(extsi.getResult(), - VMILayoutAttr::getDeinterleaved(ctx, 2), - op))) - return WalkResult::interrupt(); - } else if (sourceBits == 8 && resultBits == 32) { - requestDataUse(extsi.getSourceMutable(), getContiguousLayout()); - if (failed(setNaturalLayout(extsi.getResult(), - VMILayoutAttr::getDeinterleaved(ctx, 4), - op))) + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredCastLayoutFact(sourceType, resultType); + if (succeeded(fact) && + (fact->kind == VMICastLayoutKind::Widen2x || + fact->kind == VMICastLayoutKind::Widen4x)) { + requestDataUse(extsi.getSourceMutable(), fact->sourceLayout); + if (failed( + setNaturalLayout(extsi.getResult(), fact->resultLayout, op))) return WalkResult::interrupt(); } return WalkResult::advance(); @@ -952,19 +927,15 @@ struct LayoutSolver { if (auto extui = dyn_cast(op)) { auto sourceType = cast(extui.getSource().getType()); auto resultType = cast(extui.getResult().getType()); - unsigned sourceBits = getElementBitWidth(sourceType.getElementType()); - unsigned resultBits = getElementBitWidth(resultType.getElementType()); - if (sourceBits == 16 && resultBits == 32) { - requestDataUse(extui.getSourceMutable(), getContiguousLayout()); - if (failed(setNaturalLayout(extui.getResult(), - VMILayoutAttr::getDeinterleaved(ctx, 2), - op))) - return WalkResult::interrupt(); - } else if (sourceBits == 8 && resultBits == 32) { - requestDataUse(extui.getSourceMutable(), getContiguousLayout()); - if (failed(setNaturalLayout(extui.getResult(), - VMILayoutAttr::getDeinterleaved(ctx, 4), - op))) + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredCastLayoutFact(sourceType, resultType); + if (succeeded(fact) && + (fact->kind == VMICastLayoutKind::Widen2x || + fact->kind == VMICastLayoutKind::Widen4x)) { + requestDataUse(extui.getSourceMutable(), fact->sourceLayout); + if (failed( + setNaturalLayout(extui.getResult(), fact->resultLayout, op))) return WalkResult::interrupt(); } return WalkResult::advance(); @@ -972,48 +943,50 @@ struct LayoutSolver { if (auto truncf = dyn_cast(op)) { auto sourceType = cast(truncf.getSource().getType()); auto resultType = cast(truncf.getResult().getType()); - unsigned sourceBits = getElementBitWidth(sourceType.getElementType()); - unsigned resultBits = getElementBitWidth(resultType.getElementType()); + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredCastLayoutFact(sourceType, resultType); VMILayoutAttr sourceLayout = getDataLayout(truncf.getSource()); - if (sourceBits == 32 && resultBits == 16 && sourceLayout && + if (succeeded(fact) && fact->kind == VMICastLayoutKind::Narrow2x && + sourceLayout && sourceLayout.isGroupSlots() && sourceLayout.getSlots() == 1) { requestDataUse(truncf.getSourceMutable(), sourceLayout); if (failed(setNaturalLayout(truncf.getResult(), sourceLayout, op))) return WalkResult::interrupt(); return WalkResult::advance(); } - if (sourceBits == 32 && resultBits == 16) - requestDataUse(truncf.getSourceMutable(), - VMILayoutAttr::getDeinterleaved(ctx, 2)); - else if (sourceBits == 32 && resultBits == 8) - requestDataUse(truncf.getSourceMutable(), - VMILayoutAttr::getDeinterleaved(ctx, 4)); - if (failed(setNaturalLayout(truncf.getResult(), getContiguousLayout(), - op))) + if (succeeded(fact) && + (fact->kind == VMICastLayoutKind::Narrow2x || + fact->kind == VMICastLayoutKind::Narrow4x)) + requestDataUse(truncf.getSourceMutable(), fact->sourceLayout); + VMILayoutAttr resultLayout = + succeeded(fact) ? fact->resultLayout : getContiguousLayout(); + if (failed(setNaturalLayout(truncf.getResult(), resultLayout, op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto trunci = dyn_cast(op)) { auto sourceType = cast(trunci.getSource().getType()); auto resultType = cast(trunci.getResult().getType()); - unsigned sourceBits = getElementBitWidth(sourceType.getElementType()); - unsigned resultBits = getElementBitWidth(resultType.getElementType()); + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredCastLayoutFact(sourceType, resultType); VMILayoutAttr sourceLayout = getDataLayout(trunci.getSource()); - if (sourceBits == 32 && resultBits == 16 && sourceLayout && + if (succeeded(fact) && fact->kind == VMICastLayoutKind::Narrow2x && + sourceLayout && sourceLayout.isGroupSlots() && sourceLayout.getSlots() == 1) { requestDataUse(trunci.getSourceMutable(), sourceLayout); if (failed(setNaturalLayout(trunci.getResult(), sourceLayout, op))) return WalkResult::interrupt(); return WalkResult::advance(); } - if (sourceBits == 32 && resultBits == 16) - requestDataUse(trunci.getSourceMutable(), - VMILayoutAttr::getDeinterleaved(ctx, 2)); - else if (sourceBits == 32 && resultBits == 8) - requestDataUse(trunci.getSourceMutable(), - VMILayoutAttr::getDeinterleaved(ctx, 4)); - if (failed(setNaturalLayout(trunci.getResult(), getContiguousLayout(), - op))) + if (succeeded(fact) && + (fact->kind == VMICastLayoutKind::Narrow2x || + fact->kind == VMICastLayoutKind::Narrow4x)) + requestDataUse(trunci.getSourceMutable(), fact->sourceLayout); + VMILayoutAttr resultLayout = + succeeded(fact) ? fact->resultLayout : getContiguousLayout(); + if (failed(setNaturalLayout(trunci.getResult(), resultLayout, op))) return WalkResult::interrupt(); return WalkResult::advance(); } @@ -1447,27 +1420,6 @@ struct LayoutSolver { } } - std::optional rematerializeDataUse(Value value, VMIVRegType resultType, - Location loc, OpBuilder &builder) { - if (auto constant = value.getDefiningOp()) { - auto denseAttr = dyn_cast(constant.getValue()); - if (denseAttr && denseAttr.isSplat()) - return builder - .create(loc, resultType, constant.getValue()) - .getResult(); - } - if (auto broadcast = value.getDefiningOp()) - return builder - .create(loc, resultType, broadcast.getValue()) - .getResult(); - if (auto iota = value.getDefiningOp()) - return builder - .create(loc, resultType, iota.getBase(), - iota.getOrderAttr()) - .getResult(); - return std::nullopt; - } - LogicalResult insertDataUseMaterializations() { OpBuilder builder(ctx); for (DataUseRequest request : dataUseRequests) { @@ -1488,12 +1440,6 @@ struct LayoutSolver { VMIVRegType::get(ctx, sourceType.getElementCount(), sourceType.getElementType(), request.layout); builder.setInsertionPoint(request.operand->getOwner()); - std::optional rematerialized = rematerializeDataUse( - value, resultType, request.operand->getOwner()->getLoc(), builder); - if (rematerialized) { - request.operand->set(*rematerialized); - continue; - } auto ensure = builder.create( request.operand->getOwner()->getLoc(), resultType, value); request.operand->set(ensure.getResult()); @@ -1625,27 +1571,6 @@ struct LayoutSolver { } } - std::optional rematerializeMaskUse(Value value, VMIMaskType resultType, - Location loc, OpBuilder &builder) { - if (auto createMask = value.getDefiningOp()) - return builder - .create(loc, resultType, createMask.getActiveLanes()) - .getResult(); - if (auto createGroupMask = value.getDefiningOp()) - return builder - .create( - loc, resultType, createGroupMask.getActiveElemsPerGroup(), - createGroupMask.getNumGroupsAttr(), - createGroupMask.getGroupSizeAttr()) - .getResult(); - if (auto constantMask = value.getDefiningOp()) - return builder - .create(loc, resultType, - constantMask.getValueAttr()) - .getResult(); - return std::nullopt; - } - LogicalResult insertMaskUseMaterializations() { OpBuilder builder(ctx); for (MaskUseRequest request : maskUseRequests) { @@ -1663,19 +1588,6 @@ struct LayoutSolver { builder.setInsertionPoint(request.operand->getOwner()); Value current = value; VMIMaskType currentType = sourceType; - auto requestedType = - VMIMaskType::get(ctx, sourceType.getElementCount(), - request.granularity, request.layout); - if (sourceType != requestedType) { - std::optional rematerialized = rematerializeMaskUse( - value, requestedType, request.operand->getOwner()->getLoc(), - builder); - if (rematerialized) { - request.operand->set(*rematerialized); - continue; - } - } - if (sourceLayout != request.layout) { auto layoutType = VMIMaskType::get(ctx, currentType.getElementCount(), @@ -1753,7 +1665,8 @@ struct LayoutSolver { if (failed(insertMaskUseMaterializations())) return failure(); rewriteFunctionType(); - return validateVMILayoutAssignedIR(module); + return validateVMILayoutAssignedIR(module, /*diagOS=*/nullptr, + /*verifyHelperSupport=*/false); } ModuleOp module; diff --git a/lib/PTO/Transforms/VMILayoutFoldConsumers.cpp b/lib/PTO/Transforms/VMILayoutFoldConsumers.cpp index 26536f196d..fda374f661 100644 --- a/lib/PTO/Transforms/VMILayoutFoldConsumers.cpp +++ b/lib/PTO/Transforms/VMILayoutFoldConsumers.cpp @@ -14,7 +14,7 @@ #include "PTO/IR/PTO.h" #include "PTO/IR/VMIUtils.h" #include "PTO/Transforms/Passes.h" -#include "PTO/Transforms/VMILocalRecipeRegistry.h" +#include "PTO/Transforms/VMILayoutSupport.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -43,9 +43,9 @@ static bool isFoldableStoreEnsure(VMIEnsureLayoutOp ensure) { if (!sourceType || !resultType) return false; - VMILocalRecipeRegistry recipes; + VMILayoutSupport supports; return succeeded( - recipes.canFoldContiguousStoreMaterialization(sourceType, resultType)); + supports.canFoldContiguousStoreMaterialization(sourceType, resultType)); } static void tryFoldEnsureLayoutIntoOperand( diff --git a/lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp b/lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp index c3bbf67731..3027d919f7 100644 --- a/lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp +++ b/lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp @@ -13,6 +13,7 @@ #include "PTO/IR/PTO.h" #include "PTO/Transforms/Passes.h" +#include "PTO/Transforms/VMILayoutSupport.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -39,6 +40,18 @@ struct BinaryVRegOperands { OpOperand *rhs = nullptr; }; +struct TernaryVRegOperands { + OpOperand *lhs = nullptr; + OpOperand *rhs = nullptr; + OpOperand *acc = nullptr; +}; + +struct SelectOperands { + OpOperand *mask = nullptr; + OpOperand *trueValue = nullptr; + OpOperand *falseValue = nullptr; +}; + struct UnaryVRegOperand { OpOperand *source = nullptr; }; @@ -84,6 +97,31 @@ static std::optional getSinkableBinaryOperands(Operation *op return std::nullopt; } +static std::optional +getSinkableCompareOperands(Operation *op) { + if (auto cmpf = dyn_cast(op)) + return BinaryVRegOperands{&cmpf.getLhsMutable(), &cmpf.getRhsMutable()}; + if (auto cmpi = dyn_cast(op)) + return BinaryVRegOperands{&cmpi.getLhsMutable(), &cmpi.getRhsMutable()}; + return std::nullopt; +} + +static std::optional getSinkableSelectOperands(Operation *op) { + if (auto select = dyn_cast(op)) + return SelectOperands{&select.getMaskMutable(), + &select.getTrueValueMutable(), + &select.getFalseValueMutable()}; + return std::nullopt; +} + +static std::optional +getSinkableTernaryOperands(Operation *op) { + if (auto fma = dyn_cast(op)) + return TernaryVRegOperands{&fma.getLhsMutable(), &fma.getRhsMutable(), + &fma.getAccMutable()}; + return std::nullopt; +} + static std::optional getSinkableUnaryOperand(Operation *op) { if (auto negf = dyn_cast(op)) return UnaryVRegOperand{&negf.getSourceMutable()}; @@ -155,6 +193,34 @@ static bool isSameMaterialization(VMIEnsureLayoutOp lhsEnsure, lhsResultType == resultType && lhsSourceType != resultType; } +static bool isSameMaterialization(VMIEnsureLayoutOp lhsEnsure, + VMIEnsureLayoutOp rhsEnsure, + VMIEnsureLayoutOp accEnsure, + VMIVRegType resultType) { + if (!lhsEnsure || !rhsEnsure || !accEnsure || !resultType) + return false; + + auto lhsSourceType = dyn_cast(lhsEnsure.getSource().getType()); + auto rhsSourceType = dyn_cast(rhsEnsure.getSource().getType()); + auto accSourceType = dyn_cast(accEnsure.getSource().getType()); + auto lhsResultType = dyn_cast(lhsEnsure.getResult().getType()); + auto rhsResultType = dyn_cast(rhsEnsure.getResult().getType()); + auto accResultType = dyn_cast(accEnsure.getResult().getType()); + if (!lhsSourceType || !rhsSourceType || !accSourceType || !lhsResultType || + !rhsResultType || !accResultType) + return false; + + return lhsSourceType == rhsSourceType && lhsSourceType == accSourceType && + lhsResultType == rhsResultType && lhsResultType == accResultType && + lhsResultType == resultType && lhsSourceType != resultType; +} + +static bool canMaterializeDataLayout(VMIVRegType sourceType, + VMIVRegType resultType) { + VMILayoutSupport supports; + return succeeded(supports.canMaterializeDataLayout(sourceType, resultType)); +} + template static bool isSameMaskMaterialization(EnsureOp ensure, VMIMaskType resultType) { if (!ensure || !resultType) @@ -185,6 +251,20 @@ static bool isSameMaskMaterialization(EnsureOp lhsEnsure, EnsureOp rhsEnsure, lhsResultType == resultType && lhsSourceType != resultType; } +static bool canMaterializeMask(VMIEnsureMaskLayoutOp, VMIMaskType sourceType, + VMIMaskType resultType) { + VMILayoutSupport supports; + return succeeded(supports.canMaterializeMaskLayout(sourceType, resultType)); +} + +static bool canMaterializeMask(VMIEnsureMaskGranularityOp, + VMIMaskType sourceType, + VMIMaskType resultType) { + VMILayoutSupport supports; + return succeeded( + supports.canMaterializeMaskGranularity(sourceType, resultType)); +} + static bool trySinkBinaryMaterialization(Operation *op) { std::optional operands = getSinkableBinaryOperands(op); if (!operands || op->getNumResults() != 1) @@ -200,9 +280,175 @@ static bool trySinkBinaryMaterialization(Operation *op) { return false; auto sourceType = cast(lhsEnsure.getSource().getType()); + if (!canMaterializeDataLayout(sourceType, resultType)) + return false; + + OpBuilder builder(op); + OperationState state(op->getLoc(), op->getName()); + state.addOperands({lhsEnsure.getSource(), rhsEnsure.getSource()}); + state.addTypes(sourceType); + state.addAttributes(op->getAttrs()); + Operation *newOp = builder.create(state); + + builder.setInsertionPointAfter(newOp); + auto resultEnsure = builder.create( + op->getLoc(), resultType, newOp->getResult(0)); + op->getResult(0).replaceAllUsesWith(resultEnsure.getResult()); + op->erase(); + + if (lhsEnsure->use_empty()) + lhsEnsure.erase(); + if (rhsEnsure != lhsEnsure && rhsEnsure->use_empty()) + rhsEnsure.erase(); + return true; +} + +static bool trySinkSelectMaterialization(Operation *op) { + std::optional operands = getSinkableSelectOperands(op); + if (!operands || op->getNumResults() != 1) + return false; + + auto resultType = dyn_cast(op->getResult(0).getType()); + if (!resultType) + return false; + + auto maskEnsure = + operands->mask->get().getDefiningOp(); + auto trueEnsure = + operands->trueValue->get().getDefiningOp(); + auto falseEnsure = + operands->falseValue->get().getDefiningOp(); + if (!maskEnsure || !trueEnsure || !falseEnsure) + return false; + + auto trueSourceType = dyn_cast(trueEnsure.getSource().getType()); + auto falseSourceType = + dyn_cast(falseEnsure.getSource().getType()); + auto trueResultType = dyn_cast(trueEnsure.getResult().getType()); + auto falseResultType = + dyn_cast(falseEnsure.getResult().getType()); + auto maskSourceType = dyn_cast(maskEnsure.getSource().getType()); + auto maskResultType = dyn_cast(maskEnsure.getResult().getType()); + if (!trueSourceType || !falseSourceType || !trueResultType || + !falseResultType || !maskSourceType || !maskResultType) + return false; + + if (trueSourceType != falseSourceType || trueResultType != falseResultType || + trueResultType != resultType || trueSourceType == resultType) + return false; + if (maskResultType != operands->mask->get().getType()) + return false; + if (maskResultType.getLayoutAttr() != resultType.getLayoutAttr() || + maskSourceType.getLayoutAttr() != trueSourceType.getLayoutAttr()) + return false; + if (maskSourceType.getElementCount() != trueSourceType.getElementCount() || + maskResultType.getElementCount() != resultType.getElementCount() || + maskSourceType.getGranularity() != maskResultType.getGranularity()) + return false; + if (!canMaterializeDataLayout(trueSourceType, resultType) || + !canMaterializeMask(maskEnsure, maskSourceType, maskResultType)) + return false; + + OpBuilder builder(op); + OperationState state(op->getLoc(), op->getName()); + state.addOperands({maskEnsure.getSource(), trueEnsure.getSource(), + falseEnsure.getSource()}); + state.addTypes(trueSourceType); + state.addAttributes(op->getAttrs()); + Operation *newOp = builder.create(state); + + builder.setInsertionPointAfter(newOp); + auto resultEnsure = builder.create( + op->getLoc(), resultType, newOp->getResult(0)); + op->getResult(0).replaceAllUsesWith(resultEnsure.getResult()); + op->erase(); + + if (maskEnsure->use_empty()) + maskEnsure.erase(); + if (trueEnsure->use_empty()) + trueEnsure.erase(); + if (falseEnsure != trueEnsure && falseEnsure->use_empty()) + falseEnsure.erase(); + return true; +} + +static bool trySinkCompareMaterialization(Operation *op) { + std::optional operands = getSinkableCompareOperands(op); + if (!operands || op->getNumResults() != 1) + return false; + + auto resultMaskType = dyn_cast(op->getResult(0).getType()); + if (!resultMaskType) + return false; + + auto lhsEnsure = operands->lhs->get().getDefiningOp(); + auto rhsEnsure = operands->rhs->get().getDefiningOp(); + if (!lhsEnsure || !rhsEnsure) + return false; + + auto lhsSourceType = dyn_cast(lhsEnsure.getSource().getType()); + auto rhsSourceType = dyn_cast(rhsEnsure.getSource().getType()); + auto lhsResultType = dyn_cast(lhsEnsure.getResult().getType()); + auto rhsResultType = dyn_cast(rhsEnsure.getResult().getType()); + if (!lhsSourceType || !rhsSourceType || !lhsResultType || !rhsResultType) + return false; + if (lhsSourceType != rhsSourceType || lhsResultType != rhsResultType || + lhsSourceType == lhsResultType) + return false; + if (lhsResultType.getElementCount() != resultMaskType.getElementCount() || + lhsResultType.getLayoutAttr() != resultMaskType.getLayoutAttr()) + return false; + + auto sourceMaskType = VMIMaskType::get( + op->getContext(), resultMaskType.getElementCount(), + resultMaskType.getGranularity(), lhsSourceType.getLayoutAttr()); + VMILayoutSupport supports; + if (failed(supports.canMaterializeMaskLayout(sourceMaskType, resultMaskType))) + return false; + OpBuilder builder(op); OperationState state(op->getLoc(), op->getName()); state.addOperands({lhsEnsure.getSource(), rhsEnsure.getSource()}); + state.addTypes(sourceMaskType); + state.addAttributes(op->getAttrs()); + Operation *newOp = builder.create(state); + + builder.setInsertionPointAfter(newOp); + auto resultEnsure = builder.create( + op->getLoc(), resultMaskType, newOp->getResult(0)); + op->getResult(0).replaceAllUsesWith(resultEnsure.getResult()); + op->erase(); + + if (lhsEnsure->use_empty()) + lhsEnsure.erase(); + if (rhsEnsure != lhsEnsure && rhsEnsure->use_empty()) + rhsEnsure.erase(); + return true; +} + +static bool trySinkTernaryMaterialization(Operation *op) { + std::optional operands = getSinkableTernaryOperands(op); + if (!operands || op->getNumResults() != 1) + return false; + + auto resultType = dyn_cast(op->getResult(0).getType()); + if (!resultType) + return false; + + auto lhsEnsure = operands->lhs->get().getDefiningOp(); + auto rhsEnsure = operands->rhs->get().getDefiningOp(); + auto accEnsure = operands->acc->get().getDefiningOp(); + if (!isSameMaterialization(lhsEnsure, rhsEnsure, accEnsure, resultType)) + return false; + + auto sourceType = cast(lhsEnsure.getSource().getType()); + if (!canMaterializeDataLayout(sourceType, resultType)) + return false; + + OpBuilder builder(op); + OperationState state(op->getLoc(), op->getName()); + state.addOperands( + {lhsEnsure.getSource(), rhsEnsure.getSource(), accEnsure.getSource()}); state.addTypes(sourceType); state.addAttributes(op->getAttrs()); Operation *newOp = builder.create(state); @@ -217,6 +463,9 @@ static bool trySinkBinaryMaterialization(Operation *op) { lhsEnsure.erase(); if (rhsEnsure != lhsEnsure && rhsEnsure->use_empty()) rhsEnsure.erase(); + if (accEnsure != lhsEnsure && accEnsure != rhsEnsure && + accEnsure->use_empty()) + accEnsure.erase(); return true; } @@ -236,6 +485,9 @@ static bool trySinkBinaryMaskMaterialization(Operation *op) { return false; auto sourceType = cast(lhsEnsure.getSource().getType()); + if (!canMaterializeMask(lhsEnsure, sourceType, resultType)) + return false; + OpBuilder builder(op); OperationState state(op->getLoc(), op->getName()); state.addOperands({lhsEnsure.getSource(), rhsEnsure.getSource()}); @@ -271,6 +523,9 @@ static bool trySinkUnaryMaterialization(Operation *op) { return false; auto sourceType = cast(sourceEnsure.getSource().getType()); + if (!canMaterializeDataLayout(sourceType, resultType)) + return false; + OpBuilder builder(op); OperationState state(op->getLoc(), op->getName()); state.addOperands(sourceEnsure.getSource()); @@ -305,6 +560,9 @@ static bool trySinkUnaryMaskMaterialization(Operation *op) { return false; auto sourceType = cast(sourceEnsure.getSource().getType()); + if (!canMaterializeMask(sourceEnsure, sourceType, resultType)) + return false; + OpBuilder builder(op); OperationState state(op->getLoc(), op->getName()); state.addOperands(sourceEnsure.getSource()); @@ -340,8 +598,10 @@ struct VMILayoutSinkMaterializationPass ModuleOp module = getOperation(); SmallVector candidates; module.walk([&](Operation *op) { - if (getSinkableBinaryOperands(op) || getSinkableUnaryOperand(op) || - getSinkableBinaryMaskOperands(op) || getSinkableUnaryMaskOperand(op)) + if (getSinkableBinaryOperands(op) || getSinkableCompareOperands(op) || + getSinkableSelectOperands(op) || getSinkableTernaryOperands(op) || + getSinkableUnaryOperand(op) || getSinkableBinaryMaskOperands(op) || + getSinkableUnaryMaskOperand(op)) candidates.push_back(op); }); @@ -349,8 +609,14 @@ struct VMILayoutSinkMaterializationPass if (op->getBlock() == nullptr) continue; if (!trySinkBinaryMaterialization(op)) { - if (!trySinkUnaryMaterialization(op)) - trySinkMaskMaterialization(op); + if (!trySinkCompareMaterialization(op)) { + if (!trySinkSelectMaterialization(op)) { + if (!trySinkTernaryMaterialization(op)) { + if (!trySinkUnaryMaterialization(op)) + trySinkMaskMaterialization(op); + } + } + } } } } diff --git a/lib/PTO/Transforms/VMILocalRecipeRegistry.cpp b/lib/PTO/Transforms/VMILayoutSupport.cpp similarity index 73% rename from lib/PTO/Transforms/VMILocalRecipeRegistry.cpp rename to lib/PTO/Transforms/VMILayoutSupport.cpp index 34b843737c..27a994ba55 100644 --- a/lib/PTO/Transforms/VMILocalRecipeRegistry.cpp +++ b/lib/PTO/Transforms/VMILayoutSupport.cpp @@ -8,10 +8,10 @@ // FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository // for the full text of the License. -//===- VMILocalRecipeRegistry.cpp - VMI local recipe queries --------------===// +//===- VMILayoutSupport.cpp - VMI layout support queries --------------===// //===----------------------------------------------------------------------===// -#include "PTO/Transforms/VMILocalRecipeRegistry.h" +#include "PTO/Transforms/VMILayoutSupport.h" #include "PTO/IR/PTOTypeUtils.h" #include "PTO/IR/VMIUtils.h" @@ -311,12 +311,12 @@ getPhysicalLogicalBitFootprint(VMIVRegType type) { return bits; } -static FailureOr -getLayoutMaterializationRecipe(VMILayoutAttr sourceLayout, +static FailureOr +getLayoutMaterializationSupport(VMILayoutAttr sourceLayout, VMILayoutAttr resultLayout, std::string *reason) { auto fail = [&](const Twine &message) - -> FailureOr { + -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -325,26 +325,156 @@ getLayoutMaterializationRecipe(VMILayoutAttr sourceLayout, if (!sourceLayout || !resultLayout) return fail("requires assigned source/result layouts"); if (sourceLayout == resultLayout) - return VMILayoutMaterializationRecipe{ - VMILayoutMaterializationRecipeKind::Identity}; + return VMILayoutMaterializationSupport{ + VMILayoutMaterializationSupportKind::Identity}; if (sourceLayout.isContiguous() && resultLayout.isDeinterleaved() && (resultLayout.getFactor() == 2 || resultLayout.getFactor() == 4)) - return VMILayoutMaterializationRecipe{ - VMILayoutMaterializationRecipeKind::ContiguousToDeinterleaved}; + return VMILayoutMaterializationSupport{ + VMILayoutMaterializationSupportKind::ContiguousToDeinterleaved}; if (sourceLayout.isDeinterleaved() && resultLayout.isContiguous() && (sourceLayout.getFactor() == 2 || sourceLayout.getFactor() == 4)) - return VMILayoutMaterializationRecipe{ - VMILayoutMaterializationRecipeKind::DeinterleavedToContiguous}; + return VMILayoutMaterializationSupport{ + VMILayoutMaterializationSupportKind::DeinterleavedToContiguous}; return fail("unsupported source/result layout pair"); } } // namespace -FailureOr -VMILocalRecipeRegistry::getContiguousStoreRecipe(VMIVRegType valueType, +FailureOr +VMILayoutSupport::getPreferredGroupReduceLayoutFact( + VMIVRegType sourceType, int64_t numGroups, std::string *reason) const { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + FailureOr groupSize = + getGroupSizeFromNumGroups(sourceType, numGroups, reason); + if (failed(groupSize)) + return failure(); + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + if (failed(lanesPerPart) || *lanesPerPart % 8 != 0) + return fail("requires element type with known physical VLane width"); + + MLIRContext *ctx = sourceType.getContext(); + int64_t vlaneElems = *lanesPerPart / 8; + VMIGroupReduceLayoutFact fact; + fact.groupSize = *groupSize; + fact.lanesPerPart = *lanesPerPart; + fact.vlaneElems = vlaneElems; + + if (*groupSize == vlaneElems) { + fact.kind = VMIGroupReduceLayoutKind::OneVLane; + fact.sourceLayout = VMILayoutAttr::getContiguous(ctx); + fact.maskLayout = fact.sourceLayout; + fact.resultLayout = + VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); + return fact; + } + + if (*groupSize == 2 * vlaneElems) { + fact.kind = VMIGroupReduceLayoutKind::TwoVLane; + fact.sourceLayout = + VMILayoutAttr::getDeinterleaved(ctx, 2, /*blockElems=*/8); + fact.maskLayout = fact.sourceLayout; + fact.resultLayout = + VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); + return fact; + } + + if (*groupSize == 4 * vlaneElems) { + fact.kind = VMIGroupReduceLayoutKind::FourVLane; + fact.sourceLayout = + VMILayoutAttr::getDeinterleaved(ctx, 4, /*blockElems=*/8); + fact.maskLayout = fact.sourceLayout; + fact.resultLayout = + VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); + return fact; + } + + if (*groupSize >= *lanesPerPart && *groupSize % *lanesPerPart == 0) { + fact.kind = VMIGroupReduceLayoutKind::RowLocal; + fact.sourceLayout = VMILayoutAttr::getContiguous(ctx); + fact.maskLayout = fact.sourceLayout; + fact.resultLayout = + VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/1); + return fact; + } + + return fail("group_reduce layout supports group sizes VLaneElems, " + "2*VLaneElems, 4*VLaneElems, or full physical chunk multiples"); +} + +FailureOr +VMILayoutSupport::getPreferredCastLayoutFact(VMIVRegType sourceType, + VMIVRegType resultType, + std::string *reason) const { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + unsigned sourceBits = + pto::getPTOStorageElemBitWidth(sourceType.getElementType()); + unsigned resultBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); + if (sourceBits == 0 || resultBits == 0) + return fail("requires source/result element types with known storage width"); + if (sourceType.getElementCount() != resultType.getElementCount()) + return fail("requires source/result lane count to match"); + + MLIRContext *ctx = sourceType.getContext(); + VMICastLayoutFact fact; + fact.sourceBits = sourceBits; + fact.resultBits = resultBits; + + if (resultBits == 32 && sourceBits == 16) { + fact.kind = VMICastLayoutKind::Widen2x; + fact.factor = 2; + fact.sourceLayout = VMILayoutAttr::getContiguous(ctx); + fact.resultLayout = + VMILayoutAttr::getDeinterleaved(ctx, fact.factor, /*blockElems=*/1); + return fact; + } + + if (resultBits == 32 && sourceBits == 8) { + fact.kind = VMICastLayoutKind::Widen4x; + fact.factor = 4; + fact.sourceLayout = VMILayoutAttr::getContiguous(ctx); + fact.resultLayout = + VMILayoutAttr::getDeinterleaved(ctx, fact.factor, /*blockElems=*/1); + return fact; + } + + if (sourceBits == 32 && resultBits == 16) { + fact.kind = VMICastLayoutKind::Narrow2x; + fact.factor = 2; + fact.sourceLayout = + VMILayoutAttr::getDeinterleaved(ctx, fact.factor, /*blockElems=*/1); + fact.resultLayout = VMILayoutAttr::getContiguous(ctx); + return fact; + } + + if (sourceBits == 32 && resultBits == 8) { + fact.kind = VMICastLayoutKind::Narrow4x; + fact.factor = 4; + fact.sourceLayout = + VMILayoutAttr::getDeinterleaved(ctx, fact.factor, /*blockElems=*/1); + fact.resultLayout = VMILayoutAttr::getContiguous(ctx); + return fact; + } + + return fail("supports only 8/16-bit <-> 32-bit dense cast layout facts"); +} + +FailureOr +VMILayoutSupport::getContiguousStoreSupport(VMIVRegType valueType, std::string *reason) const { auto fail = [&](const Twine &message) - -> FailureOr { + -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -354,8 +484,8 @@ VMILocalRecipeRegistry::getContiguousStoreRecipe(VMIVRegType valueType, if (!layout) return fail("requires assigned value layout"); if (layout.isContiguous()) - return VMIContiguousStoreRecipe{ - VMIContiguousStoreRecipeKind::ContiguousVsts}; + return VMIContiguousStoreSupport{ + VMIContiguousStoreSupportKind::ContiguousVsts}; if (!layout.isDeinterleaved()) return fail("requires contiguous or deinterleaved value layout"); if (layout.getBlockElems() != 1) @@ -366,18 +496,18 @@ VMILocalRecipeRegistry::getContiguousStoreRecipe(VMIVRegType valueType, if (layout.getFactor() == 2) { if (!hasX2MemoryDistToken(valueType.getElementType())) return fail("requires 8/16/32-bit element type for vstsx2 INTLV"); - return VMIContiguousStoreRecipe{ - VMIContiguousStoreRecipeKind::Deinterleaved2Vstsx2}; + return VMIContiguousStoreSupport{ + VMIContiguousStoreSupportKind::Deinterleaved2Vstsx2}; } if (layout.getFactor() == 4) - return VMIContiguousStoreRecipe{ - VMIContiguousStoreRecipeKind::DeinterleavedMaterializeThenVsts}; + return VMIContiguousStoreSupport{ + VMIContiguousStoreSupportKind::DeinterleavedMaterializeThenVsts}; return fail("requires deinterleaved factor 2 or 4"); } -LogicalResult VMILocalRecipeRegistry::canFoldContiguousStoreMaterialization( +LogicalResult VMILayoutSupport::canFoldContiguousStoreMaterialization( VMIVRegType sourceType, VMIVRegType resultType, std::string *reason) const { if (sourceType.getElementType() != resultType.getElementType()) return failWithReason("source/result element types must match", reason); @@ -388,22 +518,22 @@ LogicalResult VMILocalRecipeRegistry::canFoldContiguousStoreMaterialization( if (!resultLayout || !resultLayout.isContiguous()) return failWithReason("result layout must be contiguous", reason); - FailureOr recipe = - getContiguousStoreRecipe(sourceType, reason); - if (failed(recipe)) + FailureOr support = + getContiguousStoreSupport(sourceType, reason); + if (failed(support)) return failure(); - if (recipe->kind == VMIContiguousStoreRecipeKind::ContiguousVsts) + if (support->kind == VMIContiguousStoreSupportKind::ContiguousVsts) return failWithReason("source layout is already contiguous", reason); return success(); } -FailureOr -VMILocalRecipeRegistry::getDataLayoutMaterializationRecipe( +FailureOr +VMILayoutSupport::getDataLayoutMaterializationSupport( VMIVRegType sourceType, VMIVRegType resultType, std::string *reason) const { auto fail = [&](const Twine &message) - -> FailureOr { + -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -416,23 +546,33 @@ VMILocalRecipeRegistry::getDataLayoutMaterializationRecipe( VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); VMILayoutAttr resultLayout = resultType.getLayoutAttr(); - FailureOr recipe = - getLayoutMaterializationRecipe(sourceLayout, resultLayout, reason); - if (failed(recipe)) + FailureOr support = + getLayoutMaterializationSupport(sourceLayout, resultLayout, reason); + if (failed(support)) return failure(); if (failed(checkLayoutMaterializationShape(sourceType, resultType, sourceLayout, resultLayout, reason))) return failure(); - return recipe; + return support; } -FailureOr -VMILocalRecipeRegistry::getMaskLayoutMaterializationRecipe( +LogicalResult +VMILayoutSupport::canMaterializeDataLayout(VMIVRegType sourceType, + VMIVRegType resultType, + std::string *reason) const { + if (failed(getDataLayoutMaterializationSupport(sourceType, resultType, + reason))) + return failure(); + return success(); +} + +FailureOr +VMILayoutSupport::getMaskLayoutMaterializationSupport( VMIMaskType sourceType, VMIMaskType resultType, std::string *reason) const { auto fail = [&](const Twine &message) - -> FailureOr { + -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -445,23 +585,33 @@ VMILocalRecipeRegistry::getMaskLayoutMaterializationRecipe( VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); VMILayoutAttr resultLayout = resultType.getLayoutAttr(); - FailureOr recipe = - getLayoutMaterializationRecipe(sourceLayout, resultLayout, reason); - if (failed(recipe)) + FailureOr support = + getLayoutMaterializationSupport(sourceLayout, resultLayout, reason); + if (failed(support)) return failure(); if (failed(checkLayoutMaterializationShape(sourceType, resultType, sourceLayout, resultLayout, reason))) return failure(); - return recipe; + return support; +} + +LogicalResult +VMILayoutSupport::canMaterializeMaskLayout(VMIMaskType sourceType, + VMIMaskType resultType, + std::string *reason) const { + if (failed(getMaskLayoutMaterializationSupport(sourceType, resultType, + reason))) + return failure(); + return success(); } -FailureOr -VMILocalRecipeRegistry::getMaskGranularityMaterializationRecipe( +FailureOr +VMILayoutSupport::getMaskGranularityMaterializationSupport( VMIMaskType sourceType, VMIMaskType resultType, std::string *reason) const { auto fail = [&](const Twine &message) - -> FailureOr { + -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -475,18 +625,27 @@ VMILocalRecipeRegistry::getMaskGranularityMaterializationRecipe( !VMIMaskType::isConcreteGranularity(resultType.getGranularity())) return fail("requires concrete b8/b16/b32 source and result granularities"); if (sourceType.getGranularity() == resultType.getGranularity()) - return VMIMaskGranularityMaterializationRecipe{ - VMIMaskGranularityMaterializationRecipeKind::Identity}; + return VMIMaskGranularityMaterializationSupport{ + VMIMaskGranularityMaterializationSupportKind::Identity}; - return VMIMaskGranularityMaterializationRecipe{ - VMIMaskGranularityMaterializationRecipeKind::PredicateCast}; + return VMIMaskGranularityMaterializationSupport{ + VMIMaskGranularityMaterializationSupportKind::PredicateCast}; +} + +LogicalResult VMILayoutSupport::canMaterializeMaskGranularity( + VMIMaskType sourceType, VMIMaskType resultType, + std::string *reason) const { + if (failed(getMaskGranularityMaterializationSupport(sourceType, resultType, + reason))) + return failure(); + return success(); } -FailureOr -VMILocalRecipeRegistry::getGroupSlotLoadRecipe( +FailureOr +VMILayoutSupport::getGroupSlotLoadSupport( const VMITargetCapabilityRegistry &capabilities, VMIGroupSlotLoadOp op, std::string *reason) const { - auto fail = [&](const Twine &message) -> FailureOr { + auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -515,8 +674,8 @@ VMILocalRecipeRegistry::getGroupSlotLoadRecipe( if (!stride || *stride != 1) return fail("slots=8 group_slot_load requires constant unit " "source_group_stride"); - return VMIGroupSlotLoadRecipe{ - VMIGroupSlotLoadRecipeKind::Slots8UnitStrideVsldb}; + return VMIGroupSlotLoadSupport{ + VMIGroupSlotLoadSupportKind::Slots8UnitStrideVsldb}; } unsigned elementBits = @@ -533,14 +692,14 @@ VMILocalRecipeRegistry::getGroupSlotLoadRecipe( " elements for 32B load alignment; packed or unaligned " "scalar load lowering is not implemented"); - return VMIGroupSlotLoadRecipe{ - VMIGroupSlotLoadRecipeKind::Slots1AlignedLane0Vsldb}; + return VMIGroupSlotLoadSupport{ + VMIGroupSlotLoadSupportKind::Slots1AlignedLane0Vsldb}; } -FailureOr VMILocalRecipeRegistry::getGroupLoadRecipe( +FailureOr VMILayoutSupport::getGroupLoadSupport( const VMITargetCapabilityRegistry &capabilities, VMIGroupLoadOp op, std::string *reason) const { - auto fail = [&](const Twine &message) -> FailureOr { + auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -584,16 +743,16 @@ FailureOr VMILocalRecipeRegistry::getGroupLoadRecipe( fullChunkReason); if (*groupSize == 16) - return VMIGroupLoadRecipe{VMIGroupLoadRecipeKind::S16Block8Vsldb}; - return VMIGroupLoadRecipe{VMIGroupLoadRecipeKind::S32Block8Vsldb}; + return VMIGroupLoadSupport{VMIGroupLoadSupportKind::S16Block8Vsldb}; + return VMIGroupLoadSupport{VMIGroupLoadSupportKind::S32Block8Vsldb}; } -FailureOr -VMILocalRecipeRegistry::getGroupSlotsStoreRecipe( +FailureOr +VMILayoutSupport::getGroupSlotsStoreSupport( const VMITargetCapabilityRegistry &capabilities, VMIGroupStoreOp op, std::string *reason) const { auto fail = - [&](const Twine &message) -> FailureOr { + [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -636,8 +795,8 @@ VMILocalRecipeRegistry::getGroupSlotsStoreRecipe( Twine(alignedStrideElems) + " elements for 32B store alignment; packed or unaligned " "contiguous store lowering is not implemented"); - return VMIGroupSlotsStoreRecipe{ - VMIGroupSlotsStoreRecipeKind::Slots1AlignedLane0Vsts}; + return VMIGroupSlotsStoreSupport{ + VMIGroupSlotsStoreSupportKind::Slots1AlignedLane0Vsts}; } if (layout.getSlots() == 8) { @@ -648,23 +807,23 @@ VMILocalRecipeRegistry::getGroupSlotsStoreRecipe( if (*arity != ceilDivNonNegative(numGroups, 8)) return fail("slots=8 group_store arity must equal ceil(num_groups / " "8)"); - return VMIGroupSlotsStoreRecipe{ - VMIGroupSlotsStoreRecipeKind::Slots8UnitStrideVsts}; + return VMIGroupSlotsStoreSupport{ + VMIGroupSlotsStoreSupportKind::Slots8UnitStrideVsts}; } return fail("group_slots group_store currently supports only slots=1 or " "unit-stride slots=8"); } -FailureOr -getGroupReduceAddRecipeImpl(const VMITargetCapabilityRegistry &capabilities, +FailureOr +getGroupReduceAddSupportImpl(const VMITargetCapabilityRegistry &capabilities, Operation *op, VMIVRegType sourceType, VMIMaskType maskType, VMIVRegType resultType, int64_t numGroups, bool requiresReassoc, VMIReductionKind reductionKind, std::string *reason) { auto fail = - [&](const Twine &message) -> FailureOr { + [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -692,9 +851,9 @@ getGroupReduceAddRecipeImpl(const VMITargetCapabilityRegistry &capabilities, if (succeeded(groupSize) && resultLayout.getSlots() <= 0 && (*groupSize != vlaneElems && *groupSize != 2 * vlaneElems && *groupSize != 4 * vlaneElems)) - return fail("stable group_reduce_add slots=8 recipes support group " + return fail("stable group_reduce_add slots=8 support group " "sizes VLaneElems, 2*VLaneElems, or 4*VLaneElems"); - return fail("stable group_reduce_add local recipes currently require " + return fail("stable group_reduce_add layout support currently requires " "result layout slots=8 or slots=1"); } @@ -704,7 +863,7 @@ getGroupReduceAddRecipeImpl(const VMITargetCapabilityRegistry &capabilities, if (!elementCapability.isSupported()) return fail(elementCapability.reason); if (sourceType.getElementType() != resultType.getElementType()) - return fail("stable group_reduce_add local recipes require matching " + return fail("stable group_reduce_add layout support requires matching " "source/result element types"); if (sourceType.getElementCount() != resultType.getElementCount()) return fail("requires source/result lane count to match"); @@ -730,7 +889,7 @@ getGroupReduceAddRecipeImpl(const VMITargetCapabilityRegistry &capabilities, if (resultLayout.getSlots() == 1) { if (failed(lanesPerPart) || *groupSize < *lanesPerPart || *groupSize % *lanesPerPart != 0) - return fail("stable group_reduce_add slots=1 recipes support group " + return fail("stable group_reduce_add slots=1 support group " "sizes that are multiples of one physical chunk"); if (!sourceLayout.isContiguous() || !maskLayout.isContiguous()) return fail("slots=1 group_reduce_add requires contiguous source/mask " @@ -743,8 +902,8 @@ getGroupReduceAddRecipeImpl(const VMITargetCapabilityRegistry &capabilities, return fail(Twine("slots=1 group_reduce_add requires full source " "chunks; ") + sourceFullReason); - return VMIGroupReduceAddFRecipe{ - VMIGroupReduceAddFRecipeKind::ContiguousVcaddRows}; + return VMIGroupReduceAddFSupport{ + VMIGroupReduceAddFSupportKind::ContiguousVcaddRows}; } if (*groupSize == vlaneElems) { @@ -759,8 +918,8 @@ getGroupReduceAddRecipeImpl(const VMITargetCapabilityRegistry &capabilities, if (*resultArity != *sourceArity) return fail("one-vlane group_reduce_add requires source/result physical " "arity to match"); - return VMIGroupReduceAddFRecipe{ - VMIGroupReduceAddFRecipeKind::OneVLaneVcgadd}; + return VMIGroupReduceAddFSupport{ + VMIGroupReduceAddFSupportKind::OneVLaneVcgadd}; } if (*groupSize == 2 * vlaneElems) { @@ -778,8 +937,8 @@ getGroupReduceAddRecipeImpl(const VMITargetCapabilityRegistry &capabilities, *sourceArity != *resultArity * 2) return fail("two-vlane group_reduce_add requires two source/mask parts per " "result part"); - return VMIGroupReduceAddFRecipe{ - VMIGroupReduceAddFRecipeKind::TwoVLaneDeinterleaved2VcgaddVadd}; + return VMIGroupReduceAddFSupport{ + VMIGroupReduceAddFSupportKind::TwoVLaneDeinterleaved2VcgaddVadd}; } if (*groupSize == 4 * vlaneElems) { @@ -797,19 +956,19 @@ getGroupReduceAddRecipeImpl(const VMITargetCapabilityRegistry &capabilities, *sourceArity != *resultArity * 4) return fail("four-vlane group_reduce_add requires four source/mask parts per " "result part"); - return VMIGroupReduceAddFRecipe{ - VMIGroupReduceAddFRecipeKind::FourVLaneDeinterleaved4VcgaddTree}; + return VMIGroupReduceAddFSupport{ + VMIGroupReduceAddFSupportKind::FourVLaneDeinterleaved4VcgaddTree}; } - return fail("stable group_reduce_add slots=8 recipes support group sizes " + return fail("stable group_reduce_add slots=8 support group sizes " "VLaneElems, 2*VLaneElems, or 4*VLaneElems"); } -FailureOr -VMILocalRecipeRegistry::getGroupReduceAddFRecipe( +FailureOr +VMILayoutSupport::getGroupReduceAddFSupport( const VMITargetCapabilityRegistry &capabilities, VMIGroupReduceAddFOp op, std::string *reason) const { - return getGroupReduceAddRecipeImpl( + return getGroupReduceAddSupportImpl( capabilities, op.getOperation(), cast(op.getSource().getType()), cast(op.getMask().getType()), cast(op.getResult().getType()), @@ -817,11 +976,11 @@ VMILocalRecipeRegistry::getGroupReduceAddFRecipe( VMIReductionKind::GroupAddF, reason); } -FailureOr -VMILocalRecipeRegistry::getGroupReduceAddIRecipe( +FailureOr +VMILayoutSupport::getGroupReduceAddISupport( const VMITargetCapabilityRegistry &capabilities, VMIGroupReduceAddIOp op, std::string *reason) const { - return getGroupReduceAddRecipeImpl( + return getGroupReduceAddSupportImpl( capabilities, op.getOperation(), cast(op.getSource().getType()), cast(op.getMask().getType()), cast(op.getResult().getType()), @@ -829,12 +988,12 @@ VMILocalRecipeRegistry::getGroupReduceAddIRecipe( VMIReductionKind::GroupAddI, reason); } -FailureOr -VMILocalRecipeRegistry::getGroupBroadcastRecipe( +FailureOr +VMILayoutSupport::getGroupBroadcastSupport( const VMITargetCapabilityRegistry &capabilities, VMIGroupBroadcastOp op, std::string *reason) const { (void)capabilities; - auto fail = [&](const Twine &message) -> FailureOr { + auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -888,30 +1047,30 @@ VMILocalRecipeRegistry::getGroupBroadcastRecipe( if (failed(resultFactor)) return fail("requires known result layout factor"); if (*resultFactor == 1) - return VMIGroupBroadcastRecipe{ - VMIGroupBroadcastRecipeKind::GroupSlotsVselr}; + return VMIGroupBroadcastSupport{ + VMIGroupBroadcastSupportKind::GroupSlotsVselr}; bool blockFragmentSmallGroup = resultLayout.isDeinterleaved() && resultLayout.getBlockElems() > 1 && *groupSize < *lanesPerPart && *lanesPerPart % resultLayout.getBlockElems() == 0; if (blockFragmentSmallGroup) - return VMIGroupBroadcastRecipe{ - VMIGroupBroadcastRecipeKind::GroupSlotsVselr}; + return VMIGroupBroadcastSupport{ + VMIGroupBroadcastSupportKind::GroupSlotsVselr}; int64_t logicalSpanPerResultChunk = *lanesPerPart * *resultFactor; if (*groupSize < *lanesPerPart || *groupSize % logicalSpanPerResultChunk != 0) return fail("deinterleaved result requires every physical result chunk to " "stay within one logical group"); - return VMIGroupBroadcastRecipe{ - VMIGroupBroadcastRecipeKind::GroupSlotsVselr}; + return VMIGroupBroadcastSupport{ + VMIGroupBroadcastSupportKind::GroupSlotsVselr}; } -FailureOr -VMILocalRecipeRegistry::getTruncFRecipe(VMITruncFOp op, +FailureOr +VMILayoutSupport::getTruncFSupport(VMITruncFOp op, std::string *reason) const { - auto fail = [&](const Twine &message) -> FailureOr { + auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -928,10 +1087,9 @@ VMILocalRecipeRegistry::getTruncFRecipe(VMITruncFOp op, return fail("requires assigned source/result layouts and computable " "physical arity"); - unsigned resultBits = - pto::getPTOStorageElemBitWidth(resultType.getElementType()); - if (sourceLayout.isGroupSlots() || resultLayout.isGroupSlots()) { + unsigned resultBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); if (!sourceLayout.isGroupSlots() || !resultLayout.isGroupSlots() || sourceLayout.getNumGroups() != resultLayout.getNumGroups() || sourceLayout.getSlots() != 1 || resultLayout.getSlots() != 1 || @@ -940,28 +1098,37 @@ VMILocalRecipeRegistry::getTruncFRecipe(VMITruncFOp op, return fail("group-slot truncf requires matching " "group_slots(num_groups=G, slots=1) source/result layouts, " "f32 source, f16 result, and matching physical arity"); - return VMITruncFRecipe{VMITruncFRecipeKind::GroupSlots1F32ToF16}; + return VMITruncFSupport{VMITruncFSupportKind::GroupSlots1F32ToF16}; } if (!sourceLayout.isDeinterleaved() || !resultLayout.isContiguous() || !sourceType.getElementType().isF32() || *resultArity != 1) return fail("requires f32 deinterleaved source and contiguous result"); - if (sourceLayout.getFactor() == 2 && *sourceArity == 2 && resultBits == 16) - return VMITruncFRecipe{ - VMITruncFRecipeKind::Deinterleaved2F32ToContiguousF16}; - if (sourceLayout.getFactor() == 4 && *sourceArity == 4 && resultBits == 8) - return VMITruncFRecipe{ - VMITruncFRecipeKind::Deinterleaved4F32ToContiguousF8}; + FailureOr fact = + getPreferredCastLayoutFact(sourceType, resultType, reason); + if (failed(fact) || (fact->kind != VMICastLayoutKind::Narrow2x && + fact->kind != VMICastLayoutKind::Narrow4x)) + return fail("unsupported deinterleaved truncf factor, arity, or result " + "element width"); + + if (fact->kind == VMICastLayoutKind::Narrow2x && + sourceLayout.getFactor() == fact->factor && *sourceArity == fact->factor) + return VMITruncFSupport{ + VMITruncFSupportKind::Deinterleaved2F32ToContiguousF16}; + if (fact->kind == VMICastLayoutKind::Narrow4x && + sourceLayout.getFactor() == fact->factor && *sourceArity == fact->factor) + return VMITruncFSupport{ + VMITruncFSupportKind::Deinterleaved4F32ToContiguousF8}; return fail("unsupported deinterleaved truncf factor, arity, or result " "element width"); } -FailureOr -VMILocalRecipeRegistry::getExtFRecipe(VMIExtFOp op, +FailureOr +VMILayoutSupport::getExtFSupport(VMIExtFOp op, std::string *reason) const { - auto fail = [&](const Twine &message) -> FailureOr { + auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -982,25 +1149,32 @@ VMILocalRecipeRegistry::getExtFRecipe(VMIExtFOp op, return fail("requires contiguous source layout and deinterleaved f32 " "result layout"); - unsigned sourceBits = - pto::getPTOStorageElemBitWidth(sourceType.getElementType()); - if (sourceBits == 16 && resultLayout.getFactor() == 2 && - *resultArity == 2 * *sourceArity) - return VMIExtFRecipe{ - VMIExtFRecipeKind::ContiguousF16ToDeinterleaved2F32}; - if (sourceBits == 8 && resultLayout.getFactor() == 4 && - *resultArity == 4 * *sourceArity) - return VMIExtFRecipe{ - VMIExtFRecipeKind::ContiguousF8ToDeinterleaved4F32}; + FailureOr fact = + getPreferredCastLayoutFact(sourceType, resultType, reason); + if (failed(fact) || (fact->kind != VMICastLayoutKind::Widen2x && + fact->kind != VMICastLayoutKind::Widen4x)) + return fail("unsupported extf source element width, result factor, or " + "physical arity"); + + if (fact->kind == VMICastLayoutKind::Widen2x && + resultLayout.getFactor() == fact->factor && + *resultArity == fact->factor * *sourceArity) + return VMIExtFSupport{ + VMIExtFSupportKind::ContiguousF16ToDeinterleaved2F32}; + if (fact->kind == VMICastLayoutKind::Widen4x && + resultLayout.getFactor() == fact->factor && + *resultArity == fact->factor * *sourceArity) + return VMIExtFSupport{ + VMIExtFSupportKind::ContiguousF8ToDeinterleaved4F32}; return fail("unsupported extf source element width, result factor, or " "physical arity"); } template -static FailureOr getExtIRecipeImpl(OpT op, +static FailureOr getExtISupportImpl(OpT op, std::string *reason) { - auto fail = [&](const Twine &message) -> FailureOr { + auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -1022,39 +1196,45 @@ static FailureOr getExtIRecipeImpl(OpT op, return fail("requires contiguous integer source layout and deinterleaved " "integer result layout"); - unsigned sourceBits = - pto::getPTOStorageElemBitWidth(sourceType.getElementType()); - unsigned resultBits = - pto::getPTOStorageElemBitWidth(resultType.getElementType()); - if (sourceBits == 16 && resultBits == 32 && resultLayout.getFactor() == 2 && - *resultArity == 2 * *sourceArity) - return VMIExtIRecipe{ - VMIExtIRecipeKind::ContiguousI16ToDeinterleaved2I32}; - if (sourceBits == 8 && resultBits == 32 && resultLayout.getFactor() == 4 && - *resultArity == 4 * *sourceArity) - return VMIExtIRecipe{ - VMIExtIRecipeKind::ContiguousI8ToDeinterleaved4I32}; + FailureOr fact = + VMILayoutSupport().getPreferredCastLayoutFact(sourceType, resultType, + reason); + if (failed(fact) || (fact->kind != VMICastLayoutKind::Widen2x && + fact->kind != VMICastLayoutKind::Widen4x)) + return fail("unsupported integer extension source/result element width, " + "result factor, or physical arity"); + + if (fact->kind == VMICastLayoutKind::Widen2x && + resultLayout.getFactor() == fact->factor && + *resultArity == fact->factor * *sourceArity) + return VMIExtISupport{ + VMIExtISupportKind::ContiguousI16ToDeinterleaved2I32}; + if (fact->kind == VMICastLayoutKind::Widen4x && + resultLayout.getFactor() == fact->factor && + *resultArity == fact->factor * *sourceArity) + return VMIExtISupport{ + VMIExtISupportKind::ContiguousI8ToDeinterleaved4I32}; return fail("unsupported integer extension source/result element width, " "result factor, or physical arity"); } -FailureOr -VMILocalRecipeRegistry::getExtSIRecipe(VMIExtSIOp op, +FailureOr +VMILayoutSupport::getExtSISupport(VMIExtSIOp op, std::string *reason) const { - return getExtIRecipeImpl(op, reason); + return getExtISupportImpl(op, reason); } -FailureOr -VMILocalRecipeRegistry::getExtUIRecipe(VMIExtUIOp op, +FailureOr +VMILayoutSupport::getExtUISupport(VMIExtUIOp op, std::string *reason) const { - return getExtIRecipeImpl(op, reason); + return getExtISupportImpl(op, reason); } -FailureOr -VMILocalRecipeRegistry::getTruncIRecipe(VMITruncIOp op, +FailureOr +VMILayoutSupport::getTruncISupport(VMITruncIOp op, std::string *reason) const { - auto fail = [&](const Twine &message) -> FailureOr { + auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -1088,31 +1268,42 @@ VMILocalRecipeRegistry::getTruncIRecipe(VMITruncIOp op, "group_slots(num_groups=G, slots=1) source/result layouts, " "32-bit integer source, 16-bit integer result, and matching " "physical arity"); - return VMITruncIRecipe{VMITruncIRecipeKind::GroupSlots1I32ToI16}; + return VMITruncISupport{VMITruncISupportKind::GroupSlots1I32ToI16}; } if (!sourceLayout.isDeinterleaved() || !resultLayout.isContiguous() || - sourceBits != 32 || *resultArity != 1) - return fail("requires 32-bit integer deinterleaved source and contiguous " + *resultArity != 1) + return fail("requires integer deinterleaved source and contiguous " "integer result"); - if (sourceLayout.getFactor() == 2 && *sourceArity == 2 && resultBits == 16) - return VMITruncIRecipe{ - VMITruncIRecipeKind::Deinterleaved2I32ToContiguousI16}; - if (sourceLayout.getFactor() == 4 && *sourceArity == 4 && resultBits == 8 && + FailureOr fact = + getPreferredCastLayoutFact(sourceType, resultType, reason); + if (failed(fact) || (fact->kind != VMICastLayoutKind::Narrow2x && + fact->kind != VMICastLayoutKind::Narrow4x)) + return fail("unsupported deinterleaved trunci factor, arity, result " + "element width, or result signedness; 32-bit to 8-bit integer " + "narrowing requires unsigned i8 result"); + + if (fact->kind == VMICastLayoutKind::Narrow2x && + sourceLayout.getFactor() == fact->factor && *sourceArity == fact->factor) + return VMITruncISupport{ + VMITruncISupportKind::Deinterleaved2I32ToContiguousI16}; + if (fact->kind == VMICastLayoutKind::Narrow4x && + sourceLayout.getFactor() == fact->factor && + *sourceArity == fact->factor && cast(resultType.getElementType()).isUnsigned()) - return VMITruncIRecipe{ - VMITruncIRecipeKind::Deinterleaved4I32ToContiguousI8}; + return VMITruncISupport{ + VMITruncISupportKind::Deinterleaved4I32ToContiguousI8}; return fail("unsupported deinterleaved trunci factor, arity, result element " "width, or result signedness; 32-bit to 8-bit integer narrowing " "requires unsigned i8 result"); } -FailureOr -VMILocalRecipeRegistry::getBitcastRecipe(VMIBitcastOp op, +FailureOr +VMILayoutSupport::getBitcastSupport(VMIBitcastOp op, std::string *reason) const { - auto fail = [&](const Twine &message) -> FailureOr { + auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -1151,5 +1342,5 @@ VMILocalRecipeRegistry::getBitcastRecipe(VMIBitcastOp op, "chunk"); } - return VMIBitcastRecipe{VMIBitcastRecipeKind::PerPartVbitcast}; + return VMIBitcastSupport{VMIBitcastSupportKind::PerPartVbitcast}; } diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index c44fc114ec..7f10e39ea6 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -18,7 +18,7 @@ #include "PTO/IR/PTOTypeUtils.h" #include "PTO/IR/VMIUtils.h" #include "PTO/Transforms/Passes.h" -#include "PTO/Transforms/VMILocalRecipeRegistry.h" +#include "PTO/Transforms/VMILayoutSupport.h" #include "PTO/Transforms/VMITargetCapabilities.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -1202,8 +1202,8 @@ checkSupportedGroupLoadShape(const VMITargetCapabilityRegistry &capabilities, if (resultLayout.isDeinterleaved() && resultLayout.getBlockElems() == 8 && resultType.getElementType().isF32()) { - VMILocalRecipeRegistry recipes; - if (failed(recipes.getGroupLoadRecipe(capabilities, op, reason))) + VMILayoutSupport supports; + if (failed(supports.getGroupLoadSupport(capabilities, op, reason))) return failure(); return success(); } @@ -1214,8 +1214,8 @@ checkSupportedGroupLoadShape(const VMITargetCapabilityRegistry &capabilities, LogicalResult checkSupportedGroupSlotLoadShape( const VMITargetCapabilityRegistry &capabilities, VMIGroupSlotLoadOp op, std::string *reason) { - VMILocalRecipeRegistry recipes; - if (failed(recipes.getGroupSlotLoadRecipe(capabilities, op, reason))) + VMILayoutSupport supports; + if (failed(supports.getGroupSlotLoadSupport(capabilities, op, reason))) return failure(); return success(); } @@ -1243,8 +1243,8 @@ checkSupportedGroupStoreShape(const VMITargetCapabilityRegistry &capabilities, if (!accessPlan.targetCapability.isSupported()) return fail(accessPlan.targetCapability.reason); - VMILocalRecipeRegistry recipes; - if (failed(recipes.getGroupSlotsStoreRecipe(capabilities, op, reason))) + VMILayoutSupport supports; + if (failed(supports.getGroupSlotsStoreSupport(capabilities, op, reason))) return failure(); return success(); } @@ -2424,171 +2424,6 @@ FailureOr createGroupSlotIndexVector(Location loc, VRegType indexType, return result; } -LogicalResult checkVcgaddGroupReduceShape(VMIVRegType sourceType, - VMIMaskType maskType, - VMIVRegType resultType, - int64_t groupSize, - std::string *reason) { - auto fail = [&](const Twine &message) -> LogicalResult { - if (reason) - *reason = message.str(); - return failure(); - }; - - if (sourceType.getElementType() != resultType.getElementType()) - return fail("vcgadd group_reduce_add path requires matching " - "source/result element types"); - FailureOr lanesPerPart = - getDataLanesPerPart(sourceType.getElementType()); - if (failed(lanesPerPart) || *lanesPerPart % 8 != 0) - return fail("vcgadd group_reduce_add path requires known VLane width"); - int64_t vlaneElems = *lanesPerPart / 8; - if (groupSize != vlaneElems) - return fail("vcgadd group_reduce_add path requires group size equal to " - "one 32-byte VLane"); - VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); - VMILayoutAttr resultLayout = resultType.getLayoutAttr(); - VMILayoutAttr maskLayout = maskType.getLayoutAttr(); - int64_t numGroups = sourceType.getElementCount() / groupSize; - if (!sourceLayout || !resultLayout || !maskLayout || - !sourceLayout.isContiguous() || !resultLayout.isGroupSlots() || - resultLayout.getNumGroups() != numGroups || !maskLayout.isContiguous()) - return fail("vcgadd group_reduce_add path requires contiguous source/mask " - "layouts and matching num_groups result layout"); - std::string sourceFullReason; - if (failed(checkFullDataPhysicalChunks(sourceType, &sourceFullReason))) - return fail(Twine("vcgadd group_reduce_add path requires full source " - "chunks; ") + - sourceFullReason); - FailureOr sourceArity = getVMIPhysicalArity(sourceType); - FailureOr maskArity = getVMIPhysicalArity(maskType); - FailureOr resultArity = getVMIPhysicalArity(resultType); - if (failed(sourceArity) || failed(maskArity) || failed(resultArity)) - return fail("vcgadd group_reduce_add path requires computable physical " - "arity"); - if (*sourceArity < 1 || *sourceArity != *maskArity || - *sourceArity != *resultArity) - return fail("vcgadd group_reduce_add path requires matching non-empty " - "source/mask/result physical arity"); - return success(); -} - -template -LogicalResult checkS16Block8GroupReduceShape(OpTy op, - std::string *reason) { - auto fail = [&](const Twine &message) -> LogicalResult { - if (reason) - *reason = message.str(); - return failure(); - }; - - auto sourceType = cast(op.getSource().getType()); - auto maskType = cast(op.getMask().getType()); - auto resultType = cast(op.getResult().getType()); - if (sourceType.getElementType() != resultType.getElementType()) - return fail("two-vlane group_reduce_add requires matching source/result " - "element types"); - - FailureOr groupSize = - getGroupSizeFromNumGroups(sourceType, op.getNumGroupsAttr().getInt()); - FailureOr lanesPerPart = - getDataLanesPerPart(sourceType.getElementType()); - if (failed(groupSize) || failed(lanesPerPart) || *lanesPerPart % 8 != 0 || - *groupSize != 2 * (*lanesPerPart / 8)) - return fail("two-vlane group_reduce_add requires group size equal to two " - "32-byte VLanes"); - - VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); - VMILayoutAttr maskLayout = maskType.getLayoutAttr(); - VMILayoutAttr resultLayout = resultType.getLayoutAttr(); - int64_t numGroups = op.getNumGroupsAttr().getInt(); - if (!sourceLayout || !sourceLayout.isDeinterleaved() || - sourceLayout.getFactor() != 2 || - (sourceLayout.getBlockElems() != 1 && sourceLayout.getBlockElems() != 8)) - return fail("two-vlane group_reduce_add requires source layout " - "deinterleaved=2 with block_elems=1 or block_elems=8"); - if (!maskLayout || !maskLayout.isDeinterleaved() || - maskLayout.getFactor() != 2 || - maskLayout.getBlockElems() != sourceLayout.getBlockElems()) - return fail("two-vlane group_reduce_add requires matching mask layout " - "deinterleaved=2 with the same block_elems"); - if (!resultLayout || !resultLayout.isGroupSlots() || - resultLayout.getNumGroups() != numGroups || resultLayout.getSlots() != 8) - return fail("two-vlane group_reduce_add requires " - "group_slots(num_groups, slots=8) result layout"); - FailureOr sourceArity = getVMIPhysicalArity(sourceType); - FailureOr maskArity = getVMIPhysicalArity(maskType); - FailureOr resultArity = getVMIPhysicalArity(resultType); - if (failed(sourceArity) || failed(maskArity) || failed(resultArity)) - return fail("two-vlane group_reduce_add requires computable physical " - "arity"); - int64_t expectedResultArity = ceilDivNonNegative(numGroups, 8); - if (*resultArity != expectedResultArity || *sourceArity != *resultArity * 2 || - *maskArity != *sourceArity) - return fail("two-vlane group_reduce_add requires two source/mask " - "parts per result part"); - - return success(); -} - -template -LogicalResult checkS32Block8GroupReduceShape(OpTy op, - std::string *reason) { - auto fail = [&](const Twine &message) -> LogicalResult { - if (reason) - *reason = message.str(); - return failure(); - }; - - auto sourceType = cast(op.getSource().getType()); - auto maskType = cast(op.getMask().getType()); - auto resultType = cast(op.getResult().getType()); - if (sourceType.getElementType() != resultType.getElementType()) - return fail("four-vlane group_reduce_add requires matching source/result " - "element types"); - - FailureOr groupSize = - getGroupSizeFromNumGroups(sourceType, op.getNumGroupsAttr().getInt()); - FailureOr lanesPerPart = - getDataLanesPerPart(sourceType.getElementType()); - if (failed(groupSize) || failed(lanesPerPart) || *lanesPerPart % 8 != 0 || - *groupSize != 4 * (*lanesPerPart / 8)) - return fail("four-vlane group_reduce_add requires group size equal to four " - "32-byte VLanes"); - - VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); - VMILayoutAttr maskLayout = maskType.getLayoutAttr(); - VMILayoutAttr resultLayout = resultType.getLayoutAttr(); - int64_t numGroups = op.getNumGroupsAttr().getInt(); - if (!sourceLayout || !sourceLayout.isDeinterleaved() || - sourceLayout.getFactor() != 4 || - (sourceLayout.getBlockElems() != 1 && sourceLayout.getBlockElems() != 8)) - return fail("four-vlane group_reduce_add requires source layout " - "deinterleaved=4 with block_elems=1 or block_elems=8"); - if (!maskLayout || !maskLayout.isDeinterleaved() || - maskLayout.getFactor() != 4 || - maskLayout.getBlockElems() != sourceLayout.getBlockElems()) - return fail("four-vlane group_reduce_add requires matching mask layout " - "deinterleaved=4 with the same block_elems"); - if (!resultLayout || !resultLayout.isGroupSlots() || - resultLayout.getNumGroups() != numGroups || resultLayout.getSlots() != 8) - return fail("four-vlane group_reduce_add requires " - "group_slots(num_groups, slots=8) result layout"); - FailureOr sourceArity = getVMIPhysicalArity(sourceType); - FailureOr maskArity = getVMIPhysicalArity(maskType); - FailureOr resultArity = getVMIPhysicalArity(resultType); - if (failed(sourceArity) || failed(maskArity) || failed(resultArity)) - return fail("four-vlane group_reduce_add requires computable physical " - "arity"); - int64_t expectedResultArity = ceilDivNonNegative(numGroups, 8); - if (*resultArity != expectedResultArity || *sourceArity != *resultArity * 4 || - *maskArity != *sourceArity) - return fail("four-vlane group_reduce_add requires four source/mask " - "parts per result part"); - - return success(); -} - std::optional getX2MemoryDistToken(Type elementType, StringRef prefix) { unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); @@ -3230,13 +3065,13 @@ struct OneToNVMIEnsureLayoutOpPattern OneToNPatternRewriter &rewriter) const override { auto sourceType = cast(op.getSource().getType()); auto resultType = cast(op.getResult().getType()); - VMILocalRecipeRegistry recipes; - std::string recipeReason; - if (failed(recipes.getDataLayoutMaterializationRecipe( - sourceType, resultType, &recipeReason))) + VMILayoutSupport supports; + std::string supportReason; + if (failed(supports.canMaterializeDataLayout(sourceType, resultType, + &supportReason))) return rewriter.notifyMatchFailure( - op, Twine("ensure_layout has no registered materialization recipe: ") + - recipeReason); + op, Twine("ensure_layout has no registered materialization support: ") + + supportReason); VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); VMILayoutAttr resultLayout = resultType.getLayoutAttr(); if (!sourceLayout || !resultLayout) @@ -3264,14 +3099,14 @@ struct OneToNVMIEnsureMaskLayoutOpPattern OneToNPatternRewriter &rewriter) const override { auto sourceType = cast(op.getSource().getType()); auto resultType = cast(op.getResult().getType()); - VMILocalRecipeRegistry recipes; - std::string recipeReason; - if (failed(recipes.getMaskLayoutMaterializationRecipe( - sourceType, resultType, &recipeReason))) + VMILayoutSupport supports; + std::string supportReason; + if (failed(supports.canMaterializeMaskLayout(sourceType, resultType, + &supportReason))) return rewriter.notifyMatchFailure( op, - Twine("ensure_mask_layout has no registered materialization recipe: ") + - recipeReason); + Twine("ensure_mask_layout has no registered materialization support: ") + + supportReason); if (sourceType.getGranularity() != resultType.getGranularity()) return rewriter.notifyMatchFailure( op, "mask layout helper cannot also change granularity"); @@ -3303,14 +3138,14 @@ struct OneToNVMIEnsureMaskGranularityOpPattern OneToNPatternRewriter &rewriter) const override { auto sourceType = cast(op.getSource().getType()); auto resultType = cast(op.getResult().getType()); - VMILocalRecipeRegistry recipes; - std::string recipeReason; - if (failed(recipes.getMaskGranularityMaterializationRecipe( - sourceType, resultType, &recipeReason))) + VMILayoutSupport supports; + std::string supportReason; + if (failed(supports.canMaterializeMaskGranularity(sourceType, resultType, + &supportReason))) return rewriter.notifyMatchFailure( op, Twine("ensure_mask_granularity has no registered materialization " - "recipe: ") + - recipeReason); + "support: ") + + supportReason); if (sourceType.getLayout() != resultType.getLayout()) return rewriter.notifyMatchFailure( op, "mask granularity helper cannot also change layout"); @@ -4503,12 +4338,12 @@ struct OneToNVMIStoreOpPattern : OneToNOpConversionPattern { return failure(); ValueRange valueParts = adaptor.getValue(); - VMILocalRecipeRegistry localRecipes; - FailureOr storeRecipe = - localRecipes.getContiguousStoreRecipe(valueVMIType); - if (succeeded(storeRecipe) && - storeRecipe->kind == - VMIContiguousStoreRecipeKind::Deinterleaved2Vstsx2) { + VMILayoutSupport localSupports; + FailureOr storeSupport = + localSupports.getContiguousStoreSupport(valueVMIType); + if (succeeded(storeSupport) && + storeSupport->kind == + VMIContiguousStoreSupportKind::Deinterleaved2Vstsx2) { std::optional dist = getX2MemoryDistToken(valueVMIType.getElementType(), "INTLV"); if (dist && !valueParts.empty() && valueParts.size() % 2 == 0) { @@ -4961,12 +4796,12 @@ struct OneToNVMITileWriteOpPattern : OneToNOpConversionPattern { ValueRange valueParts = adaptor.getValue(); Value zero = rewriter.create(op.getLoc(), 0); - VMILocalRecipeRegistry localRecipes; - FailureOr storeRecipe = - localRecipes.getContiguousStoreRecipe(valueVMIType); - if (succeeded(storeRecipe) && - storeRecipe->kind == - VMIContiguousStoreRecipeKind::Deinterleaved2Vstsx2) { + VMILayoutSupport localSupports; + FailureOr storeSupport = + localSupports.getContiguousStoreSupport(valueVMIType); + if (succeeded(storeSupport) && + storeSupport->kind == + VMIContiguousStoreSupportKind::Deinterleaved2Vstsx2) { std::optional dist = getX2MemoryDistToken(valueVMIType.getElementType(), "INTLV"); if (dist && !valueParts.empty() && valueParts.size() % 2 == 0) { @@ -5569,25 +5404,38 @@ struct OneToNVMIReduceAddFOpPattern template struct OneToNVMIGroupReduceAddOpPattern : OneToNOpConversionPattern { - using OneToNOpConversionPattern::OneToNOpConversionPattern; + OneToNVMIGroupReduceAddOpPattern( + TypeConverter &typeConverter, MLIRContext *context, + const VMITargetCapabilityRegistry &capabilities) + : OneToNOpConversionPattern(typeConverter, context), + capabilities(capabilities) {} LogicalResult matchAndRewrite(OpTy op, typename OneToNOpConversionPattern::OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override { auto sourceVMIType = cast(op.getSource().getType()); - auto maskVMIType = cast(op.getMask().getType()); auto resultVMIType = cast(op.getResult().getType()); ValueRange sourceParts = adaptor.getSource(); ValueRange maskParts = adaptor.getMask(); TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + + VMILayoutSupport supports; + std::string supportReason; + FailureOr support = + getSupport(supports, op, &supportReason); + if (failed(support)) + return rewriter.notifyMatchFailure( + op, Twine("group_reduce_add has no layout support: ") + + supportReason); + FailureOr groupSize = getGroupSizeFromNumGroups( sourceVMIType, op.getNumGroupsAttr().getInt()); if (failed(groupSize)) return rewriter.notifyMatchFailure( op, "group_reduce_addf requires num_groups to evenly divide lane count"); - if (succeeded(checkVcgaddGroupReduceShape( - sourceVMIType, maskVMIType, resultVMIType, *groupSize, nullptr))) { + + if (support->kind == VMIGroupReduceAddFSupportKind::OneVLaneVcgadd) { if (sourceParts.size() != maskParts.size() || sourceParts.size() != resultTypes.size() || sourceParts.empty()) return rewriter.notifyMatchFailure( @@ -5621,7 +5469,8 @@ struct OneToNVMIGroupReduceAddOpPattern : OneToNOpConversionPattern { return success(); } - if (succeeded(checkS16Block8GroupReduceShape(op, nullptr))) { + if (support->kind == + VMIGroupReduceAddFSupportKind::TwoVLaneDeinterleaved2VcgaddVadd) { int64_t resultPartCount = resultTypes.size(); if (static_cast(sourceParts.size()) != resultPartCount * 2 || maskParts.size() != sourceParts.size()) @@ -5674,7 +5523,8 @@ struct OneToNVMIGroupReduceAddOpPattern : OneToNOpConversionPattern { return success(); } - if (succeeded(checkS32Block8GroupReduceShape(op, nullptr))) { + if (support->kind == + VMIGroupReduceAddFSupportKind::FourVLaneDeinterleaved4VcgaddTree) { int64_t resultPartCount = resultTypes.size(); if (static_cast(sourceParts.size()) != resultPartCount * 4 || maskParts.size() != sourceParts.size()) @@ -5733,6 +5583,10 @@ struct OneToNVMIGroupReduceAddOpPattern : OneToNOpConversionPattern { return success(); } + if (support->kind != VMIGroupReduceAddFSupportKind::ContiguousVcaddRows) + return rewriter.notifyMatchFailure(op, + "unknown group_reduce_add support"); + int64_t lanesPerPart = 0; int64_t groupCount = 0; int64_t chunksPerGroup = 0; @@ -5815,6 +5669,21 @@ struct OneToNVMIGroupReduceAddOpPattern : OneToNOpConversionPattern { rewriter.replaceOp(op, results, adaptor.getResultMapping()); return success(); } + +private: + FailureOr + getSupport(VMILayoutSupport &supports, VMIGroupReduceAddFOp op, + std::string *reason) const { + return supports.getGroupReduceAddFSupport(capabilities, op, reason); + } + + FailureOr + getSupport(VMILayoutSupport &supports, VMIGroupReduceAddIOp op, + std::string *reason) const { + return supports.getGroupReduceAddISupport(capabilities, op, reason); + } + + const VMITargetCapabilityRegistry &capabilities; }; struct OneToNVMIGroupBroadcastOpPattern @@ -7006,8 +6875,6 @@ void populateVMIOneToNConversionPatterns( OneToNVMISelectOpPattern, OneToNVMIActivePrefixIndexOpPattern, OneToNVMICompressOpPattern, OneToNVMICompressStoreOpPattern, OneToNVMIReduceAddIOpPattern, OneToNVMIReduceAddFOpPattern, - OneToNVMIGroupReduceAddOpPattern, - OneToNVMIGroupReduceAddOpPattern, OneToNVMIGroupBroadcastOpPattern, OneToNVMIReduceMinMaxFOpPattern, OneToNVMIReduceMinMaxFOpPattern, @@ -7017,6 +6884,10 @@ void populateVMIOneToNConversionPatterns( OneToNVMIBitcastOpPattern, OneToNVMIChannelSplitOpPattern, OneToNVMIChannelMergeOpPattern, OneToNVMIShuffleOpPattern>( typeConverter, patterns.getContext()); + patterns + .add, + OneToNVMIGroupReduceAddOpPattern>( + typeConverter, patterns.getContext(), capabilities); patterns.add( typeConverter, patterns.getContext(), capabilities); } @@ -7059,47 +6930,47 @@ LogicalResult verifyNoResidualVMIIR(ModuleOp module) { LogicalResult checkSupportedExtFShape(VMIExtFOp op, std::string *reason = nullptr) { - VMILocalRecipeRegistry recipes; - if (failed(recipes.getExtFRecipe(op, reason))) + VMILayoutSupport supports; + if (failed(supports.getExtFSupport(op, reason))) return failure(); return success(); } LogicalResult checkSupportedTruncFShape(VMITruncFOp op, std::string *reason = nullptr) { - VMILocalRecipeRegistry recipes; - if (failed(recipes.getTruncFRecipe(op, reason))) + VMILayoutSupport supports; + if (failed(supports.getTruncFSupport(op, reason))) return failure(); return success(); } LogicalResult checkSupportedExtSIShape(VMIExtSIOp op, std::string *reason = nullptr) { - VMILocalRecipeRegistry recipes; - if (failed(recipes.getExtSIRecipe(op, reason))) + VMILayoutSupport supports; + if (failed(supports.getExtSISupport(op, reason))) return failure(); return success(); } LogicalResult checkSupportedExtUIShape(VMIExtUIOp op, std::string *reason = nullptr) { - VMILocalRecipeRegistry recipes; - if (failed(recipes.getExtUIRecipe(op, reason))) + VMILayoutSupport supports; + if (failed(supports.getExtUISupport(op, reason))) return failure(); return success(); } LogicalResult checkSupportedTruncIShape(VMITruncIOp op, std::string *reason = nullptr) { - VMILocalRecipeRegistry recipes; - if (failed(recipes.getTruncIRecipe(op, reason))) + VMILayoutSupport supports; + if (failed(supports.getTruncISupport(op, reason))) return failure(); return success(); } LogicalResult checkSupportedBitcastShape(VMIBitcastOp op, std::string *reason) { - VMILocalRecipeRegistry recipes; - if (failed(recipes.getBitcastRecipe(op, reason))) + VMILayoutSupport supports; + if (failed(supports.getBitcastSupport(op, reason))) return failure(); return success(); } @@ -7399,83 +7270,15 @@ template LogicalResult checkSupportedGroupReduceAddShape( const VMITargetCapabilityRegistry &capabilities, OpTy op, std::string *reason = nullptr) { - auto fail = [&](const Twine &message) -> LogicalResult { - if (reason) - *reason = message.str(); - return failure(); - }; - - if constexpr (std::is_same_v) { - if (!op->hasAttr("reassoc")) - return fail("requires reassoc attr for pair-wise floating-point reduction"); - } - auto sourceType = cast(op.getSource().getType()); - auto resultType = cast(op.getResult().getType()); - auto maskType = cast(op.getMask().getType()); - VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); - VMILayoutAttr resultLayout = resultType.getLayoutAttr(); - VMILayoutAttr maskLayout = maskType.getLayoutAttr(); - if (!sourceLayout || !resultLayout || !maskLayout) - return fail("requires assigned source, mask, and result layouts"); - - VMILocalRecipeRegistry recipes; + VMILayoutSupport supports; if constexpr (std::is_same_v) { - if (succeeded(recipes.getGroupReduceAddFRecipe(capabilities, op, nullptr))) + if (succeeded(supports.getGroupReduceAddFSupport(capabilities, op, reason))) return success(); } else { - if (succeeded(recipes.getGroupReduceAddIRecipe(capabilities, op, nullptr))) + if (succeeded(supports.getGroupReduceAddISupport(capabilities, op, reason))) return success(); } - - FailureOr groupSize = getGroupSizeFromNumGroups( - sourceType, op.getNumGroupsAttr().getInt(), reason); - if (failed(groupSize)) - return failure(); - if (succeeded(checkS16Block8GroupReduceShape(op, reason))) - return success(); - if (succeeded(checkS32Block8GroupReduceShape(op, reason))) - return success(); - if (!sourceLayout.isContiguous() || !resultLayout.isGroupSlots() || - resultLayout.getNumGroups() != op.getNumGroupsAttr().getInt() || - !maskLayout.isContiguous()) - return fail("requires contiguous source/mask layouts and matching " - "num_groups result layout"); - VMICapabilityResult elementCapability = - capabilities.supportsReductionElementType( - std::is_same_v ? VMIReductionKind::GroupAddF - : VMIReductionKind::GroupAddI, - sourceType.getElementType()); - if (!elementCapability.isSupported()) - return fail(elementCapability.reason); - if (sourceType.getElementType() != resultType.getElementType()) - return fail("requires source/result element type to match"); - if (sourceType.getElementCount() != resultType.getElementCount()) - return fail("requires source/result lane count to match"); - FailureOr sourceArity = getVMIPhysicalArity(sourceType); - FailureOr resultArity = getVMIPhysicalArity(resultType); - FailureOr maskArity = getVMIPhysicalArity(maskType); - if (failed(sourceArity) || failed(resultArity) || failed(maskArity)) - return fail("requires computable source/result/mask physical arity"); - if (*sourceArity != *resultArity || *sourceArity != *maskArity) - return fail("requires source/result/mask physical arity to match"); - if (succeeded(checkVcgaddGroupReduceShape(sourceType, maskType, resultType, - *groupSize, nullptr))) - return success(); - if (failed(checkSupportedGroupChunkShape(sourceType, *groupSize, reason))) - return failure(); - if (resultLayout.getSlots() <= 0) - return success(); - - FailureOr lanesPerPart = - getDataLanesPerPart(sourceType.getElementType()); - if (failed(lanesPerPart)) - return fail("requires known physical chunk lane count"); - if (!sourceLayout.isContiguous() || *groupSize != *lanesPerPart || - resultLayout.getSlots() != 1) - return fail("explicit group_slots group_reduce_add chunk path requires " - "contiguous full-physical-chunk group size source and slots=1 " - "result layout"); - return success(); + return failure(); } LogicalResult checkSupportedGroupBroadcastShape( @@ -7500,8 +7303,8 @@ LogicalResult checkSupportedGroupBroadcastShape( VMILayoutAttr resultLayout = resultType.getLayoutAttr(); if (!sourceLayout || !resultLayout) return fail("requires assigned source/result layouts"); - VMILocalRecipeRegistry recipes; - if (succeeded(recipes.getGroupBroadcastRecipe(capabilities, op, nullptr))) + VMILayoutSupport supports; + if (succeeded(supports.getGroupBroadcastSupport(capabilities, op, nullptr))) return success(); if (!sourceLayout.isGroupSlots() || sourceLayout.getNumGroups() != op.getNumGroupsAttr().getInt()) @@ -7708,7 +7511,7 @@ verifySupportedVMIToVPTOOps(ModuleOp module, broadcast.emitError() << kVMIDiagUnsupportedPrefix << "pto.vmi.group_broadcast requires full source chunks with " - "#pto.vmi.layout, a dense full result layout, " + "#pto.vmi.layout, a dense full result layout, " "and num_groups deriving a group size that divides or is a " "multiple of physical chunk lanes (" << reason << ")"; @@ -8117,7 +7920,7 @@ verifySupportedVMIToVPTOOps(ModuleOp module, << kVMIDiagUnsupportedPrefix << "pto.vmi.group_reduce_addf lowers through pto.vcgadd for 32B " "VLane groups or through pto.vcadd with reassoc, contiguous full " - "source/mask chunks, #pto.vmi.layout result " + "source/mask chunks, #pto.vmi.layout result " "chunks, and num_groups deriving a group size aligned to " "physical chunks (" << reason << ")"; diff --git a/test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto b/test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto index 51cd09053f..1cbdeea1d8 100644 --- a/test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto +++ b/test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto @@ -53,12 +53,12 @@ module { // ASSIGN: %[[COPY_DENSE:.*]] = pto.vmi.ensure_layout %[[COPY]] // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.store %[[COPY_DENSE]] -// ASSIGN: pto.vmi.create_group_mask +// ASSIGN: %[[MASK0:.*]] = pto.vmi.create_group_mask // ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[PROD:.*]] = pto.vmi.mulf %[[X]], %[[SCALE]] // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> -// ASSIGN: %[[MASK:.*]] = pto.vmi.create_group_mask -// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[PROD]], %[[MASK]] // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] diff --git a/test/lit/vmi/vmi_layout_assignment_broadcast_remat.pto b/test/lit/vmi/vmi_layout_assignment_broadcast_remat.pto index 6e165de8a0..eefe95d973 100644 --- a/test/lit/vmi/vmi_layout_assignment_broadcast_remat.pto +++ b/test/lit/vmi/vmi_layout_assignment_broadcast_remat.pto @@ -7,7 +7,7 @@ // See LICENSE in the root of the software repository for the full text of the License. // RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN -// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-rematerialize -vmi-to-vpto | FileCheck %s --check-prefix=LOWER module { func.func @vmi_layout_assignment_broadcast_remat( @@ -37,9 +37,8 @@ module { // ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.addf %[[BCAST_DEINT]], %[[WIDE]] // ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> -// ASSIGN-NOT: pto.vmi.ensure_layout %[[BCAST_DEINT]] -// ASSIGN: %[[BCAST_CONTIG:.*]] = pto.vmi.broadcast %[[SCALAR]] -// ASSIGN-SAME: f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[BCAST_CONTIG:.*]] = pto.vmi.ensure_layout %[[BCAST_DEINT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.store %[[BCAST_CONTIG]] // ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_constant_remat.pto b/test/lit/vmi/vmi_layout_assignment_constant_remat.pto index e387aa077d..a426621c15 100644 --- a/test/lit/vmi/vmi_layout_assignment_constant_remat.pto +++ b/test/lit/vmi/vmi_layout_assignment_constant_remat.pto @@ -7,7 +7,7 @@ // See LICENSE in the root of the software repository for the full text of the License. // RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN -// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-rematerialize -vmi-to-vpto | FileCheck %s --check-prefix=LOWER module { func.func @vmi_layout_assignment_constant_remat( @@ -37,10 +37,8 @@ module { // ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.addf %[[CONST_DEINT]], %[[WIDE]] // ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> -// ASSIGN-NOT: pto.vmi.ensure_layout %[[CONST_DEINT]] -// ASSIGN: %[[CONST_CONTIG:.*]] = "pto.vmi.constant"() -// ASSIGN-SAME: dense<1.000000e+00> : tensor<128xf32> -// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[CONST_CONTIG:.*]] = pto.vmi.ensure_layout %[[CONST_DEINT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.store %[[CONST_CONTIG]] // ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto b/test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto index 2bc648261f..5999ace148 100644 --- a/test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto +++ b/test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto @@ -31,10 +31,10 @@ module { // ASSIGN-LABEL: func.func @vmi_layout_assignment_create_group_mask_s16( // ASSIGN: %[[X:.*]] = pto.vmi.group_load // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> -// ASSIGN: pto.vmi.create_group_mask +// ASSIGN: %[[MASK0:.*]] = pto.vmi.create_group_mask // ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> -// ASSIGN: %[[MASK:.*]] = pto.vmi.create_group_mask -// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] +// ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> -> !pto.vmi.mask<128xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto b/test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto index cb0e15864e..fe5920c07b 100644 --- a/test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto +++ b/test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto @@ -7,7 +7,7 @@ // See LICENSE in the root of the software repository for the full text of the License. // RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN -// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-rematerialize -vmi-to-vpto | FileCheck %s --check-prefix=LOWER module { func.func @vmi_layout_assignment_create_group_mask_s32_dynamic( @@ -39,8 +39,8 @@ module { // ASSIGN-SAME: %[[ACTIVE:arg[0-9]+]]: index) // ASSIGN: %[[MASK0:.*]] = pto.vmi.create_group_mask %[[ACTIVE]] // ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -// ASSIGN: %[[MASK1:.*]] = pto.vmi.create_group_mask %[[ACTIVE]] -// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK1:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_reduce_addf // LOWER-LABEL: func.func @vmi_layout_assignment_create_group_mask_s32_dynamic( diff --git a/test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto b/test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto index 8e8a86450d..6ffab1471d 100644 --- a/test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto +++ b/test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto @@ -33,10 +33,12 @@ module { // ASSIGN-LABEL: func.func @vmi_layout_assignment_dense_group_reduce_multi_consumer( // ASSIGN: %[[X:.*]] = pto.vmi.load // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK0:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[X_SPLIT:.*]] = pto.vmi.ensure_layout %[[X]] // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> -// ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.create_mask -// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X_SPLIT]], %[[MASK_SPLIT]] // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] diff --git a/test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto b/test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto index ec29b4387a..c8ded49a2f 100644 --- a/test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto @@ -18,7 +18,7 @@ module { {num_groups = 8, reassoc} : !pto.vmi.vreg<64xf32>, !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<64xf32> - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.store operand #0 has type !pto.vmi.vreg<64xf32, #pto.vmi.layout> but requires !pto.vmi.vreg<64xf32, #pto.vmi.layout>; pto.vmi.ensure_layout has no registered materialization recipe: unsupported source/result layout pair + // CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.store operand #0 has type '!pto.vmi.vreg<64xf32, #pto.vmi.layout>' but requires '!pto.vmi.vreg<64xf32, #pto.vmi.layout>'; pto.vmi.ensure_layout cannot materialize this conversion // CHECK: failed helper conversion '!pto.vmi.vreg<64xf32, #pto.vmi.layout>' -> '!pto.vmi.vreg<64xf32, #pto.vmi.layout>' (unsupported source/result layout pair) pto.vmi.store %sum, %dst[%off] : !pto.vmi.vreg<64xf32>, !pto.ptr diff --git a/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto b/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto index 27e304ae27..b976ab518d 100644 --- a/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto +++ b/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto @@ -36,10 +36,10 @@ module { // ASSIGN-LABEL: func.func @vmi_layout_assignment_f32_f8_store_reduce( // ASSIGN: %[[X32:.*]] = pto.vmi.load // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> -// ASSIGN: pto.vmi.create_mask +// ASSIGN: %[[MASK0:.*]] = pto.vmi.create_mask // ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> -// ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask -// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X32]], %[[MASK]] // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto index 6ed4e7f9e7..187a79d42b 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto @@ -26,7 +26,7 @@ module { -> !pto.vmi.vreg<128xf32> pto.vmi.group_store %sum, %sum_dst[%off], %c1 {num_groups = 8} : !pto.vmi.vreg<128xf32>, !pto.ptr - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.truncf operand #0 has type !pto.vmi.vreg<128xf32, #pto.vmi.layout> but requires !pto.vmi.vreg<128xf32, #pto.vmi.layout>; pto.vmi.ensure_layout has no registered materialization recipe: unsupported source/result layout pair + // CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.truncf operand #0 has type '!pto.vmi.vreg<128xf32, #pto.vmi.layout>' but requires '!pto.vmi.vreg<128xf32, #pto.vmi.layout>'; pto.vmi.ensure_layout cannot materialize this conversion // CHECK: failed helper conversion '!pto.vmi.vreg<128xf32, #pto.vmi.layout>' -> '!pto.vmi.vreg<128xf32, #pto.vmi.layout>' (unsupported source/result layout pair) %h = pto.vmi.truncf %x : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto index b322e5700e..b63d134392 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto @@ -15,8 +15,11 @@ module { %dst: !pto.ptr, %off: index) { %c1 = arith.constant 1 : index - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots local recipe - // CHECK-SAME: stable group_reduce_add slots=8 recipes support group sizes VLaneElems, 2*VLaneElems, or 4*VLaneElems + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots layout support + // CHECK-SAME: stable group_reduce_add slots=8 support group sizes VLaneElems, 2*VLaneElems, or 4*VLaneElems + // CHECK-SAME: VMI types: operand#0=!pto.vmi.vreg<96xf32, #pto.vmi.layout> + // CHECK-SAME: operand#1=!pto.vmi.mask<96xb32, #pto.vmi.layout> + // CHECK-SAME: result#0=!pto.vmi.vreg<96xf32, #pto.vmi.layout> %sum = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<96xf32>, !pto.vmi.mask<96xpred> diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto index d5fa902c56..602ac579ad 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto @@ -7,7 +7,7 @@ // See LICENSE in the root of the software repository for the full text of the License. // RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN -// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-rematerialize -vmi-to-vpto | FileCheck %s --check-prefix=LOWER module { func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile( @@ -46,10 +46,10 @@ module { // ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile( // ASSIGN: %[[X:.*]] = pto.vmi.load // ASSIGN-SAME: -> !pto.vmi.vreg<192xf32, #pto.vmi.layout> -// ASSIGN: pto.vmi.create_mask +// ASSIGN: %[[MASK0:.*]] = pto.vmi.create_mask // ASSIGN-SAME: -> !pto.vmi.mask<192xb32, #pto.vmi.layout> -// ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask -// ASSIGN-SAME: !pto.vmi.mask<192xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] +// ASSIGN-SAME: !pto.vmi.mask<192xb32, #pto.vmi.layout> -> !pto.vmi.mask<192xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] // ASSIGN-SAME: -> !pto.vmi.vreg<192xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] @@ -73,7 +73,9 @@ module { // ASSIGN: %[[PX:.*]] = pto.vmi.load // ASSIGN-SAME: {full_read_elems = 256 : i64} // ASSIGN-SAME: -> !pto.vmi.vreg<192xf32, #pto.vmi.layout> -// ASSIGN: %[[PMASK:.*]] = pto.vmi.create_mask %{{.*}} : index -> !pto.vmi.mask<192xb32, #pto.vmi.layout> +// ASSIGN: %[[PMASK0:.*]] = pto.vmi.create_mask %{{.*}} : index -> !pto.vmi.mask<192xb32, #pto.vmi.layout> +// ASSIGN: %[[PMASK:.*]] = pto.vmi.ensure_mask_layout %[[PMASK0]] +// ASSIGN-SAME: !pto.vmi.mask<192xb32, #pto.vmi.layout> -> !pto.vmi.mask<192xb32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_reduce_addf %[[PX]], %[[PMASK]] // LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile_ptr_contract( diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto index cface43bab..af78715f95 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto @@ -15,11 +15,11 @@ module { %dst: !pto.ptr, %off: index) { %c1 = arith.constant 1 : index - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf operand #0 has type + // CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.group_reduce_addf operand #0 has type // CHECK-SAME: #pto.vmi.layout // CHECK-SAME: requires // CHECK-SAME: #pto.vmi.layout - // CHECK-SAME: pto.vmi.ensure_layout has no registered materialization recipe + // CHECK-SAME: pto.vmi.ensure_layout cannot materialize this conversion // CHECK: requires source and result to have the same physical arity %sum = pto.vmi.group_reduce_addf %source, %mask {num_groups = 6, reassoc} diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto index b8cd439d23..01dab5b003 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto @@ -11,7 +11,7 @@ module { func.func @vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid( %src: !pto.ptr, %off: index, %stride: index) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_slot_load has no registered local recipe + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_slot_load has no registered layout support // CHECK-SAME: slots=1 group_slot_load currently lowers as one lane-0 vsldb per group // CHECK-SAME: requires constant positive source_group_stride divisible by 8 elements // CHECK-SAME: packed or unaligned scalar load lowering is not implemented diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto index b432d7c68c..1589e531dc 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto @@ -12,7 +12,7 @@ module { func.func @vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid( %src: !pto.ptr, %off: index) { %c2 = arith.constant 2 : index - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_slot_load has no registered local recipe + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_slot_load has no registered layout support // CHECK-SAME: slots=1 group_slot_load currently lowers as one lane-0 vsldb per group // CHECK-SAME: requires constant positive source_group_stride divisible by 8 elements // CHECK-SAME: packed or unaligned scalar load lowering is not implemented diff --git a/test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto b/test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto index c30502a252..95fa93474d 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto @@ -51,10 +51,10 @@ module { // ASSIGN-SAME: -> (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) // ASSIGN: %[[X:.*]] = pto.vmi.group_load // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> -// ASSIGN: pto.vmi.create_group_mask +// ASSIGN: %[[MASK0:.*]] = pto.vmi.create_group_mask // ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> -// ASSIGN: %[[MASK:.*]] = pto.vmi.create_group_mask -// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] +// ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> -> !pto.vmi.mask<128xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.addf %[[ARG]], %[[SUM]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto index 996760ed66..35959585de 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto @@ -19,7 +19,7 @@ module { {num_groups = 8, reassoc} : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> -> !pto.vmi.vreg<512xf32> - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_store has no registered group_slots local recipe + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_store has no registered group_slots layout support // CHECK-SAME: slots=1 group_store currently lowers as one lane-0 vsts per group // CHECK-SAME: requires constant positive row_stride divisible by 8 elements // CHECK-SAME: packed or unaligned contiguous store lowering is not implemented diff --git a/test/lit/vmi/vmi_layout_assignment_iota_remat.pto b/test/lit/vmi/vmi_layout_assignment_iota_remat.pto index 773fd4187c..d79cdfddba 100644 --- a/test/lit/vmi/vmi_layout_assignment_iota_remat.pto +++ b/test/lit/vmi/vmi_layout_assignment_iota_remat.pto @@ -7,7 +7,7 @@ // See LICENSE in the root of the software repository for the full text of the License. // RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN -// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-rematerialize -vmi-to-vpto | FileCheck %s --check-prefix=LOWER module { func.func @vmi_layout_assignment_iota_remat( @@ -37,9 +37,8 @@ module { // ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.addf %[[IOTA_DEINT]], %[[WIDE]] // ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> -// ASSIGN-NOT: pto.vmi.ensure_layout %[[IOTA_DEINT]] -// ASSIGN: %[[IOTA_CONTIG:.*]] = pto.vmi.iota %[[BASE]] -// ASSIGN-SAME: f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[IOTA_CONTIG:.*]] = pto.vmi.ensure_layout %[[IOTA_DEINT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.store %[[IOTA_CONTIG]] // ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_mask_granularity_f32_f16_store.pto b/test/lit/vmi/vmi_layout_assignment_mask_granularity_f32_f16_store.pto index 8a74de4097..1d3a2f3d0b 100644 --- a/test/lit/vmi/vmi_layout_assignment_mask_granularity_f32_f16_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_mask_granularity_f32_f16_store.pto @@ -38,8 +38,8 @@ module { // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[H:.*]] = pto.vmi.truncf %[[X_SPLIT]] // ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> -// ASSIGN: %[[M16:.*]] = pto.vmi.create_mask -// ASSIGN-SAME: -> !pto.vmi.mask<128xb16, #pto.vmi.layout> +// ASSIGN: %[[M16:.*]] = pto.vmi.ensure_mask_granularity %[[M32]] +// ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> -> !pto.vmi.mask<128xb16, #pto.vmi.layout> // ASSIGN: pto.vmi.masked_store %[[H]] // ASSIGN-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> @@ -54,7 +54,9 @@ module { // LOWER: pto.vcvt // LOWER: pto.vcvt // LOWER: pto.vor -// LOWER: pto.plt_b16 +// LOWER: pto.ppack +// LOWER: pto.ppack +// LOWER: pto.por // LOWER: pto.vsts // LOWER-NOT: pto.vmi. // LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_mask_remat.pto b/test/lit/vmi/vmi_layout_assignment_mask_remat.pto index b114643836..8e799c0704 100644 --- a/test/lit/vmi/vmi_layout_assignment_mask_remat.pto +++ b/test/lit/vmi/vmi_layout_assignment_mask_remat.pto @@ -6,7 +6,8 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-rematerialize | FileCheck %s --check-prefix=REMAT module { func.func @vmi_layout_assignment_create_mask_remat( @@ -47,27 +48,41 @@ module { } } -// CHECK-LABEL: func.func @vmi_layout_assignment_create_mask_remat( -// CHECK-SAME: %[[ACTIVE:.*]]: index -// CHECK: %[[M32:.*]] = pto.vmi.create_mask %[[ACTIVE]] -// CHECK-SAME: index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> -// CHECK: %[[M16:.*]] = pto.vmi.create_mask %[[ACTIVE]] -// CHECK-SAME: index -> !pto.vmi.mask<128xb16, #pto.vmi.layout> -// CHECK: pto.vmi.select %[[M16]] -// CHECK-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> -// CHECK: pto.vmi.select %[[M32]] -// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> -// CHECK-NOT: pto.vmi.ensure_mask_layout -// CHECK-NOT: pto.vmi.ensure_mask_granularity +// ASSIGN-LABEL: func.func @vmi_layout_assignment_create_mask_remat( +// ASSIGN-SAME: %[[ACTIVE:.*]]: index +// ASSIGN: %[[M32:.*]] = pto.vmi.create_mask %[[ACTIVE]] +// ASSIGN-SAME: index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[M16:.*]] = pto.vmi.ensure_mask_granularity %[[M32]] +// ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> -> !pto.vmi.mask<128xb16, #pto.vmi.layout> +// ASSIGN: pto.vmi.select %[[M16]] +// ASSIGN-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> +// ASSIGN: pto.vmi.select %[[M32]] +// ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> -// CHECK-LABEL: func.func @vmi_layout_assignment_constant_mask_remat( -// CHECK: %[[CM32:.*]] = "pto.vmi.constant_mask"() -// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> -// CHECK: %[[CM16:.*]] = "pto.vmi.constant_mask"() -// CHECK-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> -// CHECK: pto.vmi.select %[[CM16]] -// CHECK-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> -// CHECK: pto.vmi.select %[[CM32]] -// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> -// CHECK-NOT: pto.vmi.ensure_mask_layout -// CHECK-NOT: pto.vmi.ensure_mask_granularity +// ASSIGN-LABEL: func.func @vmi_layout_assignment_constant_mask_remat( +// ASSIGN: %[[CM32:.*]] = "pto.vmi.constant_mask"() +// ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[CM16:.*]] = pto.vmi.ensure_mask_granularity %[[CM32]] +// ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> -> !pto.vmi.mask<128xb16, #pto.vmi.layout> +// ASSIGN: pto.vmi.select %[[CM16]] +// ASSIGN-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> +// ASSIGN: pto.vmi.select %[[CM32]] +// ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> + +// REMAT-LABEL: func.func @vmi_layout_assignment_create_mask_remat( +// REMAT-SAME: %[[ACTIVE:.*]]: index +// REMAT: %[[M32:.*]] = pto.vmi.create_mask %[[ACTIVE]] +// REMAT-SAME: index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// REMAT: %[[M16:.*]] = pto.vmi.create_mask %[[ACTIVE]] +// REMAT-SAME: index -> !pto.vmi.mask<128xb16, #pto.vmi.layout> +// REMAT: pto.vmi.select %[[M16]] +// REMAT: pto.vmi.select %[[M32]] +// REMAT-LABEL: func.func @vmi_layout_assignment_constant_mask_remat( +// REMAT: %[[CM32:.*]] = "pto.vmi.constant_mask"() +// REMAT-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// REMAT: %[[CM16:.*]] = "pto.vmi.constant_mask"() +// REMAT-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> +// REMAT: pto.vmi.select %[[CM16]] +// REMAT: pto.vmi.select %[[CM32]] +// REMAT-NOT: pto.vmi.ensure_mask_layout +// REMAT-NOT: pto.vmi.ensure_mask_granularity diff --git a/test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto b/test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto index 6c0b2d2ece..796f446b60 100644 --- a/test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto +++ b/test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto @@ -48,8 +48,8 @@ module { // ASSIGN: pto.vmi.store %[[X]] // ASSIGN: %[[X_SPLIT:.*]] = pto.vmi.ensure_layout %[[X]] // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> -// ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.create_mask -// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.ensure_mask_layout %[[MASK]] +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X_SPLIT]], %[[MASK_SPLIT]] // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] diff --git a/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto b/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto index 968e8d03c2..9d3147aaea 100644 --- a/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto +++ b/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto @@ -40,8 +40,8 @@ module { // ASSIGN: pto.vmi.ensure_layout // ASSIGN-SAME: #pto.vmi.layout // ASSIGN-SAME: #pto.vmi.layout -// ASSIGN: pto.vmi.create_group_mask -// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: pto.vmi.ensure_mask_layout +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_reduce_addf // LOWER: pto.pdintlv_b32 // LOWER: pto.pdintlv_b32 diff --git a/test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto b/test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto index 46f7ff71f2..dd8f2910ab 100644 --- a/test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto +++ b/test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto @@ -42,8 +42,8 @@ module { // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: %[[X:.*]] = pto.vmi.addf %[[A]], %[[BIASV]] // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> -// ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask -// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] diff --git a/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto b/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto index e57954b16e..71d282577a 100644 --- a/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto @@ -19,7 +19,7 @@ module { {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> -> !pto.vmi.vreg<128xf32> - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.truncf operand #0 has type !pto.vmi.vreg<128xf32, #pto.vmi.layout> but requires !pto.vmi.vreg<128xf32, #pto.vmi.layout>; pto.vmi.ensure_layout has no registered materialization recipe: unsupported source/result layout pair + // CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.truncf operand #0 has type '!pto.vmi.vreg<128xf32, #pto.vmi.layout>' but requires '!pto.vmi.vreg<128xf32, #pto.vmi.layout>'; pto.vmi.ensure_layout cannot materialize this conversion // CHECK: failed helper conversion '!pto.vmi.vreg<128xf32, #pto.vmi.layout>' -> '!pto.vmi.vreg<128xf32, #pto.vmi.layout>' (unsupported source/result layout pair) %h = pto.vmi.truncf %sum : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> diff --git a/test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto b/test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto index 63fc33cfe6..e9553c2c9d 100644 --- a/test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto +++ b/test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto @@ -36,10 +36,10 @@ module { // ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> // ASSIGN: %[[X32:.*]] = pto.vmi.extf %[[X16]] // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> -// ASSIGN: pto.vmi.create_mask +// ASSIGN: %[[MASK0:.*]] = pto.vmi.create_mask // ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> -// ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask -// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] +// ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> -> !pto.vmi.mask<128xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X32]], %[[MASK]] // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] diff --git a/test/lit/vmi/vmi_layout_gate_bitcast_group_slots_invalid.pto b/test/lit/vmi/vmi_layout_gate_bitcast_group_slots_invalid.pto index e63567e48d..f946686b6f 100644 --- a/test/lit/vmi/vmi_layout_gate_bitcast_group_slots_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_bitcast_group_slots_invalid.pto @@ -11,7 +11,7 @@ module { func.func @vmi_layout_gate_bitcast_group_slots_invalid( %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.bitcast has no registered local recipe + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.bitcast has no registered layout support // CHECK-SAME: does not support group_slots layouts // CHECK: note: see current operation: %{{.*}} = "pto.vmi.bitcast" %out = pto.vmi.bitcast %source diff --git a/test/lit/vmi/vmi_layout_gate_bitcast_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_bitcast_support_invalid.pto similarity index 94% rename from test/lit/vmi/vmi_layout_gate_bitcast_recipe_invalid.pto rename to test/lit/vmi/vmi_layout_gate_bitcast_support_invalid.pto index 806aaa26dd..2acec47cd2 100644 --- a/test/lit/vmi/vmi_layout_gate_bitcast_recipe_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_bitcast_support_invalid.pto @@ -9,9 +9,9 @@ // RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s module { - func.func @vmi_layout_gate_bitcast_recipe_invalid( + func.func @vmi_layout_gate_bitcast_support_invalid( %source: !pto.vmi.vreg<65xf32, #pto.vmi.layout>) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.bitcast has no registered local recipe + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.bitcast has no registered layout support // CHECK-SAME: requires matching logical bit footprint in every physical chunk // CHECK: note: see current operation: %{{.*}} = "pto.vmi.bitcast" %out = pto.vmi.bitcast %source diff --git a/test/lit/vmi/vmi_layout_gate_extf_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_extf_support_invalid.pto similarity index 94% rename from test/lit/vmi/vmi_layout_gate_extf_recipe_invalid.pto rename to test/lit/vmi/vmi_layout_gate_extf_support_invalid.pto index 7bda214fed..4e14381743 100644 --- a/test/lit/vmi/vmi_layout_gate_extf_recipe_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_extf_support_invalid.pto @@ -9,9 +9,9 @@ // RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s module { - func.func @vmi_layout_gate_extf_recipe_invalid( + func.func @vmi_layout_gate_extf_support_invalid( %source: !pto.vmi.vreg<128xf16, #pto.vmi.layout>) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.extf has no registered local recipe + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.extf has no registered layout support // CHECK-SAME: requires contiguous source layout and deinterleaved f32 result layout // CHECK: note: see current operation: %{{.*}} = "pto.vmi.extf" %out = pto.vmi.extf %source diff --git a/test/lit/vmi/vmi_layout_gate_group_broadcast_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_broadcast_support_invalid.pto similarity index 93% rename from test/lit/vmi/vmi_layout_gate_group_broadcast_recipe_invalid.pto rename to test/lit/vmi/vmi_layout_gate_group_broadcast_support_invalid.pto index 224858064c..64681b5dd3 100644 --- a/test/lit/vmi/vmi_layout_gate_group_broadcast_recipe_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_group_broadcast_support_invalid.pto @@ -9,9 +9,9 @@ // RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s module { - func.func @vmi_layout_gate_group_broadcast_recipe_invalid( + func.func @vmi_layout_gate_group_broadcast_support_invalid( %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_broadcast has no registered local recipe + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_broadcast has no registered layout support // CHECK-SAME: supports only slots=8 or slots=1 group_broadcast source layouts // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_broadcast" %out = pto.vmi.group_broadcast %source {num_groups = 8} diff --git a/test/lit/vmi/vmi_layout_gate_group_load_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_load_support_invalid.pto similarity index 93% rename from test/lit/vmi/vmi_layout_gate_group_load_recipe_invalid.pto rename to test/lit/vmi/vmi_layout_gate_group_load_support_invalid.pto index 8f9fb2c809..a14ff20a0b 100644 --- a/test/lit/vmi/vmi_layout_gate_group_load_recipe_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_group_load_support_invalid.pto @@ -9,9 +9,9 @@ // RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s module { - func.func @vmi_layout_gate_group_load_recipe_invalid( + func.func @vmi_layout_gate_group_load_support_invalid( %src: !pto.ptr, %off: index, %stride: index) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_load has no registered block8 local recipe + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_load has no registered block8 layout support // CHECK-SAME: block8 strided group_load requires constant positive row_stride divisible by 8 f32 elements // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_load" %out = pto.vmi.group_load %src[%off], %stride diff --git a/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_support_invalid.pto similarity index 85% rename from test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto rename to test/lit/vmi/vmi_layout_gate_group_reduce_slots1_support_invalid.pto index d33315f88d..0c792693f3 100644 --- a/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_support_invalid.pto @@ -9,11 +9,11 @@ // RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s module { - func.func @vmi_layout_gate_group_reduce_slots1_recipe_invalid( + func.func @vmi_layout_gate_group_reduce_slots1_support_invalid( %source: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, %mask: !pto.vmi.mask<256xb32, #pto.vmi.layout>) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots local recipe - // CHECK-SAME: stable group_reduce_add slots=1 recipes support group sizes that are multiples of one physical chunk + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots layout support + // CHECK-SAME: stable group_reduce_add slots=1 support group sizes that are multiples of one physical chunk // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_reduce_addf" %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} diff --git a/test/lit/vmi/vmi_layout_gate_group_reduce_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_reduce_support_invalid.pto similarity index 85% rename from test/lit/vmi/vmi_layout_gate_group_reduce_recipe_invalid.pto rename to test/lit/vmi/vmi_layout_gate_group_reduce_support_invalid.pto index 33a7bc0fae..734c9dd497 100644 --- a/test/lit/vmi/vmi_layout_gate_group_reduce_recipe_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_group_reduce_support_invalid.pto @@ -9,11 +9,11 @@ // RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s module { - func.func @vmi_layout_gate_group_reduce_recipe_invalid( + func.func @vmi_layout_gate_group_reduce_support_invalid( %source: !pto.vmi.vreg<96xf32, #pto.vmi.layout>, %mask: !pto.vmi.mask<96xb32, #pto.vmi.layout>) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots local recipe - // CHECK-SAME: stable group_reduce_add slots=8 recipes support group sizes VLaneElems, 2*VLaneElems, or 4*VLaneElems + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots layout support + // CHECK-SAME: stable group_reduce_add slots=8 support group sizes VLaneElems, 2*VLaneElems, or 4*VLaneElems // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_reduce_addf" %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} diff --git a/test/lit/vmi/vmi_layout_gate_group_slot_load_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_slot_load_support_invalid.pto similarity index 93% rename from test/lit/vmi/vmi_layout_gate_group_slot_load_recipe_invalid.pto rename to test/lit/vmi/vmi_layout_gate_group_slot_load_support_invalid.pto index 31e7f13c3e..334be3d744 100644 --- a/test/lit/vmi/vmi_layout_gate_group_slot_load_recipe_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_group_slot_load_support_invalid.pto @@ -9,9 +9,9 @@ // RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s module { - func.func @vmi_layout_gate_group_slot_load_recipe_invalid( + func.func @vmi_layout_gate_group_slot_load_support_invalid( %src: !pto.ptr, %off: index, %stride: index) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_slot_load has no registered local recipe + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_slot_load has no registered layout support // CHECK-SAME: slots=8 group_slot_load requires constant unit source_group_stride // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_slot_load" %out = pto.vmi.group_slot_load %src[%off], %stride diff --git a/test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto index c787f57fea..f3263148b3 100644 --- a/test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto @@ -12,7 +12,7 @@ module { func.func @vmi_layout_gate_group_store_slots2_invalid( %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, %dst: !pto.ptr, %off: index, %row_stride: index) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_store has no registered group_slots local recipe + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_store has no registered group_slots layout support // CHECK-SAME: group_slots group_store currently supports only slots=1 or unit-stride slots=8 pto.vmi.group_store %value, %dst[%off], %row_stride {num_groups = 8} @@ -28,8 +28,8 @@ module { func.func @vmi_layout_gate_group_reduce_slots2_invalid( %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots local recipe - // CHECK-SAME: stable group_reduce_add local recipes currently require result layout slots=8 or slots=1 + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots layout support + // CHECK-SAME: stable group_reduce_add layout support currently requires result layout slots=8 or slots=1 %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, diff --git a/test/lit/vmi/vmi_layout_gate_group_store_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_store_support_invalid.pto similarity index 93% rename from test/lit/vmi/vmi_layout_gate_group_store_recipe_invalid.pto rename to test/lit/vmi/vmi_layout_gate_group_store_support_invalid.pto index c7003a887d..db0794748d 100644 --- a/test/lit/vmi/vmi_layout_gate_group_store_recipe_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_group_store_support_invalid.pto @@ -9,10 +9,10 @@ // RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s module { - func.func @vmi_layout_gate_group_store_recipe_invalid( + func.func @vmi_layout_gate_group_store_support_invalid( %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, %dst: !pto.ptr, %off: index, %row_stride: index) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_store has no registered group_slots local recipe + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_store has no registered group_slots layout support // CHECK-SAME: slots=8 group_store currently requires constant unit row_stride // CHECK: note: see current operation: "pto.vmi.group_store" pto.vmi.group_store %value, %dst[%off], %row_stride diff --git a/test/lit/vmi/vmi_layout_gate_helper_materialization_shape_invalid.pto b/test/lit/vmi/vmi_layout_gate_helper_materialization_shape_invalid.pto index 53cc5c2a12..4aa1f30cbb 100644 --- a/test/lit/vmi/vmi_layout_gate_helper_materialization_shape_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_helper_materialization_shape_invalid.pto @@ -11,7 +11,7 @@ module { func.func @vmi_layout_gate_ensure_layout_shape_invalid( %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.ensure_layout has no registered materialization recipe + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.ensure_layout has no registered materialization support // CHECK-SAME: requires source and result to have the same physical arity %dense = pto.vmi.ensure_layout %value : !pto.vmi.vreg<128xf32, #pto.vmi.layout> @@ -25,7 +25,7 @@ module { module { func.func @vmi_layout_gate_ensure_mask_layout_shape_invalid( %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.ensure_mask_layout has no registered materialization recipe + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.ensure_mask_layout has no registered materialization support // CHECK-SAME: requires source and result to have the same physical arity %dense = pto.vmi.ensure_mask_layout %mask : !pto.vmi.mask<128xb32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_gate_helper_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_helper_support_invalid.pto similarity index 92% rename from test/lit/vmi/vmi_layout_gate_helper_recipe_invalid.pto rename to test/lit/vmi/vmi_layout_gate_helper_support_invalid.pto index 871e14eb5b..90e49c52dd 100644 --- a/test/lit/vmi/vmi_layout_gate_helper_recipe_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_helper_support_invalid.pto @@ -9,7 +9,7 @@ // RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s module { - func.func @vmi_layout_gate_helper_recipe_invalid( + func.func @vmi_layout_gate_helper_support_invalid( %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) { %bad = pto.vmi.ensure_layout %value : !pto.vmi.vreg<64xf32, #pto.vmi.layout> @@ -18,5 +18,5 @@ module { } } -// CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.ensure_layout has no registered materialization recipe +// CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.ensure_layout has no registered materialization support // CHECK-SAME: unsupported source/result layout pair diff --git a/test/lit/vmi/vmi_layout_gate_store_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_store_support_invalid.pto similarity index 95% rename from test/lit/vmi/vmi_layout_gate_store_recipe_invalid.pto rename to test/lit/vmi/vmi_layout_gate_store_support_invalid.pto index 3877eb1a3a..7c62871865 100644 --- a/test/lit/vmi/vmi_layout_gate_store_recipe_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_store_support_invalid.pto @@ -12,7 +12,7 @@ module { func.func @vmi_layout_gate_store_deint_tail_invalid( %value: !pto.vmi.vreg<129xf32, #pto.vmi.layout>, %dst: !pto.ptr, %offset: index) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.store has no registered contiguous-memory local recipe + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.store has no registered contiguous-memory layout support // CHECK-SAME: requires arity divisible by layout factor pto.vmi.store %value, %dst[%offset] : !pto.vmi.vreg<129xf32, #pto.vmi.layout>, @@ -27,7 +27,7 @@ module { func.func @vmi_layout_gate_tile_write_deint_tail_invalid( %value: !pto.vmi.vreg<129xf32, #pto.vmi.layout>, %dst: memref<129xf32>) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.tile_write has no registered contiguous-memory local recipe + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.tile_write has no registered contiguous-memory layout support // CHECK-SAME: requires arity divisible by layout factor pto.vmi.tile_write %value, %dst : !pto.vmi.vreg<129xf32, #pto.vmi.layout>, diff --git a/test/lit/vmi/vmi_layout_gate_local_recipe.pto b/test/lit/vmi/vmi_layout_gate_support.pto similarity index 92% rename from test/lit/vmi/vmi_layout_gate_local_recipe.pto rename to test/lit/vmi/vmi_layout_gate_support.pto index 7644fae1c6..629b85c208 100644 --- a/test/lit/vmi/vmi_layout_gate_local_recipe.pto +++ b/test/lit/vmi/vmi_layout_gate_support.pto @@ -9,7 +9,7 @@ // RUN: pto-test-opt %s -pto-validate-vmi-layout-ir | FileCheck %s module { - func.func @vmi_layout_gate_local_recipe( + func.func @vmi_layout_gate_support( %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} @@ -20,5 +20,5 @@ module { } } -// CHECK-LABEL: func.func @vmi_layout_gate_local_recipe( +// CHECK-LABEL: func.func @vmi_layout_gate_support( // CHECK: pto.vmi.group_reduce_addf diff --git a/test/lit/vmi/vmi_layout_gate_truncf_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_truncf_support_invalid.pto similarity index 94% rename from test/lit/vmi/vmi_layout_gate_truncf_recipe_invalid.pto rename to test/lit/vmi/vmi_layout_gate_truncf_support_invalid.pto index 68e7963b1b..3021b88a7d 100644 --- a/test/lit/vmi/vmi_layout_gate_truncf_recipe_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_truncf_support_invalid.pto @@ -9,9 +9,9 @@ // RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s module { - func.func @vmi_layout_gate_truncf_recipe_invalid( + func.func @vmi_layout_gate_truncf_support_invalid( %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.truncf has no registered local recipe + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.truncf has no registered layout support // CHECK-SAME: group-slot truncf requires matching group_slots(num_groups=G, slots=1) // CHECK: note: see current operation: %{{.*}} = "pto.vmi.truncf" %out = pto.vmi.truncf %source diff --git a/test/lit/vmi/vmi_layout_rematerialize_data.pto b/test/lit/vmi/vmi_layout_rematerialize_data.pto index 29faa34fb1..22a03d88a5 100644 --- a/test/lit/vmi/vmi_layout_rematerialize_data.pto +++ b/test/lit/vmi/vmi_layout_rematerialize_data.pto @@ -39,6 +39,18 @@ module { !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.vmi.vreg<128xf32, #pto.vmi.layout> } + + func.func @vmi_layout_rematerialize_keeps_load_helper( + %src: !pto.ptr, %off: index) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + %load = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %load_deint = pto.vmi.ensure_layout %load + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %load_deint + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } } // CHECK-LABEL: func.func @vmi_layout_rematerialize_data( @@ -47,3 +59,8 @@ module { // CHECK: %[[CONST:.*]] = "pto.vmi.constant"(){{.*}}dense<1.000000e+00> : tensor<128xf32>{{.*}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> // CHECK-NOT: pto.vmi.ensure_layout // CHECK: return %[[BCAST]], %[[IOTA]], %[[CONST]] + +// CHECK-LABEL: func.func @vmi_layout_rematerialize_keeps_load_helper( +// CHECK: %[[LOAD:.*]] = pto.vmi.load +// CHECK: %[[LOAD_DEINT:.*]] = pto.vmi.ensure_layout %[[LOAD]] +// CHECK: return %[[LOAD_DEINT]] diff --git a/test/lit/vmi/vmi_layout_sink_materialization_binary.pto b/test/lit/vmi/vmi_layout_sink_materialization_binary.pto index 9db3fcb22b..eb21fae758 100644 --- a/test/lit/vmi/vmi_layout_sink_materialization_binary.pto +++ b/test/lit/vmi/vmi_layout_sink_materialization_binary.pto @@ -57,6 +57,85 @@ module { return %sum : !pto.vmi.vreg<128xf32, #pto.vmi.layout> } + func.func @vmi_layout_sink_materialization_fma( + %lhs: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %acc: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + %lhs_deint = pto.vmi.ensure_layout %lhs + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %rhs_deint = pto.vmi.ensure_layout %rhs + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %acc_deint = pto.vmi.ensure_layout %acc + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %out = pto.vmi.fma %lhs_deint, %rhs_deint, %acc_deint + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %out : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_materialization_cmpf( + %lhs: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> { + %lhs_deint = pto.vmi.ensure_layout %lhs + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %rhs_deint = pto.vmi.ensure_layout %rhs + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %mask = pto.vmi.cmpf "olt", %lhs_deint, %rhs_deint + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + return %mask : !pto.vmi.mask<128xb32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_materialization_cmpi( + %lhs: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> { + %lhs_deint = pto.vmi.ensure_layout %lhs + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %rhs_deint = pto.vmi.ensure_layout %rhs + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %mask = pto.vmi.cmpi "slt", %lhs_deint, %rhs_deint + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + return %mask : !pto.vmi.mask<128xb32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_materialization_select( + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>, + %true_value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %false_value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + %mask_deint = pto.vmi.ensure_mask_layout %mask + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %true_deint = pto.vmi.ensure_layout %true_value + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %false_deint = pto.vmi.ensure_layout %false_value + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %selected = pto.vmi.select %mask_deint, %true_deint, %false_deint + : !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %selected + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } + func.func @vmi_layout_sink_materialization_unary( %src: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { @@ -154,6 +233,49 @@ module { // CHECK: %[[SUM2:.*]] = pto.vmi.addf %[[LHS_DEINT]], %arg1 // CHECK: return %[[SUM2]] +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_fma( +// CHECK-NOT: pto.vmi.ensure_layout %arg0 +// CHECK-NOT: pto.vmi.ensure_layout %arg1 +// CHECK-NOT: pto.vmi.ensure_layout %arg2 +// CHECK: %[[FMA:.*]] = pto.vmi.fma %arg0, %arg1, %arg2 +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: %[[FMA_DEINT:.*]] = pto.vmi.ensure_layout %[[FMA]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[FMA_DEINT]] + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_cmpf( +// CHECK-NOT: pto.vmi.ensure_layout %arg0 +// CHECK-NOT: pto.vmi.ensure_layout %arg1 +// CHECK: %[[CMPF:.*]] = pto.vmi.cmpf "olt", %arg0, %arg1 +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: %[[CMPF_DEINT:.*]] = pto.vmi.ensure_mask_layout %[[CMPF]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[CMPF_DEINT]] + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_cmpi( +// CHECK-NOT: pto.vmi.ensure_layout %arg0 +// CHECK-NOT: pto.vmi.ensure_layout %arg1 +// CHECK: %[[CMPI:.*]] = pto.vmi.cmpi "slt", %arg0, %arg1 +// CHECK-SAME: !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: %[[CMPI_DEINT:.*]] = pto.vmi.ensure_mask_layout %[[CMPI]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[CMPI_DEINT]] + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_select( +// CHECK-NOT: pto.vmi.ensure_mask_layout %arg0 +// CHECK-NOT: pto.vmi.ensure_layout %arg1 +// CHECK-NOT: pto.vmi.ensure_layout %arg2 +// CHECK: %[[SELECT:.*]] = pto.vmi.select %arg0, %arg1, %arg2 +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: %[[SELECT_DEINT:.*]] = pto.vmi.ensure_layout %[[SELECT]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[SELECT_DEINT]] + // CHECK-LABEL: func.func @vmi_layout_sink_materialization_unary( // CHECK-NOT: pto.vmi.ensure_layout %arg0 // CHECK: %[[NEG:.*]] = pto.vmi.negf %arg0 diff --git a/test/lit/vmi/vmi_to_vpto_constant_mask_rematerialize.pto b/test/lit/vmi/vmi_to_vpto_constant_mask_rematerialize.pto index 3b2fc0d080..55e1308b4b 100644 --- a/test/lit/vmi/vmi_to_vpto_constant_mask_rematerialize.pto +++ b/test/lit/vmi/vmi_to_vpto_constant_mask_rematerialize.pto @@ -6,7 +6,7 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-rematerialize -vmi-to-vpto | FileCheck %s module { func.func @vmi_to_vpto_constant_mask_rematerialize( diff --git a/test/lit/vmi/vmi_to_vpto_create_mask_rematerialize.pto b/test/lit/vmi/vmi_to_vpto_create_mask_rematerialize.pto index 74ef8194d5..03add9ada4 100644 --- a/test/lit/vmi/vmi_to_vpto_create_mask_rematerialize.pto +++ b/test/lit/vmi/vmi_to_vpto_create_mask_rematerialize.pto @@ -6,7 +6,7 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-rematerialize -vmi-to-vpto | FileCheck %s module { func.func @vmi_to_vpto_create_mask_rematerialize( diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_local_recipe.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_support.pto similarity index 94% rename from test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_local_recipe.pto rename to test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_support.pto index dc1b938924..55ed864da1 100644 --- a/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_local_recipe.pto +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_support.pto @@ -9,7 +9,7 @@ // RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s module { - func.func @vmi_to_vpto_group_broadcast_slots8_local_recipe( + func.func @vmi_to_vpto_group_broadcast_slots8_support( %source: !pto.vmi.vreg<1024xf32, #pto.vmi.layout>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, @@ -34,7 +34,7 @@ module { } } -// CHECK-LABEL: func.func @vmi_to_vpto_group_broadcast_slots8_local_recipe( +// CHECK-LABEL: func.func @vmi_to_vpto_group_broadcast_slots8_support( // CHECK-COUNT-16: pto.vselr // CHECK-NOT: pto.vmi. // CHECK-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_vselr.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_vselr.pto index 01d9711ef0..3c40457460 100644 --- a/test/lit/vmi/vmi_to_vpto_group_broadcast_vselr.pto +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_vselr.pto @@ -10,13 +10,13 @@ module { func.func @vmi_to_vpto_group_broadcast_vselr( - %source: !pto.vmi.vreg<1024xf32, #pto.vmi.layout>) + %source: !pto.vmi.vreg<1024xf32, #pto.vmi.layout>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { %out = pto.vmi.group_broadcast %source {num_groups = 128} - : !pto.vmi.vreg<1024xf32, #pto.vmi.layout> + : !pto.vmi.vreg<1024xf32, #pto.vmi.layout> -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, %p8, %p9, %p10, %p11, %p12, %p13, %p14, %p15 = "pto.vmi.unpack"(%out) diff --git a/test/lit/vmi/vmi_to_vpto_group_load_local_recipe.pto b/test/lit/vmi/vmi_to_vpto_group_load_support.pto similarity index 94% rename from test/lit/vmi/vmi_to_vpto_group_load_local_recipe.pto rename to test/lit/vmi/vmi_to_vpto_group_load_support.pto index a1c5959f98..1af77958af 100644 --- a/test/lit/vmi/vmi_to_vpto_group_load_local_recipe.pto +++ b/test/lit/vmi/vmi_to_vpto_group_load_support.pto @@ -9,7 +9,7 @@ // RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s module { - func.func @vmi_to_vpto_group_load_local_recipe( + func.func @vmi_to_vpto_group_load_support( %source: !pto.ptr, %row_stride: index) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, @@ -30,7 +30,7 @@ module { } } -// CHECK-LABEL: func.func @vmi_to_vpto_group_load_local_recipe( +// CHECK-LABEL: func.func @vmi_to_vpto_group_load_support( // CHECK-COUNT-8: pto.vlds // CHECK-NOT: pto.vmi. // CHECK-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_to_vpto_group_ops.pto b/test/lit/vmi/vmi_to_vpto_group_ops.pto index 380a090a71..019b45f7c5 100644 --- a/test/lit/vmi/vmi_to_vpto_group_ops.pto +++ b/test/lit/vmi/vmi_to_vpto_group_ops.pto @@ -21,9 +21,9 @@ module { %r = pto.vmi.group_reduce_addf %v, %mask {num_groups = 2, reassoc} : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, !pto.vmi.mask<512xb32, #pto.vmi.layout> - -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> %b = pto.vmi.group_broadcast %r {num_groups = 2} - : !pto.vmi.vreg<512xf32, #pto.vmi.layout> + : !pto.vmi.vreg<512xf32, #pto.vmi.layout> -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> pto.vmi.group_store %b, %dst[%c0], %row_stride {num_groups = 2} : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, !pto.ptr diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_legacy_slots_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_legacy_slots_invalid.pto new file mode 100644 index 0000000000..b3e48c56b4 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_legacy_slots_invalid.pto @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @legacy_group_slots_without_explicit_slots( + %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + // CHECK: pto.vmi.group_reduce_addf lowers through pto.vcgadd + // CHECK-SAME: stable group_reduce_add layout support currently requires result layout slots=8 or slots=1 + %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_s64_local_recipe.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_s64_support.pto similarity index 94% rename from test/lit/vmi/vmi_to_vpto_group_reduce_s64_local_recipe.pto rename to test/lit/vmi/vmi_to_vpto_group_reduce_s64_support.pto index 4b706dc08d..99359b1a8e 100644 --- a/test/lit/vmi/vmi_to_vpto_group_reduce_s64_local_recipe.pto +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_s64_support.pto @@ -9,7 +9,7 @@ // RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s module { - func.func @vmi_to_vpto_group_reduce_s64_local_recipe( + func.func @vmi_to_vpto_group_reduce_s64_support( %source: !pto.vmi.vreg<512xf32, #pto.vmi.layout>, %mask: !pto.vmi.mask<512xb32, #pto.vmi.layout>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, @@ -31,7 +31,7 @@ module { } } -// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_s64_local_recipe( +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_s64_support( // CHECK-COUNT-8: pto.vcadd // CHECK: pto.vsel // CHECK-NOT: pto.vmi. diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_slots8_local_recipe.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_slots8_support.pto similarity index 91% rename from test/lit/vmi/vmi_to_vpto_group_reduce_slots8_local_recipe.pto rename to test/lit/vmi/vmi_to_vpto_group_reduce_slots8_support.pto index a6737eae1f..9e6a9faf00 100644 --- a/test/lit/vmi/vmi_to_vpto_group_reduce_slots8_local_recipe.pto +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_slots8_support.pto @@ -9,7 +9,7 @@ // RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s module { - func.func @vmi_to_vpto_group_reduce_slots8_local_recipe( + func.func @vmi_to_vpto_group_reduce_slots8_support( %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) -> !pto.vreg<64xf32> { @@ -24,7 +24,7 @@ module { } } -// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_slots8_local_recipe( +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_slots8_support( // CHECK: pto.vcgadd // CHECK-NOT: pto.vcadd // CHECK-NOT: pto.vmi. diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd.pto index 27d246e6d2..d6b52468b4 100644 --- a/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd.pto +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd.pto @@ -16,9 +16,9 @@ module { %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.vmi.mask<64xb32, #pto.vmi.layout> - -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> %part = "pto.vmi.unpack"(%out) - : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) -> !pto.vreg<64xf32> return %part : !pto.vreg<64xf32> } diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd_multichunk.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd_multichunk.pto index d3da9416b6..d6265bd490 100644 --- a/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd_multichunk.pto +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd_multichunk.pto @@ -19,10 +19,10 @@ module { %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 128, reassoc} : !pto.vmi.vreg<1024xf32, #pto.vmi.layout>, !pto.vmi.mask<1024xb32, #pto.vmi.layout> - -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, %p8, %p9, %p10, %p11, %p12, %p13, %p14, %p15 = "pto.vmi.unpack"(%out) - : (!pto.vmi.vreg<1024xf32, #pto.vmi.layout>) + : (!pto.vmi.vreg<1024xf32, #pto.vmi.layout>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_load_local_recipe.pto b/test/lit/vmi/vmi_to_vpto_group_slot_load_support.pto similarity index 91% rename from test/lit/vmi/vmi_to_vpto_group_slot_load_local_recipe.pto rename to test/lit/vmi/vmi_to_vpto_group_slot_load_support.pto index 3a9aa117b5..e806b28b92 100644 --- a/test/lit/vmi/vmi_to_vpto_group_slot_load_local_recipe.pto +++ b/test/lit/vmi/vmi_to_vpto_group_slot_load_support.pto @@ -9,7 +9,7 @@ // RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s module { - func.func @vmi_to_vpto_group_slot_load_local_recipe( + func.func @vmi_to_vpto_group_slot_load_support( %src: !pto.ptr, %off: index) -> !pto.vreg<64xf32> { %c1 = arith.constant 1 : index %out = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} @@ -22,7 +22,7 @@ module { } } -// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_load_local_recipe( +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_load_support( // CHECK: pto.vsldb // CHECK-NOT: pto.vmi. // CHECK-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_local_recipe.pto b/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_support.pto similarity index 96% rename from test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_local_recipe.pto rename to test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_support.pto index eec3c06d2a..4874117e69 100644 --- a/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_local_recipe.pto +++ b/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_support.pto @@ -9,7 +9,7 @@ // RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s module { - func.func @vmi_to_vpto_group_slot_truncf_slots1_local_recipe( + func.func @vmi_to_vpto_group_slot_truncf_slots1_support( %source: !pto.vmi.vreg<512xf32, #pto.vmi.layout>) -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, @@ -29,7 +29,7 @@ module { } } -// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_truncf_slots1_local_recipe( +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_truncf_slots1_support( // CHECK-COUNT-8: pto.vcvt // CHECK-NOT: pto.vmi. // CHECK-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_to_vpto_quant_dequant.pto b/test/lit/vmi/vmi_to_vpto_quant_dequant.pto index a0cc8215cb..dd69bcfaa2 100644 --- a/test/lit/vmi/vmi_to_vpto_quant_dequant.pto +++ b/test/lit/vmi/vmi_to_vpto_quant_dequant.pto @@ -258,7 +258,9 @@ module { // CHECK: pto.vcvt {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> // CHECK: pto.vor {{.*}} : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> // CHECK: scf.if -// CHECK: pto.plt_b16 {{.*}} : i32 -> !pto.mask, i32 +// CHECK: pto.ppack {{.*}} : !pto.mask -> !pto.mask +// CHECK: pto.ppack {{.*}} : !pto.mask -> !pto.mask +// CHECK: pto.por {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask // CHECK: pto.vsts {{.*}} : !pto.vreg<128xf16>, !pto.ptr, !pto.mask // CHECK-LABEL: func.func @vmi_to_vpto_dequant_matrix_fp8_to_f32( @@ -301,7 +303,12 @@ module { // CHECK: pto.vor {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> // CHECK: pto.vor {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> // CHECK: pto.vor {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> -// CHECK: pto.plt_b8 {{.*}} : i32 -> !pto.mask, i32 +// CHECK: pto.ppack {{.*}} : !pto.mask -> !pto.mask +// CHECK: pto.ppack {{.*}} : !pto.mask -> !pto.mask +// CHECK: pto.por {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK: pto.ppack {{.*}} : !pto.mask -> !pto.mask +// CHECK: pto.ppack {{.*}} : !pto.mask -> !pto.mask +// CHECK: pto.por {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask // CHECK: pto.vsts {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.ptr, !pto.mask // CHECK-NOT: pto.vmi. // CHECK-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto b/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto index f78e4ef5f2..5297123e5a 100644 --- a/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto @@ -17,9 +17,9 @@ module { } } -// CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.truncf operand #0 has type {{'?}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.truncf operand #0 has type {{'?}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> // CHECK-SAME: but requires {{'?}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> -// CHECK-SAME: pto.vmi.ensure_layout has no registered materialization recipe +// CHECK-SAME: pto.vmi.ensure_layout cannot materialize this conversion // CHECK: failed helper conversion {{'?}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> // CHECK-SAME: {{'?}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> // CHECK-SAME: requires source and result to have the same physical arity From 85a98cb2d709781c05a9575849ec20738987d39c Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Tue, 23 Jun 2026 14:17:33 +0800 Subject: [PATCH 21/31] Support partial packed VMI group slots --- docs/designs/vmi-layout-lowering-cases.md | 155 ++++++++++++++++++ lib/PTO/IR/VMI.cpp | 5 +- ...assignment_group_reduce_partial_slots8.pto | 61 +++++++ .../vmi/vmi_layout_group_slots_invalid.pto | 4 +- ...mi_to_vpto_group_reduce_partial_slots8.pto | 94 +++++++++++ test/lit/vmi/vmi_type_attr_parse.pto | 7 +- 6 files changed, 320 insertions(+), 6 deletions(-) create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_partial_slots8.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_reduce_partial_slots8.pto diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md index d2d7b3835d..5a007987a7 100644 --- a/docs/designs/vmi-layout-lowering-cases.md +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -5494,6 +5494,12 @@ S = logical_lane_count / num_groups The canonical grouped-reduce layouts are: ```text +Packed group-slot rule: + K is the physical slot capacity of one packed group-result chunk. + For VCG-style packed reductions, K = 8. + G does not have to be divisible by K; the final chunk may be partial. + active_groups(chunk c) = min(K, G - c * K). + S == VLaneElems: source/mask layout = contiguous result layout = group_slots(num_groups=G, slots=8) @@ -5696,6 +5702,155 @@ for r = 0..7: out[group_off + r] = reduce_T16(base[off + r * 64 + 0 .. 63]) ``` +#### 3.50.1 Partial Packed `S = 64` Reductions + +This is the same `S = 4 * VLaneElems` lowering family as section 3.50, but it +covers `G` values that do not fill every packed group-result chunk. The key +point is that `slots = 8` is a physical capacity, not a promise that every +chunk contains eight valid group results. + +The result layout remains: + +```text +!pto.vmi.vreg<(G * 64)xf16, #pto.vmi.layout> +``` + +The lowering computes per result chunk: + +```text +K = 8 +chunk c active groups A(c) = min(K, G - c * K) + +source active lanes per deinterleaved part for chunk c: + A(c) * VLaneElems = A(c) * 16 f16 lanes + +reduce input mask: + PAT_VL(A(c) * 16) + +combine/store mask: + PAT_VL(A(c)) +``` + +For full chunks, `A(c) = 8`, so the reduce input mask is `PAT_ALL` for f16 +and the combine/store mask is `PAT_VL8`. For partial chunks, masks are +required for correctness. The semantic source mask produced by +`pto.vmi.create_group_mask` must also materialize only the valid source lanes; +the reduce lowering should not treat padding lanes as active data. + +##### `G = 4`: `256xf16, num_groups = 4` + +VMI-shaped input: + +```text +%x = pto.vmi.load %base[%off] + : memref<256xf16> -> !pto.vmi.vreg<256xf16> +%mask = pto.vmi.create_group_mask %c64 {num_groups = 4, group_size = 64} + : index -> !pto.vmi.mask<256xpred> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 4, reassoc} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 4} +``` + +Assigned layouts: + +```text +%x, %mask: + #pto.vmi.layout + +%sum: + !pto.vmi.vreg<256xf16, #pto.vmi.layout> +``` + +VPTO lowering shape for the only result chunk: + +```text +%x_p0, %x_p1, %x_p2, %x_p3 = materialize deinterleaved=4, block_elems=8 input + : four !pto.vreg<128xf16> + +%lane64_b16 = pto.pge_b16 "PAT_VL64" // A * 16 = 4 * 16 +%slot4_b16 = pto.pge_b16 "PAT_VL4" + +%s0 = pto.vcgadd %x_p0, %lane64_b16 : !pto.vreg<128xf16> +%s1 = pto.vcgadd %x_p1, %lane64_b16 : !pto.vreg<128xf16> +%s2 = pto.vcgadd %x_p2, %lane64_b16 : !pto.vreg<128xf16> +%s3 = pto.vcgadd %x_p3, %lane64_b16 : !pto.vreg<128xf16> + +%s01 = pto.vadd %s0, %s1, %slot4_b16 : !pto.vreg<128xf16> +%s23 = pto.vadd %s2, %s3, %slot4_b16 : !pto.vreg<128xf16> +%sum0 = pto.vadd %s01, %s23, %slot4_b16 : !pto.vreg<128xf16> + +pto.vsts %sum0, %out[%group_off], %slot4_b16 {dist = "NORM_B16"} + : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..3: + out[group_off + r] = reduce_f16(base[off + r * 64 + 0 .. 63]) + +sum0 lanes 4..127 are not semantic for this VMI result. +``` + +##### `G = 8`: full packed chunk + +This is section 3.50. There is one result chunk with `A = 8`: + +```text +source mask = PAT_ALL // 8 * 16 = 128 f16 lanes +combine/store = PAT_VL8 +result layout = group_slots(num_groups=8, slots=8) +``` + +##### `G = 12`: full chunk plus partial chunk + +This case needs two packed result chunks: + +```text +result layout = group_slots(num_groups=12, slots=8) +result arity = ceil(12 / 8) = 2 +``` + +Chunk 0 handles groups `0..7`: + +```text +A(0) = 8 +source mask = PAT_ALL +combine/store = PAT_VL8 +``` + +Chunk 1 handles groups `8..11`: + +```text +A(1) = 4 +source mask = PAT_VL64 +combine/store = PAT_VL4 +``` + +Implementation checklist for this family: + +```text +layout attr: + slots=8 should be legal even when num_groups is not divisible by 8. + slot_block(g) = g / 8 and slot_lane(g) = g % 8 are still well-defined. + +layout assignment: + packed VCG-style group_reduce results keep slots=8. + +mask materialization: + create_group_mask must not activate padding lanes in partial chunks. + For chunk c, source active lanes are A(c) * VLaneElems. + +vmi-to-vpto group_reduce: + use A(c) from result layout slots and num_groups. + combine masks use PAT_VL(A(c)). + input vcgadd consumes the physical mask parts, which must already encode + PAT_VL(A(c) * VLaneElems) for all-true grouped masks. + +vmi-to-vpto group_store: + use A(c) to build the store predicate. + output group offset for chunk c is c * slots. +``` + ### 3.51 16-bit Typed Group Reduce, `S = L = 128` This is the first row-local full-physical-chunk case for both `f16` and `i16`. diff --git a/lib/PTO/IR/VMI.cpp b/lib/PTO/IR/VMI.cpp index b504de67f5..d3d2dc6b14 100644 --- a/lib/PTO/IR/VMI.cpp +++ b/lib/PTO/IR/VMI.cpp @@ -531,12 +531,11 @@ VMILayoutAttr::verify(function_ref emitError, if (blockElems != 1) return emitError() << "#pto.vmi.layout requires block_elems to be 1"; - if (slots < 0 || (slots != 0 && factor % slots != 0)) + if (slots < 0) return emitError() << "#pto.vmi.layout requires slots to be positive and divide num_groups when " - "specified"; + << "> requires slots to be omitted or positive"; return success(); } diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_partial_slots8.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_partial_slots8.pto new file mode 100644 index 0000000000..e828ba6b2d --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_partial_slots8.pto @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_reduce_f16_s64_g4( + %source: !pto.vmi.vreg<256xf16>, + %mask: !pto.vmi.mask<256xpred>) + -> !pto.vmi.vreg<256xf16> { + %out = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 4, reassoc} + : !pto.vmi.vreg<256xf16>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf16> + return %out : !pto.vmi.vreg<256xf16> + } + + func.func @vmi_layout_assignment_group_reduce_f16_s64_g12( + %source: !pto.vmi.vreg<768xf16>, + %mask: !pto.vmi.mask<768xpred>) + -> !pto.vmi.vreg<768xf16> { + %out = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 12, reassoc} + : !pto.vmi.vreg<768xf16>, !pto.vmi.mask<768xpred> + -> !pto.vmi.vreg<768xf16> + return %out : !pto.vmi.vreg<768xf16> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_reduce_f16_s64_g4( +// CHECK-SAME: %arg0: !pto.vmi.vreg<256xf16, #pto.vmi.layout> +// CHECK-SAME: %arg1: !pto.vmi.mask<256xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> +// CHECK: %[[SRC4:.*]] = pto.vmi.ensure_layout %arg0 +// CHECK-SAME: -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> +// CHECK: %[[MASK4_LAYOUT:.*]] = pto.vmi.ensure_mask_layout %arg1 +// CHECK-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// CHECK: %[[MASK4:.*]] = pto.vmi.ensure_mask_granularity %[[MASK4_LAYOUT]] +// CHECK-SAME: -> !pto.vmi.mask<256xb16, #pto.vmi.layout> +// CHECK: %[[OUT4:.*]] = pto.vmi.group_reduce_addf %[[SRC4]], %[[MASK4]] +// CHECK-SAME: -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> +// CHECK: return %[[OUT4]] + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_reduce_f16_s64_g12( +// CHECK-SAME: %arg0: !pto.vmi.vreg<768xf16, #pto.vmi.layout> +// CHECK-SAME: %arg1: !pto.vmi.mask<768xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<768xf16, #pto.vmi.layout> +// CHECK: %[[SRC12:.*]] = pto.vmi.ensure_layout %arg0 +// CHECK-SAME: -> !pto.vmi.vreg<768xf16, #pto.vmi.layout> +// CHECK: %[[MASK12_LAYOUT:.*]] = pto.vmi.ensure_mask_layout %arg1 +// CHECK-SAME: -> !pto.vmi.mask<768xb32, #pto.vmi.layout> +// CHECK: %[[MASK12:.*]] = pto.vmi.ensure_mask_granularity %[[MASK12_LAYOUT]] +// CHECK-SAME: -> !pto.vmi.mask<768xb16, #pto.vmi.layout> +// CHECK: %[[OUT12:.*]] = pto.vmi.group_reduce_addf %[[SRC12]], %[[MASK12]] +// CHECK-SAME: -> !pto.vmi.vreg<768xf16, #pto.vmi.layout> +// CHECK: return %[[OUT12]] diff --git a/test/lit/vmi/vmi_layout_group_slots_invalid.pto b/test/lit/vmi/vmi_layout_group_slots_invalid.pto index f354adb6e8..1f4ccd2856 100644 --- a/test/lit/vmi/vmi_layout_group_slots_invalid.pto +++ b/test/lit/vmi/vmi_layout_group_slots_invalid.pto @@ -10,9 +10,9 @@ module { func.func @vmi_layout_group_slots_invalid( - %arg0: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %arg0: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { return } } -// CHECK: #pto.vmi.layout requires slots to be positive and divide num_groups when specified +// CHECK: #pto.vmi.layout requires slots to be omitted or positive diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_partial_slots8.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_partial_slots8.pto new file mode 100644 index 0000000000..8efe26cf22 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_partial_slots8.pto @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_reduce_f16_s64_g4( + %source: !pto.vmi.vreg<256xf16, #pto.vmi.layout>, + %mask: !pto.vmi.mask<256xb16, #pto.vmi.layout>, + %dst: !pto.ptr, %off: index) { + %c1 = arith.constant 1 : index + %out = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 4, reassoc} + : !pto.vmi.vreg<256xf16, #pto.vmi.layout>, + !pto.vmi.mask<256xb16, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> + pto.vmi.group_store %out, %dst[%off], %c1 {num_groups = 4} + : !pto.vmi.vreg<256xf16, #pto.vmi.layout>, + !pto.ptr + return + } + + func.func @vmi_to_vpto_group_reduce_f16_s64_g8( + %source: !pto.vmi.vreg<512xf16, #pto.vmi.layout>, + %mask: !pto.vmi.mask<512xb16, #pto.vmi.layout>, + %dst: !pto.ptr, %off: index) { + %c1 = arith.constant 1 : index + %out = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf16, #pto.vmi.layout>, + !pto.vmi.mask<512xb16, #pto.vmi.layout> + -> !pto.vmi.vreg<512xf16, #pto.vmi.layout> + pto.vmi.group_store %out, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<512xf16, #pto.vmi.layout>, + !pto.ptr + return + } + + func.func @vmi_to_vpto_group_reduce_f16_s64_g12( + %source: !pto.vmi.vreg<768xf16, #pto.vmi.layout>, + %mask: !pto.vmi.mask<768xb16, #pto.vmi.layout>, + %dst: !pto.ptr, %off: index) { + %c1 = arith.constant 1 : index + %out = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 12, reassoc} + : !pto.vmi.vreg<768xf16, #pto.vmi.layout>, + !pto.vmi.mask<768xb16, #pto.vmi.layout> + -> !pto.vmi.vreg<768xf16, #pto.vmi.layout> + pto.vmi.group_store %out, %dst[%off], %c1 {num_groups = 12} + : !pto.vmi.vreg<768xf16, #pto.vmi.layout>, + !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_f16_s64_g4( +// CHECK-DAG: %[[SLOT4:.*]] = pto.pge_b16 "PAT_VL4" : !pto.mask +// CHECK-COUNT-4: pto.vcgadd +// CHECK-COUNT-3: pto.vadd {{.*}}, {{.*}}, %[[SLOT4]] +// CHECK: %[[STORE4:.*]] = pto.pge_b16 "PAT_VL4" : !pto.mask +// CHECK: pto.vsts {{.*}}, {{.*}}, %[[STORE4]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_f16_s64_g8( +// CHECK-DAG: %[[SLOT8:.*]] = pto.pge_b16 "PAT_VL8" : !pto.mask +// CHECK-COUNT-4: pto.vcgadd +// CHECK-COUNT-3: pto.vadd {{.*}}, {{.*}}, %[[SLOT8]] +// CHECK: %[[STORE8:.*]] = pto.pge_b16 "PAT_VL8" : !pto.mask +// CHECK: pto.vsts {{.*}}, {{.*}}, %[[STORE8]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_f16_s64_g12( +// CHECK: %[[SLOT8_12:.*]] = pto.pge_b16 "PAT_VL8" : !pto.mask +// CHECK-COUNT-4: pto.vcgadd +// CHECK-COUNT-3: pto.vadd {{.*}}, {{.*}}, %[[SLOT8_12]] +// CHECK: %[[SLOT4_12:.*]] = pto.pge_b16 "PAT_VL4" : !pto.mask +// CHECK-COUNT-4: pto.vcgadd +// CHECK-COUNT-3: pto.vadd {{.*}}, {{.*}}, %[[SLOT4_12]] +// CHECK: %[[STORE8_12:.*]] = pto.pge_b16 "PAT_VL8" : !pto.mask +// CHECK: pto.vsts {{.*}}, {{.*}}, %[[STORE8_12]] +// CHECK: %[[STORE4_12:.*]] = pto.pge_b16 "PAT_VL4" : !pto.mask +// CHECK: pto.vsts {{.*}}, {{.*}}, %[[STORE4_12]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_type_attr_parse.pto b/test/lit/vmi/vmi_type_attr_parse.pto index 5798114cc7..b2001c29f0 100644 --- a/test/lit/vmi/vmi_type_attr_parse.pto +++ b/test/lit/vmi/vmi_type_attr_parse.pto @@ -14,7 +14,9 @@ module attributes { pto.vmi_deinterleaved4 = #pto.vmi.layout, pto.vmi_deinterleaved4_block8 = #pto.vmi.layout, - pto.vmi_group_slots8 = #pto.vmi.layout + pto.vmi_group_slots8 = #pto.vmi.layout, + pto.vmi_group_slots_partial = + #pto.vmi.layout } { func.func @vmi_type_attr_parse( %surface: !pto.vmi.vreg<128xf32>, @@ -23,6 +25,7 @@ module attributes { %wide4: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, %wide4_block8: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, %group_slots8: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %group_slots_partial: !pto.vmi.vreg<640xf32, #pto.vmi.layout>, %surface_mask: !pto.vmi.mask<128xpred>, %mask_b8: !pto.vmi.mask<128xb8, #pto.vmi.layout>, %mask_b16: !pto.vmi.mask<128xb16, #pto.vmi.layout>, @@ -37,6 +40,7 @@ module attributes { // CHECK: pto.vmi_deinterleaved4 = #pto.vmi.layout // CHECK: pto.vmi_deinterleaved4_block8 = #pto.vmi.layout // CHECK: pto.vmi_group_slots8 = #pto.vmi.layout +// CHECK: pto.vmi_group_slots_partial = #pto.vmi.layout // CHECK-LABEL: func.func @vmi_type_attr_parse( // CHECK-SAME: %{{.*}}: !pto.vmi.vreg<128xf32> // CHECK-SAME: %{{.*}}: !pto.vmi.vreg<128xf32, #pto.vmi.layout> @@ -44,6 +48,7 @@ module attributes { // CHECK-SAME: %{{.*}}: !pto.vmi.vreg<256xf32, #pto.vmi.layout> // CHECK-SAME: %{{.*}}: !pto.vmi.vreg<256xf32, #pto.vmi.layout> // CHECK-SAME: %{{.*}}: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: %{{.*}}: !pto.vmi.vreg<640xf32, #pto.vmi.layout> // CHECK-SAME: %{{.*}}: !pto.vmi.mask<128xpred> // CHECK-SAME: %{{.*}}: !pto.vmi.mask<128xb8, #pto.vmi.layout> // CHECK-SAME: %{{.*}}: !pto.vmi.mask<128xb16, #pto.vmi.layout> From c9604adf9ffc06670869b55a091f6984fc278f2a Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Tue, 23 Jun 2026 15:05:06 +0800 Subject: [PATCH 22/31] Support arith select in VPTO LLVM lowering --- lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp | 36 +++++++++++++ lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 36 +++++++++++++ test/lit/vpto/arith_select_vpto_llvm.pto | 54 +++++++++++++++++++ 3 files changed, 126 insertions(+) create mode 100644 test/lit/vpto/arith_select_vpto_llvm.pto diff --git a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp index 2b881c6f6d..71db79529b 100644 --- a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp @@ -9289,6 +9289,41 @@ class ConvertVPTOUnrealizedCastOp final } }; +class ConvertArithSelectOp final : public OpConversionPattern { +public: + ConvertArithSelectOp(TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context, + PatternBenefit(2)) {} + + LogicalResult + matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!hasVPTOConvertibleType(op->getOperandTypes()) && + !hasVPTOConvertibleType(op->getResultTypes())) + return failure(); + if (!op.getCondition().getType().isInteger(1)) + return rewriter.notifyMatchFailure( + op, "only scalar i1 conditions supported for VPTO arith.select"); + + Type convertedResultType = + getTypeConverter()->convertType(op.getResult().getType()); + if (!convertedResultType) + return rewriter.notifyMatchFailure(op, "failed to convert result type"); + + Value trueValue = adaptor.getTrueValue(); + Value falseValue = adaptor.getFalseValue(); + if (trueValue.getType() != convertedResultType || + falseValue.getType() != convertedResultType) + return rewriter.notifyMatchFailure( + op, "converted true/false values must match result type"); + + rewriter.replaceOpWithNewOp( + op, convertedResultType, adaptor.getCondition(), trueValue, + falseValue); + return success(); + } +}; + class ConvertPtoAddPtrOp final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -10164,6 +10199,7 @@ static LogicalResult lowerVPTOTypes(ModuleOp module, llvm::raw_ostream &diagOS) patterns.add( typeConverter, context, state); + patterns.add(typeConverter, context); patterns.add(typeConverter, context); patterns.add(typeConverter, context); diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index cd501fc420..a495293aa6 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -9300,6 +9300,41 @@ class ConvertVPTOUnrealizedCastOp final } }; +class ConvertArithSelectOp final : public OpConversionPattern { +public: + ConvertArithSelectOp(TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context, + PatternBenefit(2)) {} + + LogicalResult + matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!hasVPTOConvertibleType(op->getOperandTypes()) && + !hasVPTOConvertibleType(op->getResultTypes())) + return failure(); + if (!op.getCondition().getType().isInteger(1)) + return rewriter.notifyMatchFailure( + op, "only scalar i1 conditions supported for VPTO arith.select"); + + Type convertedResultType = + getTypeConverter()->convertType(op.getResult().getType()); + if (!convertedResultType) + return rewriter.notifyMatchFailure(op, "failed to convert result type"); + + Value trueValue = adaptor.getTrueValue(); + Value falseValue = adaptor.getFalseValue(); + if (trueValue.getType() != convertedResultType || + falseValue.getType() != convertedResultType) + return rewriter.notifyMatchFailure( + op, "converted true/false values must match result type"); + + rewriter.replaceOpWithNewOp( + op, convertedResultType, adaptor.getCondition(), trueValue, + falseValue); + return success(); + } +}; + class ConvertPtoAddPtrOp final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -10177,6 +10212,7 @@ static LogicalResult lowerVPTOTypes(ModuleOp module, llvm::raw_ostream &diagOS) patterns.add( typeConverter, context, state); + patterns.add(typeConverter, context); patterns.add(typeConverter, context); patterns.add(typeConverter, context); diff --git a/test/lit/vpto/arith_select_vpto_llvm.pto b/test/lit/vpto/arith_select_vpto_llvm.pto new file mode 100644 index 0000000000..b32a7fe0de --- /dev/null +++ b/test/lit/vpto/arith_select_vpto_llvm.pto @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ( mkdir -p %T && ptoas --pto-arch=a5 --pto-backend=vpto %s -o %t --mlir-print-ir-after=convert-func-to-llvm 2>&1 || true ) | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @arith_select_vreg(%cond: i1, %lhs_scalar: f32, %rhs_scalar: f32, + %dst: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %lhs = pto.vdup %lhs_scalar, %mask + : f32, !pto.mask -> !pto.vreg<64xf32> + %rhs = pto.vdup %rhs_scalar, %mask + : f32, !pto.mask -> !pto.vreg<64xf32> + %chosen = arith.select %cond, %lhs, %rhs : !pto.vreg<64xf32> + pto.vsts %chosen, %dst[%c0], %mask + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + return + } + + func.func @arith_select_mask(%cond: i1, %value: f32, + %dst: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + pto.vecscope { + %all = pto.pset_b32 "PAT_ALL" : !pto.mask + %tail = pto.pge_b32 "PAT_VL4" : !pto.mask + %chosen_mask = arith.select %cond, %all, %tail : !pto.mask + %vec = pto.vdup %value, %all + : f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %vec, %dst[%c0], %chosen_mask + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + return + } +} + +// CHECK-LABEL: llvm.func @arith_select_vreg_mix_aiv +// CHECK: %[[LHS:.*]] = llvm.call @llvm.hivm.vdups{{.*}} +// CHECK: %[[RHS:.*]] = llvm.call @llvm.hivm.vdups{{.*}} +// CHECK: %[[CHOSEN:.*]] = llvm.select %arg0, %[[LHS]], %[[RHS]] : i1, vector<64xf32> +// CHECK: llvm.call @llvm.hivm.vstsx1.v64f32(%[[CHOSEN]] + +// CHECK-LABEL: llvm.func @arith_select_mask_mix_aiv +// CHECK: %[[ALL:.*]] = llvm.call @llvm.hivm.pset.b32 +// CHECK: %[[TAIL:.*]] = llvm.call @llvm.hivm.pge.b32 +// CHECK: %[[CHOSEN_MASK:.*]] = llvm.select %arg0, %[[ALL]], %[[TAIL]] : i1, vector<256xi1> +// CHECK: llvm.call @llvm.hivm.vstsx1.v64f32({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[CHOSEN_MASK]]) From fd5fc1132728a57e9b9dac306ac6bc8928847325 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Tue, 23 Jun 2026 15:09:44 +0800 Subject: [PATCH 23/31] Add VMI introduction design doc --- docs/designs/vmi-introduction.md | 658 +++++++++++++++++++++++++++++++ 1 file changed, 658 insertions(+) create mode 100644 docs/designs/vmi-introduction.md diff --git a/docs/designs/vmi-introduction.md b/docs/designs/vmi-introduction.md new file mode 100644 index 0000000000..94ca638cb2 --- /dev/null +++ b/docs/designs/vmi-introduction.md @@ -0,0 +1,658 @@ +# VMI 介绍 + +本文介绍 VMI 的设计入口:VMI 解决什么问题,layout 有哪些,pass pipeline +如何分工,以及这些机制分别应对哪些典型场景。更完整的逐 case lowering 结果见 +`docs/designs/vmi-layout-lowering-cases.md`。 + +示例是设计级 IR,保留关键 type、layout、helper op 和 VPTO op 形状, +省略 module wrapper、完整 operand list 和不影响讨论的 SSA 细节。 + +## 1. VMI 表达什么 + +VMI 是 VPTO 之前的逻辑向量层。它让前端先表达“我要对 `NxT` 的逻辑向量做什么”, +再由 layout assignment 决定这个逻辑向量如何拆到 256B 物理 vector register 上。 + +Surface VMI 类型不携带布局: + +```mlir +!pto.vmi.vreg<128xf32> +!pto.vmi.mask<128xpred> +``` + +Layout-assigned VMI 类型携带具体布局和 mask granularity: + +```mlir +!pto.vmi.vreg<128xf32, #pto.vmi.layout> +!pto.vmi.mask<128xb32, #pto.vmi.layout> +``` + +VMI 的核心约束是:`vmi-to-vpto` 只从当前 op 的 attrs、operands、types、 +layouts 和显式 helper ops 做 lowering,不读取隐藏 plan/recipe,也不通过 +defining op 或 sibling user 恢复上下文。 + +## 2. Layout 类型 + +### 2.1 `contiguous` + +```mlir +#pto.vmi.layout +``` + +含义:logical lane 按顺序落入物理 register list。 + +```text +logical lanes: 0 1 2 ... 63 | 64 65 ... 127 +physical part: p0 | p1 +``` + +典型场景: + +```text +dense load/store +普通 elementwise compute +一个 group 天然适配当前 reduce op 时的 reduction input +caller/callee 约定 dense order 时的 control-flow/function boundary +``` + +### 2.2 `deinterleaved = F, block_elems = B` + +```mlir +#pto.vmi.layout +#pto.vmi.layout +``` + +`block_elems` 缺省为 `1`。逻辑 lane 到物理 part 的映射是: + +```text +logical lane i +block q = i / B +in-block lane r = i % B +part p = q % F +part block t = q / F + +physical part p, physical lane t * B + r +``` + +`deinterleaved=2` 的直观例子: + +```text +logical lanes: 0 1 2 3 4 5 ... +physical part0: 0 2 4 ... +physical part1: 1 3 5 ... +``` + +`deinterleaved=4, block_elems=8` 的直观例子: + +```text +logical group S=32: + lanes 0.. 7 -> part0 lanes 0..7 + lanes 8..15 -> part1 lanes 0..7 + lanes 16..23 -> part2 lanes 0..7 + lanes 24..31 -> part3 lanes 0..7 +``` + +典型场景: + +```text +f16 -> f32: + vcvt 天然产生 even/odd 两个 f32 part,所以结果使用 deinterleaved=2。 + +f32 -> f16: + vcvt 需要 f32 source 先拆成 even/odd 两个 part,所以 source 使用 + deinterleaved=2。 + +S=32 group_reduce f32: + 一个 group 有 32 个 f32 element。高效 reduce path 消费四个 8-lane block, + 所以 source/mask 使用 deinterleaved=4, block_elems=8。 +``` + +### 2.3 `num_groups = G, slots = K` + +```mlir +#pto.vmi.layout +#pto.vmi.layout +``` + +这是 sparse group-result layout。它不表示全部 `N` 个 logical lane 都有语义值。 +只有 `G` 个 group 结果 slot 有语义值。 + +```text +slot_block(g) = g / K +slot_lane(g) = g % K + +physical part slot_block(g) 的 lane slot_lane(g) 保存 group g 的结果 +``` + +`num_groups=16, slots=8` 的例子: + +```text +part0 lane0..7 = group result 0..7 +part1 lane0..7 = group result 8..15 +other lanes = 对普通 dense consumer 来说未定义 +``` + +为什么 group 信息也要放进 layout: + +```text +group_reduce 自身有 num_groups,但它的结果可能继续跨过 truncf、 +group_broadcast、group_store、scf.if、scf.for、function call 或多个 consumer。 + +这些后续 op 不应该回看 producer attr。value layout 因此需要记录有多少个 +group result,以及这些 result 如何 packed 到 physical slot。 +``` + +典型场景: + +```text +group_reduce result +group_slot_load result +group_store input +group_broadcast input +group-slot control-flow/function boundary +部分 row-local cast 路径,通常使用 slots=1 +``` + +## 3. Pass Pipeline + +```text +pto-validate-vmi-ir + -> vmi-layout-assignment + -> canonicalize/cse + -> vmi-layout-fold-consumers + -> canonicalize/cse + -> vmi-layout-rematerialize + -> canonicalize/cse + -> vmi-layout-sink-materialization + -> canonicalize/cse + -> vmi-legalize-arith-select + -> pto-validate-vmi-layout-ir + -> vmi-to-vpto +``` + +### 3.1 `pto-validate-vmi-ir` + +检查 surface VMI 边界。 + +合法输入: + +```mlir +%x = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<128xf16> +``` + +非法输入: + +```mlir +%x = pto.vmi.load %src[%off] + : !pto.ptr + -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +``` + +原因:具体 layout 由 `vmi-layout-assignment` 产生,不应该由 surface frontend +提前写入。 + +### 3.2 `vmi-layout-assignment` + +这是硬合法化 pass。它选择具体 value layout、具体 mask granularity, +并在 layout 不匹配的 use site 插入显式 helper op。 + +例子:`f16 -> f32 -> store`。 + +Surface VMI: + +```mlir +%x16 = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<128xf16> +%x32 = pto.vmi.extf %x16 + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> +pto.vmi.store %x32, %dst[%off] + : !pto.vmi.vreg<128xf32>, !pto.ptr +``` + +Assignment 之后: + +```mlir +%x16 = pto.vmi.load %src[%off] + : !pto.ptr + -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> + +%x32 = pto.vmi.extf %x16 + : !pto.vmi.vreg<128xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%x32_dense = pto.vmi.ensure_layout %x32 + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +pto.vmi.store %x32_dense, %dst[%off] + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr +``` + +即使不跑任何优化 pass,这个 assignment 后的 IR 也已经是正确可降的。 + +### 3.3 `vmi-layout-fold-consumers` + +当 consumer 可以直接保持同样的外部效果时,把显式 materialization 折进 +consumer。 + +变换前: + +```mlir +%dense = pto.vmi.ensure_layout %x + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +pto.vmi.store %dense, %dst[%off] +``` + +变换后: + +```mlir +pto.vmi.store %x, %dst[%off] + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr +``` + +可能的 VPTO 形状: + +```text +fold 前:vintlv + vsts + vsts +fold 后:vstsx2,使用交错 store mode +``` + +### 3.4 `vmi-layout-rematerialize` + +通过 clone 低成本、layout-polymorphic 的 producer 来替换 `ensure_*`。 + +变换前: + +```mlir +%s = pto.vmi.broadcast %scale + : f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +%s_split = pto.vmi.ensure_layout %s + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +变换后: + +```mlir +%s_split = pto.vmi.broadcast %scale + : f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +预期可 rematerialize 的 producer: + +```text +splat constant +broadcast +iota +create_mask +create_group_mask +constant_mask +``` + +这个 pass 不 rematerialize: + +```text +load / masked_load / group_load / group_slot_load +reduce / group_reduce +control-flow results +``` + +### 3.5 `vmi-layout-sink-materialization` + +把匹配的 layout 转换跨过 layout-transparent elementwise op。 + +变换前: + +```mlir +%a_dense = pto.vmi.ensure_layout %a : deinterleaved=2 -> contiguous +%b_dense = pto.vmi.ensure_layout %b : deinterleaved=2 -> contiguous +%y_dense = pto.vmi.addf %a_dense, %b_dense : contiguous +``` + +变换后: + +```mlir +%y_split = pto.vmi.addf %a, %b : deinterleaved=2 +%y_dense = pto.vmi.ensure_layout %y_split : deinterleaved=2 -> contiguous +``` + +效果: + +```text +两个 input materialization -> 一个 result materialization +``` + +这个 pass 不会 sink 穿过 cast、load、store、reduce、group_broadcast 或 +control-flow op。 + +### 3.6 `vmi-legalize-arith-select` + +Canonicalization 可能把简单的 `scf.if` 折成 `arith.select`。VMI 希望把 +control-flow lowering 保持在结构化控制流里,所以这个 pass 会把 VMI value 上的 +`arith.select` 改回 `scf.if`。 + +```mlir +%r = arith.select %cond, %a, %b + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +改成: + +```mlir +%r = scf.if %cond + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + scf.yield %a : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +} else { + scf.yield %b : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +} +``` + +### 3.7 `pto-validate-vmi-layout-ir` + +检查 post-assignment gate: + +```text +每个 VMI 数据值都有 concrete layout +每个 VMI mask 都有 concrete granularity 和 layout +helper op 有支持的 materialization path +semantic op/layout 组合有支持的 local lowering +vmi-to-vpto 之前没有物理 VPTO value 泄漏到 VMI IR 中 +``` + +非法例子: + +```mlir +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : ... -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + +pto.vmi.store %sum, %dst[%off] + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.ptr +``` + +原因: + +```text +dense store 不能把 sparse group_slots 当 dense vector 读取。 +应使用 group_store、group_broadcast 或显式支持的 group-to-dense op。 +``` + +### 3.8 `vmi-to-vpto` + +把 layout-assigned VMI value 转换成有序物理 VPTO value 列表,并对每个 +VMI op 做 local lowering。 + +例子: + +```text +!pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> 两个 physical !pto.vreg<64xf32> part + +!pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> 两个 physical !pto.vreg<64xf32> part + part0 携带 even lanes,part1 携带 odd lanes + +!pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> 四个 physical part + part0 携带 group 0..7,part1 携带 group 8..15,... +``` + +`VMILayoutSupport` 不是 pass。它是 assignment、validation、optimization 和 +lowering 共享的查询库,用来避免重复实现 layout fact 和 supported +materialization 检查。 + +## 4. 典型场景 + +### 4.1 Dense Cast 与 Store + +```text +surface: + load f16,语义上连续 + extf 到 f32 + dense store f32 + +assignment: + load result = contiguous + extf result = deinterleaved=2 + store use = ensure_layout(deinterleaved=2 -> contiguous) + +baseline VPTO: + vlds + vcvt even / vcvt odd + vintlv + vsts + vsts + +fold-consumers 后的优化 VPTO: + vlds + vcvt even / vcvt odd + vstsx2,使用 interleaving store +``` + +这个场景说明为什么需要 `deinterleaved=2`,以及为什么 store-consumer folding +有价值。 + +### 4.2 Narrow Cast 与 Store + +```text +surface: + load f32 + truncf 到 f16 + dense store f16 + +assignment: + load result = deinterleaved=2 + truncf result = contiguous + +VPTO: + vldsx2 deinterleaving load + vcvt even / vcvt odd + vor + vsts +``` + +这个场景说明 memory op 可以直接产生 consumer 需要的 layout,但不需要保存隐藏 +plan。 + +### 4.3 一个 Producer 同时服务 Dense 和 Group Consumer + +```mlir +%x32 = pto.vmi.extf %x16 +%sum = pto.vmi.group_reduce_addf %x32, %mask {num_groups = 8, reassoc} +pto.vmi.group_store %sum, %sum_out[%off], %c1 {num_groups = 8} +pto.vmi.store %x32, %dense_out[%off] +``` + +Assignment 形状: + +```text +%x32 layout = deinterleaved=2 +group_reduce 直接消费 %x32 +dense store 获得 ensure_layout(%x32 -> contiguous) +``` + +VPTO 形状: + +```text +vcvt even/odd +vcgadd + vcgadd + vadd -> group_store result +vintlv + dense stores -> 产生 dense store 结果 +``` + +这个场景说明为什么需要 use-site materialization。producer 不需要选择一个能同时 +满足所有 consumer 的唯一 layout。 + +### 4.4 按 Group Size 区分的 Group Reduce + +对于 `N` 个 f32 lane 和 `G = num_groups`,group size 是 `S = N / G`。 + +```text +S=8: + input layout 可以是 contiguous。 + group_reduce result 通常使用 layout。 + +S=16: + 如果 input 来自 f16->f32 vcvt,layout 可以是 deinterleaved=2。 + 如果 input 从 dense 拆出,layout 可以是 deinterleaved=2, block_elems=8。 + result 通常使用 layout。 + +S=32: + input layout 使用 deinterleaved=4, block_elems=8。 + VPTO 形状是四个部分 group reduction 后接 add tree。 + result 通常使用 layout。 + +S=64: + row-local path 在可行时让每个 group 使用一条 physical row。 + result 可以使用 layout,避免 unsupported packing。 +``` + +S=32 例子: + +```text +assignment: + source/mask = deinterleaved=4, block_elems=8 + result = group_slots(num_groups=8, slots=8) + +VPTO: + vdintlv / pdintlv_b32 + vcgadd x4 + 使用 PAT_VL8 做 vadd tree + 通过一次 PAT_VL8 store 完成 group_store +``` + +这个场景说明为什么需要 `block_elems`。 + +### 4.5 Group Result 继续作为 Dense Rows 使用 + +Surface 意图: + +```mlir +%sum32 = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} +%sum16 = pto.vmi.truncf %sum32 +%rows16 = pto.vmi.group_broadcast %sum16 {num_groups = 8} +pto.vmi.store %rows16, %dst[%off] +``` + +支持的 assignment 形状: + +```mlir +%sum32 = pto.vmi.group_reduce_addf ... + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%rows32 = pto.vmi.group_broadcast %sum32 {num_groups = 8} + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%rows32_split = pto.vmi.ensure_layout %rows32 + : contiguous -> deinterleaved=2 + +%rows16 = pto.vmi.truncf %rows32_split + : deinterleaved=2 -> contiguous + +pto.vmi.store %rows16, %dst[%off] +``` + +VPTO 形状: + +```text +group_reduce: + vcgadd partials + vadd tree + +group_broadcast: + vselr 风格 selection,把 group slots 展开到 dense row lanes + +truncf: + vcvt even/odd + merge + +store: + vsts +``` + +这个场景说明为什么 group 结果 layout 必须挂在 value 上:reduce 之后, +cast 和 broadcast 必须知道 group 结果在哪里,而不能回看 producer。 + +### 4.6 通过 Mask 表达 Tail + +VMI 通过 mask 表达 tail,不通过 padding 表达 tail。 + +```mlir +%mask = pto.vmi.create_mask %active_lanes +%x = pto.vmi.masked_load %src[%off], %mask +%y = pto.vmi.mulf %x, %scale +pto.vmi.masked_store %y, %dst[%off], %mask +``` + +Grouped tail: + +```mlir +%gmask = pto.vmi.create_group_mask %active_elems_per_group + {num_groups = 8, group_size = 32} +%sum = pto.vmi.group_reduce_addf %x, %gmask {num_groups = 8, reassoc} +``` + +同一个 semantic mask 面对 f8/f16/f32 user 时,可能需要不同 concrete +granularity。Assignment 会通过 mask helper op 显式表达这些转换。 + +### 4.7 控制流和函数边界 + +Concrete layout 必须显式跨过 CFG 和内部 function boundary。 + +```mlir +%r = scf.if %cond + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + %a_dense = pto.vmi.ensure_layout %a : deinterleaved=2 -> contiguous + scf.yield %a_dense +} else { + %b_dense = pto.vmi.ensure_layout %b : deinterleaved=2 -> contiguous + scf.yield %b_dense +} +``` + +`vmi-to-vpto` 之后,region result 会变成多个物理 VPTO value: + +```text +scf.if -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +``` + +这个场景说明为什么 layout 应该是 type 的一部分,而不是依赖 defining op。 + +## 5. 当前边界 + +当前设计方向: + +```text +surface VMI: + 描述不带 layout 的逻辑向量语义。 + +layout assignment: + 选择 layout、mask granularity 和显式 materialization helper。 + +optimization: + 只在结果 IR 仍然可以 local lowering 时改写显式 helper。 + +vmi-to-vpto: + 严格 lower 它看到的 assigned/optimized IR。 +``` + +暂不支持或有意收紧的范围: + +```text +group_slots value 的普通 dense store: + 非法,除非先经过 group_broadcast 或其他显式 group-to-dense op。 + +packed group_slots f32->f16 cast: + 非法,除非 assignment 能把它 commute 到 group_broadcast 之后,或者使用 + 支持的 row-local slots=1 path。 + +extract: + 暂不作为支持的 VMI surface。 + +padding transfer_read: + 当前 tail 设计不需要;tail 使用 mask。 + +scan / contract / gather / scatter / compress / active_prefix_index: + dialect surface 中可以存在,但除非补充具体 case,否则不属于第一阶段聚焦的 + layout/lowering 实现集合。 +``` + +设计目标是优先保证语义完整:只要 VMI 接受某个 case,所需的 layout 沟通就必须 +在 IR 中显式表达,并且能被 `vmi-to-vpto` local lowering。 From 46942f09ed5c6d8a34f6bdf820cf95f140ed3332 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Tue, 23 Jun 2026 16:59:52 +0800 Subject: [PATCH 24/31] Fold deinterleaved VMI loads through vldsx2 --- docs/designs/vmi-introduction.md | 59 +++++++++------- docs/designs/vmi-layout-lowering-cases.md | 70 ++++++++++++++++--- lib/PTO/Transforms/VMIToVPTO.cpp | 62 ++++++++++++++++ ..._layout_assignment_f32_f8_store_reduce.pto | 4 +- test/lit/vmi/vmi_to_vpto_load_deint.pto | 12 ++-- .../vmi/vmi_to_vpto_load_deint_multichunk.pto | 36 ++++++++++ test/lit/vmi/vmi_to_vpto_quant_dequant.pto | 10 +-- test/lit/vmi/vmi_to_vpto_quant_fp8.pto | 8 +-- 8 files changed, 204 insertions(+), 57 deletions(-) diff --git a/docs/designs/vmi-introduction.md b/docs/designs/vmi-introduction.md index 94ca638cb2..fb1f1b7135 100644 --- a/docs/designs/vmi-introduction.md +++ b/docs/designs/vmi-introduction.md @@ -106,6 +106,16 @@ S=32 group_reduce f32: 所以 source/mask 使用 deinterleaved=4, block_elems=8。 ``` +`block_elems=8` 表示一种按 32B row fragment 组织的输入形态,不表示 +S=32 reduce 只能接受这一种形态。如果同一个 value 还要服务 narrow cast 等 +element-parity consumer,assignment 可以选择 `deinterleaved=4, block_elems=1` +作为共同 layout,再由 lowering 生成对应的物理指令序列。 + +`deinterleaved` 只描述最终物理 part 中有哪些 logical lane,不描述这个 layout +由哪条指令生成。不同 producer 可以用不同方式直接产生同一个 layout;如果不能 +直接产生,后续 lowering 再通过显式 materialization helper 把 source layout +转换成 consumer 需要的 layout。具体 lowering 形状见 case catalog。 + ### 2.3 `num_groups = G, slots = K` ```mlir @@ -196,39 +206,40 @@ pto-validate-vmi-ir 这是硬合法化 pass。它选择具体 value layout、具体 mask granularity, 并在 layout 不匹配的 use site 插入显式 helper op。 -例子:`f16 -> f32 -> store`。 +实现上它维护 data 和 mask 两套求解状态: -Surface VMI: +```text +data value: + 每个 !pto.vmi.vreg 是一个节点,节点记录最终选择的布局。 -```mlir -%x16 = pto.vmi.load %src[%off] - : !pto.ptr -> !pto.vmi.vreg<128xf16> -%x32 = pto.vmi.extf %x16 - : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> -pto.vmi.store %x32, %dst[%off] - : !pto.vmi.vreg<128xf32>, !pto.ptr +mask value: + 每个 !pto.vmi.mask 是一个节点,节点记录最终选择的布局和 predicate 粒度。 ``` -Assignment 之后: - -```mlir -%x16 = pto.vmi.load %src[%off] - : !pto.ptr - -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +data value 使用 union-find 表示“这些 value 必须共用 layout”。函数参数、 +call operand/result、return/yield、block argument、bitcast 等边界会把相关 +value 合并到同一个等价类里。等价类只能有一个最终 data layout。 -%x32 = pto.vmi.extf %x16 - : !pto.vmi.vreg<128xf16, #pto.vmi.layout> - -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +assignment 遍历 IR 时,每类 op 向求解器贡献两种信息: -%x32_dense = pto.vmi.ensure_layout %x32 - : !pto.vmi.vreg<128xf32, #pto.vmi.layout> - -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +```text +result 自然布局: + 这个 op 自己产生的 result 适合用什么 layout 表达。 -pto.vmi.store %x32_dense, %dst[%off] - : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr +operand 使用请求: + 这个 op 消费某个 operand 时希望 operand 是什么 layout。 ``` -即使不跑任何优化 pass,这个 assignment 后的 IR 也已经是正确可降的。 +有些 producer 生成的是同一个逻辑向量,但可以用多种物理 layout 表达。若它的 +所有 consumer 给出的使用请求一致,assignment 会把这个请求反推为 producer +result 的最终布局。否则,producer 保持自己的布局,assignment 在不匹配的 use +site 插入 `pto.vmi.ensure_layout`。mask 使用同样思路,但还会同时求解 predicate +粒度,必要时插入 `ensure_mask_layout` 或 `ensure_mask_granularity`。 + +最后,pass 会把所有 VMI data/mask type 改写成带 layout 的 type,并同步更新 +function type、call site、block argument 和 terminator operand。这个阶段之后, +IR 不再依赖隐藏 plan;后续 pass 和 `vmi-to-vpto` 都只读取 type 上的 layout +和显式 `ensure_*` helper。 ### 3.3 `vmi-layout-fold-consumers` diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md index 5a007987a7..855a7a486f 100644 --- a/docs/designs/vmi-layout-lowering-cases.md +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -131,6 +131,52 @@ group_store future explicit group-pack op ``` +Contiguous memory loads may produce a non-contiguous physical value directly +when the requested result layout is a dense deinterleaved layout. This is a +lowering choice, not a separate layout family. + +```text +pto.vmi.load -> #pto.vmi.layout + lower as: + vlds NORM for each physical chunk + +pto.vmi.load -> #pto.vmi.layout + lower as: + vldsx2 DINTLV_B* for each pair of physical chunks + +pto.vmi.load -> #pto.vmi.layout + lower as: + two vldsx2 DINTLV_B* operations for each four-chunk group + followed by two vdintlv operations to split mod4 parts + +pto.vmi.load -> #pto.vmi.layout + lower using the producer-specific path or fall back to explicit + materialization. Do not treat DINTLV_B* as a block-fragment layout. +``` + +The `deinterleaved = 4` result order remains the normal VMI physical part +order: + +```text +results = [part0 chunks..., part1 chunks..., part2 chunks..., part3 chunks...] +``` + +For one full `256xf32` tile: + +```text +%even0, %odd0 = pto.vldsx2 %base[%off0], "DINTLV_B32" +%even1, %odd1 = pto.vldsx2 %base[%off128], "DINTLV_B32" + +%part0, %part2 = pto.vdintlv %even0, %even1 +%part1, %part3 = pto.vdintlv %odd0, %odd1 + +replace pto.vmi.load with [%part0, %part1, %part2, %part3] +``` + +This optimization is legal only for full physical chunks and supported +`DINTLV_B8/B16/B32` element widths. Tail and masked loads keep their explicit +safe lowering until a masked or guarded `vldsx2` strategy is designed. + ## 3. Lowering Results The following examples use symbolic VPTO names. `PAT_ALL_B*` means an all-true @@ -3927,18 +3973,14 @@ VPTO lowering result: %all_b32 = pto.pge_b32 "PAT_ALL" %sum_mask = pto.pge_b32 "PAT_VL8" -%x0 = pto.vlds %base[%off] : memref<256xf32> -> !pto.vreg<64xf32> -%x1 = pto.vlds %base[%off_plus_64] : memref<256xf32> -> !pto.vreg<64xf32> -%x2 = pto.vlds %base[%off_plus_128] : memref<256xf32> -> !pto.vreg<64xf32> -%x3 = pto.vlds %base[%off_plus_192] : memref<256xf32> -> !pto.vreg<64xf32> +%x_even_0, %x_odd_0 = pto.vldsx2 %base[%off], "DINTLV_B32" + : memref<256xf32>, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_even_1, %x_odd_1 = pto.vldsx2 %base[%off_plus_128], "DINTLV_B32" + : memref<256xf32>, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> -%x01_lo, %x01_hi = pto.vdintlv %x0, %x1 - : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> -%x23_lo, %x23_hi = pto.vdintlv %x2, %x3 - : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> -%x_p0, %x_p2 = pto.vdintlv %x01_lo, %x23_lo +%x_p0, %x_p2 = pto.vdintlv %x_even_0, %x_even_1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> -%x_p1, %x_p3 = pto.vdintlv %x01_hi, %x23_hi +%x_p1, %x_p3 = pto.vdintlv %x_odd_0, %x_odd_1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> %s0 = pto.vcgadd %x_p0, %all_b32 : !pto.vreg<64xf32> @@ -3991,6 +4033,14 @@ The common layout selected for `%x32` is `truncf f32 -> f8` and S=32 `group_reduce_addf`. A later strided block-load producer may introduce `block_elems = 8`, but that is a different case and requires an explicit materialization/rematerialization decision. + +When `%x32` is produced by a full contiguous `pto.vmi.load`, `vmi-to-vpto` +should not first materialize four contiguous f32 chunks and then run a full +four-op `vdintlv` tree. The load lowering should fold the first deinterleave +level into two `vldsx2 DINTLV_B32` operations and then run only the second +`vdintlv` level, as shown above. The layout remains just +`deinterleaved = 4, block_elems = 1`; it does not encode the fact that `vldsx2` +was used. ``` ### 3.33 One Dense Value Feeding S=16 And S=32 Reduces diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index 7f10e39ea6..ea286520bf 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -3794,6 +3794,68 @@ struct OneToNVMILoadOpPattern : OneToNOpConversionPattern { } } + if (resultLayout && resultLayout.isDeinterleaved() && + resultLayout.getFactor() == 4 && resultLayout.getBlockElems() == 1) { + std::optional dist = + getX2MemoryDistToken(resultVMIType.getElementType(), "DINTLV"); + if (dist && !resultTypes.empty() && resultTypes.size() % 4 == 0) { + int64_t groups = resultTypes.size() / 4; + SmallVector part0; + SmallVector part1; + SmallVector part2; + SmallVector part3; + part0.reserve(groups); + part1.reserve(groups); + part2.reserve(groups); + part3.reserve(groups); + for (int64_t group = 0; group < groups; ++group) { + Type part0Type = resultTypes[group]; + Type part1Type = resultTypes[groups + group]; + Type part2Type = resultTypes[2 * groups + group]; + Type part3Type = resultTypes[3 * groups + group]; + if (part0Type != part1Type || part0Type != part2Type || + part0Type != part3Type) + return rewriter.notifyMatchFailure( + op, "vldsx2 deinterleaved=4 load requires matching part " + "types"); + + Value firstOffset = createChunkOffset( + op.getLoc(), *offset, group * 4 * *lanesPerPart, rewriter); + Value secondOffset = createChunkOffset( + op.getLoc(), *offset, (group * 4 + 2) * *lanesPerPart, + rewriter); + auto first = + rewriter.create(op.getLoc(), part0Type, part1Type, + *source, firstOffset, + rewriter.getStringAttr(*dist)); + auto second = + rewriter.create(op.getLoc(), part2Type, part3Type, + *source, secondOffset, + rewriter.getStringAttr(*dist)); + + auto even = rewriter.create( + op.getLoc(), part0Type, part2Type, first.getLow(), + second.getLow()); + auto odd = rewriter.create( + op.getLoc(), part1Type, part3Type, first.getHigh(), + second.getHigh()); + part0.push_back(even.getLow()); + part1.push_back(odd.getLow()); + part2.push_back(even.getHigh()); + part3.push_back(odd.getHigh()); + } + + SmallVector results; + results.reserve(resultTypes.size()); + results.append(part0); + results.append(part1); + results.append(part2); + results.append(part3); + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + } + SmallVector contiguousParts; contiguousParts.reserve(resultTypes.size()); for (auto [index, resultType] : llvm::enumerate(resultTypes)) { diff --git a/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto b/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto index b976ab518d..8dfe2292cf 100644 --- a/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto +++ b/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto @@ -48,8 +48,8 @@ module { // ASSIGN: pto.vmi.store %[[X8]] // LOWER-LABEL: func.func @vmi_layout_assignment_f32_f8_store_reduce( -// LOWER-COUNT-4: pto.vlds -// LOWER-COUNT-4: pto.vdintlv +// LOWER-COUNT-2: pto.vldsx2 +// LOWER-COUNT-2: pto.vdintlv // LOWER-COUNT-4: pto.vcgadd // LOWER-COUNT-3: pto.vadd // LOWER: pto.vsts diff --git a/test/lit/vmi/vmi_to_vpto_load_deint.pto b/test/lit/vmi/vmi_to_vpto_load_deint.pto index 715dacdfa6..0f3c3f825a 100644 --- a/test/lit/vmi/vmi_to_vpto_load_deint.pto +++ b/test/lit/vmi/vmi_to_vpto_load_deint.pto @@ -39,14 +39,10 @@ module { // CHECK-NOT: unrealized_conversion_cast // CHECK-LABEL: func.func @vmi_to_vpto_load_deint4( -// CHECK: %[[D0:.*]] = pto.vlds %arg0[%arg1] -// CHECK: %[[D1:.*]] = pto.vlds %arg0[{{.*}}] -// CHECK: %[[D2:.*]] = pto.vlds %arg0[{{.*}}] -// CHECK: %[[D3:.*]] = pto.vlds %arg0[{{.*}}] -// CHECK: %[[A0:.*]], %[[B0:.*]] = pto.vdintlv %[[D0]], %[[D1]] -// CHECK: %[[A1:.*]], %[[B1:.*]] = pto.vdintlv %[[D2]], %[[D3]] -// CHECK: %[[P0:.*]], %[[P2:.*]] = pto.vdintlv %[[A0]], %[[A1]] -// CHECK: %[[P1:.*]], %[[P3:.*]] = pto.vdintlv %[[B0]], %[[B1]] +// CHECK: %[[E0:.*]], %[[O0:.*]] = pto.vldsx2 %arg0[%arg1], "DINTLV_B32" +// CHECK: %[[E1:.*]], %[[O1:.*]] = pto.vldsx2 %arg0[{{.*}}], "DINTLV_B32" +// CHECK: %[[P0:.*]], %[[P2:.*]] = pto.vdintlv %[[E0]], %[[E1]] +// CHECK: %[[P1:.*]], %[[P3:.*]] = pto.vdintlv %[[O0]], %[[O1]] // CHECK: return %[[P0]], %[[P1]], %[[P2]], %[[P3]] // CHECK-NOT: pto.vmi. // CHECK-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_to_vpto_load_deint_multichunk.pto b/test/lit/vmi/vmi_to_vpto_load_deint_multichunk.pto index 433f222af3..200a1af04e 100644 --- a/test/lit/vmi/vmi_to_vpto_load_deint_multichunk.pto +++ b/test/lit/vmi/vmi_to_vpto_load_deint_multichunk.pto @@ -20,6 +20,28 @@ module { return %p0, %p1, %p2, %p3 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> } + + func.func @vmi_to_vpto_load_deint4_multichunk( + %src: !pto.ptr, %offset: index) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %value = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + %p0_0, %p0_1, %p1_0, %p1_1, %p2_0, %p2_1, %p3_0, %p3_1 = + "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<512xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0_0, %p0_1, %p1_0, %p1_1, %p2_0, %p2_1, %p3_0, %p3_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32> + } } // CHECK-LABEL: func.func @vmi_to_vpto_load_deint2_multichunk( @@ -29,3 +51,17 @@ module { // CHECK-NOT: pto.vmi. // CHECK-NOT: !pto.vmi. // CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_load_deint4_multichunk( +// CHECK: %[[E0_0:.*]], %[[O0_0:.*]] = pto.vldsx2 %arg0[%arg1], "DINTLV_B32" +// CHECK: %[[E1_0:.*]], %[[O1_0:.*]] = pto.vldsx2 %arg0[{{.*}}], "DINTLV_B32" +// CHECK: %[[P0_0:.*]], %[[P2_0:.*]] = pto.vdintlv %[[E0_0]], %[[E1_0]] +// CHECK: %[[P1_0:.*]], %[[P3_0:.*]] = pto.vdintlv %[[O0_0]], %[[O1_0]] +// CHECK: %[[E0_1:.*]], %[[O0_1:.*]] = pto.vldsx2 %arg0[{{.*}}], "DINTLV_B32" +// CHECK: %[[E1_1:.*]], %[[O1_1:.*]] = pto.vldsx2 %arg0[{{.*}}], "DINTLV_B32" +// CHECK: %[[P0_1:.*]], %[[P2_1:.*]] = pto.vdintlv %[[E0_1]], %[[E1_1]] +// CHECK: %[[P1_1:.*]], %[[P3_1:.*]] = pto.vdintlv %[[O0_1]], %[[O1_1]] +// CHECK: return %[[P0_0]], %[[P0_1]], %[[P1_0]], %[[P1_1]], %[[P2_0]], %[[P2_1]], %[[P3_0]], %[[P3_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_quant_dequant.pto b/test/lit/vmi/vmi_to_vpto_quant_dequant.pto index dd69bcfaa2..c3a1a0fede 100644 --- a/test/lit/vmi/vmi_to_vpto_quant_dequant.pto +++ b/test/lit/vmi/vmi_to_vpto_quant_dequant.pto @@ -287,15 +287,11 @@ module { // CHECK-SAME: %[[FQDST:[^,]+]]: !pto.ptr // CHECK: scf.for // CHECK: scf.for -// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> -// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> -// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> -// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> -// CHECK: pto.vmul {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> -// CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> -// CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vldsx2 {{.*}}, "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vldsx2 {{.*}}, "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> // CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> // CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vmul {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> // CHECK: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> // CHECK: pto.vcvt {{.*}} {part = "P1", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> // CHECK: pto.vcvt {{.*}} {part = "P2", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> diff --git a/test/lit/vmi/vmi_to_vpto_quant_fp8.pto b/test/lit/vmi/vmi_to_vpto_quant_fp8.pto index c44de2ec84..01e92013cc 100644 --- a/test/lit/vmi/vmi_to_vpto_quant_fp8.pto +++ b/test/lit/vmi/vmi_to_vpto_quant_fp8.pto @@ -30,12 +30,8 @@ module { } // CHECK-LABEL: func.func @vmi_to_vpto_quant_matrix_f32_to_fp8( -// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> -// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> -// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> -// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> -// CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> -// CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vldsx2 {{.*}}, "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vldsx2 {{.*}}, "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> // CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> // CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> // CHECK: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> From 910a2a98cd6b313474147f7fe1e4ba51b6b0f7a1 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Tue, 23 Jun 2026 17:31:33 +0800 Subject: [PATCH 25/31] Document VMI layout assignment mechanism --- docs/designs/vmi-introduction.md | 225 ++++++++++++++++++++++++++++--- 1 file changed, 203 insertions(+), 22 deletions(-) diff --git a/docs/designs/vmi-introduction.md b/docs/designs/vmi-introduction.md index fb1f1b7135..33852a69ce 100644 --- a/docs/designs/vmi-introduction.md +++ b/docs/designs/vmi-introduction.md @@ -206,40 +206,221 @@ pto-validate-vmi-ir 这是硬合法化 pass。它选择具体 value layout、具体 mask granularity, 并在 layout 不匹配的 use site 插入显式 helper op。 -实现上它维护 data 和 mask 两套求解状态: +这个 pass 的工作顺序是固定的: ```text -data value: - 每个 !pto.vmi.vreg 是一个节点,节点记录最终选择的布局。 +1. 做少量 VMI 内部规整,让后续 layout 规则面对稳定形态。 +2. 为 data value 建 union-find 求解器,并收集 data 约束和 data use request。 +3. 把可采纳的 consumer request 提升为 producer/result 的最终 layout。 +4. 改写所有 data value type,让 !pto.vmi.vreg 携带具体 layout。 +5. 对仍不匹配的 data use 插入 pto.vmi.ensure_layout。 +6. 基于已经确定的 data layout 推导 mask layout 和 predicate granularity。 +7. 改写所有 mask type,并对不匹配的 mask use 插入 ensure_mask_*。 +8. 同步更新 function type、call boundary 和 block argument type。 +9. 校验 layout-assigned VMI IR。 +``` + +Data 和 mask 分两轮求解。原因是 mask layout 通常依赖对应 data operand 或 result +的 layout;例如 `cmpf` 产生的 mask 跟比较输入的 data layout 对齐, +`select`/`reduce`/`masked_load` 消费的 mask 也要跟对应 data value 的 lane +layout 和元素 bitwidth 对齐。 + +Data 求解器为每个 `!pto.vmi.vreg` 建一个节点: -mask value: - 每个 !pto.vmi.mask 是一个节点,节点记录最终选择的布局和 predicate 粒度。 +```text +DataNode: + value = 对应 SSA value + original type = surface VMI type + parent = union-find parent + naturalLayout = 当前等价类选择的自然 layout,可能为空 ``` -data value 使用 union-find 表示“这些 value 必须共用 layout”。函数参数、 -call operand/result、return/yield、block argument、bitcast 等边界会把相关 -value 合并到同一个等价类里。等价类只能有一个最终 data layout。 +遍历 IR 时,每个 op 向 data 求解器贡献三类信息。 -assignment 遍历 IR 时,每类 op 向求解器贡献两种信息: +第一类是 layout 等价约束。它表示几个 value 必须使用同一个 physical layout, +也就是 union-find 中的同一个等价类。典型来源: ```text -result 自然布局: - 这个 op 自己产生的 result 适合用什么 layout 表达。 +layout-transparent elementwise: + addf/addi/subf/subi/mulf/muli/fma/divf/minf/maxf/... + L(operands...) = L(result) + +unary elementwise: + negf/absf/absi/sqrt/exp/ln/relu/not + L(source) = L(result) + +select: + L(true_value) = L(false_value) = L(result) + +bitcast: + L(source) = L(result) + +structured control flow: + scf.if result = then/else yield operand + scf.for result = init operand = iter_arg = yield operand + scf.while result = init/before/condition/after/yield carried value + +cf branch: + branch operand = destination block argument + +function boundary: + call operand = callee argument + call result = callee return operand + multiple returns of the same function agree per result index +``` + +这一步只说明“这些 value 如果存在布局,就必须一致”。它不等价于把某个 +consumer 的 request 无条件推过所有 producer 或控制流。 + +第二类是 result 自然布局。某些 op 的结果本身有目标相关的自然布局: + +```text +普通 reduce / compress / shuffle: + result 通常是 contiguous。 + +group_reduce: + source 需要适配 group reduce 指令形态; + result 使用 group_slots(num_groups, slots) 描述 sparse group result。 + +cast: + widening/narrowing 根据 cast support 决定 source request 和 result layout。 + +group_load / group_slot_load: + result 根据 group size、row stride 和目标能力选择 contiguous、deinterleaved + 或 group_slots。 + +active_prefix_index: + result 使用 contiguous。 +``` + +若同一个等价类已经有自然布局,再设置不同自然布局会报 layout contract 冲突。 + +第三类是 operand 使用请求。consumer 不直接修改 operand 的 type,而是记录 +“这个 use site 希望 operand 是什么 layout”: + +```text +store / tile_write / masked_store value: + wants contiguous + +ordinary reduce source/init: + wants contiguous + +group_reduce source: + wants preferred group-reduce source layout + +group_store value: + wants preferred group result layout + +truncf/trunci/extf/extsi/extui source: + wants cast support 给出的 source layout + +channel_split / channel_merge / shuffle: + wants 各自 lowering 需要的 source/input layout +``` + +收集完这些信息后,assignment 才尝试做 consumer-driven adoption。它逐个查看 +use request:如果 operand 的 producer 可以直接用 consumer 需要的 layout 产生 +同一个逻辑向量,并且多 use 时所有 use 都请求同一个 layout,那么这个 request +会被提升为该 value 所在 data 等价类的最终 layout。 + +可采纳 producer 是受限集合: + +```text +load / tile_read +broadcast / constant / iota +layout-transparent elementwise +select +bitcast +``` + +这就是 request 看起来能穿过 elemwise 的原因: + +```mlir +%x = pto.vmi.load ... +%k = pto.vmi.broadcast ... +%y = pto.vmi.mulf %x, %k +%q = pto.vmi.truncf %y +``` + +`mulf` 先把 `%x`、`%k`、`%y` 合成同一个 data 等价类。`truncf` 对 `%y` +的 source use 请求 `deinterleaved=4` 时,这个 request 作用到 `%y` 所在等价类; +因为 `mulf` 是可采纳 producer,assignment 可以把整个等价类选成 +`deinterleaved=4`,从而让 load/broadcast/mulf 直接在这个 layout 下产生数据。 + +控制流边界也会形成等价类,但它不是任意 request 的自动传播通道: + +```mlir +%y = scf.if %c -> !pto.vmi.vreg<128xf32> { + scf.yield %a +} else { + scf.yield %b +} +%q = pto.vmi.truncf %y +``` + +`%y`、`%a`、`%b` 的 layout 必须一致;但 `scf.if` result 本身不是 +consumer-driven adoption 的可采纳 producer。若 `%q` 需要的 layout 无法成为 +这个等价类的最终布局,assignment 会在 `%q` 的 use site 插 +`pto.vmi.ensure_layout`,而不是隐式重写两个 branch 的内部计算。 + +Data layout 确定后,pass 会把每个 `!pto.vmi.vreg` 改写成 +`!pto.vmi.vreg`。如果某个记录过的 use request 仍然和 operand +当前 layout 不一致,pass 在该 consumer 前插显式 materialization: + +```mlir +%x_req = pto.vmi.ensure_layout %x + : !pto.vmi.vreg + -> !pto.vmi.vreg +consumer %x_req +``` + +这个规则也处理多 consumer 冲突: + +```mlir +%y = pto.vmi.mulf %x, %k +pto.vmi.store %y, %out0 // wants contiguous +%q = pto.vmi.truncf %y // wants deinterleaved=4 source +``` + +一个 SSA value 只能属于一个 data layout 等价类。若两个 use 不能共同满足, +baseline assignment 保留一个等价类 layout,并在不匹配 use 前插 +`ensure_layout`。后续 `vmi-layout-fold-consumers`、`vmi-layout-rematerialize` +和 `vmi-layout-sink-materialization` 可以在显式 helper op 上做优化,但 +`vmi-to-vpto` 不读取隐藏 plan 或 sibling user。 + +Mask 求解发生在 data type 改写之后。它同样维护 union-find 等价类,但节点记录 +两件事: + +```text +mask layout +predicate granularity: b8 / b16 / b32 +``` + +mask request 从已经带 layout 的 data value 推导: + +```text +cmpf/cmpi result: + mask layout = lhs data layout + granularity = lhs element bitwidth 对应的 predicate 粒度 + +select mask: + mask layout = result data layout + granularity = result element bitwidth 对应的 predicate 粒度 -operand 使用请求: - 这个 op 消费某个 operand 时希望 operand 是什么 layout。 +reduce / group_reduce / masked_load / expand_load mask: + mask layout = source/result data layout + granularity = 对应 data element bitwidth 的 predicate 粒度 ``` -有些 producer 生成的是同一个逻辑向量,但可以用多种物理 layout 表达。若它的 -所有 consumer 给出的使用请求一致,assignment 会把这个请求反推为 producer -result 的最终布局。否则,producer 保持自己的布局,assignment 在不匹配的 use -site 插入 `pto.vmi.ensure_layout`。mask 使用同样思路,但还会同时求解 predicate -粒度,必要时插入 `ensure_mask_layout` 或 `ensure_mask_granularity`。 +若 mask use 的 layout 或 granularity 不匹配,pass 显式插 +`pto.vmi.ensure_mask_layout` 或 `pto.vmi.ensure_mask_granularity`。 -最后,pass 会把所有 VMI data/mask type 改写成带 layout 的 type,并同步更新 -function type、call site、block argument 和 terminator operand。这个阶段之后, -IR 不再依赖隐藏 plan;后续 pass 和 `vmi-to-vpto` 都只读取 type 上的 layout -和显式 `ensure_*` helper。 +完成 data/mask 改写和 helper 插入后,pass 会同步更新 function type。直接 +internal call 会把 call operand/result 与 callee argument/return operand 合成 +同一布局约束;带 VMI type 的 external declaration 或 indirect call 没有可见 +body,当前需要显式 ABI materialization 设计,因此 layout assignment 会拒绝。 +这个阶段之后,IR 不再依赖隐藏 plan;后续 pass 和 `vmi-to-vpto` 都只读取 type +上的 layout 和显式 `ensure_*` helper。 ### 3.3 `vmi-layout-fold-consumers` From 31abc145f8a815add7dbda65c3a3571b938c3cb9 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Tue, 23 Jun 2026 18:23:51 +0800 Subject: [PATCH 26/31] Illustrate VMI layout equivalence classes --- docs/designs/vmi-introduction.md | 85 ++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/docs/designs/vmi-introduction.md b/docs/designs/vmi-introduction.md index 33852a69ce..dbf09230f1 100644 --- a/docs/designs/vmi-introduction.md +++ b/docs/designs/vmi-introduction.md @@ -272,6 +272,91 @@ function boundary: 这一步只说明“这些 value 如果存在布局,就必须一致”。它不等价于把某个 consumer 的 request 无条件推过所有 producer 或控制流。 +等价类可以画成“同一个框里的 value 共用一个 layout 变量”。例如普通 +elementwise 链: + +```text +surface VMI: + + %x = pto.vmi.load ... + %k = pto.vmi.broadcast ... + %y = pto.vmi.mulf %x, %k + %q = pto.vmi.truncf %y + +data layout 等价类: + + class C0 + +--------------------------------------+ + | %x %k %y | + | load broadcast mulf result | + +--------------------------------------+ + ^ + | + use request from truncf source: + wants deinterleaved=4 + +若 %y 的 producer chain 可采纳该 request,assignment 可以选择: + + L(C0) = deinterleaved=4 +``` + +控制流 join 也是等价类,但 request adoption 的含义不同: + +```text +surface VMI: + + %y = scf.if %c -> !pto.vmi.vreg<128xf32> { + scf.yield %a + } else { + scf.yield %b + } + %q = pto.vmi.truncf %y + +data layout 等价类: + + class C1 + +--------------------------------------+ + | %a %b %y | + | then yield else yield if result | + +--------------------------------------+ + ^ + | + use request from truncf source: + wants deinterleaved=4 + +scf.if result 不是 consumer-driven adoption 的可采纳 producer。 +若 C1 不能直接选择 deinterleaved=4,assignment 保持 C1 的布局, +并在 use site materialize: + + %y_for_q = pto.vmi.ensure_layout %y : L(C1) -> deinterleaved=4 + %q = pto.vmi.truncf %y_for_q +``` + +多 consumer 冲突时,等价类仍然只有一个 layout: + +```text +surface VMI: + + %y = pto.vmi.mulf %x, %k + pto.vmi.store %y, %out0 + %q = pto.vmi.truncf %y + +data layout 等价类: + + class C2 + +-----------------------------+ + | %x %k %y | + +-----------------------------+ + |\ + | \ use request from truncf: deinterleaved=4 + | + +--- use request from store: contiguous + +两个 use request 不一致时,不能让 %y 同时拥有两个 layout。 +baseline assignment 保留 C2 已有的 natural layout;若没有 natural layout, +则使用默认 contiguous。与该 layout 不匹配的 edge 会插 ensure_layout。 +``` + 第二类是 result 自然布局。某些 op 的结果本身有目标相关的自然布局: ```text From f5c27d4dc6183d9a20d0a6972f375263dd29f28b Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Wed, 24 Jun 2026 12:53:10 +0800 Subject: [PATCH 27/31] Add VMI histogram lowering support --- docs/designs/vmi-implementation-manual.md | 117 +++++++++++ docs/designs/vmi-introduction.md | 53 +++++ .../vmi-layout-assignment-implementation.md | 97 +++++++-- .../vmi-layout-assignment-lowering-design.md | 46 +++++ docs/designs/vmi-layout-lowering-cases.md | 194 ++++++++++++++++++ include/PTO/IR/VMIOps.td | 17 ++ include/PTO/Transforms/VMILayoutSupport.h | 14 ++ lib/PTO/IR/VMI.cpp | 51 +++++ lib/PTO/Transforms/PTOValidateVMIIR.cpp | 18 ++ lib/PTO/Transforms/VMILayoutAssignment.cpp | 22 ++ lib/PTO/Transforms/VMILayoutSupport.cpp | 67 ++++++ lib/PTO/Transforms/VMIToVPTO.cpp | 110 +++++++++- test/lit/vmi/vmi_layout_assignment_dhist.pto | 37 ++++ .../vmi_to_vpto_chist_semantics_invalid.pto | 27 +++ test/lit/vmi/vmi_to_vpto_dhist.pto | 41 ++++ test/lit/vmi/vmi_to_vpto_dhist_tail_mask.pto | 49 +++++ .../vmi/dhist-tail-mask-store/compare.py | 36 ++++ .../cases/vmi/dhist-tail-mask-store/golden.py | 44 ++++ .../vmi/dhist-tail-mask-store/kernel.pto | 56 +++++ .../vmi/dhist-tail-mask-store/launch.cpp | 33 +++ .../cases/vmi/dhist-tail-mask-store/main.cpp | 94 +++++++++ .../vmi/dhist-tail-mask-store/ptoas.flags | 1 + 22 files changed, 1211 insertions(+), 13 deletions(-) create mode 100644 test/lit/vmi/vmi_layout_assignment_dhist.pto create mode 100644 test/lit/vmi/vmi_to_vpto_chist_semantics_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_dhist.pto create mode 100644 test/lit/vmi/vmi_to_vpto_dhist_tail_mask.pto create mode 100644 test/vpto/cases/vmi/dhist-tail-mask-store/compare.py create mode 100644 test/vpto/cases/vmi/dhist-tail-mask-store/golden.py create mode 100644 test/vpto/cases/vmi/dhist-tail-mask-store/kernel.pto create mode 100644 test/vpto/cases/vmi/dhist-tail-mask-store/launch.cpp create mode 100644 test/vpto/cases/vmi/dhist-tail-mask-store/main.cpp create mode 100644 test/vpto/cases/vmi/dhist-tail-mask-store/ptoas.flags diff --git a/docs/designs/vmi-implementation-manual.md b/docs/designs/vmi-implementation-manual.md index 6bb7a7e0fe..2cd72208a6 100644 --- a/docs/designs/vmi-implementation-manual.md +++ b/docs/designs/vmi-implementation-manual.md @@ -1905,6 +1905,9 @@ layout-producing conversion: externally ordered memory: load, store, tile_read, tile_write + +value-indexed accumulation: + dhist, chist ``` Per-part elementwise ops are straightforward only when all operands/results already share the same assigned layout: @@ -2219,6 +2222,17 @@ pto.vmi.tile_read pto.vmi.tile_write ``` +Value-indexed accumulation: + +```text +pto.vmi.dhist +pto.vmi.chist +``` + +`pto.vmi.dhist` is a first-stage semantic op when histogram support is enabled. +`pto.vmi.chist` may share the surface verifier, but its final lowering must be +gated until the target CHISTv2 high-range cumulative semantics are verified. + Current implementation scope note: ```text @@ -2299,6 +2313,18 @@ load/tile_read memory element type must match result VMI data element type when store/tile_write memory element type must match stored VMI data element type when the destination is PtrType or MemRefType ``` +Histogram op verifier: + +```text +dhist/chist acc type must be !pto.vmi.vreg<256xui16> +dhist/chist result type must match acc type +source type must be !pto.vmi.vreg +mask logical lane count must match source logical lane count +surface mask may be pred; after layout assignment it must be b8 contiguous +source/result/acc must not carry layout before vmi-layout-assignment +layout-assigned dhist/chist requires contiguous source, mask, acc, and result +``` + `shuffle` verifier: ```text @@ -3833,6 +3859,87 @@ vmi.tile_read / vmi.tile_write, current direct full-footprint path: any path that would expose padding lanes or reorder externally visible memory ``` +Histogram lowering: + +```text +vmi.dhist semantics: + source lanes are ui8 samples + mask selects active source lanes + acc/result are complete logical 256-bin ui16 histograms + result[b] = acc[b] + count(active source lanes whose value equals b) + +layout assignment: + source layout = contiguous + mask layout = contiguous, granularity b8 + acc/result layout = contiguous !pto.vmi.vreg<256xui16> + +physicalization: + acc/result physical arity is 2 because 256xui16 is 512B + part0 represents logical bins 0..127 + part1 represents logical bins 128..255 +``` + +`vmi-to-vpto` lowering for `pto.vmi.dhist` is local and deterministic from the +op and assigned types: + +```text +lo = converted acc part0 +hi = converted acc part1 + +for each converted source physical chunk c in logical order: + chunk_mask = converted b8 mask chunk c + + if source chunk c contains padding lanes because N is not a multiple of 256: + valid = pto.pge/plt_b8 prefix mask for the valid logical lanes in this chunk + chunk_mask = pto.pand chunk_mask, valid + + lo = pto.dhistv2 lo, src_c, chunk_mask, #bin=0 + hi = pto.dhistv2 hi, src_c, chunk_mask, #bin=1 + +return physical result parts [lo, hi] +``` + +Required preflight: + +```text +acc/result element type is ui16 and logical lane count is exactly 256 +source element type is ui8 +source and mask logical lane counts match +source/mask are contiguous +mask granularity is b8 +source physical chunks are 256-lane ui8 chunks; final partial chunk is allowed +only when the lowering can construct the valid-lane prefix mask +``` + +Diagnostics: + +```text +VMI-UNSUPPORTED: pto.vmi.dhist requires contiguous ui8 source, b8 mask, and +contiguous 256xui16 accumulator/result + +VMI-UNSUPPORTED: pto.vmi.dhist final partial source chunk requires valid-lane +b8 mask materialization +``` + +`pto.vmi.chist` has the same verifier and assignment requirements, but final +lowering is capability-gated: + +```text +if CHISTv2 high-range semantics are verified as global cumulative: + replace the two pto.dhistv2 calls above with pto.chistv2 calls + +elif CHISTv2 high-range semantics are verified as range-local cumulative: + lower low/high pto.chistv2 and add the low-half total count to every high-half bin, + but only after low-total materialization and broadcast support is explicit + +else: + VMI-UNSUPPORTED: pto.vmi.chist requires a verified CHISTv2 range semantics contract +``` + +Do not classify histogram as `group_reduce`. Its result location is selected +by source values, not by lane/group position, and its low/high split is caused +by the physical `128xui16` VPTO result width. + Final hard gate: ```text @@ -3895,6 +4002,12 @@ Slice 4 完成条件: 11. Same-family mask logic ops lower through the physical mask granularity instead of assuming b32 masks. Covered by vmi_to_vpto_mask_logic.pto for mask_and/mask_or/mask_xor/mask_not on b32 masks produced by cmpf and on direct b8/b16 mask operands. +12. `pto.vmi.dhist` lowers one logical 256-bin histogram into two VPTO low/high + bin-range histogram accumulator chains, and tail source chunks are masked + with a valid-lane b8 prefix. `pto.vmi.chist` is rejected until the target + CHISTv2 cumulative range semantics are classified. + Covered by vmi_to_vpto_dhist.pto, vmi_to_vpto_dhist_tail_mask.pto, and + vmi_to_vpto_chist_semantics_invalid.pto. ``` ## 7. Slice 5: Tile Memory And Padding @@ -4075,6 +4188,7 @@ currently routed through the registry: supported source/result layout conversion pairs supported b8/b16/b32 mask granularity conversion pairs pto.vmi.channel_split/channel_merge supported channel count + pto.vmi.dhist direct target support and pto.vmi.chist cumulative range semantics classification still legacy helper-based and should migrate into the registry as follow-up: full layout materialization plans and padding-safety checks @@ -4132,6 +4246,9 @@ vmi_to_vpto_deinterleaved2.mlir vmi_to_vpto_deinterleaved4.mlir vmi_to_vpto_compaction_deint_invalid.mlir vmi_to_vpto_non_full_tile.mlir +vmi_to_vpto_dhist.mlir +vmi_to_vpto_dhist_tail_mask.mlir +vmi_to_vpto_chist_semantics_invalid.mlir vmi_tile_read_padding.mlir vmi_tile_write_mask.mlir vmi_pipeline_hard_gates.mlir diff --git a/docs/designs/vmi-introduction.md b/docs/designs/vmi-introduction.md index dbf09230f1..e7161dc4a0 100644 --- a/docs/designs/vmi-introduction.md +++ b/docs/designs/vmi-introduction.md @@ -11,6 +11,10 @@ VMI 是 VPTO 之前的逻辑向量层。它让前端先表达“我要对 `NxT` 的逻辑向量做什么”, 再由 layout assignment 决定这个逻辑向量如何拆到 256B 物理 vector register 上。 +当 VPTO 指令因为物理 register 宽度只能暴露半宽接口时,VMI 也负责提供完整的 +逻辑语义。例如 `ui8` histogram 的完整结果是 `256xui16`,物理 VPTO histogram +一次只能返回 `128xui16`;VMI surface 应该表达完整 histogram,low/high bin +range 拆分属于 lowering 细节。 Surface VMI 类型不携带布局: @@ -892,6 +896,55 @@ scf.if -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) 这个场景说明为什么 layout 应该是 type 的一部分,而不是依赖 defining op。 +### 4.8 完整 Histogram 语义 + +VPTO 的 histogram 指令一次读取 `256xui8` source,但结果只能写 +`128xui16` accumulator。完整 `ui8` histogram 有 256 个 bin,因此物理 VPTO +接口需要通过 `#bin = 0/1` 分两次统计低半区和高半区。 + +VMI surface 不暴露这个物理 split: + +```mlir +%hist = pto.vmi.dhist %acc, %src, %mask + : !pto.vmi.vreg<256xui16>, + !pto.vmi.vreg, + !pto.vmi.mask + -> !pto.vmi.vreg<256xui16> +``` + +语义是完整 256-bin distribution histogram: + +```text +for b = 0..255: + hist[b] = acc[b] + count(i where mask[i] && src[i] == b) +``` + +Assignment 形状: + +```text +src/mask = contiguous, b8 mask granularity +acc/result = contiguous 256xui16 logical value +``` + +VPTO 形状: + +```text +acc/result part0 = bins 0..127 +acc/result part1 = bins 128..255 + +for each 256-lane source chunk: + part0 = dhistv2(part0, src_chunk, mask_chunk, #bin=0) + part1 = dhistv2(part1, src_chunk, mask_chunk, #bin=1) +``` + +这说明 VMI 的易用性不只来自 layout assignment。对于这种 value-indexed +accumulation,VMI 还应该隐藏 VPTO 为了物理 vreg 宽度暴露出来的 range +selector、lo/hi accumulator 和多条物理指令。 + +`pto.vmi.chist` 可以使用相同 surface 形状,但当前必须先验证 VPTO `CHISTv2` +在 high range 上返回的是全局累计还是 range-local 累计。这个差异会影响是否需要 +额外给 high half 加上 low half 的总计数,因此不能只按 op 名字猜 lowering。 + ## 5. 当前边界 当前设计方向: diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index e1fa19cc7e..1162818712 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -104,7 +104,7 @@ pto-validate-vmi-layout-ir: `pto.vmi.group_load`, `pto.vmi.group_slot_load`, group_slots `pto.vmi.group_store`, group_slots `pto.vmi.group_reduce_add{f|i}`, explicit-slots `pto.vmi.group_broadcast`, `pto.vmi.truncf`, - `pto.vmi.extf`, and `pto.vmi.bitcast`, at the layout gate. + `pto.vmi.extf`, `pto.vmi.bitcast`, and histogram family ops at the layout gate. vmi-to-vpto: use OneToN type conversion @@ -302,6 +302,7 @@ group_slot_load result group_slots layout and source_group_stride group_reduce_add{f|i} source/mask/result layouts, num_groups, typed reduce semantics group_broadcast source/result layouts and num_groups truncf source/result layouts and element widths +dhist/chist acc/source/mask/result layouts and target capability ensure_layout always carries source/result layouts ensure_mask_layout always carries source/result layouts ensure_mask_granularity always carries source/result granularities @@ -373,6 +374,8 @@ group_reduce_addf group_reduce_addi group_broadcast group_store +dhist +chist ensure_layout // internal ensure_mask_layout // internal @@ -444,6 +447,13 @@ group_reduce layout fact: Example: S=2*VLaneElems means deinterleaved=2 source/mask and group_slots(G, slots=8) result in every stage. +histogram layout fact: + shared by layout assignment, layout validation, and vmi-to-vpto. + Example: dhist requires contiguous Nxui8 source, contiguous b8 mask, and + contiguous 256xui16 acc/result. chist uses the same layout fact but also + requires a target capability that classifies CHISTv2 cumulative range + semantics. + layout materialization support: shared by layout validation, vmi-to-vpto, and helper-based optimizations. Example: ensure_layout from deinterleaved=2 f32 to contiguous f32 is the same @@ -670,6 +680,15 @@ ordinary store: group_store: source request group_slots(G,K) + +dhist: + acc/result request contiguous 256xui16 + source request contiguous Nxui8 + mask request contiguous b8 + +chist: + same layout requests as dhist + diagnostic unless CHISTv2 cumulative range semantics are classified ``` Baseline assignment does not perform consumer-driven adoption for performance. @@ -679,8 +698,8 @@ use. ```text natural layout producer: - extf/truncf, group_reduce, group_slot_load, group_load when the op itself - carries a layout-producing contract + extf/truncf, group_reduce, group_slot_load, group_load, dhist/chist when the + op itself carries a layout-producing contract layout equality producer: dense add/mul/select and CFG-carried values tie operands/results but do not @@ -754,6 +773,14 @@ buildMaskRequests: masked_store requests source layout, mask layout, and store predicate granularity explicitly +buildHistogramRequests: + dhist -> acc/result contiguous 256xui16, source contiguous Nxui8, + mask contiguous b8 + chist -> same layout requests, plus target capability diagnostic until + CHISTv2 high-range semantics are classified + do not create group_slots or group_reduce requests; histogram result bins are + selected by source values, not by lane/group position + buildControlFlowRequests: region yields, branch operands, loop iter_args, call operands, and returns create equality requests on the carried VMI layout variable @@ -797,6 +824,7 @@ fixed-layout producers: extf/truncf physical conversion layouts group_load block-fragment layouts group_reduce result group_slots + dhist/chist result contiguous 256xui16 and source/mask contiguous b8 contract masked_load when the physical memory-safety proof fixes a full-read lowering ``` @@ -955,6 +983,20 @@ vmi-to-vpto contract: dynamic mask generation. ``` +```text +case family builder / owner assignment artifact +3.56 full distribution hist buildHistogramRequests contiguous src/mask/acc/result +3.57 cumulative hist boundary buildHistogramRequests capability diagnostic or classified path + +vmi-to-vpto contract: + lower dhist from the current op and assigned layouts by carrying two physical + accumulator parts for bins 0..127 and 128..255. It must not expose the VPTO + #bin range selector on the VMI surface and must not model histogram as + group_reduce. chist remains rejected until the target records whether the + high-range cumulative result is global or range-local and, for range-local + behavior, until low-total materialization is explicit. +``` + ```text case family builder / owner assignment artifact 3.15.1 S=16 row stride 16 buildGroupMemoryRequests block_elems=8 group_load layout @@ -1170,6 +1212,19 @@ group_reduce_add{f|i}, lowering=full_chunk_reduce_row_local: the existing row-local VCADD/VADD/VSEL sequence while preserving the same group_slots(G, slots=1) value contract +dhist, lowering=full_256bin_histogram: + consumes contiguous Nxui8 source and contiguous b8 mask + consumes/produces contiguous 256xui16 accumulator/result + physical result parts are [bins 0..127, bins 128..255] + emits one low-range and one high-range histogram update for each 256-lane + source chunk + final partial source chunks require an explicit valid-lane b8 mask + +chist, lowering=capability_gated_cumulative_histogram: + uses the same layout shape as dhist + rejects until target capability classifies CHISTv2 high-range cumulative + semantics and any required low-total correction materialization is explicit + group_slot_load, lowering=group_slot_load_slots8_unit_stride: result group_slots(G, slots=8) requires source_group_stride == 1 @@ -1558,6 +1613,11 @@ strided/group-slot memory: function/control-flow: 3.12, 3.20, 3.22, 3.25.1, 3.42, 3.43 + +histogram: + 3.56 positive dhist layout/lowering and simulator case when backend support + is enabled + 3.57 diagnostic chist case until CHISTv2 range semantics are classified ``` Aggregate catalog headings are covered through their endpoint subcases: @@ -1607,6 +1667,8 @@ repository evidence: golden.py, and compare.py latest broad VMI runtime sweep passed: PASS=47 FAIL=0 latest full VMI lit sweep passed: 350/350 + this historical sweep predates 3.56/3.57; histogram endpoints require new + lit/SIM or diagnostic tests before they can be counted as implemented ``` Current checked-in coverage for 3.3 dense f8->f32->compute->f8: @@ -2208,23 +2270,34 @@ internal function argument boundary materialization public ABI diagnostic ``` +### Slice 7: Histogram + +```text +3.56 full 256-bin dhist logical op +3.57 chist semantic capability diagnostic +``` + ## 13. Completion Checklist Current evidence for the case-catalog objective: ```text -1. every catalog endpoint is mapped in section 6.6 to an assignment owner, - assignment artifact, and vmi-to-vpto contract -2. every SIM-backed positive endpoint is listed in section 11.3 and has a - checked-in runtime case directory -3. every runtime case directory contains kernel.pto, launch.cpp, main.cpp, - golden.py, and compare.py -4. the latest broad VMI runtime sweep passed: PASS=47 FAIL=0 -5. the latest full VMI lit sweep passed: 350/350 -6. every unsupported endpoint listed in section 11.3 has a diagnostic lit test +1. every pre-histogram catalog endpoint is mapped in section 6.6 to an + assignment owner, assignment artifact, and vmi-to-vpto contract +2. every pre-histogram SIM-backed positive endpoint is listed in section 11.3 + and has a checked-in runtime case directory +3. every existing runtime case directory contains kernel.pto, launch.cpp, + main.cpp, golden.py, and compare.py +4. the latest historical broad VMI runtime sweep passed: PASS=47 FAIL=0 +5. the latest historical full VMI lit sweep passed: 350/350 +6. every pre-histogram unsupported endpoint listed in section 11.3 has a + diagnostic lit test 7. vmi-to-vpto decisions are represented by current-op attrs/operands, assigned layouts, helper ops, rematerialization, or diagnostics 8. no separate lowering-plan string attr is emitted or consumed 9. release docs remain untouched; this is still a design/implementation plan under docs/designs +10. new histogram endpoints 3.56/3.57 are mapped in section 6.6, but their + implementation evidence is intentionally pending new lit/SIM or diagnostic + tests ``` diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index 42e62e8b3a..00f69aae05 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -148,6 +148,11 @@ strided memory: strided group_load feeding broadcast and a second group_reduce group_slot_load slots=1 with non-unit source stride group_store slots=1 with non-unit output stride + +value-indexed accumulation: + full 256-bin distribution histogram over Nxui8 source lanes + VPTO low/high bin range split hidden behind one logical 256xui16 VMI result + cumulative histogram is a semantic boundary until CHISTv2 range semantics are verified ``` ### 1.1 Case-Set Sufficiency @@ -184,6 +189,10 @@ control-flow propagation: memory legality: full_tile_readable proof, grouped masks, predicate granularity, aligned strided group memory, stable gather diagnostic + +value-indexed accumulation: + histogram source/result shape, b8 source mask, and fixed low/high VPTO bin + split for a logical 256-bin result ``` No extra layout kind should be added unless a new case proves that the existing @@ -235,6 +244,13 @@ compute boundary: storage must be widened first because integer reduction instructions widen narrow inputs. f8/i8 are not baseline accumulator/compute element types. + +value-indexed accumulation boundary: + pto.vmi.dhist consumes ui8 source lanes and produces a logical 256xui16 + accumulator/result. It is not a group_reduce family member because result + bins are selected by source values rather than by source lane/group position. + pto.vmi.chist uses the same surface shape only after the target CHISTv2 + range semantics are verified. ``` ### 2.1 Dense Layouts @@ -298,6 +314,22 @@ S=8/16/32 packed VCG result -> slots=8 S=64 row-local result -> slots=1 ``` +Histogram does not add a layout family. A full logical histogram result uses: + +```text +!pto.vmi.vreg<256xui16, #pto.vmi.layout> +``` + +and physicalizes to two ordered VPTO parts: + +```text +part0 = logical bins 0..127 +part1 = logical bins 128..255 +``` + +The VPTO `#bin` selector is therefore an op-local lowering detail, not a VMI +layout attribute and not a user-visible operand on `pto.vmi.dhist`. + ## 3. Lowering Context Must Become Explicit IR Output `vmi-to-vpto` may inspect only: @@ -670,6 +702,20 @@ create_mask/create_group_mask: incompatible mask consumers are represented by ensure_mask_layout or ensure_mask_granularity; optimization may clone/rematerialize the mask op +dhist: + requests acc/result contiguous !pto.vmi.vreg<256xui16> + requests source contiguous !pto.vmi.vreg + requests mask contiguous with b8 granularity + lowers each 256-lane source chunk by carrying two accumulator parts: + bins 0..127 use VPTO histogram #bin=0, bins 128..255 use #bin=1 + final partial source chunks are represented by AND-ing the user mask with a + valid-lane prefix mask before the VPTO histogram op + +chist: + same layout requests as dhist + baseline lowering is disabled until target capability records whether the + high-range VPTO cumulative result is global or range-local + scf.if/scf.for/call/return: requests equality across carried VMI values, yielded values, call operands, callee arguments, and function results diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md index 855a7a486f..efb2a7c502 100644 --- a/docs/designs/vmi-layout-lowering-cases.md +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -246,6 +246,9 @@ the immediately following complete endpoints. 3.44 masked_load grouped tail feeding S=32 reduce complete 3.45 dynamic S=32 create_group_mask complete 3.46 extf value and derived elemwise value both stored complete/optimization +3.47-3.55 typed group-reduce generalization complete/diagnostic +3.56 full 256-bin distribution histogram complete +3.57 full 256-bin cumulative histogram design boundary ``` ### 3.1 `f16 -> f32 -> store` @@ -6235,3 +6238,194 @@ pto.vmi.group_store %sum8, %out_i8[%group_off], %c1 {num_groups = 8} That packed group-slot `trunci` path is not baseline lowering support yet; the implementation must either define slot-wise VCVTII lowering support or diagnose at layout assignment. + +### 3.56 Full 256-Bin Distribution Histogram + +Histogram is not modeled as `group_reduce`. A group reduce maps source lanes to +result slots by lane/group position. A histogram maps each active source lane +to a result bin by the source value itself. + +VMI-shaped input: + +```text +%src = pto.vmi.load %src_base[%src_off] + : memref -> !pto.vmi.vreg +%mask = pto.vmi.create_mask %active_lanes + : index -> !pto.vmi.mask +%acc = pto.vmi.load %acc_base[%acc_off] + : memref<256xui16> -> !pto.vmi.vreg<256xui16> +%hist = pto.vmi.dhist %acc, %src, %mask + : !pto.vmi.vreg<256xui16>, !pto.vmi.vreg, + !pto.vmi.mask -> !pto.vmi.vreg<256xui16> +pto.vmi.store %hist, %out[%out_off] +``` + +Logical semantics: + +```text +for b = 0..255: + hist[b] = acc[b] + +for i = 0..N-1: + if mask[i]: + hist[src[i]] += 1 +``` + +Assigned layouts: + +```text +%src: + !pto.vmi.vreg> + +%mask: + !pto.vmi.mask> + +%acc, %hist: + !pto.vmi.vreg<256xui16, #pto.vmi.layout> +``` + +The `256xui16` accumulator/result is one logical VMI value but two physical +VPTO vector registers: + +```text +physical result part0 = logical bins 0..127 +physical result part1 = logical bins 128..255 +``` + +For `N = 256`, VPTO lowering shape: + +```text +%src0 = pto.vlds %src_base[%src_off] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<256xui8> + +%acc_lo = pto.vlds %acc_base[%acc_off + 0] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<128xui16> +%acc_hi = pto.vlds %acc_base[%acc_off + 128] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<128xui16> + +%hist_lo = pto.dhistv2 %acc_lo, %src0, %mask0, %bin0 + : !pto.vreg<128xui16>, !pto.vreg<256xui8>, !pto.mask, i32 + -> !pto.vreg<128xui16> +%hist_hi = pto.dhistv2 %acc_hi, %src0, %mask0, %bin1 + : !pto.vreg<128xui16>, !pto.vreg<256xui8>, !pto.mask, i32 + -> !pto.vreg<128xui16> + +pto.vsts %hist_lo, %out[%out_off + 0], %all_b16 {dist = "NORM_B16"} +pto.vsts %hist_hi, %out[%out_off + 128], %all_b16 {dist = "NORM_B16"} +``` + +Memory result: + +```text +for b = 0..127: + out[out_off + b] = acc_base[acc_off + b] + + count(i where mask[i] && src_base[src_off + i] == b) + +for b = 128..255: + out[out_off + b] = acc_base[acc_off + b] + + count(i where mask[i] && src_base[src_off + i] == b) +``` + +For `N > 256`, the source is processed in contiguous 256-lane chunks. The two +histogram accumulator parts are carried through all chunks: + +```text +%lo = %acc_lo +%hi = %acc_hi + +for source chunk c in logical order: + %chunk_mask = mask chunk c + if c is the final partial chunk: + %chunk_mask = %chunk_mask & valid-lane-prefix-for-this-chunk + + %lo = pto.dhistv2 %lo, %src_c, %chunk_mask, %bin0 + %hi = pto.dhistv2 %hi, %src_c, %chunk_mask, %bin1 + +result physical parts = [%lo, %hi] +``` + +Tail source lanes are expressed only through the b8 mask. Padding lanes in the +last physical source chunk must be masked off before `pto.dhistv2`; they are +not padding values. + +The VMI op does not expose `#bin`. `#bin` is a VPTO range selector forced by +the physical result width: + +```text +ui8 value domain = 256 bins +complete histogram = 256 x ui16 = 512B +one VPTO vreg result = 128 x ui16 = 256B +``` + +Therefore VMI represents one logical `256xui16` result and `vmi-to-vpto` +locally emits the low-range and high-range VPTO histogram updates. + +### 3.57 Full 256-Bin Cumulative Histogram + +The desired VMI surface shape mirrors `dhist`: + +```text +%hist = pto.vmi.chist %acc, %src, %mask + : !pto.vmi.vreg<256xui16>, !pto.vmi.vreg, + !pto.vmi.mask -> !pto.vmi.vreg<256xui16> +``` + +The intended logical semantics is a full cumulative histogram: + +```text +dist[b] = count(i where mask[i] && src[i] == b) + +hist[0] = acc[0] + dist[0] +for b = 1..255: + hist[b] = acc[b] + dist[0] + dist[1] + ... + dist[b] +``` + +The current VPTO/VISA documentation only states that `CHISTv2` computes a +`uint16 Cumulative histogram` over the selected bin range. It does not state +whether the high-range call with `#bin = 1` returns: + +```text +global cumulative: + result[j] = count(src <= 128 + j) + +or range-local cumulative: + result[j] = count(128 <= src <= 128 + j) +``` + +These two interpretations have different VMI lowerings. If the hardware result +is global cumulative, the full VMI lowering is the same low/high split as +`dhist`, replacing `pto.dhistv2` with `pto.chistv2`. If the hardware result is +range-local cumulative, the high half also needs the total low-half count added +to every high-half bin: + +```text +%lo = pto.chistv2 %acc_lo, %src0, %mask0, %bin0 +%hi_local = pto.chistv2 %acc_hi, %src0, %mask0, %bin1 + +%low_total = materialize count(src <= 127) from the low-half result +%low_total_vec = broadcast %low_total to every high-half bin +%hi = pto.vadd %hi_local, %low_total_vec, %all_b16 +``` + +That correction path also requires a designed way to materialize and broadcast +the low-half total. Since baseline VMI does not support arbitrary vector +extract, the range-local CHISTv2 interpretation remains unsupported until that +materialization path is explicit. + +The baseline design therefore treats `pto.vmi.chist` as a semantic op whose +exact lowering is gated by a target semantic capability: + +```text +if target documents or validation proves CHISTv2 high range is global: + lower as two pto.chistv2 calls +elif target documents or validation proves CHISTv2 high range is range-local: + lower as pto.chistv2 low/high plus explicit high-half correction only after + low-total materialization support is designed +else: + VMI-UNSUPPORTED: pto.vmi.chist requires a verified CHISTv2 range semantics contract +``` + +This boundary is deliberate. `pto.vmi.dhist` is fully defined because +distribution bins are independent across the low/high split. `pto.vmi.chist` +has cross-range prefix semantics, so VMI must not guess the high-half behavior +from the VPTO op name alone. diff --git a/include/PTO/IR/VMIOps.td b/include/PTO/IR/VMIOps.td index d14b6fe8ee..98083fb687 100644 --- a/include/PTO/IR/VMIOps.td +++ b/include/PTO/IR/VMIOps.td @@ -437,6 +437,23 @@ def VMIGroupBroadcastOp : VMI_Op<"group_broadcast"> { let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; } +class VMIHistogramOp + : VMI_Op { + let summary = summaryText; + let arguments = (ins VMI_VRegTypeConstraint:$acc, + VMI_VRegTypeConstraint:$source, + VMI_MaskTypeConstraint:$mask); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$acc `,` $source `,` $mask attr-dict `:` type($acc) `,` type($source) `,` type($mask) `->` type($result)"; +} + +def VMIDhistOp : VMIHistogramOp<"dhist", + "VMI full 256-bin distribution histogram over unsigned 8-bit source lanes">; + +def VMIChistOp : VMIHistogramOp<"chist", + "VMI full 256-bin cumulative histogram over unsigned 8-bit source lanes">; + def VMIExtFOp : VMI_Op<"extf"> { let summary = "VMI floating-point elementwise extension"; let arguments = (ins VMI_VRegTypeConstraint:$source); diff --git a/include/PTO/Transforms/VMILayoutSupport.h b/include/PTO/Transforms/VMILayoutSupport.h index 9a274a2a9b..41b686b322 100644 --- a/include/PTO/Transforms/VMILayoutSupport.h +++ b/include/PTO/Transforms/VMILayoutSupport.h @@ -188,6 +188,14 @@ struct VMIBitcastSupport { VMIBitcastSupportKind kind = VMIBitcastSupportKind::PerPartVbitcast; }; +enum class VMIHistogramSupportKind { + Full256BinDhist, +}; + +struct VMIHistogramSupport { + VMIHistogramSupportKind kind = VMIHistogramSupportKind::Full256BinDhist; +}; + class VMILayoutSupport { public: FailureOr @@ -280,6 +288,12 @@ class VMILayoutSupport { FailureOr getBitcastSupport(VMIBitcastOp op, std::string *reason = nullptr) const; + + FailureOr + getDhistSupport(VMIDhistOp op, std::string *reason = nullptr) const; + + FailureOr + getChistSupport(VMIChistOp op, std::string *reason = nullptr) const; }; } // namespace mlir::pto diff --git a/lib/PTO/IR/VMI.cpp b/lib/PTO/IR/VMI.cpp index d3d2dc6b14..edd55a2bb3 100644 --- a/lib/PTO/IR/VMI.cpp +++ b/lib/PTO/IR/VMI.cpp @@ -1231,6 +1231,57 @@ LogicalResult VMIGroupBroadcastOp::verify() { getNumGroupsAttr().getInt()); } +template +static LogicalResult verifyVMIHistogramOp(OpTy op) { + auto accType = cast(op.getAcc().getType()); + auto sourceType = cast(op.getSource().getType()); + auto maskType = cast(op.getMask().getType()); + auto resultType = cast(op.getResult().getType()); + + auto accElemType = dyn_cast(accType.getElementType()); + auto sourceElemType = dyn_cast(sourceType.getElementType()); + if (!accElemType || !accElemType.isUnsigned() || + accElemType.getWidth() != 16 || accType.getElementCount() != 256) + return op.emitOpError("requires acc type to be " + "!pto.vmi.vreg<256xui16>"); + if (resultType != accType) + return op.emitOpError("requires result type to match acc type"); + if (!sourceElemType || !sourceElemType.isUnsigned() || + sourceElemType.getWidth() != 8) + return op.emitOpError("requires source type to be " + "!pto.vmi.vreg"); + if (maskType.getElementCount() != sourceType.getElementCount()) + return op.emitOpError("requires mask logical lane count to match source"); + + if (auto accLayout = accType.getLayoutAttr()) { + if (!accLayout.isContiguous()) + return op.emitOpError("requires layout-assigned acc to use contiguous " + "layout"); + } + if (auto sourceLayout = sourceType.getLayoutAttr()) { + if (!sourceLayout.isContiguous()) + return op.emitOpError("requires layout-assigned source to use contiguous " + "layout"); + } + if (auto resultLayout = resultType.getLayoutAttr()) { + if (!resultLayout.isContiguous()) + return op.emitOpError("requires layout-assigned result to use " + "contiguous layout"); + } + if (auto maskLayout = maskType.getLayoutAttr()) { + if (!maskLayout.isContiguous()) + return op.emitOpError("requires layout-assigned mask to use contiguous " + "layout"); + if (maskType.getGranularity() != "b8") + return op.emitOpError("requires layout-assigned mask granularity b8"); + } + return success(); +} + +LogicalResult VMIDhistOp::verify() { return verifyVMIHistogramOp(*this); } + +LogicalResult VMIChistOp::verify() { return verifyVMIHistogramOp(*this); } + LogicalResult VMIExtFOp::verify() { auto sourceType = cast(getSource().getType()); auto resultType = cast(getResult().getType()); diff --git a/lib/PTO/Transforms/PTOValidateVMIIR.cpp b/lib/PTO/Transforms/PTOValidateVMIIR.cpp index 6fdf6acf07..1186a25e26 100644 --- a/lib/PTO/Transforms/PTOValidateVMIIR.cpp +++ b/lib/PTO/Transforms/PTOValidateVMIIR.cpp @@ -627,6 +627,24 @@ LogicalResult verifyLayoutSemanticSupport(Operation *op, return success(); } + if (auto hist = dyn_cast(op)) { + std::string reason; + if (failed(supports.getDhistSupport(hist, &reason))) + return emitLayoutSupportContract( + op, diagOS, "pto.vmi.dhist has no registered histogram support", + reason); + return success(); + } + + if (auto hist = dyn_cast(op)) { + std::string reason; + if (failed(supports.getChistSupport(hist, &reason))) + return emitLayoutSupportContract( + op, diagOS, "pto.vmi.chist has no registered histogram support", + reason); + return success(); + } + if (auto truncf = dyn_cast(op)) { std::string reason; if (failed(supports.getTruncFSupport(truncf, &reason))) diff --git a/lib/PTO/Transforms/VMILayoutAssignment.cpp b/lib/PTO/Transforms/VMILayoutAssignment.cpp index eb3593c9ee..99e4314cf9 100644 --- a/lib/PTO/Transforms/VMILayoutAssignment.cpp +++ b/lib/PTO/Transforms/VMILayoutAssignment.cpp @@ -892,6 +892,28 @@ struct LayoutSolver { sourceType, broadcast.getNumGroupsAttr().getInt())); return WalkResult::advance(); } + if (auto hist = dyn_cast(op)) { + requestDataUse(hist.getAccMutable(), getContiguousLayout()); + requestDataUse(hist.getSourceMutable(), getContiguousLayout()); + if (failed(requestMaskUse(hist.getMaskMutable(), getContiguousLayout(), + "b8", op))) + return WalkResult::interrupt(); + if (failed(setNaturalLayout(hist.getResult(), getContiguousLayout(), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto hist = dyn_cast(op)) { + requestDataUse(hist.getAccMutable(), getContiguousLayout()); + requestDataUse(hist.getSourceMutable(), getContiguousLayout()); + if (failed(requestMaskUse(hist.getMaskMutable(), getContiguousLayout(), + "b8", op))) + return WalkResult::interrupt(); + if (failed(setNaturalLayout(hist.getResult(), getContiguousLayout(), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } if (auto extf = dyn_cast(op)) { auto sourceType = cast(extf.getSource().getType()); auto resultType = cast(extf.getResult().getType()); diff --git a/lib/PTO/Transforms/VMILayoutSupport.cpp b/lib/PTO/Transforms/VMILayoutSupport.cpp index 27a994ba55..acb687eed0 100644 --- a/lib/PTO/Transforms/VMILayoutSupport.cpp +++ b/lib/PTO/Transforms/VMILayoutSupport.cpp @@ -1344,3 +1344,70 @@ VMILayoutSupport::getBitcastSupport(VMIBitcastOp op, return VMIBitcastSupport{VMIBitcastSupportKind::PerPartVbitcast}; } + +template +static FailureOr +getHistogramSupportImpl(OpTy op, std::string *reason) { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto accType = cast(op.getAcc().getType()); + auto sourceType = cast(op.getSource().getType()); + auto maskType = cast(op.getMask().getType()); + auto resultType = cast(op.getResult().getType()); + + VMILayoutAttr accLayout = accType.getLayoutAttr(); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!accLayout || !sourceLayout || !maskLayout || !resultLayout) + return fail("requires assigned acc/source/mask/result layouts"); + if (!accLayout.isContiguous() || !sourceLayout.isContiguous() || + !maskLayout.isContiguous() || !resultLayout.isContiguous()) + return fail("requires contiguous acc, source, mask, and result layouts"); + if (maskType.getGranularity() != "b8") + return fail("requires b8 mask granularity"); + if (maskType.getElementCount() != sourceType.getElementCount()) + return fail("requires mask lane count to match source lane count"); + + auto accElem = dyn_cast(accType.getElementType()); + auto sourceElem = dyn_cast(sourceType.getElementType()); + if (!accElem || !accElem.isUnsigned() || accElem.getWidth() != 16 || + accType.getElementCount() != 256 || resultType != accType) + return fail("requires contiguous 256xui16 acc/result"); + if (!sourceElem || !sourceElem.isUnsigned() || sourceElem.getWidth() != 8) + return fail("requires unsigned 8-bit source elements"); + + FailureOr accArity = getVMIPhysicalArity(accType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + if (failed(accArity) || failed(resultArity) || failed(sourceArity) || + failed(maskArity)) + return fail("requires computable physical arity"); + if (*accArity != 2 || *resultArity != 2) + return fail("requires acc/result to physicalize to two 128xui16 parts"); + if (*sourceArity != *maskArity) + return fail("requires source and mask physical arity to match"); + if (*sourceArity < 1) + return fail("requires at least one source physical chunk"); + + return VMIHistogramSupport{VMIHistogramSupportKind::Full256BinDhist}; +} + +FailureOr +VMILayoutSupport::getDhistSupport(VMIDhistOp op, + std::string *reason) const { + return getHistogramSupportImpl(op, reason); +} + +FailureOr +VMILayoutSupport::getChistSupport(VMIChistOp op, + std::string *reason) const { + if (reason) + *reason = "CHISTv2 cumulative high-range semantics are not classified"; + return failure(); +} diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index ea286520bf..4115b90324 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -5986,6 +5986,76 @@ struct OneToNVMIGroupBroadcastOpPattern } }; +struct OneToNVMIDhistOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIDhistOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange accParts = adaptor.getAcc(); + ValueRange sourceParts = adaptor.getSource(); + ValueRange maskParts = adaptor.getMask(); + if (accParts.size() != 2 || sourceParts.empty() || + sourceParts.size() != maskParts.size()) + return rewriter.notifyMatchFailure( + op, "expected two accumulator parts and matching source/mask chunks"); + + auto loType = dyn_cast(accParts[0].getType()); + auto hiType = dyn_cast(accParts[1].getType()); + if (!loType || loType != hiType) + return rewriter.notifyMatchFailure(op, + "expected matching ui16 acc parts"); + auto sourceType = cast(op.getSource().getType()); + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + if (failed(lanesPerPart)) + return rewriter.notifyMatchFailure(op, + "failed to compute source lanes"); + + Location loc = op.getLoc(); + Value bin0 = createI32Constant(loc, 0, rewriter); + Value bin1 = createI32Constant(loc, 1, rewriter); + Value lo = accParts[0]; + Value hi = accParts[1]; + + for (size_t index = 0, e = sourceParts.size(); index < e; ++index) { + Value source = sourceParts[index]; + Value userMask = maskParts[index]; + auto maskType = dyn_cast(userMask.getType()); + if (!maskType || !maskType.isB8()) + return rewriter.notifyMatchFailure(op, "expected b8 source mask"); + + Value chunkMask = userMask; + int64_t firstLane = static_cast(index) * *lanesPerPart; + int64_t activeLanes = + std::min(*lanesPerPart, + sourceType.getElementCount() - firstLane); + if (activeLanes < *lanesPerPart) { + FailureOr validMask = + createPrefixMaskForActiveLanes(loc, maskType, activeLanes, + rewriter); + FailureOr allMask = createAllTrueMask(loc, maskType, rewriter); + if (failed(validMask) || failed(allMask)) + return rewriter.notifyMatchFailure( + op, "failed to materialize tail-valid b8 mask"); + chunkMask = rewriter + .create(loc, maskType, chunkMask, *validMask, + *allMask) + .getResult(); + } + + lo = rewriter.create(loc, loType, lo, source, chunkMask, bin0) + .getResult(); + hi = rewriter.create(loc, hiType, hi, source, chunkMask, bin1) + .getResult(); + } + + rewriter.replaceOp(op, SmallVector{lo, hi}, + adaptor.getResultMapping()); + return success(); + } +}; + template struct OneToNVMIReduceMinMaxFOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; @@ -6937,7 +7007,7 @@ void populateVMIOneToNConversionPatterns( OneToNVMISelectOpPattern, OneToNVMIActivePrefixIndexOpPattern, OneToNVMICompressOpPattern, OneToNVMICompressStoreOpPattern, OneToNVMIReduceAddIOpPattern, OneToNVMIReduceAddFOpPattern, - OneToNVMIGroupBroadcastOpPattern, + OneToNVMIGroupBroadcastOpPattern, OneToNVMIDhistOpPattern, OneToNVMIReduceMinMaxFOpPattern, OneToNVMIReduceMinMaxFOpPattern, OneToNVMIExtFOpPattern, OneToNVMITruncFOpPattern, @@ -7420,6 +7490,22 @@ LogicalResult checkSupportedGroupBroadcastShape( return success(); } +LogicalResult checkSupportedDhistShape(VMIDhistOp op, + std::string *reason = nullptr) { + VMILayoutSupport supports; + if (succeeded(supports.getDhistSupport(op, reason))) + return success(); + return failure(); +} + +LogicalResult checkSupportedChistShape(VMIChistOp op, + std::string *reason = nullptr) { + VMILayoutSupport supports; + if (succeeded(supports.getChistSupport(op, reason))) + return success(); + return failure(); +} + LogicalResult checkSupportedFmaShape(const VMITargetCapabilityRegistry &capabilities, VMIFmaOp op, std::string *reason = nullptr) { @@ -7579,6 +7665,28 @@ verifySupportedVMIToVPTOOps(ModuleOp module, << reason << ")"; return WalkResult::interrupt(); } + if (auto hist = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedDhistShape(hist, &reason))) + return WalkResult::advance(); + hist.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.dhist requires contiguous Nxui8 source, contiguous b8 " + "mask, and contiguous 256xui16 acc/result (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto hist = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedChistShape(hist, &reason))) + return WalkResult::advance(); + hist.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.chist requires a verified CHISTv2 range semantics " + "contract before lowering (" + << reason << ")"; + return WalkResult::interrupt(); + } if (auto load = dyn_cast(op)) { std::optional explicitFullReadElems; diff --git a/test/lit/vmi/vmi_layout_assignment_dhist.pto b/test/lit/vmi/vmi_layout_assignment_dhist.pto new file mode 100644 index 0000000000..89d9aee9e6 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_dhist.pto @@ -0,0 +1,37 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_dhist( + %acc: !pto.vmi.vreg<256xui16>, + %source: !pto.vmi.vreg<300xui8>, + %mask: !pto.vmi.mask<300xpred>) + -> !pto.vmi.vreg<256xui16> { + %hist = pto.vmi.dhist %acc, %source, %mask + : !pto.vmi.vreg<256xui16>, !pto.vmi.vreg<300xui8>, + !pto.vmi.mask<300xpred> -> !pto.vmi.vreg<256xui16> + return %hist : !pto.vmi.vreg<256xui16> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_dhist( +// CHECK-SAME: %[[ACC:.*]]: !pto.vmi.vreg<256xui16, #pto.vmi.layout> +// CHECK-SAME: %[[SRC:.*]]: !pto.vmi.vreg<300xui8, #pto.vmi.layout> +// CHECK-SAME: %[[MASK:.*]]: !pto.vmi.mask<300xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<256xui16, #pto.vmi.layout> +// CHECK: %[[MASK_B8:.*]] = pto.vmi.ensure_mask_granularity %[[MASK]] +// CHECK-SAME: !pto.vmi.mask<300xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.mask<300xb8, #pto.vmi.layout> +// CHECK: %[[HIST:.*]] = pto.vmi.dhist %[[ACC]], %[[SRC]], %[[MASK_B8]] +// CHECK-SAME: !pto.vmi.vreg<256xui16, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<300xui8, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.mask<300xb8, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<256xui16, #pto.vmi.layout> +// CHECK: return %[[HIST]] diff --git a/test/lit/vmi/vmi_to_vpto_chist_semantics_invalid.pto b/test/lit/vmi/vmi_to_vpto_chist_semantics_invalid.pto new file mode 100644 index 0000000000..1049cbdf2e --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_chist_semantics_invalid.pto @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_chist_semantics_invalid( + %acc: !pto.vmi.vreg<256xui16, #pto.vmi.layout>, + %source: !pto.vmi.vreg<256xui8, #pto.vmi.layout>, + %mask: !pto.vmi.mask<256xb8, #pto.vmi.layout>) { + %hist = pto.vmi.chist %acc, %source, %mask + : !pto.vmi.vreg<256xui16, #pto.vmi.layout>, + !pto.vmi.vreg<256xui8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout> + -> !pto.vmi.vreg<256xui16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUP{{.*}} pto.vmi.chist requires a verified CHISTv2 range semantics contract before lowering diff --git a/test/lit/vmi/vmi_to_vpto_dhist.pto b/test/lit/vmi/vmi_to_vpto_dhist.pto new file mode 100644 index 0000000000..b8a1113534 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_dhist.pto @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_dhist( + %acc: !pto.vmi.vreg<256xui16, #pto.vmi.layout>, + %source: !pto.vmi.vreg<256xui8, #pto.vmi.layout>, + %mask: !pto.vmi.mask<256xb8, #pto.vmi.layout>) + -> (!pto.vreg<128xui16>, !pto.vreg<128xui16>) { + %hist = pto.vmi.dhist %acc, %source, %mask + : !pto.vmi.vreg<256xui16, #pto.vmi.layout>, + !pto.vmi.vreg<256xui8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout> + -> !pto.vmi.vreg<256xui16, #pto.vmi.layout> + %lo, %hi = "pto.vmi.unpack"(%hist) + : (!pto.vmi.vreg<256xui16, #pto.vmi.layout>) + -> (!pto.vreg<128xui16>, !pto.vreg<128xui16>) + return %lo, %hi : !pto.vreg<128xui16>, !pto.vreg<128xui16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_dhist( +// CHECK-SAME: %[[ACC0:[^,]+]]: !pto.vreg<128xui16> +// CHECK-SAME: %[[ACC1:[^,]+]]: !pto.vreg<128xui16> +// CHECK-SAME: %[[SRC:[^,]+]]: !pto.vreg<256xui8> +// CHECK-SAME: %[[MASK:[^)]+]]: !pto.mask +// CHECK-DAG: %[[BIN0:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[BIN1:.*]] = arith.constant 1 : i32 +// CHECK: %[[LO:.*]] = pto.dhistv2 %[[ACC0]], %[[SRC]], %[[MASK]], %[[BIN0]] +// CHECK: %[[HI:.*]] = pto.dhistv2 %[[ACC1]], %[[SRC]], %[[MASK]], %[[BIN1]] +// CHECK: return %[[LO]], %[[HI]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_dhist_tail_mask.pto b/test/lit/vmi/vmi_to_vpto_dhist_tail_mask.pto new file mode 100644 index 0000000000..4aada7a188 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_dhist_tail_mask.pto @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_dhist_tail_mask( + %acc: !pto.vmi.vreg<256xui16, #pto.vmi.layout>, + %source: !pto.vmi.vreg<300xui8, #pto.vmi.layout>, + %mask: !pto.vmi.mask<300xb8, #pto.vmi.layout>) + -> (!pto.vreg<128xui16>, !pto.vreg<128xui16>) { + %hist = pto.vmi.dhist %acc, %source, %mask + : !pto.vmi.vreg<256xui16, #pto.vmi.layout>, + !pto.vmi.vreg<300xui8, #pto.vmi.layout>, + !pto.vmi.mask<300xb8, #pto.vmi.layout> + -> !pto.vmi.vreg<256xui16, #pto.vmi.layout> + %lo, %hi = "pto.vmi.unpack"(%hist) + : (!pto.vmi.vreg<256xui16, #pto.vmi.layout>) + -> (!pto.vreg<128xui16>, !pto.vreg<128xui16>) + return %lo, %hi : !pto.vreg<128xui16>, !pto.vreg<128xui16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_dhist_tail_mask( +// CHECK-SAME: %[[ACC0:[^,]+]]: !pto.vreg<128xui16> +// CHECK-SAME: %[[ACC1:[^,]+]]: !pto.vreg<128xui16> +// CHECK-SAME: %[[SRC0:[^,]+]]: !pto.vreg<256xui8> +// CHECK-SAME: %[[SRC1:[^,]+]]: !pto.vreg<256xui8> +// CHECK-SAME: %[[MASK0:[^,]+]]: !pto.mask +// CHECK-SAME: %[[MASK1:[^)]+]]: !pto.mask +// CHECK-DAG: %[[BIN0:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[BIN1:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[C44:.*]] = arith.constant 44 : i32 +// CHECK: %[[LO0:.*]] = pto.dhistv2 %[[ACC0]], %[[SRC0]], %[[MASK0]], %[[BIN0]] +// CHECK: %[[HI0:.*]] = pto.dhistv2 %[[ACC1]], %[[SRC0]], %[[MASK0]], %[[BIN1]] +// CHECK: %[[TAIL:.*]], %{{.*}} = pto.plt_b8 %[[C44]] : i32 -> !pto.mask, i32 +// CHECK: %[[ALL:.*]] = pto.pset_b8 "PAT_ALL" : !pto.mask +// CHECK: %[[MASK1_VALID:.*]] = pto.pand %[[MASK1]], %[[TAIL]], %[[ALL]] +// CHECK: %[[LO1:.*]] = pto.dhistv2 %[[LO0]], %[[SRC1]], %[[MASK1_VALID]], %[[BIN0]] +// CHECK: %[[HI1:.*]] = pto.dhistv2 %[[HI0]], %[[SRC1]], %[[MASK1_VALID]], %[[BIN1]] +// CHECK: return %[[LO1]], %[[HI1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/vpto/cases/vmi/dhist-tail-mask-store/compare.py b/test/vpto/cases/vmi/dhist-tail-mask-store/compare.py new file mode 100644 index 0000000000..22aff69b5d --- /dev/null +++ b/test/vpto/cases/vmi/dhist-tail-mask-store/compare.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.uint16) + output = np.fromfile("v3.bin", dtype=np.uint16) + if golden.shape == output.shape and np.array_equal(golden, output): + print("[INFO] compare passed") + return + + if golden.shape != output.shape: + print(f"[ERROR] compare failed v3.bin: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed v3.bin idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/dhist-tail-mask-store/golden.py b/test/vpto/cases/vmi/dhist-tail-mask-store/golden.py new file mode 100644 index 0000000000..0c09bb49d7 --- /dev/null +++ b/test/vpto/cases/vmi/dhist-tail-mask-store/golden.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +SOURCE_ELEMS = 512 +LOGICAL_LANES = 300 +BINS = 256 + + +def generate(output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + src = (np.arange(SOURCE_ELEMS, dtype=np.uint16) % BINS).astype(np.uint8) + acc = (np.arange(BINS, dtype=np.uint16) % np.uint16(5)).astype(np.uint16) + dst = np.full(BINS, np.uint16(0xcccc), dtype=np.uint16) + + counts = np.bincount(src[:LOGICAL_LANES].astype(np.int64), minlength=BINS) + golden = (acc.astype(np.uint32) + counts.astype(np.uint32)).astype(np.uint16) + + src.tofile(output_dir / "v1.bin") + acc.tofile(output_dir / "v2.bin") + dst.tofile(output_dir / "v3.bin") + golden.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/dhist-tail-mask-store/kernel.pto b/test/vpto/cases/vmi/dhist-tail-mask-store/kernel.pto new file mode 100644 index 0000000000..4fb1fe531c --- /dev/null +++ b/test/vpto/cases/vmi/dhist-tail-mask-store/kernel.pto @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dhist_tail_mask_store_kernel( + %src_gm: !pto.ptr, %acc_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c300 = arith.constant 300 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_acc = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %acc_gm, %ub_acc, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %source = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<512xui8> + %acc = pto.vmi.load %ub_acc[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xui16> + %mask = pto.vmi.create_mask %c300 : index -> !pto.vmi.mask<512xpred> + %hist = pto.vmi.dhist %acc, %source, %mask + : !pto.vmi.vreg<256xui16>, !pto.vmi.vreg<512xui8>, + !pto.vmi.mask<512xpred> -> !pto.vmi.vreg<256xui16> + pto.vmi.store %hist, %ub_dst[%c0] + : !pto.vmi.vreg<256xui16>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/dhist-tail-mask-store/launch.cpp b/test/vpto/cases/vmi/dhist-tail-mask-store/launch.cpp new file mode 100644 index 0000000000..4031c8131e --- /dev/null +++ b/test/vpto/cases/vmi/dhist-tail-mask-store/launch.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dhist_tail_mask_store_kernel(__gm__ uint8_t *src, __gm__ uint16_t *acc, + __gm__ uint16_t *dst); + +void LaunchVmi_dhist_tail_mask_store_kernel(uint8_t *src, uint16_t *acc, + uint16_t *dst, void *stream) { + vmi_dhist_tail_mask_store_kernel<<<1, nullptr, stream>>>( + (__gm__ uint8_t *)src, (__gm__ uint16_t *)acc, (__gm__ uint16_t *)dst); +} diff --git a/test/vpto/cases/vmi/dhist-tail-mask-store/main.cpp b/test/vpto/cases/vmi/dhist-tail-mask-store/main.cpp new file mode 100644 index 0000000000..aa1288ab26 --- /dev/null +++ b/test/vpto/cases/vmi/dhist-tail-mask-store/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dhist_tail_mask_store_kernel(uint8_t *src, uint16_t *acc, + uint16_t *dst, void *stream); + +int main() { + constexpr size_t kSourceElems = 512; + constexpr size_t kBins = 256; + size_t srcBytes = kSourceElems * sizeof(uint8_t); + size_t accBytes = kBins * sizeof(uint16_t); + size_t dstBytes = kBins * sizeof(uint16_t); + uint8_t *srcHost = nullptr; + uint16_t *accHost = nullptr; + uint16_t *dstHost = nullptr; + uint8_t *srcDevice = nullptr; + uint16_t *accDevice = nullptr; + uint16_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&accHost), accBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&accDevice, accBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", accBytes, accHost, accBytes); + ReadFile("./v3.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(accDevice, accBytes, accHost, accBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVmi_dhist_tail_mask_store_kernel(srcDevice, accDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(accDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(accHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/dhist-tail-mask-store/ptoas.flags b/test/vpto/cases/vmi/dhist-tail-mask-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/dhist-tail-mask-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi From 0f32642344d31566fb388118e155282e8bc34465 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Wed, 24 Jun 2026 14:44:35 +0800 Subject: [PATCH 28/31] Remove VMI load full read attribute --- .../vmi-layout-assignment-implementation.md | 22 ++++--- .../vmi-layout-assignment-lowering-design.md | 13 ++-- docs/designs/vmi-layout-lowering-cases.md | 14 ++--- include/PTO/IR/VMIOps.td | 3 +- lib/PTO/IR/VMI.cpp | 4 -- lib/PTO/Transforms/VMIToVPTO.cpp | 59 +++++++------------ ...gnment_group_reduce_s32_tail_full_tile.pto | 30 +++++----- .../vmi/vmi_load_full_read_elems_invalid.pto | 20 ------- .../golden.py | 2 +- .../kernel.pto | 16 ++--- 10 files changed, 70 insertions(+), 113 deletions(-) delete mode 100644 test/lit/vmi/vmi_load_full_read_elems_invalid.pto diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index 1162818712..fcdf7fe292 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -311,7 +311,7 @@ ensure_mask_granularity always carries source/result granularities Layout/attr-only decisions today: ```text -load result layout plus full_read_elems/full chunk proof +load result layout plus full chunk or shaped memref proof group_store source group_slots layout plus explicit output stride masked_load explicit passthrough, mask layout, and memory proof masked_store/select operand/result layouts plus mask granularity @@ -410,9 +410,8 @@ Important semantic split: ```text load: - optional full_read_elems=N is a memory-safety contract for pointer sources. - It states that source[offset : offset + N) may be physically read even if the - VMI logical result has fewer active lanes. + pointer sources must load full physical chunks directly. Partial logical + loads require a shaped memref proof or a future guarded/scratch fallback. group_load: loads group_size data elements per group @@ -806,7 +805,7 @@ helpers with cheaper equivalent IR. ```text cheap rematerializable producers: load when address operands dominate the clone site, no intervening may-alias - write exists, and any full_read_elems proof is preserved + write exists, and any shaped memory proof is preserved broadcast create_mask create_group_mask @@ -1036,7 +1035,7 @@ vmi-to-vpto contract: ```text case family builder / owner assignment artifact -3.21 S=32 safe full-read tail buildMaskRequests full_read_elems memory proof +3.21 S=32 rounded tail mask buildMaskRequests rounded vector plus mask 3.24 mask/select/store buildMaskRequests explicit mask layout/granularity 3.12 scf.if before reduce buildControlFlowRequests common yielded layout 3.20 group_slots scf.if buildControlFlowRequests common group_slots layout @@ -1469,7 +1468,7 @@ Current audit result: masked_load: direct lowering is load + vsel. It does not inspect the mask producer to choose a different load form; memory safety is provided by full physical - chunks, shaped memref proof, or load full_read_elems. + chunks or shaped memref proof. memref.subview: mentioned only after identity lane-to-address planning fails. It is not used @@ -2187,11 +2186,10 @@ private physical function ABI: rejected until a stable VMI ABI is defined. memory-proof runtime coverage: - 3.21 S=32 full-tile-readable tail is covered by a runtime case that uses - `pto.vmi.load {full_read_elems = 256}` on a UB pointer source. The attr is - the explicit safe-read proof consumed by `vmi-to-vpto`; no surrounding MTE, - caller/body context, or producer/user scan is inspected to justify the - rounded-up physical reads. + 3.21 S=32 rounded tail-mask coverage is provided by a runtime case that loads + a full 256xf32 UB pointer vector and uses a 192-lane mask to define the active + logical rows. No surrounding MTE, caller/body context, or producer/user scan is + inspected to justify partial pointer reads. ``` ## 12. Implementation Slices diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index 00f69aae05..497f6cad8c 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -422,8 +422,8 @@ group_reduce layout fact: group_slots(G, slots=1) result. memory safety fact: - full_read_elems, shaped safe-tail memref, or explicit fallback option - proves whether rounded-up physical reads are legal. + full physical chunks are legal for pointer sources. Partial logical loads + need a shaped safe-tail memref proof or an explicit fallback option. ``` These helpers return semantic layout requirements and capability diagnostics. @@ -596,10 +596,11 @@ full_tile_readable: ``` The full-tile-readable proof must be explicit. It may be carried by a -statically shaped memref source, or by `pto.vmi.load {full_read_elems = N}` for -pointer sources. `vmi-to-vpto` consumes only this proof carrier; it does not -inspect surrounding MTE copies, producer bodies, callers, or later consumers to -decide whether inactive physical lanes are safe to read. +statically shaped memref source. Pointer-source runtime kernels should load a +rounded physical vector and use a mask to express logical active lanes. +`vmi-to-vpto` consumes only the op/type-local proof carrier; it does not inspect +surrounding MTE copies, producer bodies, callers, or later consumers to decide +whether inactive physical lanes are safe to read. Example: diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md index efb2a7c502..3fd0c4b7eb 100644 --- a/docs/designs/vmi-layout-lowering-cases.md +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -2845,11 +2845,10 @@ for r = 0..7: This is the positive counterpart to section 3.11.2. Tail participation is still expressed by masks, but the source must provide a static proof that reading the rounded-up 8-row physical tile is memory-safe. That proof is -explicit: it can come from a statically shaped memref source, or from -`pto.vmi.load {full_read_elems = N}` on a pointer source. The pointer attr -means the memory interval starting at the load offset is safe to read for `N` -logical elements; it is not inferred from surrounding MTE copies or caller -context. +explicit for partial logical loads: it can come from a statically shaped memref +source. Pointer-source runtime kernels should instead load the rounded physical +vector and use a mask to express active logical lanes; this is not inferred from +surrounding MTE copies or caller context. VMI input: @@ -2864,8 +2863,9 @@ pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 6} Equivalent pointer-source VMI input for runtime kernels: ```text -%x = pto.vmi.load %base[%off] {full_read_elems = 256} - : !pto.ptr -> !pto.vmi.vreg<192xf32> +%x = pto.vmi.load %base[%off] + : !pto.ptr -> !pto.vmi.vreg<256xf32> +%mask = pto.vmi.create_mask %c192 : index -> !pto.vmi.mask<256xpred> ``` Assigned layouts: diff --git a/include/PTO/IR/VMIOps.td b/include/PTO/IR/VMIOps.td index 98083fb687..263146ec1f 100644 --- a/include/PTO/IR/VMIOps.td +++ b/include/PTO/IR/VMIOps.td @@ -504,8 +504,7 @@ def VMIBitcastOp : VMI_Op<"bitcast"> { def VMILoadOp : VMI_Op<"load", [DeclareOpInterfaceMethods]> { let summary = "VMI logical vector load"; - let arguments = (ins PtrOrMemRef:$source, Index:$offset, - OptionalAttr:$full_read_elems); + let arguments = (ins PtrOrMemRef:$source, Index:$offset); let results = (outs VMI_VRegTypeConstraint:$result); let hasVerifier = 1; let assemblyFormat = "$source `[` $offset `]` attr-dict `:` type($source) `->` type($result)"; diff --git a/lib/PTO/IR/VMI.cpp b/lib/PTO/IR/VMI.cpp index edd55a2bb3..3589ec603e 100644 --- a/lib/PTO/IR/VMI.cpp +++ b/lib/PTO/IR/VMI.cpp @@ -1394,10 +1394,6 @@ LogicalResult VMIBitcastOp::verify() { } LogicalResult VMILoadOp::verify() { - if (auto fullReadElems = getFullReadElemsAttr()) { - if (fullReadElems.getInt() <= 0) - return emitOpError("requires full_read_elems to be positive"); - } return verifyMemoryElementMatches(getOperation(), getSource().getType(), cast(getResult().getType()), "source"); diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index 4115b90324..b0da679232 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -923,8 +923,7 @@ VMICapabilityResult requireIdentityMemRefLayout(Type memoryType, StringRef role, VMIMemorySafeReadProof computeSafeFullReadProof( Type sourceType, std::optional constantOffset, - VMIVRegType resultType, - std::optional explicitFullReadElems = std::nullopt) { + VMIVRegType resultType) { VMIMemorySafeReadProof proof; proof.constantOffset = constantOffset; @@ -937,15 +936,11 @@ VMIMemorySafeReadProof computeSafeFullReadProof( if (!constantOffset) return fail("requires constant index offset"); - std::optional elements = explicitFullReadElems; - if (!elements) { - FailureOr staticElements = getStaticMemRefElementCount(sourceType); - if (failed(staticElements)) - return fail("requires statically shaped memref source or explicit " - "full_read_elems"); - elements = *staticElements; - } - proof.staticElementCount = *elements; + FailureOr staticElements = getStaticMemRefElementCount(sourceType); + if (failed(staticElements)) + return fail("requires statically shaped memref source"); + int64_t elements = *staticElements; + proof.staticElementCount = elements; if (*constantOffset < 0) return fail("requires non-negative offset"); @@ -959,11 +954,11 @@ VMIMemorySafeReadProof computeSafeFullReadProof( proof.laneAddressMap = *addressMap; proof.physicalFootprint = addressMap->physicalLaneFootprint; - if (addressMap->getExclusiveEndElement() > *elements) + if (addressMap->getExclusiveEndElement() > elements) return fail(Twine("full physical read footprint [") + Twine(addressMap->baseElementOffset) + ", " + Twine(addressMap->getExclusiveEndElement()) + - ") exceeds static memref element count " + Twine(*elements)); + ") exceeds static memref element count " + Twine(elements)); proof.proven = true; return proof; @@ -972,8 +967,7 @@ VMIMemorySafeReadProof computeSafeFullReadProof( VMIMemoryAccessPlan buildReadAccessPlan( const VMITargetCapabilityRegistry &capabilities, Value source, Type sourceType, VMIVRegType resultType, - std::optional constantOffset, VMIMemoryValidMaskKind validMask, - std::optional explicitFullReadElems = std::nullopt) { + std::optional constantOffset, VMIMemoryValidMaskKind validMask) { VMIMemoryAccessPlan plan; plan.baseType = sourceType; plan.valueType = resultType; @@ -982,8 +976,8 @@ VMIMemoryAccessPlan buildReadAccessPlan( plan.validMask = validMask; plan.permutation = VMIMemoryPermutationKind::Identity; plan.writeMask = VMIMemoryWriteMaskKind::AllTrue; - plan.safeReadProof = computeSafeFullReadProof( - sourceType, constantOffset, resultType, explicitFullReadElems); + plan.safeReadProof = + computeSafeFullReadProof(sourceType, constantOffset, resultType); plan.laneAddressMap = plan.safeReadProof.laneAddressMap; plan.targetCapability = capabilities.supportsDirectMemory(sourceType, "source"); @@ -1040,16 +1034,15 @@ void requireUnavailableReadFallback(VMIMemoryAccessPlan &plan) { FailureOr verifyFullOrSafeReadVRegChunks( Operation *op, VMIVRegType type, Type sourceType, Value offset, - PatternRewriter &rewriter, - std::optional explicitFullReadElems = std::nullopt) { + PatternRewriter &rewriter) { std::string fullChunkReason; FailureOr lanesPerPart = checkFullDataPhysicalChunks(type, &fullChunkReason); if (succeeded(lanesPerPart)) return *lanesPerPart; - VMIMemorySafeReadProof safeReadProof = computeSafeFullReadProof( - sourceType, getConstantIndexValue(offset), type, explicitFullReadElems); + VMIMemorySafeReadProof safeReadProof = + computeSafeFullReadProof(sourceType, getConstantIndexValue(offset), type); if (safeReadProof.proven) { lanesPerPart = getDataLanesPerPart(type.getElementType()); if (succeeded(lanesPerPart)) @@ -1065,7 +1058,7 @@ FailureOr verifyFullOrSafeReadVRegChunks( LogicalResult checkSupportedLoadShape( const VMITargetCapabilityRegistry &capabilities, VMIVRegType type, Value source, Type sourceType, std::optional constantOffset, - std::optional explicitFullReadElems, std::string *reason) { + std::string *reason) { auto fail = [&](const Twine &message) -> LogicalResult { if (reason) *reason = message.str(); @@ -1074,7 +1067,7 @@ LogicalResult checkSupportedLoadShape( VMIMemoryAccessPlan accessPlan = buildReadAccessPlan( capabilities, source, sourceType, type, constantOffset, - VMIMemoryValidMaskKind::AllTrue, explicitFullReadElems); + VMIMemoryValidMaskKind::AllTrue); if (!accessPlan.targetCapability.isSupported()) return fail(accessPlan.targetCapability.reason); @@ -1195,7 +1188,7 @@ checkSupportedGroupLoadShape(const VMITargetCapabilityRegistry &capabilities, if (resultLayout.isContiguous()) { if (failed(checkSupportedLoadShape(capabilities, resultType, op.getSource(), op.getSource().getType(), std::nullopt, - std::nullopt, reason))) + reason))) return failure(); return checkSupportedGroupChunkShape(resultType, *groupSize, reason); } @@ -3750,12 +3743,8 @@ struct OneToNVMILoadOpPattern : OneToNOpConversionPattern { "load offset must convert to one value", rewriter); if (failed(source) || failed(offset)) return failure(); - std::optional explicitFullReadElems; - if (auto attr = op.getFullReadElemsAttr()) - explicitFullReadElems = attr.getInt(); FailureOr lanesPerPart = verifyFullOrSafeReadVRegChunks( - op, resultVMIType, op.getSource().getType(), *offset, rewriter, - explicitFullReadElems); + op, resultVMIType, op.getSource().getType(), *offset, rewriter); if (failed(lanesPerPart)) return failure(); @@ -7585,13 +7574,11 @@ verifySupportedVMIToVPTOOps(ModuleOp module, bool enableStableGatherMaskedLoad) { auto emitMemoryUnsupported = [&](Operation *op, StringRef opName, VMIVRegType type, Value source, - std::optional constantOffset, - std::optional explicitFullReadElems = - std::nullopt) -> WalkResult { + std::optional constantOffset) -> WalkResult { std::string reason; if (succeeded(checkSupportedLoadShape(capabilities, type, source, source.getType(), constantOffset, - explicitFullReadElems, &reason))) + &reason))) return WalkResult::advance(); op->emitError() @@ -7689,13 +7676,9 @@ verifySupportedVMIToVPTOOps(ModuleOp module, } if (auto load = dyn_cast(op)) { - std::optional explicitFullReadElems; - if (auto attr = load.getFullReadElemsAttr()) - explicitFullReadElems = attr.getInt(); return emitMemoryUnsupported( op, "pto.vmi.load", cast(load.getResult().getType()), - load.getSource(), getConstantIndexValue(load.getOffset()), - explicitFullReadElems); + load.getSource(), getConstantIndexValue(load.getOffset())); } if (auto load = dyn_cast(op)) { std::string reason; diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto index 602ac579ad..31e83e37d7 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto @@ -26,19 +26,19 @@ module { return } - func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile_ptr_contract( + func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile_ptr_masked( %src: !pto.ptr, %dst: !pto.ptr, %off: index) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c192 = arith.constant 192 : index - %x = pto.vmi.load %src[%c0] {full_read_elems = 256} - : !pto.ptr -> !pto.vmi.vreg<192xf32> - %mask = pto.vmi.create_mask %c192 : index -> !pto.vmi.mask<192xpred> - %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6, reassoc} - : !pto.vmi.vreg<192xf32>, !pto.vmi.mask<192xpred> - -> !pto.vmi.vreg<192xf32> - pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 6} - : !pto.vmi.vreg<192xf32>, !pto.ptr + %x = pto.vmi.load %src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %mask = pto.vmi.create_mask %c192 : index -> !pto.vmi.mask<256xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr return } } @@ -69,16 +69,16 @@ module { // LOWER-NOT: !pto.vmi. // LOWER-NOT: unrealized_conversion_cast -// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile_ptr_contract( +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile_ptr_masked( // ASSIGN: %[[PX:.*]] = pto.vmi.load -// ASSIGN-SAME: {full_read_elems = 256 : i64} -// ASSIGN-SAME: -> !pto.vmi.vreg<192xf32, #pto.vmi.layout> -// ASSIGN: %[[PMASK0:.*]] = pto.vmi.create_mask %{{.*}} : index -> !pto.vmi.mask<192xb32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[PMASK0:.*]] = pto.vmi.create_mask %{{.*}} : index -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[PMASK:.*]] = pto.vmi.ensure_mask_layout %[[PMASK0]] -// ASSIGN-SAME: !pto.vmi.mask<192xb32, #pto.vmi.layout> -> !pto.vmi.mask<192xb32, #pto.vmi.layout> +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_reduce_addf %[[PX]], %[[PMASK]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> -// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile_ptr_contract( +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile_ptr_masked( // LOWER-COUNT-4: pto.vlds // LOWER-COUNT-3: pto.vdintlv // LOWER-COUNT-4: pto.vcgadd diff --git a/test/lit/vmi/vmi_load_full_read_elems_invalid.pto b/test/lit/vmi/vmi_load_full_read_elems_invalid.pto deleted file mode 100644 index 102efd4f0e..0000000000 --- a/test/lit/vmi/vmi_load_full_read_elems_invalid.pto +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s - -module { - func.func @vmi_load_full_read_elems_invalid(%src: !pto.ptr) { - %c0 = arith.constant 0 : index - %value = pto.vmi.load %src[%c0] {full_read_elems = 0} - : !pto.ptr -> !pto.vmi.vreg<100xf32> - return - } -} - -// CHECK: 'pto.vmi.load' op requires full_read_elems to be positive diff --git a/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/golden.py b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/golden.py index cf80936861..a521122803 100644 --- a/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/golden.py +++ b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/golden.py @@ -22,7 +22,7 @@ def generate(output_dir: Path) -> None: src = np.empty(INPUT_ELEMS, dtype=np.float32) dst = np.full(PHYSICAL_ROWS, SENTINEL, dtype=np.float32) - golden = np.full(PHYSICAL_ROWS, SENTINEL, dtype=np.float32) + golden = np.zeros(PHYSICAL_ROWS, dtype=np.float32) base_row = np.linspace(-0.875, 0.625, GROUP_SIZE, dtype=np.float32) for row in range(PHYSICAL_ROWS): diff --git a/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/kernel.pto index fabed4ee8b..4e311c0703 100644 --- a/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/kernel.pto +++ b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/kernel.pto @@ -32,14 +32,14 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind !pto.vmi.mask<192xpred> - %x = pto.vmi.load %ub_src[%c0] {full_read_elems = 256} - : !pto.ptr -> !pto.vmi.vreg<192xf32> - %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6, reassoc} - : !pto.vmi.vreg<192xf32>, !pto.vmi.mask<192xpred> - -> !pto.vmi.vreg<192xf32> - pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 6} - : !pto.vmi.vreg<192xf32>, !pto.ptr + %mask = pto.vmi.create_mask %c192 : index -> !pto.vmi.mask<256xpred> + %x = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] From fbdda7b71bc543df9c03d99b8cefdce54f50ffe7 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Wed, 24 Jun 2026 16:56:04 +0800 Subject: [PATCH 29/31] Define VMI scatter as unique-index op --- docs/designs/vmi-dialect-design.md | 19 ++++++------- docs/designs/vmi-implementation-manual.md | 8 +----- include/PTO/IR/VMIOps.td | 3 +-- lib/PTO/Transforms/VMIToVPTO.cpp | 11 +++----- .../lit/vmi/vmi_layout_assignment_scatter.pto | 3 +-- test/lit/vmi/vmi_scatter_indices_invalid.pto | 2 +- ...i_to_vpto_gather_scatter_shape_invalid.pto | 4 +-- test/lit/vmi/vmi_to_vpto_scatter.pto | 2 +- ...to_vpto_scatter_missing_unique_invalid.pto | 27 ------------------- 9 files changed, 20 insertions(+), 59 deletions(-) delete mode 100644 test/lit/vmi/vmi_to_vpto_scatter_missing_unique_invalid.pto diff --git a/docs/designs/vmi-dialect-design.md b/docs/designs/vmi-dialect-design.md index 7569b787a0..897b1661cc 100644 --- a/docs/designs/vmi-dialect-design.md +++ b/docs/designs/vmi-dialect-design.md @@ -1178,7 +1178,7 @@ interleave/deinterleave boundary: vldsx2/vstsx2 dist or explicit rearrangement indexed memory: - gather/scatter if inactive and duplicate-index semantics match + gather/scatter; ordinary scatter requires pairwise-distinct active indices ``` GM-backed VMI memory is semantic input, not a direct vector load/store target. @@ -1356,20 +1356,21 @@ lanes to preserve passthru, so the `vsel` is semantically required, not an optim gather, tail gather, non-contiguous layout, memref/gm source, and fallback through guarded scalar load or scratch are future target-capability paths. -当前 `scatter` direct lowering 只在 VMI IR 携带显式 no-conflict proof 时启用: +`scatter` 的基础语义要求所有 active logical lanes 的 `%indices` 两两不同。inactive lane 不写内存, +因此不参与这个唯一性约束。如果两个 active lane 的 index 相同,程序违反 `pto.vmi.scatter` 的 +语义前置条件;VMI 不为这种输入定义 logical lane order 或 winner。 ```mlir -pto.vmi.scatter %v, %base[%indices], %mask {indices_unique} +pto.vmi.scatter %v, %base[%indices], %mask : !pto.vmi.vreg<64xf32>, !pto.ptr, !pto.vmi.vreg<64xi32>, !pto.vmi.mask<64xpred> ``` -`indices_unique` 的含义是:所有 active logical lanes 的 `%indices` 两两不同。这个 proof 可以来自 -producer 的静态分析、前端语义或上游 canonicalization;VMI lowering 不从 runtime 值猜测它。direct -path 的其它限制与 gather 对齐:UB pointer destination、contiguous full physical chunks、32-bit value -element、i32 indices 和 b32 mask。没有 `indices_unique` 时,`vmi-to-vpto` 必须诊断,而不能直接发 -`VSCATTER`,因为 `VSCATTER` 对重复 index 的 grant procedure 是目标相关/未定义的,不等价于 VMI -logical lane order。 +当前 direct path 的其它限制与 gather 对齐:UB pointer destination、contiguous full physical chunks、 +32-bit value element、i32 indices 和 b32 mask。允许冲突的 scatter 不能复用普通 `pto.vmi.scatter`, +因为 `VSCATTER` 对重复 index 的 grant procedure 是目标相关/未定义的,不等价于确定的 VMI logical +lane order。后续如果需要定义 duplicate-index scatter,需要新增显式语义,例如 ordered fallback、 +atomic scatter、reduce-scatter 或 target-specific unordered scatter。 `expand_load/compress_store` 表达 masked contiguous stream,不是 arbitrary indexed access: diff --git a/docs/designs/vmi-implementation-manual.md b/docs/designs/vmi-implementation-manual.md index 2cd72208a6..f86a812f05 100644 --- a/docs/designs/vmi-implementation-manual.md +++ b/docs/designs/vmi-implementation-manual.md @@ -3730,13 +3730,11 @@ vmi.scatter: if mask[lane] is true, memory[base + indices[lane]] = value[lane] if mask[lane] is false, no memory write occurs for that lane indices are interpreted in element units, not bytes - if two active lanes have the same index, VMI logical semantics require an ordered conflict policy or an explicit - no-conflict proof before direct target lowering + all active lanes must have pairwise-distinct indices; duplicate active indices violate the VMI scatter contract layout assignment: value and indices uses are requested as contiguous mask use is requested as contiguous with granularity derived from value element width current direct path: - op must carry {indices_unique} destination must be !pto.ptr T must be a 32-bit element type indices must be signless or unsigned i32 @@ -3744,11 +3742,7 @@ vmi.scatter: mask granularity must be b32 for each physical chunk i: pto.vscatter value_i, destination, indices_i, mask_i - reason for indices_unique: - VSCATTER false predicate lanes do not write, but duplicate active indices have target-defined/undefined grant - behavior. VMI cannot lower duplicate-index logical order semantics to VSCATTER without a proof or fallback. unsupported cases: - missing indices_unique proof f16/b16/f8/i8 value element types partial/tail chunks non-contiguous layouts diff --git a/include/PTO/IR/VMIOps.td b/include/PTO/IR/VMIOps.td index 263146ec1f..7eccc093a5 100644 --- a/include/PTO/IR/VMIOps.td +++ b/include/PTO/IR/VMIOps.td @@ -590,8 +590,7 @@ def VMIScatterOp : VMI_Op<"scatter", [DeclareOpInterfaceMethods:$indices_unique); + VMI_MaskTypeConstraint:$mask); let results = (outs); let hasVerifier = 1; let assemblyFormat = "$value `,` $destination `[` $indices `]` `,` $mask attr-dict `:` type($value) `,` type($destination) `,` type($indices) `,` type($mask)"; diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index b0da679232..39ca049a1e 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -1374,10 +1374,6 @@ checkSupportedScatterShape(const VMITargetCapabilityRegistry &capabilities, return failure(); }; - if (!op->hasAttr("indices_unique")) - return fail("requires indices_unique proof because pto.vscatter does not " - "define logical-lane-order duplicate-index semantics"); - auto valueType = cast(op.getValue().getType()); auto indicesType = cast(op.getIndices().getType()); auto maskType = cast(op.getMask().getType()); @@ -7800,10 +7796,9 @@ verifySupportedVMIToVPTOOps(ModuleOp module, return WalkResult::advance(); scatter.emitError() << kVMIDiagUnsupportedPrefix - << "pto.vmi.scatter lowers through pto.vscatter only with an " - "indices_unique proof, UB pointer destination, contiguous full " - "physical chunks, 32-bit value elements, i32 indices, and b32 " - "masks (" + << "pto.vmi.scatter lowers through pto.vscatter only with a UB " + "pointer destination, contiguous full physical chunks, 32-bit " + "value elements, i32 indices, and b32 masks (" << reason << ")"; return WalkResult::interrupt(); } diff --git a/test/lit/vmi/vmi_layout_assignment_scatter.pto b/test/lit/vmi/vmi_layout_assignment_scatter.pto index 9560cfa981..b920cf4da4 100644 --- a/test/lit/vmi/vmi_layout_assignment_scatter.pto +++ b/test/lit/vmi/vmi_layout_assignment_scatter.pto @@ -14,7 +14,7 @@ module { %dst: !pto.ptr, %indices: !pto.vmi.vreg<64xi32>, %mask: !pto.vmi.mask<64xpred>) { - pto.vmi.scatter %value, %dst[%indices], %mask {indices_unique} + pto.vmi.scatter %value, %dst[%indices], %mask : !pto.vmi.vreg<64xf32>, !pto.ptr, !pto.vmi.vreg<64xi32>, !pto.vmi.mask<64xpred> return @@ -26,7 +26,6 @@ module { // CHECK-SAME: %arg2: !pto.vmi.vreg<64xi32, #pto.vmi.layout> // CHECK-SAME: %arg3: !pto.vmi.mask<64xb32, #pto.vmi.layout> // CHECK: pto.vmi.scatter %arg0, %arg1[%arg2], %arg3 -// CHECK-SAME: indices_unique // CHECK-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout> // CHECK-SAME: !pto.vmi.vreg<64xi32, #pto.vmi.layout> // CHECK-SAME: !pto.vmi.mask<64xb32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_scatter_indices_invalid.pto b/test/lit/vmi/vmi_scatter_indices_invalid.pto index bd59b81b04..e16d6905f0 100644 --- a/test/lit/vmi/vmi_scatter_indices_invalid.pto +++ b/test/lit/vmi/vmi_scatter_indices_invalid.pto @@ -14,7 +14,7 @@ module { %dst: !pto.ptr, %indices: !pto.vmi.vreg<64xf32>, %mask: !pto.vmi.mask<64xpred>) { - pto.vmi.scatter %value, %dst[%indices], %mask {indices_unique} + pto.vmi.scatter %value, %dst[%indices], %mask : !pto.vmi.vreg<64xf32>, !pto.ptr, !pto.vmi.vreg<64xf32>, !pto.vmi.mask<64xpred> return diff --git a/test/lit/vmi/vmi_to_vpto_gather_scatter_shape_invalid.pto b/test/lit/vmi/vmi_to_vpto_gather_scatter_shape_invalid.pto index c271e9f446..2e5afb7708 100644 --- a/test/lit/vmi/vmi_to_vpto_gather_scatter_shape_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_gather_scatter_shape_invalid.pto @@ -57,7 +57,7 @@ module { %dst: !pto.ptr, %indices: !pto.vmi.vreg<64xi32, #pto.vmi.layout>, %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { - pto.vmi.scatter %value, %dst[%indices], %mask {indices_unique} + pto.vmi.scatter %value, %dst[%indices], %mask : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.ptr, !pto.vmi.vreg<64xi32, #pto.vmi.layout>, @@ -77,7 +77,7 @@ module { %dst: !pto.ptr, %indices: !pto.vmi.vreg<32xi32, #pto.vmi.layout>, %mask: !pto.vmi.mask<32xb32, #pto.vmi.layout>) { - pto.vmi.scatter %value, %dst[%indices], %mask {indices_unique} + pto.vmi.scatter %value, %dst[%indices], %mask : !pto.vmi.vreg<32xf32, #pto.vmi.layout>, !pto.ptr, !pto.vmi.vreg<32xi32, #pto.vmi.layout>, diff --git a/test/lit/vmi/vmi_to_vpto_scatter.pto b/test/lit/vmi/vmi_to_vpto_scatter.pto index 12799c01fc..4f898e3571 100644 --- a/test/lit/vmi/vmi_to_vpto_scatter.pto +++ b/test/lit/vmi/vmi_to_vpto_scatter.pto @@ -14,7 +14,7 @@ module { %dst: !pto.ptr, %indices: !pto.vmi.vreg<64xi32, #pto.vmi.layout>, %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { - pto.vmi.scatter %value, %dst[%indices], %mask {indices_unique} + pto.vmi.scatter %value, %dst[%indices], %mask : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.ptr, !pto.vmi.vreg<64xi32, #pto.vmi.layout>, diff --git a/test/lit/vmi/vmi_to_vpto_scatter_missing_unique_invalid.pto b/test/lit/vmi/vmi_to_vpto_scatter_missing_unique_invalid.pto deleted file mode 100644 index 027162ac68..0000000000 --- a/test/lit/vmi/vmi_to_vpto_scatter_missing_unique_invalid.pto +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s - -module { - func.func @vmi_to_vpto_scatter_missing_unique_invalid( - %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, - %dst: !pto.ptr, - %indices: !pto.vmi.vreg<64xi32, #pto.vmi.layout>, - %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { - pto.vmi.scatter %value, %dst[%indices], %mask - : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, - !pto.ptr, - !pto.vmi.vreg<64xi32, #pto.vmi.layout>, - !pto.vmi.mask<64xb32, #pto.vmi.layout> - return - } -} - -// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.scatter lowers through pto.vscatter only with an indices_unique proof -// CHECK-SAME: requires indices_unique proof From f9013c29c74f1a9e1d88184b1da33d6cefee3601 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Wed, 24 Jun 2026 22:16:41 +0800 Subject: [PATCH 30/31] Add VMI group max quant kernel case --- include/PTO/IR/VMIOps.td | 10 + include/PTO/Transforms/VMILayoutSupport.h | 75 +++-- .../PTO/Transforms/VMITargetCapabilities.h | 35 +- lib/PTO/IR/VMI.cpp | 54 ++-- lib/PTO/Transforms/PTOValidateVMIIR.cpp | 113 ++++--- lib/PTO/Transforms/VMILayoutAssignment.cpp | 156 ++++++--- lib/PTO/Transforms/VMILayoutSupport.cpp | 193 ++++++----- lib/PTO/Transforms/VMIToVPTO.cpp | 300 ++++++++++-------- ...out_assignment_group_reduce_maxf_quant.pto | 78 +++++ .../simdvf-per-token-cast-to-fp8/compare.py | 49 +++ .../simdvf-per-token-cast-to-fp8/golden.py | 62 ++++ .../simdvf-per-token-cast-to-fp8/kernel.pto | 79 +++++ .../simdvf-per-token-cast-to-fp8/launch.cpp | 43 +++ .../simdvf-per-token-cast-to-fp8/main.cpp | 91 ++++++ .../simdvf-per-token-cast-to-fp8/ptoas.flags | 1 + 15 files changed, 945 insertions(+), 394 deletions(-) create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_maxf_quant.pto create mode 100644 test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/compare.py create mode 100644 test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/golden.py create mode 100644 test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/ptoas.flags diff --git a/include/PTO/IR/VMIOps.td b/include/PTO/IR/VMIOps.td index 7eccc093a5..9acce8cd7b 100644 --- a/include/PTO/IR/VMIOps.td +++ b/include/PTO/IR/VMIOps.td @@ -418,6 +418,16 @@ def VMIGroupReduceAddFOp : VMI_Op<"group_reduce_addf"> { let assemblyFormat = "$source `,` $mask attr-dict `:` type($source) `,` type($mask) `->` type($result)"; } +def VMIGroupReduceMaxFOp : VMI_Op<"group_reduce_maxf"> { + let summary = "VMI masked floating-point maximum reduction within fixed logical groups"; + let arguments = (ins VMI_VRegTypeConstraint:$source, + VMI_MaskTypeConstraint:$mask, + I64Attr:$num_groups); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `,` $mask attr-dict `:` type($source) `,` type($mask) `->` type($result)"; +} + def VMIGroupReduceAddIOp : VMI_Op<"group_reduce_addi"> { let summary = "VMI masked integer add reduction within fixed logical groups"; let arguments = (ins VMI_VRegTypeConstraint:$source, diff --git a/include/PTO/Transforms/VMILayoutSupport.h b/include/PTO/Transforms/VMILayoutSupport.h index 41b686b322..429a20bf0d 100644 --- a/include/PTO/Transforms/VMILayoutSupport.h +++ b/include/PTO/Transforms/VMILayoutSupport.h @@ -200,82 +200,93 @@ class VMILayoutSupport { public: FailureOr getContiguousStoreSupport(VMIVRegType valueType, - std::string *reason = nullptr) const; + std::string *reason = nullptr) const; - LogicalResult canFoldContiguousStoreMaterialization( - VMIVRegType sourceType, VMIVRegType resultType, - std::string *reason = nullptr) const; + LogicalResult + canFoldContiguousStoreMaterialization(VMIVRegType sourceType, + VMIVRegType resultType, + std::string *reason = nullptr) const; FailureOr getDataLayoutMaterializationSupport(VMIVRegType sourceType, - VMIVRegType resultType, - std::string *reason = nullptr) const; + VMIVRegType resultType, + std::string *reason = nullptr) const; LogicalResult canMaterializeDataLayout(VMIVRegType sourceType, - VMIVRegType resultType, - std::string *reason = nullptr) const; + VMIVRegType resultType, + std::string *reason = nullptr) const; FailureOr getMaskLayoutMaterializationSupport(VMIMaskType sourceType, - VMIMaskType resultType, - std::string *reason = nullptr) const; + VMIMaskType resultType, + std::string *reason = nullptr) const; LogicalResult canMaterializeMaskLayout(VMIMaskType sourceType, - VMIMaskType resultType, - std::string *reason = nullptr) const; + VMIMaskType resultType, + std::string *reason = nullptr) const; FailureOr getMaskGranularityMaterializationSupport(VMIMaskType sourceType, - VMIMaskType resultType, - std::string *reason = nullptr) const; + VMIMaskType resultType, + std::string *reason = nullptr) const; - LogicalResult canMaterializeMaskGranularity( - VMIMaskType sourceType, VMIMaskType resultType, - std::string *reason = nullptr) const; + LogicalResult + canMaterializeMaskGranularity(VMIMaskType sourceType, VMIMaskType resultType, + std::string *reason = nullptr) const; FailureOr getPreferredCastLayoutFact(VMIVRegType sourceType, VMIVRegType resultType, - std::string *reason = nullptr) const; + std::string *reason = nullptr) const; FailureOr getGroupSlotLoadSupport(const VMITargetCapabilityRegistry &capabilities, - VMIGroupSlotLoadOp op, - std::string *reason = nullptr) const; + VMIGroupSlotLoadOp op, + std::string *reason = nullptr) const; FailureOr getGroupLoadSupport(const VMITargetCapabilityRegistry &capabilities, - VMIGroupLoadOp op, - std::string *reason = nullptr) const; + VMIGroupLoadOp op, std::string *reason = nullptr) const; FailureOr getGroupSlotsStoreSupport(const VMITargetCapabilityRegistry &capabilities, - VMIGroupStoreOp op, - std::string *reason = nullptr) const; + VMIGroupStoreOp op, + std::string *reason = nullptr) const; FailureOr getPreferredGroupReduceLayoutFact(VMIVRegType sourceType, int64_t numGroups, - std::string *reason = nullptr) const; + std::string *reason = nullptr) const; FailureOr getGroupReduceAddFSupport(const VMITargetCapabilityRegistry &capabilities, - VMIGroupReduceAddFOp op, - std::string *reason = nullptr) const; + VMIGroupReduceAddFOp op, + std::string *reason = nullptr) const; + + FailureOr + getGroupReduceMaxFSupport(const VMITargetCapabilityRegistry &capabilities, + VMIGroupReduceMaxFOp op, + std::string *reason = nullptr) const; FailureOr getGroupReduceAddISupport(const VMITargetCapabilityRegistry &capabilities, - VMIGroupReduceAddIOp op, + VMIGroupReduceAddIOp op, + std::string *reason = nullptr) const; + + FailureOr + getGroupBroadcastSupport(const VMITargetCapabilityRegistry &capabilities, + VMIGroupBroadcastOp op, std::string *reason = nullptr) const; FailureOr getGroupBroadcastSupport(const VMITargetCapabilityRegistry &capabilities, - VMIGroupBroadcastOp op, - std::string *reason = nullptr) const; + VMIVRegType sourceType, VMIVRegType resultType, + int64_t numGroups, + std::string *reason = nullptr) const; FailureOr getTruncFSupport(VMITruncFOp op, std::string *reason = nullptr) const; - FailureOr - getExtFSupport(VMIExtFOp op, std::string *reason = nullptr) const; + FailureOr getExtFSupport(VMIExtFOp op, + std::string *reason = nullptr) const; FailureOr getExtSISupport(VMIExtSIOp op, std::string *reason = nullptr) const; diff --git a/include/PTO/Transforms/VMITargetCapabilities.h b/include/PTO/Transforms/VMITargetCapabilities.h index a96a73a6d0..043da612e6 100644 --- a/include/PTO/Transforms/VMITargetCapabilities.h +++ b/include/PTO/Transforms/VMITargetCapabilities.h @@ -1,10 +1,12 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. //===- VMITargetCapabilities.h - VMI target capability registry -*- C++ -*-===// //===----------------------------------------------------------------------===// @@ -44,6 +46,7 @@ enum class VMIReductionKind { AddF, GroupAddI, GroupAddF, + GroupMaxF, MaxF, MinF, }; @@ -66,9 +69,7 @@ struct VMICapabilityResult { return result; } - bool isSupported() const { - return status == VMICapabilityStatus::supported; - } + bool isSupported() const { return status == VMICapabilityStatus::supported; } LogicalResult toLogicalResult(std::string *outReason = nullptr) const { if (isSupported()) @@ -188,8 +189,9 @@ class VMITargetCapabilityRegistry { "unsupported source/result layout pair"); } - VMICapabilityResult supportsMaskGranularityConversion( - StringRef sourceGranularity, StringRef resultGranularity) const { + VMICapabilityResult + supportsMaskGranularityConversion(StringRef sourceGranularity, + StringRef resultGranularity) const { if (!VMIMaskType::isConcreteGranularity(sourceGranularity) || !VMIMaskType::isConcreteGranularity(resultGranularity)) return VMICapabilityResult::missingCapability( @@ -207,8 +209,8 @@ class VMITargetCapabilityRegistry { "current VPTO pto.vlds surface has no mask operand"); } - VMICapabilityResult supportsFallbackResource( - VMIFallbackResourceKind kind) const { + VMICapabilityResult + supportsFallbackResource(VMIFallbackResourceKind kind) const { switch (kind) { case VMIFallbackResourceKind::ScratchMemory: return VMICapabilityResult::missingCapability( @@ -220,8 +222,8 @@ class VMITargetCapabilityRegistry { llvm_unreachable("unhandled VMI fallback resource kind"); } - VMICapabilityResult supportsReductionElementType( - VMIReductionKind kind, Type elementType) const { + VMICapabilityResult supportsReductionElementType(VMIReductionKind kind, + Type elementType) const { switch (kind) { case VMIReductionKind::AddI: if (pto::getPTOStorageElemBitWidth(elementType) == 32 && @@ -246,10 +248,11 @@ class VMITargetCapabilityRegistry { "cast i8/i16 storage before grouped reduction"); } case VMIReductionKind::GroupAddF: + case VMIReductionKind::GroupMaxF: if (elementType.isF16() || elementType.isF32()) return VMICapabilityResult::supported(); return VMICapabilityResult::missingCapability( - "grouped floating-point add reduction supports f16/f32 accumulator " + "grouped floating-point reduction supports f16/f32 accumulator " "elements"); case VMIReductionKind::MaxF: case VMIReductionKind::MinF: diff --git a/lib/PTO/IR/VMI.cpp b/lib/PTO/IR/VMI.cpp index 3589ec603e..25e08ac381 100644 --- a/lib/PTO/IR/VMI.cpp +++ b/lib/PTO/IR/VMI.cpp @@ -532,10 +532,9 @@ VMILayoutAttr::verify(function_ref emitError, return emitError() << "#pto.vmi.layout requires block_elems to be 1"; if (slots < 0) - return emitError() - << "#pto.vmi.layout requires slots to be omitted or positive"; + return emitError() << "#pto.vmi.layout requires slots to be omitted or positive"; return success(); } @@ -1121,21 +1120,23 @@ LogicalResult VMIReduceMaxFOp::verify() { return verifyReduceMinMaxFOp(*this); } LogicalResult VMIReduceMinFOp::verify() { return verifyReduceMinMaxFOp(*this); } -LogicalResult VMIGroupReduceAddFOp::verify() { - auto sourceType = cast(getSource().getType()); - auto maskType = cast(getMask().getType()); - auto resultType = cast(getResult().getType()); - if (!getOperation()->hasAttr("reassoc")) - return emitOpError( +template +static LogicalResult verifyGroupReduceFloatOp(OpTy op, bool requiresReassoc) { + auto sourceType = cast(op.getSource().getType()); + auto maskType = cast(op.getMask().getType()); + auto resultType = cast(op.getResult().getType()); + if (requiresReassoc && !op->hasAttr("reassoc")) + return op.emitOpError( "requires reassoc attr because grouped lowering uses pair-wise " "floating-point reductions"); if (!isVMIFloatLikeType(sourceType.getElementType())) - return emitOpError("requires floating-point-like VMI source element type"); + return op.emitOpError( + "requires floating-point-like VMI source element type"); if (sourceType.getElementCount() != resultType.getElementCount()) - return emitOpError( + return op.emitOpError( "requires source and result logical lane counts to match"); if (sourceType.getElementType() != resultType.getElementType()) - return emitOpError("requires source and result element types to match"); + return op.emitOpError("requires source and result element types to match"); if (auto sourceLayout = sourceType.getLayoutAttr()) { bool supportedSourceLayout = sourceLayout.isContiguous() || @@ -1146,21 +1147,29 @@ LogicalResult VMIGroupReduceAddFOp::verify() { (sourceLayout.getBlockElems() == 1 || sourceLayout.getBlockElems() == 8)); if (!supportedSourceLayout) - return emitOpError( + return op.emitOpError( "requires layout-assigned source to use contiguous layout or " "deinterleaved=2/4 layout with block_elems=1 or block_elems=8"); } if (auto resultLayout = resultType.getLayoutAttr()) { if (!resultLayout.isGroupSlots() || - resultLayout.getNumGroups() != getNumGroupsAttr().getInt()) - return emitOpError() << "requires layout-assigned result to use " - "#pto.vmi.layout"; + resultLayout.getNumGroups() != op.getNumGroupsAttr().getInt()) + return op.emitOpError() << "requires layout-assigned result to use " + "#pto.vmi.layout"; } - if (failed(verifyMaskMatchesData(getOperation(), maskType, sourceType))) + if (failed(verifyMaskMatchesData(op.getOperation(), maskType, sourceType))) return failure(); - return verifyNumGroups(getOperation(), sourceType, - getNumGroupsAttr().getInt()); + return verifyNumGroups(op.getOperation(), sourceType, + op.getNumGroupsAttr().getInt()); +} + +LogicalResult VMIGroupReduceAddFOp::verify() { + return verifyGroupReduceFloatOp(*this, /*requiresReassoc=*/true); +} + +LogicalResult VMIGroupReduceMaxFOp::verify() { + return verifyGroupReduceFloatOp(*this, /*requiresReassoc=*/false); } LogicalResult VMIGroupReduceAddIOp::verify() { @@ -1231,8 +1240,7 @@ LogicalResult VMIGroupBroadcastOp::verify() { getNumGroupsAttr().getInt()); } -template -static LogicalResult verifyVMIHistogramOp(OpTy op) { +template static LogicalResult verifyVMIHistogramOp(OpTy op) { auto accType = cast(op.getAcc().getType()); auto sourceType = cast(op.getSource().getType()); auto maskType = cast(op.getMask().getType()); diff --git a/lib/PTO/Transforms/PTOValidateVMIIR.cpp b/lib/PTO/Transforms/PTOValidateVMIIR.cpp index 1186a25e26..7529953fa5 100644 --- a/lib/PTO/Transforms/PTOValidateVMIIR.cpp +++ b/lib/PTO/Transforms/PTOValidateVMIIR.cpp @@ -1,10 +1,12 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. //===- PTOValidateVMIIR.cpp - VMI boundary verifier ----------------------===// //===----------------------------------------------------------------------===// @@ -50,9 +52,9 @@ bool containsVMIOrPhysicalType(Type type) { return true; if (auto functionType = dyn_cast(type)) { - return llvm::any_of(functionType.getInputs(), [](Type input) { - return containsVMIOrPhysicalType(input); - }) || + return llvm::any_of( + functionType.getInputs(), + [](Type input) { return containsVMIOrPhysicalType(input); }) || llvm::any_of(functionType.getResults(), [](Type result) { return containsVMIOrPhysicalType(result); }); @@ -110,8 +112,8 @@ bool isVMIHelperOp(Operation *op) { StringRef name = op->getName().getStringRef(); return name == "pto.vmi.ensure_layout" || name == "pto.vmi.ensure_mask_layout" || - name == "pto.vmi.ensure_mask_granularity" || - name == "pto.vmi.pack" || name == "pto.vmi.unpack"; + name == "pto.vmi.ensure_mask_granularity" || name == "pto.vmi.pack" || + name == "pto.vmi.unpack"; } bool isVMILayoutHelperOp(Operation *op) { @@ -155,8 +157,8 @@ void mirrorDiagnostic(llvm::raw_ostream *diagOS, Twine message) { LogicalResult emitInvariant(Operation *op, llvm::raw_ostream *diagOS, Twine message) { - InFlightDiagnostic diag = - op->emitError() << kVMIDiagPassInvariantPrefix << message; + InFlightDiagnostic diag = op->emitError() + << kVMIDiagPassInvariantPrefix << message; (void)diag; mirrorDiagnostic(diagOS, Twine(kVMIDiagPassInvariantPrefix) + message); return failure(); @@ -164,8 +166,8 @@ LogicalResult emitInvariant(Operation *op, llvm::raw_ostream *diagOS, LogicalResult emitLayoutContract(Operation *op, llvm::raw_ostream *diagOS, Twine message) { - InFlightDiagnostic diag = - op->emitError() << kVMIDiagLayoutContractPrefix << message; + InFlightDiagnostic diag = op->emitError() + << kVMIDiagLayoutContractPrefix << message; (void)diag; mirrorDiagnostic(diagOS, Twine(kVMIDiagLayoutContractPrefix) + message); return failure(); @@ -198,17 +200,15 @@ LogicalResult emitLayoutSupportContract(Operation *op, return emitLayoutContract(op, diagOS, text); } -LogicalResult emitHelperMaterializationContract(Operation *helper, - Type sourceType, - Type resultType, - StringRef helperName, - StringRef reason, - llvm::raw_ostream *diagOS) { +LogicalResult +emitHelperMaterializationContract(Operation *helper, Type sourceType, + Type resultType, StringRef helperName, + StringRef reason, llvm::raw_ostream *diagOS) { auto emitFallback = [&]() { return emitLayoutContract( helper, diagOS, - Twine(helperName) + " has no registered materialization support: " + - reason); + Twine(helperName) + + " has no registered materialization support: " + reason); }; if (helper->getNumResults() != 1 || !helper->getResult(0).hasOneUse()) @@ -223,8 +223,8 @@ LogicalResult emitHelperMaterializationContract(Operation *helper, << helperName << " has no registered materialization support: " << reason; os.flush(); - InFlightDiagnostic diag = - requester->emitError() << kVMIDiagLayoutContractPrefix << message; + InFlightDiagnostic diag = requester->emitError() + << kVMIDiagLayoutContractPrefix << message; diag.attachNote(helper->getLoc()) << "failed helper conversion " << sourceType << " -> " << resultType << " (" << reason << ")"; @@ -340,8 +340,7 @@ bool isFunctionTypeAttr(Operation *op, NamedAttribute attr) { return isa(op) && attr.getName() == "function_type"; } -LogicalResult verifyNoHiddenVMIAttributeType(Operation *op, - NamedAttribute attr, +LogicalResult verifyNoHiddenVMIAttributeType(Operation *op, NamedAttribute attr, llvm::raw_ostream *diagOS) { if (isFunctionTypeAttr(op, attr)) return success(); @@ -424,10 +423,10 @@ LogicalResult verifyLayoutAssignedOperationTypes(Operation *op, } LogicalResult verifyLayoutHelperSupport(Operation *op, - llvm::raw_ostream *diagOS); + llvm::raw_ostream *diagOS); LogicalResult verifyLayoutSemanticSupport(Operation *op, - llvm::raw_ostream *diagOS); + llvm::raw_ostream *diagOS); LogicalResult verifyOperationBoundary(Operation *op, llvm::raw_ostream *diagOS) { @@ -461,7 +460,7 @@ LogicalResult verifyLayoutAssignedOperation(Operation *op, if (isVMIHelperOp(op)) { if (isVMILayoutHelperOp(op)) return verifyHelperSupports ? verifyLayoutHelperSupport(op, diagOS) - : success(); + : success(); return emitInvariant( op, diagOS, "VMI pack/unpack helper appears before VMI-to-VPTO physicalization"); @@ -477,15 +476,15 @@ LogicalResult verifyLayoutAssignedOperation(Operation *op, } LogicalResult verifyLayoutHelperSupport(Operation *op, - llvm::raw_ostream *diagOS) { + llvm::raw_ostream *diagOS) { VMILayoutSupport supports; if (auto ensure = dyn_cast(op)) { auto sourceType = cast(ensure.getSource().getType()); auto resultType = cast(ensure.getResult().getType()); std::string reason; - if (failed(supports.canMaterializeDataLayout(sourceType, resultType, - &reason))) + if (failed( + supports.canMaterializeDataLayout(sourceType, resultType, &reason))) return emitHelperMaterializationContract( op, sourceType, resultType, "pto.vmi.ensure_layout", reason, diagOS); return success(); @@ -495,11 +494,11 @@ LogicalResult verifyLayoutHelperSupport(Operation *op, auto sourceType = cast(ensure.getSource().getType()); auto resultType = cast(ensure.getResult().getType()); std::string reason; - if (failed(supports.canMaterializeMaskLayout(sourceType, resultType, - &reason))) - return emitHelperMaterializationContract( - op, sourceType, resultType, "pto.vmi.ensure_mask_layout", reason, - diagOS); + if (failed( + supports.canMaterializeMaskLayout(sourceType, resultType, &reason))) + return emitHelperMaterializationContract(op, sourceType, resultType, + "pto.vmi.ensure_mask_layout", + reason, diagOS); return success(); } @@ -508,7 +507,7 @@ LogicalResult verifyLayoutHelperSupport(Operation *op, auto resultType = cast(ensure.getResult().getType()); std::string reason; if (failed(supports.canMaterializeMaskGranularity(sourceType, resultType, - &reason))) + &reason))) return emitLayoutContract( op, diagOS, Twine("pto.vmi.ensure_mask_granularity has no registered " @@ -521,7 +520,7 @@ LogicalResult verifyLayoutHelperSupport(Operation *op, } LogicalResult verifyLayoutSemanticSupport(Operation *op, - llvm::raw_ostream *diagOS) { + llvm::raw_ostream *diagOS) { VMILayoutSupport supports; VMITargetCapabilityRegistry capabilities; @@ -587,7 +586,8 @@ LogicalResult verifyLayoutSemanticSupport(Operation *op, return success(); std::string reason; - if (failed(supports.getGroupSlotsStoreSupport(capabilities, store, &reason))) + if (failed( + supports.getGroupSlotsStoreSupport(capabilities, store, &reason))) return emitLayoutSupportContract( op, diagOS, "pto.vmi.group_store has no registered group_slots layout support", @@ -602,8 +602,8 @@ LogicalResult verifyLayoutSemanticSupport(Operation *op, return success(); std::string reason; - if (failed(supports.getGroupReduceAddFSupport(capabilities, reduce, - &reason))) + if (failed( + supports.getGroupReduceAddFSupport(capabilities, reduce, &reason))) return emitLayoutSupportContract( op, diagOS, "pto.vmi.group_reduce_addf has no registered group_slots layout " @@ -612,6 +612,23 @@ LogicalResult verifyLayoutSemanticSupport(Operation *op, return success(); } + if (auto reduce = dyn_cast(op)) { + auto resultType = cast(reduce.getResult().getType()); + VMILayoutAttr layout = resultType.getLayoutAttr(); + if (!layout || !layout.isGroupSlots()) + return success(); + + std::string reason; + if (failed( + supports.getGroupReduceMaxFSupport(capabilities, reduce, &reason))) + return emitLayoutSupportContract( + op, diagOS, + "pto.vmi.group_reduce_maxf has no registered group_slots layout " + "support", + reason); + return success(); + } + if (auto broadcast = dyn_cast(op)) { auto sourceType = cast(broadcast.getSource().getType()); VMILayoutAttr layout = sourceType.getLayoutAttr(); @@ -620,7 +637,7 @@ LogicalResult verifyLayoutSemanticSupport(Operation *op, std::string reason; if (failed(supports.getGroupBroadcastSupport(capabilities, broadcast, - &reason))) + &reason))) return emitLayoutSupportContract( op, diagOS, "pto.vmi.group_broadcast has no registered layout support", reason); @@ -697,8 +714,9 @@ struct PTOValidateVMILayoutIRPass } // namespace -LogicalResult mlir::pto::validateVMIProducerBoundaryIR( - ModuleOp module, llvm::raw_ostream *diagOS) { +LogicalResult +mlir::pto::validateVMIProducerBoundaryIR(ModuleOp module, + llvm::raw_ostream *diagOS) { WalkResult result = module.walk([&](Operation *op) { if (failed(verifyOperationBoundary(op, diagOS))) return WalkResult::interrupt(); @@ -710,8 +728,7 @@ LogicalResult mlir::pto::validateVMIProducerBoundaryIR( LogicalResult mlir::pto::validateVMILayoutAssignedIR( ModuleOp module, llvm::raw_ostream *diagOS, bool verifyHelperSupports) { WalkResult result = module.walk([&](Operation *op) { - if (failed(verifyLayoutAssignedOperation(op, diagOS, - verifyHelperSupports))) + if (failed(verifyLayoutAssignedOperation(op, diagOS, verifyHelperSupports))) return WalkResult::interrupt(); return WalkResult::advance(); }); diff --git a/lib/PTO/Transforms/VMILayoutAssignment.cpp b/lib/PTO/Transforms/VMILayoutAssignment.cpp index 99e4314cf9..f976b0d5a7 100644 --- a/lib/PTO/Transforms/VMILayoutAssignment.cpp +++ b/lib/PTO/Transforms/VMILayoutAssignment.cpp @@ -349,6 +349,7 @@ struct LayoutSolver { solved.getSlots() > 0) return solved; if (value.getDefiningOp() || + value.getDefiningOp() || value.getDefiningOp()) return getPreferredGroupSlotsLayout(type, numGroups); if (value.getDefiningOp()) @@ -387,7 +388,8 @@ struct LayoutSolver { if (!resultType) continue; unsigned resultBits = getElementBitWidth(resultType.getElementType()); - std::optional vlaneElems = getVLaneElems(sourceType.getElementType()); + std::optional vlaneElems = + getVLaneElems(sourceType.getElementType()); if (vlaneElems && groupSize == 2 * *vlaneElems && resultBits == 16) return true; if (vlaneElems && groupSize == 4 * *vlaneElems && resultBits == 8) @@ -408,8 +410,8 @@ struct LayoutSolver { (layout.getBlockElems() == 1 || layout.getBlockElems() == 8); } - VMILayoutAttr getTruncFCompatibleGroupReduceSourceLayout( - VMIGroupReduceLayoutFact fact) { + VMILayoutAttr + getTruncFCompatibleGroupReduceSourceLayout(VMIGroupReduceLayoutFact fact) { if (fact.kind == VMIGroupReduceLayoutKind::TwoVLane) return VMILayoutAttr::getDeinterleaved(ctx, 2, /*blockElems=*/1); if (fact.kind == VMIGroupReduceLayoutKind::FourVLane) @@ -458,6 +460,42 @@ struct LayoutSolver { VMISelectOp, VMIBitcastOp>(op); } + bool canGroupBroadcastProduceLayout(VMIGroupBroadcastOp broadcast, + VMILayoutAttr resultLayout) { + if (!resultLayout) + return false; + auto sourceType = cast(broadcast.getSource().getType()); + auto resultType = cast(broadcast.getResult().getType()); + int64_t numGroups = broadcast.getNumGroupsAttr().getInt(); + auto assignedSourceType = VMIVRegType::get( + ctx, sourceType.getElementCount(), sourceType.getElementType(), + getPreferredGroupSlotsLayout(sourceType, numGroups)); + auto assignedResultType = + VMIVRegType::get(ctx, resultType.getElementCount(), + resultType.getElementType(), resultLayout); + VMILayoutSupport supports; + return succeeded(supports.getGroupBroadcastSupport( + capabilities, assignedSourceType, assignedResultType, numGroups)); + } + + bool canEquivalenceClassAdoptConsumerLayout(Value value, + VMILayoutAttr requestedLayout) { + unsigned id = addDataValue(value); + if (id == ~0u) + return true; + unsigned root = find(id); + for (DataNode &node : dataNodes) { + if (find(dataIds.lookup(node.value)) != root) + continue; + if (auto broadcast = node.value.getDefiningOp()) { + if (node.value == broadcast.getResult() && + !canGroupBroadcastProduceLayout(broadcast, requestedLayout)) + return false; + } + } + return true; + } + bool canAdoptConsumerRequestedLayout(Value value, VMILayoutAttr requestedLayout) { Operation *definingOp = value.getDefiningOp(); @@ -469,6 +507,8 @@ struct LayoutSolver { if (!canProducerAdoptConsumerLayout(definingOp)) return false; } + if (!canEquivalenceClassAdoptConsumerLayout(value, requestedLayout)) + return false; if (value.hasOneUse()) return true; @@ -502,9 +542,7 @@ struct LayoutSolver { unsigned root = find(id); VMILayoutAttr existing = dataNodes[root].naturalLayout; if (existing && existing != request.layout) - return request.operand->getOwner()->emitError() - << kVMIDiagLayoutContractPrefix << "conflicting natural layouts " - << existing << " and " << request.layout; + continue; dataNodes[root].naturalLayout = request.layout; } return success(); @@ -560,6 +598,7 @@ struct LayoutSolver { bool sourceIsGroupSlotValue = (sourceLayout && sourceLayout.isGroupSlots()) || truncf.getSource().getDefiningOp() || + truncf.getSource().getDefiningOp() || truncf.getSource().getDefiningOp(); if (!sourceIsGroupSlotValue) return false; @@ -828,8 +867,44 @@ struct LayoutSolver { VMILayoutSupport supports; FailureOr fact = supports.getPreferredGroupReduceLayoutFact(sourceType, numGroups); - VMILayoutAttr sourceLayout = getPreferredGroupReduceSourceLayout( - sourceType, numGroups); + VMILayoutAttr sourceLayout = + getPreferredGroupReduceSourceLayout(sourceType, numGroups); + VMILayoutAttr solvedSourceLayout = + getExplicitDataLayout(reduce.getSource()); + if (solvedSourceLayout && succeeded(fact) && + isCompatibleGroupReduceSourceLayout(*fact, solvedSourceLayout)) { + sourceLayout = solvedSourceLayout; + } else if (!sourceType.getLayoutAttr() && succeeded(fact)) { + if (hasCompatibleTruncFUseForGroupReduce(reduce.getSource(), + fact->groupSize)) { + if (VMILayoutAttr truncLayout = + getTruncFCompatibleGroupReduceSourceLayout(*fact)) + sourceLayout = truncLayout; + } + } + requestDataUse(reduce.getSourceMutable(), sourceLayout); + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceLayout, + getMaskGranularityForElement(sourceType.getElementType()), op))) + return WalkResult::interrupt(); + if (failed(setNaturalLayout( + reduce.getResult(), + succeeded(fact) + ? fact->resultLayout + : getPreferredGroupSlotsLayout(resultType, numGroups), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + auto sourceType = cast(reduce.getSource().getType()); + auto resultType = cast(reduce.getResult().getType()); + int64_t numGroups = reduce.getNumGroupsAttr().getInt(); + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredGroupReduceLayoutFact(sourceType, numGroups); + VMILayoutAttr sourceLayout = + getPreferredGroupReduceSourceLayout(sourceType, numGroups); VMILayoutAttr solvedSourceLayout = getExplicitDataLayout(reduce.getSource()); if (solvedSourceLayout && succeeded(fact) && @@ -850,9 +925,9 @@ struct LayoutSolver { return WalkResult::interrupt(); if (failed(setNaturalLayout( reduce.getResult(), - succeeded(fact) ? fact->resultLayout - : getPreferredGroupSlotsLayout(resultType, - numGroups), + succeeded(fact) + ? fact->resultLayout + : getPreferredGroupSlotsLayout(resultType, numGroups), op))) return WalkResult::interrupt(); return WalkResult::advance(); @@ -864,8 +939,8 @@ struct LayoutSolver { VMILayoutSupport supports; FailureOr fact = supports.getPreferredGroupReduceLayoutFact(sourceType, numGroups); - VMILayoutAttr sourceLayout = getPreferredGroupReduceSourceLayout( - sourceType, numGroups); + VMILayoutAttr sourceLayout = + getPreferredGroupReduceSourceLayout(sourceType, numGroups); VMILayoutAttr solvedSourceLayout = getExplicitDataLayout(reduce.getSource()); if (solvedSourceLayout && succeeded(fact) && @@ -878,9 +953,9 @@ struct LayoutSolver { return WalkResult::interrupt(); if (failed(setNaturalLayout( reduce.getResult(), - succeeded(fact) ? fact->resultLayout - : getPreferredGroupSlotsLayout(resultType, - numGroups), + succeeded(fact) + ? fact->resultLayout + : getPreferredGroupSlotsLayout(resultType, numGroups), op))) return WalkResult::interrupt(); return WalkResult::advance(); @@ -898,8 +973,8 @@ struct LayoutSolver { if (failed(requestMaskUse(hist.getMaskMutable(), getContiguousLayout(), "b8", op))) return WalkResult::interrupt(); - if (failed(setNaturalLayout(hist.getResult(), getContiguousLayout(), - op))) + if (failed( + setNaturalLayout(hist.getResult(), getContiguousLayout(), op))) return WalkResult::interrupt(); return WalkResult::advance(); } @@ -909,8 +984,8 @@ struct LayoutSolver { if (failed(requestMaskUse(hist.getMaskMutable(), getContiguousLayout(), "b8", op))) return WalkResult::interrupt(); - if (failed(setNaturalLayout(hist.getResult(), getContiguousLayout(), - op))) + if (failed( + setNaturalLayout(hist.getResult(), getContiguousLayout(), op))) return WalkResult::interrupt(); return WalkResult::advance(); } @@ -920,9 +995,8 @@ struct LayoutSolver { VMILayoutSupport supports; FailureOr fact = supports.getPreferredCastLayoutFact(sourceType, resultType); - if (succeeded(fact) && - (fact->kind == VMICastLayoutKind::Widen2x || - fact->kind == VMICastLayoutKind::Widen4x)) { + if (succeeded(fact) && (fact->kind == VMICastLayoutKind::Widen2x || + fact->kind == VMICastLayoutKind::Widen4x)) { requestDataUse(extf.getSourceMutable(), fact->sourceLayout); if (failed( setNaturalLayout(extf.getResult(), fact->resultLayout, op))) @@ -936,9 +1010,8 @@ struct LayoutSolver { VMILayoutSupport supports; FailureOr fact = supports.getPreferredCastLayoutFact(sourceType, resultType); - if (succeeded(fact) && - (fact->kind == VMICastLayoutKind::Widen2x || - fact->kind == VMICastLayoutKind::Widen4x)) { + if (succeeded(fact) && (fact->kind == VMICastLayoutKind::Widen2x || + fact->kind == VMICastLayoutKind::Widen4x)) { requestDataUse(extsi.getSourceMutable(), fact->sourceLayout); if (failed( setNaturalLayout(extsi.getResult(), fact->resultLayout, op))) @@ -952,9 +1025,8 @@ struct LayoutSolver { VMILayoutSupport supports; FailureOr fact = supports.getPreferredCastLayoutFact(sourceType, resultType); - if (succeeded(fact) && - (fact->kind == VMICastLayoutKind::Widen2x || - fact->kind == VMICastLayoutKind::Widen4x)) { + if (succeeded(fact) && (fact->kind == VMICastLayoutKind::Widen2x || + fact->kind == VMICastLayoutKind::Widen4x)) { requestDataUse(extui.getSourceMutable(), fact->sourceLayout); if (failed( setNaturalLayout(extui.getResult(), fact->resultLayout, op))) @@ -970,16 +1042,15 @@ struct LayoutSolver { supports.getPreferredCastLayoutFact(sourceType, resultType); VMILayoutAttr sourceLayout = getDataLayout(truncf.getSource()); if (succeeded(fact) && fact->kind == VMICastLayoutKind::Narrow2x && - sourceLayout && - sourceLayout.isGroupSlots() && sourceLayout.getSlots() == 1) { + sourceLayout && sourceLayout.isGroupSlots() && + sourceLayout.getSlots() == 1) { requestDataUse(truncf.getSourceMutable(), sourceLayout); if (failed(setNaturalLayout(truncf.getResult(), sourceLayout, op))) return WalkResult::interrupt(); return WalkResult::advance(); } - if (succeeded(fact) && - (fact->kind == VMICastLayoutKind::Narrow2x || - fact->kind == VMICastLayoutKind::Narrow4x)) + if (succeeded(fact) && (fact->kind == VMICastLayoutKind::Narrow2x || + fact->kind == VMICastLayoutKind::Narrow4x)) requestDataUse(truncf.getSourceMutable(), fact->sourceLayout); VMILayoutAttr resultLayout = succeeded(fact) ? fact->resultLayout : getContiguousLayout(); @@ -995,16 +1066,15 @@ struct LayoutSolver { supports.getPreferredCastLayoutFact(sourceType, resultType); VMILayoutAttr sourceLayout = getDataLayout(trunci.getSource()); if (succeeded(fact) && fact->kind == VMICastLayoutKind::Narrow2x && - sourceLayout && - sourceLayout.isGroupSlots() && sourceLayout.getSlots() == 1) { + sourceLayout && sourceLayout.isGroupSlots() && + sourceLayout.getSlots() == 1) { requestDataUse(trunci.getSourceMutable(), sourceLayout); if (failed(setNaturalLayout(trunci.getResult(), sourceLayout, op))) return WalkResult::interrupt(); return WalkResult::advance(); } - if (succeeded(fact) && - (fact->kind == VMICastLayoutKind::Narrow2x || - fact->kind == VMICastLayoutKind::Narrow4x)) + if (succeeded(fact) && (fact->kind == VMICastLayoutKind::Narrow2x || + fact->kind == VMICastLayoutKind::Narrow4x)) requestDataUse(trunci.getSourceMutable(), fact->sourceLayout); VMILayoutAttr resultLayout = succeeded(fact) ? fact->resultLayout : getContiguousLayout(); @@ -1551,6 +1621,14 @@ struct LayoutSolver { return WalkResult::interrupt(); return WalkResult::advance(); } + if (auto reduce = dyn_cast(op)) { + auto sourceType = cast(reduce.getSource().getType()); + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceType.getLayoutAttr(), + getMaskGranularityForElement(sourceType.getElementType()), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } if (auto reduce = dyn_cast(op)) { auto sourceType = cast(reduce.getSource().getType()); if (failed(requestMaskUse( diff --git a/lib/PTO/Transforms/VMILayoutSupport.cpp b/lib/PTO/Transforms/VMILayoutSupport.cpp index acb687eed0..a3babbf7ab 100644 --- a/lib/PTO/Transforms/VMILayoutSupport.cpp +++ b/lib/PTO/Transforms/VMILayoutSupport.cpp @@ -313,10 +313,10 @@ getPhysicalLogicalBitFootprint(VMIVRegType type) { static FailureOr getLayoutMaterializationSupport(VMILayoutAttr sourceLayout, - VMILayoutAttr resultLayout, - std::string *reason) { - auto fail = [&](const Twine &message) - -> FailureOr { + VMILayoutAttr resultLayout, + std::string *reason) { + auto fail = + [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -341,8 +341,9 @@ getLayoutMaterializationSupport(VMILayoutAttr sourceLayout, } // namespace FailureOr -VMILayoutSupport::getPreferredGroupReduceLayoutFact( - VMIVRegType sourceType, int64_t numGroups, std::string *reason) const { +VMILayoutSupport::getPreferredGroupReduceLayoutFact(VMIVRegType sourceType, + int64_t numGroups, + std::string *reason) const { auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); @@ -407,10 +408,8 @@ VMILayoutSupport::getPreferredGroupReduceLayoutFact( "2*VLaneElems, 4*VLaneElems, or full physical chunk multiples"); } -FailureOr -VMILayoutSupport::getPreferredCastLayoutFact(VMIVRegType sourceType, - VMIVRegType resultType, - std::string *reason) const { +FailureOr VMILayoutSupport::getPreferredCastLayoutFact( + VMIVRegType sourceType, VMIVRegType resultType, std::string *reason) const { auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); @@ -422,7 +421,8 @@ VMILayoutSupport::getPreferredCastLayoutFact(VMIVRegType sourceType, unsigned resultBits = pto::getPTOStorageElemBitWidth(resultType.getElementType()); if (sourceBits == 0 || resultBits == 0) - return fail("requires source/result element types with known storage width"); + return fail( + "requires source/result element types with known storage width"); if (sourceType.getElementCount() != resultType.getElementCount()) return fail("requires source/result lane count to match"); @@ -472,9 +472,9 @@ VMILayoutSupport::getPreferredCastLayoutFact(VMIVRegType sourceType, FailureOr VMILayoutSupport::getContiguousStoreSupport(VMIVRegType valueType, - std::string *reason) const { - auto fail = [&](const Twine &message) - -> FailureOr { + std::string *reason) const { + auto fail = + [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -530,10 +530,9 @@ LogicalResult VMILayoutSupport::canFoldContiguousStoreMaterialization( FailureOr VMILayoutSupport::getDataLayoutMaterializationSupport( - VMIVRegType sourceType, VMIVRegType resultType, - std::string *reason) const { - auto fail = [&](const Twine &message) - -> FailureOr { + VMIVRegType sourceType, VMIVRegType resultType, std::string *reason) const { + auto fail = + [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -550,29 +549,25 @@ VMILayoutSupport::getDataLayoutMaterializationSupport( getLayoutMaterializationSupport(sourceLayout, resultLayout, reason); if (failed(support)) return failure(); - if (failed(checkLayoutMaterializationShape(sourceType, resultType, - sourceLayout, resultLayout, - reason))) + if (failed(checkLayoutMaterializationShape( + sourceType, resultType, sourceLayout, resultLayout, reason))) return failure(); return support; } -LogicalResult -VMILayoutSupport::canMaterializeDataLayout(VMIVRegType sourceType, - VMIVRegType resultType, - std::string *reason) const { - if (failed(getDataLayoutMaterializationSupport(sourceType, resultType, - reason))) +LogicalResult VMILayoutSupport::canMaterializeDataLayout( + VMIVRegType sourceType, VMIVRegType resultType, std::string *reason) const { + if (failed( + getDataLayoutMaterializationSupport(sourceType, resultType, reason))) return failure(); return success(); } FailureOr VMILayoutSupport::getMaskLayoutMaterializationSupport( - VMIMaskType sourceType, VMIMaskType resultType, - std::string *reason) const { - auto fail = [&](const Twine &message) - -> FailureOr { + VMIMaskType sourceType, VMIMaskType resultType, std::string *reason) const { + auto fail = + [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -589,27 +584,23 @@ VMILayoutSupport::getMaskLayoutMaterializationSupport( getLayoutMaterializationSupport(sourceLayout, resultLayout, reason); if (failed(support)) return failure(); - if (failed(checkLayoutMaterializationShape(sourceType, resultType, - sourceLayout, resultLayout, - reason))) + if (failed(checkLayoutMaterializationShape( + sourceType, resultType, sourceLayout, resultLayout, reason))) return failure(); return support; } -LogicalResult -VMILayoutSupport::canMaterializeMaskLayout(VMIMaskType sourceType, - VMIMaskType resultType, - std::string *reason) const { - if (failed(getMaskLayoutMaterializationSupport(sourceType, resultType, - reason))) +LogicalResult VMILayoutSupport::canMaterializeMaskLayout( + VMIMaskType sourceType, VMIMaskType resultType, std::string *reason) const { + if (failed( + getMaskLayoutMaterializationSupport(sourceType, resultType, reason))) return failure(); return success(); } FailureOr VMILayoutSupport::getMaskGranularityMaterializationSupport( - VMIMaskType sourceType, VMIMaskType resultType, - std::string *reason) const { + VMIMaskType sourceType, VMIMaskType resultType, std::string *reason) const { auto fail = [&](const Twine &message) -> FailureOr { if (reason) @@ -633,16 +624,14 @@ VMILayoutSupport::getMaskGranularityMaterializationSupport( } LogicalResult VMILayoutSupport::canMaterializeMaskGranularity( - VMIMaskType sourceType, VMIMaskType resultType, - std::string *reason) const { + VMIMaskType sourceType, VMIMaskType resultType, std::string *reason) const { if (failed(getMaskGranularityMaterializationSupport(sourceType, resultType, - reason))) + reason))) return failure(); return success(); } -FailureOr -VMILayoutSupport::getGroupSlotLoadSupport( +FailureOr VMILayoutSupport::getGroupSlotLoadSupport( const VMITargetCapabilityRegistry &capabilities, VMIGroupSlotLoadOp op, std::string *reason) const { auto fail = [&](const Twine &message) -> FailureOr { @@ -711,9 +700,8 @@ FailureOr VMILayoutSupport::getGroupLoadSupport( !resultType.getElementType().isF32()) return fail("requires deinterleaved block8 f32 result layout"); - FailureOr groupSize = - getGroupSizeFromNumGroups(resultType, op.getNumGroupsAttr().getInt(), - reason); + FailureOr groupSize = getGroupSizeFromNumGroups( + resultType, op.getNumGroupsAttr().getInt(), reason); if (failed(groupSize)) return failure(); @@ -815,13 +803,11 @@ VMILayoutSupport::getGroupSlotsStoreSupport( "unit-stride slots=8"); } -FailureOr -getGroupReduceAddSupportImpl(const VMITargetCapabilityRegistry &capabilities, - Operation *op, VMIVRegType sourceType, - VMIMaskType maskType, VMIVRegType resultType, - int64_t numGroups, bool requiresReassoc, - VMIReductionKind reductionKind, - std::string *reason) { +FailureOr getGroupReduceAddSupportImpl( + const VMITargetCapabilityRegistry &capabilities, Operation *op, + VMIVRegType sourceType, VMIMaskType maskType, VMIVRegType resultType, + int64_t numGroups, bool requiresReassoc, VMIReductionKind reductionKind, + std::string *reason) { auto fail = [&](const Twine &message) -> FailureOr { if (reason) @@ -845,9 +831,9 @@ getGroupReduceAddSupportImpl(const VMITargetCapabilityRegistry &capabilities, getGroupSizeFromNumGroups(sourceType, numGroups, reason); FailureOr lanesPerPart = getDataLanesPerPart(sourceType.getElementType()); - int64_t vlaneElems = - succeeded(lanesPerPart) && *lanesPerPart % 8 == 0 ? *lanesPerPart / 8 - : -1; + int64_t vlaneElems = succeeded(lanesPerPart) && *lanesPerPart % 8 == 0 + ? *lanesPerPart / 8 + : -1; if (succeeded(groupSize) && resultLayout.getSlots() <= 0 && (*groupSize != vlaneElems && *groupSize != 2 * vlaneElems && *groupSize != 4 * vlaneElems)) @@ -933,10 +919,10 @@ getGroupReduceAddSupportImpl(const VMITargetCapabilityRegistry &capabilities, return fail("two-vlane group_reduce_add requires matching mask layout " "deinterleaved=2 with the same block_elems"); int64_t expectedResultArity = ceilDivNonNegative(numGroups, 8); - if (*resultArity != expectedResultArity || - *sourceArity != *resultArity * 2) - return fail("two-vlane group_reduce_add requires two source/mask parts per " - "result part"); + if (*resultArity != expectedResultArity || *sourceArity != *resultArity * 2) + return fail( + "two-vlane group_reduce_add requires two source/mask parts per " + "result part"); return VMIGroupReduceAddFSupport{ VMIGroupReduceAddFSupportKind::TwoVLaneDeinterleaved2VcgaddVadd}; } @@ -952,10 +938,10 @@ getGroupReduceAddSupportImpl(const VMITargetCapabilityRegistry &capabilities, return fail("four-vlane group_reduce_add requires matching mask layout " "deinterleaved=4 with the same block_elems"); int64_t expectedResultArity = ceilDivNonNegative(numGroups, 8); - if (*resultArity != expectedResultArity || - *sourceArity != *resultArity * 4) - return fail("four-vlane group_reduce_add requires four source/mask parts per " - "result part"); + if (*resultArity != expectedResultArity || *sourceArity != *resultArity * 4) + return fail( + "four-vlane group_reduce_add requires four source/mask parts per " + "result part"); return VMIGroupReduceAddFSupport{ VMIGroupReduceAddFSupportKind::FourVLaneDeinterleaved4VcgaddTree}; } @@ -969,29 +955,52 @@ VMILayoutSupport::getGroupReduceAddFSupport( const VMITargetCapabilityRegistry &capabilities, VMIGroupReduceAddFOp op, std::string *reason) const { return getGroupReduceAddSupportImpl( - capabilities, op.getOperation(), cast(op.getSource().getType()), + capabilities, op.getOperation(), + cast(op.getSource().getType()), cast(op.getMask().getType()), cast(op.getResult().getType()), op.getNumGroupsAttr().getInt(), /*requiresReassoc=*/true, VMIReductionKind::GroupAddF, reason); } +FailureOr +VMILayoutSupport::getGroupReduceMaxFSupport( + const VMITargetCapabilityRegistry &capabilities, VMIGroupReduceMaxFOp op, + std::string *reason) const { + return getGroupReduceAddSupportImpl( + capabilities, op.getOperation(), + cast(op.getSource().getType()), + cast(op.getMask().getType()), + cast(op.getResult().getType()), + op.getNumGroupsAttr().getInt(), /*requiresReassoc=*/false, + VMIReductionKind::GroupMaxF, reason); +} + FailureOr VMILayoutSupport::getGroupReduceAddISupport( const VMITargetCapabilityRegistry &capabilities, VMIGroupReduceAddIOp op, std::string *reason) const { return getGroupReduceAddSupportImpl( - capabilities, op.getOperation(), cast(op.getSource().getType()), + capabilities, op.getOperation(), + cast(op.getSource().getType()), cast(op.getMask().getType()), cast(op.getResult().getType()), op.getNumGroupsAttr().getInt(), /*requiresReassoc=*/false, VMIReductionKind::GroupAddI, reason); } -FailureOr -VMILayoutSupport::getGroupBroadcastSupport( +FailureOr VMILayoutSupport::getGroupBroadcastSupport( const VMITargetCapabilityRegistry &capabilities, VMIGroupBroadcastOp op, std::string *reason) const { + return getGroupBroadcastSupport(capabilities, + cast(op.getSource().getType()), + cast(op.getResult().getType()), + op.getNumGroupsAttr().getInt(), reason); +} + +FailureOr VMILayoutSupport::getGroupBroadcastSupport( + const VMITargetCapabilityRegistry &capabilities, VMIVRegType sourceType, + VMIVRegType resultType, int64_t numGroups, std::string *reason) const { (void)capabilities; auto fail = [&](const Twine &message) -> FailureOr { if (reason) @@ -999,15 +1008,12 @@ VMILayoutSupport::getGroupBroadcastSupport( return failure(); }; - auto sourceType = cast(op.getSource().getType()); - auto resultType = cast(op.getResult().getType()); if (sourceType.getElementType() != resultType.getElementType() || sourceType.getElementCount() != resultType.getElementCount()) return fail("requires source/result shape and element type to match"); VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); VMILayoutAttr resultLayout = resultType.getLayoutAttr(); - int64_t numGroups = op.getNumGroupsAttr().getInt(); if (!sourceLayout || !resultLayout) return fail("requires assigned source/result layouts"); if (!sourceLayout.isGroupSlots() || sourceLayout.getNumGroups() != numGroups) @@ -1068,8 +1074,7 @@ VMILayoutSupport::getGroupBroadcastSupport( } FailureOr -VMILayoutSupport::getTruncFSupport(VMITruncFOp op, - std::string *reason) const { +VMILayoutSupport::getTruncFSupport(VMITruncFOp op, std::string *reason) const { auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); @@ -1126,8 +1131,7 @@ VMILayoutSupport::getTruncFSupport(VMITruncFOp op, } FailureOr -VMILayoutSupport::getExtFSupport(VMIExtFOp op, - std::string *reason) const { +VMILayoutSupport::getExtFSupport(VMIExtFOp op, std::string *reason) const { auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); @@ -1159,13 +1163,11 @@ VMILayoutSupport::getExtFSupport(VMIExtFOp op, if (fact->kind == VMICastLayoutKind::Widen2x && resultLayout.getFactor() == fact->factor && *resultArity == fact->factor * *sourceArity) - return VMIExtFSupport{ - VMIExtFSupportKind::ContiguousF16ToDeinterleaved2F32}; + return VMIExtFSupport{VMIExtFSupportKind::ContiguousF16ToDeinterleaved2F32}; if (fact->kind == VMICastLayoutKind::Widen4x && resultLayout.getFactor() == fact->factor && *resultArity == fact->factor * *sourceArity) - return VMIExtFSupport{ - VMIExtFSupportKind::ContiguousF8ToDeinterleaved4F32}; + return VMIExtFSupport{VMIExtFSupportKind::ContiguousF8ToDeinterleaved4F32}; return fail("unsupported extf source element width, result factor, or " "physical arity"); @@ -1173,7 +1175,7 @@ VMILayoutSupport::getExtFSupport(VMIExtFOp op, template static FailureOr getExtISupportImpl(OpT op, - std::string *reason) { + std::string *reason) { auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); @@ -1207,33 +1209,28 @@ static FailureOr getExtISupportImpl(OpT op, if (fact->kind == VMICastLayoutKind::Widen2x && resultLayout.getFactor() == fact->factor && *resultArity == fact->factor * *sourceArity) - return VMIExtISupport{ - VMIExtISupportKind::ContiguousI16ToDeinterleaved2I32}; + return VMIExtISupport{VMIExtISupportKind::ContiguousI16ToDeinterleaved2I32}; if (fact->kind == VMICastLayoutKind::Widen4x && resultLayout.getFactor() == fact->factor && *resultArity == fact->factor * *sourceArity) - return VMIExtISupport{ - VMIExtISupportKind::ContiguousI8ToDeinterleaved4I32}; + return VMIExtISupport{VMIExtISupportKind::ContiguousI8ToDeinterleaved4I32}; return fail("unsupported integer extension source/result element width, " "result factor, or physical arity"); } FailureOr -VMILayoutSupport::getExtSISupport(VMIExtSIOp op, - std::string *reason) const { +VMILayoutSupport::getExtSISupport(VMIExtSIOp op, std::string *reason) const { return getExtISupportImpl(op, reason); } FailureOr -VMILayoutSupport::getExtUISupport(VMIExtUIOp op, - std::string *reason) const { +VMILayoutSupport::getExtUISupport(VMIExtUIOp op, std::string *reason) const { return getExtISupportImpl(op, reason); } FailureOr -VMILayoutSupport::getTruncISupport(VMITruncIOp op, - std::string *reason) const { +VMILayoutSupport::getTruncISupport(VMITruncIOp op, std::string *reason) const { auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); @@ -1302,7 +1299,7 @@ VMILayoutSupport::getTruncISupport(VMITruncIOp op, FailureOr VMILayoutSupport::getBitcastSupport(VMIBitcastOp op, - std::string *reason) const { + std::string *reason) const { auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); @@ -1399,14 +1396,12 @@ getHistogramSupportImpl(OpTy op, std::string *reason) { } FailureOr -VMILayoutSupport::getDhistSupport(VMIDhistOp op, - std::string *reason) const { +VMILayoutSupport::getDhistSupport(VMIDhistOp op, std::string *reason) const { return getHistogramSupportImpl(op, reason); } FailureOr -VMILayoutSupport::getChistSupport(VMIChistOp op, - std::string *reason) const { +VMILayoutSupport::getChistSupport(VMIChistOp op, std::string *reason) const { if (reason) *reason = "CHISTv2 cumulative high-range semantics are not classified"; return failure(); diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index 39ca049a1e..806f6c67fc 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -921,9 +921,9 @@ VMICapabilityResult requireIdentityMemRefLayout(Type memoryType, StringRef role, return VMICapabilityResult::missingCapability(reason); } -VMIMemorySafeReadProof computeSafeFullReadProof( - Type sourceType, std::optional constantOffset, - VMIVRegType resultType) { +VMIMemorySafeReadProof +computeSafeFullReadProof(Type sourceType, std::optional constantOffset, + VMIVRegType resultType) { VMIMemorySafeReadProof proof; proof.constantOffset = constantOffset; @@ -964,10 +964,11 @@ VMIMemorySafeReadProof computeSafeFullReadProof( return proof; } -VMIMemoryAccessPlan buildReadAccessPlan( - const VMITargetCapabilityRegistry &capabilities, Value source, - Type sourceType, VMIVRegType resultType, - std::optional constantOffset, VMIMemoryValidMaskKind validMask) { +VMIMemoryAccessPlan +buildReadAccessPlan(const VMITargetCapabilityRegistry &capabilities, + Value source, Type sourceType, VMIVRegType resultType, + std::optional constantOffset, + VMIMemoryValidMaskKind validMask) { VMIMemoryAccessPlan plan; plan.baseType = sourceType; plan.valueType = resultType; @@ -1032,9 +1033,10 @@ void requireUnavailableReadFallback(VMIMemoryAccessPlan &plan) { maskedLoadReason + scratchReason + guardedReason); } -FailureOr verifyFullOrSafeReadVRegChunks( - Operation *op, VMIVRegType type, Type sourceType, Value offset, - PatternRewriter &rewriter) { +FailureOr verifyFullOrSafeReadVRegChunks(Operation *op, + VMIVRegType type, + Type sourceType, Value offset, + PatternRewriter &rewriter) { std::string fullChunkReason; FailureOr lanesPerPart = checkFullDataPhysicalChunks(type, &fullChunkReason); @@ -1055,19 +1057,20 @@ FailureOr verifyFullOrSafeReadVRegChunks( return failure(); } -LogicalResult checkSupportedLoadShape( - const VMITargetCapabilityRegistry &capabilities, VMIVRegType type, - Value source, Type sourceType, std::optional constantOffset, - std::string *reason) { +LogicalResult +checkSupportedLoadShape(const VMITargetCapabilityRegistry &capabilities, + VMIVRegType type, Value source, Type sourceType, + std::optional constantOffset, + std::string *reason) { auto fail = [&](const Twine &message) -> LogicalResult { if (reason) *reason = message.str(); return failure(); }; - VMIMemoryAccessPlan accessPlan = buildReadAccessPlan( - capabilities, source, sourceType, type, constantOffset, - VMIMemoryValidMaskKind::AllTrue); + VMIMemoryAccessPlan accessPlan = + buildReadAccessPlan(capabilities, source, sourceType, type, + constantOffset, VMIMemoryValidMaskKind::AllTrue); if (!accessPlan.targetCapability.isSupported()) return fail(accessPlan.targetCapability.reason); @@ -2112,8 +2115,7 @@ FailureOr> materializeDynamicContiguousGroupMask( shiftScalar, *allMask) .getResult(); col = rewriter - .create(loc, indexVectorType, lane, groupBase, - *allMask) + .create(loc, indexVectorType, lane, groupBase, *allMask) .getResult(); } @@ -3057,10 +3059,11 @@ struct OneToNVMIEnsureLayoutOpPattern VMILayoutSupport supports; std::string supportReason; if (failed(supports.canMaterializeDataLayout(sourceType, resultType, - &supportReason))) + &supportReason))) return rewriter.notifyMatchFailure( - op, Twine("ensure_layout has no registered materialization support: ") + - supportReason); + op, + Twine("ensure_layout has no registered materialization support: ") + + supportReason); VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); VMILayoutAttr resultLayout = resultType.getLayoutAttr(); if (!sourceLayout || !resultLayout) @@ -3091,11 +3094,11 @@ struct OneToNVMIEnsureMaskLayoutOpPattern VMILayoutSupport supports; std::string supportReason; if (failed(supports.canMaterializeMaskLayout(sourceType, resultType, - &supportReason))) + &supportReason))) return rewriter.notifyMatchFailure( - op, - Twine("ensure_mask_layout has no registered materialization support: ") + - supportReason); + op, Twine("ensure_mask_layout has no registered materialization " + "support: ") + + supportReason); if (sourceType.getGranularity() != resultType.getGranularity()) return rewriter.notifyMatchFailure( op, "mask layout helper cannot also change granularity"); @@ -3130,7 +3133,7 @@ struct OneToNVMIEnsureMaskGranularityOpPattern VMILayoutSupport supports; std::string supportReason; if (failed(supports.canMaterializeMaskGranularity(sourceType, resultType, - &supportReason))) + &supportReason))) return rewriter.notifyMatchFailure( op, Twine("ensure_mask_granularity has no registered materialization " "support: ") + @@ -3623,8 +3626,8 @@ struct OneToNVMICreateGroupMaskOpPattern contiguousMaterializations = computeGroupMaskMaterializationForType( op, contiguousType, &contiguousReason); if (failed(contiguousMaterializations)) - return rewriter.notifyMatchFailure( - op, Twine("create_group_mask ") + contiguousReason); + return rewriter.notifyMatchFailure(op, Twine("create_group_mask ") + + contiguousReason); contiguousParts.reserve(contiguousMaterializations->size()); for (const ConstantMaskChunkMaterialization &materialization : @@ -3807,23 +3810,20 @@ struct OneToNVMILoadOpPattern : OneToNOpConversionPattern { Value firstOffset = createChunkOffset( op.getLoc(), *offset, group * 4 * *lanesPerPart, rewriter); Value secondOffset = createChunkOffset( - op.getLoc(), *offset, (group * 4 + 2) * *lanesPerPart, - rewriter); - auto first = - rewriter.create(op.getLoc(), part0Type, part1Type, - *source, firstOffset, - rewriter.getStringAttr(*dist)); - auto second = - rewriter.create(op.getLoc(), part2Type, part3Type, - *source, secondOffset, - rewriter.getStringAttr(*dist)); - - auto even = rewriter.create( - op.getLoc(), part0Type, part2Type, first.getLow(), - second.getLow()); - auto odd = rewriter.create( - op.getLoc(), part1Type, part3Type, first.getHigh(), - second.getHigh()); + op.getLoc(), *offset, (group * 4 + 2) * *lanesPerPart, rewriter); + auto first = rewriter.create( + op.getLoc(), part0Type, part1Type, *source, firstOffset, + rewriter.getStringAttr(*dist)); + auto second = rewriter.create( + op.getLoc(), part2Type, part3Type, *source, secondOffset, + rewriter.getStringAttr(*dist)); + + auto even = + rewriter.create(op.getLoc(), part0Type, part2Type, + first.getLow(), second.getLow()); + auto odd = + rewriter.create(op.getLoc(), part1Type, part3Type, + first.getHigh(), second.getHigh()); part0.push_back(even.getLow()); part1.push_back(odd.getLow()); part2.push_back(even.getHigh()); @@ -5449,16 +5449,17 @@ struct OneToNVMIReduceAddFOpPattern } }; -template -struct OneToNVMIGroupReduceAddOpPattern : OneToNOpConversionPattern { - OneToNVMIGroupReduceAddOpPattern( - TypeConverter &typeConverter, MLIRContext *context, - const VMITargetCapabilityRegistry &capabilities) +template +struct OneToNVMIGroupReduceOpPattern : OneToNOpConversionPattern { + OneToNVMIGroupReduceOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + const VMITargetCapabilityRegistry &capabilities) : OneToNOpConversionPattern(typeConverter, context), capabilities(capabilities) {} LogicalResult - matchAndRewrite(OpTy op, typename OneToNOpConversionPattern::OpAdaptor adaptor, + matchAndRewrite(OpTy op, + typename OneToNOpConversionPattern::OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override { auto sourceVMIType = cast(op.getSource().getType()); auto resultVMIType = cast(op.getResult().getType()); @@ -5472,15 +5473,14 @@ struct OneToNVMIGroupReduceAddOpPattern : OneToNOpConversionPattern { getSupport(supports, op, &supportReason); if (failed(support)) return rewriter.notifyMatchFailure( - op, Twine("group_reduce_add has no layout support: ") + - supportReason); + op, Twine(op->getName().getStringRef()) + + " has no layout support: " + supportReason); FailureOr groupSize = getGroupSizeFromNumGroups( sourceVMIType, op.getNumGroupsAttr().getInt()); if (failed(groupSize)) return rewriter.notifyMatchFailure( - op, - "group_reduce_addf requires num_groups to evenly divide lane count"); + op, "group reduce requires num_groups to evenly divide lane count"); if (support->kind == VMIGroupReduceAddFSupportKind::OneVLaneVcgadd) { if (sourceParts.size() != maskParts.size() || @@ -5506,9 +5506,9 @@ struct OneToNVMIGroupReduceAddOpPattern : OneToNOpConversionPattern { results.reserve(resultTypes.size()); for (auto [sourceIndex, sourcePart] : llvm::enumerate(sourceParts)) { results.push_back(rewriter - .create(op.getLoc(), resultType, - sourcePart, - maskParts[sourceIndex]) + .create(op.getLoc(), resultType, + sourcePart, + maskParts[sourceIndex]) .getResult()); } @@ -5554,16 +5554,18 @@ struct OneToNVMIGroupReduceAddOpPattern : OneToNOpConversionPattern { return rewriter.notifyMatchFailure( op, "s16 block8 group_reduce_addf requires uniform physical " "types"); - Value lo = - rewriter.create(op.getLoc(), resultType, loSource, loMask) - .getResult(); - Value hi = - rewriter.create(op.getLoc(), resultType, hiSource, hiMask) - .getResult(); - results.push_back( - rewriter - .create(op.getLoc(), resultType, lo, hi, *combineMask) - .getResult()); + Value lo = rewriter + .create(op.getLoc(), resultType, + loSource, loMask) + .getResult(); + Value hi = rewriter + .create(op.getLoc(), resultType, + hiSource, hiMask) + .getResult(); + results.push_back(rewriter + .create(op.getLoc(), resultType, lo, + hi, *combineMask) + .getResult()); } rewriter.replaceOp(op, results, adaptor.getResultMapping()); @@ -5608,21 +5610,24 @@ struct OneToNVMIGroupReduceAddOpPattern : OneToNOpConversionPattern { return rewriter.notifyMatchFailure( op, "s32 block8 group_reduce_addf requires uniform physical " "types"); - partials.push_back( - rewriter.create(op.getLoc(), resultType, source, mask) - .getResult()); + partials.push_back(rewriter + .create( + op.getLoc(), resultType, source, mask) + .getResult()); } - Value sum01 = rewriter - .create(op.getLoc(), resultType, partials[0], - partials[1], *combineMask) - .getResult(); - Value sum23 = rewriter - .create(op.getLoc(), resultType, partials[2], - partials[3], *combineMask) - .getResult(); + Value sum01 = + rewriter + .create(op.getLoc(), resultType, partials[0], + partials[1], *combineMask) + .getResult(); + Value sum23 = + rewriter + .create(op.getLoc(), resultType, partials[2], + partials[3], *combineMask) + .getResult(); results.push_back(rewriter - .create(op.getLoc(), resultType, sum01, - sum23, *combineMask) + .create(op.getLoc(), resultType, + sum01, sum23, *combineMask) .getResult()); } @@ -5642,10 +5647,9 @@ struct OneToNVMIGroupReduceAddOpPattern : OneToNOpConversionPattern { &chunksPerGroup, rewriter))) return failure(); VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); - bool rowLocalSlots1Result = - resultLayout && resultLayout.isGroupSlots() && - resultLayout.getNumGroups() == groupCount && - resultLayout.getSlots() == 1; + bool rowLocalSlots1Result = resultLayout && resultLayout.isGroupSlots() && + resultLayout.getNumGroups() == groupCount && + resultLayout.getSlots() == 1; int64_t expectedResultParts = rowLocalSlots1Result ? groupCount : groupCount * chunksPerGroup; if (sourceParts.size() != maskParts.size() || @@ -5682,11 +5686,7 @@ struct OneToNVMIGroupReduceAddOpPattern : OneToNOpConversionPattern { op, "failed to create group_reduce_addf masks"); for (int64_t group = 0; group < groupCount; ++group) { - FailureOr accumulator = - createZeroVector(op.getLoc(), resultType, rewriter); - if (failed(accumulator)) - return rewriter.notifyMatchFailure( - op, "failed to create group_reduce_addf accumulator"); + Value accumulator; for (int64_t chunk = 0; chunk < chunksPerGroup; ++chunk) { int64_t index = group * chunksPerGroup + chunk; @@ -5696,19 +5696,23 @@ struct OneToNVMIGroupReduceAddOpPattern : OneToNOpConversionPattern { op, "group_reduce_addf requires uniform physical chunk types"); Value reduced = rewriter - .create(op.getLoc(), resultType, sourceParts[index], - maskParts[index]) + .create(op.getLoc(), resultType, + sourceParts[index], maskParts[index]) .getResult(); - *accumulator = rewriter - .create(op.getLoc(), resultType, reduced, - *accumulator, *firstLaneMask) - .getResult(); + if (!accumulator) { + accumulator = reduced; + continue; + } + accumulator = rewriter + .create(op.getLoc(), resultType, reduced, + accumulator, *firstLaneMask) + .getResult(); } int64_t destChunk = rowLocalSlots1Result ? group : group * chunksPerGroup; results[destChunk] = rewriter - .create(op.getLoc(), resultType, *accumulator, + .create(op.getLoc(), resultType, accumulator, results[destChunk], *firstLaneMask) .getResult(); } @@ -5718,18 +5722,24 @@ struct OneToNVMIGroupReduceAddOpPattern : OneToNOpConversionPattern { } private: - FailureOr - getSupport(VMILayoutSupport &supports, VMIGroupReduceAddFOp op, - std::string *reason) const { + FailureOr getSupport(VMILayoutSupport &supports, + VMIGroupReduceAddFOp op, + std::string *reason) const { return supports.getGroupReduceAddFSupport(capabilities, op, reason); } - FailureOr - getSupport(VMILayoutSupport &supports, VMIGroupReduceAddIOp op, - std::string *reason) const { + FailureOr getSupport(VMILayoutSupport &supports, + VMIGroupReduceAddIOp op, + std::string *reason) const { return supports.getGroupReduceAddISupport(capabilities, op, reason); } + FailureOr getSupport(VMILayoutSupport &supports, + VMIGroupReduceMaxFOp op, + std::string *reason) const { + return supports.getGroupReduceMaxFSupport(capabilities, op, reason); + } + const VMITargetCapabilityRegistry &capabilities; }; @@ -5994,8 +6004,7 @@ struct OneToNVMIDhistOpPattern : OneToNOpConversionPattern { FailureOr lanesPerPart = getDataLanesPerPart(sourceType.getElementType()); if (failed(lanesPerPart)) - return rewriter.notifyMatchFailure(op, - "failed to compute source lanes"); + return rewriter.notifyMatchFailure(op, "failed to compute source lanes"); Location loc = op.getLoc(); Value bin0 = createI32Constant(loc, 0, rewriter); @@ -6012,21 +6021,19 @@ struct OneToNVMIDhistOpPattern : OneToNOpConversionPattern { Value chunkMask = userMask; int64_t firstLane = static_cast(index) * *lanesPerPart; - int64_t activeLanes = - std::min(*lanesPerPart, - sourceType.getElementCount() - firstLane); + int64_t activeLanes = std::min( + *lanesPerPart, sourceType.getElementCount() - firstLane); if (activeLanes < *lanesPerPart) { - FailureOr validMask = - createPrefixMaskForActiveLanes(loc, maskType, activeLanes, - rewriter); + FailureOr validMask = createPrefixMaskForActiveLanes( + loc, maskType, activeLanes, rewriter); FailureOr allMask = createAllTrueMask(loc, maskType, rewriter); if (failed(validMask) || failed(allMask)) return rewriter.notifyMatchFailure( op, "failed to materialize tail-valid b8 mask"); - chunkMask = rewriter - .create(loc, maskType, chunkMask, *validMask, - *allMask) - .getResult(); + chunkMask = + rewriter + .create(loc, maskType, chunkMask, *validMask, *allMask) + .getResult(); } lo = rewriter.create(loc, loType, lo, source, chunkMask, bin0) @@ -6301,9 +6308,9 @@ struct OneToNVMIExtIOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult - matchAndRewrite( - OpT op, typename OneToNOpConversionPattern::OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + matchAndRewrite(OpT op, + typename OneToNOpConversionPattern::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { ValueRange sourceParts = adaptor.getSource(); TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); if (sourceParts.empty()) @@ -6330,8 +6337,7 @@ struct OneToNVMIExtIOpPattern : OneToNOpConversionPattern { !isa(resultVRegType.getElementType()) || (resultVRegTypes.empty() ? pto::getPTOStorageElemBitWidth( resultVRegType.getElementType()) != 32 - : resultVRegType != - resultVRegTypes.front())) + : resultVRegType != resultVRegTypes.front())) return rewriter.notifyMatchFailure( op, "unsupported physical integer extension result type"); resultVRegTypes.push_back(resultVRegType); @@ -6996,15 +7002,15 @@ void populateVMIOneToNConversionPatterns( OneToNVMIReduceMinMaxFOpPattern, OneToNVMIReduceMinMaxFOpPattern, OneToNVMIExtFOpPattern, OneToNVMITruncFOpPattern, - OneToNVMIExtIOpPattern, - OneToNVMIExtIOpPattern, OneToNVMITruncIOpPattern, - OneToNVMIBitcastOpPattern, OneToNVMIChannelSplitOpPattern, - OneToNVMIChannelMergeOpPattern, OneToNVMIShuffleOpPattern>( - typeConverter, patterns.getContext()); - patterns - .add, - OneToNVMIGroupReduceAddOpPattern>( - typeConverter, patterns.getContext(), capabilities); + OneToNVMIExtIOpPattern, OneToNVMIExtIOpPattern, + OneToNVMITruncIOpPattern, OneToNVMIBitcastOpPattern, + OneToNVMIChannelSplitOpPattern, OneToNVMIChannelMergeOpPattern, + OneToNVMIShuffleOpPattern>(typeConverter, patterns.getContext()); + patterns.add< + OneToNVMIGroupReduceOpPattern, + OneToNVMIGroupReduceOpPattern, + OneToNVMIGroupReduceOpPattern>( + typeConverter, patterns.getContext(), capabilities); patterns.add( typeConverter, patterns.getContext(), capabilities); } @@ -7384,13 +7390,16 @@ checkSupportedReduceShape(const VMITargetCapabilityRegistry &capabilities, } template -LogicalResult checkSupportedGroupReduceAddShape( - const VMITargetCapabilityRegistry &capabilities, OpTy op, - std::string *reason = nullptr) { +LogicalResult +checkSupportedGroupReduceShape(const VMITargetCapabilityRegistry &capabilities, + OpTy op, std::string *reason = nullptr) { VMILayoutSupport supports; if constexpr (std::is_same_v) { if (succeeded(supports.getGroupReduceAddFSupport(capabilities, op, reason))) return success(); + } else if constexpr (std::is_same_v) { + if (succeeded(supports.getGroupReduceMaxFSupport(capabilities, op, reason))) + return success(); } else { if (succeeded(supports.getGroupReduceAddISupport(capabilities, op, reason))) return success(); @@ -7642,7 +7651,8 @@ verifySupportedVMIToVPTOOps(ModuleOp module, broadcast.emitError() << kVMIDiagUnsupportedPrefix << "pto.vmi.group_broadcast requires full source chunks with " - "#pto.vmi.layout, a dense full result layout, " + "#pto.vmi.layout, a dense full result " + "layout, " "and num_groups deriving a group size that divides or is a " "multiple of physical chunk lanes (" << reason << ")"; @@ -8062,13 +8072,14 @@ verifySupportedVMIToVPTOOps(ModuleOp module, if (auto reduce = dyn_cast(op)) { std::string reason; if (succeeded( - checkSupportedGroupReduceAddShape(capabilities, reduce, &reason))) + checkSupportedGroupReduceShape(capabilities, reduce, &reason))) return WalkResult::advance(); reduce.emitError() << kVMIDiagUnsupportedPrefix << "pto.vmi.group_reduce_addf lowers through pto.vcgadd for 32B " "VLane groups or through pto.vcadd with reassoc, contiguous full " - "source/mask chunks, #pto.vmi.layout result " + "source/mask chunks, #pto.vmi.layout " + "result " "chunks, and num_groups deriving a group size aligned to " "physical chunks (" << reason << ")"; @@ -8078,7 +8089,7 @@ verifySupportedVMIToVPTOOps(ModuleOp module, if (auto reduce = dyn_cast(op)) { std::string reason; if (succeeded( - checkSupportedGroupReduceAddShape(capabilities, reduce, &reason))) + checkSupportedGroupReduceShape(capabilities, reduce, &reason))) return WalkResult::advance(); reduce.emitError() << kVMIDiagUnsupportedPrefix @@ -8090,6 +8101,21 @@ verifySupportedVMIToVPTOOps(ModuleOp module, return WalkResult::interrupt(); } + if (auto reduce = dyn_cast(op)) { + std::string reason; + if (succeeded( + checkSupportedGroupReduceShape(capabilities, reduce, &reason))) + return WalkResult::advance(); + reduce.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.group_reduce_maxf lowers through pto.vcgmax/vmax only " + "for f16/f32 values, matching source/mask chunks, " + "#pto.vmi.layout result chunks, and " + "num_groups deriving a group size aligned to physical chunks (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto reduce = dyn_cast(op)) { std::string reason; if (succeeded(checkSupportedReduceShape( diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_maxf_quant.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_maxf_quant.pto new file mode 100644 index 0000000000..1ae3f90a15 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_maxf_quant.pto @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_maxf_quant( + %src: !pto.ptr, + %scale_out: !pto.ptr, + %out8: !pto.ptr, + %off: index) { + %c8 = arith.constant 8 : index + %c256 = arith.constant 256 : index + %eps = arith.constant 1.000000e-04 : f32 + %fp8_max = arith.constant 4.480000e+02 : f32 + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x = pto.vmi.load %src[%off] : !pto.ptr -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax_raw = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + %eps2 = pto.vmi.broadcast %eps : f32 -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.maxf %amax_raw, %eps2 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %fp8_max2 = pto.vmi.broadcast %fp8_max : f32 -> !pto.vmi.vreg<256xf32> + %scale = pto.vmi.divf %amax, %fp8_max2 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %scale, %scale_out[%off], %c8 {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %q = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %q + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %out8[%off] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_maxf_quant( +// ASSIGN: %[[X:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[ABS:.*]] = pto.vmi.absf %[[X]] +// ASSIGN: %[[AMAX_RAW:.*]] = pto.vmi.group_reduce_maxf %[[ABS]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[SCALE:.*]] = pto.vmi.divf +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SCALE]] +// ASSIGN: %[[SCALE_VEC:.*]] = pto.vmi.group_broadcast %[[SCALE]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[Q:.*]] = pto.vmi.divf %[[X]], %[[SCALE_VEC]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[Q_SPLIT:.*]] = pto.vmi.ensure_layout %[[Q]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[Q8:.*]] = pto.vmi.truncf %[[Q_SPLIT]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_maxf_quant( +// LOWER: pto.vcgmax +// LOWER: pto.vmax +// LOWER: pto.vsel +// LOWER: pto.vdiv +// LOWER: pto.vdintlv +// LOWER: pto.vcvt +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/compare.py b/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/compare.py new file mode 100644 index 0000000000..c6e34633b5 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check_f32(name: str, atol: float, rtol: float) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.float32) + output = np.fromfile(f"{name}.bin", dtype=np.float32) + close = golden.shape == output.shape and np.allclose(golden, output, atol=atol, rtol=rtol) + if close: + return True + diff = np.nonzero(~np.isclose(golden, output, atol=atol, rtol=rtol))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + return False + + +def check_u8(name: str) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.uint8) + output = np.fromfile(f"{name}.bin", dtype=np.uint8) + if golden.shape == output.shape and np.array_equal(golden, output): + return True + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed {name} idx={idx} golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}") + return False + + +def main() -> None: + if not check_f32("v2", 1e-5, 1e-5) or not check_u8("v3"): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/golden.py b/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/golden.py new file mode 100644 index 0000000000..39f0af76f7 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/golden.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 256 +GROUPS = 2 +GROUP_SIZE = ELEMS // GROUPS +FP8_MAX = np.float32(448.0) +SCALES = np.array([0.25, 0.5], dtype=np.float32) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array([0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0, 448.0], dtype=np.float32) +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8, 0x7E], dtype=np.uint8) + + +def generate(output_dir: Path) -> None: + repeats = (GROUP_SIZE + len(Q_VALUES) - 1) // len(Q_VALUES) + q_group = np.tile(Q_VALUES, repeats)[:GROUP_SIZE].astype(np.float32) + q = np.concatenate([q_group, q_group]).astype(np.float32) + src = np.empty(ELEMS, dtype=np.float32) + golden_scale = np.full(ELEMS, SENTINEL_F32, dtype=np.float32) + for group in range(GROUPS): + begin = group * GROUP_SIZE + end = begin + GROUP_SIZE + src[begin:end] = (q_group * SCALES[group]).astype(np.float32) + amax = np.max(np.abs(src[begin:end])).astype(np.float32) + scale = np.maximum(amax, np.float32(1.0e-4)) / FP8_MAX + golden_scale[group * 8] = scale + golden_out8_group = np.tile(F8E4M3FN_BYTES, repeats)[:GROUP_SIZE].astype(np.uint8) + golden_out8 = np.concatenate([golden_out8_group, golden_out8_group]).astype(np.uint8) + + scale_out = np.full(ELEMS, SENTINEL_F32, dtype=np.float32) + out8 = np.full(ELEMS, SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + scale_out.tofile(output_dir / "v2.bin") + out8.tofile(output_dir / "v3.bin") + golden_scale.tofile(output_dir / "golden_v2.bin") + golden_out8.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/kernel.pto b/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/kernel.pto new file mode 100644 index 0000000000..f2dcc0cd16 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/kernel.pto @@ -0,0 +1,79 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_simdvf_per_token_cast_to_fp8_kernel(%src_gm: !pto.ptr, + %scale_gm: !pto.ptr, + %out8_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %eps = arith.constant 1.000000e-04 : f32 + %fp8_max = arith.constant 4.480000e+02 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out8_u8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out8_f8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %scale_gm, %ub_scale, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out8_gm, %ub_out8_u8, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax_raw = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + %eps1 = pto.vmi.broadcast %eps : f32 -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.maxf %amax_raw, %eps1 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %fp8_max1 = pto.vmi.broadcast %fp8_max : f32 -> !pto.vmi.vreg<256xf32> + %scale = pto.vmi.divf %amax, %fp8_max1 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %scale, %ub_scale[%c0], %c8 {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale + {num_groups = 2} : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %q = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %q + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %ub_out8_f8[%c0] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out8_u8, %out8_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/launch.cpp b/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/launch.cpp new file mode 100644 index 0000000000..630c7d55af --- /dev/null +++ b/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/launch.cpp @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_simdvf_per_token_cast_to_fp8_kernel(__gm__ float *src, + __gm__ float *scale, + __gm__ uint8_t *out8); + +void LaunchVmi_simdvf_per_token_cast_to_fp8_kernel(float *src, float *scale, + uint8_t *out8, + void *stream) { + vmi_simdvf_per_token_cast_to_fp8_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)scale, (__gm__ uint8_t *)out8); +} diff --git a/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/main.cpp b/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/main.cpp new file mode 100644 index 0000000000..cbb7149b86 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/main.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_simdvf_per_token_cast_to_fp8_kernel(float *src, float *scale, + uint8_t *out8, + void *stream); + +int main() { + constexpr size_t kElems = 256; + size_t srcBytes = kElems * sizeof(float); + size_t scaleBytes = kElems * sizeof(float); + size_t out8Bytes = kElems * sizeof(uint8_t); + float *srcHost = nullptr; + float *scaleHost = nullptr; + uint8_t *out8Host = nullptr; + float *srcDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *out8Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&out8Host), out8Bytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&out8Device, out8Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v3.bin", out8Bytes, out8Host, out8Bytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(out8Device, out8Bytes, out8Host, out8Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_simdvf_per_token_cast_to_fp8_kernel(srcDevice, scaleDevice, + out8Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(out8Host, out8Bytes, out8Device, out8Bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", scaleHost, scaleBytes); + WriteFile("./v3.bin", out8Host, out8Bytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(out8Device); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(out8Host); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/ptoas.flags b/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi From 500c2fd1e3768d9f421101ef1c14cfe1ea696b0f Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Fri, 26 Jun 2026 09:29:18 +0000 Subject: [PATCH 31/31] Add a vmi version of per block cast to fp8 --- .../simdvf-per-block-cast-to-fp8/kernel.pto | 81 +++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/kernel.pto diff --git a/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/kernel.pto b/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/kernel.pto new file mode 100644 index 0000000000..b279df6a90 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/kernel.pto @@ -0,0 +1,81 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_simdvf_per_block_cast_to_fp8_kernel(%src_gm: !pto.ptr, + %scale_gm: !pto.ptr, + %out8_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %num_per_tokens = arith.constant 4 : index + %num_sf_rows_per_block = arith.constant 4 : index + %num_sf_cols_per_block = arith.constant 8 : index + %num_per_channels = arith.constant 32 : index + %block_k = arith.muli %num_sf_cols_per_block, %num_per_channels : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %block_m_i64 = arith.constant 16 : i64 + %block_k_bytes_f16 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out8_u8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out8_f8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %block_k_bytes_f16 + nburst(%block_m_i64, %block_k_bytes_f16, %block_k_bytes_f16) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %scale_gm, %ub_scale, %c0_i64, %block_k_bytes_f16 + nburst(%c1_i64, %block_k_bytes_f16, %block_k_bytes_f16) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out8_gm, %ub_out8_u8, %c0_i64, %block_k_bytes_f16 + nburst(%block_m_i64, %block_k_bytes_f16, %block_k_bytes_f16) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %sf_i = %c0 to %num_sf_rows_per_block step %c1 { + %sf_row_offset = arith.muli %sf_i, %num_sf_cols_per_block : index + // VMI group_broadcast uses num_groups as the number of scale factors. + // Each scale factor covers block_k / num_sf_cols_per_block = num_per_channels lanes. + %sf_slots = pto.vmi.group_slot_load %ub_scale[%sf_row_offset], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %sf = pto.vmi.group_broadcast %sf_slots {num_groups = 8} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + scf.for %token_j = %c0 to %num_per_tokens step %c1 { + %sf_token_row_base = arith.muli %sf_i, %num_per_tokens : index + %token_row = arith.addi %sf_token_row_base, %token_j : index + %row_elem_offset = arith.muli %token_row, %block_k : index + %x16 = pto.vmi.load %ub_src[%row_elem_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x32 = pto.vmi.extf %x16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %x32, %sf + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %out8 = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %out8, %ub_out8_f8[%row_elem_offset] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out8_u8, %out8_gm, %block_k_bytes_f16 + nburst(%block_m_i64, %block_k_bytes_f16, %block_k_bytes_f16) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +}