-
Notifications
You must be signed in to change notification settings - Fork 68
fix error in tload template for cube kernel #879
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -353,15 +353,17 @@ def _constraint_tload_mat_nd2nz(src, dst) -> bool: | |||||||||||||||||||||||||||||||||||||||
| # COL_MAJOR + ROW_MAJOR corresponds to NZ format | ||||||||||||||||||||||||||||||||||||||||
| if b_layout_value not in {"col_major", "COL_MAJOR"} or s_layout_value not in {"row_major", "ROW_MAJOR"}: | ||||||||||||||||||||||||||||||||||||||||
| return False | ||||||||||||||||||||||||||||||||||||||||
| # ND2NZ: source is in ND (row-major) format where the inner dimension (g4) | ||||||||||||||||||||||||||||||||||||||||
| # corresponds to the tile column count. Disambiguates from DN format where | ||||||||||||||||||||||||||||||||||||||||
| # g4 corresponds to the tile row count. | ||||||||||||||||||||||||||||||||||||||||
| if hasattr(src, 'rank') and src.rank == 5: | ||||||||||||||||||||||||||||||||||||||||
| dst_valid_cols = dst.valid_shape[1] if hasattr(dst, 'valid_shape') and dst.valid_shape is not None else None | ||||||||||||||||||||||||||||||||||||||||
| if dst_valid_cols is not None and hasattr(src, 'shape') and src.shape is not None: | ||||||||||||||||||||||||||||||||||||||||
| src_inner = src.shape[4] if len(src.shape) >= 5 else None | ||||||||||||||||||||||||||||||||||||||||
| if src_inner is not None: | ||||||||||||||||||||||||||||||||||||||||
| if not _known_eq(dst_valid_cols, src_inner): | ||||||||||||||||||||||||||||||||||||||||
| # Disambiguate ND vs DN via the source view layout attribute | ||||||||||||||||||||||||||||||||||||||||
| # (e.g. layout = #pto.layout<nd>). When the layout is not annotated the | ||||||||||||||||||||||||||||||||||||||||
| # constraint conservatively returns True to avoid false rejections. | ||||||||||||||||||||||||||||||||||||||||
| # src.config is always a ViewConfig for view operands — never None. | ||||||||||||||||||||||||||||||||||||||||
| if hasattr(src, 'config') and src.config is not None: | ||||||||||||||||||||||||||||||||||||||||
| view_config = src.config | ||||||||||||||||||||||||||||||||||||||||
| if hasattr(view_config, 'layout'): | ||||||||||||||||||||||||||||||||||||||||
| layout = view_config.layout | ||||||||||||||||||||||||||||||||||||||||
| if layout is not None: | ||||||||||||||||||||||||||||||||||||||||
| layout_val = layout.value | ||||||||||||||||||||||||||||||||||||||||
| if layout_val.upper() not in {'ND', 'NZ'}: | ||||||||||||||||||||||||||||||||||||||||
| return False | ||||||||||||||||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
@@ -381,15 +383,17 @@ def _constraint_tload_mat_dn2nz(src, dst) -> bool: | |||||||||||||||||||||||||||||||||||||||
| s_layout_value = s_layout.value if hasattr(s_layout, "value") else s_layout | ||||||||||||||||||||||||||||||||||||||||
| if b_layout_value not in {"col_major", "COL_MAJOR"} or s_layout_value not in {"row_major", "ROW_MAJOR"}: | ||||||||||||||||||||||||||||||||||||||||
| return False | ||||||||||||||||||||||||||||||||||||||||
| # DN2NZ: source is in DN (col-major) format where the inner dimension (g4) | ||||||||||||||||||||||||||||||||||||||||
| # corresponds to the tile row count. Disambiguates from ND format where | ||||||||||||||||||||||||||||||||||||||||
| # g4 corresponds to the tile column count. | ||||||||||||||||||||||||||||||||||||||||
| if hasattr(src, 'rank') and src.rank == 5: | ||||||||||||||||||||||||||||||||||||||||
| dst_valid_rows = dst.valid_shape[0] if hasattr(dst, 'valid_shape') and dst.valid_shape is not None else None | ||||||||||||||||||||||||||||||||||||||||
| if dst_valid_rows is not None and hasattr(src, 'shape') and src.shape is not None: | ||||||||||||||||||||||||||||||||||||||||
| src_inner = src.shape[4] if len(src.shape) >= 5 else None | ||||||||||||||||||||||||||||||||||||||||
| if src_inner is not None: | ||||||||||||||||||||||||||||||||||||||||
| if not _known_eq(dst_valid_rows, src_inner): | ||||||||||||||||||||||||||||||||||||||||
| # Disambiguate DN vs ND via the source view layout attribute | ||||||||||||||||||||||||||||||||||||||||
| # (e.g. layout = #pto.layout<dn>). When the layout is not annotated the | ||||||||||||||||||||||||||||||||||||||||
| # constraint conservatively returns True to avoid false rejections. | ||||||||||||||||||||||||||||||||||||||||
| # src.config is always a ViewConfig for view operands — never None. | ||||||||||||||||||||||||||||||||||||||||
| if hasattr(src, 'config') and src.config is not None: | ||||||||||||||||||||||||||||||||||||||||
| view_config = src.config | ||||||||||||||||||||||||||||||||||||||||
| if hasattr(view_config, 'layout'): | ||||||||||||||||||||||||||||||||||||||||
| layout = view_config.layout | ||||||||||||||||||||||||||||||||||||||||
| if layout is not None: | ||||||||||||||||||||||||||||||||||||||||
| layout_val = layout.value | ||||||||||||||||||||||||||||||||||||||||
| if layout_val.upper() not in {'DN', 'NZ'}: | ||||||||||||||||||||||||||||||||||||||||
| return False | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+390
to
397
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the
Suggested change
|
||||||||||||||||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
@@ -501,3 +505,84 @@ def template_tload_gm_to_mat_dn2nz(src: pto.PartitionTensorView, dst: pto.Tile): | |||||||||||||||||||||||||||||||||||||||
| ctrl=(0, False) | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def _constraint_tload_mat_dn2zn(src, dst) -> bool: | ||||||||||||||||||||||||||||||||||||||||
| """TLOAD.MAT DN2ZN fractal load constraint (transposed DN→ZN) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| DN2ZN loads col-major (DN) data from GM into L1 MAT in ZN format | ||||||||||||||||||||||||||||||||||||||||
| (blayout=row_major, slayout=col_major). Internally reuses the nd2nz | ||||||||||||||||||||||||||||||||||||||||
| hardware path with transposed parameters — see TLoadCubeDN2ZN. | ||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||
| if not _constraint_tload_mat_base(src, dst): | ||||||||||||||||||||||||||||||||||||||||
| return False | ||||||||||||||||||||||||||||||||||||||||
| config = dst.config | ||||||||||||||||||||||||||||||||||||||||
| if config is None: | ||||||||||||||||||||||||||||||||||||||||
| return False | ||||||||||||||||||||||||||||||||||||||||
| b_layout = config.b_layout | ||||||||||||||||||||||||||||||||||||||||
| s_layout = config.s_layout | ||||||||||||||||||||||||||||||||||||||||
| if b_layout is None or s_layout is None: | ||||||||||||||||||||||||||||||||||||||||
| return False | ||||||||||||||||||||||||||||||||||||||||
| b_layout_value = b_layout.value if hasattr(b_layout, "value") else b_layout | ||||||||||||||||||||||||||||||||||||||||
| s_layout_value = s_layout.value if hasattr(s_layout, "value") else s_layout | ||||||||||||||||||||||||||||||||||||||||
| # ZN format: blayout=row_major, slayout=col_major | ||||||||||||||||||||||||||||||||||||||||
| if b_layout_value not in {"row_major", "ROW_MAJOR"} or s_layout_value not in {"col_major", "COL_MAJOR"}: | ||||||||||||||||||||||||||||||||||||||||
| return False | ||||||||||||||||||||||||||||||||||||||||
| # Use the view layout to disambiguate: DN2ZN only matches DN or NZ sources. | ||||||||||||||||||||||||||||||||||||||||
| # src.config is always a ViewConfig for view operands — never None. | ||||||||||||||||||||||||||||||||||||||||
| if hasattr(src, 'config') and src.config is not None: | ||||||||||||||||||||||||||||||||||||||||
| view_config = src.config | ||||||||||||||||||||||||||||||||||||||||
| if hasattr(view_config, 'layout'): | ||||||||||||||||||||||||||||||||||||||||
| layout = view_config.layout | ||||||||||||||||||||||||||||||||||||||||
| if layout is not None: | ||||||||||||||||||||||||||||||||||||||||
| layout_val = layout.value | ||||||||||||||||||||||||||||||||||||||||
| if layout_val.upper() not in {'DN', 'NZ'}: | ||||||||||||||||||||||||||||||||||||||||
| return False | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+533
to
+540
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the
Suggested change
|
||||||||||||||||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| @pto.ckernel( | ||||||||||||||||||||||||||||||||||||||||
| target="a5", | ||||||||||||||||||||||||||||||||||||||||
| op="pto.tload", | ||||||||||||||||||||||||||||||||||||||||
| priority=1, | ||||||||||||||||||||||||||||||||||||||||
| dtypes=[ | ||||||||||||||||||||||||||||||||||||||||
| (pto.f16, pto.f16), | ||||||||||||||||||||||||||||||||||||||||
| (pto.bf16, pto.bf16), | ||||||||||||||||||||||||||||||||||||||||
| (pto.f32, pto.f32), | ||||||||||||||||||||||||||||||||||||||||
| ], | ||||||||||||||||||||||||||||||||||||||||
| constraints=[_constraint_tload_mat_dn2zn], | ||||||||||||||||||||||||||||||||||||||||
| name="tload_gm_to_mat_dn2zn", | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| def template_tload_gm_to_mat_dn2zn(src: pto.PartitionTensorView, dst: pto.Tile): | ||||||||||||||||||||||||||||||||||||||||
| """GM -> MAT DN2ZN fractal load template (transposed DN → ZN) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| Load Col-Major (DN) format data from GM into L1 MAT Buffer in ZN format | ||||||||||||||||||||||||||||||||||||||||
| (transposed NZ: blayout=row_major, slayout=col_major). | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| Reuses the ND2NZ hardware mode with transposed parameter mapping, matching | ||||||||||||||||||||||||||||||||||||||||
| the C++ ISA TLoadCubeDN2ZN (TLoad.hpp:472) which calls | ||||||||||||||||||||||||||||||||||||||||
| TLoadCubeInstr<Layout::ND> with nValue=gShape4, dValue=validRow. | ||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||
| m, k = dst.valid_shape | ||||||||||||||||||||||||||||||||||||||||
| g0, g1, g2, g3, g4 = src.shape | ||||||||||||||||||||||||||||||||||||||||
| s0, s1, s2, s3, s4 = src.strides | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| gm_ptr = src.as_ptr() | ||||||||||||||||||||||||||||||||||||||||
| mat_ptr = dst.as_ptr() | ||||||||||||||||||||||||||||||||||||||||
| elem_bytes = pto.bytewidth(dst.element_type) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # TLoadCubeDN2ZN: nValue = gShape4, dValue = validRow | ||||||||||||||||||||||||||||||||||||||||
| n_value = g4 | ||||||||||||||||||||||||||||||||||||||||
| d_value = m | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # TLoadCubeDN2ZN: loop1SrcStride = GetByteSize(gStride4) | ||||||||||||||||||||||||||||||||||||||||
| src_inner_stride = s4 * elem_bytes | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+572
to
+579
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the TileLang DSL,
Suggested change
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| pto.mte_gm_l1_frac( | ||||||||||||||||||||||||||||||||||||||||
| gm_ptr, mat_ptr, pto.FractalMode.ND2NZ, | ||||||||||||||||||||||||||||||||||||||||
| shape=(n_value, d_value), | ||||||||||||||||||||||||||||||||||||||||
| src_layout=(src_inner_stride,), | ||||||||||||||||||||||||||||||||||||||||
| dst_group=(1, 1, k, 0), | ||||||||||||||||||||||||||||||||||||||||
| ctrl=(0, False) | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the
layoutattribute is an MLIR custom attribute (e.g.,#pto.layout<nd>), its string representation will contain the full assembly format. Checkingstr(layout_val).upper() not in {'ND', 'NZ'}directly will fail because"#PTO.LAYOUT<ND>"is not in the set, causing the constraint to always returnFalseand reject valid layouts. We should extract the layout name from the angle brackets if present to ensure robustness.