From c2fb532d78edd593d58ccbc5dbafd8ba8d620cbf Mon Sep 17 00:00:00 2001 From: likai00 Date: Mon, 29 Jun 2026 15:56:15 +0800 Subject: [PATCH] fix error in tload template for cube kernel --- lib/TileOps/tload_template.py | 121 +++++++++++++++++++++++++++++----- 1 file changed, 103 insertions(+), 18 deletions(-) diff --git a/lib/TileOps/tload_template.py b/lib/TileOps/tload_template.py index 9e40bedb3..bcfe21a84 100644 --- a/lib/TileOps/tload_template.py +++ b/lib/TileOps/tload_template.py @@ -353,15 +353,17 @@ def _constraint_tload_mat_nd2nz(src, dst) -> bool: # COL_MAJOR + ROW_MAJOR corresponds to NZ format if b_layout_value not in {"col_major", "COL_MAJOR"} or s_layout_value not in {"row_major", "ROW_MAJOR"}: return False - # ND2NZ: source is in ND (row-major) format where the inner dimension (g4) - # corresponds to the tile column count. Disambiguates from DN format where - # g4 corresponds to the tile row count. - if hasattr(src, 'rank') and src.rank == 5: - dst_valid_cols = dst.valid_shape[1] if hasattr(dst, 'valid_shape') and dst.valid_shape is not None else None - if dst_valid_cols is not None and hasattr(src, 'shape') and src.shape is not None: - src_inner = src.shape[4] if len(src.shape) >= 5 else None - if src_inner is not None: - if not _known_eq(dst_valid_cols, src_inner): + # Disambiguate ND vs DN via the source view layout attribute + # (e.g. layout = #pto.layout). When the layout is not annotated the + # constraint conservatively returns True to avoid false rejections. + # src.config is always a ViewConfig for view operands — never None. + if hasattr(src, 'config') and src.config is not None: + view_config = src.config + if hasattr(view_config, 'layout'): + layout = view_config.layout + if layout is not None: + layout_val = layout.value + if layout_val.upper() not in {'ND', 'NZ'}: return False return True @@ -381,15 +383,17 @@ def _constraint_tload_mat_dn2nz(src, dst) -> bool: s_layout_value = s_layout.value if hasattr(s_layout, "value") else s_layout if b_layout_value not in {"col_major", "COL_MAJOR"} or s_layout_value not in {"row_major", "ROW_MAJOR"}: return False - # DN2NZ: source is in DN (col-major) format where the inner dimension (g4) - # corresponds to the tile row count. Disambiguates from ND format where - # g4 corresponds to the tile column count. - if hasattr(src, 'rank') and src.rank == 5: - dst_valid_rows = dst.valid_shape[0] if hasattr(dst, 'valid_shape') and dst.valid_shape is not None else None - if dst_valid_rows is not None and hasattr(src, 'shape') and src.shape is not None: - src_inner = src.shape[4] if len(src.shape) >= 5 else None - if src_inner is not None: - if not _known_eq(dst_valid_rows, src_inner): + # Disambiguate DN vs ND via the source view layout attribute + # (e.g. layout = #pto.layout). When the layout is not annotated the + # constraint conservatively returns True to avoid false rejections. + # src.config is always a ViewConfig for view operands — never None. + if hasattr(src, 'config') and src.config is not None: + view_config = src.config + if hasattr(view_config, 'layout'): + layout = view_config.layout + if layout is not None: + layout_val = layout.value + if layout_val.upper() not in {'DN', 'NZ'}: return False return True @@ -501,3 +505,84 @@ def template_tload_gm_to_mat_dn2nz(src: pto.PartitionTensorView, dst: pto.Tile): ctrl=(0, False) ) return + + +def _constraint_tload_mat_dn2zn(src, dst) -> bool: + """TLOAD.MAT DN2ZN fractal load constraint (transposed DN→ZN) + + DN2ZN loads col-major (DN) data from GM into L1 MAT in ZN format + (blayout=row_major, slayout=col_major). Internally reuses the nd2nz + hardware path with transposed parameters — see TLoadCubeDN2ZN. + """ + if not _constraint_tload_mat_base(src, dst): + return False + config = dst.config + if config is None: + return False + b_layout = config.b_layout + s_layout = config.s_layout + if b_layout is None or s_layout is None: + return False + b_layout_value = b_layout.value if hasattr(b_layout, "value") else b_layout + s_layout_value = s_layout.value if hasattr(s_layout, "value") else s_layout + # ZN format: blayout=row_major, slayout=col_major + if b_layout_value not in {"row_major", "ROW_MAJOR"} or s_layout_value not in {"col_major", "COL_MAJOR"}: + return False + # Use the view layout to disambiguate: DN2ZN only matches DN or NZ sources. + # src.config is always a ViewConfig for view operands — never None. + if hasattr(src, 'config') and src.config is not None: + view_config = src.config + if hasattr(view_config, 'layout'): + layout = view_config.layout + if layout is not None: + layout_val = layout.value + if layout_val.upper() not in {'DN', 'NZ'}: + return False + return True + + +@pto.ckernel( + target="a5", + op="pto.tload", + priority=1, + dtypes=[ + (pto.f16, pto.f16), + (pto.bf16, pto.bf16), + (pto.f32, pto.f32), + ], + constraints=[_constraint_tload_mat_dn2zn], + name="tload_gm_to_mat_dn2zn", +) +def template_tload_gm_to_mat_dn2zn(src: pto.PartitionTensorView, dst: pto.Tile): + """GM -> MAT DN2ZN fractal load template (transposed DN → ZN) + + Load Col-Major (DN) format data from GM into L1 MAT Buffer in ZN format + (transposed NZ: blayout=row_major, slayout=col_major). + + Reuses the ND2NZ hardware mode with transposed parameter mapping, matching + the C++ ISA TLoadCubeDN2ZN (TLoad.hpp:472) which calls + TLoadCubeInstr with nValue=gShape4, dValue=validRow. + """ + m, k = dst.valid_shape + g0, g1, g2, g3, g4 = src.shape + s0, s1, s2, s3, s4 = src.strides + + gm_ptr = src.as_ptr() + mat_ptr = dst.as_ptr() + elem_bytes = pto.bytewidth(dst.element_type) + + # TLoadCubeDN2ZN: nValue = gShape4, dValue = validRow + n_value = g4 + d_value = m + + # TLoadCubeDN2ZN: loop1SrcStride = GetByteSize(gStride4) + src_inner_stride = s4 * elem_bytes + + pto.mte_gm_l1_frac( + gm_ptr, mat_ptr, pto.FractalMode.ND2NZ, + shape=(n_value, d_value), + src_layout=(src_inner_stride,), + dst_group=(1, 1, k, 0), + ctrl=(0, False) + ) + return