fix error in tload template for cube kernel#879
Conversation
|
Fixes #877 |
There was a problem hiding this comment.
Code Review
This pull request updates the tload_template.py file by refactoring the nd2nz and dn2nz constraints to use the source view layout attribute for disambiguation, and introduces a new dn2zn constraint and template (template_tload_gm_to_mat_dn2zn) to support transposed DN to ZN fractal loads. The code review feedback highlights two critical issues: first, parsing MLIR custom attributes (e.g., #pto.layout<nd>) directly as strings will fail the layout checks, so the layout name needs to be extracted from the angle brackets; second, the stride passed to pto.mte_gm_l1_frac in the new template should be in elements rather than bytes, as multiplying by elem_bytes leads to an incorrect, duplicated multiplication.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if hasattr(src, 'config') and src.config is not None: | ||
| view_config = src.config | ||
| if hasattr(view_config, 'layout'): | ||
| layout = view_config.layout | ||
| if layout is not None: | ||
| layout_val = layout.value if hasattr(layout, 'value') else str(layout) | ||
| if str(layout_val).upper() not in {'ND', 'NZ'}: | ||
| return False |
There was a problem hiding this comment.
If the layout attribute is an MLIR custom attribute (e.g., #pto.layout<nd>), its string representation will contain the full assembly format. Checking str(layout_val).upper() not in {'ND', 'NZ'} directly will fail because "#PTO.LAYOUT<ND>" is not in the set, causing the constraint to always return False and reject valid layouts. We should extract the layout name from the angle brackets if present to ensure robustness.
| if hasattr(src, 'config') and src.config is not None: | |
| view_config = src.config | |
| if hasattr(view_config, 'layout'): | |
| layout = view_config.layout | |
| if layout is not None: | |
| layout_val = layout.value if hasattr(layout, 'value') else str(layout) | |
| if str(layout_val).upper() not in {'ND', 'NZ'}: | |
| return False | |
| if hasattr(src, 'config') and src.config is not None: | |
| view_config = src.config | |
| if hasattr(view_config, 'layout'): | |
| layout = view_config.layout | |
| if layout is not None: | |
| layout_val = layout.value if hasattr(layout, 'value') else str(layout) | |
| layout_str = str(layout_val).upper() | |
| if '<' in layout_str: | |
| layout_str = layout_str.split('<')[-1].strip('>') | |
| if layout_str not in {'ND', 'NZ'}: | |
| return False |
| if hasattr(src, 'config') and src.config is not None: | ||
| view_config = src.config | ||
| if hasattr(view_config, 'layout'): | ||
| layout = view_config.layout | ||
| if layout is not None: | ||
| layout_val = layout.value if hasattr(layout, 'value') else str(layout) | ||
| if str(layout_val).upper() not in {'DN', 'NZ'}: | ||
| return False |
There was a problem hiding this comment.
If the layout attribute is an MLIR custom attribute (e.g., #pto.layout<dn>), its string representation will contain the full assembly format. Checking str(layout_val).upper() not in {'DN', 'NZ'} directly will fail because "#PTO.LAYOUT<DN>" is not in the set, causing the constraint to always return False and reject valid layouts. We should extract the layout name from the angle brackets if present to ensure robustness.
| if hasattr(src, 'config') and src.config is not None: | |
| view_config = src.config | |
| if hasattr(view_config, 'layout'): | |
| layout = view_config.layout | |
| if layout is not None: | |
| layout_val = layout.value if hasattr(layout, 'value') else str(layout) | |
| if str(layout_val).upper() not in {'DN', 'NZ'}: | |
| return False | |
| if hasattr(src, 'config') and src.config is not None: | |
| view_config = src.config | |
| if hasattr(view_config, 'layout'): | |
| layout = view_config.layout | |
| if layout is not None: | |
| layout_val = layout.value if hasattr(layout, 'value') else str(layout) | |
| layout_str = str(layout_val).upper() | |
| if '<' in layout_str: | |
| layout_str = layout_str.split('<')[-1].strip('>') | |
| if layout_str not in {'DN', 'NZ'}: | |
| return False |
| if hasattr(src, 'config') and src.config is not None: | ||
| view_config = src.config | ||
| if hasattr(view_config, 'layout'): | ||
| layout = view_config.layout | ||
| if layout is not None: | ||
| layout_val = layout.value if hasattr(layout, 'value') else str(layout) | ||
| if str(layout_val).upper() not in {'DN', 'NZ'}: | ||
| return False |
There was a problem hiding this comment.
If the layout attribute is an MLIR custom attribute (e.g., #pto.layout<dn>), its string representation will contain the full assembly format. Checking str(layout_val).upper() not in {'DN', 'NZ'} directly will fail because "#PTO.LAYOUT<DN>" is not in the set, causing the constraint to always return False and reject valid layouts. We should extract the layout name from the angle brackets if present to ensure robustness.
| if hasattr(src, 'config') and src.config is not None: | |
| view_config = src.config | |
| if hasattr(view_config, 'layout'): | |
| layout = view_config.layout | |
| if layout is not None: | |
| layout_val = layout.value if hasattr(layout, 'value') else str(layout) | |
| if str(layout_val).upper() not in {'DN', 'NZ'}: | |
| return False | |
| if hasattr(src, 'config') and src.config is not None: | |
| view_config = src.config | |
| if hasattr(view_config, 'layout'): | |
| layout = view_config.layout | |
| if layout is not None: | |
| layout_val = layout.value if hasattr(layout, 'value') else str(layout) | |
| layout_str = str(layout_val).upper() | |
| if '<' in layout_str: | |
| layout_str = layout_str.split('<')[-1].strip('>') | |
| if layout_str not in {'DN', 'NZ'}: | |
| return False |
| elem_bytes = pto.bytewidth(dst.element_type) | ||
|
|
||
| # TLoadCubeDN2ZN: nValue = gShape4, dValue = validRow | ||
| n_value = g4 | ||
| d_value = m | ||
|
|
||
| # TLoadCubeDN2ZN: loop1SrcStride = GetByteSize(gStride4) | ||
| src_inner_stride = s4 * elem_bytes |
There was a problem hiding this comment.
In the TileLang DSL, pto.mte_gm_l1_frac expects the src_layout stride to be specified in terms of elements, not bytes. The DSL compiler internally handles the multiplication by elem_bytes based on the data type. Multiplying s4 by elem_bytes here results in a duplicate multiplication, causing the stride to be incorrect (too large) and leading to memory corruption or incorrect data loading. We should pass s4 directly.
| elem_bytes = pto.bytewidth(dst.element_type) | |
| # TLoadCubeDN2ZN: nValue = gShape4, dValue = validRow | |
| n_value = g4 | |
| d_value = m | |
| # TLoadCubeDN2ZN: loop1SrcStride = GetByteSize(gStride4) | |
| src_inner_stride = s4 * elem_bytes | |
| # TLoadCubeDN2ZN: nValue = gShape4, dValue = validRow | |
| n_value = g4 | |
| d_value = m | |
| # TLoadCubeDN2ZN: loop1SrcStride = gStride4 | |
| src_inner_stride = s4 |
1fedd23 to
c2fb532
Compare
Codex Review该评论由 review 机器人自动更新。
SummaryLayout-based TLOAD.MAT matching in PR #879 regresses existing DN2NZ selection and also breaks MAT-NZ template resolution when view layout inference is unavailable. Findings
Both MAT-NZ constraints now intentionally return |
修改
lib/TileOps/tload_template.py,分两部分:1. 消歧逻辑改为 layout 优先
原来通过 rank-5 shape 做消歧(
src.shape[4] == dst.valid_shape[1]),改为通过源 view 的layout属性(src.config.layout)。constraint 函数检查layout是否匹配预期(ND2NZ 只接受ND/NZ源,DN2NZ 只接受DN/NZ源,DN2ZN 只接受DN/NZ源)。当layout未知时保守返回True(不消歧,保持旧行为)。2. 新增
tload_gm_to_mat_dn2zn模板约束条件:
src.layout为DN或NZdst.blayout=ROW_MAJOR, dst.slayout=COL_MAJOR模板实现参照
TLoad.hpp:472的TLoadCubeDN2ZN: