Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 103 additions & 18 deletions lib/TileOps/tload_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,15 +353,17 @@ def _constraint_tload_mat_nd2nz(src, dst) -> bool:
# COL_MAJOR + ROW_MAJOR corresponds to NZ format
if b_layout_value not in {"col_major", "COL_MAJOR"} or s_layout_value not in {"row_major", "ROW_MAJOR"}:
return False
# ND2NZ: source is in ND (row-major) format where the inner dimension (g4)
# corresponds to the tile column count. Disambiguates from DN format where
# g4 corresponds to the tile row count.
if hasattr(src, 'rank') and src.rank == 5:
dst_valid_cols = dst.valid_shape[1] if hasattr(dst, 'valid_shape') and dst.valid_shape is not None else None
if dst_valid_cols is not None and hasattr(src, 'shape') and src.shape is not None:
src_inner = src.shape[4] if len(src.shape) >= 5 else None
if src_inner is not None:
if not _known_eq(dst_valid_cols, src_inner):
# Disambiguate ND vs DN via the source view layout attribute
# (e.g. layout = #pto.layout<nd>). When the layout is not annotated the
# constraint conservatively returns True to avoid false rejections.
# src.config is always a ViewConfig for view operands — never None.
if hasattr(src, 'config') and src.config is not None:
view_config = src.config
if hasattr(view_config, 'layout'):
layout = view_config.layout
if layout is not None:
layout_val = layout.value
if layout_val.upper() not in {'ND', 'NZ'}:
return False
Comment on lines +360 to 367

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

If the layout attribute is an MLIR custom attribute (e.g., #pto.layout<nd>), its string representation will contain the full assembly format. Checking str(layout_val).upper() not in {'ND', 'NZ'} directly will fail because "#PTO.LAYOUT<ND>" is not in the set, causing the constraint to always return False and reject valid layouts. We should extract the layout name from the angle brackets if present to ensure robustness.

Suggested change
if hasattr(src, 'config') and src.config is not None:
view_config = src.config
if hasattr(view_config, 'layout'):
layout = view_config.layout
if layout is not None:
layout_val = layout.value if hasattr(layout, 'value') else str(layout)
if str(layout_val).upper() not in {'ND', 'NZ'}:
return False
if hasattr(src, 'config') and src.config is not None:
view_config = src.config
if hasattr(view_config, 'layout'):
layout = view_config.layout
if layout is not None:
layout_val = layout.value if hasattr(layout, 'value') else str(layout)
layout_str = str(layout_val).upper()
if '<' in layout_str:
layout_str = layout_str.split('<')[-1].strip('>')
if layout_str not in {'ND', 'NZ'}:
return False

return True

Expand All @@ -381,15 +383,17 @@ def _constraint_tload_mat_dn2nz(src, dst) -> bool:
s_layout_value = s_layout.value if hasattr(s_layout, "value") else s_layout
if b_layout_value not in {"col_major", "COL_MAJOR"} or s_layout_value not in {"row_major", "ROW_MAJOR"}:
return False
# DN2NZ: source is in DN (col-major) format where the inner dimension (g4)
# corresponds to the tile row count. Disambiguates from ND format where
# g4 corresponds to the tile column count.
if hasattr(src, 'rank') and src.rank == 5:
dst_valid_rows = dst.valid_shape[0] if hasattr(dst, 'valid_shape') and dst.valid_shape is not None else None
if dst_valid_rows is not None and hasattr(src, 'shape') and src.shape is not None:
src_inner = src.shape[4] if len(src.shape) >= 5 else None
if src_inner is not None:
if not _known_eq(dst_valid_rows, src_inner):
# Disambiguate DN vs ND via the source view layout attribute
# (e.g. layout = #pto.layout<dn>). When the layout is not annotated the
# constraint conservatively returns True to avoid false rejections.
# src.config is always a ViewConfig for view operands — never None.
if hasattr(src, 'config') and src.config is not None:
view_config = src.config
if hasattr(view_config, 'layout'):
layout = view_config.layout
if layout is not None:
layout_val = layout.value
if layout_val.upper() not in {'DN', 'NZ'}:
return False
Comment on lines +390 to 397

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

If the layout attribute is an MLIR custom attribute (e.g., #pto.layout<dn>), its string representation will contain the full assembly format. Checking str(layout_val).upper() not in {'DN', 'NZ'} directly will fail because "#PTO.LAYOUT<DN>" is not in the set, causing the constraint to always return False and reject valid layouts. We should extract the layout name from the angle brackets if present to ensure robustness.

Suggested change
if hasattr(src, 'config') and src.config is not None:
view_config = src.config
if hasattr(view_config, 'layout'):
layout = view_config.layout
if layout is not None:
layout_val = layout.value if hasattr(layout, 'value') else str(layout)
if str(layout_val).upper() not in {'DN', 'NZ'}:
return False
if hasattr(src, 'config') and src.config is not None:
view_config = src.config
if hasattr(view_config, 'layout'):
layout = view_config.layout
if layout is not None:
layout_val = layout.value if hasattr(layout, 'value') else str(layout)
layout_str = str(layout_val).upper()
if '<' in layout_str:
layout_str = layout_str.split('<')[-1].strip('>')
if layout_str not in {'DN', 'NZ'}:
return False

return True

Expand Down Expand Up @@ -501,3 +505,84 @@ def template_tload_gm_to_mat_dn2nz(src: pto.PartitionTensorView, dst: pto.Tile):
ctrl=(0, False)
)
return


def _constraint_tload_mat_dn2zn(src, dst) -> bool:
"""TLOAD.MAT DN2ZN fractal load constraint (transposed DN→ZN)

DN2ZN loads col-major (DN) data from GM into L1 MAT in ZN format
(blayout=row_major, slayout=col_major). Internally reuses the nd2nz
hardware path with transposed parameters — see TLoadCubeDN2ZN.
"""
if not _constraint_tload_mat_base(src, dst):
return False
config = dst.config
if config is None:
return False
b_layout = config.b_layout
s_layout = config.s_layout
if b_layout is None or s_layout is None:
return False
b_layout_value = b_layout.value if hasattr(b_layout, "value") else b_layout
s_layout_value = s_layout.value if hasattr(s_layout, "value") else s_layout
# ZN format: blayout=row_major, slayout=col_major
if b_layout_value not in {"row_major", "ROW_MAJOR"} or s_layout_value not in {"col_major", "COL_MAJOR"}:
return False
# Use the view layout to disambiguate: DN2ZN only matches DN or NZ sources.
# src.config is always a ViewConfig for view operands — never None.
if hasattr(src, 'config') and src.config is not None:
view_config = src.config
if hasattr(view_config, 'layout'):
layout = view_config.layout
if layout is not None:
layout_val = layout.value
if layout_val.upper() not in {'DN', 'NZ'}:
return False
Comment on lines +533 to +540

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

If the layout attribute is an MLIR custom attribute (e.g., #pto.layout<dn>), its string representation will contain the full assembly format. Checking str(layout_val).upper() not in {'DN', 'NZ'} directly will fail because "#PTO.LAYOUT<DN>" is not in the set, causing the constraint to always return False and reject valid layouts. We should extract the layout name from the angle brackets if present to ensure robustness.

Suggested change
if hasattr(src, 'config') and src.config is not None:
view_config = src.config
if hasattr(view_config, 'layout'):
layout = view_config.layout
if layout is not None:
layout_val = layout.value if hasattr(layout, 'value') else str(layout)
if str(layout_val).upper() not in {'DN', 'NZ'}:
return False
if hasattr(src, 'config') and src.config is not None:
view_config = src.config
if hasattr(view_config, 'layout'):
layout = view_config.layout
if layout is not None:
layout_val = layout.value if hasattr(layout, 'value') else str(layout)
layout_str = str(layout_val).upper()
if '<' in layout_str:
layout_str = layout_str.split('<')[-1].strip('>')
if layout_str not in {'DN', 'NZ'}:
return False

return True


@pto.ckernel(
target="a5",
op="pto.tload",
priority=1,
dtypes=[
(pto.f16, pto.f16),
(pto.bf16, pto.bf16),
(pto.f32, pto.f32),
],
constraints=[_constraint_tload_mat_dn2zn],
name="tload_gm_to_mat_dn2zn",
)
def template_tload_gm_to_mat_dn2zn(src: pto.PartitionTensorView, dst: pto.Tile):
"""GM -> MAT DN2ZN fractal load template (transposed DN → ZN)

Load Col-Major (DN) format data from GM into L1 MAT Buffer in ZN format
(transposed NZ: blayout=row_major, slayout=col_major).

Reuses the ND2NZ hardware mode with transposed parameter mapping, matching
the C++ ISA TLoadCubeDN2ZN (TLoad.hpp:472) which calls
TLoadCubeInstr<Layout::ND> with nValue=gShape4, dValue=validRow.
"""
m, k = dst.valid_shape
g0, g1, g2, g3, g4 = src.shape
s0, s1, s2, s3, s4 = src.strides

gm_ptr = src.as_ptr()
mat_ptr = dst.as_ptr()
elem_bytes = pto.bytewidth(dst.element_type)

# TLoadCubeDN2ZN: nValue = gShape4, dValue = validRow
n_value = g4
d_value = m

# TLoadCubeDN2ZN: loop1SrcStride = GetByteSize(gStride4)
src_inner_stride = s4 * elem_bytes
Comment on lines +572 to +579

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

In the TileLang DSL, pto.mte_gm_l1_frac expects the src_layout stride to be specified in terms of elements, not bytes. The DSL compiler internally handles the multiplication by elem_bytes based on the data type. Multiplying s4 by elem_bytes here results in a duplicate multiplication, causing the stride to be incorrect (too large) and leading to memory corruption or incorrect data loading. We should pass s4 directly.

Suggested change
elem_bytes = pto.bytewidth(dst.element_type)
# TLoadCubeDN2ZN: nValue = gShape4, dValue = validRow
n_value = g4
d_value = m
# TLoadCubeDN2ZN: loop1SrcStride = GetByteSize(gStride4)
src_inner_stride = s4 * elem_bytes
# TLoadCubeDN2ZN: nValue = gShape4, dValue = validRow
n_value = g4
d_value = m
# TLoadCubeDN2ZN: loop1SrcStride = gStride4
src_inner_stride = s4


pto.mte_gm_l1_frac(
gm_ptr, mat_ptr, pto.FractalMode.ND2NZ,
shape=(n_value, d_value),
src_layout=(src_inner_stride,),
dst_group=(1, 1, k, 0),
ctrl=(0, False)
)
return
Loading