From a4078975b3f8e75ff2ad25507bc4d6e719d16d29 Mon Sep 17 00:00:00 2001 From: FBSourceBlackLinter Bot Date: Tue, 28 Apr 2026 05:52:49 -0700 Subject: [PATCH] Daily `arc lint --take BLACK` Reviewed By: jeremychan Differential Revision: D102768587 --- docs/generate_figures.py | 188 +++++--- examples/composed.py | 15 +- examples/layouts.py | 72 ++- examples/tensor.py | 42 +- examples/viz.py | 205 +++++++-- src/tensor_layouts/__init__.py | 6 +- src/tensor_layouts/analysis.py | 476 ++++++++++--------- src/tensor_layouts/atoms.py | 2 + src/tensor_layouts/atoms_amd.py | 635 ++++++++++++++++++-------- src/tensor_layouts/atoms_amx.py | 38 +- src/tensor_layouts/atoms_nv.py | 706 ++++++++++++++++++++--------- src/tensor_layouts/atoms_xe.py | 31 +- src/tensor_layouts/layout_utils.py | 20 +- src/tensor_layouts/layouts.py | 188 ++++++-- src/tensor_layouts/tensor.py | 4 +- src/tensor_layouts/viz.py | 182 +++++--- tests/analysis.py | 573 ++++++++++++----------- tests/composed.py | 25 +- tests/conftest.py | 9 +- tests/external.py | 187 ++++---- tests/layouts.py | 23 +- tests/oracle_amd.py | 264 +++++++---- tests/oracle_cute_cpp.py | 91 +++- tests/oracle_nv.py | 235 ++++++---- tests/oracle_rdna.py | 27 +- tests/oracle_xe.py | 24 +- tests/paper_examples.py | 67 ++- tests/tensor.py | 23 +- tests/viz.py | 224 ++++----- 29 files changed, 2915 insertions(+), 1667 deletions(-) diff --git a/docs/generate_figures.py b/docs/generate_figures.py index 693a5a7..715fde9 100644 --- a/docs/generate_figures.py +++ b/docs/generate_figures.py @@ -35,18 +35,17 @@ import matplotlib.patches as patches import matplotlib.pyplot as plt - -from tensor_layouts import Layout, Swizzle, compose +from tensor_layouts import compose, Layout, Swizzle from tensor_layouts.atoms_nv import SM80_16x8x16_F16F16F16F16_TN from tensor_layouts.layout_utils import tile_mma_grid from tensor_layouts.viz import ( draw_composite, draw_layout, + draw_mma_layout, draw_slice, draw_swizzle, draw_tiled_grid, draw_tv_layout, - draw_mma_layout, ) IMAGES = Path(__file__).resolve().parent / "images" @@ -65,8 +64,10 @@ def _generate_intile_oftile(path: Path) -> None: M, K = 4, 8 tm, tk = 2, 4 tile_colors = { - (0, 0): "#DBEAFE", (0, 1): "#FEE2E2", - (1, 0): "#D1FAE5", (1, 1): "#EDE9FE", + (0, 0): "#DBEAFE", + (0, 1): "#FEE2E2", + (1, 0): "#D1FAE5", + (1, 1): "#EDE9FE", } fig, axes = plt.subplots(1, 3, figsize=(18, 5.2)) @@ -79,40 +80,71 @@ def _draw_grid(ax, cell_text_fn, title, subtitle): ok = c // tk y = M - 1 - r rect = patches.Rectangle( - (c, y), 1, 1, + (c, y), + 1, + 1, facecolor=tile_colors[(om, ok)], - edgecolor="#D1D5DB", linewidth=0.5, + edgecolor="#D1D5DB", + linewidth=0.5, ) ax.add_patch(rect) ax.text( - c + 0.5, y + 0.5, cell_text_fn(r, c), - ha="center", va="center", fontsize=9, - color="#374151", family="monospace", + c + 0.5, + y + 0.5, + cell_text_fn(r, c), + ha="center", + va="center", + fontsize=9, + color="#374151", + family="monospace", ) # thick tile borders for i in range(0, M + 1, tm): - ax.plot([0, K], [i, i], color="#1F2937", lw=2.5, - solid_capstyle="butt") + ax.plot([0, K], [i, i], color="#1F2937", lw=2.5, solid_capstyle="butt") for j in range(0, K + 1, tk): - ax.plot([j, j], [0, M], color="#1F2937", lw=2.5, - solid_capstyle="butt") + ax.plot([j, j], [0, M], color="#1F2937", lw=2.5, solid_capstyle="butt") # oftile margin labels for om in range(M // tm): y_c = M - om * tm - tm / 2 - ax.text(-0.3, y_c, f"oftile\u2080={om}", ha="right", va="center", - fontsize=8, fontweight="bold", color="#7C3AED", - family="monospace") + ax.text( + -0.3, + y_c, + f"oftile\u2080={om}", + ha="right", + va="center", + fontsize=8, + fontweight="bold", + color="#7C3AED", + family="monospace", + ) for ok in range(K // tk): x_c = ok * tk + tk / 2 - ax.text(x_c, M + 0.15, f"oftile\u2081={ok}", ha="center", - va="bottom", fontsize=8, fontweight="bold", color="#7C3AED", - family="monospace") + ax.text( + x_c, + M + 0.15, + f"oftile\u2081={ok}", + ha="center", + va="bottom", + fontsize=8, + fontweight="bold", + color="#7C3AED", + family="monospace", + ) ax.set_xlim(-2.5, K + 0.5) ax.set_ylim(-1.8, M + 0.8) ax.axis("off") ax.set_title(title, fontsize=10.5, fontweight="bold", pad=10) - ax.text(K / 2, -0.3, subtitle, ha="center", va="top", fontsize=8, - color="#6B7280", family="monospace", linespacing=1.6) + ax.text( + K / 2, + -0.3, + subtitle, + ha="center", + va="top", + fontsize=8, + color="#6B7280", + family="monospace", + linespacing=1.6, + ) # ── Panel 1: linear index ──────────────────────────────────── _draw_grid( @@ -129,8 +161,7 @@ def _draw_grid(ax, cell_text_fn, title, subtitle): axes[1], lambda r, c: f"{r},{c}", "2D index (row, col)", - "intile = (row % 2, col % 4)\n" - "oftile = (row // 2, col // 4)", + "intile = (row % 2, col % 4)\noftile = (row // 2, col // 4)", ) # ── Panel 3: layout algebra ────────────────────────────────── @@ -142,9 +173,17 @@ def _draw_grid(ax, cell_text_fn, title, subtitle): "mode 0: (intile\u2080, oftile\u2080)\n" "mode 1: (intile\u2081, oftile\u2081)", ) - axes[2].text(K / 2, -1.45, "cells show (intile\u2080, intile\u2081)", - ha="center", va="top", fontsize=9, fontweight="bold", - color="#2563EB", family="monospace") + axes[2].text( + K / 2, + -1.45, + "cells show (intile\u2080, intile\u2081)", + ha="center", + va="top", + fontsize=9, + fontweight="bold", + color="#2563EB", + family="monospace", + ) plt.tight_layout() fig.savefig(path, dpi=150, bbox_inches="tight") @@ -188,7 +227,7 @@ def _generate_im2col( window_colors_rgb.append((r, g, b)) def rgb_to_hex(rgb): - return f"#{int(rgb[0]*255):02X}{int(rgb[1]*255):02X}{int(rgb[2]*255):02X}" + return f"#{int(rgb[0] * 255):02X}{int(rgb[1] * 255):02X}{int(rgb[2] * 255):02X}" window_colors = [rgb_to_hex(c) for c in window_colors_rgb] @@ -205,7 +244,9 @@ def blend(win_indices): # -- figure -- fig, axes = plt.subplots( - 1, 2, figsize=(4 + n_taps * 1.2, max(H, n_windows) * 0.8 + 1.5), + 1, + 2, + figsize=(4 + n_taps * 1.2, max(H, n_windows) * 0.8 + 1.5), gridspec_kw={"width_ratios": [W, n_taps * 1.1]}, ) @@ -217,19 +258,33 @@ def blend(win_indices): y = H - 1 - row color = blend(cell_windows[idx]) rect = patches.Rectangle( - (col, y), 1, 1, - facecolor=color, edgecolor="#D1D5DB", linewidth=0.5, + (col, y), + 1, + 1, + facecolor=color, + edgecolor="#D1D5DB", + linewidth=0.5, ) ax.add_patch(rect) ax.text( - col + 0.5, y + 0.5, labels[idx], - ha="center", va="center", fontsize=13, - color="#374151", family="monospace", fontweight="bold", + col + 0.5, + y + 0.5, + labels[idx], + ha="center", + va="center", + fontsize=13, + color="#374151", + family="monospace", + fontweight="bold", ) # thick outer border rect = patches.Rectangle( - (0, 0), W, H, - facecolor="none", edgecolor="#1F2937", linewidth=2.5, + (0, 0), + W, + H, + facecolor="none", + edgecolor="#1F2937", + linewidth=2.5, ) ax.add_patch(rect) ax.set_xlim(-0.5, W + 0.5) @@ -245,27 +300,45 @@ def blend(win_indices): for col_idx in range(n_taps): cell_label = labels[im2col_rows[row_idx][col_idx]] rect = patches.Rectangle( - (col_idx, y), 1, 1, + (col_idx, y), + 1, + 1, facecolor=window_colors[row_idx], - edgecolor="#D1D5DB", linewidth=0.5, + edgecolor="#D1D5DB", + linewidth=0.5, ) ax.add_patch(rect) ax.text( - col_idx + 0.5, y + 0.5, cell_label, - ha="center", va="center", fontsize=12, - color="#374151", family="monospace", fontweight="bold", + col_idx + 0.5, + y + 0.5, + cell_label, + ha="center", + va="center", + fontsize=12, + color="#374151", + family="monospace", + fontweight="bold", ) # row label: window position (p, q) p, q = divmod(row_idx, Q) ax.text( - -0.2, y + 0.5, f"({p},{q})", - ha="right", va="center", fontsize=8, - color="#6B7280", family="monospace", + -0.2, + y + 0.5, + f"({p},{q})", + ha="right", + va="center", + fontsize=8, + color="#6B7280", + family="monospace", ) # thick outer border rect = patches.Rectangle( - (0, 0), n_taps, n_windows, - facecolor="none", edgecolor="#1F2937", linewidth=2.5, + (0, 0), + n_taps, + n_windows, + facecolor="none", + edgecolor="#1F2937", + linewidth=2.5, ) ax.add_patch(rect) ax.set_xlim(-1.5, n_taps + 0.5) @@ -274,21 +347,32 @@ def blend(win_indices): ax.axis("off") ax.set_title( f"im2col output ({n_windows}\u00d7{n_taps})", - fontsize=11, fontweight="bold", pad=10, + fontsize=11, + fontweight="bold", + pad=10, ) # ── Arrow connecting the panels ───────────────────────────────── arrow = patches.FancyArrowPatch( - (0.44, 0.5), (0.52, 0.5), + (0.44, 0.5), + (0.52, 0.5), transform=fig.transFigure, arrowstyle="->,head_width=6,head_length=5", - color="#374151", linewidth=2, + color="#374151", + linewidth=2, ) fig.patches.append(arrow) fig.text( - 0.48, 0.54, f"im2col({R}\u00d7{S})", - ha="center", va="bottom", fontsize=11, fontweight="bold", - color="#374151", family="monospace", transform=fig.transFigure, + 0.48, + 0.54, + f"im2col({R}\u00d7{S})", + ha="center", + va="bottom", + fontsize=11, + fontweight="bold", + color="#374151", + family="monospace", + transform=fig.transFigure, ) plt.tight_layout() diff --git a/examples/composed.py b/examples/composed.py index 657274c..e8d56b2 100644 --- a/examples/composed.py +++ b/examples/composed.py @@ -62,7 +62,9 @@ def example_fast_path() -> None: def example_exact_fallback() -> LayoutExpr: - _banner("2. Exact fallback: compositions that do not collapse return ComposedLayout") + _banner( + "2. Exact fallback: compositions that do not collapse return ComposedLayout" + ) base = Layout((4, 4), (4, 1)) swizzled = compose(Swizzle(2, 0, 2), base) @@ -142,7 +144,12 @@ def maybe_draw(exact: LayoutExpr, outdir: Path | None) -> None: from tensor_layouts.viz import draw_layout, draw_slice outdir.mkdir(parents=True, exist_ok=True) - draw_layout(exact, outdir / "composed_exact.png", title="Exact composed layout", colorize=True) + draw_layout( + exact, + outdir / "composed_exact.png", + title="Exact composed layout", + colorize=True, + ) draw_slice( exact, (None, 1), @@ -155,7 +162,9 @@ def maybe_draw(exact: LayoutExpr, outdir: Path | None) -> None: def main() -> None: parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument("--draw", type=Path, default=None, help="Optional directory for PNG output") + parser.add_argument( + "--draw", type=Path, default=None, help="Optional directory for PNG output" + ) args = parser.parse_args() example_fast_path() diff --git a/examples/layouts.py b/examples/layouts.py index f8d9fff..4beca69 100644 --- a/examples/layouts.py +++ b/examples/layouts.py @@ -43,6 +43,7 @@ # Section 1: Layout Construction # ============================================================================= + def example_construction(): """Building layouts from shape and stride. @@ -91,6 +92,7 @@ def example_construction(): # Section 2: Querying Layouts # ============================================================================= + def example_querying(): """Query functions for shape, size, rank, and depth. @@ -106,10 +108,10 @@ def example_querying(): print(f" Layout: {layout}") print(f" shape: {layout.shape}") print(f" stride: {layout.stride}") - print(f" size: {size(layout)}") # Total number of elements - print(f" cosize: {cosize(layout)}") # Span: max offset + 1 - print(f" rank: {rank(layout)}") # Number of top-level modes - print(f" depth: {depth(layout)}") # Maximum nesting depth + print(f" size: {size(layout)}") # Total number of elements + print(f" cosize: {cosize(layout)}") # Span: max offset + 1 + print(f" rank: {rank(layout)}") # Number of top-level modes + print(f" depth: {depth(layout)}") # Maximum nesting depth # mode() extracts a single mode as a Layout print(f" mode 0: {mode(layout, 0)}") @@ -125,6 +127,7 @@ def example_querying(): # Section 3: Coordinate Mapping # ============================================================================= + def example_coordinate_mapping(): """Calling a layout to map coordinates to memory offsets. @@ -139,17 +142,17 @@ def example_coordinate_mapping(): print(f" Layout: {layout}") # Multi-dimensional coordinates - print(f" (0, 0) -> {layout(0, 0)}") # 0 - print(f" (2, 3) -> {layout(2, 3)}") # 2 + 12 = 14 - print(f" (3, 7) -> {layout(3, 7)}") # 3 + 28 = 31 + print(f" (0, 0) -> {layout(0, 0)}") # 0 + print(f" (2, 3) -> {layout(2, 3)}") # 2 + 12 = 14 + print(f" (3, 7) -> {layout(3, 7)}") # 3 + 28 = 31 # Flat index: column-major traversal of the domain - print(f" flat 0 -> {layout(0)}") # Same as (0,0) -> 0 - print(f" flat 5 -> {layout(5)}") # Same as (1,1) -> 5 - print(f" flat 31 -> {layout(31)}") # Same as (3,7) -> 31 + print(f" flat 0 -> {layout(0)}") # Same as (0,0) -> 0 + print(f" flat 5 -> {layout(5)}") # Same as (1,1) -> 5 + print(f" flat 31 -> {layout(31)}") # Same as (3,7) -> 31 # idx2crd: convert flat index to multi-dimensional coordinate - print(f"\n idx2crd(5, (4, 8)): {idx2crd(5, (4, 8))}") # (1, 1) + print(f"\n idx2crd(5, (4, 8)): {idx2crd(5, (4, 8))}") # (1, 1) print(f" idx2crd(14, (4, 8)): {idx2crd(14, (4, 8))}") # (2, 3) # crd2flat: convert coordinate to flat index @@ -160,6 +163,7 @@ def example_coordinate_mapping(): # Section 4: Tuple Arithmetic # ============================================================================= + def example_tuple_arithmetic(): """Arithmetic on nested tuples — the foundation of layout algebra. @@ -174,12 +178,12 @@ def example_tuple_arithmetic(): # column-major strides from a shape shape = (2, 3, 4) pp = prefix_product(shape) - print(f" prefix_product({shape}): {pp}") # (1, 2, 6) + print(f" prefix_product({shape}): {pp}") # (1, 2, 6) # This is exactly the column-major stride for shape (2,3,4) # suffix_product: running product from the right sp = suffix_product(shape) - print(f" suffix_product({shape}): {sp}") # (12, 4, 1) + print(f" suffix_product({shape}): {sp}") # (12, 4, 1) # This is exactly the row-major stride for shape (2,3,4) # inner_product: sum of element-wise products @@ -200,6 +204,7 @@ def example_tuple_arithmetic(): # Section 5: Layout Manipulation # ============================================================================= + def example_manipulation(): """Reshape and reorganize layouts without changing the mapping. @@ -244,6 +249,7 @@ def example_manipulation(): # Section 6: Composition # ============================================================================= + def example_composition(): """compose(A, B) — function composition: C(i) = A(B(i)). @@ -289,6 +295,7 @@ def example_composition(): # Section 7: Complement # ============================================================================= + def example_complement(): """complement(L) — the layout that fills in L's gaps. @@ -333,6 +340,7 @@ def example_complement(): # Section 8: Division # ============================================================================= + def example_division(): """logical_divide — split a layout into (tile, rest). @@ -382,6 +390,7 @@ def example_division(): # Section 9: Product # ============================================================================= + def example_product(): """logical_product — replicate A's pattern across B's domain. @@ -422,6 +431,7 @@ def example_product(): # Section 10: Inverse # ============================================================================= + def example_inverse(): """right_inverse, left_inverse — undo a layout's mapping. @@ -461,6 +471,7 @@ def example_inverse(): # Section 11: Swizzle # ============================================================================= + def example_swizzle(): """Swizzle(bits, base, shift) — XOR-based bank conflict avoidance. @@ -504,6 +515,7 @@ def example_swizzle(): # Section 12: Tensor # ============================================================================= + def example_tensor(): """Tensor — a Layout combined with a base offset. @@ -548,6 +560,7 @@ def example_tensor(): # Section 13: Tile # ============================================================================= + def example_tile(): """Tile — a tuple of Layouts for mode-by-mode composition. @@ -586,6 +599,7 @@ def example_tile(): # Section 14: Iteration # ============================================================================= + def example_iteration(): """Iterating over layouts. @@ -627,6 +641,7 @@ def example_iteration(): # Section 15: Image and Injectivity # ============================================================================= + def example_image_injectivity(): """Analyzing a layout as a function. @@ -668,6 +683,7 @@ def example_image_injectivity(): # Section 16: Functional Equivalence # ============================================================================= + def example_functional_equivalence(): """Checking if two layouts compute the same mapping. @@ -704,6 +720,7 @@ def example_functional_equivalence(): # Section 17: GPU Analysis # ============================================================================= + def example_analysis(): """Analyzing layouts for GPU performance. @@ -735,7 +752,9 @@ def example_analysis(): result = bank_conflicts(col_access, element_bytes=4) print(f" Column access (stride 8): {col_access}") print(f" conflict_free: {result['conflict_free']}") - print(f" max_ways: {result['max_ways']} (threads serialize {result['max_ways']}x)") + print( + f" max_ways: {result['max_ways']} (threads serialize {result['max_ways']}x)" + ) # --- Swizzle fix --- sw_tile = compose(Swizzle(3, 0, 3), tile) @@ -759,13 +778,17 @@ def example_analysis(): print(" " + "-" * 40) coalesced = Layout(32, 1) r1 = coalescing_efficiency(coalesced, element_bytes=4) - print(f" Stride-1 (fp32): {r1['transactions']} transaction, " - f"efficiency {r1['efficiency']:.0%}") + print( + f" Stride-1 (fp32): {r1['transactions']} transaction, " + f"efficiency {r1['efficiency']:.0%}" + ) scattered = Layout(32, 64) r2 = coalescing_efficiency(scattered, element_bytes=2) - print(f" Stride-64 (fp16): {r2['transactions']} transactions, " - f"efficiency {r2['efficiency']:.1%}") + print( + f" Stride-64 (fp16): {r2['transactions']} transactions, " + f"efficiency {r2['efficiency']:.1%}" + ) # --- Permutation structure --- print(f"\n Permutation Structure") @@ -792,11 +815,11 @@ def example_analysis(): print(" " + "-" * 40) layouts = [ - ("Contiguous 1D", Layout(8, 1)), - ("Strided 1D", Layout(8, 2)), - ("Col-major 4x8", Layout((4, 8), (1, 4))), - ("Gapped 4x8", Layout((4, 8), (1, 8))), - ("Row-major 4x8", Layout((4, 8), (8, 1))), + ("Contiguous 1D", Layout(8, 1)), + ("Strided 1D", Layout(8, 2)), + ("Col-major 4x8", Layout((4, 8), (1, 4))), + ("Gapped 4x8", Layout((4, 8), (1, 8))), + ("Row-major 4x8", Layout((4, 8), (8, 1))), ] for label, layout in layouts: print(f" {label:20s} {str(layout):25s} contiguity={contiguity(layout)}") @@ -805,8 +828,8 @@ def example_analysis(): print(f"\n Atom Summary") print(" " + "-" * 40) - from tensor_layouts.atoms_nv import SM80_16x8x16_F16F16F16F16_TN from tensor_layouts.atoms_amd import CDNA_32x32x8_F32F16F16_MFMA + from tensor_layouts.atoms_nv import SM80_16x8x16_F16F16F16F16_TN print() atom_summary(SM80_16x8x16_F16F16F16F16_TN) @@ -818,6 +841,7 @@ def example_analysis(): # Main # ============================================================================= + def main(): """Run all layout algebra examples.""" print("=" * 70) diff --git a/examples/tensor.py b/examples/tensor.py index b5c38a7..eac2d71 100644 --- a/examples/tensor.py +++ b/examples/tensor.py @@ -36,8 +36,8 @@ viz.ipynb — Jupyter notebook gallery """ -from pathlib import Path import sys +from pathlib import Path # Prefer the local repo sources when running this script from a checkout. REPO_ROOT = Path(__file__).resolve().parents[1] @@ -46,7 +46,7 @@ sys.path.insert(0, str(SRC_DIR)) from tensor_layouts import * -from tensor_layouts.viz import draw_layout, draw_composite +from tensor_layouts.viz import draw_composite, draw_layout def setup_output_dir(name: str = "examples_output") -> Path: @@ -59,6 +59,7 @@ def setup_output_dir(name: str = "examples_output") -> Path: # Section 1: Attaching Storage # ============================================================================= + def example_storage(): """Attach storage to a Tensor for element-level access. @@ -97,6 +98,7 @@ def example_storage(): # Section 2: Reading Through Coordinates # ============================================================================= + def example_reading(): """Access elements using logical coordinates. @@ -135,6 +137,7 @@ def example_reading(): # Section 3: Writing Through Coordinates # ============================================================================= + def example_writing(): """Write to storage through logical coordinates. @@ -166,6 +169,7 @@ def example_writing(): # Section 4: View Semantics (Slicing Shares Data) # ============================================================================= + def example_views(): """Slicing produces sub-Tensors that share storage. @@ -206,6 +210,7 @@ def example_views(): # Section 5: Swapping Storage # ============================================================================= + def example_swap(): """The data property is writable — swap storage without rebuilding. @@ -237,6 +242,7 @@ def example_swap(): # Section 6: Two Layouts, One Buffer # ============================================================================= + def example_two_layouts(): """Attach the same storage to two Tensors with different layouts. @@ -264,6 +270,7 @@ def example_two_layouts(): # Section 7: Auto-Labeling in draw_layout # ============================================================================= + def example_auto_label(output: Path): """draw_layout auto-labels cells with data values. @@ -280,27 +287,37 @@ def example_auto_label(output: Path): t = Tensor(layout, data=buf) # Auto-label: cells show A, B, C, ... - draw_layout(t, output / "tensor_data_auto.svg", - title="Auto-labeled from data") + draw_layout(t, output / "tensor_data_auto.svg", title="Auto-labeled from data") print(f" tensor_data_auto.svg — cells show data values") # Override: show raw offsets - draw_layout(t, output / "tensor_data_offsets.svg", - title="cell_labels='offset'", cell_labels="offset") + draw_layout( + t, + output / "tensor_data_offsets.svg", + title="cell_labels='offset'", + cell_labels="offset", + ) print(f" tensor_data_offsets.svg — cells show offsets") # Override: no labels - draw_layout(t, output / "tensor_data_nolabels.svg", - title="cell_labels=False", colorize=True, cell_labels=False) + draw_layout( + t, + output / "tensor_data_nolabels.svg", + title="cell_labels=False", + colorize=True, + cell_labels=False, + ) print(f" tensor_data_nolabels.svg — no cell text") # Side-by-side: same data, two layouts row_t = Tensor(Layout((4, 8), (8, 1)), data=buf) col_t = Tensor(Layout((4, 8), (1, 4)), data=buf) - draw_composite([row_t, col_t], - output / "tensor_data_compare.svg", - titles=["Row-major", "Col-major"], - main_title="Same data, different layouts") + draw_composite( + [row_t, col_t], + output / "tensor_data_compare.svg", + titles=["Row-major", "Col-major"], + main_title="Same data, different layouts", + ) print(f" tensor_data_compare.svg — row vs col-major, same data") @@ -308,6 +325,7 @@ def example_auto_label(output: Path): # Main # ============================================================================= + def main(): """Run all Tensor storage examples.""" output = setup_output_dir() diff --git a/examples/viz.py b/examples/viz.py index e02e982..f9c6a16 100644 --- a/examples/viz.py +++ b/examples/viz.py @@ -35,8 +35,8 @@ Reference: https://github.com/NVIDIA/cutlass/blob/main/media/docs/cute/ """ -from pathlib import Path import sys +from pathlib import Path # Prefer the local repo sources when running this script from a checkout. # An installed `tensor-layouts` distribution is still required for package metadata. @@ -105,7 +105,9 @@ def example_output_formats(output: Path): layout_8x8 = Layout((8, 8), (8, 1)) # Color by value (default) - same value = same color - draw_layout(layout_8x8, output / "color_by_value.svg", title="color_layout=None (by value)") + draw_layout( + layout_8x8, output / "color_by_value.svg", title="color_layout=None (by value)" + ) print(f"✓ Color by value: color_by_value.svg") # Color by column - darker across columns (cute-viz style) @@ -147,7 +149,10 @@ def example_output_formats(output: Path): # color_by shorthand — equivalent to the manual color_layout above draw_layout( - layout_8x8, output / "color_by_row_shorthand.svg", title='color_by="row"', color_by="row" + layout_8x8, + output / "color_by_row_shorthand.svg", + title='color_by="row"', + color_by="row", ) draw_layout( layout_8x8, @@ -155,14 +160,18 @@ def example_output_formats(output: Path): title='color_by="column"', color_by="column", ) - print(f"✓ color_by shorthand: color_by_row_shorthand.svg, color_by_col_shorthand.svg") + print( + f"✓ color_by shorthand: color_by_row_shorthand.svg, color_by_col_shorthand.svg" + ) # Swizzle comparison showing row-group coloring (reveals permutation effect) base = Layout((8, 8), (8, 1)) sw = Swizzle(3, 0, 3) draw_swizzle(base, sw, output / "swizzle_example.svg") draw_swizzle(base, sw, output / "swizzle_example_color.svg", colorize=True) - print(f"✓ Swizzle with row-group coloring: swizzle_example.svg, swizzle_example_color.svg") + print( + f"✓ Swizzle with row-group coloring: swizzle_example.svg, swizzle_example_color.svg" + ) # ============================================================================= @@ -184,7 +193,9 @@ def example_1d_layouts(output: Path): # Contiguous 1D layout: 8 elements, stride 1 layout_1d_contiguous = Layout(8, 1) - draw_layout(layout_1d_contiguous, output / "1d_contiguous.svg", title="1D Contiguous: 8:1") + draw_layout( + layout_1d_contiguous, output / "1d_contiguous.svg", title="1D Contiguous: 8:1" + ) print(f"✓ 1D Contiguous: 8:1") print(f" Maps index i → offset i (e.g., 3 → 3)") @@ -230,21 +241,33 @@ def example_2d_layouts(output: Path): # Row-major 4x3: shape (4 rows, 3 cols), stride (3, 1) # Row i, Col j → offset = i*3 + j row_major_4x3 = Layout((4, 3), (3, 1)) - draw_layout(row_major_4x3, output / "2d_row_major_4x3.svg", title="Row-Major 4×3: (4,3):(3,1)") + draw_layout( + row_major_4x3, + output / "2d_row_major_4x3.svg", + title="Row-Major 4×3: (4,3):(3,1)", + ) print(f"✓ Row-Major 4×3: (4,3):(3,1)") print(f" offset(i,j) = i*3 + j*1") # Column-major 4x3: shape (4 rows, 3 cols), stride (1, 4) # Row i, Col j → offset = i*1 + j*4 col_major_4x3 = Layout((4, 3), (1, 4)) - draw_layout(col_major_4x3, output / "2d_col_major_4x3.svg", title="Col-Major 4×3: (4,3):(1,4)") + draw_layout( + col_major_4x3, + output / "2d_col_major_4x3.svg", + title="Col-Major 4×3: (4,3):(1,4)", + ) print(f"✓ Col-Major 4×3: (4,3):(1,4)") print(f" offset(i,j) = i*1 + j*4") # 8x8 Row-major: shape (8 rows, 8 cols), stride (8, 1) # This is the common layout for matrix operations row_major_8x8 = Layout((8, 8), (8, 1)) - draw_layout(row_major_8x8, output / "2d_row_major_8x8.svg", title="Row-Major 8×8: (8,8):(8,1)") + draw_layout( + row_major_8x8, + output / "2d_row_major_8x8.svg", + title="Row-Major 8×8: (8,8):(8,1)", + ) draw_layout( row_major_8x8, output / "2d_row_major_8x8_color.svg", @@ -256,7 +279,11 @@ def example_2d_layouts(output: Path): # 8x8 Column-major: shape (8 rows, 8 cols), stride (1, 8) col_major_8x8 = Layout((8, 8), (1, 8)) - draw_layout(col_major_8x8, output / "2d_col_major_8x8.svg", title="Col-Major 8×8: (8,8):(1,8)") + draw_layout( + col_major_8x8, + output / "2d_col_major_8x8.svg", + title="Col-Major 8×8: (8,8):(1,8)", + ) draw_layout( col_major_8x8, output / "2d_col_major_8x8_color.svg", @@ -419,12 +446,16 @@ def example_hierarchical_layouts(output: Path): # Flatten the hierarchical layout (algebra operation) flat_layout = flatten(logo_layout) - draw_layout(flat_layout, output / "hier_flattened.svg", title=f"flatten(): {flat_layout}") + draw_layout( + flat_layout, output / "hier_flattened.svg", title=f"flatten(): {flat_layout}" + ) print(f"✓ Flattened (algebra): {flat_layout}") # Coalesce to merge contiguous dimensions coal_layout = coalesce(logo_layout) - draw_layout(coal_layout, output / "hier_coalesced.svg", title=f"coalesce(): {coal_layout}") + draw_layout( + coal_layout, output / "hier_coalesced.svg", title=f"coalesce(): {coal_layout}" + ) print(f"✓ Coalesced: {coal_layout}") # ========================================================================= @@ -605,7 +636,10 @@ def example_hierarchical_layouts(output: Path): draw_layout(morton2, output / "hier_morton_4x4.svg", title=f"Morton 4×4: {morton2}") draw_layout(morton3, output / "hier_morton_8x8.svg", title=f"Morton 8×8: {morton3}") draw_layout( - morton3, output / "hier_morton_8x8_color.svg", title=f"Morton 8×8: {morton3}", colorize=True + morton3, + output / "hier_morton_8x8_color.svg", + title=f"Morton 8×8: {morton3}", + colorize=True, ) print(f"✓ Morton 2×2: {morton1}") print(f"✓ Morton 4×4: {morton2}") @@ -678,7 +712,9 @@ def example_swizzled_layouts(output: Path): # Column-major variant base_8x8_col = Layout((8, 8), (1, 8)) - draw_swizzle(base_8x8_col, sw_303, output / "swizzle_8x8_col_303.svg", colorize=True) + draw_swizzle( + base_8x8_col, sw_303, output / "swizzle_8x8_col_303.svg", colorize=True + ) print(f"✓ Swizzle(3,0,3) on 8×8 col-major") # 16x8 variant (common for tensor core) @@ -720,7 +756,9 @@ def example_swizzled_layouts(output: Path): # Canonical byte layout: 8 rows × 128 columns (128 bytes per row) sw_343 = Swizzle(3, 4, 3) base_8x128 = Layout((8, 128), (128, 1)) - draw_swizzle(base_8x128, sw_343, output / "swizzle_8x128_343_SW128.svg", colorize=True) + draw_swizzle( + base_8x128, sw_343, output / "swizzle_8x128_343_SW128.svg", colorize=True + ) print(f"✓ Swizzle(3,4,3) SW128 on 8×128: XOR bits [4,7) with [7,10)") # ========================================================================= @@ -730,7 +768,9 @@ def example_swizzled_layouts(output: Path): # Swizzle(0, M, S) is identity - no XOR applied sw_043 = Swizzle(0, 4, 3) - draw_swizzle(base_8x128, sw_043, output / "swizzle_8x128_043_none.svg", colorize=True) + draw_swizzle( + base_8x128, sw_043, output / "swizzle_8x128_043_none.svg", colorize=True + ) print(f"✓ Swizzle(0,4,3) on 8×128: Identity (no XOR)") @@ -763,7 +803,10 @@ def example_thread_value_layouts(output: Path): title="TV: (4,2):(2,1) - 4 threads, 2 values each", ) draw_tv_layout( - tv_4x2, output / "tv_4threads_2values_color.svg", title="TV: (4,2):(2,1)", colorize=True + tv_4x2, + output / "tv_4threads_2values_color.svg", + title="TV: (4,2):(2,1)", + colorize=True, ) print(f"✓ TV Layout 4×2: 4 threads, 2 values each") print(f" Thread t owns values V0, V1 at offsets 2*t and 2*t+1") @@ -779,19 +822,29 @@ def example_thread_value_layouts(output: Path): # 8x4 TV layout (smaller than full warp for clarity) tv_8x4 = Layout((8, 4), (4, 1)) - draw_tv_layout(tv_8x4, output / "tv_8x4.svg", title="TV: (8,4):(4,1) - 8 threads, 4 values") - draw_tv_layout(tv_8x4, output / "tv_8x4_color.svg", title="TV: (8,4):(4,1)", colorize=True) + draw_tv_layout( + tv_8x4, output / "tv_8x4.svg", title="TV: (8,4):(4,1) - 8 threads, 4 values" + ) + draw_tv_layout( + tv_8x4, output / "tv_8x4_color.svg", title="TV: (8,4):(4,1)", colorize=True + ) print(f"✓ TV Layout 8×4: 8 threads, 4 values each") # 8x8 TV layout (common for LDMATRIX) tv_8x8 = Layout((8, 8), (8, 1)) - draw_tv_layout(tv_8x8, output / "tv_8x8.svg", title="TV: (8,8):(8,1) - 8 threads, 8 values") - draw_tv_layout(tv_8x8, output / "tv_8x8_color.svg", title="TV: (8,8):(8,1)", colorize=True) + draw_tv_layout( + tv_8x8, output / "tv_8x8.svg", title="TV: (8,8):(8,1) - 8 threads, 8 values" + ) + draw_tv_layout( + tv_8x8, output / "tv_8x8_color.svg", title="TV: (8,8):(8,1)", colorize=True + ) print(f"✓ TV Layout 8×8: 8 threads, 8 values each (LDMATRIX style)") # Also show the regular layout view for comparison draw_layout( - tv_8x8, output / "tv_8x8_offsets.svg", title="TV: (8,8):(8,1) - Memory offsets view" + tv_8x8, + output / "tv_8x8_offsets.svg", + title="TV: (8,8):(8,1) - Memory offsets view", ) print(f" (Also showing memory offset view for comparison)") @@ -836,7 +889,9 @@ def example_copy_atoms(output: Path): ] for atom in ldsm_atoms: # draw_copy_atom handles upcast from bit to element coords automatically - draw_copy_atom(atom, element_bits=element_bits, filename=output / f"{atom.name}_copy.svg") + draw_copy_atom( + atom, element_bits=element_bits, filename=output / f"{atom.name}_copy.svg" + ) dst = upcast(atom.dst_layout_bits, element_bits) n_thr = size(atom.thr_id) @@ -851,7 +906,9 @@ def example_copy_atoms(output: Path): stsm_atoms = [SM90_U32x4_STSM_N, SM90_U16x8_STSM_T] for atom in stsm_atoms: - draw_copy_atom(atom, element_bits=element_bits, filename=output / f"{atom.name}_copy.svg") + draw_copy_atom( + atom, element_bits=element_bits, filename=output / f"{atom.name}_copy.svg" + ) print(f"✓ {atom.name} ({atom.ptx})") # ===================================================================== @@ -874,13 +931,18 @@ def example_copy_atoms(output: Path): # For fp16: Swizzle<3,4,3> ∘ (8, 64):(64, 1) = 8 rows × 64 cols print("\n TMA target: GMMA K-major SW128 smem layout (fp16):") base_tma = Layout((8, 64), (64, 1)) - draw_swizzle(base_tma, Swizzle(3, 4, 3), output / "SM90_TMA_GMMA_K_SW128.svg", colorize=True) + draw_swizzle( + base_tma, Swizzle(3, 4, 3), output / "SM90_TMA_GMMA_K_SW128.svg", colorize=True + ) print(f"✓ SM90 TMA → GMMA K-major SW128: Swizzle(3,4,3) ∘ (8,64):(64,1)") print("\n TMA target: GMMA M|N-major SW128 smem layout (fp16):") base_tma_mn = Layout((64, 8), (1, 64)) draw_swizzle( - base_tma_mn, Swizzle(3, 4, 3), output / "SM90_TMA_GMMA_MN_SW128.svg", colorize=True + base_tma_mn, + Swizzle(3, 4, 3), + output / "SM90_TMA_GMMA_MN_SW128.svg", + colorize=True, ) print(f"✓ SM90 TMA → GMMA M|N-major SW128: Swizzle(3,4,3) ∘ (64,8):(1,64)") @@ -889,7 +951,9 @@ def example_copy_atoms(output: Path): # ===================================================================== print("\n --- LDMATRIX Shared Memory with Swizzle ---") smem_8x8 = Layout((8, 8), (8, 1)) - draw_swizzle(smem_8x8, Swizzle(3, 0, 3), output / "ldmatrix_smem_swizzle.svg", colorize=True) + draw_swizzle( + smem_8x8, Swizzle(3, 0, 3), output / "ldmatrix_smem_swizzle.svg", colorize=True + ) print(f"✓ LDMATRIX shared memory with Swizzle(3,0,3)") @@ -996,8 +1060,12 @@ def _draw_tiled_mma(atom, atom_layout, output: Path, tile_mnk=None): a_grid, _ = tile_mma_grid(atom, atom_layout, "A", tile_mnk=tile_mnk) b_grid, _ = tile_mma_grid(atom, atom_layout, "B", tile_mnk=tile_mnk) - draw_tiled_grid(c_grid, M, N, output / f"{label}_C.svg", title=f"{label} C ({M}×{N})") - draw_tiled_grid(a_grid, M, K, output / f"{label}_A.svg", title=f"{label} A ({M}×{K})") + draw_tiled_grid( + c_grid, M, N, output / f"{label}_C.svg", title=f"{label} C ({M}×{N})" + ) + draw_tiled_grid( + a_grid, M, K, output / f"{label}_A.svg", title=f"{label} A ({M}×{K})" + ) # B displayed as K×N (transposed) b_display = {} for (r, c), val in b_grid.items(): @@ -1005,11 +1073,20 @@ def _draw_tiled_mma(atom, atom_layout, output: Path, tile_mnk=None): n_coord = r k_coord = c b_display[(k_coord, n_coord)] = val - draw_tiled_grid(b_display, K, N, output / f"{label}_B.svg", title=f"{label} B ({K}×{N})") + draw_tiled_grid( + b_display, K, N, output / f"{label}_B.svg", title=f"{label} B ({K}×{N})" + ) # Combined figure: A (left), B (top-right), C (bottom-right) draw_combined_mma_grid( - a_grid, b_display, c_grid, M, N, K, output / f"{label}_combined.svg", title=label + a_grid, + b_display, + c_grid, + M, + N, + K, + output / f"{label}_combined.svg", + title=label, ) print(f"✓ Tiled MMA: {label}") @@ -1161,17 +1238,23 @@ def example_slicing(output: Path): base = Layout((8, 8), (8, 1)) # Row slice: select row 3 (all columns) - draw_slice(base, (3, None), output / "slice_row3.svg", title="Row Slice: layout(3, :)") + draw_slice( + base, (3, None), output / "slice_row3.svg", title="Row Slice: layout(3, :)" + ) print(f"✓ Row slice: layout(3, :)") print(f" Selects all 8 elements in row 3") # Column slice: select column 5 (all rows) - draw_slice(base, (None, 5), output / "slice_col5.svg", title="Column Slice: layout(:, 5)") + draw_slice( + base, (None, 5), output / "slice_col5.svg", title="Column Slice: layout(:, 5)" + ) print(f"✓ Column slice: layout(:, 5)") print(f" Selects all 8 elements in column 5") # Single element - draw_slice(base, (4, 6), output / "slice_element.svg", title="Single Element: layout(4, 6)") + draw_slice( + base, (4, 6), output / "slice_element.svg", title="Single Element: layout(4, 6)" + ) print(f"✓ Single element: layout(4, 6)") # Rectangular region: rows 2-5, columns 1-4 @@ -1190,22 +1273,34 @@ def example_slicing(output: Path): # Divide 8 rows into 4 groups of 2 divided = logical_divide(base, Layout((2, 4), (1, 2))) - draw_layout(divided, output / "slice_divided_base.svg", title="Divided: 2-row groups") + draw_layout( + divided, output / "slice_divided_base.svg", title="Divided: 2-row groups" + ) print(f"✓ Divided layout: groups of 2 rows") # Tile-based slicing: select every other 2x2 tile tiled = Layout(((2, 4), (2, 4)), ((1, 16), (2, 8))) - draw_layout(tiled, output / "slice_tiled.svg", title="Tiled: ((2,4),(2,4)):((1,16),(2,8))") + draw_layout( + tiled, output / "slice_tiled.svg", title="Tiled: ((2,4),(2,4)):((1,16),(2,8))" + ) print(f"✓ Tiled layout: 2×2 tiles in 8×8") # Strided row access (every other row) strided_rows = Layout((4, 8), (16, 1)) - draw_layout(strided_rows, output / "slice_strided_rows.svg", title="Strided Rows: (4,8):(16,1)") + draw_layout( + strided_rows, + output / "slice_strided_rows.svg", + title="Strided Rows: (4,8):(16,1)", + ) print(f"✓ Strided rows: every other row") # Strided column access (every other column) strided_cols = Layout((8, 4), (8, 2)) - draw_layout(strided_cols, output / "slice_strided_cols.svg", title="Strided Cols: (8,4):(8,2)") + draw_layout( + strided_cols, + output / "slice_strided_cols.svg", + title="Strided Cols: (8,4):(8,2)", + ) print(f"✓ Strided columns: every other column") # Diagonal-like pattern using hierarchical layout @@ -1245,13 +1340,19 @@ def example_slicing(output: Path): # Slice (2, ((0,None),None)) — fix mode-0 to 2, partially slice mode-1 draw_slice( - cecka_t, (2, ((0, None), None)), output / "cecka_slice_2_0NN.svg", title="(2,((0,:),:))" + cecka_t, + (2, ((0, None), None)), + output / "cecka_slice_2_0NN.svg", + title="(2,((0,:),:))", ) print(f"✓ Slice (2, ((0,:),:)) — fix row=2, inner-col-0=0, rest free") # Slice ((None,1),(None,0)) — fix outer-row=1, inner-col-outer=0 draw_slice( - cecka_t, ((None, 1), (None, 0)), output / "cecka_slice_N1_N0.svg", title="((:,1),(:,0))" + cecka_t, + ((None, 1), (None, 0)), + output / "cecka_slice_N1_N0.svg", + title="((:,1),(:,0))", ) print(f"✓ Slice ((:,1), (:,0)) — outer-row=1, mode-1 partially fixed") @@ -1297,7 +1398,9 @@ def example_algebra_operations(output: Path): composed = compose(outer, inner) draw_layout(inner, output / "algebra_inner.svg", title=f"Inner: {inner}") draw_layout( - composed, output / "algebra_composed.svg", title=f"Composed: compose({outer}, {inner})" + composed, + output / "algebra_composed.svg", + title=f"Composed: compose({outer}, {inner})", ) print(f"✓ Composition: compose({outer}, {inner}) = {composed}") @@ -1306,7 +1409,9 @@ def example_algebra_operations(output: Path): comp = complement(base, 16) draw_layout(base, output / "algebra_base.svg", title=f"Base: {base}") draw_layout( - comp, output / "algebra_complement.svg", title=f"Complement: complement({base}, 16)" + comp, + output / "algebra_complement.svg", + title=f"Complement: complement({base}, 16)", ) print(f"✓ Complement: complement({base}, 16) = {comp}") @@ -1316,7 +1421,9 @@ def example_algebra_operations(output: Path): divided = logical_divide(matrix, tiler) draw_layout(matrix, output / "algebra_matrix.svg", title=f"Matrix: {matrix}") draw_layout( - divided, output / "algebra_divided.svg", title=f"Divided: logical_divide by {tiler}" + divided, + output / "algebra_divided.svg", + title=f"Divided: logical_divide by {tiler}", ) print(f"✓ Logical divide: 8×8 by 2×2 tiler") @@ -1326,7 +1433,9 @@ def example_algebra_operations(output: Path): product = logical_product(tile, grid) draw_layout(tile, output / "algebra_tile.svg", title=f"Tile: {tile}") draw_layout( - product, output / "algebra_product.svg", title=f"Product: logical_product({tile}, {grid})" + product, + output / "algebra_product.svg", + title=f"Product: logical_product({tile}, {grid})", ) print(f"✓ Logical product: {tile} × {grid}") @@ -1334,13 +1443,17 @@ def example_algebra_operations(output: Path): # that are now automatically rendered as multi-panel 2D grids fd = flat_divide(matrix, Layout(2, 1)) draw_layout( - fd, output / "algebra_flat_divide.svg", title=f"flat_divide result (rank {rank(fd)})" + fd, + output / "algebra_flat_divide.svg", + title=f"flat_divide result (rank {rank(fd)})", ) print(f"✓ flat_divide: shape={fd.shape}, rank={rank(fd)} → multi-panel") fp = flat_product(Layout((2, 2), (1, 2)), Layout(4, 1)) draw_layout( - fp, output / "algebra_flat_product.svg", title=f"flat_product result (rank {rank(fp)})" + fp, + output / "algebra_flat_product.svg", + title=f"flat_product result (rank {rank(fp)})", ) print(f"✓ flat_product: shape={fp.shape}, rank={rank(fp)} → multi-panel") diff --git a/src/tensor_layouts/__init__.py b/src/tensor_layouts/__init__.py index c53e3a1..360bdb3 100644 --- a/src/tensor_layouts/__init__.py +++ b/src/tensor_layouts/__init__.py @@ -23,9 +23,9 @@ """Pure-Python implementation of GPU layout algebra.""" from .layouts import * # noqa: F401,F403 -from .tensor import Tensor # noqa: F401 -from .atoms import MMAAtom, CopyAtom # noqa: F401 - from importlib.metadata import version # noqa: F401 +from .atoms import CopyAtom, MMAAtom # noqa: F401 +from .tensor import Tensor # noqa: F401 + __version__ = version("tensor-layouts") diff --git a/src/tensor_layouts/analysis.py b/src/tensor_layouts/analysis.py index 8e216f8..ba8f7fc 100644 --- a/src/tensor_layouts/analysis.py +++ b/src/tensor_layouts/analysis.py @@ -209,6 +209,7 @@ def _normalize_explain_compose_tiler(tiler): # Inverse mapping # ============================================================================= + def offset_table(layout: LayoutExpr) -> dict: """Return {offset: [coord, ...]} mapping each offset to its coordinates. @@ -272,12 +273,12 @@ def aliasing_profile(layout: Layout) -> dict: n_total = sum(ways_per_offset.values()) return { - 'has_aliasing': bool(aliased_offsets), - 'max_alias_ways': max(ways_per_offset.values(), default=0), - 'aliased_offset_count': len(aliased_offsets), - 'duplicate_elements': n_total - n_unique, - 'reuse_histogram': dict(sorted(reuse_histogram.items())), - 'aliased_offsets': aliased_offsets, + "has_aliasing": bool(aliased_offsets), + "max_alias_ways": max(ways_per_offset.values(), default=0), + "aliased_offset_count": len(aliased_offsets), + "duplicate_elements": n_total - n_unique, + "reuse_histogram": dict(sorted(reuse_histogram.items())), + "aliased_offsets": aliased_offsets, } @@ -321,13 +322,13 @@ def footprint(layout: LayoutExpr) -> dict: span = max_off - min_off + 1 if offsets else 0 return { - 'min_offset': min_off, - 'max_offset': max_off, - 'span': span, - 'unique_offsets': n_unique, - 'total_elements': n_total, - 'reuse_factor': n_total / n_unique if n_unique > 0 else 0.0, - 'holes': span - n_unique, + "min_offset": min_off, + "max_offset": max_off, + "span": span, + "unique_offsets": n_unique, + "total_elements": n_total, + "reuse_factor": n_total / n_unique if n_unique > 0 else 0.0, + "holes": span - n_unique, } @@ -359,12 +360,12 @@ def gap_profile(layout: LayoutExpr) -> dict: offsets = image(layout) if not offsets: return { - 'runs': [], - 'gap_sizes': [], - 'max_gap': 0, - 'avg_gap': 0.0, - 'run_count': 0, - 'isolated_offsets': 0, + "runs": [], + "gap_sizes": [], + "max_gap": 0, + "avg_gap": 0.0, + "run_count": 0, + "isolated_offsets": 0, } runs = [] @@ -383,12 +384,12 @@ def gap_profile(layout: LayoutExpr) -> dict: runs.append((run_start, prev)) return { - 'runs': runs, - 'gap_sizes': gap_sizes, - 'max_gap': max(gap_sizes, default=0), - 'avg_gap': (sum(gap_sizes) / len(gap_sizes)) if gap_sizes else 0.0, - 'run_count': len(runs), - 'isolated_offsets': sum(1 for s, e in runs if s == e), + "runs": runs, + "gap_sizes": gap_sizes, + "max_gap": max(gap_sizes, default=0), + "avg_gap": (sum(gap_sizes) / len(gap_sizes)) if gap_sizes else 0.0, + "run_count": len(runs), + "isolated_offsets": sum(1 for s, e in runs if s == e), } @@ -396,9 +397,15 @@ def gap_profile(layout: LayoutExpr) -> dict: # Bank conflict analysis # ============================================================================= -def bank_conflicts(layout: LayoutExpr, *, element_bytes: int, - num_banks: int = 32, bank_width_bytes: int = 4, - group_size: int = 32): + +def bank_conflicts( + layout: LayoutExpr, + *, + element_bytes: int, + num_banks: int = 32, + bank_width_bytes: int = 4, + group_size: int = 32, +): """Analyze shared memory bank conflicts for a thread-to-offset layout. Given a layout that maps thread indices to shared memory offsets, @@ -481,9 +488,9 @@ def bank_conflicts(layout: LayoutExpr, *, element_bytes: int, max_ways = ways return { - 'conflict_free': max_ways <= 1, - 'max_ways': max_ways, - 'bank_to_threads': bank_to_threads, + "conflict_free": max_ways <= 1, + "max_ways": max_ways, + "bank_to_threads": bank_to_threads, } @@ -492,8 +499,9 @@ def bank_conflicts(layout: LayoutExpr, *, element_bytes: int, # ============================================================================= -def _group_access_offsets(layout: LayoutExpr, start_thread: int = 0, - end_thread: int | None = None) -> tuple[list[int], int]: +def _group_access_offsets( + layout: LayoutExpr, start_thread: int = 0, end_thread: int | None = None +) -> tuple[list[int], int]: """Return accessed offsets for a thread group and their minimum offset. A bare Layout has no attached base pointer, so address-based analyses use @@ -516,9 +524,13 @@ def _group_access_offsets(layout: LayoutExpr, start_thread: int = 0, return offsets, min(offsets) -def coalescing_efficiency(layout: LayoutExpr, *, element_bytes: int, - warp_size: int = 32, - cache_line_bytes: int = 128): +def coalescing_efficiency( + layout: LayoutExpr, + *, + element_bytes: int, + warp_size: int = 32, + cache_line_bytes: int = 128, +): """Analyze global memory coalescing for a thread-to-offset layout. Given a layout that maps thread indices to global memory offsets, @@ -586,16 +598,20 @@ def coalescing_efficiency(layout: LayoutExpr, *, element_bytes: int, efficiency = useful_bytes / transferred_bytes if transferred_bytes > 0 else 0.0 return { - 'transactions': transactions, - 'efficiency': efficiency, - 'cache_lines': sorted(cache_lines), + "transactions": transactions, + "efficiency": efficiency, + "cache_lines": sorted(cache_lines), } -def segment_analysis(layout: LayoutExpr, *, element_bytes: int, - warp_size: int = 32, - segment_bytes: int = 32, - cache_line_bytes: int = 128): +def segment_analysis( + layout: LayoutExpr, + *, + element_bytes: int, + warp_size: int = 32, + segment_bytes: int = 32, + cache_line_bytes: int = 128, +): """Segment- and alignment-aware global memory transaction analysis. A more detailed model than ``coalescing_efficiency()``. NVIDIA GPUs @@ -658,14 +674,16 @@ def segment_analysis(layout: LayoutExpr, *, element_bytes: int, transferred_bytes = n_segments * segment_bytes return { - 'segments': n_segments, - 'cache_lines': n_lines, - 'unique_bytes': unique_bytes, - 'requested_bytes': requested_bytes, - 'transferred_bytes': transferred_bytes, - 'segment_efficiency': unique_bytes / transferred_bytes if transferred_bytes > 0 else 0.0, - 'first_byte_addr': first_byte, - 'first_alignment': first_byte % segment_bytes, + "segments": n_segments, + "cache_lines": n_lines, + "unique_bytes": unique_bytes, + "requested_bytes": requested_bytes, + "transferred_bytes": transferred_bytes, + "segment_efficiency": unique_bytes / transferred_bytes + if transferred_bytes > 0 + else 0.0, + "first_byte_addr": first_byte, + "first_alignment": first_byte % segment_bytes, } @@ -769,22 +787,26 @@ def _offset(thread: int, value_flat: int) -> int: global_unique = sorted(set(all_deltas)) return { - 'thread_count': thread_count, - 'value_count': value_count, - 'per_value_deltas': per_value_deltas, - 'per_value_unique_strides': per_value_unique, - 'per_value_constant_stride': per_value_constant, - 'per_value_is_constant': per_value_is_constant, - 'global_unique_strides': global_unique, - 'is_uniform': len(global_unique) <= 1, - 'has_broadcast_lane': has_broadcast_lane, + "thread_count": thread_count, + "value_count": value_count, + "per_value_deltas": per_value_deltas, + "per_value_unique_strides": per_value_unique, + "per_value_constant_stride": per_value_constant, + "per_value_is_constant": per_value_is_constant, + "global_unique_strides": global_unique, + "is_uniform": len(global_unique) <= 1, + "has_broadcast_lane": has_broadcast_lane, } -def per_group_bank_conflicts(layout: LayoutExpr, *, element_bytes: int, - group_size: int = 32, - num_banks: int = 32, - bank_width_bytes: int = 4) -> dict: +def per_group_bank_conflicts( + layout: LayoutExpr, + *, + element_bytes: int, + group_size: int = 32, + num_banks: int = 32, + bank_width_bytes: int = 4, +) -> dict: """Analyze bank conflicts per warp/wavefront group across a full layout. Splits the layout into groups of ``group_size`` threads and analyzes @@ -844,9 +866,9 @@ def per_group_bank_conflicts(layout: LayoutExpr, *, element_bytes: int, max_ways = ways result = { - 'conflict_free': max_ways <= 1, - 'max_ways': max_ways, - 'bank_to_threads': bank_to_threads, + "conflict_free": max_ways <= 1, + "max_ways": max_ways, + "bank_to_threads": bank_to_threads, } groups.append(result) if max_ways > worst_ways: @@ -854,15 +876,19 @@ def per_group_bank_conflicts(layout: LayoutExpr, *, element_bytes: int, worst_idx = g return { - 'groups': groups, - 'worst_group': worst_idx, - 'worst_max_ways': worst_ways, + "groups": groups, + "worst_group": worst_idx, + "worst_max_ways": worst_ways, } -def per_group_coalescing(layout: LayoutExpr, *, element_bytes: int, - group_size: int = 32, - cache_line_bytes: int = 128) -> dict: +def per_group_coalescing( + layout: LayoutExpr, + *, + element_bytes: int, + group_size: int = 32, + cache_line_bytes: int = 128, +) -> dict: """Analyze coalescing efficiency per warp/wavefront group across a full layout. Splits the layout into groups of ``group_size`` threads and analyzes @@ -892,7 +918,7 @@ def per_group_coalescing(layout: LayoutExpr, *, element_bytes: int, groups = [] worst_idx = 0 - worst_eff = float('inf') + worst_eff = float("inf") for g in range(num_groups): start = g * group_size @@ -914,9 +940,9 @@ def per_group_coalescing(layout: LayoutExpr, *, element_bytes: int, efficiency = useful_bytes / transferred_bytes if transferred_bytes > 0 else 0.0 result = { - 'transactions': transactions, - 'efficiency': efficiency, - 'cache_lines': sorted(cache_lines), + "transactions": transactions, + "efficiency": efficiency, + "cache_lines": sorted(cache_lines), } groups.append(result) if efficiency < worst_eff: @@ -924,9 +950,9 @@ def per_group_coalescing(layout: LayoutExpr, *, element_bytes: int, worst_idx = g return { - 'groups': groups, - 'worst_group': worst_idx, - 'worst_efficiency': worst_eff, + "groups": groups, + "worst_group": worst_idx, + "worst_efficiency": worst_eff, } @@ -1050,6 +1076,7 @@ def order(layout: LayoutExpr) -> int: # Contiguity # ============================================================================= + def contiguity(layout: Layout) -> int: """Return the longest contiguous vector width from the start of the layout. @@ -1139,6 +1166,7 @@ def slice_contiguity(layout: Layout, coord) -> int: # Atom analysis # ============================================================================= + def atom_summary(atom: MMAAtom) -> dict: """Summarize an MMA atom's key properties. @@ -1181,40 +1209,39 @@ def atom_summary(atom: MMAAtom) -> dict: for v in range(num_v): c_offset_list.append(atom.c_layout(t, v)) c_offsets = set(c_offset_list) - c_coverage_ok = (c_offsets == set(range(M * N)) - and len(c_offset_list) == M * N) + c_coverage_ok = c_offsets == set(range(M * N)) and len(c_offset_list) == M * N # Check for broadcast (stride-0) in A and B a_broadcast = atom.a_layout.filter() != atom.a_layout b_broadcast = atom.b_layout.filter() != atom.b_layout result = { - 'name': atom.name, - 'shape_mnk': atom.shape_mnk, - 'threads': threads, - 'values_a': values_a, - 'values_b': values_b, - 'values_c': values_c, - 'c_coverage_ok': c_coverage_ok, - 'a_broadcast': a_broadcast, - 'b_broadcast': b_broadcast, + "name": atom.name, + "shape_mnk": atom.shape_mnk, + "threads": threads, + "values_a": values_a, + "values_b": values_b, + "values_c": values_c, + "c_coverage_ok": c_coverage_ok, + "a_broadcast": a_broadcast, + "b_broadcast": b_broadcast, } lines = [ atom.name, - f' Shape (M, N, K): {M} x {N} x {K}', - f' Threads: {threads}', - f' Values per thread: A={values_a}, B={values_b}, C={values_c}', - f' C covers M*N: {c_coverage_ok}', + f" Shape (M, N, K): {M} x {N} x {K}", + f" Threads: {threads}", + f" Values per thread: A={values_a}, B={values_b}, C={values_c}", + f" C covers M*N: {c_coverage_ok}", ] if a_broadcast: - lines.append(' A has broadcast (stride-0) modes') + lines.append(" A has broadcast (stride-0) modes") if b_broadcast: - lines.append(' B has broadcast (stride-0) modes') + lines.append(" B has broadcast (stride-0) modes") - text = '\n'.join(lines) + text = "\n".join(lines) print(text) - result['text'] = text + result["text"] = text return result @@ -1236,14 +1263,16 @@ def _operand_coverage(layout: Layout, domain_size: int) -> dict: duplicates = total_accesses - len(unique) return { - 'domain_size': domain_size, - 'unique_offsets': len(unique), - 'total_accesses': total_accesses, - 'duplicates': duplicates, - 'coverage_ok': unique == expected, - 'missing': sorted(missing) if missing else [], - 'extra': sorted(extra) if extra else [], - 'thread_utilization': len(unique) / total_accesses if total_accesses > 0 else 0.0, + "domain_size": domain_size, + "unique_offsets": len(unique), + "total_accesses": total_accesses, + "duplicates": duplicates, + "coverage_ok": unique == expected, + "missing": sorted(missing) if missing else [], + "extra": sorted(extra) if extra else [], + "thread_utilization": len(unique) / total_accesses + if total_accesses > 0 + else 0.0, } @@ -1270,9 +1299,9 @@ def operand_analysis(atom: MMAAtom) -> dict: M, N, K = atom.shape_mnk return { - 'a': _operand_coverage(atom.a_layout, M * K), - 'b': _operand_coverage(atom.b_layout, N * K), - 'c': _operand_coverage(atom.c_layout, M * N), + "a": _operand_coverage(atom.a_layout, M * K), + "b": _operand_coverage(atom.b_layout, N * K), + "c": _operand_coverage(atom.c_layout, M * N), } @@ -1280,34 +1309,35 @@ def operand_analysis(atom: MMAAtom) -> dict: # Algebra explanation # ============================================================================= + def _explain_logical_divide(fn, args): L, T = args if isinstance(T, int): T = Layout(T) - lines = [f'logical_divide({L}, {T})'] + lines = [f"logical_divide({L}, {T})"] actual = logical_divide(L, T) if is_affine(T): - lines.append(' = compose(L, Layout(T, complement(T, shape(coalesce(L)))))') - lines.append('') - lines.append(f' L = {L}') - lines.append(f' T = {T}') + lines.append(" = compose(L, Layout(T, complement(T, shape(coalesce(L)))))") + lines.append("") + lines.append(f" L = {L}") + lines.append(f" T = {T}") coalesced_shape = coalesce(L).shape - lines.append(f' shape(coalesce(L)) = {coalesced_shape}') + lines.append(f" shape(coalesce(L)) = {coalesced_shape}") comp = complement(T, coalesced_shape) - lines.append(f' complement(T, {coalesced_shape}) = {comp}') + lines.append(f" complement(T, {coalesced_shape}) = {comp}") intermediate = Layout(T, comp) - lines.append(f' Layout(T, complement) = {intermediate}') + lines.append(f" Layout(T, complement) = {intermediate}") result = compose(L, intermediate) - lines.append(f' compose(L, {intermediate}) = {result}') + lines.append(f" compose(L, {intermediate}) = {result}") else: - lines.append(' Divides each mode of L by the corresponding tiler element.') - lines.append('') - lines.append(f' L = {L}') - lines.append(f' T = {T}') + lines.append(" Divides each mode of L by the corresponding tiler element.") + lines.append("") + lines.append(f" L = {L}") + lines.append(f" T = {T}") - lines.append('') - lines.append(f' result = {actual}') + lines.append("") + lines.append(f" result = {actual}") return lines @@ -1315,37 +1345,37 @@ def _explain_logical_product(fn, args): A, B = args if isinstance(B, int): B = Layout(B) - lines = [f'logical_product({A}, {B})'] + lines = [f"logical_product({A}, {B})"] if is_affine(B): - lines.append(' = Layout(A, compose(complement(A, size(A)*cosize(B)), B))') - lines.append('') - lines.append(f' A = {A}') - lines.append(f' B = {B}') - lines.append(f' size(A) = {size(A)}') - lines.append(f' cosize(B) = {cosize(B)}') + lines.append(" = Layout(A, compose(complement(A, size(A)*cosize(B)), B))") + lines.append("") + lines.append(f" A = {A}") + lines.append(f" B = {B}") + lines.append(f" size(A) = {size(A)}") + lines.append(f" cosize(B) = {cosize(B)}") bound = size(A) * cosize(B) - lines.append(f' size(A) * cosize(B) = {bound}') + lines.append(f" size(A) * cosize(B) = {bound}") comp = complement(A, bound) - lines.append(f' complement(A, {bound}) = {comp}') + lines.append(f" complement(A, {bound}) = {comp}") comp_b = compose(comp, B) - lines.append(f' compose(complement, B) = {comp_b}') + lines.append(f" compose(complement, B) = {comp_b}") result = Layout(A, comp_b) - lines.append(f' Layout(A, {comp_b}) = {result}') + lines.append(f" Layout(A, {comp_b}) = {result}") else: # Tuple tiler: mode-by-mode decomposition - lines.append(' For tuple tilers, applies logical_product mode-by-mode.') - lines.append('') - lines.append(f' A = {A}') - lines.append(f' B = {B}') + lines.append(" For tuple tilers, applies logical_product mode-by-mode.") + lines.append("") + lines.append(f" A = {A}") + lines.append(f" B = {B}") for i in range(len(B)): mi = mode(A, i) bi = B[i] ri = logical_product(mi, bi) - lines.append(f' mode {i}: logical_product({mi}, {bi}) = {ri}') + lines.append(f" mode {i}: logical_product({mi}, {bi}) = {ri}") - lines.append('') - lines.append(f' result = {logical_product(A, B)}') + lines.append("") + lines.append(f" result = {logical_product(A, B)}") return lines @@ -1353,56 +1383,56 @@ def _explain_complement(fn, args): L = args[0] bound = args[1] if len(args) > 1 else None if bound is not None: - lines = [f'complement({L}, {bound})'] + lines = [f"complement({L}, {bound})"] else: - lines = [f'complement({L})'] + lines = [f"complement({L})"] bound = cosize(L) - lines.append(f' Fills the gaps in L\'s codomain up to bound={bound}.') - lines.append('') - lines.append(f' L = {L}') - lines.append(f' image(L) = {image(L)}') - lines.append(f' codomain = [0, {bound})') + lines.append(f" Fills the gaps in L's codomain up to bound={bound}.") + lines.append("") + lines.append(f" L = {L}") + lines.append(f" image(L) = {image(L)}") + lines.append(f" codomain = [0, {bound})") comp = complement(*args) - lines.append(f' complement = {comp}') - lines.append(f' image(complement) = {image(comp)}') + lines.append(f" complement = {comp}") + lines.append(f" image(complement) = {image(comp)}") return lines def _explain_compose(fn, args): A, B = args - lines = [f'compose({A}, {B})'] + lines = [f"compose({A}, {B})"] if is_layout(B): - lines.append(' C(i) = A(B(i))') + lines.append(" C(i) = A(B(i))") else: - lines.append(' For tuple tilers, composition is applied mode-by-mode.') - lines.append('') - lines.append(f' A = {A}') - lines.append(f' B = {B}') + lines.append(" For tuple tilers, composition is applied mode-by-mode.") + lines.append("") + lines.append(f" A = {A}") + lines.append(f" B = {B}") result = compose(A, B) - lines.append(f' result = {result}') - lines.append('') + lines.append(f" result = {result}") + lines.append("") if is_layout(B): n = min(size(result), 8) - lines.append(f' First {n} values:') + lines.append(f" First {n} values:") for i in range(n): b_i = B(i) - lines.append(f' i={i}: B({i})={b_i}, A({b_i})={result(i)}') + lines.append(f" i={i}: B({i})={b_i}, A({b_i})={result(i)}") else: for i in range(len(B)): ai = mode(A, i) bi = _normalize_explain_compose_tiler(B[i]) ri = compose(ai, bi) - lines.append(f' mode {i}: compose({ai}, {bi}) = {ri}') + lines.append(f" mode {i}: compose({ai}, {bi}) = {ri}") for i in range(len(B), rank(A)): ai = mode(A, i) - lines.append(f' mode {i}: unchanged = {ai}') + lines.append(f" mode {i}: unchanged = {ai}") n = min(size(result), 8) - lines.append('') - lines.append(f' First {n} output offsets:') + lines.append("") + lines.append(f" First {n} output offsets:") for i in range(n): coord = idx2crd(i, result.shape) - lines.append(f' coord={coord}: result({coord})={result(coord)}') + lines.append(f" coord={coord}: result({coord})={result(coord)}") return lines @@ -1411,16 +1441,16 @@ def _explain_right_inverse(fn, args): R = right_inverse(L) n = min(size(R), 8) lines = [ - f'right_inverse({L})', - ' R such that L(R(i)) == i', - '', - f' L = {L}', - f' R = {R}', - '', - f' Verification (first {n}):', + f"right_inverse({L})", + " R such that L(R(i)) == i", + "", + f" L = {L}", + f" R = {R}", + "", + f" Verification (first {n}):", ] for i in range(n): - lines.append(f' R({i})={R(i)}, L(R({i}))={L(R(i))}') + lines.append(f" R({i})={R(i)}, L(R({i}))={L(R(i))}") return lines @@ -1429,16 +1459,16 @@ def _explain_left_inverse(fn, args): R = left_inverse(L) n = min(size(L), 8) lines = [ - f'left_inverse({L})', - ' R such that R(L(i)) == i', - '', - f' L = {L}', - f' R = {R}', - '', - f' Verification (first {n}):', + f"left_inverse({L})", + " R such that R(L(i)) == i", + "", + f" L = {L}", + f" R = {R}", + "", + f" Verification (first {n}):", ] for i in range(n): - lines.append(f' L({i})={L(i)}, R(L({i}))={R(L(i))}') + lines.append(f" L({i})={L(i)}, R(L({i}))={R(L(i))}") return lines @@ -1447,19 +1477,19 @@ def _explain_blocked_product(fn, args): lp = logical_product(A, B) actual = blocked_product(A, B) lines = [ - f'blocked_product({A}, {B})', - ' Like logical_product, but interleaves corresponding modes:', - ' ((A0, B0), (A1, B1), ...) — A varies fastest (block-first).', - '', - f' logical_product(A, B) = {lp}', - f' blocked_product(A, B) = {actual}', - '', - ' Mode structure:', + f"blocked_product({A}, {B})", + " Like logical_product, but interleaves corresponding modes:", + " ((A0, B0), (A1, B1), ...) — A varies fastest (block-first).", + "", + f" logical_product(A, B) = {lp}", + f" blocked_product(A, B) = {actual}", + "", + " Mode structure:", ] n_modes = max(1, len(actual.shape) if isinstance(actual.shape, tuple) else 1) for i in range(n_modes): m = mode(actual, i) if isinstance(actual.shape, tuple) else actual - lines.append(f' mode {i}: {m.shape} : {m.stride}') + lines.append(f" mode {i}: {m.shape} : {m.stride}") return lines @@ -1471,23 +1501,23 @@ def _explain_raked_product(fn, args): bp_vals = [bp(i) for i in range(n)] rp_vals = [actual(i) for i in range(n)] return [ - f'raked_product({A}, {B})', - ' Like blocked_product, but B varies fastest (rake-first):', - ' ((B0, A0), (B1, A1), ...) — elements are interleaved.', - '', - f' blocked_product(A, B) = {bp}', - f' raked_product(A, B) = {actual}', - '', - ' Compare first 8 offsets:', - f' blocked: {bp_vals}', - f' raked: {rp_vals}', + f"raked_product({A}, {B})", + " Like blocked_product, but B varies fastest (rake-first):", + " ((B0, A0), (B1, A1), ...) — elements are interleaved.", + "", + f" blocked_product(A, B) = {bp}", + f" raked_product(A, B) = {actual}", + "", + " Compare first 8 offsets:", + f" blocked: {bp_vals}", + f" raked: {rp_vals}", ] _DIVIDE_VARIANT_STRUCTURE = { - 'zipped_divide': ' Structure: ((tiles), (rests))', - 'tiled_divide': ' Structure: ((tiles), rest0, rest1, ...)', - 'flat_divide': ' Structure: (tile0, tile1, ..., rest0, rest1, ...)', + "zipped_divide": " Structure: ((tiles), (rests))", + "tiled_divide": " Structure: ((tiles), rest0, rest1, ...)", + "flat_divide": " Structure: (tile0, tile1, ..., rest0, rest1, ...)", } @@ -1497,14 +1527,14 @@ def _explain_divide_variant(fn, args): ld = logical_divide(L, T) actual = fn(L, T) return [ - f'{name}({L}, {T})', - ' Rearrangement of logical_divide result.', - '', - f' logical_divide({L}, {T})', - f' = {ld}', - f' {name}:', - f' = {actual}', - '', + f"{name}({L}, {T})", + " Rearrangement of logical_divide result.", + "", + f" logical_divide({L}, {T})", + f" = {ld}", + f" {name}:", + f" = {actual}", + "", _DIVIDE_VARIANT_STRUCTURE[name], ] @@ -1545,15 +1575,15 @@ def explain(fn, *args): """ handler = _EXPLAIN_HANDLERS.get(fn) if handler is None: - supported = ', '.join(sorted(h.__name__ for h in _EXPLAIN_HANDLERS)) + supported = ", ".join(sorted(h.__name__ for h in _EXPLAIN_HANDLERS)) lines = [ - f'explain() does not support {getattr(fn, "__name__", fn)}.', - f'Supported: {supported}.', + f"explain() does not support {getattr(fn, '__name__', fn)}.", + f"Supported: {supported}.", ] else: lines = handler(fn, args) - text = '\n'.join(lines) + text = "\n".join(lines) print(text) return text @@ -1667,8 +1697,10 @@ def to_F2_matrix(layout: Layout) -> list[list[int]]: if layout.swizzle is not None: sw = layout.swizzle # Build swizzle matrix S (identity + XOR connections) - S = [[1 if i == j else 0 for j in range(n_offset_bits)] - for i in range(n_offset_bits)] + S = [ + [1 if i == j else 0 for j in range(n_offset_bits)] + for i in range(n_offset_bits) + ] for k in range(sw.bits): if sw.shift >= 0: src = sw.base + sw.shift + k @@ -1680,8 +1712,10 @@ def to_F2_matrix(layout: Layout) -> list[list[int]]: S[dst][src] = 1 # Compose: M' = S @ M (mod 2) M = [ - [sum(S[i][k] * M[k][j] for k in range(n_offset_bits)) % 2 - for j in range(n_coord_bits)] + [ + sum(S[i][k] * M[k][j] for k in range(n_offset_bits)) % 2 + for j in range(n_coord_bits) + ] for i in range(n_offset_bits) ] diff --git a/src/tensor_layouts/atoms.py b/src/tensor_layouts/atoms.py index 729b634..eedd278 100644 --- a/src/tensor_layouts/atoms.py +++ b/src/tensor_layouts/atoms.py @@ -45,6 +45,7 @@ class MMAAtom: b_layout: (T, V) -> col-major offset in (N, K) c_layout: (T, V) -> col-major offset in (M, N) """ + name: str ptx: str shape_mnk: Tuple[int, int, int] @@ -71,6 +72,7 @@ class CopyAtom: src_layout_bits: (thr, val) -> bit offset for source dst_layout_bits: (thr, val) -> bit offset for destination """ + name: str ptx: str thr_id: Layout diff --git a/src/tensor_layouts/atoms_amd.py b/src/tensor_layouts/atoms_amd.py index 3f7d7ee..9d14df5 100644 --- a/src/tensor_layouts/atoms_amd.py +++ b/src/tensor_layouts/atoms_amd.py @@ -131,6 +131,7 @@ # Helper: construct CuTe layouts from MFMA structural parameters # ============================================================================= + def _mfma_c_layout( m: int, n: int, @@ -254,21 +255,23 @@ def make_mfma_atom( # Sanity checks matching the CK static_asserts if num_threads_per_blk != n: - raise ValueError( - f"num_threads_per_blk ({num_threads_per_blk}) != n ({n})") + raise ValueError(f"num_threads_per_blk ({num_threads_per_blk}) != n ({n})") if num_regs_per_blk * num_input_blks != m: raise ValueError( f"num_regs_per_blk * num_input_blks " - f"({num_regs_per_blk * num_input_blks}) != m ({m})") + f"({num_regs_per_blk * num_input_blks}) != m ({m})" + ) if num_regs_per_blk * wave_size != m * n: raise ValueError( f"num_regs_per_blk * wave_size " - f"({num_regs_per_blk * wave_size}) != m*n ({m * n})") + f"({num_regs_per_blk * wave_size}) != m*n ({m * n})" + ) if wave_size != num_input_blks * num_threads_per_blk: raise ValueError( f"wave_size ({wave_size}) != " f"num_input_blks * num_threads_per_blk " - f"({num_input_blks * num_threads_per_blk})") + f"({num_input_blks * num_threads_per_blk})" + ) # For k-reduction variants: K = k_per_blk * num_input_blks # For non-k-reduction: K = k_per_blk @@ -277,16 +280,28 @@ def make_mfma_atom( raise ValueError(f"total_k ({total_k}) != k ({k})") c_layout = _mfma_c_layout( - m, n, group_size, num_groups_per_blk, - num_threads_per_blk, num_input_blks, + m, + n, + group_size, + num_groups_per_blk, + num_threads_per_blk, + num_input_blks, ) a_layout = _mfma_input_layout( - m, k, num_threads_per_blk, num_input_blks, k_per_blk, + m, + k, + num_threads_per_blk, + num_input_blks, + k_per_blk, ) b_layout = _mfma_input_layout( - n, k, num_threads_per_blk, num_input_blks, k_per_blk, + n, + k, + num_threads_per_blk, + num_input_blks, + k_per_blk, ) return MMAAtom( @@ -313,11 +328,18 @@ def make_mfma_atom( CDNA_32x32x8_F32F16F16_MFMA = make_mfma_atom( name="CDNA_32x32x8_F32F16F16_MFMA", inst="v_mfma_f32_32x32x8f16", - m=32, n=32, k=8, - group_size=4, num_groups_per_blk=4, - num_threads_per_blk=32, num_input_blks=2, - num_output_blks=1, k_per_blk=4, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=32, + n=32, + k=8, + group_size=4, + num_groups_per_blk=4, + num_threads_per_blk=32, + num_input_blks=2, + num_output_blks=1, + k_per_blk=4, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) # v_mfma_f32_16x16x16f16: D[16x16] = C[16x16] + A[16x16]*B[16x16] @@ -327,11 +349,18 @@ def make_mfma_atom( CDNA_16x16x16_F32F16F16_MFMA = make_mfma_atom( name="CDNA_16x16x16_F32F16F16_MFMA", inst="v_mfma_f32_16x16x16f16", - m=16, n=16, k=16, - group_size=4, num_groups_per_blk=1, - num_threads_per_blk=16, num_input_blks=4, - num_output_blks=1, k_per_blk=4, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=16, + n=16, + k=16, + group_size=4, + num_groups_per_blk=1, + num_threads_per_blk=16, + num_input_blks=4, + num_output_blks=1, + k_per_blk=4, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) # v_mfma_f32_4x4x4f16: D[4x4] = C[4x4] + A[4x4]*B[4x4] @@ -346,11 +375,18 @@ def make_mfma_atom( CDNA_4x4x4_F32F16F16_MFMA = make_mfma_atom( name="CDNA_4x4x4_F32F16F16_MFMA", inst="v_mfma_f32_4x4x4f16", - m=4, n=64, k=4, - group_size=4, num_groups_per_blk=1, - num_threads_per_blk=64, num_input_blks=1, - num_output_blks=1, k_per_blk=4, - is_k_reduction=False, num_v_a=2, num_v_b=2, + m=4, + n=64, + k=4, + group_size=4, + num_groups_per_blk=1, + num_threads_per_blk=64, + num_input_blks=1, + num_output_blks=1, + k_per_blk=4, + is_k_reduction=False, + num_v_a=2, + num_v_b=2, ) # --- Non-k-reduction variants (larger K, multiple output blocks) --- @@ -362,11 +398,18 @@ def make_mfma_atom( CDNA_32x32x4_F32F16F16_MFMA = make_mfma_atom( name="CDNA_32x32x4_F32F16F16_MFMA", inst="v_mfma_f32_32x32x4f16", - m=32, n=32, k=4, - group_size=4, num_groups_per_blk=4, - num_threads_per_blk=32, num_input_blks=2, - num_output_blks=2, k_per_blk=4, - is_k_reduction=False, num_v_a=2, num_v_b=2, + m=32, + n=32, + k=4, + group_size=4, + num_groups_per_blk=4, + num_threads_per_blk=32, + num_input_blks=2, + num_output_blks=2, + k_per_blk=4, + is_k_reduction=False, + num_v_a=2, + num_v_b=2, ) # v_mfma_f32_16x16x4f16: 4 output blocks (non-k-reduction) @@ -375,11 +418,18 @@ def make_mfma_atom( CDNA_16x16x4_F32F16F16_MFMA = make_mfma_atom( name="CDNA_16x16x4_F32F16F16_MFMA", inst="v_mfma_f32_16x16x4f16", - m=16, n=16, k=4, - group_size=4, num_groups_per_blk=1, - num_threads_per_blk=16, num_input_blks=4, - num_output_blks=4, k_per_blk=4, - is_k_reduction=False, num_v_a=2, num_v_b=2, + m=16, + n=16, + k=4, + group_size=4, + num_groups_per_blk=1, + num_threads_per_blk=16, + num_input_blks=4, + num_output_blks=4, + k_per_blk=4, + is_k_reduction=False, + num_v_a=2, + num_v_b=2, ) @@ -392,22 +442,36 @@ def make_mfma_atom( CDNA_32x32x8_F32BF16BF16_1K_MFMA = make_mfma_atom( name="CDNA_32x32x8_F32BF16BF16_1K_MFMA", inst="v_mfma_f32_32x32x8bf16_1k", - m=32, n=32, k=8, - group_size=4, num_groups_per_blk=4, - num_threads_per_blk=32, num_input_blks=2, - num_output_blks=1, k_per_blk=4, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=32, + n=32, + k=8, + group_size=4, + num_groups_per_blk=4, + num_threads_per_blk=32, + num_input_blks=2, + num_output_blks=1, + k_per_blk=4, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) # v_mfma_f32_16x16x16bf16_1k: identical layout to 16x16x16f16 CDNA_16x16x16_F32BF16BF16_1K_MFMA = make_mfma_atom( name="CDNA_16x16x16_F32BF16BF16_1K_MFMA", inst="v_mfma_f32_16x16x16bf16_1k", - m=16, n=16, k=16, - group_size=4, num_groups_per_blk=1, - num_threads_per_blk=16, num_input_blks=4, - num_output_blks=1, k_per_blk=4, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=16, + n=16, + k=16, + group_size=4, + num_groups_per_blk=1, + num_threads_per_blk=16, + num_input_blks=4, + num_output_blks=1, + k_per_blk=4, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) @@ -419,22 +483,36 @@ def make_mfma_atom( CDNA_32x32x4_F32BF16BF16_MFMA = make_mfma_atom( name="CDNA_32x32x4_F32BF16BF16_MFMA", inst="v_mfma_f32_32x32x4bf16", - m=32, n=32, k=4, - group_size=4, num_groups_per_blk=4, - num_threads_per_blk=32, num_input_blks=2, - num_output_blks=1, k_per_blk=2, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=32, + n=32, + k=4, + group_size=4, + num_groups_per_blk=4, + num_threads_per_blk=32, + num_input_blks=2, + num_output_blks=1, + k_per_blk=2, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) # v_mfma_f32_16x16x8bf16 CDNA_16x16x8_F32BF16BF16_MFMA = make_mfma_atom( name="CDNA_16x16x8_F32BF16BF16_MFMA", inst="v_mfma_f32_16x16x8bf16", - m=16, n=16, k=8, - group_size=4, num_groups_per_blk=1, - num_threads_per_blk=16, num_input_blks=4, - num_output_blks=1, k_per_blk=2, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=16, + n=16, + k=8, + group_size=4, + num_groups_per_blk=1, + num_threads_per_blk=16, + num_input_blks=4, + num_output_blks=1, + k_per_blk=2, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) @@ -446,22 +524,36 @@ def make_mfma_atom( CDNA_32x32x8_I32I8I8_MFMA = make_mfma_atom( name="CDNA_32x32x8_I32I8I8_MFMA", inst="v_mfma_i32_32x32x8i8", - m=32, n=32, k=8, - group_size=4, num_groups_per_blk=4, - num_threads_per_blk=32, num_input_blks=2, - num_output_blks=1, k_per_blk=4, - is_k_reduction=True, num_v_a=1, num_v_b=1, + m=32, + n=32, + k=8, + group_size=4, + num_groups_per_blk=4, + num_threads_per_blk=32, + num_input_blks=2, + num_output_blks=1, + k_per_blk=4, + is_k_reduction=True, + num_v_a=1, + num_v_b=1, ) # v_mfma_i32_16x16x16i8 CDNA_16x16x16_I32I8I8_MFMA = make_mfma_atom( name="CDNA_16x16x16_I32I8I8_MFMA", inst="v_mfma_i32_16x16x16i8", - m=16, n=16, k=16, - group_size=4, num_groups_per_blk=1, - num_threads_per_blk=16, num_input_blks=4, - num_output_blks=1, k_per_blk=4, - is_k_reduction=True, num_v_a=1, num_v_b=1, + m=16, + n=16, + k=16, + group_size=4, + num_groups_per_blk=1, + num_threads_per_blk=16, + num_input_blks=4, + num_output_blks=1, + k_per_blk=4, + is_k_reduction=True, + num_v_a=1, + num_v_b=1, ) @@ -473,22 +565,36 @@ def make_mfma_atom( CDNA_32x32x2_F32F32F32_MFMA = make_mfma_atom( name="CDNA_32x32x2_F32F32F32_MFMA", inst="v_mfma_f32_32x32x2f32", - m=32, n=32, k=2, - group_size=4, num_groups_per_blk=4, - num_threads_per_blk=32, num_input_blks=2, - num_output_blks=1, k_per_blk=1, - is_k_reduction=True, num_v_a=1, num_v_b=1, + m=32, + n=32, + k=2, + group_size=4, + num_groups_per_blk=4, + num_threads_per_blk=32, + num_input_blks=2, + num_output_blks=1, + k_per_blk=1, + is_k_reduction=True, + num_v_a=1, + num_v_b=1, ) # v_mfma_f32_16x16x4f32 CDNA_16x16x4_F32F32F32_MFMA = make_mfma_atom( name="CDNA_16x16x4_F32F32F32_MFMA", inst="v_mfma_f32_16x16x4f32", - m=16, n=16, k=4, - group_size=4, num_groups_per_blk=1, - num_threads_per_blk=16, num_input_blks=4, - num_output_blks=1, k_per_blk=1, - is_k_reduction=True, num_v_a=1, num_v_b=1, + m=16, + n=16, + k=4, + group_size=4, + num_groups_per_blk=1, + num_threads_per_blk=16, + num_input_blks=4, + num_output_blks=1, + k_per_blk=1, + is_k_reduction=True, + num_v_a=1, + num_v_b=1, ) @@ -500,11 +606,18 @@ def make_mfma_atom( CDNA_16x16x4_F64F64F64_MFMA = make_mfma_atom( name="CDNA_16x16x4_F64F64F64_MFMA", inst="v_mfma_f64_16x16x4f64", - m=16, n=16, k=4, - group_size=1, num_groups_per_blk=4, - num_threads_per_blk=16, num_input_blks=4, - num_output_blks=1, k_per_blk=1, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=16, + n=16, + k=4, + group_size=1, + num_groups_per_blk=4, + num_threads_per_blk=16, + num_input_blks=4, + num_output_blks=1, + k_per_blk=1, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) @@ -538,11 +651,18 @@ def make_mfma_atom( CDNA3_32x32x16_I32I8I8_MFMA = make_mfma_atom( name="CDNA3_32x32x16_I32I8I8_MFMA", inst="v_mfma_i32_32x32x16i8", - m=32, n=32, k=16, - group_size=4, num_groups_per_blk=4, - num_threads_per_blk=32, num_input_blks=2, - num_output_blks=1, k_per_blk=8, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=32, + n=32, + k=16, + group_size=4, + num_groups_per_blk=4, + num_threads_per_blk=32, + num_input_blks=2, + num_output_blks=1, + k_per_blk=8, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) # v_mfma_i32_16x16x32i8: 16x16 output, K=32 @@ -551,11 +671,18 @@ def make_mfma_atom( CDNA3_16x16x32_I32I8I8_MFMA = make_mfma_atom( name="CDNA3_16x16x32_I32I8I8_MFMA", inst="v_mfma_i32_16x16x32i8", - m=16, n=16, k=32, - group_size=4, num_groups_per_blk=1, - num_threads_per_blk=16, num_input_blks=4, - num_output_blks=1, k_per_blk=8, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=16, + n=16, + k=32, + group_size=4, + num_groups_per_blk=1, + num_threads_per_blk=16, + num_input_blks=4, + num_output_blks=1, + k_per_blk=8, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) # --- XF32 (TF32-like, CDNA3) --- @@ -564,22 +691,36 @@ def make_mfma_atom( CDNA3_32x32x4_F32XF32XF32_MFMA = make_mfma_atom( name="CDNA3_32x32x4_F32XF32XF32_MFMA", inst="v_mfma_f32_32x32x4_xf32", - m=32, n=32, k=4, - group_size=4, num_groups_per_blk=4, - num_threads_per_blk=32, num_input_blks=2, - num_output_blks=1, k_per_blk=2, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=32, + n=32, + k=4, + group_size=4, + num_groups_per_blk=4, + num_threads_per_blk=32, + num_input_blks=2, + num_output_blks=1, + k_per_blk=2, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) # v_mfma_f32_16x16x8_xf32 CDNA3_16x16x8_F32XF32XF32_MFMA = make_mfma_atom( name="CDNA3_16x16x8_F32XF32XF32_MFMA", inst="v_mfma_f32_16x16x8_xf32", - m=16, n=16, k=8, - group_size=4, num_groups_per_blk=1, - num_threads_per_blk=16, num_input_blks=4, - num_output_blks=1, k_per_blk=2, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=16, + n=16, + k=8, + group_size=4, + num_groups_per_blk=1, + num_threads_per_blk=16, + num_input_blks=4, + num_output_blks=1, + k_per_blk=2, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) @@ -591,85 +732,141 @@ def make_mfma_atom( CDNA3_32x32x16_F32F8F8_MFMA = make_mfma_atom( name="CDNA3_32x32x16_F32F8F8_MFMA", inst="v_mfma_f32_32x32x16_fp8_fp8", - m=32, n=32, k=16, - group_size=4, num_groups_per_blk=4, - num_threads_per_blk=32, num_input_blks=2, - num_output_blks=1, k_per_blk=8, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=32, + n=32, + k=16, + group_size=4, + num_groups_per_blk=4, + num_threads_per_blk=32, + num_input_blks=2, + num_output_blks=1, + k_per_blk=8, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) # v_mfma_f32_16x16x32_fp8_fp8 CDNA3_16x16x32_F32F8F8_MFMA = make_mfma_atom( name="CDNA3_16x16x32_F32F8F8_MFMA", inst="v_mfma_f32_16x16x32_fp8_fp8", - m=16, n=16, k=32, - group_size=4, num_groups_per_blk=1, - num_threads_per_blk=16, num_input_blks=4, - num_output_blks=1, k_per_blk=8, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=16, + n=16, + k=32, + group_size=4, + num_groups_per_blk=1, + num_threads_per_blk=16, + num_input_blks=4, + num_output_blks=1, + k_per_blk=8, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) # v_mfma_f32_32x32x16_bf8_bf8: same layout as fp8_fp8 32x32 CDNA3_32x32x16_F32BF8BF8_MFMA = make_mfma_atom( name="CDNA3_32x32x16_F32BF8BF8_MFMA", inst="v_mfma_f32_32x32x16_bf8_bf8", - m=32, n=32, k=16, - group_size=4, num_groups_per_blk=4, - num_threads_per_blk=32, num_input_blks=2, - num_output_blks=1, k_per_blk=8, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=32, + n=32, + k=16, + group_size=4, + num_groups_per_blk=4, + num_threads_per_blk=32, + num_input_blks=2, + num_output_blks=1, + k_per_blk=8, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) # v_mfma_f32_16x16x32_bf8_bf8 CDNA3_16x16x32_F32BF8BF8_MFMA = make_mfma_atom( name="CDNA3_16x16x32_F32BF8BF8_MFMA", inst="v_mfma_f32_16x16x32_bf8_bf8", - m=16, n=16, k=32, - group_size=4, num_groups_per_blk=1, - num_threads_per_blk=16, num_input_blks=4, - num_output_blks=1, k_per_blk=8, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=16, + n=16, + k=32, + group_size=4, + num_groups_per_blk=1, + num_threads_per_blk=16, + num_input_blks=4, + num_output_blks=1, + k_per_blk=8, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) # Mixed FP8 variants (fp8 x bf8, bf8 x fp8) — same layouts CDNA3_32x32x16_F32F8BF8_MFMA = make_mfma_atom( name="CDNA3_32x32x16_F32F8BF8_MFMA", inst="v_mfma_f32_32x32x16_fp8_bf8", - m=32, n=32, k=16, - group_size=4, num_groups_per_blk=4, - num_threads_per_blk=32, num_input_blks=2, - num_output_blks=1, k_per_blk=8, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=32, + n=32, + k=16, + group_size=4, + num_groups_per_blk=4, + num_threads_per_blk=32, + num_input_blks=2, + num_output_blks=1, + k_per_blk=8, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) CDNA3_16x16x32_F32F8BF8_MFMA = make_mfma_atom( name="CDNA3_16x16x32_F32F8BF8_MFMA", inst="v_mfma_f32_16x16x32_fp8_bf8", - m=16, n=16, k=32, - group_size=4, num_groups_per_blk=1, - num_threads_per_blk=16, num_input_blks=4, - num_output_blks=1, k_per_blk=8, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=16, + n=16, + k=32, + group_size=4, + num_groups_per_blk=1, + num_threads_per_blk=16, + num_input_blks=4, + num_output_blks=1, + k_per_blk=8, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) CDNA3_32x32x16_F32BF8F8_MFMA = make_mfma_atom( name="CDNA3_32x32x16_F32BF8F8_MFMA", inst="v_mfma_f32_32x32x16_bf8_fp8", - m=32, n=32, k=16, - group_size=4, num_groups_per_blk=4, - num_threads_per_blk=32, num_input_blks=2, - num_output_blks=1, k_per_blk=8, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=32, + n=32, + k=16, + group_size=4, + num_groups_per_blk=4, + num_threads_per_blk=32, + num_input_blks=2, + num_output_blks=1, + k_per_blk=8, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) CDNA3_16x16x32_F32BF8F8_MFMA = make_mfma_atom( name="CDNA3_16x16x32_F32BF8F8_MFMA", inst="v_mfma_f32_16x16x32_bf8_fp8", - m=16, n=16, k=32, - group_size=4, num_groups_per_blk=1, - num_threads_per_blk=16, num_input_blks=4, - num_output_blks=1, k_per_blk=8, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=16, + n=16, + k=32, + group_size=4, + num_groups_per_blk=1, + num_threads_per_blk=16, + num_input_blks=4, + num_output_blks=1, + k_per_blk=8, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) @@ -683,11 +880,18 @@ def make_mfma_atom( CDNA3P_32x32x16_F32F16F16_MFMA = make_mfma_atom( name="CDNA3P_32x32x16_F32F16F16_MFMA", inst="v_mfma_f32_32x32x16_f16", - m=32, n=32, k=16, - group_size=4, num_groups_per_blk=4, - num_threads_per_blk=32, num_input_blks=2, - num_output_blks=1, k_per_blk=8, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=32, + n=32, + k=16, + group_size=4, + num_groups_per_blk=4, + num_threads_per_blk=32, + num_input_blks=2, + num_output_blks=1, + k_per_blk=8, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) # v_mfma_f32_16x16x32_f16 (gfx950 only): 2x K vs 16x16x16f16 @@ -696,11 +900,18 @@ def make_mfma_atom( CDNA3P_16x16x32_F32F16F16_MFMA = make_mfma_atom( name="CDNA3P_16x16x32_F32F16F16_MFMA", inst="v_mfma_f32_16x16x32_f16", - m=16, n=16, k=32, - group_size=4, num_groups_per_blk=1, - num_threads_per_blk=16, num_input_blks=4, - num_output_blks=1, k_per_blk=8, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=16, + n=16, + k=32, + group_size=4, + num_groups_per_blk=1, + num_threads_per_blk=16, + num_input_blks=4, + num_output_blks=1, + k_per_blk=8, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) # v_mfma_f32_32x32x16_bf16 (gfx950 only) @@ -709,44 +920,72 @@ def make_mfma_atom( CDNA3P_32x32x16_F32BF16BF16_MFMA = make_mfma_atom( name="CDNA3P_32x32x16_F32BF16BF16_MFMA", inst="v_mfma_f32_32x32x16_bf16", - m=32, n=32, k=16, - group_size=4, num_groups_per_blk=4, - num_threads_per_blk=32, num_input_blks=2, - num_output_blks=1, k_per_blk=8, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=32, + n=32, + k=16, + group_size=4, + num_groups_per_blk=4, + num_threads_per_blk=32, + num_input_blks=2, + num_output_blks=1, + k_per_blk=8, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) # v_mfma_f32_16x16x32_bf16 (gfx950 only) CDNA3P_16x16x32_F32BF16BF16_MFMA = make_mfma_atom( name="CDNA3P_16x16x32_F32BF16BF16_MFMA", inst="v_mfma_f32_16x16x32_bf16", - m=16, n=16, k=32, - group_size=4, num_groups_per_blk=1, - num_threads_per_blk=16, num_input_blks=4, - num_output_blks=1, k_per_blk=8, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=16, + n=16, + k=32, + group_size=4, + num_groups_per_blk=1, + num_threads_per_blk=16, + num_input_blks=4, + num_output_blks=1, + k_per_blk=8, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) # v_mfma_i32_32x32x32_i8 (gfx950 only) CDNA3P_32x32x32_I32I8I8_MFMA = make_mfma_atom( name="CDNA3P_32x32x32_I32I8I8_MFMA", inst="v_mfma_i32_32x32x32i8", - m=32, n=32, k=32, - group_size=4, num_groups_per_blk=4, - num_threads_per_blk=32, num_input_blks=2, - num_output_blks=1, k_per_blk=16, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=32, + n=32, + k=32, + group_size=4, + num_groups_per_blk=4, + num_threads_per_blk=32, + num_input_blks=2, + num_output_blks=1, + k_per_blk=16, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) # v_mfma_i32_16x16x64_i8 (gfx950 only) CDNA3P_16x16x64_I32I8I8_MFMA = make_mfma_atom( name="CDNA3P_16x16x64_I32I8I8_MFMA", inst="v_mfma_i32_16x16x64i8", - m=16, n=16, k=64, - group_size=4, num_groups_per_blk=1, - num_threads_per_blk=16, num_input_blks=4, - num_output_blks=1, k_per_blk=16, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=16, + n=16, + k=64, + group_size=4, + num_groups_per_blk=1, + num_threads_per_blk=16, + num_input_blks=4, + num_output_blks=1, + k_per_blk=16, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) @@ -920,7 +1159,7 @@ def make_wmma_atom( name=name, ptx=inst, shape_mnk=(m, n, k), - thr_id=None, # identity: lane_id = thread_idx % 32 + thr_id=None, # identity: lane_id = thread_idx % 32 a_layout=a_layout, b_layout=b_layout, c_layout=c_layout, @@ -935,40 +1174,52 @@ def make_wmma_atom( RDNA3_16x16x16_F32F16F16_WMMA = make_wmma_atom( name="RDNA3_16x16x16_F32F16F16_WMMA", inst="v_wmma_f32_16x16x16_f16", - m=16, n=16, k=16, + m=16, + n=16, + k=16, ) RDNA3_16x16x16_F16F16F16_WMMA = make_wmma_atom( name="RDNA3_16x16x16_F16F16F16_WMMA", inst="v_wmma_f16_16x16x16_f16", - m=16, n=16, k=16, + m=16, + n=16, + k=16, ) # --- BF16 --- RDNA3_16x16x16_F32BF16BF16_WMMA = make_wmma_atom( name="RDNA3_16x16x16_F32BF16BF16_WMMA", inst="v_wmma_f32_16x16x16_bf16", - m=16, n=16, k=16, + m=16, + n=16, + k=16, ) RDNA3_16x16x16_BF16BF16BF16_WMMA = make_wmma_atom( name="RDNA3_16x16x16_BF16BF16BF16_WMMA", inst="v_wmma_bf16_16x16x16_bf16", - m=16, n=16, k=16, + m=16, + n=16, + k=16, ) # --- INT8 --- RDNA3_16x16x16_I32I8I8_WMMA = make_wmma_atom( name="RDNA3_16x16x16_I32I8I8_WMMA", inst="v_wmma_i32_16x16x16_iu8", - m=16, n=16, k=16, + m=16, + n=16, + k=16, ) # --- INT4 --- RDNA3_16x16x16_I32I4I4_WMMA = make_wmma_atom( name="RDNA3_16x16x16_I32I4I4_WMMA", inst="v_wmma_i32_16x16x16_iu4", - m=16, n=16, k=16, + m=16, + n=16, + k=16, ) @@ -980,53 +1231,69 @@ def make_wmma_atom( RDNA4_16x16x32_F32F16F16_WMMA = make_wmma_atom( name="RDNA4_16x16x32_F32F16F16_WMMA", inst="v_wmma_f32_16x16x32_f16", - m=16, n=16, k=32, + m=16, + n=16, + k=32, ) RDNA4_16x16x32_F16F16F16_WMMA = make_wmma_atom( name="RDNA4_16x16x32_F16F16F16_WMMA", inst="v_wmma_f16_16x16x32_f16", - m=16, n=16, k=32, + m=16, + n=16, + k=32, ) # --- BF16 --- RDNA4_16x16x32_F32BF16BF16_WMMA = make_wmma_atom( name="RDNA4_16x16x32_F32BF16BF16_WMMA", inst="v_wmma_f32_16x16x32_bf16", - m=16, n=16, k=32, + m=16, + n=16, + k=32, ) RDNA4_16x16x32_BF16BF16BF16_WMMA = make_wmma_atom( name="RDNA4_16x16x32_BF16BF16BF16_WMMA", inst="v_wmma_bf16_16x16x32_bf16", - m=16, n=16, k=32, + m=16, + n=16, + k=32, ) # --- FP8 (new in RDNA4) --- RDNA4_16x16x32_F32F8F8_WMMA = make_wmma_atom( name="RDNA4_16x16x32_F32F8F8_WMMA", inst="v_wmma_f32_16x16x32_fp8_fp8", - m=16, n=16, k=32, + m=16, + n=16, + k=32, ) RDNA4_16x16x32_F32BF8BF8_WMMA = make_wmma_atom( name="RDNA4_16x16x32_F32BF8BF8_WMMA", inst="v_wmma_f32_16x16x32_bf8_bf8", - m=16, n=16, k=32, + m=16, + n=16, + k=32, ) # --- INT8 (doubled K) --- RDNA4_16x16x32_I32I8I8_WMMA = make_wmma_atom( name="RDNA4_16x16x32_I32I8I8_WMMA", inst="v_wmma_i32_16x16x32_iu8", - m=16, n=16, k=32, + m=16, + n=16, + k=32, ) # --- INT4 (quadrupled K) --- RDNA4_16x16x64_I32I4I4_WMMA = make_wmma_atom( name="RDNA4_16x16x64_I32I4I4_WMMA", inst="v_wmma_i32_16x16x64_iu4", - m=16, n=16, k=64, + m=16, + n=16, + k=64, ) diff --git a/src/tensor_layouts/atoms_amx.py b/src/tensor_layouts/atoms_amx.py index 53f5a2f..bff57ff 100644 --- a/src/tensor_layouts/atoms_amx.py +++ b/src/tensor_layouts/atoms_amx.py @@ -80,8 +80,8 @@ print(atom.c_layout) # (1, (16, 16)):(0, (1, 16)) """ -from .layouts import Layout from .atoms import MMAAtom +from .layouts import Layout # ============================================================================= @@ -102,58 +102,70 @@ AMX_16x16x32_F32BF16BF16F32 = MMAAtom( name="AMX_16x16x32_F32BF16BF16F32", ptx="tdpbf16ps", - shape_mnk=(16, 16, 32), thr_id=Layout(1), + shape_mnk=(16, 16, 32), + thr_id=Layout(1), # (T=1, V=512) -> col-major offset in (M=16, K=32) a_layout=Layout((1, (16, 32)), (0, (1, 16))), # (T=1, V=512) -> col-major offset in (N=16, K=32) b_layout=Layout((1, (16, 32)), (0, (1, 16))), # (T=1, V=256) -> col-major offset in (M=16, N=16) - c_layout=Layout((1, (16, 16)), (0, (1, 16)))) + c_layout=Layout((1, (16, 16)), (0, (1, 16))), +) # -- FP16 -> FP32 ------------------------------------------------------------- AMX_16x16x32_F32F16F16F32 = MMAAtom( name="AMX_16x16x32_F32F16F16F32", ptx="tdpfp16ps", - shape_mnk=(16, 16, 32), thr_id=Layout(1), + shape_mnk=(16, 16, 32), + thr_id=Layout(1), a_layout=Layout((1, (16, 32)), (0, (1, 16))), b_layout=Layout((1, (16, 32)), (0, (1, 16))), - c_layout=Layout((1, (16, 16)), (0, (1, 16)))) + c_layout=Layout((1, (16, 16)), (0, (1, 16))), +) # -- INT8 x INT8 -> INT32 (signed x signed) ----------------------------------- AMX_16x16x64_S32S8S8S32 = MMAAtom( name="AMX_16x16x64_S32S8S8S32", ptx="tdpbssd", - shape_mnk=(16, 16, 64), thr_id=Layout(1), + shape_mnk=(16, 16, 64), + thr_id=Layout(1), # (T=1, V=1024) -> col-major offset in (M=16, K=64) a_layout=Layout((1, (16, 64)), (0, (1, 16))), # (T=1, V=1024) -> col-major offset in (N=16, K=64) b_layout=Layout((1, (16, 64)), (0, (1, 16))), # (T=1, V=256) -> col-major offset in (M=16, N=16) - c_layout=Layout((1, (16, 16)), (0, (1, 16)))) + c_layout=Layout((1, (16, 16)), (0, (1, 16))), +) # -- INT8 x UINT8 -> INT32 (signed x unsigned) -------------------------------- AMX_16x16x64_S32S8U8S32 = MMAAtom( name="AMX_16x16x64_S32S8U8S32", ptx="tdpbsud", - shape_mnk=(16, 16, 64), thr_id=Layout(1), + shape_mnk=(16, 16, 64), + thr_id=Layout(1), a_layout=Layout((1, (16, 64)), (0, (1, 16))), b_layout=Layout((1, (16, 64)), (0, (1, 16))), - c_layout=Layout((1, (16, 16)), (0, (1, 16)))) + c_layout=Layout((1, (16, 16)), (0, (1, 16))), +) # -- UINT8 x INT8 -> INT32 (unsigned x signed) -------------------------------- AMX_16x16x64_S32U8S8S32 = MMAAtom( name="AMX_16x16x64_S32U8S8S32", ptx="tdpbusd", - shape_mnk=(16, 16, 64), thr_id=Layout(1), + shape_mnk=(16, 16, 64), + thr_id=Layout(1), a_layout=Layout((1, (16, 64)), (0, (1, 16))), b_layout=Layout((1, (16, 64)), (0, (1, 16))), - c_layout=Layout((1, (16, 16)), (0, (1, 16)))) + c_layout=Layout((1, (16, 16)), (0, (1, 16))), +) # -- UINT8 x UINT8 -> INT32 (unsigned x unsigned) ----------------------------- AMX_16x16x64_S32U8U8S32 = MMAAtom( name="AMX_16x16x64_S32U8U8S32", ptx="tdpbuud", - shape_mnk=(16, 16, 64), thr_id=Layout(1), + shape_mnk=(16, 16, 64), + thr_id=Layout(1), a_layout=Layout((1, (16, 64)), (0, (1, 16))), b_layout=Layout((1, (16, 64)), (0, (1, 16))), - c_layout=Layout((1, (16, 16)), (0, (1, 16)))) + c_layout=Layout((1, (16, 16)), (0, (1, 16))), +) diff --git a/src/tensor_layouts/atoms_nv.py b/src/tensor_layouts/atoms_nv.py index d8017a5..276bede 100644 --- a/src/tensor_layouts/atoms_nv.py +++ b/src/tensor_layouts/atoms_nv.py @@ -59,8 +59,8 @@ print(SM70_8x8x4_F32F16F16F32_NT.a_layout) """ +from .atoms import CopyAtom, MMAAtom from .layouts import Layout -from .atoms import MMAAtom, CopyAtom # ============================================================================= # SM61 Pascal DP MMA atoms — 1 "thread" (scalar) # Source: include/cute/atom/mma_traits_sm61.hpp @@ -70,18 +70,22 @@ SM61_1x1x4_S32S8S8S32 = MMAAtom( name="SM61_DP4A", ptx="dp4a.s32.s32", - shape_mnk=(1, 1, 4), thr_id=Layout(1), + shape_mnk=(1, 1, 4), + thr_id=Layout(1), a_layout=Layout((1, 4)), b_layout=Layout((1, 4)), - c_layout=Layout((1, 1))) + c_layout=Layout((1, 1)), +) SM61_1x1x2_S32S16S16S32 = MMAAtom( name="SM61_DP2A", ptx="dp2a.s32.s32", - shape_mnk=(1, 1, 2), thr_id=Layout(1), + shape_mnk=(1, 1, 2), + thr_id=Layout(1), a_layout=Layout((1, 2)), b_layout=Layout((1, 2)), - c_layout=Layout((1, 1))) + c_layout=Layout((1, 1)), +) # ============================================================================= @@ -90,28 +94,41 @@ # ============================================================================= # Logical thread id → warp lane index (quadpair: lanes 0-3 and 16-19) -SM70_QuadPair = Layout((4, 2), (1, 16)) # line 44 -SM70_8x4_Row = Layout((8, 4), (1, 8)) # line 47: (T8,V4) → (M8,K4) -SM70_8x4_Col = Layout(((4, 2), 4), # line 50: (T8,V4) → (M8,K4) - ((8, 4), 1)) -SM70_8x8_16b = Layout((8, 8), (1, 8)) # line 53: (T8,V8) → (M8,N8) fp16 accum -SM70_8x8_32b = Layout(((2, 2, 2), # line 56: (T8,V8) → (M8,N8) fp32 accum - (2, 2, 2)), - ((1, 16, 4), - (8, 2, 32))) +SM70_QuadPair = Layout((4, 2), (1, 16)) # line 44 +SM70_8x4_Row = Layout((8, 4), (1, 8)) # line 47: (T8,V4) → (M8,K4) +SM70_8x4_Col = Layout( + ((4, 2), 4), # line 50: (T8,V4) → (M8,K4) + ((8, 4), 1), +) +SM70_8x8_16b = Layout((8, 8), (1, 8)) # line 53: (T8,V8) → (M8,N8) fp16 accum +SM70_8x8_32b = Layout( + ( + (2, 2, 2), # line 56: (T8,V8) → (M8,N8) fp32 accum + (2, 2, 2), + ), + ((1, 16, 4), (8, 2, 32)), +) # ============================================================================= # From mma_traits_sm80.hpp (lines 41-55) # ============================================================================= -SM80_8x4 = Layout(((4, 8), 1), # line 42: (T32,V1) → (M8,N8) - ((8, 1), 0)) -SM80_8x8_Row = Layout(((4, 8), 2), # line 46: (T32,V2) → (M8,N8) - ((16, 1), 8)) -SM80_8x16_Row = Layout(((4, 8), 4), # line 50: (T32,V4) → (M8,N16) - ((32, 1), 8)) -SM80_16x8_Row = Layout(((4, 8), (2, 2)), # line 53: (T32,V4) → (M16,N8) - ((32, 1), (16, 8))) +SM80_8x4 = Layout( + ((4, 8), 1), # line 42: (T32,V1) → (M8,N8) + ((8, 1), 0), +) +SM80_8x8_Row = Layout( + ((4, 8), 2), # line 46: (T32,V2) → (M8,N8) + ((16, 1), 8), +) +SM80_8x16_Row = Layout( + ((4, 8), 4), # line 50: (T32,V4) → (M8,N16) + ((32, 1), 8), +) +SM80_16x8_Row = Layout( + ((4, 8), (2, 2)), # line 53: (T32,V4) → (M16,N8) + ((32, 1), (16, 8)), +) # ============================================================================= @@ -125,58 +142,90 @@ SM70_8x8x4_F16F16F16F16_TN = MMAAtom( name="SM70_8x8x4_F16F16F16F16_TN", ptx="mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16", - shape_mnk=(8, 8, 4), thr_id=SM70_QuadPair, - a_layout=SM70_8x4_Row, b_layout=SM70_8x4_Row, c_layout=SM70_8x8_16b) + shape_mnk=(8, 8, 4), + thr_id=SM70_QuadPair, + a_layout=SM70_8x4_Row, + b_layout=SM70_8x4_Row, + c_layout=SM70_8x8_16b, +) # line 81 — fp16 accumulator, A=col-major, B=row-major SM70_8x8x4_F16F16F16F16_NT = MMAAtom( name="SM70_8x8x4_F16F16F16F16_NT", ptx="mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16", - shape_mnk=(8, 8, 4), thr_id=SM70_QuadPair, - a_layout=SM70_8x4_Col, b_layout=SM70_8x4_Col, c_layout=SM70_8x8_16b) + shape_mnk=(8, 8, 4), + thr_id=SM70_QuadPair, + a_layout=SM70_8x4_Col, + b_layout=SM70_8x4_Col, + c_layout=SM70_8x8_16b, +) # line 98 SM70_8x8x4_F16F16F16F16_NN = MMAAtom( name="SM70_8x8x4_F16F16F16F16_NN", ptx="mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16", - shape_mnk=(8, 8, 4), thr_id=SM70_QuadPair, - a_layout=SM70_8x4_Col, b_layout=SM70_8x4_Row, c_layout=SM70_8x8_16b) + shape_mnk=(8, 8, 4), + thr_id=SM70_QuadPair, + a_layout=SM70_8x4_Col, + b_layout=SM70_8x4_Row, + c_layout=SM70_8x8_16b, +) # line 115 SM70_8x8x4_F16F16F16F16_TT = MMAAtom( name="SM70_8x8x4_F16F16F16F16_TT", ptx="mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16", - shape_mnk=(8, 8, 4), thr_id=SM70_QuadPair, - a_layout=SM70_8x4_Row, b_layout=SM70_8x4_Col, c_layout=SM70_8x8_16b) + shape_mnk=(8, 8, 4), + thr_id=SM70_QuadPair, + a_layout=SM70_8x4_Row, + b_layout=SM70_8x4_Col, + c_layout=SM70_8x8_16b, +) # line 132 — fp32 accumulator, A=row-major, B=col-major SM70_8x8x4_F32F16F16F32_TN = MMAAtom( name="SM70_8x8x4_F32F16F16F32_TN", ptx="mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32", - shape_mnk=(8, 8, 4), thr_id=SM70_QuadPair, - a_layout=SM70_8x4_Row, b_layout=SM70_8x4_Row, c_layout=SM70_8x8_32b) + shape_mnk=(8, 8, 4), + thr_id=SM70_QuadPair, + a_layout=SM70_8x4_Row, + b_layout=SM70_8x4_Row, + c_layout=SM70_8x8_32b, +) # line 149 — fp32 accumulator, A=col-major, B=row-major # Reference image: media/images/cute/HMMA.8x8x4.NT_Atom.png SM70_8x8x4_F32F16F16F32_NT = MMAAtom( name="SM70_8x8x4_F32F16F16F32_NT", ptx="mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32", - shape_mnk=(8, 8, 4), thr_id=SM70_QuadPair, - a_layout=SM70_8x4_Col, b_layout=SM70_8x4_Col, c_layout=SM70_8x8_32b) + shape_mnk=(8, 8, 4), + thr_id=SM70_QuadPair, + a_layout=SM70_8x4_Col, + b_layout=SM70_8x4_Col, + c_layout=SM70_8x8_32b, +) # line 166 SM70_8x8x4_F32F16F16F32_NN = MMAAtom( name="SM70_8x8x4_F32F16F16F32_NN", ptx="mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32", - shape_mnk=(8, 8, 4), thr_id=SM70_QuadPair, - a_layout=SM70_8x4_Col, b_layout=SM70_8x4_Row, c_layout=SM70_8x8_32b) + shape_mnk=(8, 8, 4), + thr_id=SM70_QuadPair, + a_layout=SM70_8x4_Col, + b_layout=SM70_8x4_Row, + c_layout=SM70_8x8_32b, +) # line 183 SM70_8x8x4_F32F16F16F32_TT = MMAAtom( name="SM70_8x8x4_F32F16F16F32_TT", ptx="mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32", - shape_mnk=(8, 8, 4), thr_id=SM70_QuadPair, - a_layout=SM70_8x4_Row, b_layout=SM70_8x4_Col, c_layout=SM70_8x8_32b) + shape_mnk=(8, 8, 4), + thr_id=SM70_QuadPair, + a_layout=SM70_8x4_Row, + b_layout=SM70_8x4_Col, + c_layout=SM70_8x8_32b, +) # ============================================================================= @@ -189,18 +238,22 @@ SM75_16x8x8_F32F16F16F32_TN = MMAAtom( name="SM75_16x8x8_F32F16F16F32_TN", ptx="mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32", - shape_mnk=(16, 8, 8), thr_id=None, + shape_mnk=(16, 8, 8), + thr_id=None, a_layout=Layout(((4, 8), (2, 2)), ((32, 1), (16, 8))), b_layout=Layout(((4, 8), 2), ((16, 1), 8)), - c_layout=Layout(((4, 8), (2, 2)), ((32, 1), (16, 8)))) + c_layout=Layout(((4, 8), (2, 2)), ((32, 1), (16, 8))), +) SM75_8x8x16_S32S8S8S32_TN = MMAAtom( name="SM75_8x8x16_S32S8S8S32_TN", ptx="mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32", - shape_mnk=(8, 8, 16), thr_id=None, + shape_mnk=(8, 8, 16), + thr_id=None, a_layout=Layout(((4, 8), 4), ((32, 1), 8)), b_layout=Layout(((4, 8), 4), ((32, 1), 8)), - c_layout=Layout(((4, 8), 2), ((16, 1), 8))) + c_layout=Layout(((4, 8), 2), ((16, 1), 8)), +) # ============================================================================= @@ -220,70 +273,96 @@ SM80_16x8x8_F16F16F16F16_TN = MMAAtom( name="SM80_16x8x8_F16F16F16F16_TN", ptx="mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16", - shape_mnk=(16, 8, 8), thr_id=None, - a_layout=SM80_16x8_Row, b_layout=SM80_8x8_Row, c_layout=SM80_16x8_Row) + shape_mnk=(16, 8, 8), + thr_id=None, + a_layout=SM80_16x8_Row, + b_layout=SM80_8x8_Row, + c_layout=SM80_16x8_Row, +) SM80_16x8x16_F16F16F16F16_TN = MMAAtom( name="SM80_16x8x16_F16F16F16F16_TN", ptx="mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16", - shape_mnk=(16, 8, 16), thr_id=None, + shape_mnk=(16, 8, 16), + thr_id=None, a_layout=Layout(((4, 8), (2, 2, 2)), ((32, 1), (16, 8, 128))), b_layout=Layout(((4, 8), (2, 2)), ((16, 1), (8, 64))), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # --- FP32 accumulator with FP16 inputs --- SM80_16x8x8_F32F16F16F32_TN = MMAAtom( name="SM80_16x8x8_F32F16F16F32_TN", ptx="mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32", - shape_mnk=(16, 8, 8), thr_id=None, - a_layout=SM80_16x8_Row, b_layout=SM80_8x8_Row, c_layout=SM80_16x8_Row) + shape_mnk=(16, 8, 8), + thr_id=None, + a_layout=SM80_16x8_Row, + b_layout=SM80_8x8_Row, + c_layout=SM80_16x8_Row, +) SM80_16x8x16_F32F16F16F32_TN = MMAAtom( name="SM80_16x8x16_F32F16F16F32_TN", ptx="mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32", - shape_mnk=(16, 8, 16), thr_id=None, + shape_mnk=(16, 8, 16), + thr_id=None, a_layout=Layout(((4, 8), (2, 2, 2)), ((32, 1), (16, 8, 128))), b_layout=Layout(((4, 8), (2, 2)), ((16, 1), (8, 64))), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # --- BF16 (same layouts as FP16) --- SM80_16x8x8_F32BF16BF16F32_TN = MMAAtom( name="SM80_16x8x8_F32BF16BF16F32_TN", ptx="mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32", - shape_mnk=(16, 8, 8), thr_id=None, - a_layout=SM80_16x8_Row, b_layout=SM80_8x8_Row, c_layout=SM80_16x8_Row) + shape_mnk=(16, 8, 8), + thr_id=None, + a_layout=SM80_16x8_Row, + b_layout=SM80_8x8_Row, + c_layout=SM80_16x8_Row, +) SM80_16x8x16_F32BF16BF16F32_TN = MMAAtom( name="SM80_16x8x16_F32BF16BF16F32_TN", ptx="mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32", - shape_mnk=(16, 8, 16), thr_id=None, + shape_mnk=(16, 8, 16), + thr_id=None, a_layout=Layout(((4, 8), (2, 2, 2)), ((32, 1), (16, 8, 128))), b_layout=Layout(((4, 8), (2, 2)), ((16, 1), (8, 64))), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # --- TF32 (TensorFloat-32) --- SM80_16x8x4_F32TF32TF32F32_TN = MMAAtom( name="SM80_16x8x4_F32TF32TF32F32_TN", ptx="mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32", - shape_mnk=(16, 8, 4), thr_id=None, + shape_mnk=(16, 8, 4), + thr_id=None, a_layout=Layout(((4, 8), 2), ((16, 1), 8)), b_layout=SM80_8x4, - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) SM80_16x8x8_F32TF32TF32F32_TN = MMAAtom( name="SM80_16x8x8_F32TF32TF32F32_TN", ptx="mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32", - shape_mnk=(16, 8, 8), thr_id=None, + shape_mnk=(16, 8, 8), + thr_id=None, a_layout=Layout(((4, 8), (2, 2)), ((16, 1), (8, 64))), b_layout=Layout(((4, 8), 2), ((8, 1), 32)), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # --- FP64 --- SM80_8x8x4_F64F64F64F64_TN = MMAAtom( name="SM80_8x8x4_F64F64F64F64_TN", ptx="mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64", - shape_mnk=(8, 8, 4), thr_id=None, - a_layout=SM80_8x4, b_layout=SM80_8x4, c_layout=SM80_8x8_Row) + shape_mnk=(8, 8, 4), + thr_id=None, + a_layout=SM80_8x4, + b_layout=SM80_8x4, + c_layout=SM80_8x8_Row, +) # --- INT8 (s8×s8, s8×u8, u8×s8, u8×u8 all share layouts at same tile size) --- @@ -291,26 +370,34 @@ SM80_8x8x16_S32S8S8S32_TN = MMAAtom( name="SM80_8x8x16_S32S8S8S32_TN", ptx="mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32", - shape_mnk=(8, 8, 16), thr_id=None, - a_layout=SM80_8x16_Row, b_layout=SM80_8x16_Row, c_layout=SM80_8x8_Row) + shape_mnk=(8, 8, 16), + thr_id=None, + a_layout=SM80_8x16_Row, + b_layout=SM80_8x16_Row, + c_layout=SM80_8x8_Row, +) # 16x8x16 SM80_16x8x16_S32S8S8S32_TN = MMAAtom( name="SM80_16x8x16_S32S8S8S32_TN", ptx="mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32", - shape_mnk=(16, 8, 16), thr_id=None, + shape_mnk=(16, 8, 16), + thr_id=None, a_layout=Layout(((4, 8), (4, 2)), ((64, 1), (16, 8))), b_layout=SM80_8x16_Row, - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # 16x8x32 SM80_16x8x32_S32S8S8S32_TN = MMAAtom( name="SM80_16x8x32_S32S8S8S32_TN", ptx="mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32", - shape_mnk=(16, 8, 32), thr_id=None, + shape_mnk=(16, 8, 32), + thr_id=None, a_layout=Layout(((4, 8), (4, 2, 2)), ((64, 1), (16, 8, 256))), b_layout=Layout(((4, 8), (4, 2)), ((32, 1), (8, 128))), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # --- INT4 --- @@ -318,54 +405,66 @@ SM80_8x8x32_S32S4S4S32_TN = MMAAtom( name="SM80_8x8x32_S32S4S4S32_TN", ptx="mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32", - shape_mnk=(8, 8, 32), thr_id=None, + shape_mnk=(8, 8, 32), + thr_id=None, a_layout=Layout(((4, 8), 8), ((64, 1), 8)), b_layout=Layout(((4, 8), 8), ((64, 1), 8)), - c_layout=SM80_8x8_Row) + c_layout=SM80_8x8_Row, +) # 16x8x32 SM80_16x8x32_S32S4S4S32_TN = MMAAtom( name="SM80_16x8x32_S32S4S4S32_TN", ptx="mma.sync.aligned.m16n8k32.row.col.s32.s4.s4.s32", - shape_mnk=(16, 8, 32), thr_id=None, + shape_mnk=(16, 8, 32), + thr_id=None, a_layout=Layout(((4, 8), (8, 2)), ((128, 1), (16, 8))), b_layout=Layout(((4, 8), 8), ((32, 1), 8)), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # 16x8x64 SM80_16x8x64_S32S4S4S32_TN = MMAAtom( name="SM80_16x8x64_S32S4S4S32_TN", ptx="mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32", - shape_mnk=(16, 8, 64), thr_id=None, + shape_mnk=(16, 8, 64), + thr_id=None, a_layout=Layout(((4, 8), (8, 2, 2)), ((128, 1), (16, 8, 512))), b_layout=Layout(((4, 8), (8, 2)), ((64, 1), (8, 256))), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # --- Binary (U1) --- SM80_8x8x128_S32U1U1S32_TN_XORPOPC = MMAAtom( name="SM80_8x8x128_S32U1U1S32_TN_XORPOPC", ptx="mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.xor.popc", - shape_mnk=(8, 8, 128), thr_id=None, + shape_mnk=(8, 8, 128), + thr_id=None, a_layout=Layout(((4, 8), 32), ((256, 1), 8)), b_layout=Layout(((4, 8), 32), ((256, 1), 8)), - c_layout=SM80_8x8_Row) + c_layout=SM80_8x8_Row, +) SM80_16x8x128_S32U1U1S32_TN_XORPOPC = MMAAtom( name="SM80_16x8x128_S32U1U1S32_TN_XORPOPC", ptx="mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.xor.popc", - shape_mnk=(16, 8, 128), thr_id=None, + shape_mnk=(16, 8, 128), + thr_id=None, a_layout=Layout(((4, 8), (32, 2)), ((512, 1), (16, 8))), b_layout=Layout(((4, 8), 32), ((256, 1), 8)), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) SM80_16x8x256_S32U1U1S32_TN_XORPOPC = MMAAtom( name="SM80_16x8x256_S32U1U1S32_TN_XORPOPC", ptx="mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc", - shape_mnk=(16, 8, 256), thr_id=None, + shape_mnk=(16, 8, 256), + thr_id=None, a_layout=Layout(((4, 8), (32, 2, 2)), ((512, 1), (16, 8, 2048))), b_layout=Layout(((4, 8), (32, 2)), ((256, 1), (8, 1024))), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # ============================================================================= @@ -381,67 +480,83 @@ SM89_16x8x32_F32E4M3E4M3F32_TN = MMAAtom( name="SM89_16x8x32_F32E4M3E4M3F32_TN", ptx="mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32", - shape_mnk=(16, 8, 32), thr_id=None, + shape_mnk=(16, 8, 32), + thr_id=None, a_layout=Layout(((4, 8), (4, 2, 2)), ((64, 1), (16, 8, 256))), b_layout=Layout(((4, 8), (4, 2)), ((32, 1), (8, 128))), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) SM89_16x8x32_F32E4M3E5M2F32_TN = MMAAtom( name="SM89_16x8x32_F32E4M3E5M2F32_TN", ptx="mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e5m2.f32", - shape_mnk=(16, 8, 32), thr_id=None, + shape_mnk=(16, 8, 32), + thr_id=None, a_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.a_layout, b_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.b_layout, - c_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.c_layout) + c_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.c_layout, +) SM89_16x8x32_F32E5M2E5M2F32_TN = MMAAtom( name="SM89_16x8x32_F32E5M2E5M2F32_TN", ptx="mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32", - shape_mnk=(16, 8, 32), thr_id=None, + shape_mnk=(16, 8, 32), + thr_id=None, a_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.a_layout, b_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.b_layout, - c_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.c_layout) + c_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.c_layout, +) SM89_16x8x32_F32E5M2E4M3F32_TN = MMAAtom( name="SM89_16x8x32_F32E5M2E4M3F32_TN", ptx="mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e4m3.f32", - shape_mnk=(16, 8, 32), thr_id=None, + shape_mnk=(16, 8, 32), + thr_id=None, a_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.a_layout, b_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.b_layout, - c_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.c_layout) + c_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.c_layout, +) # FP16 accumulator variants (same layouts) SM89_16x8x32_F16E4M3E4M3F16_TN = MMAAtom( name="SM89_16x8x32_F16E4M3E4M3F16_TN", ptx="mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e4m3.f16", - shape_mnk=(16, 8, 32), thr_id=None, + shape_mnk=(16, 8, 32), + thr_id=None, a_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.a_layout, b_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.b_layout, - c_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.c_layout) + c_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.c_layout, +) SM89_16x8x32_F16E4M3E5M2F16_TN = MMAAtom( name="SM89_16x8x32_F16E4M3E5M2F16_TN", ptx="mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e5m2.f16", - shape_mnk=(16, 8, 32), thr_id=None, + shape_mnk=(16, 8, 32), + thr_id=None, a_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.a_layout, b_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.b_layout, - c_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.c_layout) + c_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.c_layout, +) SM89_16x8x32_F16E5M2E5M2F16_TN = MMAAtom( name="SM89_16x8x32_F16E5M2E5M2F16_TN", ptx="mma.sync.aligned.m16n8k32.row.col.f16.e5m2.e5m2.f16", - shape_mnk=(16, 8, 32), thr_id=None, + shape_mnk=(16, 8, 32), + thr_id=None, a_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.a_layout, b_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.b_layout, - c_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.c_layout) + c_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.c_layout, +) SM89_16x8x32_F16E5M2E4M3F16_TN = MMAAtom( name="SM89_16x8x32_F16E5M2E4M3F16_TN", ptx="mma.sync.aligned.m16n8k32.row.col.f16.e5m2.e4m3.f16", - shape_mnk=(16, 8, 32), thr_id=None, + shape_mnk=(16, 8, 32), + thr_id=None, a_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.a_layout, b_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.b_layout, - c_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.c_layout) + c_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.c_layout, +) # ============================================================================= @@ -455,54 +570,66 @@ SM90_16x8x4_F64F64F64F64_TN = MMAAtom( name="SM90_16x8x4_F64F64F64F64_TN", ptx="mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64", - shape_mnk=(16, 8, 4), thr_id=None, + shape_mnk=(16, 8, 4), + thr_id=None, a_layout=Layout(((4, 8), 2), ((16, 1), 8)), b_layout=SM80_8x4, - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # line 67 SM90_16x8x8_F64F64F64F64_TN = MMAAtom( name="SM90_16x8x8_F64F64F64F64_TN", ptx="mma.sync.aligned.m16n8k8.row.col.f64.f64.f64.f64", - shape_mnk=(16, 8, 8), thr_id=None, + shape_mnk=(16, 8, 8), + thr_id=None, a_layout=Layout(((4, 8), (2, 2)), ((16, 1), (8, 64))), b_layout=Layout(((4, 8), 2), ((8, 1), 32)), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # line 87 SM90_16x8x16_F64F64F64F64_TN = MMAAtom( name="SM90_16x8x16_F64F64F64F64_TN", ptx="mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64", - shape_mnk=(16, 8, 16), thr_id=None, + shape_mnk=(16, 8, 16), + thr_id=None, a_layout=Layout(((4, 8), (2, 4)), ((16, 1), (8, 64))), b_layout=Layout(((4, 8), 4), ((8, 1), 32)), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # --- Complex FP64 (same layouts as FP64, different value types) --- SM90_16x8x4_C64C64C64C64_TN = MMAAtom( name="SM90_16x8x4_C64C64C64C64_TN", ptx="mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64 (complex)", - shape_mnk=(16, 8, 4), thr_id=None, + shape_mnk=(16, 8, 4), + thr_id=None, a_layout=SM90_16x8x4_F64F64F64F64_TN.a_layout, b_layout=SM90_16x8x4_F64F64F64F64_TN.b_layout, - c_layout=SM90_16x8x4_F64F64F64F64_TN.c_layout) + c_layout=SM90_16x8x4_F64F64F64F64_TN.c_layout, +) SM90_16x8x8_C64C64C64C64_TN = MMAAtom( name="SM90_16x8x8_C64C64C64C64_TN", ptx="mma.sync.aligned.m16n8k8.row.col.f64.f64.f64.f64 (complex)", - shape_mnk=(16, 8, 8), thr_id=None, + shape_mnk=(16, 8, 8), + thr_id=None, a_layout=SM90_16x8x8_F64F64F64F64_TN.a_layout, b_layout=SM90_16x8x8_F64F64F64F64_TN.b_layout, - c_layout=SM90_16x8x8_F64F64F64F64_TN.c_layout) + c_layout=SM90_16x8x8_F64F64F64F64_TN.c_layout, +) SM90_16x8x16_C64C64C64C64_TN = MMAAtom( name="SM90_16x8x16_C64C64C64C64_TN", ptx="mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64 (complex)", - shape_mnk=(16, 8, 16), thr_id=None, + shape_mnk=(16, 8, 16), + thr_id=None, a_layout=SM90_16x8x16_F64F64F64F64_TN.a_layout, b_layout=SM90_16x8x16_F64F64F64F64_TN.b_layout, - c_layout=SM90_16x8x16_F64F64F64F64_TN.c_layout) + c_layout=SM90_16x8x16_F64F64F64F64_TN.c_layout, +) # ============================================================================= @@ -545,6 +672,7 @@ # 436-443). # ============================================================================= + def gmma_c_layout(n: int) -> Layout: """CLayout_64xN: accumulator layout for SM90 GMMA with N columns. @@ -553,8 +681,8 @@ def gmma_c_layout(n: int) -> Layout: See the section header above for the warpgroup-level meaning of T. Source: mma_traits_sm90_gmma.hpp line 432. """ - return Layout(((4, 8, 4), (2, 2, n // 8)), - ((128, 1, 16), (64, 8, 512))) + return Layout(((4, 8, 4), (2, 2, n // 8)), ((128, 1, 16), (64, 8, 512))) + def gmma_ab_layout(m: int, k: int) -> Layout: """ABLayout: shared-memory descriptor layout for an A or B tile. @@ -567,54 +695,67 @@ def gmma_ab_layout(m: int, k: int) -> Layout: """ return Layout((128, (m, k)), (0, (1, m))) + # line 657 — SM90_64x64x16_F16F16F16_SS SM90_64x8x16_F16F16F16_SS = MMAAtom( name="SM90_64x8x16_F16F16F16_SS", ptx="wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16", - shape_mnk=(64, 8, 16), thr_id=None, + shape_mnk=(64, 8, 16), + thr_id=None, a_layout=gmma_ab_layout(64, 16), b_layout=gmma_ab_layout(8, 16), - c_layout=gmma_c_layout(8)) + c_layout=gmma_c_layout(8), +) SM90_64x16x16_F16F16F16_SS = MMAAtom( name="SM90_64x16x16_F16F16F16_SS", ptx="wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16", - shape_mnk=(64, 16, 16), thr_id=None, + shape_mnk=(64, 16, 16), + thr_id=None, a_layout=gmma_ab_layout(64, 16), b_layout=gmma_ab_layout(16, 16), - c_layout=gmma_c_layout(16)) + c_layout=gmma_c_layout(16), +) SM90_64x32x16_F16F16F16_SS = MMAAtom( name="SM90_64x32x16_F16F16F16_SS", ptx="wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16", - shape_mnk=(64, 32, 16), thr_id=None, + shape_mnk=(64, 32, 16), + thr_id=None, a_layout=gmma_ab_layout(64, 16), b_layout=gmma_ab_layout(32, 16), - c_layout=gmma_c_layout(32)) + c_layout=gmma_c_layout(32), +) SM90_64x64x16_F16F16F16_SS = MMAAtom( name="SM90_64x64x16_F16F16F16_SS", ptx="wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16", - shape_mnk=(64, 64, 16), thr_id=None, + shape_mnk=(64, 64, 16), + thr_id=None, a_layout=gmma_ab_layout(64, 16), b_layout=gmma_ab_layout(64, 16), - c_layout=gmma_c_layout(64)) + c_layout=gmma_c_layout(64), +) SM90_64x128x16_F16F16F16_SS = MMAAtom( name="SM90_64x128x16_F16F16F16_SS", ptx="wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16", - shape_mnk=(64, 128, 16), thr_id=None, + shape_mnk=(64, 128, 16), + thr_id=None, a_layout=gmma_ab_layout(64, 16), b_layout=gmma_ab_layout(128, 16), - c_layout=gmma_c_layout(128)) + c_layout=gmma_c_layout(128), +) SM90_64x256x16_F16F16F16_SS = MMAAtom( name="SM90_64x256x16_F16F16F16_SS", ptx="wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16", - shape_mnk=(64, 256, 16), thr_id=None, + shape_mnk=(64, 256, 16), + thr_id=None, a_layout=gmma_ab_layout(64, 16), b_layout=gmma_ab_layout(256, 16), - c_layout=gmma_c_layout(256)) + c_layout=gmma_c_layout(256), +) # ============================================================================= @@ -630,8 +771,10 @@ def gmma_ab_layout(m: int, k: int) -> Layout: # provide a factory instead of enumerating hundreds of concrete atoms. # ============================================================================= -def make_gmma_atom_ss(n: int, k: int = 16, d_type: str = "F16", - ab_type: str | None = None) -> MMAAtom: + +def make_gmma_atom_ss( + n: int, k: int = 16, d_type: str = "F16", ab_type: str | None = None +) -> MMAAtom: """Create an SM90 GMMA SS atom for 64×N×K with the given data types. Args: @@ -648,10 +791,12 @@ def make_gmma_atom_ss(n: int, k: int = 16, d_type: str = "F16", return MMAAtom( name=name, ptx=f"wgmma.mma_async.sync.aligned.m64n{n}k{k}", - shape_mnk=(64, n, k), thr_id=None, + shape_mnk=(64, n, k), + thr_id=None, a_layout=gmma_ab_layout(64, k), b_layout=gmma_ab_layout(n, k), - c_layout=gmma_c_layout(n)) + c_layout=gmma_c_layout(n), + ) # Representative ext atoms (N values not in the base set) @@ -687,19 +832,35 @@ def make_gmma_atom_ss(n: int, k: int = 16, d_type: str = "F16", # FP8 E4M3 GMMA atoms (K=32 for 8-bit types) SM90_64x64x32_F32E4M3E4M3_SS = make_gmma_atom_ss(64, k=32, d_type="F32", ab_type="E4M3") -SM90_64x128x32_F32E4M3E4M3_SS = make_gmma_atom_ss(128, k=32, d_type="F32", ab_type="E4M3") -SM90_64x256x32_F32E4M3E4M3_SS = make_gmma_atom_ss(256, k=32, d_type="F32", ab_type="E4M3") +SM90_64x128x32_F32E4M3E4M3_SS = make_gmma_atom_ss( + 128, k=32, d_type="F32", ab_type="E4M3" +) +SM90_64x256x32_F32E4M3E4M3_SS = make_gmma_atom_ss( + 256, k=32, d_type="F32", ab_type="E4M3" +) SM90_64x64x32_F16E4M3E4M3_SS = make_gmma_atom_ss(64, k=32, d_type="F16", ab_type="E4M3") -SM90_64x128x32_F16E4M3E4M3_SS = make_gmma_atom_ss(128, k=32, d_type="F16", ab_type="E4M3") -SM90_64x256x32_F16E4M3E4M3_SS = make_gmma_atom_ss(256, k=32, d_type="F16", ab_type="E4M3") +SM90_64x128x32_F16E4M3E4M3_SS = make_gmma_atom_ss( + 128, k=32, d_type="F16", ab_type="E4M3" +) +SM90_64x256x32_F16E4M3E4M3_SS = make_gmma_atom_ss( + 256, k=32, d_type="F16", ab_type="E4M3" +) # FP8 E5M2 GMMA atoms SM90_64x64x32_F32E5M2E5M2_SS = make_gmma_atom_ss(64, k=32, d_type="F32", ab_type="E5M2") -SM90_64x128x32_F32E5M2E5M2_SS = make_gmma_atom_ss(128, k=32, d_type="F32", ab_type="E5M2") -SM90_64x256x32_F32E5M2E5M2_SS = make_gmma_atom_ss(256, k=32, d_type="F32", ab_type="E5M2") +SM90_64x128x32_F32E5M2E5M2_SS = make_gmma_atom_ss( + 128, k=32, d_type="F32", ab_type="E5M2" +) +SM90_64x256x32_F32E5M2E5M2_SS = make_gmma_atom_ss( + 256, k=32, d_type="F32", ab_type="E5M2" +) SM90_64x64x32_F16E5M2E5M2_SS = make_gmma_atom_ss(64, k=32, d_type="F16", ab_type="E5M2") -SM90_64x128x32_F16E5M2E5M2_SS = make_gmma_atom_ss(128, k=32, d_type="F16", ab_type="E5M2") -SM90_64x256x32_F16E5M2E5M2_SS = make_gmma_atom_ss(256, k=32, d_type="F16", ab_type="E5M2") +SM90_64x128x32_F16E5M2E5M2_SS = make_gmma_atom_ss( + 128, k=32, d_type="F16", ab_type="E5M2" +) +SM90_64x256x32_F16E5M2E5M2_SS = make_gmma_atom_ss( + 256, k=32, d_type="F16", ab_type="E5M2" +) # ============================================================================= @@ -712,8 +873,10 @@ def make_gmma_atom_ss(n: int, k: int = 16, d_type: str = "F16", # K_sparse = 2 * K_dense (e.g. K=32 for F16 sparse vs K=16 for F16 dense). # ============================================================================= -def make_gmma_sparse_atom_ss(n: int, k: int = 32, d_type: str = "F16", - ab_type: str | None = None) -> MMAAtom: + +def make_gmma_sparse_atom_ss( + n: int, k: int = 32, d_type: str = "F16", ab_type: str | None = None +) -> MMAAtom: """Create an SM90 GMMA sparse SS atom for 64×N×K.""" if ab_type is None: ab_type = d_type @@ -723,10 +886,12 @@ def make_gmma_sparse_atom_ss(n: int, k: int = 32, d_type: str = "F16", return MMAAtom( name=name, ptx=f"wgmma.mma_async.sp.sync.aligned.m64n{n}k{k}", - shape_mnk=(64, n, k), thr_id=None, + shape_mnk=(64, n, k), + thr_id=None, a_layout=gmma_ab_layout(64, k), b_layout=gmma_ab_layout(n, k), - c_layout=gmma_c_layout(n)) + c_layout=gmma_c_layout(n), + ) # F16 sparse (K=32, double the dense K=16) @@ -735,14 +900,26 @@ def make_gmma_sparse_atom_ss(n: int, k: int = 32, d_type: str = "F16", SM90_64x256x32_F16F16F16_SS_SPARSE = make_gmma_sparse_atom_ss(256) # TF32 sparse (K=16, double the dense K=8) -SM90_64x64x16_F32TF32TF32_SS_SPARSE = make_gmma_sparse_atom_ss(64, k=16, d_type="F32", ab_type="TF32") -SM90_64x128x16_F32TF32TF32_SS_SPARSE = make_gmma_sparse_atom_ss(128, k=16, d_type="F32", ab_type="TF32") -SM90_64x256x16_F32TF32TF32_SS_SPARSE = make_gmma_sparse_atom_ss(256, k=16, d_type="F32", ab_type="TF32") +SM90_64x64x16_F32TF32TF32_SS_SPARSE = make_gmma_sparse_atom_ss( + 64, k=16, d_type="F32", ab_type="TF32" +) +SM90_64x128x16_F32TF32TF32_SS_SPARSE = make_gmma_sparse_atom_ss( + 128, k=16, d_type="F32", ab_type="TF32" +) +SM90_64x256x16_F32TF32TF32_SS_SPARSE = make_gmma_sparse_atom_ss( + 256, k=16, d_type="F32", ab_type="TF32" +) # INT8 sparse (K=64, double the dense K=32) -SM90_64x64x64_S32S8S8_SS_SPARSE = make_gmma_sparse_atom_ss(64, k=64, d_type="S32", ab_type="S8") -SM90_64x128x64_S32S8S8_SS_SPARSE = make_gmma_sparse_atom_ss(128, k=64, d_type="S32", ab_type="S8") -SM90_64x256x64_S32S8S8_SS_SPARSE = make_gmma_sparse_atom_ss(256, k=64, d_type="S32", ab_type="S8") +SM90_64x64x64_S32S8S8_SS_SPARSE = make_gmma_sparse_atom_ss( + 64, k=64, d_type="S32", ab_type="S8" +) +SM90_64x128x64_S32S8S8_SS_SPARSE = make_gmma_sparse_atom_ss( + 128, k=64, d_type="S32", ab_type="S8" +) +SM90_64x256x64_S32S8S8_SS_SPARSE = make_gmma_sparse_atom_ss( + 256, k=64, d_type="S32", ab_type="S8" +) # ============================================================================= @@ -764,83 +941,103 @@ def make_gmma_sparse_atom_ss(n: int, k: int = 32, d_type: str = "F16", # M ∈ {64, 128}, N ∈ {8, 16, 24, ..., 256} (multiples of 8) # ============================================================================= + def umma_layout(rows: int, cols: int) -> Layout: """SM100 UMMA layout: (1, (rows, cols)) : (0, (1, rows)) — col-major.""" return Layout((1, (rows, cols)), (0, (1, rows))) + # --- F16/BF16 SS (both operands from shared memory) --- SM100_64x64x16_F16F16F16_SS = MMAAtom( name="SM100_64x64x16_F16F16F16_SS", ptx="tcgen05.mma ... m64n64k16.f16.f16.f16", - shape_mnk=(64, 64, 16), thr_id=Layout(1), + shape_mnk=(64, 64, 16), + thr_id=Layout(1), a_layout=umma_layout(64, 16), b_layout=umma_layout(64, 16), - c_layout=umma_layout(64, 64)) + c_layout=umma_layout(64, 64), +) SM100_64x128x16_F16F16F16_SS = MMAAtom( name="SM100_64x128x16_F16F16F16_SS", ptx="tcgen05.mma ... m64n128k16.f16.f16.f16", - shape_mnk=(64, 128, 16), thr_id=Layout(1), + shape_mnk=(64, 128, 16), + thr_id=Layout(1), a_layout=umma_layout(64, 16), b_layout=umma_layout(128, 16), - c_layout=umma_layout(64, 128)) + c_layout=umma_layout(64, 128), +) SM100_64x256x16_F16F16F16_SS = MMAAtom( name="SM100_64x256x16_F16F16F16_SS", ptx="tcgen05.mma ... m64n256k16.f16.f16.f16", - shape_mnk=(64, 256, 16), thr_id=Layout(1), + shape_mnk=(64, 256, 16), + thr_id=Layout(1), a_layout=umma_layout(64, 16), b_layout=umma_layout(256, 16), - c_layout=umma_layout(64, 256)) + c_layout=umma_layout(64, 256), +) SM100_128x64x16_F16F16F16_SS = MMAAtom( name="SM100_128x64x16_F16F16F16_SS", ptx="tcgen05.mma ... m128n64k16.f16.f16.f16", - shape_mnk=(128, 64, 16), thr_id=Layout(1), + shape_mnk=(128, 64, 16), + thr_id=Layout(1), a_layout=umma_layout(128, 16), b_layout=umma_layout(64, 16), - c_layout=umma_layout(128, 64)) + c_layout=umma_layout(128, 64), +) SM100_128x128x16_F16F16F16_SS = MMAAtom( name="SM100_128x128x16_F16F16F16_SS", ptx="tcgen05.mma ... m128n128k16.f16.f16.f16", - shape_mnk=(128, 128, 16), thr_id=Layout(1), + shape_mnk=(128, 128, 16), + thr_id=Layout(1), a_layout=umma_layout(128, 16), b_layout=umma_layout(128, 16), - c_layout=umma_layout(128, 128)) + c_layout=umma_layout(128, 128), +) SM100_128x256x16_F16F16F16_SS = MMAAtom( name="SM100_128x256x16_F16F16F16_SS", ptx="tcgen05.mma ... m128n256k16.f16.f16.f16", - shape_mnk=(128, 256, 16), thr_id=Layout(1), + shape_mnk=(128, 256, 16), + thr_id=Layout(1), a_layout=umma_layout(128, 16), b_layout=umma_layout(256, 16), - c_layout=umma_layout(128, 256)) + c_layout=umma_layout(128, 256), +) # --- TF32 SS (K=8 because 256/32=8) --- SM100_64x64x8_F32TF32TF32F32_SS = MMAAtom( name="SM100_64x64x8_F32TF32TF32F32_SS", ptx="tcgen05.mma ... m64n64k8.f32.tf32.tf32.f32", - shape_mnk=(64, 64, 8), thr_id=Layout(1), + shape_mnk=(64, 64, 8), + thr_id=Layout(1), a_layout=umma_layout(64, 8), b_layout=umma_layout(64, 8), - c_layout=umma_layout(64, 64)) + c_layout=umma_layout(64, 64), +) SM100_128x128x8_F32TF32TF32F32_SS = MMAAtom( name="SM100_128x128x8_F32TF32TF32F32_SS", ptx="tcgen05.mma ... m128n128k8.f32.tf32.tf32.f32", - shape_mnk=(128, 128, 8), thr_id=Layout(1), + shape_mnk=(128, 128, 8), + thr_id=Layout(1), a_layout=umma_layout(128, 8), b_layout=umma_layout(128, 8), - c_layout=umma_layout(128, 128)) + c_layout=umma_layout(128, 128), +) # --- SM100 UMMA factory --- -def make_umma_atom_ss(m: int, n: int, k: int = 16, - d_type: str = "F16", ab_type: str | None = None) -> MMAAtom: + +def make_umma_atom_ss( + m: int, n: int, k: int = 16, d_type: str = "F16", ab_type: str | None = None +) -> MMAAtom: """Create an SM100 UMMA SS atom for M×N×K with the given data types.""" if ab_type is None: ab_type = d_type @@ -848,10 +1045,13 @@ def make_umma_atom_ss(m: int, n: int, k: int = 16, return MMAAtom( name=name, ptx=f"tcgen05.mma ... m{m}n{n}k{k}", - shape_mnk=(m, n, k), thr_id=Layout(1), + shape_mnk=(m, n, k), + thr_id=Layout(1), a_layout=umma_layout(m, k), b_layout=umma_layout(n, k), - c_layout=umma_layout(m, n)) + c_layout=umma_layout(m, n), + ) + # F32-accumulator with F16 inputs SM100_64x64x16_F32F16F16_SS = make_umma_atom_ss(64, 64, d_type="F32", ab_type="F16") @@ -863,19 +1063,39 @@ def make_umma_atom_ss(m: int, n: int, k: int = 16, # F32-accumulator with BF16 inputs SM100_64x64x16_F32BF16BF16_SS = make_umma_atom_ss(64, 64, d_type="F32", ab_type="BF16") -SM100_64x128x16_F32BF16BF16_SS = make_umma_atom_ss(64, 128, d_type="F32", ab_type="BF16") -SM100_64x256x16_F32BF16BF16_SS = make_umma_atom_ss(64, 256, d_type="F32", ab_type="BF16") -SM100_128x64x16_F32BF16BF16_SS = make_umma_atom_ss(128, 64, d_type="F32", ab_type="BF16") -SM100_128x128x16_F32BF16BF16_SS = make_umma_atom_ss(128, 128, d_type="F32", ab_type="BF16") -SM100_128x256x16_F32BF16BF16_SS = make_umma_atom_ss(128, 256, d_type="F32", ab_type="BF16") +SM100_64x128x16_F32BF16BF16_SS = make_umma_atom_ss( + 64, 128, d_type="F32", ab_type="BF16" +) +SM100_64x256x16_F32BF16BF16_SS = make_umma_atom_ss( + 64, 256, d_type="F32", ab_type="BF16" +) +SM100_128x64x16_F32BF16BF16_SS = make_umma_atom_ss( + 128, 64, d_type="F32", ab_type="BF16" +) +SM100_128x128x16_F32BF16BF16_SS = make_umma_atom_ss( + 128, 128, d_type="F32", ab_type="BF16" +) +SM100_128x256x16_F32BF16BF16_SS = make_umma_atom_ss( + 128, 256, d_type="F32", ab_type="BF16" +) # F16-accumulator with BF16 inputs SM100_64x64x16_F16BF16BF16_SS = make_umma_atom_ss(64, 64, d_type="F16", ab_type="BF16") -SM100_64x128x16_F16BF16BF16_SS = make_umma_atom_ss(64, 128, d_type="F16", ab_type="BF16") -SM100_64x256x16_F16BF16BF16_SS = make_umma_atom_ss(64, 256, d_type="F16", ab_type="BF16") -SM100_128x64x16_F16BF16BF16_SS = make_umma_atom_ss(128, 64, d_type="F16", ab_type="BF16") -SM100_128x128x16_F16BF16BF16_SS = make_umma_atom_ss(128, 128, d_type="F16", ab_type="BF16") -SM100_128x256x16_F16BF16BF16_SS = make_umma_atom_ss(128, 256, d_type="F16", ab_type="BF16") +SM100_64x128x16_F16BF16BF16_SS = make_umma_atom_ss( + 64, 128, d_type="F16", ab_type="BF16" +) +SM100_64x256x16_F16BF16BF16_SS = make_umma_atom_ss( + 64, 256, d_type="F16", ab_type="BF16" +) +SM100_128x64x16_F16BF16BF16_SS = make_umma_atom_ss( + 128, 64, d_type="F16", ab_type="BF16" +) +SM100_128x128x16_F16BF16BF16_SS = make_umma_atom_ss( + 128, 128, d_type="F16", ab_type="BF16" +) +SM100_128x256x16_F16BF16BF16_SS = make_umma_atom_ss( + 128, 256, d_type="F16", ab_type="BF16" +) # ============================================================================= @@ -890,19 +1110,23 @@ def make_umma_atom_ss(m: int, n: int, k: int = 16, SM120_16x8x32_F32E4M3E4M3F32_TN = MMAAtom( name="SM120_16x8x32_F32E4M3E4M3F32_TN", ptx="mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32", - shape_mnk=(16, 8, 32), thr_id=None, + shape_mnk=(16, 8, 32), + thr_id=None, a_layout=Layout(((4, 8), (4, 2, 2)), ((64, 1), (16, 8, 256))), b_layout=Layout(((4, 8), (4, 2)), ((32, 1), (8, 128))), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # SM120 block-scaled MXF8F6F4 16x8x64 SM120_16x8x64_F32E4M3E4M3F32_TN = MMAAtom( name="SM120_16x8x64_F32E4M3E4M3F32_TN", ptx="mma.sync.aligned.m16n8k64.row.col.f32.e4m3.e4m3.f32", - shape_mnk=(16, 8, 64), thr_id=None, + shape_mnk=(16, 8, 64), + thr_id=None, a_layout=Layout(((4, 8), (8, 2, 2)), ((128, 1), (16, 8, 512))), b_layout=Layout(((4, 8), (8, 2)), ((64, 1), (8, 256))), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # --- SM120 Sparse (structured 2:4 sparsity) --- # Source: include/cute/atom/mma_traits_sm120_sparse.hpp @@ -911,53 +1135,65 @@ def make_umma_atom_ss(m: int, n: int, k: int = 16, SM120_16x8x64_F32E4M3E4M3F32_TN_SPARSE = MMAAtom( name="SM120_16x8x64_F32E4M3E4M3F32_TN_SPARSE", ptx="mma.sync.aligned.m16n8k64.row.col.f32.e4m3.e4m3.f32 (sparse)", - shape_mnk=(16, 8, 64), thr_id=None, + shape_mnk=(16, 8, 64), + thr_id=None, a_layout=Layout(((4, 8), (8, 2, 2)), ((128, 1), (16, 8, 512))), b_layout=Layout(((4, 8), (4, 4)), ((32, 1), (8, 128))), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # SM120 sparse block-scaled 16x8x128 (FP4, 2:4 sparsity) SM120_16x8x128_F32E4M3E4M3F32_TN_SPARSE = MMAAtom( name="SM120_16x8x128_F32E4M3E4M3F32_TN_SPARSE", ptx="mma.sync.aligned.m16n8k128.row.col.f32.e4m3.e4m3.f32 (sparse)", - shape_mnk=(16, 8, 128), thr_id=None, + shape_mnk=(16, 8, 128), + thr_id=None, a_layout=Layout(((4, 8), (16, 2, 2)), ((256, 1), (16, 8, 1024))), b_layout=Layout(((4, 8), (8, 4)), ((64, 1), (8, 256))), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # --- SM120 F16-accumulator variants (same layouts as F32, different register width) --- SM120_16x8x32_F16E4M3E4M3F16_TN = MMAAtom( name="SM120_16x8x32_F16E4M3E4M3F16_TN", ptx="mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e4m3.f16", - shape_mnk=(16, 8, 32), thr_id=None, + shape_mnk=(16, 8, 32), + thr_id=None, a_layout=Layout(((4, 8), (4, 2, 2)), ((64, 1), (16, 8, 256))), b_layout=Layout(((4, 8), (4, 2)), ((32, 1), (8, 128))), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) SM120_16x8x64_F16E4M3E4M3F16_TN = MMAAtom( name="SM120_16x8x64_F16E4M3E4M3F16_TN", ptx="mma.sync.aligned.m16n8k64.row.col.f16.e4m3.e4m3.f16", - shape_mnk=(16, 8, 64), thr_id=None, + shape_mnk=(16, 8, 64), + thr_id=None, a_layout=Layout(((4, 8), (8, 2, 2)), ((128, 1), (16, 8, 512))), b_layout=Layout(((4, 8), (8, 2)), ((64, 1), (8, 256))), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) SM120_16x8x64_F16E4M3E4M3F16_TN_SPARSE = MMAAtom( name="SM120_16x8x64_F16E4M3E4M3F16_TN_SPARSE", ptx="mma.sync.aligned.m16n8k64.row.col.f16.e4m3.e4m3.f16 (sparse)", - shape_mnk=(16, 8, 64), thr_id=None, + shape_mnk=(16, 8, 64), + thr_id=None, a_layout=Layout(((4, 8), (8, 2, 2)), ((128, 1), (16, 8, 512))), b_layout=Layout(((4, 8), (4, 4)), ((32, 1), (8, 128))), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) SM120_16x8x128_F16E4M3E4M3F16_TN_SPARSE = MMAAtom( name="SM120_16x8x128_F16E4M3E4M3F16_TN_SPARSE", ptx="mma.sync.aligned.m16n8k128.row.col.f16.e4m3.e4m3.f16 (sparse)", - shape_mnk=(16, 8, 128), thr_id=None, + shape_mnk=(16, 8, 128), + thr_id=None, a_layout=Layout(((4, 8), (16, 2, 2)), ((256, 1), (16, 8, 1024))), b_layout=Layout(((4, 8), (8, 4)), ((64, 1), (8, 256))), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # ============================================================================= @@ -971,14 +1207,16 @@ def make_umma_atom_ss(m: int, n: int, k: int = 16, ptx="shfl.sync.bfly (XOR1 2x2 transpose)", thr_id=Layout(32), src_layout_bits=Layout((32, 64), (64, 1)), - dst_layout_bits=Layout(((2, 16), (32, 2)), ((32, 128), (1, 64)))) + dst_layout_bits=Layout(((2, 16), (32, 2)), ((32, 128), (1, 64))), +) SM50_Shuffle_U32_2x2Trans_XOR4 = CopyAtom( name="SM50_Shuffle_U32_2x2Trans_XOR4", ptx="shfl.sync.bfly (XOR4 2x2 transpose)", thr_id=Layout(32), src_layout_bits=Layout((32, 64), (64, 1)), - dst_layout_bits=Layout(((4, 2, 4), (32, 2)), ((64, 32, 512), (1, 256)))) + dst_layout_bits=Layout(((4, 2, 4), (32, 2)), ((64, 32, 512), (1, 256))), +) # ============================================================================= @@ -993,42 +1231,48 @@ def make_umma_atom_ss(m: int, n: int, k: int = 16, ptx="ldmatrix.sync.aligned.x1.m8n8.shared.b16", thr_id=Layout(32), src_layout_bits=Layout(((8, 4), 128), ((128, 0), 1)), - dst_layout_bits=Layout((32, 32), (32, 1))) + dst_layout_bits=Layout((32, 32), (32, 1)), +) SM75_U32x2_LDSM_N = CopyAtom( name="SM75_U32x2_LDSM_N", ptx="ldmatrix.sync.aligned.x2.m8n8.shared.b16", thr_id=Layout(32), src_layout_bits=Layout(((16, 2), 128), ((128, 0), 1)), - dst_layout_bits=Layout((32, (32, 2)), (32, (1, 1024)))) + dst_layout_bits=Layout((32, (32, 2)), (32, (1, 1024))), +) SM75_U32x4_LDSM_N = CopyAtom( name="SM75_U32x4_LDSM_N", ptx="ldmatrix.sync.aligned.x4.m8n8.shared.b16", thr_id=Layout(32), src_layout_bits=Layout((32, 128), (128, 1)), - dst_layout_bits=Layout((32, (32, 4)), (32, (1, 1024)))) + dst_layout_bits=Layout((32, (32, 4)), (32, (1, 1024))), +) SM75_U16x2_LDSM_T = CopyAtom( name="SM75_U16x2_LDSM_T", ptx="ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16", thr_id=Layout(32), src_layout_bits=Layout(((8, 4), 128), ((128, 0), 1)), - dst_layout_bits=Layout(((4, 8), (16, 2)), ((256, 16), (1, 128)))) + dst_layout_bits=Layout(((4, 8), (16, 2)), ((256, 16), (1, 128))), +) SM75_U16x4_LDSM_T = CopyAtom( name="SM75_U16x4_LDSM_T", ptx="ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16", thr_id=Layout(32), src_layout_bits=Layout(((16, 2), 128), ((128, 0), 1)), - dst_layout_bits=Layout(((4, 8), (16, 2, 2)), ((256, 16), (1, 128, 1024)))) + dst_layout_bits=Layout(((4, 8), (16, 2, 2)), ((256, 16), (1, 128, 1024))), +) SM75_U16x8_LDSM_T = CopyAtom( name="SM75_U16x8_LDSM_T", ptx="ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16", thr_id=Layout(32), src_layout_bits=Layout((32, 128), (128, 1)), - dst_layout_bits=Layout(((4, 8), (16, 2, 4)), ((256, 16), (1, 128, 1024)))) + dst_layout_bits=Layout(((4, 8), (16, 2, 4)), ((256, 16), (1, 128, 1024))), +) # ============================================================================= @@ -1047,14 +1291,16 @@ def make_umma_atom_ss(m: int, n: int, k: int = 16, ptx="cp.async.ca.shared.global [16B]", thr_id=Layout(1), src_layout_bits=Layout((1, 128)), - dst_layout_bits=Layout((1, 128))) + dst_layout_bits=Layout((1, 128)), +) SM80_CP_ASYNC_CACHEGLOBAL_16B = CopyAtom( name="SM80_CP_ASYNC_CACHEGLOBAL_16B", ptx="cp.async.cg.shared.global [16B]", thr_id=Layout(1), src_layout_bits=Layout((1, 128)), - dst_layout_bits=Layout((1, 128))) + dst_layout_bits=Layout((1, 128)), +) # ============================================================================= @@ -1069,42 +1315,48 @@ def make_umma_atom_ss(m: int, n: int, k: int = 16, ptx="stmatrix.sync.aligned.x1.m8n8.shared.b16", thr_id=Layout(32), src_layout_bits=SM75_U32x1_LDSM_N.dst_layout_bits, - dst_layout_bits=SM75_U32x1_LDSM_N.src_layout_bits) + dst_layout_bits=SM75_U32x1_LDSM_N.src_layout_bits, +) SM90_U32x2_STSM_N = CopyAtom( name="SM90_U32x2_STSM_N", ptx="stmatrix.sync.aligned.x2.m8n8.shared.b16", thr_id=Layout(32), src_layout_bits=SM75_U32x2_LDSM_N.dst_layout_bits, - dst_layout_bits=SM75_U32x2_LDSM_N.src_layout_bits) + dst_layout_bits=SM75_U32x2_LDSM_N.src_layout_bits, +) SM90_U32x4_STSM_N = CopyAtom( name="SM90_U32x4_STSM_N", ptx="stmatrix.sync.aligned.x4.m8n8.shared.b16", thr_id=Layout(32), src_layout_bits=SM75_U32x4_LDSM_N.dst_layout_bits, - dst_layout_bits=SM75_U32x4_LDSM_N.src_layout_bits) + dst_layout_bits=SM75_U32x4_LDSM_N.src_layout_bits, +) SM90_U16x2_STSM_T = CopyAtom( name="SM90_U16x2_STSM_T", ptx="stmatrix.sync.aligned.x1.trans.m8n8.shared.b16", thr_id=Layout(32), src_layout_bits=SM75_U16x2_LDSM_T.dst_layout_bits, - dst_layout_bits=SM75_U16x2_LDSM_T.src_layout_bits) + dst_layout_bits=SM75_U16x2_LDSM_T.src_layout_bits, +) SM90_U16x4_STSM_T = CopyAtom( name="SM90_U16x4_STSM_T", ptx="stmatrix.sync.aligned.x2.trans.m8n8.shared.b16", thr_id=Layout(32), src_layout_bits=SM75_U16x4_LDSM_T.dst_layout_bits, - dst_layout_bits=SM75_U16x4_LDSM_T.src_layout_bits) + dst_layout_bits=SM75_U16x4_LDSM_T.src_layout_bits, +) SM90_U16x8_STSM_T = CopyAtom( name="SM90_U16x8_STSM_T", ptx="stmatrix.sync.aligned.x4.trans.m8n8.shared.b16", thr_id=Layout(32), src_layout_bits=SM75_U16x8_LDSM_T.dst_layout_bits, - dst_layout_bits=SM75_U16x8_LDSM_T.src_layout_bits) + dst_layout_bits=SM75_U16x8_LDSM_T.src_layout_bits, +) # ============================================================================= @@ -1117,10 +1369,14 @@ def make_umma_atom_ss(m: int, n: int, k: int = 16, ] MMA_ATOMS_SM70 = [ - SM70_8x8x4_F16F16F16F16_TN, SM70_8x8x4_F16F16F16F16_NT, - SM70_8x8x4_F16F16F16F16_NN, SM70_8x8x4_F16F16F16F16_TT, - SM70_8x8x4_F32F16F16F32_TN, SM70_8x8x4_F32F16F16F32_NT, - SM70_8x8x4_F32F16F16F32_NN, SM70_8x8x4_F32F16F16F32_TT, + SM70_8x8x4_F16F16F16F16_TN, + SM70_8x8x4_F16F16F16F16_NT, + SM70_8x8x4_F16F16F16F16_NN, + SM70_8x8x4_F16F16F16F16_TT, + SM70_8x8x4_F32F16F16F32_TN, + SM70_8x8x4_F32F16F16F32_NT, + SM70_8x8x4_F32F16F16F32_NN, + SM70_8x8x4_F32F16F16F32_TT, ] MMA_ATOMS_SM75 = [ @@ -1129,14 +1385,20 @@ def make_umma_atom_ss(m: int, n: int, k: int = 16, ] MMA_ATOMS_SM80 = [ - SM80_16x8x8_F16F16F16F16_TN, SM80_16x8x16_F16F16F16F16_TN, - SM80_16x8x8_F32F16F16F32_TN, SM80_16x8x16_F32F16F16F32_TN, - SM80_16x8x8_F32BF16BF16F32_TN, SM80_16x8x16_F32BF16BF16F32_TN, - SM80_16x8x4_F32TF32TF32F32_TN, SM80_16x8x8_F32TF32TF32F32_TN, + SM80_16x8x8_F16F16F16F16_TN, + SM80_16x8x16_F16F16F16F16_TN, + SM80_16x8x8_F32F16F16F32_TN, + SM80_16x8x16_F32F16F16F32_TN, + SM80_16x8x8_F32BF16BF16F32_TN, + SM80_16x8x16_F32BF16BF16F32_TN, + SM80_16x8x4_F32TF32TF32F32_TN, + SM80_16x8x8_F32TF32TF32F32_TN, SM80_8x8x4_F64F64F64F64_TN, - SM80_8x8x16_S32S8S8S32_TN, SM80_16x8x16_S32S8S8S32_TN, + SM80_8x8x16_S32S8S8S32_TN, + SM80_16x8x16_S32S8S8S32_TN, SM80_16x8x32_S32S8S8S32_TN, - SM80_8x8x32_S32S4S4S32_TN, SM80_16x8x32_S32S4S4S32_TN, + SM80_8x8x32_S32S4S4S32_TN, + SM80_16x8x32_S32S4S4S32_TN, SM80_16x8x64_S32S4S4S32_TN, SM80_8x8x128_S32U1U1S32_TN_XORPOPC, SM80_16x8x128_S32U1U1S32_TN_XORPOPC, @@ -1264,8 +1526,12 @@ def make_umma_atom_ss(m: int, n: int, k: int = 16, ] COPY_ATOMS_SM75 = [ - SM75_U32x1_LDSM_N, SM75_U32x2_LDSM_N, SM75_U32x4_LDSM_N, - SM75_U16x2_LDSM_T, SM75_U16x4_LDSM_T, SM75_U16x8_LDSM_T, + SM75_U32x1_LDSM_N, + SM75_U32x2_LDSM_N, + SM75_U32x4_LDSM_N, + SM75_U16x2_LDSM_T, + SM75_U16x4_LDSM_T, + SM75_U16x8_LDSM_T, ] COPY_ATOMS_SM80 = [ @@ -1274,6 +1540,10 @@ def make_umma_atom_ss(m: int, n: int, k: int = 16, ] COPY_ATOMS_SM90 = [ - SM90_U32x1_STSM_N, SM90_U32x2_STSM_N, SM90_U32x4_STSM_N, - SM90_U16x2_STSM_T, SM90_U16x4_STSM_T, SM90_U16x8_STSM_T, + SM90_U32x1_STSM_N, + SM90_U32x2_STSM_N, + SM90_U32x4_STSM_N, + SM90_U16x2_STSM_T, + SM90_U16x4_STSM_T, + SM90_U16x8_STSM_T, ] diff --git a/src/tensor_layouts/atoms_xe.py b/src/tensor_layouts/atoms_xe.py index 04c7b28..9abcfb1 100644 --- a/src/tensor_layouts/atoms_xe.py +++ b/src/tensor_layouts/atoms_xe.py @@ -96,6 +96,7 @@ # Helper: construct CuTe layouts from DPAS structural parameters # ============================================================================= + def _dpas_c_layout(m: int, n: int) -> Layout: """Build the (T_n, V_m) -> col-major(M, N) accumulator layout. @@ -154,7 +155,7 @@ def make_dpas_atom( name=name, ptx=inst, shape_mnk=(m, n, k), - thr_id=None, # identity: lane_id = thread_idx % subgroup_size + thr_id=None, # identity: lane_id = thread_idx % subgroup_size a_layout=a_layout, b_layout=b_layout, c_layout=c_layout, @@ -170,28 +171,36 @@ def make_dpas_atom( XeHPC_8x8x8_F32F16F16_DPAS = make_dpas_atom( name="XeHPC_8x8x8_F32F16F16_DPAS", inst="dpas.8x8 (exec_size=8, FP16)", - m=8, n=8, k=8, + m=8, + n=8, + k=8, ) # --- BF16 input, FP32 accumulator --- XeHPC_8x8x8_F32BF16BF16_DPAS = make_dpas_atom( name="XeHPC_8x8x8_F32BF16BF16_DPAS", inst="dpas.8x8 (exec_size=8, BF16)", - m=8, n=8, k=8, + m=8, + n=8, + k=8, ) # --- TF32 input, FP32 accumulator --- XeHPC_8x8x8_F32TF32TF32_DPAS = make_dpas_atom( name="XeHPC_8x8x8_F32TF32TF32_DPAS", inst="dpas.8x8 (exec_size=8, TF32)", - m=8, n=8, k=8, + m=8, + n=8, + k=8, ) # --- INT8 input, INT32 accumulator --- XeHPC_8x8x8_I32I8I8_DPAS = make_dpas_atom( name="XeHPC_8x8x8_I32I8I8_DPAS", inst="dpas.8x8 (exec_size=8, INT8)", - m=8, n=8, k=8, + m=8, + n=8, + k=8, ) @@ -204,21 +213,27 @@ def make_dpas_atom( XeHPG_8x16x8_F32F16F16_DPAS = make_dpas_atom( name="XeHPG_8x16x8_F32F16F16_DPAS", inst="dpas.8x8 (exec_size=16, FP16)", - m=8, n=16, k=8, + m=8, + n=16, + k=8, ) # --- BF16 input, FP32 accumulator --- XeHPG_8x16x8_F32BF16BF16_DPAS = make_dpas_atom( name="XeHPG_8x16x8_F32BF16BF16_DPAS", inst="dpas.8x8 (exec_size=16, BF16)", - m=8, n=16, k=8, + m=8, + n=16, + k=8, ) # --- INT8 input, INT32 accumulator --- XeHPG_8x16x8_I32I8I8_DPAS = make_dpas_atom( name="XeHPG_8x16x8_I32I8I8_DPAS", inst="dpas.8x8 (exec_size=16, INT8)", - m=8, n=16, k=8, + m=8, + n=16, + k=8, ) diff --git a/src/tensor_layouts/layout_utils.py b/src/tensor_layouts/layout_utils.py index 609bec9..f148ddd 100644 --- a/src/tensor_layouts/layout_utils.py +++ b/src/tensor_layouts/layout_utils.py @@ -127,9 +127,7 @@ def tile_to_shape(layout: Layout, target_shape, order: tuple = None) -> Layout: f"layout.shape; got {len(target_shape)} target modes for block shape {block_shape}" ) - product_shape = tuple( - (t + b - 1) // b for t, b in zip(target_shape, block_shape) - ) + product_shape = tuple((t + b - 1) // b for t, b in zip(target_shape, block_shape)) replication = make_ordered_layout(product_shape, order) @@ -190,7 +188,7 @@ def _exact_tile_factor(requested: int, natural: int, *, axis: str, tile_mnk) -> return requested // natural -def tile_mma_grid(atom, atom_layout, matrix='C', tile_mnk=None): +def tile_mma_grid(atom, atom_layout, matrix="C", tile_mnk=None): """Compute the tiled MMA grid by replicating an atom across quadpairs. Mirrors the C++ make_tiled_mma(atom, atom_layout, Tile) function. @@ -233,17 +231,17 @@ def tile_mma_grid(atom, atom_layout, matrix='C', tile_mnk=None): qp_offset = n_thr_per_atom // 2 if thr_id is not None else n_thr_per_atom # Select atom layout and tile dimensions based on matrix - if matrix == 'C': + if matrix == "C": atom_lyt = atom.c_layout row_atoms = n_atoms_m col_atoms = n_atoms_n atom_rows, atom_cols = M_atom, N_atom - elif matrix == 'A': + elif matrix == "A": atom_lyt = atom.a_layout row_atoms = n_atoms_m col_atoms = 1 atom_rows, atom_cols = M_atom, K_atom - elif matrix == 'B': + elif matrix == "B": atom_lyt = atom.b_layout row_atoms = n_atoms_n col_atoms = 1 @@ -260,9 +258,9 @@ def tile_mma_grid(atom, atom_layout, matrix='C', tile_mnk=None): for am in range(row_atoms): for an in range(col_atoms): # Determine atom index from the atom_layout - if matrix == 'C': + if matrix == "C": atom_idx = atom_layout((am, an)) if not is_int(atom_shape) else am - elif matrix == 'A': + elif matrix == "A": # A tiles along M only; use first N-column atom atom_idx = atom_layout((am, 0)) if not is_int(atom_shape) else am else: # B @@ -306,9 +304,9 @@ def tile_mma_grid(atom, atom_layout, matrix='C', tile_mnk=None): rep_N = _exact_tile_factor(tile_N, nat_N, axis="N", tile_mnk=tile_mnk) rep_K = _exact_tile_factor(tile_K, K_atom, axis="K", tile_mnk=tile_mnk) - if matrix == 'C': + if matrix == "C": rep_rows, rep_cols = rep_M, rep_N - elif matrix == 'A': + elif matrix == "A": rep_rows, rep_cols = rep_M, rep_K else: # B rep_rows, rep_cols = rep_N, rep_K diff --git a/src/tensor_layouts/layouts.py b/src/tensor_layouts/layouts.py index f0de2df..f5dda2e 100644 --- a/src/tensor_layouts/layouts.py +++ b/src/tensor_layouts/layouts.py @@ -457,14 +457,18 @@ def _validate_shape_type(x, name: str) -> None: for elem in x: _validate_shape_type(elem, name) return - raise TypeError(f"Layout {name} must be int or tuple of ints, got {type(x).__name__}") + raise TypeError( + f"Layout {name} must be int or tuple of ints, got {type(x).__name__}" + ) def _validate_nonnegative_shape(shape: Any) -> None: """Validate that every shape extent is nonnegative.""" if is_int(shape): if shape < 0: - raise ValueError(f"Layout shape must contain only nonnegative extents, got {shape}") + raise ValueError( + f"Layout shape must contain only nonnegative extents, got {shape}" + ) return for elem in shape: _validate_nonnegative_shape(elem) @@ -554,7 +558,9 @@ def __init__(self, *args, swizzle: "Swizzle | None" = None): ) if not congruent(self._shape, self._stride): - raise ValueError(f"Shape {self._shape} and Stride {self._stride} are not congruent") + raise ValueError( + f"Shape {self._shape} and Stride {self._stride} are not congruent" + ) def __eq__(self, other): if self is other: @@ -576,7 +582,9 @@ def __hash__(self): def __repr__(self): """Return an eval-safe constructor string: Layout((4, 2), (1, 4)).""" if self._swizzle is not None: - return f"Layout({self._shape!r}, {self._stride!r}, swizzle={self._swizzle!r})" + return ( + f"Layout({self._shape!r}, {self._stride!r}, swizzle={self._swizzle!r})" + ) return f"Layout({self._shape!r}, {self._stride!r})" def __str__(self): @@ -602,7 +610,9 @@ def swizzle(self) -> "Swizzle | None": @staticmethod def _calculate_max_offset(shape: Any, stride: Any) -> int: if is_tuple(shape): - return sum(Layout._calculate_max_offset(s, d) for s, d in zip(shape, stride)) + return sum( + Layout._calculate_max_offset(s, d) for s, d in zip(shape, stride) + ) return (shape - 1) * abs(stride) def __call__(self, *args): @@ -792,11 +802,15 @@ def _forward_layout_domain(layout, transform): transformed inner result is affine. ComposedLayout always stays composed. """ if isinstance(layout, ComposedLayout): - return ComposedLayout(layout.outer, transform(layout.inner), preoffset=layout.preoffset) + return ComposedLayout( + layout.outer, transform(layout.inner), preoffset=layout.preoffset + ) if isinstance(layout, Layout) and layout.swizzle is not None: inner_result = transform(_affine_inner(layout)) if isinstance(inner_result, Layout) and inner_result.swizzle is None: - return Layout(inner_result.shape, inner_result.stride, swizzle=layout.swizzle) + return Layout( + inner_result.shape, inner_result.stride, swizzle=layout.swizzle + ) return ComposedLayout(layout.swizzle, inner_result) return _NO_FORWARD @@ -946,9 +960,12 @@ def concat(t1: Any, t2: Any): return t1 + t2 if isinstance(t1, Layout) and isinstance(t2, Layout): return Layout( - as_tuple(t1.shape) + as_tuple(t2.shape), as_tuple(t1.stride) + as_tuple(t2.stride) + as_tuple(t1.shape) + as_tuple(t2.shape), + as_tuple(t1.stride) + as_tuple(t2.stride), ) - raise TypeError(f"Cannot concatenate objects of {type(t1).__name__} and {type(t2).__name__}") + raise TypeError( + f"Cannot concatenate objects of {type(t1).__name__} and {type(t2).__name__}" + ) def congruent(a: IntOrIntTuple, b: IntOrIntTuple) -> bool: @@ -989,7 +1006,9 @@ def weakly_congruent(a: IntOrIntTuple, b: IntOrIntTuple) -> bool: if isinstance(a, int): return True if is_tuple(a) and is_tuple(b): - return len(a) == len(b) and all(weakly_congruent(sa, sb) for sa, sb in zip(a, b)) + return len(a) == len(b) and all( + weakly_congruent(sa, sb) for sa, sb in zip(a, b) + ) return False @@ -1031,7 +1050,10 @@ def _can_group_a_into_b(a_modes: list, b) -> bool: return acc_size == target_size if is_tuple(b): - return all(_can_group_a_into_b(a_modes, sub_b) for sub_b in b) and len(a_modes) == 0 + return ( + all(_can_group_a_into_b(a_modes, sub_b) for sub_b in b) + and len(a_modes) == 0 + ) return False @@ -1102,7 +1124,9 @@ def replace(layout: LayoutExpr, idx: int, new_layout: Layout) -> LayoutExpr: replace((3,4,(3,4)):(1,3,(1,3)), 2, 4:3) -> (3,4,4):(1,3,3) """ - forwarded = _forward_layout_domain(layout, lambda inner: replace(inner, idx, new_layout)) + forwarded = _forward_layout_domain( + layout, lambda inner: replace(inner, idx, new_layout) + ) if forwarded is not _NO_FORWARD: return forwarded shapes = as_list(layout.shape) @@ -1440,7 +1464,9 @@ def inner_product(a: Any, b: Any) -> int: return sum(inner_product(x, y) for x, y in zip(a, b)) else: if not isinstance(a, int) or not isinstance(b, int): - raise TypeError(f"Expected int, got {type(a).__name__} and {type(b).__name__}") + raise TypeError( + f"Expected int, got {type(a).__name__} and {type(b).__name__}" + ) return a * b @@ -1592,7 +1618,9 @@ def _coalesce_by_mode(layout: Layout, profile: tuple) -> Layout: result_s.append(1) result_d.append(0) else: - coalesced = _coalesce_flat(Layout(mode(layout.shape, i), mode(layout.stride, i))) + coalesced = _coalesce_flat( + Layout(mode(layout.shape, i), mode(layout.stride, i)) + ) result_s.append(coalesced.shape) result_d.append(coalesced.stride) return Layout(as_shape(result_s), as_shape(result_d)) @@ -1714,7 +1742,9 @@ def _step_mode(current_stride, stride, shape): flat_shapes = as_list(flat.shape) flat_strides = as_list(flat.stride) - modes = sorted(((d, s) for s, d in zip(flat_shapes, flat_strides) if s != 1 and d != 0)) + modes = sorted( + ((d, s) for s, d in zip(flat_shapes, flat_strides) if s != 1 and d != 0) + ) # Fold _step_mode over sorted modes, collecting gap-fills result_shapes = [] @@ -1738,7 +1768,9 @@ def _step_mode(current_stride, stride, shape): # instead of collapsing eagerly to size(...), matching CuTe C++. if is_tuple(cosize_bound): remaining = _coalesce_shape(_shape_ceil_div(cosize_bound, current_stride)) - remaining_stride = elem_scale(current_stride, compute_col_major_strides(remaining)) + remaining_stride = elem_scale( + current_stride, compute_col_major_strides(remaining) + ) else: remaining = _ceil_div(cosize_bound, current_stride) remaining_stride = current_stride @@ -2116,7 +2148,12 @@ def _slice_for_composition(crd, layout: LayoutExpr): """ if isinstance(layout, ComposedLayout): inner_slice, delta = _slice_for_composition(crd, layout.inner) - return (ComposedLayout(layout.outer, inner_slice, preoffset=layout.preoffset + delta), 0) + return ( + ComposedLayout( + layout.outer, inner_slice, preoffset=layout.preoffset + delta + ), + 0, + ) sliced_shape = slice_modes(crd, layout.shape) sliced_stride = slice_modes(crd, layout.stride) @@ -2183,7 +2220,9 @@ def idx2crd(coord: Any, shape: Any) -> Any: # We map the modes of the coordinate to the modes of the shape if is_tuple(coord): if len(coord) != len(shape): - raise ValueError(f"Coordinate rank {len(coord)} mismatch with Shape rank {len(shape)}") + raise ValueError( + f"Coordinate rank {len(coord)} mismatch with Shape rank {len(shape)}" + ) return zip_transform(coord, shape, idx2crd) @@ -2269,7 +2308,9 @@ def crd2offset(coord, shape, stride) -> int: if not is_tuple(coord): raise TypeError(f"Coordinate must be int or tuple, got {type(coord).__name__}") if len(coord) != len(shape): - raise ValueError(f"Coordinate rank {len(coord)} does not match layout rank {len(shape)}") + raise ValueError( + f"Coordinate rank {len(coord)} does not match layout rank {len(shape)}" + ) offset = 0 for c, s, d in zip(coord, shape, stride): if c is None: @@ -2321,12 +2362,16 @@ def crd2crd(crd: Any, dst_shape: Any, src_shape: Any = None) -> Any: f"Rank mismatch: crd has {len(crd)} elements, dst_shape has {len(dst_shape)}" ) if src_shape is not None and is_tuple(src_shape): - return tuple(crd2crd(c, d, s) for c, d, s in zip(crd, dst_shape, src_shape)) + return tuple( + crd2crd(c, d, s) for c, d, s in zip(crd, dst_shape, src_shape) + ) return zip_transform(crd, dst_shape, crd2crd) else: # crd is tuple, dst_shape is scalar: flatten using src_shape if src_shape is None: - raise ValueError("src_shape required to flatten tuple coordinate to scalar") + raise ValueError( + "src_shape required to flatten tuple coordinate to scalar" + ) return crd2flat(crd, src_shape) else: if is_tuple(dst_shape): @@ -2359,7 +2404,9 @@ def slice_modes(crd, trg): if is_tuple(crd): if is_tuple(trg): if len(crd) != len(trg): - raise ValueError(f"Rank mismatch: crd has {len(crd)} elements, trg has {len(trg)}") + raise ValueError( + f"Rank mismatch: crd has {len(crd)} elements, trg has {len(trg)}" + ) # Process each top-level mode independently, preserving hierarchy result = [] for c, s in zip(crd, trg): @@ -2494,7 +2541,9 @@ def __new__(cls, *layouts): """ for i, layout in enumerate(layouts): if not isinstance(layout, Layout): - raise TypeError(f"Tile element {i} must be a Layout, got {type(layout).__name__}") + raise TypeError( + f"Tile element {i} must be a Layout, got {type(layout).__name__}" + ) return super().__new__(cls, layouts) def __repr__(self): @@ -2654,7 +2703,9 @@ def _upcast_leaf(s, d): def _apply(shape, stride): if is_tuple(shape): if not is_tuple(stride) or len(shape) != len(stride): - raise ValueError(f"Shape/stride structure mismatch: {shape} vs {stride}") + raise ValueError( + f"Shape/stride structure mismatch: {shape} vs {stride}" + ) pairs = [_apply(s, d) for s, d in zip(shape, stride)] new_s = tuple(p[0] for p in pairs) new_d = tuple(p[1] for p in pairs) @@ -2688,7 +2739,9 @@ def _downcast_leaf(s, d): def _apply(shape, stride): if is_tuple(shape): if not is_tuple(stride) or len(shape) != len(stride): - raise ValueError(f"Shape/stride structure mismatch: {shape} vs {stride}") + raise ValueError( + f"Shape/stride structure mismatch: {shape} vs {stride}" + ) pairs = [_apply(s, d) for s, d in zip(shape, stride)] new_s = tuple(p[0] for p in pairs) new_d = tuple(p[1] for p in pairs) @@ -2745,7 +2798,9 @@ def _composition_1d(layout_a: "Layout", b_shape: int, b_stride: int) -> "Layout" negative_stride = remaining_stride < 0 divisible = curr_shape % abs_stride == 0 or abs_stride % curr_shape == 0 - fits_in_mode = remaining_shape > 1 and (remaining_shape - 1) * abs_stride < curr_shape + fits_in_mode = ( + remaining_shape > 1 and (remaining_shape - 1) * abs_stride < curr_shape + ) if not divisible and not fits_in_mode: raise ValueError( f"compose: shape {curr_shape} and stride {remaining_stride} are not divisible" @@ -2795,8 +2850,12 @@ def _compose_layouts(layout_a: Layout, layout_b: Layout) -> Layout: def compose_element(b_shape, b_stride): """Recursively compose A with one element of B's shape/stride.""" if is_tuple(b_shape): - results = [compose_element(b_shape[i], b_stride[i]) for i in range(len(b_shape))] - return Layout(tuple(r.shape for r in results), tuple(r.stride for r in results)) + results = [ + compose_element(b_shape[i], b_stride[i]) for i in range(len(b_shape)) + ] + return Layout( + tuple(r.shape for r in results), tuple(r.stride for r in results) + ) return _composition_1d(layout_a, b_shape, b_stride) if is_tuple(layout_b.shape): @@ -2858,7 +2917,9 @@ def _normalize_compose_tiler_element(elem): raise TypeError(f"Invalid tiler element: {type(elem)}") -def _compose_into_composed_lhs(layout_a: "ComposedLayout", layout_b: Any) -> "ComposedLayout": +def _compose_into_composed_lhs( + layout_a: "ComposedLayout", layout_b: Any +) -> "ComposedLayout": """compose(ComposedLayout(outer, inner, preoffset), B). The outer/preoffset are external to the data domain, so composition @@ -2900,7 +2961,9 @@ def _compose_swizzled_layout_lhs(layout_a: "Layout", layout_b: Any) -> Any: """ inner_composed = compose(_affine_inner(layout_a), layout_b) if isinstance(inner_composed, Layout) and inner_composed.swizzle is None: - return Layout(inner_composed.shape, inner_composed.stride, swizzle=layout_a.swizzle) + return Layout( + inner_composed.shape, inner_composed.stride, swizzle=layout_a.swizzle + ) return ComposedLayout(layout_a.swizzle, inner_composed) @@ -3128,7 +3191,9 @@ def logical_divide(layout: LayoutExpr, tiler: Any) -> LayoutExpr: logical_divide(Layout(16), 4) -> Layout((4, 4), (1, 4)) logical_divide(Layout((4,2,3), (2,1,8)), Layout(4, 2)) -> ((2,2),(2,3)):((4,1),(2,8)) """ - forwarded = _forward_layout_domain(layout, lambda inner: logical_divide(inner, tiler)) + forwarded = _forward_layout_domain( + layout, lambda inner: logical_divide(inner, tiler) + ) if forwarded is not _NO_FORWARD: return forwarded if isinstance(tiler, Layout): @@ -3260,8 +3325,16 @@ def _logical_divide_by_shape(layout: Layout, tiler_shape: Any) -> Layout: result_strides.append(divided.stride) else: tile_part = compose(Layout(s, d), Layout(tile_size, 1)) - tile_s = unwrap(tile_part.shape) if is_tuple(tile_part.shape) else tile_part.shape - tile_d = unwrap(tile_part.stride) if is_tuple(tile_part.stride) else tile_part.stride + tile_s = ( + unwrap(tile_part.shape) + if is_tuple(tile_part.shape) + else tile_part.shape + ) + tile_d = ( + unwrap(tile_part.stride) + if is_tuple(tile_part.stride) + else tile_part.stride + ) result_shapes.append((tile_s, 1)) result_strides.append((tile_d, 0)) @@ -3285,7 +3358,9 @@ def _split_divided_modes(layout: Layout, tiler: Any): elif is_tuple(tiler): tiler_shape = tiler else: - raise TypeError(f"_split_divided_modes expects an int or tuple tiler, got {type(tiler)}") + raise TypeError( + f"_split_divided_modes expects an int or tuple tiler, got {type(tiler)}" + ) divided = logical_divide(layout, tiler_shape) @@ -3355,7 +3430,9 @@ def zipped_divide(layout: LayoutExpr, tiler: Any) -> LayoutExpr: Examples: zipped_divide(Layout((4,8)), (2,4)) -> Layout(((2,4),(2,2)), ((1,4),(2,16))) """ - forwarded = _forward_layout_domain(layout, lambda inner: zipped_divide(inner, tiler)) + forwarded = _forward_layout_domain( + layout, lambda inner: zipped_divide(inner, tiler) + ) if forwarded is not _NO_FORWARD: return forwarded # True Layout tilers are terminals in CuTe's tile_unzip(). Preserve their @@ -3363,7 +3440,9 @@ def zipped_divide(layout: LayoutExpr, tiler: Any) -> LayoutExpr: if isinstance(tiler, Layout): return logical_divide(layout, tiler) - tile_shapes, tile_strides, rest_shapes, rest_strides = _split_divided_modes(layout, tiler) + tile_shapes, tile_strides, rest_shapes, rest_strides = _split_divided_modes( + layout, tiler + ) tiles_shape = as_shape(tile_shapes) tiles_stride = as_shape(tile_strides) @@ -3462,7 +3541,9 @@ def zipped_product(layout_a: LayoutExpr, layout_b) -> LayoutExpr: Returns: A rank-2 Layout with ((A-modes), (product-modes)) structure """ - forwarded = _forward_layout_domain(layout_a, lambda inner: zipped_product(inner, layout_b)) + forwarded = _forward_layout_domain( + layout_a, lambda inner: zipped_product(inner, layout_b) + ) if forwarded is not _NO_FORWARD: return forwarded return hier_unzip(logical_product, layout_a, layout_b) @@ -3480,7 +3561,9 @@ def tiled_product(layout_a: LayoutExpr, layout_b) -> LayoutExpr: Returns: A Layout with ((A-modes), rest0, rest1, ...) structure """ - forwarded = _forward_layout_domain(layout_a, lambda inner: tiled_product(inner, layout_b)) + forwarded = _forward_layout_domain( + layout_a, lambda inner: tiled_product(inner, layout_b) + ) if forwarded is not _NO_FORWARD: return forwarded result = zipped_product(layout_a, layout_b) @@ -3525,10 +3608,13 @@ def hier_unzip(splitter, layout_a: Layout, layout_b) -> Layout: if is_tuple(layout_b) and not isinstance(layout_b, Layout): if rank(layout_a) < len(layout_b): - raise ValueError(f"layout_a rank ({rank(layout_a)}) < tiler length ({len(layout_b)})") + raise ValueError( + f"layout_a rank ({rank(layout_a)}) < tiler length ({len(layout_b)})" + ) splits = [ - hier_unzip(splitter, mode(layout_a, i), layout_b[i]) for i in range(len(layout_b)) + hier_unzip(splitter, mode(layout_a, i), layout_b[i]) + for i in range(len(layout_b)) ] first_shapes = [mode(s, 0).shape for s in splits] @@ -3572,7 +3658,9 @@ def logical_product(layout_a: LayoutExpr, layout_b: Layout) -> LayoutExpr: Examples: logical_product(Layout(4,1), Layout(3,1)) -> Layout((4,3), (1,4)) """ - forwarded = _forward_layout_domain(layout_a, lambda inner: logical_product(inner, layout_b)) + forwarded = _forward_layout_domain( + layout_a, lambda inner: logical_product(inner, layout_b) + ) if forwarded is not _NO_FORWARD: return forwarded if layout_b is None: @@ -3583,7 +3671,9 @@ def logical_product(layout_a: LayoutExpr, layout_b: Layout) -> LayoutExpr: # For tuple tilers, apply mode-by-mode if is_tuple(layout_b) and not isinstance(layout_b, Layout): if rank(layout_a) < len(layout_b): - raise ValueError(f"layout_a rank ({rank(layout_a)}) < tiler length ({len(layout_b)})") + raise ValueError( + f"layout_a rank ({rank(layout_a)}) < tiler length ({len(layout_b)})" + ) result_modes = [] for i in range(len(layout_b)): result_modes.append(logical_product(mode(layout_a, i), layout_b[i])) @@ -3673,7 +3763,9 @@ def blocked_product(layout_a: LayoutExpr, layout_b: Layout) -> LayoutExpr: Examples: blocked_product((2,2):(1,2), (2,2):(1,2)) -> ((2,2),(2,2)):((1,4),(2,8)) """ - forwarded = _forward_layout_domain(layout_a, lambda inner: blocked_product(inner, layout_b)) + forwarded = _forward_layout_domain( + layout_a, lambda inner: blocked_product(inner, layout_b) + ) if forwarded is not _NO_FORWARD: return forwarded a_cosize_val = cosize(layout_a) @@ -3753,7 +3845,9 @@ def _zip_layouts(layout_a: Layout, layout_b: Layout) -> Layout: # Handle scalar layouts by treating them as rank-1 if a_rank == 0 and b_rank == 0: # Both scalar: create a single mode with paired shapes/strides - return Layout((layout_a.shape, layout_b.shape), (layout_a.stride, layout_b.stride)) + return Layout( + (layout_a.shape, layout_b.shape), (layout_a.stride, layout_b.stride) + ) if a_rank != b_rank: raise ValueError(f"Rank mismatch in zip: {a_rank} vs {b_rank}") @@ -3928,7 +4022,11 @@ def __eq__(self, other: object) -> bool: return True if not isinstance(other, Swizzle): return False - return self.bits == other.bits and self.base == other.base and self.shift == other.shift + return ( + self.bits == other.bits + and self.base == other.base + and self.shift == other.shift + ) def __hash__(self) -> int: return hash((self.bits, self.base, self.shift)) diff --git a/src/tensor_layouts/tensor.py b/src/tensor_layouts/tensor.py index 6409ef4..5bef117 100644 --- a/src/tensor_layouts/tensor.py +++ b/src/tensor_layouts/tensor.py @@ -337,7 +337,9 @@ def __setitem__(self, key, value): ) if isinstance(key, tuple): if len(key) != rank(self._layout): - raise IndexError(f"Expected {rank(self._layout)} indices, got {len(key)}") + raise IndexError( + f"Expected {rank(self._layout)} indices, got {len(key)}" + ) offset = _tensor_address(self._offset, self._layout, key) elif isinstance(key, int): offset = _tensor_address(self._offset, self._layout, key) diff --git a/src/tensor_layouts/viz.py b/src/tensor_layouts/viz.py index b1defdb..0a30031 100644 --- a/src/tensor_layouts/viz.py +++ b/src/tensor_layouts/viz.py @@ -72,6 +72,7 @@ from matplotlib.textpath import TextToPath from .layouts import * + __all__ = [ # draw_* (save to file or display inline) "draw_layout", @@ -177,7 +178,7 @@ def _make_rainbow_palette(n: int, interleave: bool = False) -> list: for i in range(n): hue = i / n r, g, b = colorsys.hsv_to_rgb(hue, sat, val) - monotonic.append(f"#{int(r*255):02X}{int(g*255):02X}{int(b*255):02X}") + monotonic.append(f"#{int(r * 255):02X}{int(g * 255):02X}{int(b * 255):02X}") order = _max_contrast_order(n) return [monotonic[k] for k in order] @@ -257,7 +258,9 @@ def is_hierarchical(self) -> bool: def _normalize_display_layout(layout): if isinstance(layout, Layout): - return Layout(unwrap(layout.shape), unwrap(layout.stride), swizzle=layout.swizzle) + return Layout( + unwrap(layout.shape), unwrap(layout.stride), swizzle=layout.swizzle + ) return layout @@ -280,8 +283,12 @@ def _layout_expr_with_offset(layout, offset: int): return layout if isinstance(layout, Layout): if layout.swizzle is None: - return ComposedLayout(_IDENTITY_LAYOUT, Layout(layout.shape, layout.stride), preoffset=offset) - return ComposedLayout(layout.swizzle, Layout(layout.shape, layout.stride), preoffset=offset) + return ComposedLayout( + _IDENTITY_LAYOUT, Layout(layout.shape, layout.stride), preoffset=offset + ) + return ComposedLayout( + layout.swizzle, Layout(layout.shape, layout.stride), preoffset=offset + ) return ComposedLayout(_IDENTITY_LAYOUT, layout, preoffset=offset) @@ -571,11 +578,17 @@ def _draw_cells( color_idx = idx % len(colors) base_facecolor = colors[color_idx] final_facecolors[i, j] = highlight_facecolor if is_hl else base_facecolor - ax.add_patch(patches.Rectangle( - (j, i), cell_size, cell_size, - facecolor=base_facecolor, edgecolor="black", - linewidth=1, zorder=1, - )) + ax.add_patch( + patches.Rectangle( + (j, i), + cell_size, + cell_size, + facecolor=base_facecolor, + edgecolor="black", + linewidth=1, + zorder=1, + ) + ) if is_hl: highlighted_cells.append((i, j)) return final_facecolors, highlighted_cells @@ -596,11 +609,17 @@ def _draw_cell_highlights( thicker border. """ for i, j in highlighted_cells: - ax.add_patch(patches.Rectangle( - (j, i), cell_size, cell_size, - facecolor=highlight_facecolor, edgecolor=highlight_edgecolor, - linewidth=2, zorder=6, - )) + ax.add_patch( + patches.Rectangle( + (j, i), + cell_size, + cell_size, + facecolor=highlight_facecolor, + edgecolor=highlight_edgecolor, + linewidth=2, + zorder=6, + ) + ) def _draw_cell_value_labels( @@ -633,9 +652,14 @@ def _draw_cell_value_labels( else: label = str(idx) ax.text( - j + 0.5, i + 0.5, label, - ha="center", va="center", - fontsize=cell_fontsize, color=text_color, zorder=7, + j + 0.5, + i + 0.5, + label, + ha="center", + va="center", + fontsize=cell_fontsize, + color=text_color, + zorder=7, ) @@ -649,15 +673,25 @@ def _draw_axis_index_labels( """Draw the row indices on the left and column indices on top.""" for i in range(rows): ax.text( - -0.3, i + 0.5, str(i), - ha="center", va="center", - fontsize=label_fontsize, color=label_color, zorder=8, + -0.3, + i + 0.5, + str(i), + ha="center", + va="center", + fontsize=label_fontsize, + color=label_color, + zorder=8, ) for j in range(cols): ax.text( - j + 0.5, -0.3, str(j), - ha="center", va="center", - fontsize=label_fontsize, color=label_color, zorder=8, + j + 0.5, + -0.3, + str(j), + ha="center", + va="center", + fontsize=label_fontsize, + color=label_color, + zorder=8, ) @@ -726,26 +760,40 @@ def _draw_grid( _setup_axes(ax, (-0.5, cols + 0.5), (-0.5, rows + 0.5), title=title) final_facecolors, highlighted_cells = _draw_cells( - ax, indices, colors, color_indices, - highlight_mask, highlight_facecolor, cell_size, + ax, + indices, + colors, + color_indices, + highlight_mask, + highlight_facecolor, + cell_size, ) if hierarchy_shapes is not None: row_shape, col_shape = hierarchy_shapes _draw_hierarchy_boundary_lines( - ax, rows, cols, + ax, + rows, + cols, _level_block_sizes(row_shape), _level_block_sizes(col_shape), zorder_base=4, ) _draw_cell_highlights( - ax, highlighted_cells, cell_size, - highlight_facecolor, highlight_edgecolor, + ax, + highlighted_cells, + cell_size, + highlight_facecolor, + highlight_edgecolor, ) _draw_cell_value_labels( - ax, indices, final_facecolors, - cell_fontsize, cell_labels, precision, + ax, + indices, + final_facecolors, + cell_fontsize, + cell_labels, + precision, ) if show_labels: @@ -824,7 +872,7 @@ def _build_composite_figure( for p in panels: lay = p[0] if isinstance(p, tuple) else p opts = p[1] if isinstance(p, tuple) else {} - if hasattr(lay, 'layout'): + if hasattr(lay, "layout"): lay = lay.layout lay = as_layout_expr(lay) # Per-panel grid overrides take priority, then defaults @@ -841,8 +889,10 @@ def _build_composite_figure( pr, pc = 1, size(lay) max_rows = max(max_rows, pr) max_cols = max(max_cols, pc) - panel_size = (max_cols * cell_scale + padding_w, - max_rows * cell_scale + padding_h) + panel_size = ( + max_cols * cell_scale + padding_w, + max_rows * cell_scale + padding_h, + ) # Parse arrangement if arrangement == "horizontal": @@ -898,7 +948,7 @@ def _build_composite_figure( # class-identity mismatches after editable-install reloads) eval_fn = None tensor = None - if hasattr(layout, 'layout') and hasattr(layout, 'data') and callable(layout): + if hasattr(layout, "layout") and hasattr(layout, "data") and callable(layout): tensor = layout eval_fn = tensor.__call__ layout = tensor.layout @@ -944,9 +994,13 @@ def _build_composite_figure( else: # Check if this panel should use hierarchical rendering r = rank(layout) - is_hier = r == 2 and not panel_flatten and ( - isinstance(mode(layout.shape, 0), tuple) - or isinstance(mode(layout.shape, 1), tuple) + is_hier = ( + r == 2 + and not panel_flatten + and ( + isinstance(mode(layout.shape, 0), tuple) + or isinstance(mode(layout.shape, 1), tuple) + ) ) grid = _prepare_offset_grid( layout, @@ -1009,7 +1063,7 @@ def _build_composite_figure( def _unwrap_tensor(panel): """Duck-type unwrap a Layout or Tensor into (layout, eval_fn, cell_labels).""" - if hasattr(panel, 'layout') and hasattr(panel, 'data') and callable(panel): + if hasattr(panel, "layout") and hasattr(panel, "data") and callable(panel): tensor = panel layout = tensor.layout eval_fn = tensor.__call__ @@ -1019,7 +1073,9 @@ def _unwrap_tensor(panel): def _build_gemm_figure( - A, B, C, + A, + B, + C, main_title: Optional[str] = None, **defaults, ): @@ -1070,7 +1126,8 @@ def _fmt(s): fig = plt.figure(figsize=(fig_w, fig_h), layout="constrained") gs = fig.add_gridspec( - 2, 2, + 2, + 2, width_ratios=[K, N], height_ratios=[K, M], wspace=0.3, @@ -1097,11 +1154,14 @@ def _render(ax, layout, eval_fn, auto_labels, title): if cell_labels is True and isinstance(auto_labels, list): cell_labels = auto_labels grid = _prepare_offset_grid( - layout, eval_fn=eval_fn, + layout, + eval_fn=eval_fn, color_layout=defaults.get("color_layout"), ) _draw_grid( - ax, grid.indices, title=title, + ax, + grid.indices, + title=title, colorize=defaults.get("colorize", False), color_indices=grid.color_indices, num_colors=defaults.get("num_colors", 8), @@ -1122,7 +1182,9 @@ def _render(ax, layout, eval_fn, auto_labels, title): def draw_gemm( - A, B, C, + A, + B, + C, filename: str = None, main_title: Optional[str] = None, dpi: int = 150, @@ -1886,7 +1948,7 @@ def _build_layout_figure( # (duck-typed to avoid class-identity mismatches after editable-install reloads) eval_fn = None tensor = None - if hasattr(layout, 'layout') and hasattr(layout, 'data') and callable(layout): + if hasattr(layout, "layout") and hasattr(layout, "data") and callable(layout): tensor = layout eval_fn = tensor.__call__ layout = tensor.layout @@ -1956,8 +2018,12 @@ def fn(*args): sub_evals.append(_make_eval(offset, sub)) if color_layout is not None and rank(color_layout) == r: - color_sub, color_offset = slice_and_offset(slice_spec, as_layout_expr(color_layout)) - sub_color_layouts.append(_layout_expr_with_offset(color_sub, color_offset)) + color_sub, color_offset = slice_and_offset( + slice_spec, as_layout_expr(color_layout) + ) + sub_color_layouts.append( + _layout_expr_with_offset(color_sub, color_offset) + ) else: sub_color_layouts.append(color_layout) @@ -1984,7 +2050,9 @@ def fn(*args): for idx in range(n_panels): grid = _prepare_offset_grid( - sub_layouts[idx], color_layout=sub_color_layouts[idx], eval_fn=sub_evals[idx] + sub_layouts[idx], + color_layout=sub_color_layouts[idx], + eval_fn=sub_evals[idx], ) _draw_grid( axes[idx], @@ -2023,7 +2091,9 @@ def fn(*args): if transpose and rank(layout) <= 1: grid = OffsetGrid( indices=grid.indices.T, - color_indices=grid.color_indices.T if grid.color_indices is not None else None, + color_indices=grid.color_indices.T + if grid.color_indices is not None + else None, cell_coords=grid.cell_coords, row_shape=grid.row_shape, col_shape=grid.col_shape, @@ -3152,8 +3222,12 @@ def draw_swizzle( (stacked). Use "vertical" for wide layouts like 8×128. """ fig = _build_swizzle_figure( - base_layout, swizzle, figsize=figsize, colorize=colorize, - num_colors=num_colors, arrangement=arrangement, + base_layout, + swizzle, + figsize=figsize, + colorize=colorize, + num_colors=num_colors, + arrangement=arrangement, ) return _save_figure(fig, filename, dpi) @@ -3183,7 +3257,9 @@ def _expand_hier_slice(spec, shape): elif is_tuple(spec): if is_tuple(shape): if len(spec) != len(shape): - raise ValueError(f"Rank mismatch: spec has {len(spec)} elements, shape has {len(shape)}") + raise ValueError( + f"Rank mismatch: spec has {len(spec)} elements, shape has {len(shape)}" + ) sub_iters = [_expand_hier_slice(s, sh) for s, sh in zip(spec, shape)] for combo in itertools.product(*sub_iters): yield combo @@ -3336,7 +3412,9 @@ def _build_slice_figure( sub, offset = slice_and_offset(slice_spec, as_layout_expr(layout)) if isinstance(sub, Layout): display_sub = _normalize_display_layout(sub) - title = str(display_sub) if offset == 0 else f"{{{offset}}}\u2218{display_sub}" + title = ( + str(display_sub) if offset == 0 else f"{{{offset}}}\u2218{display_sub}" + ) else: display_sub = _layout_expr_with_offset(sub, offset) title = str(display_sub) diff --git a/tests/analysis.py b/tests/analysis.py index 80212ce..215c7c8 100644 --- a/tests/analysis.py +++ b/tests/analysis.py @@ -62,23 +62,23 @@ def test_offset_table_strided(): def test_aliasing_profile_contiguous(): """Contiguous layout has no aliasing.""" profile = aliasing_profile(Layout((2, 4), (1, 2))) - assert profile['has_aliasing'] is False - assert profile['max_alias_ways'] == 1 - assert profile['aliased_offset_count'] == 0 - assert profile['duplicate_elements'] == 0 - assert profile['reuse_histogram'] == {1: 8} - assert profile['aliased_offsets'] == [] + assert profile["has_aliasing"] is False + assert profile["max_alias_ways"] == 1 + assert profile["aliased_offset_count"] == 0 + assert profile["duplicate_elements"] == 0 + assert profile["reuse_histogram"] == {1: 8} + assert profile["aliased_offsets"] == [] def test_aliasing_profile_broadcast(): """Broadcast layout aliases four logical values onto each offset.""" profile = aliasing_profile(Layout((4, 2), (0, 1))) - assert profile['has_aliasing'] is True - assert profile['max_alias_ways'] == 4 - assert profile['aliased_offset_count'] == 2 - assert profile['duplicate_elements'] == 6 - assert profile['reuse_histogram'] == {4: 2} - assert profile['aliased_offsets'] == [0, 1] + assert profile["has_aliasing"] is True + assert profile["max_alias_ways"] == 4 + assert profile["aliased_offset_count"] == 2 + assert profile["duplicate_elements"] == 6 + assert profile["reuse_histogram"] == {4: 2} + assert profile["aliased_offsets"] == [0, 1] def test_aliasing_profile_matches_offset_table_1d_invariants(): @@ -97,21 +97,21 @@ def test_aliasing_profile_matches_offset_table_1d_invariants(): for ways in ways_per_offset.values(): reuse_histogram[ways] = reuse_histogram.get(ways, 0) + 1 - assert profile['has_aliasing'] == bool(aliased_offsets) - assert profile['max_alias_ways'] == max(ways_per_offset.values(), default=0) - assert profile['aliased_offset_count'] == len(aliased_offsets) - assert profile['duplicate_elements'] == size(lyt) - len(ways_per_offset) - assert profile['reuse_histogram'] == dict(sorted(reuse_histogram.items())) - assert profile['aliased_offsets'] == aliased_offsets + assert profile["has_aliasing"] == bool(aliased_offsets) + assert profile["max_alias_ways"] == max(ways_per_offset.values(), default=0) + assert profile["aliased_offset_count"] == len(aliased_offsets) + assert profile["duplicate_elements"] == size(lyt) - len(ways_per_offset) + assert profile["reuse_histogram"] == dict(sorted(reuse_histogram.items())) + assert profile["aliased_offsets"] == aliased_offsets def test_aliasing_profile_tensor_like_input(): """Tensor inputs are accepted via as_layout conversion.""" t = Tensor(Layout((4, 2), (0, 1))) profile = aliasing_profile(t) - assert profile['has_aliasing'] is True - assert profile['max_alias_ways'] == 4 - assert profile['duplicate_elements'] == 6 + assert profile["has_aliasing"] is True + assert profile["max_alias_ways"] == 4 + assert profile["duplicate_elements"] == 6 ## footprint @@ -120,32 +120,32 @@ def test_aliasing_profile_tensor_like_input(): def test_footprint_contiguous(): """Contiguous layout: no holes, no reuse.""" result = footprint(Layout(8, 1)) - assert result['min_offset'] == 0 - assert result['max_offset'] == 7 - assert result['span'] == 8 - assert result['unique_offsets'] == 8 - assert result['total_elements'] == 8 - assert result['reuse_factor'] == 1.0 - assert result['holes'] == 0 + assert result["min_offset"] == 0 + assert result["max_offset"] == 7 + assert result["span"] == 8 + assert result["unique_offsets"] == 8 + assert result["total_elements"] == 8 + assert result["reuse_factor"] == 1.0 + assert result["holes"] == 0 def test_footprint_strided(): """Strided layout: holes between offsets.""" result = footprint(Layout(4, 2)) - assert result['min_offset'] == 0 - assert result['max_offset'] == 6 - assert result['span'] == 7 - assert result['unique_offsets'] == 4 - assert result['holes'] == 3 + assert result["min_offset"] == 0 + assert result["max_offset"] == 6 + assert result["span"] == 7 + assert result["unique_offsets"] == 4 + assert result["holes"] == 3 def test_footprint_broadcast(): """Broadcast: high reuse factor.""" result = footprint(Layout((4, 2), (0, 1))) - assert result['unique_offsets'] == 2 - assert result['total_elements'] == 8 - assert result['reuse_factor'] == 4.0 - assert result['holes'] == 0 + assert result["unique_offsets"] == 2 + assert result["total_elements"] == 8 + assert result["reuse_factor"] == 4.0 + assert result["holes"] == 0 ## gap_profile @@ -154,57 +154,57 @@ def test_footprint_broadcast(): def test_gap_profile_contiguous(): """Contiguous layout has one run and no interior gaps.""" result = gap_profile(Layout(8, 1)) - assert result['runs'] == [(0, 7)] - assert result['gap_sizes'] == [] - assert result['max_gap'] == 0 - assert result['avg_gap'] == 0.0 - assert result['run_count'] == 1 - assert result['isolated_offsets'] == 0 + assert result["runs"] == [(0, 7)] + assert result["gap_sizes"] == [] + assert result["max_gap"] == 0 + assert result["avg_gap"] == 0.0 + assert result["run_count"] == 1 + assert result["isolated_offsets"] == 0 def test_gap_profile_strided(): """Stride-2 layout has single-element runs separated by gap=1.""" result = gap_profile(Layout(4, 2)) - assert result['runs'] == [(0, 0), (2, 2), (4, 4), (6, 6)] - assert result['gap_sizes'] == [1, 1, 1] - assert result['max_gap'] == 1 - assert result['avg_gap'] == pytest.approx(1.0) - assert result['run_count'] == 4 - assert result['isolated_offsets'] == 4 + assert result["runs"] == [(0, 0), (2, 2), (4, 4), (6, 6)] + assert result["gap_sizes"] == [1, 1, 1] + assert result["max_gap"] == 1 + assert result["avg_gap"] == pytest.approx(1.0) + assert result["run_count"] == 4 + assert result["isolated_offsets"] == 4 def test_gap_profile_negative_stride(): """Negative strides still produce sorted runs in offset space.""" result = gap_profile(Layout(4, -2)) - assert result['runs'] == [(-6, -6), (-4, -4), (-2, -2), (0, 0)] - assert result['gap_sizes'] == [1, 1, 1] - assert result['max_gap'] == 1 - assert result['avg_gap'] == pytest.approx(1.0) - assert result['run_count'] == 4 - assert result['isolated_offsets'] == 4 + assert result["runs"] == [(-6, -6), (-4, -4), (-2, -2), (0, 0)] + assert result["gap_sizes"] == [1, 1, 1] + assert result["max_gap"] == 1 + assert result["avg_gap"] == pytest.approx(1.0) + assert result["run_count"] == 4 + assert result["isolated_offsets"] == 4 def test_gap_profile_broadcast_aliasing_has_no_gaps(): """Aliasing does not imply holes: broadcast is dense in offset space.""" result = gap_profile(Layout((4, 2), (0, 1))) - assert result['runs'] == [(0, 1)] - assert result['gap_sizes'] == [] - assert result['max_gap'] == 0 - assert result['avg_gap'] == 0.0 - assert result['run_count'] == 1 - assert result['isolated_offsets'] == 0 + assert result["runs"] == [(0, 1)] + assert result["gap_sizes"] == [] + assert result["max_gap"] == 0 + assert result["avg_gap"] == 0.0 + assert result["run_count"] == 1 + assert result["isolated_offsets"] == 0 def test_gap_profile_zero_size_layout(): """Empty domains produce an empty run/gap profile.""" result = gap_profile(Layout(0, 1)) assert result == { - 'runs': [], - 'gap_sizes': [], - 'max_gap': 0, - 'avg_gap': 0.0, - 'run_count': 0, - 'isolated_offsets': 0, + "runs": [], + "gap_sizes": [], + "max_gap": 0, + "avg_gap": 0.0, + "run_count": 0, + "isolated_offsets": 0, } @@ -217,22 +217,22 @@ def test_gap_profile_invariants_match_footprint(): fp = footprint(lyt) # Reconstruct unique-offset count from run lengths. - run_lengths = [end - start + 1 for start, end in gp['runs']] - assert sum(run_lengths) == fp['unique_offsets'] + run_lengths = [end - start + 1 for start, end in gp["runs"]] + assert sum(run_lengths) == fp["unique_offsets"] # Holes are exactly the interior gaps between runs. - assert sum(gp['gap_sizes']) == fp['holes'] - assert gp['max_gap'] == max(gp['gap_sizes'], default=0) - assert gp['run_count'] == len(gp['runs']) - assert gp['isolated_offsets'] == sum(1 for n in run_lengths if n == 1) + assert sum(gp["gap_sizes"]) == fp["holes"] + assert gp["max_gap"] == max(gp["gap_sizes"], default=0) + assert gp["run_count"] == len(gp["runs"]) + assert gp["isolated_offsets"] == sum(1 for n in run_lengths if n == 1) # Empty/single-run edge cases. - if gp['gap_sizes']: - assert gp['avg_gap'] == pytest.approx( - sum(gp['gap_sizes']) / len(gp['gap_sizes']) + if gp["gap_sizes"]: + assert gp["avg_gap"] == pytest.approx( + sum(gp["gap_sizes"]) / len(gp["gap_sizes"]) ) else: - assert gp['avg_gap'] == 0.0 + assert gp["avg_gap"] == 0.0 ## bank_conflicts @@ -241,14 +241,14 @@ def test_gap_profile_invariants_match_footprint(): def test_bank_conflicts_linear(): """Linear stride-1 access: no conflicts.""" result = bank_conflicts(Layout(32, 1), element_bytes=2) - assert result['conflict_free'] - assert result['max_ways'] == 1 + assert result["conflict_free"] + assert result["max_ways"] == 1 def test_bank_conflicts_broadcast(): """All threads access same address: broadcast, not a conflict.""" result = bank_conflicts(Layout(32, 0), element_bytes=2) - assert result['conflict_free'] + assert result["conflict_free"] def test_bank_conflicts_stride_32(): @@ -257,7 +257,7 @@ def test_bank_conflicts_stride_32(): # Actually: thread t -> offset 32*t, byte_addr = 64*t, bank = (64t/4) % 32 = 16t % 32 # This causes 2-way conflicts (threads 0,2,4,... hit bank 0; threads 1,3,5,... hit bank 16) result = bank_conflicts(Layout(32, 32), element_bytes=2) - assert not result['conflict_free'] + assert not result["conflict_free"] def test_bank_conflicts_swizzled(): @@ -283,13 +283,13 @@ def test_bank_conflicts_swizzled(): element_bytes=4, # treat each offset as a 4-byte word ) # stride-1, 8 consecutive elements with 4-byte words: 8 different banks - assert result['conflict_free'] + assert result["conflict_free"] def test_bank_conflicts_fp32(): """4-byte elements: bank width matches element width.""" result = bank_conflicts(Layout(32, 1), element_bytes=4) - assert result['conflict_free'] + assert result["conflict_free"] def test_bank_conflicts_group_size(): @@ -298,11 +298,11 @@ def test_bank_conflicts_group_size(): r32 = bank_conflicts(Layout(32, 32), element_bytes=2) r64_default = bank_conflicts(Layout(64, 32), element_bytes=2) # Default group_size=32 limits analysis to first warp - assert r64_default['max_ways'] == r32['max_ways'] + assert r64_default["max_ways"] == r32["max_ways"] # Explicitly analyzing all 64 threads gives a larger conflict factor r64_full = bank_conflicts(Layout(64, 32), element_bytes=2, group_size=64) - assert r64_full['max_ways'] > r32['max_ways'] + assert r64_full["max_ways"] > r32["max_ways"] def test_bank_conflicts_group_size_validation(): @@ -318,8 +318,8 @@ def test_bank_conflicts_tv_layout(): # 32 threads, 2 values: stride-1 threads, stride-32 values tv = Layout((32, 2), (1, 32)) r = bank_conflicts(tv, element_bytes=2) - assert r['conflict_free'] - assert len(r['bank_to_threads']) == 32 # all banks accessed + assert r["conflict_free"] + assert len(r["bank_to_threads"]) == 32 # all banks accessed ## coalescing_efficiency @@ -328,40 +328,40 @@ def test_bank_conflicts_tv_layout(): def test_coalescing_contiguous_fp16(): """32 threads, stride 1, fp16: one cache line (64B of 128B).""" result = coalescing_efficiency(Layout(32, 1), element_bytes=2) - assert result['transactions'] == 1 - assert result['efficiency'] == pytest.approx(0.5) + assert result["transactions"] == 1 + assert result["efficiency"] == pytest.approx(0.5) def test_coalescing_contiguous_fp32(): """32 threads, stride 1, fp32: one cache line (128B of 128B).""" result = coalescing_efficiency(Layout(32, 1), element_bytes=4) - assert result['transactions'] == 1 - assert result['efficiency'] == pytest.approx(1.0) + assert result["transactions"] == 1 + assert result["efficiency"] == pytest.approx(1.0) def test_coalescing_strided(): """Stride-2 access doubles the cache lines touched.""" result = coalescing_efficiency(Layout(32, 2), element_bytes=2) - assert result['transactions'] == 1 # 32*2*2=128 bytes, still fits in 1 line + assert result["transactions"] == 1 # 32*2*2=128 bytes, still fits in 1 line # Actually: offsets 0,2,4,...,62. byte addrs 0,4,8,...,124. All in line 0. - assert result['efficiency'] == pytest.approx(0.5) + assert result["efficiency"] == pytest.approx(0.5) def test_coalescing_large_stride(): """Large stride: each thread touches a different cache line.""" # stride 64 elements * 2 bytes = 128 bytes = 1 cache line apart result = coalescing_efficiency(Layout(32, 64), element_bytes=2) - assert result['transactions'] == 32 + assert result["transactions"] == 32 # 32 threads * 2 bytes = 64 useful bytes, 32 * 128 = 4096 transferred - assert result['efficiency'] == pytest.approx(64.0 / (32 * 128)) + assert result["efficiency"] == pytest.approx(64.0 / (32 * 128)) def test_coalescing_broadcast(): """All threads access same element: single transaction, minimal useful bytes.""" result = coalescing_efficiency(Layout(32, 0), element_bytes=2) - assert result['transactions'] == 1 + assert result["transactions"] == 1 # Only 1 unique offset: 1 * 2 bytes useful out of 128 transferred - assert result['efficiency'] == pytest.approx(2.0 / 128) + assert result["efficiency"] == pytest.approx(2.0 / 128) def test_coalescing_tv_layout(): @@ -370,16 +370,16 @@ def test_coalescing_tv_layout(): tv = Layout((32, 4), (4, 1)) result = coalescing_efficiency(tv, element_bytes=2) # 128 unique offsets * 2B = 256B -> cache lines 0, 1 - assert result['transactions'] == 2 - assert result['efficiency'] == pytest.approx(1.0) + assert result["transactions"] == 2 + assert result["efficiency"] == pytest.approx(1.0) def test_coalescing_negative_stride_rebases_footprint(): """Dense reverse layouts should analyze like translated dense runs.""" result = coalescing_efficiency(Layout(4, -1), element_bytes=4) - assert result['transactions'] == 1 - assert result['efficiency'] == pytest.approx(16.0 / 128) - assert result['cache_lines'] == [0] + assert result["transactions"] == 1 + assert result["efficiency"] == pytest.approx(16.0 / 128) + assert result["cache_lines"] == [0] ## segment_analysis @@ -389,29 +389,29 @@ def test_segment_analysis_contiguous_fp16(): """32 threads, stride 1, fp16: 2 segments, 1 cache line.""" result = segment_analysis(Layout(32, 1), element_bytes=2) # 32 * 2B = 64B -> 2 segments of 32B, 1 cache line of 128B - assert result['segments'] == 2 - assert result['cache_lines'] == 1 - assert result['unique_bytes'] == 64 - assert result['requested_bytes'] == 64 - assert result['transferred_bytes'] == 64 # 2 * 32 - assert result['segment_efficiency'] == pytest.approx(1.0) - assert result['first_alignment'] == 0 + assert result["segments"] == 2 + assert result["cache_lines"] == 1 + assert result["unique_bytes"] == 64 + assert result["requested_bytes"] == 64 + assert result["transferred_bytes"] == 64 # 2 * 32 + assert result["segment_efficiency"] == pytest.approx(1.0) + assert result["first_alignment"] == 0 def test_segment_analysis_strided(): """Stride-2 touches more segments than contiguous.""" result = segment_analysis(Layout(32, 2), element_bytes=2) # offsets 0,2,4,...,62 -> byte addrs 0,4,8,...,124 -> 4 segments - assert result['segments'] == 4 - assert result['cache_lines'] == 1 + assert result["segments"] == 4 + assert result["cache_lines"] == 1 def test_segment_analysis_broadcast(): """Broadcast: 1 segment, minimal unique bytes.""" result = segment_analysis(Layout(32, 0), element_bytes=2) - assert result['segments'] == 1 - assert result['unique_bytes'] == 2 - assert result['requested_bytes'] == 64 + assert result["segments"] == 1 + assert result["unique_bytes"] == 2 + assert result["requested_bytes"] == 64 def test_segment_analysis_tv_layout(): @@ -419,21 +419,21 @@ def test_segment_analysis_tv_layout(): tv = Layout((32, 4), (4, 1)) result = segment_analysis(tv, element_bytes=2) # 128 elements * 2B = 256B -> 8 segments, 2 cache lines - assert result['segments'] == 8 - assert result['cache_lines'] == 2 - assert result['requested_bytes'] == 256 # 32 * 4 * 2 + assert result["segments"] == 8 + assert result["cache_lines"] == 2 + assert result["requested_bytes"] == 256 # 32 * 4 * 2 def test_segment_analysis_negative_stride_rebases_footprint(): """Segment analysis should use the addressed footprint, not signed origin.""" result = segment_analysis(Layout((4, 2), (-1, 4)), element_bytes=4) - assert result['segments'] == 1 - assert result['cache_lines'] == 1 - assert result['unique_bytes'] == 32 - assert result['transferred_bytes'] == 32 - assert result['segment_efficiency'] == pytest.approx(1.0) - assert result['first_byte_addr'] == 12 - assert result['first_alignment'] == 12 + assert result["segments"] == 1 + assert result["cache_lines"] == 1 + assert result["unique_bytes"] == 32 + assert result["transferred_bytes"] == 32 + assert result["segment_efficiency"] == pytest.approx(1.0) + assert result["first_byte_addr"] == 12 + assert result["first_alignment"] == 12 ## per-group analysis @@ -443,11 +443,11 @@ def test_per_group_bank_conflicts(): """Per-group analysis matches single-group result for each warp.""" r_single = bank_conflicts(Layout(32, 32), element_bytes=2) r_per = per_group_bank_conflicts(Layout(64, 32), element_bytes=2) - assert len(r_per['groups']) == 2 + assert len(r_per["groups"]) == 2 # Each group should match the single-warp result - for g in r_per['groups']: - assert g['max_ways'] == r_single['max_ways'] - assert r_per['worst_max_ways'] == r_single['max_ways'] + for g in r_per["groups"]: + assert g["max_ways"] == r_single["max_ways"] + assert r_per["worst_max_ways"] == r_single["max_ways"] def test_per_group_bank_conflicts_tv_layout(): @@ -455,16 +455,16 @@ def test_per_group_bank_conflicts_tv_layout(): # 32 threads, 4 values each: should be 1 group (not 4) tv = Layout((32, 4), (1, 32)) result = per_group_bank_conflicts(tv, element_bytes=2, group_size=32) - assert len(result['groups']) == 1 + assert len(result["groups"]) == 1 def test_per_group_coalescing(): """Per-group coalescing for a uniform layout gives identical per-warp results.""" r_per = per_group_coalescing(Layout(64, 1), element_bytes=2) - assert len(r_per['groups']) == 2 - for g in r_per['groups']: - assert g['efficiency'] == pytest.approx(0.5) - assert g['transactions'] == 1 + assert len(r_per["groups"]) == 2 + for g in r_per["groups"]: + assert g["efficiency"] == pytest.approx(0.5) + assert g["transactions"] == 1 def test_per_group_coalescing_tv_layout(): @@ -472,19 +472,19 @@ def test_per_group_coalescing_tv_layout(): # 32 threads, 4 values each (contiguous within each thread's block) tv = Layout((32, 4), (4, 1)) result = per_group_coalescing(tv, element_bytes=2, group_size=32) - assert len(result['groups']) == 1 + assert len(result["groups"]) == 1 # 32 threads * 4 values = 128 elements * 2B = 256B -> 2 cache lines - assert result['groups'][0]['transactions'] == 2 + assert result["groups"][0]["transactions"] == 2 def test_per_group_coalescing_negative_stride_rebases_each_group(): """Each group should analyze its own dense reversed footprint.""" result = per_group_coalescing(Layout(64, -1), element_bytes=4, group_size=32) - assert len(result['groups']) == 2 - for group in result['groups']: - assert group['transactions'] == 1 - assert group['efficiency'] == pytest.approx(1.0) - assert group['cache_lines'] == [0] + assert len(result["groups"]) == 2 + for group in result["groups"]: + assert group["transactions"] == 1 + assert group["efficiency"] == pytest.approx(1.0) + assert group["cache_lines"] == [0] ## cycles @@ -665,43 +665,48 @@ def test_slice_contiguity_col_major(): def test_atom_summary_nv_sm80(): """SM80 16x8x16 F16 atom summary.""" from tensor_layouts.atoms_nv import SM80_16x8x16_F16F16F16F16_TN + result = atom_summary(SM80_16x8x16_F16F16F16F16_TN) - assert result['shape_mnk'] == (16, 8, 16) - assert result['threads'] == 32 - assert result['values_c'] > 0 - assert result['c_coverage_ok'] + assert result["shape_mnk"] == (16, 8, 16) + assert result["threads"] == 32 + assert result["values_c"] > 0 + assert result["c_coverage_ok"] def test_atom_summary_nv_sm80_f32(): """SM80 16x8x8 F32 accumulator atom.""" from tensor_layouts.atoms_nv import SM80_16x8x8_F32F16F16F32_TN + result = atom_summary(SM80_16x8x8_F32F16F16F32_TN) - assert result['shape_mnk'] == (16, 8, 8) - assert result['threads'] == 32 - assert result['c_coverage_ok'] + assert result["shape_mnk"] == (16, 8, 8) + assert result["threads"] == 32 + assert result["c_coverage_ok"] def test_atom_summary_amd_cdna(): """AMD CDNA 32x32x8 MFMA atom summary.""" from tensor_layouts.atoms_amd import CDNA_32x32x8_F32F16F16_MFMA + result = atom_summary(CDNA_32x32x8_F32F16F16_MFMA) - assert result['shape_mnk'] == (32, 32, 8) - assert result['threads'] == 64 # AMD wavefront - assert result['c_coverage_ok'] + assert result["shape_mnk"] == (32, 32, 8) + assert result["threads"] == 64 # AMD wavefront + assert result["c_coverage_ok"] def test_atom_summary_text_output(): """atom_summary returns a readable text summary.""" from tensor_layouts.atoms_nv import SM80_16x8x16_F16F16F16F16_TN + result = atom_summary(SM80_16x8x16_F16F16F16F16_TN) - assert 'SM80' in result['text'] - assert '16 x 8 x 16' in result['text'] - assert 'Threads' in result['text'] + assert "SM80" in result["text"] + assert "16 x 8 x 16" in result["text"] + assert "Threads" in result["text"] def test_atom_summary_rejects_wrong_c_offsets(): """c_coverage_ok must check exact offset set, not just cardinality.""" from tensor_layouts.atoms import MMAAtom + # Build a 2x2 atom where C layout produces offsets {0, 1, 2, 5} # instead of the expected {0, 1, 2, 3}. Cardinality is 4 = M*N, # but the set is wrong. @@ -710,10 +715,10 @@ def test_atom_summary_rejects_wrong_c_offsets(): ptx="test", shape_mnk=(2, 2, 1), thr_id=Layout(4), - a_layout=Layout((4, 1), (1, 0)), # doesn't matter for this test - b_layout=Layout((4, 1), (1, 0)), # doesn't matter for this test + a_layout=Layout((4, 1), (1, 0)), # doesn't matter for this test + b_layout=Layout((4, 1), (1, 0)), # doesn't matter for this test # C layout: 4 threads, 1 value each -> offsets 0, 1, 2, 5 - c_layout=Layout((4, 1), (1, 0)), # placeholder, override below + c_layout=Layout((4, 1), (1, 0)), # placeholder, override below ) # Manually construct a C layout that maps t -> {0, 1, 2, 5} # Layout((4, 1), (1, 0)) maps t -> t, giving {0, 1, 2, 3} — that's correct. @@ -721,16 +726,19 @@ def test_atom_summary_rejects_wrong_c_offsets(): # Layout with shape (2, 2) stride (1, 2) gives 0,1,2,3 — still correct. # Use a non-standard construction: ((2, 2), 1) : ((1, 4), 0) -> 0,1,4,5 import dataclasses + bad_c = Layout(((2, 2), 1), ((1, 4), 0)) bad_atom = dataclasses.replace(bad_atom, c_layout=bad_c) result = atom_summary(bad_atom) - assert not result['c_coverage_ok'] + assert not result["c_coverage_ok"] def test_atom_summary_rejects_duplicate_c_coverage(): """c_coverage_ok must be False when C layout produces duplicate offsets.""" - from tensor_layouts.atoms import MMAAtom import dataclasses + + from tensor_layouts.atoms import MMAAtom + # Build a 2x2x1 atom where C layout has shape (4, 2) stride (1, 0). # This maps (t, v) pairs to offsets [0,0,1,1,2,2,3,3] — correct set # but each offset appears twice. @@ -746,25 +754,27 @@ def test_atom_summary_rejects_duplicate_c_coverage(): dup_c = Layout((4, 2), (1, 0)) # 8 accesses, offsets 0..3 each twice bad_atom = dataclasses.replace(base, c_layout=dup_c) result = atom_summary(bad_atom) - assert not result['c_coverage_ok'] + assert not result["c_coverage_ok"] def test_operand_analysis_sm80(): """operand_analysis on a well-formed atom reports full coverage.""" from tensor_layouts.atoms_nv import SM80_16x8x16_F16F16F16F16_TN + result = operand_analysis(SM80_16x8x16_F16F16F16F16_TN) - for op in ['a', 'b', 'c']: - assert result[op]['coverage_ok'] - assert result[op]['duplicates'] == 0 - assert result[op]['thread_utilization'] == pytest.approx(1.0) - assert result['a']['domain_size'] == 16 * 16 # M * K - assert result['b']['domain_size'] == 8 * 16 # N * K - assert result['c']['domain_size'] == 16 * 8 # M * N + for op in ["a", "b", "c"]: + assert result[op]["coverage_ok"] + assert result[op]["duplicates"] == 0 + assert result[op]["thread_utilization"] == pytest.approx(1.0) + assert result["a"]["domain_size"] == 16 * 16 # M * K + assert result["b"]["domain_size"] == 8 * 16 # N * K + assert result["c"]["domain_size"] == 16 * 8 # M * N def test_operand_analysis_bad_coverage(): """operand_analysis detects malformed operand coverage.""" from tensor_layouts.atoms import MMAAtom + base = MMAAtom( name="test_bad_operand", ptx="test", @@ -775,9 +785,9 @@ def test_operand_analysis_bad_coverage(): c_layout=Layout(((2, 2), 1), ((1, 4), 0)), # offsets {0,1,4,5}, not {0,1,2,3} ) result = operand_analysis(base) - assert not result['c']['coverage_ok'] - assert len(result['c']['missing']) > 0 - assert len(result['c']['extra']) > 0 + assert not result["c"]["coverage_ok"] + assert len(result["c"]["missing"]) > 0 + assert len(result["c"]["extra"]) > 0 ## explain @@ -786,36 +796,36 @@ def test_operand_analysis_bad_coverage(): def test_explain_logical_divide(): """explain shows step-by-step logical_divide computation.""" text = explain(logical_divide, Layout(16, 1), 4) - assert 'logical_divide' in text - assert 'complement' in text - assert 'compose' in text - assert '(4, 4) : (1, 4)' in text + assert "logical_divide" in text + assert "complement" in text + assert "compose" in text + assert "(4, 4) : (1, 4)" in text def test_explain_logical_product(): """explain shows step-by-step logical_product computation.""" text = explain(logical_product, Layout(4, 1), Layout(3, 1)) - assert 'logical_product' in text - assert 'complement' in text - assert '(4, 3) : (1, 4)' in text + assert "logical_product" in text + assert "complement" in text + assert "(4, 3) : (1, 4)" in text def test_explain_logical_product_layout_tiler_uses_cosize_bound(): """Layout-tiler explanations should match CuTe's size(A) * cosize(B).""" text = explain(logical_product, Layout(4, 1), Layout(3, 2)) - assert 'size(A) = 4' in text - assert 'cosize(B) = 5' in text - assert 'size(A) * cosize(B) = 20' in text - assert 'size(A) * size(B)' not in text + assert "size(A) = 4" in text + assert "cosize(B) = 5" in text + assert "size(A) * cosize(B) = 20" in text + assert "size(A) * size(B)" not in text assert str(logical_product(Layout(4, 1), Layout(3, 2))) in text def test_explain_logical_product_tuple_tiler(): """explain handles logical_product with tuple tiler without crashing.""" text = explain(logical_product, Layout((4, 4), (1, 4)), (2, 2)) - assert 'logical_product' in text - assert 'mode 0' in text - assert 'mode 1' in text + assert "logical_product" in text + assert "mode 0" in text + assert "mode 1" in text expected = logical_product(Layout((4, 4), (1, 4)), (2, 2)) assert str(expected) in text @@ -823,16 +833,16 @@ def test_explain_logical_product_tuple_tiler(): def test_explain_complement(): """explain shows complement with image and codomain.""" text = explain(complement, Layout(4, 2), 16) - assert 'image' in text - assert 'codomain' in text - assert '[0, 16)' in text + assert "image" in text + assert "codomain" in text + assert "[0, 16)" in text def test_explain_compose(): """explain shows compose with per-element trace.""" text = explain(compose, Layout(8, 2), Layout(4, 1)) - assert 'C(i) = A(B(i))' in text - assert 'i=0' in text + assert "C(i) = A(B(i))" in text + assert "i=0" in text def test_explain_compose_tuple_tiler(): @@ -840,72 +850,72 @@ def test_explain_compose_tuple_tiler(): A = Layout((4, 8), (8, 1)) B = (2, 4) text = explain(compose, A, B) - assert 'For tuple tilers, composition is applied mode-by-mode.' in text - assert 'mode 0: compose(4 : 8, 2 : 1) = 2 : 8' in text - assert 'mode 1: compose(8 : 1, 4 : 1) = 4 : 1' in text - assert 'coord=(0, 0): result((0, 0))=0' in text + assert "For tuple tilers, composition is applied mode-by-mode." in text + assert "mode 0: compose(4 : 8, 2 : 1) = 2 : 8" in text + assert "mode 1: compose(8 : 1, 4 : 1) = 4 : 1" in text + assert "coord=(0, 0): result((0, 0))=0" in text assert str(compose(A, B)) in text def test_explain_right_inverse(): """explain shows right_inverse with verification.""" text = explain(right_inverse, Layout(4, 2)) - assert 'R such that L(R(i)) == i' in text - assert 'Verification' in text + assert "R such that L(R(i)) == i" in text + assert "Verification" in text def test_explain_left_inverse(): """explain shows left_inverse with verification.""" text = explain(left_inverse, Layout(4, 2)) - assert 'R such that R(L(i)) == i' in text - assert 'Verification' in text + assert "R such that R(L(i)) == i" in text + assert "Verification" in text def test_explain_unsupported(): """explain gracefully handles unsupported functions.""" text = explain(size, Layout(4, 1)) - assert 'does not support' in text + assert "does not support" in text def test_explain_blocked_product(): """explain shows blocked_product as interleaved logical_product.""" text = explain(blocked_product, Layout((2, 3), (1, 2)), Layout((4, 2), (1, 4))) - assert 'blocked_product' in text - assert 'logical_product' in text - assert 'A varies fastest' in text + assert "blocked_product" in text + assert "logical_product" in text + assert "A varies fastest" in text def test_explain_raked_product(): """explain shows raked_product with comparison to blocked.""" text = explain(raked_product, Layout(4, 1), Layout(3, 1)) - assert 'raked_product' in text - assert 'B varies fastest' in text - assert 'blocked' in text - assert 'raked' in text + assert "raked_product" in text + assert "B varies fastest" in text + assert "blocked" in text + assert "raked" in text def test_explain_zipped_divide(): """explain shows zipped_divide as rearranged logical_divide.""" text = explain(zipped_divide, Layout((4, 6), (1, 4)), (2, 3)) - assert 'zipped_divide' in text - assert 'logical_divide' in text - assert '((tiles), (rests))' in text + assert "zipped_divide" in text + assert "logical_divide" in text + assert "((tiles), (rests))" in text def test_explain_tiled_divide(): """explain shows tiled_divide structure.""" text = explain(tiled_divide, Layout((4, 6), (1, 4)), (2, 3)) - assert 'tiled_divide' in text - assert 'logical_divide' in text - assert '((tiles), rest0, rest1, ...)' in text + assert "tiled_divide" in text + assert "logical_divide" in text + assert "((tiles), rest0, rest1, ...)" in text def test_explain_flat_divide(): """explain shows flat_divide structure.""" text = explain(flat_divide, Layout((4, 6), (1, 4)), (2, 3)) - assert 'flat_divide' in text - assert 'logical_divide' in text - assert '(tile0, tile1, ..., rest0, rest1, ...)' in text + assert "flat_divide" in text + assert "logical_divide" in text + assert "(tile0, tile1, ..., rest0, rest1, ...)" in text ## MMAAtom and CopyAtom __str__ @@ -999,11 +1009,13 @@ def test_F2_matrix_row_major(): def test_F2_matrix_col_major(): """Column-major produces identity (already in natural bit order).""" M = to_F2_matrix(Layout((4, 8), (1, 4))) - assert M == [[1, 0, 0, 0, 0], - [0, 1, 0, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 0, 1, 0], - [0, 0, 0, 0, 1]] + assert M == [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + ] def test_F2_matrix_swizzle(): @@ -1088,6 +1100,7 @@ def test_F2_matrix_sm80_mma_c_accumulator(): n1..n2 = T0..T1 (thread pairs select column groups) """ from tensor_layouts.atoms_nv import SM80_16x8x16_F16F16F16F16_TN + c = SM80_16x8x16_F16F16F16F16_TN.c_layout M = to_F2_matrix(c) # T0 T1 T2 T3 T4 V0 V1 @@ -1175,9 +1188,14 @@ def test_F2_matrix_sm80_c_vs_triton_MMAv2(): # (matching CuTe's flattened thread → value coord bit order) triton_bases = [ # lane (thread) bases — from LinearLayoutConversionsTest.cpp:438 - (0, 2), (0, 4), (1, 0), (2, 0), (4, 0), + (0, 2), + (0, 4), + (1, 0), + (2, 0), + (4, 0), # register (value) bases — from LinearLayoutConversionsTest.cpp:438 - (0, 1), (8, 0), + (0, 1), + (8, 0), ] expected = _triton_bases_to_F2_matrix(triton_bases, tile_M=16, tile_N=8) @@ -1195,13 +1213,13 @@ def test_F2_matrix_sm80_c_all_atoms_share_layout(): Source: mma_traits_sm80.hpp line 53 (SM80_16x8_Row definition). """ from tensor_layouts.atoms_nv import ( - SM80_16x8x8_F16F16F16F16_TN, - SM80_16x8x16_F32F16F16F32_TN, + SM120_16x8x32_F32E4M3E4M3F32_TN, SM80_16x8x16_F32BF16BF16F32_TN, + SM80_16x8x16_F32F16F16F32_TN, SM80_16x8x32_S32S8S8S32_TN, SM80_16x8x64_S32S4S4S32_TN, + SM80_16x8x8_F16F16F16F16_TN, SM89_16x8x32_F32E4M3E4M3F32_TN, - SM120_16x8x32_F32E4M3E4M3F32_TN, ) ref = to_F2_matrix(SM80_16x8x8_F16F16F16F16_TN.c_layout) @@ -1240,11 +1258,18 @@ def test_F2_matrix_sm90_gmma_c_64x16_vs_triton_MMAv3(): # (matching CuTe's flattened thread → value coord bit order) triton_bases = [ # lane bases — LinearLayoutConversionsTest.cpp:531 - (0, 2), (0, 4), (1, 0), (2, 0), (4, 0), + (0, 2), + (0, 4), + (1, 0), + (2, 0), + (4, 0), # warp bases — LinearLayoutConversionsTest.cpp:532 - (16, 0), (32, 0), + (16, 0), + (32, 0), # register bases — LinearLayoutConversionsTest.cpp:530 - (0, 1), (8, 0), (0, 8), + (0, 1), + (8, 0), + (0, 8), ] expected = _triton_bases_to_F2_matrix(triton_bases, tile_M=64, tile_N=16) @@ -1280,8 +1305,8 @@ def test_F2_matrix_sm80_a_operand(): a = SM80_16x8x16_F16F16F16F16_TN.a_layout M = to_F2_matrix(a) - assert len(M) == 8 # 8 offset bits (cosize = 256 = 16×16) - assert len(M[0]) == 8 # 8 coord bits (5 thread + 3 value) + assert len(M) == 8 # 8 offset bits (cosize = 256 = 16×16) + assert len(M[0]) == 8 # 8 coord bits (5 thread + 3 value) _verify_F2_matrix(a) @@ -1296,8 +1321,8 @@ def test_F2_matrix_sm80_b_operand(): b = SM80_16x8x16_F16F16F16F16_TN.b_layout M = to_F2_matrix(b) - assert len(M) == 7 # 7 offset bits (cosize = 128 = 8×16) - assert len(M[0]) == 7 # 7 coord bits (5 thread + 2 value) + assert len(M) == 7 # 7 offset bits (cosize = 128 = 8×16) + assert len(M[0]) == 7 # 7 coord bits (5 thread + 2 value) _verify_F2_matrix(b) @@ -1328,11 +1353,19 @@ def test_F2_matrix_sm90_gmma_c_64x32_vs_triton_MMAv3(): # Triton basis vectors — non-zero bases only, lane→warp→register order triton_bases = [ # lane bases — LinearLayoutConversionsTest.cpp:577 - (0, 2), (0, 4), (1, 0), (2, 0), (4, 0), + (0, 2), + (0, 4), + (1, 0), + (2, 0), + (4, 0), # warp bases (skipping W2=(0,0) broadcast) — line 578 - (16, 0), (32, 0), + (16, 0), + (32, 0), # register bases — LinearLayoutConversionsTest.cpp:576 - (0, 1), (8, 0), (0, 8), (0, 16), + (0, 1), + (8, 0), + (0, 8), + (0, 16), ] expected = _triton_bases_to_F2_matrix(triton_bases, tile_M=64, tile_N=32) @@ -1350,15 +1383,20 @@ def test_F2_matrix_sm90_warp_c_vs_triton_MMAv3(): The atom-level bases (first 2 register + 5 lane) match MMAv2. """ from tensor_layouts.atoms_nv import ( - SM90_16x8x4_F64F64F64F64_TN, - SM90_16x8x16_F64F64F64F64_TN, SM90_16x8x16_C64C64C64C64_TN, + SM90_16x8x16_F64F64F64F64_TN, + SM90_16x8x4_F64F64F64F64_TN, ) # Same expected matrix as SM80 MMAv2 C accumulator triton_bases = [ - (0, 2), (0, 4), (1, 0), (2, 0), (4, 0), # lane - (0, 1), (8, 0), # register + (0, 2), + (0, 4), + (1, 0), + (2, 0), + (4, 0), # lane + (0, 1), + (8, 0), # register ] expected = _triton_bases_to_F2_matrix(triton_bases, tile_M=16, tile_N=8) @@ -1407,8 +1445,7 @@ def test_F2_matrix_sm100_umma_c_parametric(): """ from tensor_layouts.atoms_nv import umma_layout - for m, n in [(64, 64), (64, 128), (64, 256), - (128, 64), (128, 128), (128, 256)]: + for m, n in [(64, 64), (64, 128), (64, 256), (128, 64), (128, 128), (128, 256)]: c = umma_layout(m, n) M = to_F2_matrix(c) n_bits = len(M) @@ -1633,53 +1670,53 @@ def test_functionally_equal_row_col_major(): def test_thread_stride_profile_linear_layout(): """Stride-1 layout has a uniform +1 thread stride.""" p = thread_stride_profile(Layout(8, 1)) - assert p['thread_count'] == 8 - assert p['value_count'] == 1 - assert p['per_value_deltas'] == [[1, 1, 1, 1, 1, 1, 1]] - assert p['global_unique_strides'] == [1] - assert p['is_uniform'] is True - assert p['has_broadcast_lane'] is False + assert p["thread_count"] == 8 + assert p["value_count"] == 1 + assert p["per_value_deltas"] == [[1, 1, 1, 1, 1, 1, 1]] + assert p["global_unique_strides"] == [1] + assert p["is_uniform"] is True + assert p["has_broadcast_lane"] is False def test_thread_stride_profile_broadcast_lane(): """Stride-0 thread mode is detected as a broadcast lane.""" p = thread_stride_profile(Layout((8, 2), (0, 8))) - assert p['thread_count'] == 8 - assert p['value_count'] == 2 - assert p['per_value_constant_stride'] == [0, 0] - assert p['global_unique_strides'] == [0] - assert p['is_uniform'] is True - assert p['has_broadcast_lane'] is True + assert p["thread_count"] == 8 + assert p["value_count"] == 2 + assert p["per_value_constant_stride"] == [0, 0] + assert p["global_unique_strides"] == [0] + assert p["is_uniform"] is True + assert p["has_broadcast_lane"] is True def test_thread_stride_profile_nonuniform_lane(): """Composed layout can produce non-constant adjacent-thread deltas.""" p = thread_stride_profile(compose(Swizzle(2, 0, 2), Layout(8, 1))) - assert p['thread_count'] == 8 - assert p['value_count'] == 1 - assert p['per_value_is_constant'] == [False] - assert p['per_value_constant_stride'] == [None] - assert p['global_unique_strides'] == [-1, 1, 2, 3] - assert p['is_uniform'] is False + assert p["thread_count"] == 8 + assert p["value_count"] == 1 + assert p["per_value_is_constant"] == [False] + assert p["per_value_constant_stride"] == [None] + assert p["global_unique_strides"] == [-1, 1, 2, 3] + assert p["is_uniform"] is False def test_thread_stride_profile_multimode_value_shape(): """Value modes beyond rank-2 are flattened lane-wise correctly.""" tv = Layout((4, (2, 2)), (1, (8, 16))) p = thread_stride_profile(tv) - assert p['thread_count'] == 4 - assert p['value_count'] == 4 - assert p['per_value_constant_stride'] == [1, 1, 1, 1] - assert p['global_unique_strides'] == [1] - assert p['is_uniform'] is True + assert p["thread_count"] == 4 + assert p["value_count"] == 4 + assert p["per_value_constant_stride"] == [1, 1, 1, 1] + assert p["global_unique_strides"] == [1] + assert p["is_uniform"] is True def test_thread_stride_profile_single_thread_edge_case(): """Single-thread layouts have no adjacent-thread deltas.""" p = thread_stride_profile(Layout((1, 4), (0, 1))) - assert p['thread_count'] == 1 - assert p['value_count'] == 4 - assert p['per_value_deltas'] == [[], [], [], []] - assert p['global_unique_strides'] == [] - assert p['is_uniform'] is True - assert p['has_broadcast_lane'] is False + assert p["thread_count"] == 1 + assert p["value_count"] == 4 + assert p["per_value_deltas"] == [[], [], [], []] + assert p["global_unique_strides"] == [] + assert p["is_uniform"] is True + assert p["has_broadcast_lane"] is False diff --git a/tests/composed.py b/tests/composed.py index 13a0bdf..211e848 100644 --- a/tests/composed.py +++ b/tests/composed.py @@ -203,7 +203,9 @@ def test_compose_swizzled_layout_outer_preserves_exactness(): def test_logical_divide_forwards_through_composed_layout(): composed = compose(Layout(16, 2), compose(Swizzle(2, 0, 2), Layout(16, 1))) result = logical_divide(composed, 4) - expected = ComposedLayout(composed.outer, logical_divide(composed.inner, 4), preoffset=0) + expected = ComposedLayout( + composed.outer, logical_divide(composed.inner, 4), preoffset=0 + ) assert isinstance(result, ComposedLayout) assert result.outer == composed.outer @@ -214,7 +216,9 @@ def test_logical_divide_forwards_through_composed_layout(): def test_logical_product_forwards_through_composed_layout(): composed = compose(Layout(8, 2), compose(Swizzle(2, 0, 2), Layout(8, 1))) result = logical_product(composed, Layout(3, 1)) - expected = ComposedLayout(composed.outer, logical_product(composed.inner, Layout(3, 1))) + expected = ComposedLayout( + composed.outer, logical_product(composed.inner, Layout(3, 1)) + ) assert isinstance(result, ComposedLayout) _assert_pointwise_equal(result, expected) @@ -401,6 +405,7 @@ def test_generative_compose_with_preoffsets(): # Divide / product cascade on composed inputs # --------------------------------------------------------------------------- + def test_zipped_divide_forwards_through_composed_layout(): composed = compose( Layout((4, 4), (4, 1)), @@ -435,7 +440,9 @@ def test_zipped_product_forwards_through_composed_layout(): composed = compose(Layout(8, 2), compose(Swizzle(2, 0, 2), Layout(8, 1))) result = zipped_product(composed, Layout(3, 1)) assert isinstance(result, ComposedLayout) - expected = ComposedLayout(composed.outer, zipped_product(composed.inner, Layout(3, 1))) + expected = ComposedLayout( + composed.outer, zipped_product(composed.inner, Layout(3, 1)) + ) _assert_pointwise_equal(result, expected) @@ -443,7 +450,9 @@ def test_tiled_product_forwards_through_composed_layout(): composed = compose(Layout(8, 2), compose(Swizzle(2, 0, 2), Layout(8, 1))) result = tiled_product(composed, Layout(3, 1)) assert isinstance(result, ComposedLayout) - expected = ComposedLayout(composed.outer, tiled_product(composed.inner, Layout(3, 1))) + expected = ComposedLayout( + composed.outer, tiled_product(composed.inner, Layout(3, 1)) + ) _assert_pointwise_equal(result, expected) @@ -451,7 +460,9 @@ def test_flat_product_forwards_through_composed_layout(): composed = compose(Layout(8, 2), compose(Swizzle(2, 0, 2), Layout(8, 1))) result = flat_product(composed, Layout(3, 1)) assert isinstance(result, ComposedLayout) - expected = ComposedLayout(composed.outer, flat_product(composed.inner, Layout(3, 1))) + expected = ComposedLayout( + composed.outer, flat_product(composed.inner, Layout(3, 1)) + ) _assert_pointwise_equal(result, expected) @@ -459,6 +470,7 @@ def test_flat_product_forwards_through_composed_layout(): # Recursive composition and push-through # --------------------------------------------------------------------------- + def test_compose_composed_layout_as_outer_pushes_through(): """compose(ComposedLayout, Layout) pushes into the inner.""" base_inner = compose(Swizzle(2, 0, 2), Layout(16, 1)) @@ -515,6 +527,7 @@ def test_logical_divide_on_hierarchical_composed(): # Full-slice and multi-mode # --------------------------------------------------------------------------- + def test_full_slice_on_composed_layout_preserves_identity(): """Slicing with all-None returns the same composed layout with offset 0.""" composed = compose( @@ -542,6 +555,7 @@ def test_mode_on_composed_layout_mode1(): # Tensor.view() with ComposedLayout # --------------------------------------------------------------------------- + def test_tensor_view_with_composed_layout(): """Tensor.view() accepts a ComposedLayout.""" composed = compose(Layout(16, 2), compose(Swizzle(2, 0, 2), Layout(16, 1))) @@ -556,6 +570,7 @@ def test_tensor_view_with_composed_layout(): # Generic analysis functions with ComposedLayout # --------------------------------------------------------------------------- + def test_image_on_composed_layout(): composed = compose(Layout(16, 2), compose(Swizzle(2, 0, 2), Layout(16, 1))) img = image(composed) diff --git a/tests/conftest.py b/tests/conftest.py index 2041265..aadfe86 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,10 @@ # conftest.py — shared fixtures for tests/ + def pytest_addoption(parser): - parser.addoption("--draw", action="store_true", default=False, - help="Generate paper figures into tests/figures/") + parser.addoption( + "--draw", + action="store_true", + default=False, + help="Generate paper figures into tests/figures/", + ) diff --git a/tests/external.py b/tests/external.py index 1f9d010..3b07581 100644 --- a/tests/external.py +++ b/tests/external.py @@ -52,34 +52,34 @@ def _test_complement_properties(layout, cotarget=None): completed = Layout(layout, result) # Property 1: Lower-bound on codomain size of layout ++ complement - assert ( - cosize(completed) >= cotarget_size - ), f"cosize(completed)={cosize(completed)} < size(cotarget)={cotarget_size}" + assert cosize(completed) >= cotarget_size, ( + f"cosize(completed)={cosize(completed)} < size(cotarget)={cotarget_size}" + ) # Property 2: Upper-bound on codomain size of complement # Always use cosize(layout), regardless of rank layout_cosize = cosize(layout) - assert cosize(result) <= round_up( - cotarget_size, layout_cosize - ), f"cosize(result)={cosize(result)} > round_up({cotarget_size}, {layout_cosize})={round_up(cotarget_size, layout_cosize)}" + assert cosize(result) <= round_up(cotarget_size, layout_cosize), ( + f"cosize(result)={cosize(result)} > round_up({cotarget_size}, {layout_cosize})={round_up(cotarget_size, layout_cosize)}" + ) # Property 3: Result is ordered (CuTe starts at i=1) for i in range(1, size(result)): - assert result(i - 1) < result( - i - ), f"result is not ordered: result({i-1})={result(i-1)} >= result({i})={result(i)}" + assert result(i - 1) < result(i), ( + f"result is not ordered: result({i - 1})={result(i - 1)} >= result({i})={result(i)}" + ) # Property 4: Result is disjoint from layout (CuTe starts at i=1) for i in range(1, size(result)): for j in range(size(layout)): - assert result(i) != layout( - j - ), f"result and layout overlap: result({i})={result(i)} == layout({j})={layout(j)}" + assert result(i) != layout(j), ( + f"result and layout overlap: result({i})={result(i)} == layout({j})={layout(j)}" + ) # Other observations from CuTe - assert size(result) <= cosize( - result - ), f"size(result)={size(result)} > cosize(result)={cosize(result)}" + assert size(result) <= cosize(result), ( + f"size(result)={size(result)} > cosize(result)={cosize(result)}" + ) def test_complement_layout_1_0(): @@ -226,17 +226,20 @@ def _test_coalesce_properties(layout): coalesce_layout = coalesce(layout) # Property 1: Result depth is at most 1 (flattened) - assert depth(coalesce_layout) <= 1, \ + assert depth(coalesce_layout) <= 1, ( f"depth(coalesce_layout)={depth(coalesce_layout)} > 1" + ) # Property 2: Size is preserved - assert size(coalesce_layout) == size(layout), \ + assert size(coalesce_layout) == size(layout), ( f"size(coalesce_layout)={size(coalesce_layout)} != size(layout)={size(layout)}" + ) # Property 3: All indices map to the same offsets for i in range(size(layout)): - assert coalesce_layout(i) == layout(i), \ + assert coalesce_layout(i) == layout(i), ( f"coalesce_layout({i})={coalesce_layout(i)} != layout({i})={layout(i)}" + ) def test_coalesce_simple(): @@ -329,16 +332,18 @@ def _test_composition_properties(layout_a, layout_b): layout_r = compose(layout_a, layout_b) # Property 1: Layout B is compatible with layout R - assert compatible(layout_b.shape, layout_r.shape), \ + assert compatible(layout_b.shape, layout_r.shape), ( f"layoutB.shape={layout_b.shape} not compatible with layoutR.shape={layout_r.shape}" + ) # Property 2: R(c) = A(B(c)) for coordinates within A's domain a_size = size(layout_a) for c in range(size(layout_b)): bc = layout_b(c) if bc < a_size: - assert layout_r(c) == layout_a(bc), \ + assert layout_r(c) == layout_a(bc), ( f"layoutR({c})={layout_r(c)} != layoutA(layoutB({c}))={layout_a(bc)}" + ) def test_composition_simple(): @@ -436,21 +441,16 @@ def test_composition_multidimensional(): def test_composition_nested(): # Layout((8, 8)) o Layout(((2, 2, 2), (2, 2, 2)), ((1, 16, 4), (8, 2, 32))) _test_composition_properties( - Layout((8, 8)), - Layout(((2, 2, 2), (2, 2, 2)), ((1, 16, 4), (8, 2, 32))) + Layout((8, 8)), Layout(((2, 2, 2), (2, 2, 2)), ((1, 16, 4), (8, 2, 32))) ) # Layout((8, 8), (8, 1)) o Layout(((2, 2, 2), (2, 2, 2)), ((1, 16, 4), (8, 2, 32))) _test_composition_properties( - Layout((8, 8), (8, 1)), - Layout(((2, 2, 2), (2, 2, 2)), ((1, 16, 4), (8, 2, 32))) + Layout((8, 8), (8, 1)), Layout(((2, 2, 2), (2, 2, 2)), ((1, 16, 4), (8, 2, 32))) ) # Layout(((4, 2),), ((1, 16),)) o Layout((4, 2), (2, 1)) - _test_composition_properties( - Layout(((4, 2),), ((1, 16),)), - Layout((4, 2), (2, 1)) - ) + _test_composition_properties(Layout(((4, 2),), ((1, 16),)), Layout((4, 2), (2, 1))) # Layout((2, 2), (2, 1)) o Layout((2, 2), (2, 1)) _test_composition_properties(Layout((2, 2), (2, 1)), Layout((2, 2), (2, 1))) @@ -460,14 +460,12 @@ def test_composition_nested(): # Layout((4, 8, 2), (2, 8, 1)) o Layout((2, 2, 2), (1, 8, 2)) _test_composition_properties( - Layout((4, 8, 2), (2, 8, 1)), - Layout((2, 2, 2), (1, 8, 2)) + Layout((4, 8, 2), (2, 8, 1)), Layout((2, 2, 2), (1, 8, 2)) ) # Layout((4, 8, 2), (2, 8, 1)) o Layout((4, 2, 2), (2, 8, 1)) _test_composition_properties( - Layout((4, 8, 2), (2, 8, 1)), - Layout((4, 2, 2), (2, 8, 1)) + Layout((4, 8, 2), (2, 8, 1)), Layout((4, 2, 2), (2, 8, 1)) ) @@ -484,16 +482,10 @@ def test_composition_dynamic(): _test_composition_properties(Layout(16, 2), Layout(4, 2)) # Layout((128, 24, 5), (1, 128, 3072)) o Layout(64, 2) - _test_composition_properties( - Layout((128, 24, 5), (1, 128, 3072)), - Layout(64, 2) - ) + _test_composition_properties(Layout((128, 24, 5), (1, 128, 3072)), Layout(64, 2)) # Layout((128, 24, 5), (1, 128, 3072)) o Layout(480, 32) - _test_composition_properties( - Layout((128, 24, 5), (1, 128, 3072)), - Layout(480, 32) - ) + _test_composition_properties(Layout((128, 24, 5), (1, 128, 3072)), Layout(480, 32)) def test_composition_cosize_larger(): @@ -582,12 +574,15 @@ def _test_logical_divide_properties(layout, tile): # For Layout tilers, verify the result rank is 2 (tile, rest) if isinstance(tile, Layout): # CuTe formula produces rank-2 result: (Tile, Rest) - assert rank(result) == 2, f"Expected rank 2 for Layout tiler, got {rank(result)}" + assert rank(result) == 2, ( + f"Expected rank 2 for Layout tiler, got {rank(result)}" + ) # The tile part (mode 0) should have size equal to size(tiler) tile_part = mode(result, 0) - assert size(tile_part) == size(tile_layout), \ + assert size(tile_part) == size(tile_layout), ( f"Tile part size {size(tile_part)} != tiler size {size(tile_layout)}" + ) def test_logical_divide_simple(): @@ -707,6 +702,7 @@ def _test_swizzle_2d(sw_layout): This tests that slicing a tensor with swizzled layout preserves correct indexing. """ from tensor_layouts import Tensor + tensor = Tensor(sw_layout) # Get dimensions @@ -752,7 +748,7 @@ def test_swizzle_3_0_3(): """ sw_layout = compose( Swizzle(3, 0, 3), - Layout((8, 8), (8, 1)) # 8x8 row-major + Layout((8, 8), (8, 1)), # 8x8 row-major ) _test_swizzle_2d(sw_layout) @@ -767,7 +763,7 @@ def test_swizzle_3_0_neg3(): """ sw_layout = compose( Swizzle(3, 0, -3), - Layout((8, 8), (8, 1)) # 8x8 row-major + Layout((8, 8), (8, 1)), # 8x8 row-major ) _test_swizzle_2d(sw_layout) @@ -788,11 +784,7 @@ def test_swizzle_2_1_3(): - shift=3: masks are 3 positions apart (bits [1,3) and [4,6)) """ sw_layout = compose( - Swizzle(2, 1, 3), - Layout( - ((2, 2, 2), (2, 2, 2)), - ((32, 2, 8), (4, 1, 16)) - ) + Swizzle(2, 1, 3), Layout(((2, 2, 2), (2, 2, 2)), ((32, 2, 8), (4, 1, 16))) ) _test_swizzle_2d(sw_layout) @@ -870,10 +862,7 @@ def test_swizzle_with_base(): def test_composed_layout_repr(): """Test swizzled Layout string representation.""" - sw_layout = compose( - Swizzle(3, 0, 3), - Layout((8, 8), (8, 1)) - ) + sw_layout = compose(Swizzle(3, 0, 3), Layout((8, 8), (8, 1))) repr_str = repr(sw_layout) assert "Swizzle(3, 0, 3)" in repr_str @@ -903,10 +892,7 @@ def test_inner_product(): """Test inner_product (pycute test_int_tuple.py::test_inner_product).""" assert inner_product(2, 3) == 6 assert inner_product((1, 2), (3, 2)) == 7 - assert inner_product( - ((2, 3), 4), - ((2, 1), 2) - ) == 15 + assert inner_product(((2, 3), 4), ((2, 1), 2)) == 15 def test_prefix_product(): @@ -915,9 +901,11 @@ def test_prefix_product(): assert prefix_product((3, 2)) == (1, 3) assert prefix_product((3, 2, 4)) == (1, 3, 6) assert prefix_product(((2, 3), 4)) == ((1, 2), 6) - assert prefix_product( - ((2, 3), (2, 1, 2), (5, 2, 1)) - ) == ((1, 2), (6, 12, 12), (24, 120, 240)) + assert prefix_product(((2, 3), (2, 1, 2), (5, 2, 1))) == ( + (1, 2), + (6, 12, 12), + (24, 120, 240), + ) def test_shape_div_pycute(): @@ -976,6 +964,7 @@ def test_coalesce_pycute(): Uses the pycute helper: verify size and functional equivalence. """ + def _check(layout): layoutR = coalesce(layout) assert size(layoutR) == size(layout) @@ -1007,6 +996,7 @@ def test_composition_pycute(): Uses the pycute helper: R(i) == A(B(i)) for all i. """ + def _check(A, B): R = compose(A, B) for i in range(size(R)): @@ -1044,7 +1034,9 @@ def _check(A, B): _check(Layout((4, 3), (3, 1)), Layout(6, 2)) _check(Layout((4, 3), (3, 1)), Layout((6, 2), (2, 1))) _check(Layout((8, 8)), Layout(((2, 2, 2), (2, 2, 2)), ((1, 16, 4), (8, 2, 32)))) - _check(Layout((8, 8), (8, 1)), Layout(((2, 2, 2), (2, 2, 2)), ((1, 16, 4), (8, 2, 32)))) + _check( + Layout((8, 8), (8, 1)), Layout(((2, 2, 2), (2, 2, 2)), ((1, 16, 4), (8, 2, 32))) + ) # Layout applied from right with stride (from pycute, not in C++ tests) _check(Layout(((2, 2, 2), (2, 2, 2)), ((1, 16, 4), (8, 2, 32))), Layout(8, 4)) _check(Layout((4, 2), (1, 16)), Layout((4, 2), (2, 1))) @@ -1067,8 +1059,7 @@ def _test_right_inverse(layout): inv_layout = right_inverse(layout) for i in range(size(inv_layout)): assert layout(inv_layout(i)) == i, ( - f"right_inverse({layout}): L(R({i})) = " - f"{layout(inv_layout(i))} != {i}" + f"right_inverse({layout}): L(R({i})) = {layout(inv_layout(i))} != {i}" ) @@ -1110,8 +1101,7 @@ def _test_left_inverse(layout): inv_layout = left_inverse(layout) for i in range(size(layout)): assert inv_layout(layout(i)) == i, ( - f"left_inverse({layout}): R(L({i})) = " - f"{inv_layout(layout(i))} != {i}" + f"left_inverse({layout}): R(L({i})) = {inv_layout(layout(i))} != {i}" ) @@ -1171,8 +1161,7 @@ def _test_logical_product_properties(layout_a, layout_b): # Property 2: First mode of R equals A R0 = mode(R, 0) assert R0.shape == layout_a.shape and R0.stride == layout_a.stride, ( - f"logical_product({layout_a}, {layout_b}): " - f"mode(R,0)={R0} != A={layout_a}" + f"logical_product({layout_a}, {layout_b}): mode(R,0)={R0} != A={layout_a}" ) # Property 3: B is compatible with second mode of R @@ -1213,13 +1202,9 @@ def test_logical_product_multidim_tile(): # Layout((2,4)) x Layout(3) _test_logical_product_properties(Layout((2, 4)), Layout(3)) # Layout((8,(2,2))) x Layout(4,2) - _test_logical_product_properties( - Layout((8, (2, 2)), (1, (8, 16))), Layout(4, 2) - ) + _test_logical_product_properties(Layout((8, (2, 2)), (1, (8, 16))), Layout(4, 2)) # Layout((2,2)) x Layout((3,3),(3,1)) - _test_logical_product_properties( - Layout((2, 2), (1, 2)), Layout((3, 3), (3, 1)) - ) + _test_logical_product_properties(Layout((2, 2), (1, 2)), Layout((3, 3), (3, 1))) def test_logical_product_large_stride(): @@ -1231,13 +1216,9 @@ def test_logical_product_large_stride(): def test_logical_product_nested(): """Logical product with nested/hierarchical layouts (C++ lines 175-213).""" # Layout(((4,2)),((1,16))) x Layout((4,4)) - _test_logical_product_properties( - Layout((4, 2), (1, 16)), Layout((4, 4)) - ) + _test_logical_product_properties(Layout((4, 2), (1, 16)), Layout((4, 4))) # Layout(((4,2)),((1,16))) x Layout((4,2),(2,1)) - _test_logical_product_properties( - Layout((4, 2), (1, 16)), Layout((4, 2), (2, 1)) - ) + _test_logical_product_properties(Layout((4, 2), (1, 16)), Layout((4, 2), (2, 1))) # Layout(((2,2),(2,2)),((1,4),(8,32))) x Layout((2,2),(1,2)) _test_logical_product_properties( Layout(((2, 2), (2, 2)), ((1, 4), (8, 32))), @@ -1249,9 +1230,7 @@ def test_logical_product_nested(): Layout((2, 2), (2, 1)), ) # Layout(((4,6)),((1,6))) x Layout(3,1) - _test_logical_product_properties( - Layout((4, 6), (1, 6)), Layout(3, 1) - ) + _test_logical_product_properties(Layout((4, 6), (1, 6)), Layout(3, 1)) ## Left Inverse edge cases (C++ inverse_left.cpp) @@ -1281,9 +1260,7 @@ def _test_left_inverse_cpp(layout): # Fast path: stride-0 modes make injectivity impossible if _has_broadcast(layout): - assert size(inv_layout) >= 1, ( - f"left_inverse({layout}): empty result" - ) + assert size(inv_layout) >= 1, f"left_inverse({layout}): empty result" return # No broadcast modes — check injectivity and contiguity via enumeration @@ -1293,13 +1270,10 @@ def _test_left_inverse_cpp(layout): ili = inv_layout(li) lili = layout(ili) assert lili == li, ( - f"left_inverse({layout}): " - f"L(inv(L({i})))={lili} != L({i})={li}" + f"left_inverse({layout}): L(inv(L({i})))={lili} != L({i})={li}" ) else: - assert size(inv_layout) >= 1, ( - f"left_inverse({layout}): empty result" - ) + assert size(inv_layout) >= 1, f"left_inverse({layout}): empty result" def test_left_inverse_cpp_broadcast(): @@ -1339,10 +1313,12 @@ def test_left_inverse_cpp_deep_nested(): Shape: (((( 32, 4), 1), ( 32, 2)), 4), 1, (2, 2), 2) Stride: ((((262144, 4), 0), ( 0, 1)), 8388608), 0, (2, 16), 32) """ - _test_left_inverse_cpp(Layout( - ((((32, 4), 1), (32, 2)), 4, 1, (2, 2), 2), - ((((262144, 4), 0), (0, 1)), 8388608, 0, (2, 16), 32), - )) + _test_left_inverse_cpp( + Layout( + ((((32, 4), 1), (32, 2)), 4, 1, (2, 2), 2), + ((((262144, 4), 0), (0, 1)), 8388608, 0, (2, 16), 32), + ) + ) ## Right Inverse edge cases (C++ inverse_right.cpp) @@ -1357,10 +1333,7 @@ def _test_right_inverse_cpp(layout): inv_layout = right_inverse(layout) for i in range(size(inv_layout)): li = layout(inv_layout(i)) - assert li == i, ( - f"right_inverse({layout}): " - f"L(R({i}))={li} != {i}" - ) + assert li == i, f"right_inverse({layout}): L(R({i}))={li} != {i}" def test_right_inverse_cpp_4d(): @@ -1389,10 +1362,12 @@ def test_right_inverse_cpp_broadcast_middle(): def test_right_inverse_cpp_deep_nested(): """Right-inverse of deeply nested layout (C++ inverse_right.cpp line 210).""" - _test_right_inverse_cpp(Layout( - ((((32, 4), 1), (32, 2)), 4, 1, (2, 2), 2), - ((((262144, 4), 0), (0, 1)), 8388608, 0, (2, 16), 32), - )) + _test_right_inverse_cpp( + Layout( + ((((32, 4), 1), (32, 2)), 4, 1, (2, 2), 2), + ((((262144, 4), 0), (0, 1)), 8388608, 0, (2, 16), 32), + ) + ) ## Composition edge case (C++ composition.cpp line 241-246) @@ -1403,9 +1378,7 @@ def test_composition_transposed_strides(): Layout((4,3)) o Layout((4,3),(3,1)) -- col-major transposed. """ - _test_composition_properties( - Layout((4, 3)), Layout((4, 3), (3, 1)) - ) + _test_composition_properties(Layout((4, 3)), Layout((4, 3), (3, 1))) ## Complement edge case (pycute Python test_complement.py) @@ -1566,7 +1539,9 @@ def test_slice_and_offset(): # Verify: sublayout(i) + offset == original(fixed_coord, i) for i in range(size(sub)): - assert sub(i) + offset == layout(2, i), f"i={i}: {sub(i) + offset} != {layout(2, i)}" + assert sub(i) + offset == layout(2, i), ( + f"i={i}: {sub(i) + offset} != {layout(2, i)}" + ) ## zipped_product diff --git a/tests/layouts.py b/tests/layouts.py index c0e02c8..8fce80f 100644 --- a/tests/layouts.py +++ b/tests/layouts.py @@ -24,13 +24,13 @@ from tensor_layouts import * from tensor_layouts.analysis import functionally_equal +from tensor_layouts.atoms_nv import SM80_16x8x16_F16F16F16F16_TN from tensor_layouts.layout_utils import ( make_layout_like, make_ordered_layout, tile_mma_grid, tile_to_shape, ) -from tensor_layouts.atoms_nv import SM80_16x8x16_F16F16F16F16_TN # These tests roughly follow: @@ -984,7 +984,9 @@ def test_compose_two_2d(): # B(0,0)=0, B(1,0)=1, B(0,1)=2, B(1,1)=3 # A(0)=0, A(1)=1, A(2)=2, A(3)=3 # So compose gives same result as B indexing into first 4 elements of A - assert compose(Layout((4, 4), (1, 4)), Layout((2, 2), (1, 2))) == Layout((2, 2), (1, 2)) + assert compose(Layout((4, 4), (1, 4)), Layout((2, 2), (1, 2))) == Layout( + (2, 2), (1, 2) + ) def test_compose_functional_equivalence(): @@ -1111,7 +1113,9 @@ def test_logical_divide_nested_tuple_tiler_recurses_mode_by_mode(): assert result == expected assert result == Layout((((2, 1), (3, 1)), (4, 2)), (((1, 0), (2, 0)), (6, 24))) - assert sorted(result(i) for i in range(size(result))) == sorted(a(i) for i in range(size(a))) + assert sorted(result(i) for i in range(size(result))) == sorted( + a(i) for i in range(size(a)) + ) assert functionally_equal(result, expected) @@ -1919,6 +1923,7 @@ def test_is_layout(): assert is_affine(ComposedLayout(Layout(4, 1), Layout(4, 1))) is False # Dispatches through .layout (Tensor) from tensor_layouts import Tensor + assert is_affine(Tensor(Layout(4, 1))) is True assert is_affine(Tensor(ComposedLayout(Layout(4, 1), Layout(4, 1)))) is False assert is_layout(4) is False @@ -2038,7 +2043,9 @@ def test_make_ordered_layout_scalar_rejects_invalid_order(): def test_tile_mma_grid_exact_multiple_expands_c_panel(): atom = SM80_16x8x16_F16F16F16F16_TN - grid, tile_shape = tile_mma_grid(atom, Layout(1, 1), matrix="C", tile_mnk=(32, 16, 16)) + grid, tile_shape = tile_mma_grid( + atom, Layout(1, 1), matrix="C", tile_mnk=(32, 16, 16) + ) assert tile_shape == (32, 16, 16) assert max(r for r, _ in grid) + 1 == 32 @@ -2047,7 +2054,9 @@ def test_tile_mma_grid_exact_multiple_expands_c_panel(): def test_tile_mma_grid_exact_multiple_expands_a_panel_along_k(): atom = SM80_16x8x16_F16F16F16F16_TN - grid, tile_shape = tile_mma_grid(atom, Layout(1, 1), matrix="A", tile_mnk=(16, 8, 32)) + grid, tile_shape = tile_mma_grid( + atom, Layout(1, 1), matrix="A", tile_mnk=(16, 8, 32) + ) assert tile_shape == (16, 8, 32) assert max(r for r, _ in grid) + 1 == 16 @@ -2432,11 +2441,11 @@ def test_upcast_known_copy_atoms(): derived from the CUTLASS C++ copy_traits_sm75.hpp source. """ from tensor_layouts.atoms_nv import ( - SM75_U32x1_LDSM_N, - SM75_U32x4_LDSM_N, SM75_U16x2_LDSM_T, SM75_U16x4_LDSM_T, SM75_U16x8_LDSM_T, + SM75_U32x1_LDSM_N, + SM75_U32x4_LDSM_N, ) cases = [ diff --git a/tests/oracle_amd.py b/tests/oracle_amd.py index dfbc7d0..2cb7530 100644 --- a/tests/oracle_amd.py +++ b/tests/oracle_amd.py @@ -40,46 +40,51 @@ """ import tempfile -import pytest -from tensor_layouts import Layout, size, rank, depth, mode, cosize +import pytest +from tensor_layouts import cosize, depth, Layout, mode, rank, size +from tensor_layouts.layout_utils import make_ordered_layout, product_each, tile_to_shape from tensor_layouts.layouts import ( - compose, complement, flatten, coalesce, - logical_divide, logical_product, - left_inverse, right_inverse, - idx2crd, crd2idx, -) -from tensor_layouts.layout_utils import ( - make_ordered_layout, tile_to_shape, product_each, + coalesce, + complement, + compose, + crd2idx, + flatten, + idx2crd, + left_inverse, + logical_divide, + logical_product, + right_inverse, ) from tensor_layouts.atoms_amd import * # Try to import the AMD matrix instruction calculator. try: from amd_matrix_instruction_calculator import matrix_calculator + HAS_CALCULATOR = True except ImportError: try: import matrix_calculator + HAS_CALCULATOR = True except ImportError: HAS_CALCULATOR = False requires_calculator = pytest.mark.skipif( - not HAS_CALCULATOR, - reason="amd_matrix_instruction_calculator not available" + not HAS_CALCULATOR, reason="amd_matrix_instruction_calculator not available" ) # Try to import visualization module (requires matplotlib). try: - from tensor_layouts.viz import draw_tv_layout, draw_mma_layout, _compute_tv_mapping + from tensor_layouts.viz import _compute_tv_mapping, draw_mma_layout, draw_tv_layout + HAS_VIZ = True except ImportError: HAS_VIZ = False requires_viz = pytest.mark.skipif( - not HAS_VIZ, - reason="tensor_layouts.viz not available (needs matplotlib)" + not HAS_VIZ, reason="tensor_layouts.viz not available (needs matplotlib)" ) @@ -133,9 +138,12 @@ # Helpers # ============================================================================= + def _num_threads(layout): """Number of threads from a TV layout's thread dimension.""" - return size(layout.shape[0]) if isinstance(layout.shape, tuple) else size(layout.shape) + return ( + size(layout.shape[0]) if isinstance(layout.shape, tuple) else size(layout.shape) + ) def _num_values(layout): @@ -166,10 +174,10 @@ def get_calculator_d_mapping(arch: str, instruction: str, m: int, n: int): try: info = matrix_calculator.get_instruction_info(arch, instruction) mapping = {} - num_vgprs = info['num_output_regs'] + num_vgprs = info["num_output_regs"] for lane in range(64): for vgpr in range(num_vgprs): - row, col = info['get_output'](lane, vgpr) + row, col = info["get_output"](lane, vgpr) mapping[(lane, vgpr)] = (row, col) return mapping except (AttributeError, KeyError): @@ -206,9 +214,8 @@ def validate_c_layout(atom, arch: str): f"ref=({ref_row},{ref_col})" ) - assert not errors, ( - f"{atom.name}: {len(errors)} mismatches:\n" + - "\n".join(errors[:20]) + assert not errors, f"{atom.name}: {len(errors)} mismatches:\n" + "\n".join( + errors[:20] ) @@ -216,14 +223,17 @@ def validate_c_layout(atom, arch: str): # CDNA1 (gfx908) FP16 atoms # ============================================================================= + @requires_calculator def test_oracle_cdna_32x32x8_f32f16f16(): validate_c_layout(CDNA_32x32x8_F32F16F16_MFMA, "cdna1") + @requires_calculator def test_oracle_cdna_16x16x16_f32f16f16(): validate_c_layout(CDNA_16x16x16_F32F16F16_MFMA, "cdna1") + @requires_calculator def test_oracle_cdna_4x4x4_f32f16f16(): validate_c_layout(CDNA_4x4x4_F32F16F16_MFMA, "cdna1") @@ -233,10 +243,12 @@ def test_oracle_cdna_4x4x4_f32f16f16(): # CDNA1 non-k-reduction variants # ============================================================================= + @requires_calculator def test_oracle_cdna_32x32x4_f32f16f16(): validate_c_layout(CDNA_32x32x4_F32F16F16_MFMA, "cdna1") + @requires_calculator def test_oracle_cdna_16x16x4_f32f16f16(): validate_c_layout(CDNA_16x16x4_F32F16F16_MFMA, "cdna1") @@ -246,10 +258,12 @@ def test_oracle_cdna_16x16x4_f32f16f16(): # CDNA2 (gfx90a) BF16_1K atoms # ============================================================================= + @requires_calculator def test_oracle_cdna_32x32x8_f32bf16bf16_1k(): validate_c_layout(CDNA_32x32x8_F32BF16BF16_1K_MFMA, "cdna2") + @requires_calculator def test_oracle_cdna_16x16x16_f32bf16bf16_1k(): validate_c_layout(CDNA_16x16x16_F32BF16BF16_1K_MFMA, "cdna2") @@ -259,10 +273,12 @@ def test_oracle_cdna_16x16x16_f32bf16bf16_1k(): # CDNA1/2 BF16 (original, non-1K) atoms # ============================================================================= + @requires_calculator def test_oracle_cdna_32x32x4_f32bf16bf16(): validate_c_layout(CDNA_32x32x4_F32BF16BF16_MFMA, "cdna1") + @requires_calculator def test_oracle_cdna_16x16x8_f32bf16bf16(): validate_c_layout(CDNA_16x16x8_F32BF16BF16_MFMA, "cdna1") @@ -272,10 +288,12 @@ def test_oracle_cdna_16x16x8_f32bf16bf16(): # CDNA1/2 INT8 atoms # ============================================================================= + @requires_calculator def test_oracle_cdna_32x32x8_i32i8i8(): validate_c_layout(CDNA_32x32x8_I32I8I8_MFMA, "cdna1") + @requires_calculator def test_oracle_cdna_16x16x16_i32i8i8(): validate_c_layout(CDNA_16x16x16_I32I8I8_MFMA, "cdna1") @@ -285,10 +303,12 @@ def test_oracle_cdna_16x16x16_i32i8i8(): # CDNA1/2 FP32 atoms # ============================================================================= + @requires_calculator def test_oracle_cdna_32x32x2_f32f32f32(): validate_c_layout(CDNA_32x32x2_F32F32F32_MFMA, "cdna1") + @requires_calculator def test_oracle_cdna_16x16x4_f32f32f32(): validate_c_layout(CDNA_16x16x4_F32F32F32_MFMA, "cdna1") @@ -298,6 +318,7 @@ def test_oracle_cdna_16x16x4_f32f32f32(): # CDNA2/3 FP64 atom # ============================================================================= + @requires_calculator def test_oracle_cdna_16x16x4_f64f64f64(): validate_c_layout(CDNA_16x16x4_F64F64F64_MFMA, "cdna2") @@ -307,18 +328,22 @@ def test_oracle_cdna_16x16x4_f64f64f64(): # CDNA3 (gfx942) enhanced atoms # ============================================================================= + @requires_calculator def test_oracle_cdna3_32x32x16_i32i8i8(): validate_c_layout(CDNA3_32x32x16_I32I8I8_MFMA, "cdna3") + @requires_calculator def test_oracle_cdna3_16x16x32_i32i8i8(): validate_c_layout(CDNA3_16x16x32_I32I8I8_MFMA, "cdna3") + @requires_calculator def test_oracle_cdna3_32x32x4_f32xf32xf32(): validate_c_layout(CDNA3_32x32x4_F32XF32XF32_MFMA, "cdna3") + @requires_calculator def test_oracle_cdna3_16x16x8_f32xf32xf32(): validate_c_layout(CDNA3_16x16x8_F32XF32XF32_MFMA, "cdna3") @@ -328,34 +353,42 @@ def test_oracle_cdna3_16x16x8_f32xf32xf32(): # CDNA3 FP8 atoms # ============================================================================= + @requires_calculator def test_oracle_cdna3_32x32x16_f32f8f8(): validate_c_layout(CDNA3_32x32x16_F32F8F8_MFMA, "cdna3") + @requires_calculator def test_oracle_cdna3_16x16x32_f32f8f8(): validate_c_layout(CDNA3_16x16x32_F32F8F8_MFMA, "cdna3") + @requires_calculator def test_oracle_cdna3_32x32x16_f32bf8bf8(): validate_c_layout(CDNA3_32x32x16_F32BF8BF8_MFMA, "cdna3") + @requires_calculator def test_oracle_cdna3_16x16x32_f32bf8bf8(): validate_c_layout(CDNA3_16x16x32_F32BF8BF8_MFMA, "cdna3") + @requires_calculator def test_oracle_cdna3_32x32x16_f32f8bf8(): validate_c_layout(CDNA3_32x32x16_F32F8BF8_MFMA, "cdna3") + @requires_calculator def test_oracle_cdna3_16x16x32_f32f8bf8(): validate_c_layout(CDNA3_16x16x32_F32F8BF8_MFMA, "cdna3") + @requires_calculator def test_oracle_cdna3_32x32x16_f32bf8f8(): validate_c_layout(CDNA3_32x32x16_F32BF8F8_MFMA, "cdna3") + @requires_calculator def test_oracle_cdna3_16x16x32_f32bf8f8(): validate_c_layout(CDNA3_16x16x32_F32BF8F8_MFMA, "cdna3") @@ -365,26 +398,32 @@ def test_oracle_cdna3_16x16x32_f32bf8f8(): # CDNA3+ (gfx950) double-rate atoms # ============================================================================= + @requires_calculator def test_oracle_cdna3p_32x32x16_f32f16f16(): validate_c_layout(CDNA3P_32x32x16_F32F16F16_MFMA, "cdna3") + @requires_calculator def test_oracle_cdna3p_16x16x32_f32f16f16(): validate_c_layout(CDNA3P_16x16x32_F32F16F16_MFMA, "cdna3") + @requires_calculator def test_oracle_cdna3p_32x32x16_f32bf16bf16(): validate_c_layout(CDNA3P_32x32x16_F32BF16BF16_MFMA, "cdna3") + @requires_calculator def test_oracle_cdna3p_16x16x32_f32bf16bf16(): validate_c_layout(CDNA3P_16x16x32_F32BF16BF16_MFMA, "cdna3") + @requires_calculator def test_oracle_cdna3p_32x32x32_i32i8i8(): validate_c_layout(CDNA3P_32x32x32_I32I8I8_MFMA, "cdna3") + @requires_calculator def test_oracle_cdna3p_16x16x64_i32i8i8(): validate_c_layout(CDNA3P_16x16x64_I32I8I8_MFMA, "cdna3") @@ -397,6 +436,7 @@ def test_oracle_cdna3p_16x16x64_i32i8i8(): # These tests verify algebraic properties of the layouts themselves, # independent of the AMD calculator. They always run. + @pytest.mark.parametrize("atom", ALL_ATOMS, ids=lambda a: a.name) class TestMFMAStructural: """Structural invariants that must hold for any valid MFMA atom.""" @@ -412,118 +452,127 @@ def test_c_layout_covers_all_elements(self, atom): for t in range(num_t): for v in range(num_v): offset = c(t, v) - assert 0 <= offset < m * n, \ - f"{atom.name}: offset {offset} out of range [0, {m*n})" - assert offset not in seen, \ + assert 0 <= offset < m * n, ( + f"{atom.name}: offset {offset} out of range [0, {m * n})" + ) + assert offset not in seen, ( f"{atom.name}: duplicate offset {offset} at t={t}, v={v}" + ) seen.add(offset) - assert len(seen) == m * n, \ - f"{atom.name}: covers {len(seen)} elements, expected {m*n}" + assert len(seen) == m * n, ( + f"{atom.name}: covers {len(seen)} elements, expected {m * n}" + ) def test_c_layout_thread_count(self, atom): """Thread dimension has exactly 64 elements (one wavefront).""" c = atom.c_layout - assert _num_threads(c) == 64, \ + assert _num_threads(c) == 64, ( f"{atom.name}: {_num_threads(c)} threads, expected 64" + ) def test_a_layout_thread_count(self, atom): """A layout thread dimension has exactly 64 elements.""" a = atom.a_layout - assert _num_threads(a) == 64, \ + assert _num_threads(a) == 64, ( f"{atom.name}: A has {_num_threads(a)} threads, expected 64" + ) def test_b_layout_thread_count(self, atom): """B layout thread dimension has exactly 64 elements.""" b = atom.b_layout - assert _num_threads(b) == 64, \ + assert _num_threads(b) == 64, ( f"{atom.name}: B has {_num_threads(b)} threads, expected 64" + ) def test_a_layout_broadcast(self, atom): """A layout broadcasts across blocks (stride-0 in block dimension).""" a = atom.a_layout if isinstance(a.stride, tuple) and isinstance(a.stride[0], tuple): blk_stride = a.stride[0][0] - assert blk_stride == 0, \ + assert blk_stride == 0, ( f"{atom.name}: A layout block stride is {blk_stride}, expected 0" + ) def test_b_layout_broadcast(self, atom): """B layout broadcasts across blocks (stride-0 in block dimension).""" b = atom.b_layout if isinstance(b.stride, tuple) and isinstance(b.stride[0], tuple): blk_stride = b.stride[0][0] - assert blk_stride == 0, \ + assert blk_stride == 0, ( f"{atom.name}: B layout block stride is {blk_stride}, expected 0" + ) def test_a_layout_cosize_bounded(self, atom): """A layout codomain is bounded by thread_count * values_per_thread.""" a = atom.a_layout # cosize is max_offset + 1; for broadcast layouts this can exceed M*K # but must be bounded by the underlying coordinate space - assert cosize(a) >= 1, \ - f"{atom.name}: A cosize must be positive" + assert cosize(a) >= 1, f"{atom.name}: A cosize must be positive" def test_b_layout_cosize_bounded(self, atom): """B layout codomain is bounded by thread_count * values_per_thread.""" b = atom.b_layout - assert cosize(b) >= 1, \ - f"{atom.name}: B cosize must be positive" + assert cosize(b) >= 1, f"{atom.name}: B cosize must be positive" def test_c_layout_cosize_equals_mn(self, atom): """C layout codomain spans exactly M x N (since it's a bijection).""" m, n, k = atom.shape_mnk c = atom.c_layout - assert cosize(c) == m * n, \ - f"{atom.name}: C cosize {cosize(c)} != M*N={m*n}" + assert cosize(c) == m * n, f"{atom.name}: C cosize {cosize(c)} != M*N={m * n}" def test_thr_id_is_none(self, atom): """AMD MFMA atoms use identity thread mapping (thr_id is None).""" - assert atom.thr_id is None, \ + assert atom.thr_id is None, ( f"{atom.name}: thr_id should be None, got {atom.thr_id}" + ) def test_c_layout_rank_is_2(self, atom): """C layout is rank-2: (thread, value).""" c = atom.c_layout - assert rank(c) == 2, \ - f"{atom.name}: C rank {rank(c)}, expected 2" + assert rank(c) == 2, f"{atom.name}: C rank {rank(c)}, expected 2" def test_a_layout_rank_is_2(self, atom): """A layout is rank-2: (thread, value).""" a = atom.a_layout - assert rank(a) == 2, \ - f"{atom.name}: A rank {rank(a)}, expected 2" + assert rank(a) == 2, f"{atom.name}: A rank {rank(a)}, expected 2" def test_b_layout_rank_is_2(self, atom): """B layout is rank-2: (thread, value).""" b = atom.b_layout - assert rank(b) == 2, \ - f"{atom.name}: B rank {rank(b)}, expected 2" + assert rank(b) == 2, f"{atom.name}: B rank {rank(b)}, expected 2" def test_layout_sizes_match_shape_mnk(self, atom): """Layout domain sizes are consistent with M, N, K.""" m, n, k = atom.shape_mnk a, b, c = atom.a_layout, atom.b_layout, atom.c_layout - assert size(c) == m * n, \ - f"{atom.name}: C size {size(c)} != M*N={m*n}" + assert size(c) == m * n, f"{atom.name}: C size {size(c)} != M*N={m * n}" # A and B sizes include the broadcast dimension, so size >= M*K / N*K # but since broadcast replicates the same data, size == 64 * values_per_thread - assert size(a) == 64 * _num_values(a), \ + assert size(a) == 64 * _num_values(a), ( f"{atom.name}: A size {size(a)} != 64 * {_num_values(a)}" - assert size(b) == 64 * _num_values(b), \ + ) + assert size(b) == 64 * _num_values(b), ( f"{atom.name}: B size {size(b)} != 64 * {_num_values(b)}" + ) # ============================================================================= # Layout algebra tests (run without the calculator) # ============================================================================= + @pytest.mark.parametrize("atom", ALL_ATOMS, ids=lambda a: a.name) class TestLayoutAlgebra: """Test layout algebra operations on real AMD atom layouts.""" def test_size_rank_depth_mode(self, atom): """Exercise size(), rank(), depth(), mode() on all three layouts.""" - for layout_name, layout in [("C", atom.c_layout), ("A", atom.a_layout), ("B", atom.b_layout)]: + for layout_name, layout in [ + ("C", atom.c_layout), + ("A", atom.a_layout), + ("B", atom.b_layout), + ]: s = size(layout) r = rank(layout) d = depth(layout) @@ -534,8 +583,9 @@ def test_size_rank_depth_mode(self, atom): # mode(layout, 0) is the thread dimension thr_mode = mode(layout, 0) val_mode = mode(layout, 1) - assert size(thr_mode) * size(val_mode) == s, \ + assert size(thr_mode) * size(val_mode) == s, ( f"{atom.name} {layout_name}: mode sizes don't multiply to total" + ) def test_flatten_preserves_mapping(self, atom): """flatten(c_layout) produces the same offsets for all flat indices.""" @@ -543,16 +593,18 @@ def test_flatten_preserves_mapping(self, atom): c_flat = flatten(c) # Flattened layout should produce same offsets when indexed linearly for i in range(size(c)): - assert c_flat(i) == c(i), \ + assert c_flat(i) == c(i), ( f"{atom.name}: flatten mismatch at {i}: {c_flat(i)} != {c(i)}" + ) def test_coalesce_preserves_mapping(self, atom): """coalesce(c_layout) produces the same offsets.""" c = atom.c_layout c_coal = coalesce(c) for i in range(size(c)): - assert c_coal(i) == c(i), \ + assert c_coal(i) == c(i), ( f"{atom.name}: coalesce mismatch at {i}: {c_coal(i)} != {c(i)}" + ) def test_compose_with_identity(self, atom): """compose(L, identity) == L for all indices.""" @@ -560,8 +612,7 @@ def test_compose_with_identity(self, atom): identity = Layout(size(c)) # col-major identity composed = compose(c, identity) for i in range(size(c)): - assert composed(i) == c(i), \ - f"{atom.name}: compose(C, id) mismatch at {i}" + assert composed(i) == c(i), f"{atom.name}: compose(C, id) mismatch at {i}" def test_complement_c_layout(self, atom): """complement of flattened C layout produces valid ordered disjoint layout.""" @@ -571,13 +622,15 @@ def test_complement_c_layout(self, atom): comp = complement(c_flat) # complement must be ordered: comp(i-1) < comp(i) for i >= 1 for i in range(1, size(comp)): - assert comp(i - 1) < comp(i), \ - f"{atom.name}: complement not ordered at {i}: {comp(i-1)} >= {comp(i)}" + assert comp(i - 1) < comp(i), ( + f"{atom.name}: complement not ordered at {i}: {comp(i - 1)} >= {comp(i)}" + ) # complement must be disjoint from layout for i >= 1 c_offsets = {c_flat(j) for j in range(size(c_flat))} for i in range(1, size(comp)): - assert comp(i) not in c_offsets, \ + assert comp(i) not in c_offsets, ( f"{atom.name}: complement({i})={comp(i)} overlaps with layout" + ) def test_left_inverse_c_layout(self, atom): """left_inverse(C) composed with C gives identity for flat indices.""" @@ -588,8 +641,9 @@ def test_left_inverse_c_layout(self, atom): for i in range(size(c_flat)): offset = c_flat(i) roundtrip = linv(offset) - assert roundtrip == i, \ + assert roundtrip == i, ( f"{atom.name}: left_inverse roundtrip at {i}: {roundtrip} != {i}" + ) def test_right_inverse_c_layout(self, atom): """C composed with right_inverse(C) gives identity for offsets in range.""" @@ -600,8 +654,9 @@ def test_right_inverse_c_layout(self, atom): for i in range(size(c_flat)): offset = c_flat(i) roundtrip = c_flat(rinv(offset)) - assert roundtrip == offset, \ + assert roundtrip == offset, ( f"{atom.name}: right_inverse roundtrip at offset {offset}: {roundtrip} != {offset}" + ) def test_logical_divide_c_layout(self, atom): """logical_divide factors C layout into (tile, rest).""" @@ -614,8 +669,9 @@ def test_logical_divide_c_layout(self, atom): c_flat_thr = flatten(mode(c, 0)) divided = logical_divide(c_flat_thr, tiler) # The divided layout must cover the same total size - assert size(divided) == size(c_flat_thr), \ + assert size(divided) == size(c_flat_thr), ( f"{atom.name}: logical_divide changed size: {size(divided)} != {size(c_flat_thr)}" + ) def test_logical_product(self, atom): """logical_product replicates a pattern across positions.""" @@ -625,12 +681,14 @@ def test_logical_product(self, atom): replicator = Layout(2, size(c_flat)) product = logical_product(c_flat, replicator) # Size should be original * 2 - assert size(product) == size(c_flat) * 2, \ + assert size(product) == size(c_flat) * 2, ( f"{atom.name}: logical_product size {size(product)} != {size(c_flat) * 2}" + ) # First half should match original for i in range(size(c_flat)): - assert product(i) == c_flat(i), \ + assert product(i) == c_flat(i), ( f"{atom.name}: logical_product first-half mismatch at {i}" + ) def test_idx2crd_crd2idx_roundtrip(self, atom): """idx2crd and crd2idx are inverses on the thread dimension shape.""" @@ -639,8 +697,9 @@ def test_idx2crd_crd2idx_roundtrip(self, atom): for i in range(size(thr_shape)): crd = idx2crd(i, thr_shape) idx = crd2idx(crd, thr_shape) - assert idx == i, \ + assert idx == i, ( f"{atom.name}: idx2crd/crd2idx roundtrip at {i}: {idx} != {i}" + ) def test_idx2crd_crd2idx_roundtrip_val(self, atom): """idx2crd/crd2idx roundtrip on value dimension.""" @@ -649,8 +708,9 @@ def test_idx2crd_crd2idx_roundtrip_val(self, atom): for i in range(size(val_shape)): crd = idx2crd(i, val_shape) idx = crd2idx(crd, val_shape) - assert idx == i, \ + assert idx == i, ( f"{atom.name}: val idx2crd/crd2idx roundtrip at {i}: {idx} != {i}" + ) def test_flatten_is_idempotent(self, atom): """flatten(flatten(L)) == flatten(L).""" @@ -658,8 +718,7 @@ def test_flatten_is_idempotent(self, atom): once = flatten(c) twice = flatten(once) for i in range(size(c)): - assert once(i) == twice(i), \ - f"{atom.name}: flatten not idempotent at {i}" + assert once(i) == twice(i), f"{atom.name}: flatten not idempotent at {i}" def test_coalesce_is_idempotent(self, atom): """coalesce(coalesce(L)) == coalesce(L).""" @@ -667,16 +726,14 @@ def test_coalesce_is_idempotent(self, atom): once = coalesce(c) twice = coalesce(once) for i in range(size(c)): - assert once(i) == twice(i), \ - f"{atom.name}: coalesce not idempotent at {i}" + assert once(i) == twice(i), f"{atom.name}: coalesce not idempotent at {i}" def test_flatten_then_coalesce(self, atom): """flatten then coalesce produces same mapping.""" c = atom.c_layout fc = coalesce(flatten(c)) for i in range(size(c)): - assert fc(i) == c(i), \ - f"{atom.name}: flatten+coalesce mismatch at {i}" + assert fc(i) == c(i), f"{atom.name}: flatten+coalesce mismatch at {i}" def test_compose_chain(self, atom): """compose(compose(L, A), B) == compose(L, compose(A, B)) (associativity).""" @@ -689,8 +746,9 @@ def test_compose_chain(self, atom): lhs = compose(compose(c_flat, a), b) rhs = compose(c_flat, compose(a, b)) for i in range(size(b)): - assert lhs(i) == rhs(i), \ + assert lhs(i) == rhs(i), ( f"{atom.name}: compose associativity failed at {i}: {lhs(i)} != {rhs(i)}" + ) def test_make_ordered_layout_flat_c_shape(self, atom): """make_ordered_layout on flattened C shape produces ordered strides.""" @@ -698,12 +756,14 @@ def test_make_ordered_layout_flat_c_shape(self, atom): c_flat = flatten(c) ordered = make_ordered_layout(c_flat.shape) # Same size - assert size(ordered) == size(c), \ + assert size(ordered) == size(c), ( f"{atom.name}: make_ordered_layout changed size" + ) # Ordered: strides should be increasing (column-major order) for i in range(1, size(ordered)): - assert ordered(i) > ordered(i - 1), \ + assert ordered(i) > ordered(i - 1), ( f"{atom.name}: make_ordered_layout not ordered at {i}" + ) # ============================================================================= @@ -730,16 +790,17 @@ def test_compute_tv_mapping_c(self, atom): """_compute_tv_mapping on c_layout covers every cell of the M x N grid.""" m, n, k = atom.shape_mnk c = atom.c_layout - tv_map = _compute_tv_mapping(c, grid_cols=n, grid_rows=m, - col_major=True) + tv_map = _compute_tv_mapping(c, grid_cols=n, grid_rows=m, col_major=True) # Every (row, col) in [0,M) x [0,N) should have an entry for row in range(m): for col in range(n): - assert (row, col) in tv_map, \ + assert (row, col) in tv_map, ( f"{atom.name}: C tv_map missing ({row},{col})" + ) phys_t, v_idx, logical_t = tv_map[(row, col)] - assert 0 <= phys_t < 64, \ + assert 0 <= phys_t < 64, ( f"{atom.name}: C invalid thread {phys_t} at ({row},{col})" + ) def test_compute_tv_mapping_a(self, atom): """_compute_tv_mapping on a_layout produces valid entries.""" @@ -756,8 +817,9 @@ def test_compute_tv_mapping_a(self, atom): for t in range(64): for v in range(num_v): offset = a(t, v) - assert 0 <= offset < a_cosize, \ + assert 0 <= offset < a_cosize, ( f"{atom.name}: A offset {offset} out of range [0, {a_cosize})" + ) def test_compute_tv_mapping_b(self, atom): """_compute_tv_mapping on b_layout produces valid entries.""" @@ -768,15 +830,15 @@ def test_compute_tv_mapping_b(self, atom): for t in range(64): for v in range(num_v): offset = b(t, v) - assert 0 <= offset < b_cosize, \ + assert 0 <= offset < b_cosize, ( f"{atom.name}: B offset {offset} out of range [0, {b_cosize})" + ) def test_compute_tv_mapping_c_threads_match(self, atom): """Thread IDs from tv_mapping match direct layout evaluation.""" m, n, k = atom.shape_mnk c = atom.c_layout - tv_map = _compute_tv_mapping(c, grid_cols=n, grid_rows=m, - col_major=True) + tv_map = _compute_tv_mapping(c, grid_cols=n, grid_rows=m, col_major=True) # Rebuild the forward map and compare num_v = _num_values(c) for t in range(64): @@ -784,41 +846,56 @@ def test_compute_tv_mapping_c_threads_match(self, atom): offset = c(t, v) row = offset % m col = offset // m - assert (row, col) in tv_map, \ + assert (row, col) in tv_map, ( f"{atom.name}: ({row},{col}) missing from tv_map" + ) phys_t, v_idx, logical_t = tv_map[(row, col)] - assert phys_t == t, \ + assert phys_t == t, ( f"{atom.name}: thread mismatch at ({row},{col}): {phys_t} != {t}" - assert v_idx == v, \ + ) + assert v_idx == v, ( f"{atom.name}: value mismatch at ({row},{col}): {v_idx} != {v}" + ) def test_draw_tv_layout_smoke(self, atom): """draw_tv_layout runs without error (output to tempfile).""" m, n, k = atom.shape_mnk c = atom.c_layout with tempfile.NamedTemporaryFile(suffix=".png") as f: - draw_tv_layout(c, filename=f.name, - grid_shape=(m, n), colorize=True) + draw_tv_layout(c, filename=f.name, grid_shape=(m, n), colorize=True) def test_draw_mma_layout_smoke(self, atom): """draw_mma_layout runs without error.""" m, n, k = atom.shape_mnk with tempfile.NamedTemporaryFile(suffix=".png") as f: if atom.name == "CDNA_4x4x4_F32F16F16_MFMA": - with pytest.raises(ValueError, match=r"A .*panel shape .*out of bounds"): - draw_mma_layout(atom.a_layout, atom.b_layout, atom.c_layout, - filename=f.name, tile_mnk=(m, n, k), - main_title=atom.name) + with pytest.raises( + ValueError, match=r"A .*panel shape .*out of bounds" + ): + draw_mma_layout( + atom.a_layout, + atom.b_layout, + atom.c_layout, + filename=f.name, + tile_mnk=(m, n, k), + main_title=atom.name, + ) else: - draw_mma_layout(atom.a_layout, atom.b_layout, atom.c_layout, - filename=f.name, tile_mnk=(m, n, k), - main_title=atom.name) + draw_mma_layout( + atom.a_layout, + atom.b_layout, + atom.c_layout, + filename=f.name, + tile_mnk=(m, n, k), + main_title=atom.name, + ) # ============================================================================= # Layout utils tests # ============================================================================= + @pytest.mark.parametrize("atom", ALL_ATOMS, ids=lambda a: a.name) class TestLayoutUtils: """Test layout_utils functions on AMD atom layouts.""" @@ -842,5 +919,6 @@ def test_tile_to_shape_c(self, atom): # Tile to 2x the original shape target = (size(c.shape[0]) * 2, size(c.shape[1])) tiled = tile_to_shape(c, target) - assert size(tiled) == size(c) * 2, \ + assert size(tiled) == size(c) * 2, ( f"{atom.name}: tile_to_shape wrong size: {size(tiled)} != {size(c) * 2}" + ) diff --git a/tests/oracle_cute_cpp.py b/tests/oracle_cute_cpp.py index cdc92f9..3a3a135 100644 --- a/tests/oracle_cute_cpp.py +++ b/tests/oracle_cute_cpp.py @@ -31,10 +31,10 @@ from __future__ import annotations import importlib.util -from pathlib import Path import shutil import subprocess import tempfile +from pathlib import Path import pytest @@ -232,7 +232,9 @@ "logical_divide_unit_tile": lambda: logical_divide(Layout(2, 5), 1), "logical_divide_exact_division_unit_rest": lambda: logical_divide(Layout(4, 3), 4), "logical_divide_oversize_tile": lambda: logical_divide(Layout(2, 5), 4), - "logical_divide_exact_tuple": lambda: logical_divide(Layout((2, 3), (1, 2)), (2, 3)), + "logical_divide_exact_tuple": lambda: logical_divide( + Layout((2, 3), (1, 2)), (2, 3) + ), "logical_divide_nested_tuple_tiler": lambda: logical_divide( Layout(((2, 3), 8), ((1, 2), 6)), ((2, 3), 4), @@ -271,19 +273,25 @@ PYTHON_POINTWISE_CASES = { "compose_double_swizzle_offsets": lambda: ",".join( - str(compose(Swizzle(1, 0, 3), compose(Swizzle(3, 0, 3), Layout((8, 8), (8, 1))))(i)) + str( + compose( + Swizzle(1, 0, 3), compose(Swizzle(3, 0, 3), Layout((8, 8), (8, 1))) + )(i) + ) for i in range(size(Layout((8, 8), (8, 1)))) ), "compose_outer_layout_swizzled_offsets": lambda: ",".join( - str(compose(Layout((4, 4), (4, 1)), compose(Swizzle(3, 0, 3), Layout(16, 1)))(i)) + str( + compose(Layout((4, 4), (4, 1)), compose(Swizzle(3, 0, 3), Layout(16, 1)))(i) + ) for i in range(16) ), "compose_layout_swizzle_exact_offsets": lambda: ",".join( - str(compose(Layout((4, 4), (4, 1)), Swizzle(2, 1, 3))(i)) - for i in range(16) + str(compose(Layout((4, 4), (4, 1)), Swizzle(2, 1, 3))(i)) for i in range(16) ), "compose_slice_row": lambda: ( - lambda sliced: f"{sliced[1]}|" + ",".join(str(sliced[0](i)) for i in range(size(sliced[0]))) + lambda sliced: f"{sliced[1]}|" + + ",".join(str(sliced[0](i)) for i in range(size(sliced[0]))) )( slice_and_offset( (2, None), @@ -294,18 +302,30 @@ ) ), "compose_layout_zero_preoffset_composed_offsets": lambda: ",".join( - str(compose(Layout(32, 2), ComposedLayout(Swizzle(2, 1, 3), Layout(32, 1), preoffset=0))(i)) + str( + compose( + Layout(32, 2), + ComposedLayout(Swizzle(2, 1, 3), Layout(32, 1), preoffset=0), + )(i) + ) for i in range(32) ), "compose_layout_nonzero_preoffset_composed_offsets": lambda: ",".join( - str(compose(Layout(32, 2), ComposedLayout(Swizzle(2, 1, 3), Layout(32, 1), preoffset=4))(i)) + str( + compose( + Layout(32, 2), + ComposedLayout(Swizzle(2, 1, 3), Layout(32, 1), preoffset=4), + )(i) + ) for i in range(32) ), "compose_recursive_chain_offsets": lambda: ",".join( str( compose( Layout(16, 3), - ComposedLayout(Layout(16, 2), compose(Swizzle(2, 0, 2), Layout(16, 1)), preoffset=0), + ComposedLayout( + Layout(16, 2), compose(Swizzle(2, 0, 2), Layout(16, 1)), preoffset=0 + ), )(i) ) for i in range(16) @@ -359,15 +379,40 @@ ) ), "right_inverse_swizzled_composed_offsets": lambda: ",".join( - str(right_inverse(ComposedLayout(Swizzle(2, 1, 3), Layout(32, 1), preoffset=0))(i)) - for i in range(size(right_inverse(ComposedLayout(Swizzle(2, 1, 3), Layout(32, 1), preoffset=0)))) + str( + right_inverse(ComposedLayout(Swizzle(2, 1, 3), Layout(32, 1), preoffset=0))( + i + ) + ) + for i in range( + size( + right_inverse( + ComposedLayout(Swizzle(2, 1, 3), Layout(32, 1), preoffset=0) + ) + ) + ) ), "left_inverse_swizzled_composed_offsets": lambda: ",".join( - str(left_inverse(ComposedLayout(Swizzle(2, 1, 3), Layout(32, 1), preoffset=0))(i)) - for i in range(size(left_inverse(ComposedLayout(Swizzle(2, 1, 3), Layout(32, 1), preoffset=0)))) + str( + left_inverse(ComposedLayout(Swizzle(2, 1, 3), Layout(32, 1), preoffset=0))( + i + ) + ) + for i in range( + size( + left_inverse( + ComposedLayout(Swizzle(2, 1, 3), Layout(32, 1), preoffset=0) + ) + ) + ) ), "tensor_composed_values": lambda: ",".join( - str(Tensor(ComposedLayout(Swizzle(2, 1, 3), Layout(16, 1), preoffset=4), data=list(range(256)))[i]) + str( + Tensor( + ComposedLayout(Swizzle(2, 1, 3), Layout(16, 1), preoffset=4), + data=list(range(256)), + )[i] + ) for i in range(16) ), } @@ -424,7 +469,9 @@ def cute_cpp_oracle() -> dict[str, str]: include_dirs = _candidate_include_dirs() if not any((path / "cute/layout.hpp").exists() for path in include_dirs): - pytest.skip("CuTe headers not found; install CUTLASS headers to run this oracle") + pytest.skip( + "CuTe headers not found; install CUTLASS headers to run this oracle" + ) if not any((path / "cuda/std/utility").exists() for path in include_dirs): pytest.skip("CUDA C++ headers not found; CuTe C++ oracle cannot compile") @@ -437,12 +484,16 @@ def cute_cpp_oracle() -> dict[str, str]: for include_dir in include_dirs: command.extend(["-I", str(include_dir)]) - compile_result = subprocess.run(command, capture_output=True, text=True, check=False) + compile_result = subprocess.run( + command, capture_output=True, text=True, check=False + ) if compile_result.returncode != 0: stderr = compile_result.stderr.strip() or "unknown compiler error" pytest.skip(f"failed to compile CuTe oracle: {stderr}") - run_result = subprocess.run([str(exe_path)], capture_output=True, text=True, check=False) + run_result = subprocess.run( + [str(exe_path)], capture_output=True, text=True, check=False + ) if run_result.returncode != 0: stderr = run_result.stderr.strip() or "unknown runtime error" pytest.skip(f"failed to run CuTe oracle: {stderr}") @@ -459,7 +510,9 @@ def cute_cpp_oracle() -> dict[str, str]: @pytest.mark.parametrize("case_name", sorted(PYTHON_CASES)) def test_cute_cpp_oracle(case_name, cute_cpp_oracle): result = PYTHON_CASES[case_name]() - assert _normalize_layout_repr(str(result)) == _normalize_layout_repr(cute_cpp_oracle[case_name]) + assert _normalize_layout_repr(str(result)) == _normalize_layout_repr( + cute_cpp_oracle[case_name] + ) @pytest.mark.parametrize("case_name", sorted(PYTHON_POINTWISE_CASES)) diff --git a/tests/oracle_nv.py b/tests/oracle_nv.py index 290d714..980f9c2 100644 --- a/tests/oracle_nv.py +++ b/tests/oracle_nv.py @@ -31,22 +31,24 @@ """ from tensor_layouts import * -from tensor_layouts.analysis import is_injective, is_bijective, is_contiguous -from tensor_layouts.layout_utils import make_ordered_layout, tile_to_shape - import pytest +from tensor_layouts.analysis import is_bijective, is_contiguous, is_injective +from tensor_layouts.layout_utils import make_ordered_layout, tile_to_shape # Import pycute reference — skip all tests if unavailable. # Note: there is an unrelated "pycute" package on PyPI (statistics library). # We need NVIDIA's pycute from the CUTLASS source tree. try: import pycute - if not hasattr(pycute, 'Layout'): + + if not hasattr(pycute, "Layout"): pycute = None except ImportError: pycute = None -pytestmark = pytest.mark.skipif(pycute is None, reason="pycute (NVIDIA CUTLASS) not available") +pytestmark = pytest.mark.skipif( + pycute is None, reason="pycute (NVIDIA CUTLASS) not available" +) ############################################################################### @@ -116,33 +118,53 @@ def layouts_functionally_equal(our, ref, domain_size): # Standard test layouts: (shape, stride) pairs covering many patterns LAYOUT_CORPUS = [ # Trivial - (1, 0), (1, 1), + (1, 0), + (1, 1), # 1D - (4, 1), (4, 2), (8, 1), (8, 2), (12, 1), (12, 3), + (4, 1), + (4, 2), + (8, 1), + (8, 2), + (12, 1), + (12, 3), # Zero stride (broadcast) - (4, 0), (8, 0), + (4, 0), + (8, 0), # 2D col-major - ((2, 4), (1, 2)), ((4, 3), (1, 4)), ((8, 4), (1, 8)), + ((2, 4), (1, 2)), + ((4, 3), (1, 4)), + ((8, 4), (1, 8)), # 2D row-major - ((2, 4), (4, 1)), ((4, 3), (3, 1)), ((8, 4), (4, 1)), + ((2, 4), (4, 1)), + ((4, 3), (3, 1)), + ((8, 4), (4, 1)), # 2D with gaps - ((2, 4), (1, 4)), ((2, 4), (1, 6)), ((4, 2), (1, 10)), ((4, 2), (1, 16)), + ((2, 4), (1, 4)), + ((2, 4), (1, 6)), + ((4, 2), (1, 10)), + ((4, 2), (1, 16)), # 2D with broadcast - ((2, 4), (0, 2)), ((4, 2), (2, 0)), + ((2, 4), (0, 2)), + ((4, 2), (2, 0)), # 3D - ((2, 4, 6), (1, 2, 8)), ((2, 4, 6), (4, 1, 8)), - ((2, 3, 4), (1, 2, 6)), ((2, 4, 8), (8, 1, 64)), + ((2, 4, 6), (1, 2, 8)), + ((2, 4, 6), (4, 1, 8)), + ((2, 3, 4), (1, 2, 6)), + ((2, 4, 8), (8, 1, 64)), ((2, 4, 6), (24, 6, 1)), # 3D with broadcast - ((2, 4, 8), (8, 1, 0)), ((2, 4, 3), (1, 2, 0)), + ((2, 4, 8), (8, 1, 0)), + ((2, 4, 3), (1, 2, 0)), # Nested (hierarchical) (((2, 2), (2, 2)), ((1, 4), (8, 32))), ((2, (3, 4)), (3, (1, 6))), (((4, 2),), ((1, 16),)), # Auto-stride (col-major) - ((2, 4), None), ((4, 3), None), ((2, 4, 6), None), ((2, 3, 4), None), + ((2, 4), None), + ((4, 3), None), + ((2, 4, 6), None), + ((2, 3, 4), None), ((8, 8), None), - # ===== FROM C++ TESTS ===== # C++ inverse / complement tests: broadcast shapes (((3, 7),), ((0, 0),)), @@ -165,7 +187,6 @@ def layouts_functionally_equal(our, ref, domain_size): ((4, 10), (1, 10)), # C++ composition tests: transposed strides ((4, 3), (3, 1)), - # ===== EDGE CASES ===== # All-zero strides (pure broadcast) ((2, 3, 4), (0, 0, 0)), @@ -308,11 +329,22 @@ def test_oracle_composition(): # Composition pairs: (A_shape, A_stride, B_shape, B_stride) composition_pairs = [ # Simple - (1, 0, 1, 0), (1, 0, 1, 1), (1, 1, 1, 0), (1, 1, 1, 1), - (4, 1, 4, 1), (4, 2, 4, 1), (4, 1, 4, 2), (4, 0, 4, 1), - (4, 1, 4, 0), (1, 0, 4, 1), (4, 1, 1, 0), + (1, 0, 1, 0), + (1, 0, 1, 1), + (1, 1, 1, 0), + (1, 1, 1, 1), + (4, 1, 4, 1), + (4, 2, 4, 1), + (4, 1, 4, 2), + (4, 0, 4, 1), + (4, 1, 4, 0), + (1, 0, 4, 1), + (4, 1, 1, 0), # Partial - (4, 1, 2, 1), (4, 2, 2, 1), (4, 1, 2, 2), (4, 2, 2, 2), + (4, 1, 2, 1), + (4, 2, 2, 1), + (4, 1, 2, 2), + (4, 2, 2, 2), # Multi-dim A, 1D B ((4, 3), (1, 4), 12, 1), ((4, 3), (1, 4), 6, 1), @@ -380,16 +412,32 @@ def test_oracle_shape_div(): """Cross-validate shape_div() against pycute.""" test_cases = [ # (shape, divisor) -- only cases valid for pycute (a%b==0 or b%a==0 at each level) - ((3, 4), 1), ((3, 4), 3), ((3, 4), 6), - ((3, 4), 12), ((3, 4), 36), - ((4, 3), 2), ((4, 3), 4), ((4, 3), 12), - ((6, 2), 2), ((6, 2), 3), ((6, 2), 6), ((6, 2), 12), + ((3, 4), 1), + ((3, 4), 3), + ((3, 4), 6), + ((3, 4), 12), + ((3, 4), 36), + ((4, 3), 2), + ((4, 3), 4), + ((4, 3), 12), + ((6, 2), 2), + ((6, 2), 3), + ((6, 2), 6), + ((6, 2), 12), # Nested - (((3, 4), 6), 1), (((3, 4), 6), 3), (((3, 4), 6), 12), - (((3, 4), 6), 36), (((3, 4), 6), 72), - ((6, (3, 4)), 6), ((6, (3, 4)), 36), + (((3, 4), 6), 1), + (((3, 4), 6), 3), + (((3, 4), 6), 12), + (((3, 4), 6), 36), + (((3, 4), 6), 72), + ((6, (3, 4)), 6), + ((6, (3, 4)), 36), # Scalars - (12, 1), (12, 3), (12, 4), (12, 6), (12, 12), + (12, 1), + (12, 3), + (12, 4), + (12, 6), + (12, 12), ] for shape, divisor in test_cases: @@ -405,11 +453,8 @@ def test_oracle_shape_div(): assert shapes_equal( ours_r if isinstance(ours_r, int) else ours_r, - ref_r if isinstance(ref_r, int) else ref_r - ), ( - f"shape_div({shape}, {divisor}): " - f"ours={ours_r} vs pycute={ref_r}" - ) + ref_r if isinstance(ref_r, int) else ref_r, + ), f"shape_div({shape}, {divisor}): ours={ours_r} vs pycute={ref_r}" def test_oracle_prefix_product(): @@ -431,10 +476,8 @@ def test_oracle_prefix_product(): assert shapes_equal( ours_r if isinstance(ours_r, int) else ours_r, - ref_r if isinstance(ref_r, int) else ref_r - ), ( - f"prefix_product({shape}): ours={ours_r} vs pycute={ref_r}" - ) + ref_r if isinstance(ref_r, int) else ref_r, + ), f"prefix_product({shape}): ours={ours_r} vs pycute={ref_r}" def test_oracle_inner_product(): @@ -528,12 +571,22 @@ def test_oracle_logical_divide(): """Cross-validate logical_divide() with Layout tilers against pycute.""" # (layout_shape, layout_stride, tiler_shape, tiler_stride) divide_cases = [ - (1, 0, 1, 0), (1, 0, 1, 1), (1, 1, 1, 0), (1, 1, 1, 1), - (6, 1, 2, 1), (6, 1, 2, 3), (6, 2, 2, 1), (6, 2, 2, 3), - (6, 1, (2, 3), (3, 1)), (6, 2, (2, 3), (3, 1)), + (1, 0, 1, 0), + (1, 0, 1, 1), + (1, 1, 1, 0), + (1, 1, 1, 1), + (6, 1, 2, 1), + (6, 1, 2, 3), + (6, 2, 2, 1), + (6, 2, 2, 3), + (6, 1, (2, 3), (3, 1)), + (6, 2, (2, 3), (3, 1)), (32, 1, 2, 8), - (12, 1, 4, 1), (12, 1, 6, 1), (12, 2, 4, 1), - (48, 1, 32, 1), (96, 1, 32, 2), + (12, 1, 4, 1), + (12, 1, 6, 1), + (12, 2, 4, 1), + (48, 1, 32, 1), + (96, 1, 32, 2), ] for item in divide_cases: @@ -562,16 +615,26 @@ def test_oracle_logical_product(): cute_cpp_unsupported = {(4, 2, 3, 1)} product_cases = [ # (A_shape, A_stride, B_shape, B_stride) - (4, 1, 3, 1), (4, 2, 3, 1), (4, 1, 3, 2), + (4, 1, 3, 1), + (4, 2, 3, 1), + (4, 1, 3, 2), ((2, 4), (1, 2), 3, 1), - (8, 1, 4, 1), (8, 2, 4, 1), + (8, 1, 4, 1), + (8, 2, 4, 1), # === C++ test cases === # Trivial - (1, 0, 1, 0), (1, 1, 1, 0), (1, 0, 1, 1), (1, 1, 1, 1), + (1, 0, 1, 0), + (1, 1, 1, 0), + (1, 0, 1, 1), + (1, 1, 1, 1), # Broadcast - (3, 1, 4, 0), (3, 0, 4, 1), (3, 0, 4, 0), (3, 2, 4, 1), + (3, 1, 4, 0), + (3, 0, 4, 1), + (3, 0, 4, 0), + (3, 2, 4, 1), # 1D - (3, 1, (2, 4), None), ((2, 4), None, 3, 1), + (3, 1, (2, 4), None), + ((2, 4), None, 3, 1), # Hierarchical ((8, (2, 2)), None, 4, 2), ((2, 2), None, (3, 3), (3, 1)), @@ -670,8 +733,7 @@ def test_exhaustive_complement_disjointness(): la = layout(a) cb = c(b) assert (la != cb) or (la == 0 and cb == 0), ( - f"Disjointness violated for {layout}: " - f"L({a})={la} == C({b})={cb}" + f"Disjointness violated for {layout}: L({a})={la} == C({b})={cb}" ) @@ -682,7 +744,7 @@ def test_exhaustive_complement_ordering(): for i in range(1, size(c)): assert c(i - 1) < c(i), ( f"Complement not ordered for {layout}: " - f"C({i-1})={c(i-1)} >= C({i})={c(i)}" + f"C({i - 1})={c(i - 1)} >= C({i})={c(i)}" ) @@ -703,9 +765,7 @@ def test_exhaustive_coalesce_reduces_depth(): """Verify coalesce produces depth <= 1.""" for layout in _generate_small_layouts(): coal = coalesce(layout) - assert depth(coal) <= 1, ( - f"coalesce({layout}) has depth {depth(coal)} > 1" - ) + assert depth(coal) <= 1, f"coalesce({layout}) has depth {depth(coal)} > 1" def test_exhaustive_right_inverse_identity(): @@ -729,8 +789,6 @@ def _is_injective(layout): return True - - def test_exhaustive_left_inverse_identity(): """Verify R(L(i)) == i for small injective, contiguous layouts. @@ -794,8 +852,13 @@ def test_exhaustive_shape_div_mod_complementary(): so not all divisors of size(s) are valid. We try each and skip failures. """ shapes = [ - (2, 3), (3, 4), (2, 2, 3), (4, 3), - (6, 2), (2, 6), (3, 2, 4), + (2, 3), + (3, 4), + (2, 2, 3), + (4, 3), + (6, 2), + (2, 6), + (3, 2, 4), ] tested = 0 @@ -843,8 +906,7 @@ def test_exhaustive_logical_divide_preserves_mapping_when_domain_size_is_unchang for i in range(size(layout)): assert result(i) == layout(i), ( - f"logical_divide({layout}, {t})({i}) = " - f"{result(i)} != {layout(i)}" + f"logical_divide({layout}, {t})({i}) = {result(i)} != {layout(i)}" ) tested += 1 @@ -862,17 +924,13 @@ def test_exhaustive_inverse_roundtrip(): # right_inverse property: L(R(i)) == i (works for all layouts) for i in range(size(rinv)): - assert layout(rinv(i)) == i, ( - f"right_inverse({layout}): L(R({i})) != {i}" - ) + assert layout(rinv(i)) == i, f"right_inverse({layout}): L(R({i})) != {i}" # left_inverse property: R(L(i)) == i (only for contiguous layouts) if is_contiguous(layout): linv = left_inverse(layout) for i in range(size(layout)): - assert linv(layout(i)) == i, ( - f"left_inverse({layout}): R(L({i})) != {i}" - ) + assert linv(layout(i)) == i, f"left_inverse({layout}): R(L({i})) != {i}" ############################################################################### @@ -899,8 +957,8 @@ def test_oracle_tuple_max(): def test_oracle_elem_scale(): """Cross-validate elem_scale against pycute.""" cases = [ - (3, 4), # int x int - (2, (3, 4)), # int x tuple + (3, 4), # int x int + (2, (3, 4)), # int x tuple ((2, 3), (4, 5)), # tuple x tuple (1, (2, 3, 4)), # int x tuple ] @@ -1107,8 +1165,7 @@ def test_oracle_filter(): # Functional equivalence for i in range(size(ours_f)): assert ours_f(i) == ref_f(i), ( - f"filter({shape}:{stride})({i}): " - f"ours={ours_f(i)} vs pycute={ref_f(i)}" + f"filter({shape}:{stride})({i}): ours={ours_f(i)} vs pycute={ref_f(i)}" ) @@ -1366,10 +1423,16 @@ def test_exhaustive_filter_idempotent(): def test_exhaustive_blocked_product_size(): """Verify blocked_product(A, B) has size(A) * size(B).""" layouts_1d = [ - Layout(2, 1), Layout(3, 1), Layout(4, 1), Layout(2, 2), Layout(4, 2), + Layout(2, 1), + Layout(3, 1), + Layout(4, 1), + Layout(2, 2), + Layout(4, 2), ] layouts_2d = [ - Layout((2, 2), (1, 2)), Layout((2, 3), (3, 1)), Layout((3, 2)), + Layout((2, 2), (1, 2)), + Layout((2, 3), (3, 1)), + Layout((3, 2)), ] all_layouts = layouts_1d + layouts_2d @@ -1378,8 +1441,7 @@ def test_exhaustive_blocked_product_size(): result = blocked_product(a, b) expected_size = size(a) * size(b) assert size(result) == expected_size, ( - f"blocked_product({a}, {b}): " - f"size={size(result)} != {expected_size}" + f"blocked_product({a}, {b}): size={size(result)} != {expected_size}" ) @@ -1390,10 +1452,14 @@ def test_exhaustive_blocked_product_covers_offsets(): blocks: B's offsets are shifted by cosize(A) * i for each copy. """ compact_layouts = [ - Layout(2, 1), Layout(4, 1), Layout((2, 2)), + Layout(2, 1), + Layout(4, 1), + Layout((2, 2)), ] tiling_layouts = [ - Layout(2, 1), Layout(3, 1), Layout((2, 2)), + Layout(2, 1), + Layout(3, 1), + Layout((2, 2)), ] for a in compact_layouts: @@ -1436,7 +1502,11 @@ def test_exhaustive_flat_divide_preserves_mapping(): # For multi-mode layouts, only test tilers that divide evenly # within the first mode to avoid cross-mode reordering issues if r > 1: - first_mode_size = layout.shape[0] if isinstance(layout.shape[0], int) else size(Layout(layout.shape[0])) + first_mode_size = ( + layout.shape[0] + if isinstance(layout.shape[0], int) + else size(Layout(layout.shape[0])) + ) if t > first_mode_size or first_mode_size % t != 0: continue try: @@ -1446,8 +1516,7 @@ def test_exhaustive_flat_divide_preserves_mapping(): for i in range(s): assert result(i) == layout(i), ( - f"flat_divide({layout}, {t})({i}) = " - f"{result(i)} != {layout(i)}" + f"flat_divide({layout}, {t})({i}) = {result(i)} != {layout(i)}" ) tested += 1 @@ -1555,7 +1624,6 @@ def test_product_each_matches_pycute_size(): assert result == expected, f"product_each({shape}) = {result} != {expected}" - @pytest.mark.skipif(pycute is None, reason="pycute not installed") def test_oracle_idx2crd(): shapes = [ @@ -1677,7 +1745,10 @@ def test_oracle_left_inverse_padded(): if __name__ == "__main__": import traceback - test_funcs = [v for k, v in sorted(globals().items()) if k.startswith("test_") and callable(v)] + + test_funcs = [ + v for k, v in sorted(globals().items()) if k.startswith("test_") and callable(v) + ] passed = 0 failed = 0 errors = [] diff --git a/tests/oracle_rdna.py b/tests/oracle_rdna.py index 4004633..afd255f 100644 --- a/tests/oracle_rdna.py +++ b/tests/oracle_rdna.py @@ -7,15 +7,14 @@ """ import pytest - -from tensor_layouts import size, rank, cosize -from tensor_layouts.atoms_amd import ( - MMA_ATOMS_RDNA3, MMA_ATOMS_RDNA4, -) +from tensor_layouts import cosize, rank, size +from tensor_layouts.atoms_amd import MMA_ATOMS_RDNA3, MMA_ATOMS_RDNA4 def _num_threads(layout): - return size(layout.shape[0]) if isinstance(layout.shape, tuple) else size(layout.shape) + return ( + size(layout.shape[0]) if isinstance(layout.shape, tuple) else size(layout.shape) + ) def _num_values(layout): @@ -40,20 +39,24 @@ def test_c_layout_covers_all_elements(self, atom): for t in range(num_t): for v in range(num_v): offset = c(t, v) - assert 0 <= offset < m * n, \ - f"{atom.name}: offset {offset} out of range [0, {m*n})" - assert offset not in seen, \ + assert 0 <= offset < m * n, ( + f"{atom.name}: offset {offset} out of range [0, {m * n})" + ) + assert offset not in seen, ( f"{atom.name}: duplicate offset {offset} at t={t}, v={v}" + ) seen.add(offset) - assert len(seen) == m * n, \ - f"{atom.name}: covers {len(seen)} elements, expected {m*n}" + assert len(seen) == m * n, ( + f"{atom.name}: covers {len(seen)} elements, expected {m * n}" + ) def test_c_layout_thread_count(self, atom): """Thread dimension has exactly 32 elements (wave32).""" c = atom.c_layout - assert _num_threads(c) == 32, \ + assert _num_threads(c) == 32, ( f"{atom.name}: {_num_threads(c)} threads, expected 32" + ) def test_a_layout_thread_count(self, atom): a = atom.a_layout diff --git a/tests/oracle_xe.py b/tests/oracle_xe.py index aa3799d..6fd4f2b 100644 --- a/tests/oracle_xe.py +++ b/tests/oracle_xe.py @@ -7,13 +7,14 @@ """ import pytest - -from tensor_layouts import size, rank, cosize +from tensor_layouts import cosize, rank, size from tensor_layouts.atoms_xe import * def _num_threads(layout): - return size(layout.shape[0]) if isinstance(layout.shape, tuple) else size(layout.shape) + return ( + size(layout.shape[0]) if isinstance(layout.shape, tuple) else size(layout.shape) + ) def _num_values(layout): @@ -38,10 +39,12 @@ def test_c_layout_covers_all_elements(self, atom): for t in range(num_t): for v in range(num_v): offset = c(t, v) - assert 0 <= offset < m * n, \ - f"{atom.name}: offset {offset} out of range [0, {m*n})" - assert offset not in seen, \ + assert 0 <= offset < m * n, ( + f"{atom.name}: offset {offset} out of range [0, {m * n})" + ) + assert offset not in seen, ( f"{atom.name}: duplicate offset {offset} at t={t}, v={v}" + ) seen.add(offset) assert len(seen) == m * n @@ -66,8 +69,9 @@ def test_a_layout_broadcast(self, atom): """A layout broadcasts across subgroup (stride 0 on thread dim).""" a = atom.a_layout t_stride = a.stride[0] if isinstance(a.stride, tuple) else a.stride - assert t_stride == 0, \ + assert t_stride == 0, ( f"{atom.name}: A thread stride is {t_stride}, expected 0 (broadcast)" + ) def test_c_layout_cosize_equals_mn(self, atom): m, n, k = atom.shape_mnk @@ -102,10 +106,10 @@ def test_c_layout_column_ownership(self, atom): offset = c(t, v) col = offset // m # col-major: col = offset // M cols.add(col) - assert len(cols) == 1, \ + assert len(cols) == 1, ( f"{atom.name}: thread {t} touches columns {cols}, expected 1" - assert cols.pop() == t, \ - f"{atom.name}: thread {t} owns wrong column" + ) + assert cols.pop() == t, f"{atom.name}: thread {t} owns wrong column" def test_b_layout_column_ownership(self, atom): """Each thread owns exactly one N-position of B.""" diff --git a/tests/paper_examples.py b/tests/paper_examples.py index 74e5391..462f472 100644 --- a/tests/paper_examples.py +++ b/tests/paper_examples.py @@ -155,8 +155,13 @@ def _selected_cells(layout, tiler): def _expand_slice_component(spec, shape): if is_tuple(shape): if not is_tuple(spec): - raise ValueError(f"Nested shape {shape!r} requires nested spec, got {spec!r}") - parts = [_expand_slice_component(subspec, subshape) for subspec, subshape in zip(spec, shape)] + raise ValueError( + f"Nested shape {shape!r} requires nested spec, got {spec!r}" + ) + parts = [ + _expand_slice_component(subspec, subshape) + for subspec, subshape in zip(spec, shape) + ] return {coords for coords in product(*parts)} if spec is None: return set(range(shape)) @@ -169,7 +174,9 @@ def _infer_slice_component(selected, shape): selected = list(selected) if is_tuple(shape): if not selected: - raise ValueError(f"Cannot infer slice for empty selection in shape {shape!r}") + raise ValueError( + f"Cannot infer slice for empty selection in shape {shape!r}" + ) spec = tuple( _infer_slice_component([coord[i] for coord in selected], subshape) for i, subshape in enumerate(shape) @@ -187,7 +194,9 @@ def _infer_slice_component(selected, shape): return None if values == list(range(values[0], values[-1] + 1)): return slice(values[0], values[-1] + 1) - raise ValueError(f"Selection {values!r} is not representable as a flat slice of {shape}") + raise ValueError( + f"Selection {values!r} is not representable as a flat slice of {shape}" + ) def _common_subvector_slice_spec(layout, common): @@ -199,7 +208,10 @@ def _selected_coords_from_slice_spec(layout, slice_spec): row_spec, col_spec = slice_spec row_coords = _expand_slice_component(row_spec, mode(layout.shape, 0)) col_coords = _expand_slice_component(col_spec, mode(layout.shape, 1)) - return {(row_coord, col_coord) for row_coord, col_coord in product(row_coords, col_coords)} + return { + (row_coord, col_coord) + for row_coord, col_coord in product(row_coords, col_coords) + } def _fig1_tensor(layout=None): @@ -239,6 +251,7 @@ def _figure_path(name): FIG1_COLOR_LAYOUT_4X2 = Layout((4, 2), (1, 0)) FIG1_COLOR_LAYOUT_2X4 = Layout((2, (2, 2)), (1, (0, 2))) + def test_fig1_rank3_tensor(viz): """Figure 1, row 1: a 2×2×2 tensor with Shape (2,2,2) : Stride (2,1,4).""" L = FIG1_BASE_LAYOUT @@ -280,7 +293,9 @@ def test_fig1_fold_mode2_into_mode0(viz): assert C == Layout((4, 2), (2, 1)) assert functionally_equal(L, C) tensor = _fig1_tensor(C) - assert [tensor[row, col] for row in range(4) for col in range(2)] == list("abcdefgh") + assert [tensor[row, col] for row in range(4) for col in range(2)] == list( + "abcdefgh" + ) if not viz: return viz.draw_layout( @@ -498,7 +513,10 @@ def test_fig4a_identity_coordinate(viz): layout = Layout((4, 8), (1, 4)) tensor = Tensor( layout, - data=[_format_coord(idx2crd(offset, layout.shape)) for offset in range(size(layout))], + data=[ + _format_coord(idx2crd(offset, layout.shape)) + for offset in range(size(layout)) + ], ) assert tensor[0, 0] == "(0,0)" assert tensor[3, 7] == "(3,7)" @@ -518,7 +536,9 @@ def test_fig4b_transposed_block_coordinate(viz): layout = Layout((4, (2, 4)), (4, (24, 1))) tensor = Tensor( layout, - data=[_format_coord(idx2crd(offset, (4, 10))) for offset in range(cosize(layout))], + data=[ + _format_coord(idx2crd(offset, (4, 10))) for offset in range(cosize(layout)) + ], ) assert [tensor[0, j] for j in range(8)] == [ "(0,0)", @@ -570,6 +590,7 @@ def test_fig4c_binary_swizzle(viz): FIG4_BINARY_SWIZZLE_4X4 = compose(Swizzle(2, 0, 2), Layout((4, 4), (1, 4))) + def test_table1_integer_linear_form(): """Table 1: integer strides are columns of a 1×n Z-matrix.""" L = Layout(((2, 2), (4, 2)), ((1, 8), (2, 16))) @@ -1127,6 +1148,7 @@ def test_s3_3_3_apparent_violation_fail_right_raises(): THR_VAL_LAYOUT_C = Layout(((4, 8), 2), ((16, 1), 8)) + def test_fig6_thread_value_partitioning(viz): """Figure 6: inverse display of ThrValLayoutC over the 8×8 C-matrix.""" t_shape = mode(THR_VAL_LAYOUT_C.shape, 0) @@ -1155,6 +1177,7 @@ def test_fig6_thread_value_partitioning(viz): # §3.3.4 — Application: Partitioning (Table 4) # ============================================================================= + def test_table4_colmajor(): """Table 4: ColMajor (8,8):(1,8) composed with TV layout.""" data = Layout((8, 8), (1, 8)) @@ -1516,10 +1539,7 @@ def test_fig8a_two_element_common_subvector(viz): if not viz: return viz.draw_composite( - [ - ( src, { "slice_spec": src_spec } ), - ( dst, { "slice_spec": dst_spec } ) - ], + [(src, {"slice_spec": src_spec}), (dst, {"slice_spec": dst_spec})], filename=_figure_path("fig8a_two_element_common_subvector"), titles=["Source", "Destination"], main_title="Fig 8a: A 2-element common subvector", @@ -1546,10 +1566,7 @@ def test_fig8b_four_element_common_subvector(viz): if not viz: return viz.draw_composite( - [ - ( src, { "slice_spec": src_spec } ), - ( dst, { "slice_spec": dst_spec } ) - ], + [(src, {"slice_spec": src_spec}), (dst, {"slice_spec": dst_spec})], filename=_figure_path("fig8b_four_element_common_subvector"), titles=["Source", "Destination"], main_title="Fig 8b: A 4-element common subvector", @@ -1716,6 +1733,7 @@ def test_table7_complement_broadcast_even_stride(): def test_table7_complement_coordinate_identity_proxy(): """Table 7: (4,8):(e0,e1) is complemented by the proxy (1,1):(4e0,4e1).""" + def complement_proxy(coord): block_row, block_col = coord return (4 * block_row, 4 * block_col) @@ -1725,13 +1743,16 @@ def complement_proxy(coord): # 4-step coordinate basis directly, rather than asserting arbitrary finite # extensions are disjoint from the original 4×8 block. sample_shape = (2, 2) - sample_image = [complement_proxy(idx2crd(i, sample_shape)) for i in range(size(sample_shape))] + sample_image = [ + complement_proxy(idx2crd(i, sample_shape)) for i in range(size(sample_shape)) + ] assert sample_image == [(0, 0), (4, 0), (0, 4), (4, 4)] assert [crd2idx(coord, (8, 8)) for coord in sample_image] == [0, 4, 32, 36] def test_table7_complement_coordinate_blocked_proxy(): """Table 7: (4,(4,2)):(e1,(e0,12e1)) is complemented by (1,(3,1)):(4e0,(4e1,24e1)).""" + def coord_layout(coord): row, (col, block) = coord return (col, row + 12 * block) @@ -1741,10 +1762,14 @@ def complement_proxy(coord): return (4 * block_row, 4 * block_col + 24 * block_group) original_shape = (4, (4, 2)) - original_image = {coord_layout(idx2crd(i, original_shape)) for i in range(size(original_shape))} + original_image = { + coord_layout(idx2crd(i, original_shape)) for i in range(size(original_shape)) + } paper_shape = (1, (3, 1)) - assert [complement_proxy(idx2crd(i, paper_shape)) for i in range(size(paper_shape))] == [ + assert [ + complement_proxy(idx2crd(i, paper_shape)) for i in range(size(paper_shape)) + ] == [ (0, 0), (0, 4), (0, 8), @@ -1754,7 +1779,9 @@ def complement_proxy(coord): # larger compatible extension to check the paper's "disjoint and ordered" # claim concretely. sample_shape = (2, (3, 2)) - sample_image = [complement_proxy(idx2crd(i, sample_shape)) for i in range(size(sample_shape))] + sample_image = [ + complement_proxy(idx2crd(i, sample_shape)) for i in range(size(sample_shape)) + ] for coord in sample_image[1:]: assert coord not in original_image assert [crd2idx(coord, (8, 48)) for coord in sample_image] == sorted( diff --git a/tests/tensor.py b/tests/tensor.py index 801c13a..317ff7a 100644 --- a/tests/tensor.py +++ b/tests/tensor.py @@ -35,22 +35,21 @@ """ import pytest - from tensor_layouts import ( - Layout, - Swizzle, - compose, + coalesce, complement, + compose, + cosize, + flatten, + Layout, logical_divide, logical_product, + mode, rank, size, - cosize, - mode, - flatten, - coalesce, + Swizzle, + Tensor, ) -from tensor_layouts import Tensor # ============================================================================ @@ -190,7 +189,7 @@ def test_repr_with_offset(): def test_is_affine(): - from tensor_layouts import ComposedLayout, Swizzle, is_affine + from tensor_layouts import ComposedLayout, is_affine, Swizzle # Affine layout t1 = Tensor(Layout((4, 8), (1, 4))) @@ -1097,7 +1096,9 @@ def test_cute_mma_fragment_pattern(): assert isinstance(thread1_slice, Tensor) # Verify they access different starting positions - assert thread0_slice.offset != thread1_slice.offset or thread0_slice(0) != thread1_slice(0) + assert thread0_slice.offset != thread1_slice.offset or thread0_slice( + 0 + ) != thread1_slice(0) def test_hierarchical_slice_preserves_structure(): diff --git a/tests/viz.py b/tests/viz.py index a212558..0ff6ee0 100644 --- a/tests/viz.py +++ b/tests/viz.py @@ -20,23 +20,23 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from collections import defaultdict import tempfile +from collections import defaultdict import pytest from tensor_layouts import * -from tensor_layouts.tensor import Tensor from tensor_layouts.atoms_amd import ( - CDNA3P_16x16x32_F32F16F16_MFMA, CDNA3_32x32x16_F32F8F8_MFMA, + CDNA3P_16x16x32_F32F16F16_MFMA, ) from tensor_layouts.atoms_nv import ( + SM120_16x8x32_F32E4M3E4M3F32_TN, SM80_16x8x16_F16F16F16F16_TN, SM90_16x8x4_F64F64F64F64_TN, - SM120_16x8x32_F32E4M3E4M3F32_TN, ) from tensor_layouts.layout_utils import tile_mma_grid +from tensor_layouts.tensor import Tensor try: import matplotlib.figure @@ -57,25 +57,25 @@ _build_tiled_grid_figure, _build_tv_figure, _compute_tv_mapping, + _coord_levels, _draw_hierarchical_grid, _format_hierarchical_cell_lines, _format_nested_coord, - _coord_levels, - _get_slice_highlight_mask_2d, - _level_block_sizes, - _level_spans, + _get_color_indices_2d, _get_hierarchical_cell_coords_2d, _get_hierarchical_indices_2d, _get_indices_2d, - _get_color_indices_2d, + _get_slice_highlight_mask_2d, + _level_block_sizes, + _level_spans, ) + HAS_VIZ = True except ImportError: HAS_VIZ = False requires_viz = pytest.mark.skipif( - not HAS_VIZ, - reason="tensor_layouts.viz not available (needs matplotlib)" + not HAS_VIZ, reason="tensor_layouts.viz not available (needs matplotlib)" ) @@ -86,9 +86,13 @@ def _call_draw_hierarchical_grid(ax, layout, **kwargs): row_shape = mode(layout.shape, 0) col_shape = mode(layout.shape, 1) return _draw_hierarchical_grid( - ax, indices, rows, cols, + ax, + indices, + rows, + cols, cell_coords=cell_coords, - row_shape=row_shape, col_shape=col_shape, + row_shape=row_shape, + col_shape=col_shape, **kwargs, ) @@ -141,7 +145,6 @@ def _highlight_patch_positions(ax): ] - @requires_viz def test_draw_layout_returns_figure_without_raising(): """Smoke test for draw_layout helper.""" @@ -153,7 +156,6 @@ def test_draw_layout_returns_figure_without_raising(): plt.close(fig) - @requires_viz def test_draw_swizzle_returns_figure_without_raising(): """Regression test for draw_swizzle helper.""" @@ -165,7 +167,6 @@ def test_draw_swizzle_returns_figure_without_raising(): plt.close(fig) - @requires_viz def test_draw_layout_accepts_tensor(): """draw_layout/draw_layout accept Tensor and display offset-adjusted values.""" @@ -186,7 +187,6 @@ def test_draw_layout_accepts_tensor(): plt.close(fig) - @requires_viz def test_draw_layout_tensor_zero_offset(): """Tensor with offset=0 produces same values as bare Layout.""" @@ -195,18 +195,23 @@ def test_draw_layout_tensor_zero_offset(): fig_layout = _build_layout_figure(layout) fig_tensor = _build_layout_figure(tensor) try: + def _cell_values(fig): ax = fig.axes[0] return sorted( - [(t.get_position(), t.get_text()) for t in ax.texts if t.get_text().isdigit()], + [ + (t.get_position(), t.get_text()) + for t in ax.texts + if t.get_text().isdigit() + ], ) + assert _cell_values(fig_layout) == _cell_values(fig_tensor) finally: plt.close(fig_layout) plt.close(fig_tensor) - @requires_viz def test_draw_layout_swizzled_tensor(): """Swizzled Tensor renders without error.""" @@ -220,14 +225,12 @@ def test_draw_layout_swizzled_tensor(): plt.close(fig) - @requires_viz def test_draw_layout_smoke(): with tempfile.NamedTemporaryFile(suffix=".png") as f: draw_layout(Layout((8, 8), (8, 1)), filename=f.name) - @requires_viz def test_draw_layout_negative_offset_list_labels_do_not_wrap(): """Negative offsets must not wrap around explicit cell_labels lists.""" @@ -249,7 +252,6 @@ def test_draw_layout_negative_offset_list_labels_do_not_wrap(): plt.close(fig) - @requires_viz def test_draw_layout_hierarchical_negative_offset_labels_do_not_wrap(): """Hierarchical grids must also avoid Python negative-index wraparound.""" @@ -276,14 +278,14 @@ def test_draw_layout_hierarchical_negative_offset_labels_do_not_wrap(): plt.close(fig) - @requires_viz def test_color_by_row_matches_color_layout(): """color_by='row' produces the same color indices as the manual color_layout.""" layout = Layout((4, 8), (8, 1)) fig_by = _build_layout_figure(layout, color_by="row") - fig_manual = _build_layout_figure(layout, color_layout=Layout((4, 8), (1, 0)), - colorize=True) + fig_manual = _build_layout_figure( + layout, color_layout=Layout((4, 8), (1, 0)), colorize=True + ) try: # Both should have the same cell background colors patches_by = [p for p in fig_by.axes[0].patches] @@ -296,7 +298,6 @@ def test_color_by_row_matches_color_layout(): plt.close(fig_manual) - @requires_viz def test_color_by_column(): """color_by='column' renders without error.""" @@ -307,15 +308,13 @@ def test_color_by_column(): plt.close(fig) - @requires_viz def test_color_by_and_color_layout_exclusive(): """Providing both color_by and color_layout raises ValueError.""" with pytest.raises(ValueError, match="mutually exclusive"): - draw_layout(Layout((4, 4), (4, 1)), - color_by="row", - color_layout=Layout((4, 4), (1, 0))) - + draw_layout( + Layout((4, 4), (4, 1)), color_by="row", color_layout=Layout((4, 4), (1, 0)) + ) @requires_viz @@ -333,7 +332,6 @@ def test_rank3_layout_produces_multi_panel(): plt.close(fig) - @requires_viz def test_rank3_panel_values_match_layout(): """Each rank-3 panel shows correct offset values.""" @@ -341,6 +339,7 @@ def test_rank3_panel_values_match_layout(): divided = flat_divide(matrix, Layout(2, 1)) fig = _build_layout_figure(divided) try: + def _cell_val(ax, x, y): for t in ax.texts: tx = round(t.get_position()[0], 1) @@ -360,7 +359,6 @@ def _cell_val(ax, x, y): plt.close(fig) - @requires_viz def test_draw_layout_composed_values_match_layout(): composed = compose( @@ -431,7 +429,6 @@ def test_rank4_layout_renders(): plt.close(fig) - @requires_viz def test_rank4_composed_panel_values_match_layout(): composed = compose( @@ -485,14 +482,12 @@ def test_draw_swizzle_smoke(): draw_swizzle(Layout((8, 8), (8, 1)), Swizzle(3, 0, 3), filename=f.name) - @requires_viz def test_draw_slice_smoke(): with tempfile.NamedTemporaryFile(suffix=".png") as f: draw_slice(Layout((4, 8), (8, 1)), (2, None), filename=f.name) - @requires_viz def test_draw_slice_composed_uses_internal_preoffset_in_title_and_values(): composed = compose( @@ -532,7 +527,6 @@ def test_draw_tv_layout_smoke(atom): ) - @pytest.mark.parametrize("atom", MIXED_VIZ_ATOMS, ids=lambda a: a.name) @requires_viz def test_draw_mma_layout_smoke(atom): @@ -548,7 +542,6 @@ def test_draw_mma_layout_smoke(atom): ) - @requires_viz def test_draw_mma_layout_raises_for_incompatible_panel_shape(): a_layout = Layout((2, 2), (1, 2)) @@ -567,7 +560,6 @@ def test_draw_mma_layout_raises_for_incompatible_panel_shape(): plt.close("all") - @pytest.mark.parametrize("atom", MIXED_VIZ_ATOMS, ids=lambda a: a.name) @requires_viz def test_draw_tiled_grid_smoke(atom): @@ -583,7 +575,6 @@ def test_draw_tiled_grid_smoke(atom): ) - @requires_viz def test_draw_composite_smoke(): panels = [Layout((4, 4), (4, 1)), Layout((4, 4), (1, 4))] @@ -597,14 +588,13 @@ def test_draw_composite_smoke(): ) - @requires_viz def test_draw_composite_mixed_tv_and_offset(): """Composite figure with per-panel tv_mode: one offset grid, one TV grid.""" atom = SM80_16x8x16_F16F16F16F16_TN panels = [ Layout((4, 4), (4, 1)), # offset grid (default) - (atom.c_layout, {'tv_mode': True}), # TV grid + (atom.c_layout, {"tv_mode": True}), # TV grid ] fig = _build_composite_figure(panels, titles=["Offset", "TV"]) try: @@ -614,14 +604,13 @@ def test_draw_composite_mixed_tv_and_offset(): plt.close(fig) - @requires_viz def test_draw_composite_hierarchical_panel(): """Composite figure with flatten_hierarchical=False renders hierarchy lines.""" hier = Layout(((2, 2), (2, 2)), ((1, 4), (2, 8))) flat = Layout((4, 4), (4, 1)) panels = [ - (hier, {'flatten_hierarchical': False}), + (hier, {"flatten_hierarchical": False}), flat, ] fig = _build_composite_figure(panels, titles=["Hierarchical", "Flat"]) @@ -637,7 +626,6 @@ def test_draw_composite_hierarchical_panel(): plt.close(fig) - @requires_viz def test_draw_composite_hierarchical_top_level_default(): """flatten_hierarchical=False as top-level default applies to all panels.""" @@ -650,7 +638,6 @@ def test_draw_composite_hierarchical_top_level_default(): plt.close(fig) - @requires_viz def test_draw_composite_warns_on_panel_truncation(): """Warning emitted when panels exceed grid capacity.""" @@ -660,23 +647,18 @@ def test_draw_composite_warns_on_panel_truncation(): plt.close(fig) - @requires_viz def test_draw_copy_layout_smoke(): src = Layout((4, 2), (2, 1)) dst = Layout((4, 2), (1, 4)) with tempfile.NamedTemporaryFile(suffix=".png") as f: - draw_copy_layout(src, dst, filename=f.name, - title="copy smoke", colorize=True) - + draw_copy_layout(src, dst, filename=f.name, title="copy smoke", colorize=True) @requires_viz def test_draw_copy_layout_rejects_rank1(): with pytest.raises(ValueError, match="rank 2"): - draw_copy_layout(Layout(8, 1), Layout((4, 2), (2, 1)), - filename="ignored.png") - + draw_copy_layout(Layout(8, 1), Layout((4, 2), (2, 1)), filename="ignored.png") @requires_viz @@ -690,22 +672,21 @@ def test_draw_copy_layout_returns_figure(): plt.close(fig) - @requires_viz def test_draw_copy_atom_smoke(): """draw_copy_atom handles the upcast from bit coordinates automatically.""" from tensor_layouts.atoms_nv import SM75_U32x1_LDSM_N + with tempfile.NamedTemporaryFile(suffix=".png") as f: draw_copy_atom(SM75_U32x1_LDSM_N, element_bits=16, filename=f.name) - @requires_viz def test_draw_copy_atom_returns_figure(): """draw_copy_atom renders without raising.""" from tensor_layouts.atoms_nv import SM90_U32x4_STSM_N - draw_copy_atom(SM90_U32x4_STSM_N, element_bits=16) + draw_copy_atom(SM90_U32x4_STSM_N, element_bits=16) @requires_viz @@ -717,7 +698,6 @@ def test_draw_tv_layout_returns_figure(): plt.close(fig) - @requires_viz def test_draw_tv_layout_negative_stride_returns_figure(): """Dense negative-stride TV layouts should render after rebasing.""" @@ -728,7 +708,6 @@ def test_draw_tv_layout_negative_stride_returns_figure(): plt.close(fig) - @requires_viz def test_draw_tv_layout_negative_stride_with_holes_renders_questions(): """Gappy negative-stride TV layouts should render holes rather than fail.""" @@ -741,24 +720,29 @@ def test_draw_tv_layout_negative_stride_with_holes_renders_questions(): plt.close(fig) - @requires_viz def test_draw_mma_layout_returns_figure(): from tensor_layouts.atoms_nv import SM80_16x8x16_F16F16F16F16_TN + atom = SM80_16x8x16_F16F16F16F16_TN - fig = _build_mma_figure(atom.a_layout, atom.b_layout, atom.c_layout, - tile_mnk=atom.shape_mnk, colorize=True, - thr_id_layout=atom.thr_id) + fig = _build_mma_figure( + atom.a_layout, + atom.b_layout, + atom.c_layout, + tile_mnk=atom.shape_mnk, + colorize=True, + thr_id_layout=atom.thr_id, + ) try: assert isinstance(fig, matplotlib.figure.Figure) finally: plt.close(fig) - @requires_viz def test_draw_tiled_grid_returns_figure(): from tensor_layouts.atoms_nv import SM80_16x8x16_F16F16F16F16_TN + atom = SM80_16x8x16_F16F16F16F16_TN atom_layout = Layout((2, 2), (1, 2)) grid, tile_shape = tile_mma_grid(atom, atom_layout, matrix="C") @@ -769,7 +753,6 @@ def test_draw_tiled_grid_returns_figure(): plt.close(fig) - @requires_viz def test_draw_slice_returns_figure(): fig = _build_slice_figure(Layout((4, 8), (8, 1)), (2, None)) @@ -779,7 +762,6 @@ def test_draw_slice_returns_figure(): plt.close(fig) - @requires_viz def test_draw_composite_returns_figure(): l1 = Layout((4, 4), (4, 1)) @@ -792,7 +774,6 @@ def test_draw_composite_returns_figure(): plt.close(fig) - @requires_viz def test_draw_composite_tensor_data_labels(): """Tensor panels show data values, not raw offsets, in cell text.""" @@ -800,28 +781,26 @@ def test_draw_composite_tensor_data_labels(): fig = _build_composite_figure([t], titles=["Data"]) try: ax = fig.axes[0] - cell_texts = [c.get_text() for c in ax.texts - if c.get_text() in ("W", "X", "Y", "Z")] + cell_texts = [ + c.get_text() for c in ax.texts if c.get_text() in ("W", "X", "Y", "Z") + ] assert len(cell_texts) == 4, f"expected data labels, got {cell_texts}" finally: plt.close(fig) - @requires_viz def test_draw_composite_layout_shows_offsets(): """Layout panels show offset integers, not data.""" fig = _build_composite_figure([Layout(4, 1)], titles=["Offsets"]) try: ax = fig.axes[0] - cell_texts = sorted(c.get_text() for c in ax.texts - if c.get_text().isdigit()) + cell_texts = sorted(c.get_text() for c in ax.texts if c.get_text().isdigit()) assert "0" in cell_texts and "3" in cell_texts finally: plt.close(fig) - @requires_viz def test_draw_composite_auto_panel_size_compact_for_1d(): """Auto panel_size produces compact height for rank-1 layouts.""" @@ -833,7 +812,6 @@ def test_draw_composite_auto_panel_size_compact_for_1d(): plt.close(fig) - @requires_viz def test_draw_composite_auto_panel_size_scales_with_rows(): """Auto panel_size grows taller for layouts with more rows.""" @@ -848,7 +826,6 @@ def test_draw_composite_auto_panel_size_scales_with_rows(): plt.close(fig_2d) - @requires_viz def test_draw_composite_explicit_panel_size_overrides_auto(): """Explicit panel_size takes precedence over auto-compute.""" @@ -860,7 +837,6 @@ def test_draw_composite_explicit_panel_size_overrides_auto(): plt.close(fig) - @requires_viz def test_draw_composite_cell_labels_offset_kwarg(): """cell_labels='offset' passed as kwarg forces offset display for Tensors.""" @@ -876,7 +852,6 @@ def test_draw_composite_cell_labels_offset_kwarg(): plt.close(fig) - @requires_viz def test_draw_composite_per_panel_override_wins(): """Per-panel option dict overrides top-level kwarg.""" @@ -888,16 +863,17 @@ def test_draw_composite_per_panel_override_wins(): ) try: ax = fig.axes[0] - cell_texts = [c.get_text() for c in ax.texts - if c.get_text() in ("W", "X", "Y", "Z")] + cell_texts = [ + c.get_text() for c in ax.texts if c.get_text() in ("W", "X", "Y", "Z") + ] assert len(cell_texts) == 4 finally: plt.close(fig) - # --- draw_gemm tests --- + @requires_viz def test_draw_gemm_smoke(): """draw_gemm produces a figure with 4 axes (empty + A + B^T + C).""" @@ -913,7 +889,6 @@ def test_draw_gemm_smoke(): plt.close(fig) - @requires_viz def test_draw_gemm_tensor_shows_data(): """Tensor operands display data values, not offsets.""" @@ -923,14 +898,16 @@ def test_draw_gemm_tensor_shows_data(): fig = _build_gemm_figure(A, B, C) try: # Check A panel (axes[2] = bottom-left) - a_texts = [c.get_text() for c in fig.axes[2].texts - if c.get_text() in ("A", "B", "C", "D")] + a_texts = [ + c.get_text() + for c in fig.axes[2].texts + if c.get_text() in ("A", "B", "C", "D") + ] assert len(a_texts) == 4, f"expected A data labels, got {a_texts}" finally: plt.close(fig) - @requires_viz def test_draw_gemm_b_transposed(): """B panel title shows transposed dimensions K×N.""" @@ -946,7 +923,6 @@ def test_draw_gemm_b_transposed(): plt.close(fig) - @requires_viz def test_draw_gemm_cell_labels_offset(): """cell_labels='offset' forces offset display for Tensor operands.""" @@ -962,7 +938,6 @@ def test_draw_gemm_cell_labels_offset(): plt.close(fig) - @requires_viz def test_draw_gemm_hierarchy_boundary_boxes(): """Hierarchical layouts get hierarchy boundary lines in draw_gemm.""" @@ -981,7 +956,6 @@ def test_draw_gemm_hierarchy_boundary_boxes(): plt.close(fig) - @requires_viz def test_draw_copy_layout_same_thread_colors_both_panels(): """Src and dst panels should use the same color for the same thread.""" @@ -1001,7 +975,6 @@ def test_draw_copy_layout_same_thread_colors_both_panels(): plt.close(fig) - @requires_viz def test_get_indices_2d_row_major_matches_logical_coordinates(): layout = Layout((4, 3), (3, 1)) @@ -1014,7 +987,6 @@ def test_get_indices_2d_row_major_matches_logical_coordinates(): ] - @requires_viz def test_get_indices_2d_column_major_matches_logical_coordinates(): layout = Layout((4, 3), (1, 4)) @@ -1027,7 +999,6 @@ def test_get_indices_2d_column_major_matches_logical_coordinates(): ] - @requires_viz def test_get_color_indices_2d_by_row_matches_logical_coordinates(): layout = Layout((4, 3), (3, 1)) @@ -1041,7 +1012,6 @@ def test_get_color_indices_2d_by_row_matches_logical_coordinates(): ] - @requires_viz def test_get_color_indices_2d_by_column_matches_logical_coordinates(): layout = Layout((4, 3), (3, 1)) @@ -1055,7 +1025,6 @@ def test_get_color_indices_2d_by_column_matches_logical_coordinates(): ] - @requires_viz def test_get_color_indices_2d_uniform_layout_is_uniform(): layout = Layout((4, 3), (3, 1)) @@ -1069,7 +1038,6 @@ def test_get_color_indices_2d_uniform_layout_is_uniform(): ] - @requires_viz def test_get_color_indices_2d_1d_layout_is_not_treated_as_uniform(): layout = Layout(4, 1) @@ -1078,7 +1046,6 @@ def test_get_color_indices_2d_1d_layout_is_not_treated_as_uniform(): assert color_indices.tolist() == [[0, 1, 2, 3]] - @requires_viz def test_get_hierarchical_cell_coords_2d_preserves_nested_coordinates(): layout = Layout(((2, 3), (2, 4)), ((1, 6), (2, 12))) @@ -1090,7 +1057,6 @@ def test_get_hierarchical_cell_coords_2d_preserves_nested_coordinates(): assert coords[0, 2] == ((0, 0), (0, 1)) - @requires_viz def test_format_nested_coord_formats_hierarchical_labels(): assert _format_nested_coord(3) == "3" @@ -1098,7 +1064,6 @@ def test_format_nested_coord_formats_hierarchical_labels(): assert _format_nested_coord(((1, 2), 3)) == "((1,2),3)" - @requires_viz def test_format_hierarchical_cell_lines_is_explicit_and_pedagogical(): assert _format_hierarchical_cell_lines((1, 2), (3, 4), 17) == ( @@ -1108,7 +1073,6 @@ def test_format_hierarchical_cell_lines_is_explicit_and_pedagogical(): ) - @requires_viz def test_coord_levels_flattens_nested_coordinates_for_axis_labels(): assert _coord_levels(3) == (3,) @@ -1116,19 +1080,16 @@ def test_coord_levels_flattens_nested_coordinates_for_axis_labels(): assert _coord_levels(((1, 2), 3)) == (1, 2, 3) - @requires_viz def test_level_spans_supports_three_level_hierarchy(): assert _level_spans((2, 3, 4)) == (2, 6, 24) - @requires_viz def test_level_block_sizes_supports_three_level_hierarchy(): assert _level_block_sizes((2, 3, 4)) == (1, 2, 6) - @requires_viz def test_draw_layout_nested_passes_color_indices_to_hierarchical_renderer(monkeypatch): layout = Layout(((2, 2), (2, 2)), ((1, 4), (2, 8))) @@ -1142,7 +1103,9 @@ def fake_draw(ax, indices, rows, cols, **kwargs): seen["color_indices"] = kwargs.get("color_indices") monkeypatch.setattr(viz_mod, "_draw_hierarchical_grid", fake_draw) - monkeypatch.setattr(viz_mod, "_save_figure", lambda fig, filename, dpi=150: plt.close(fig)) + monkeypatch.setattr( + viz_mod, "_save_figure", lambda fig, filename, dpi=150: plt.close(fig) + ) draw_layout( layout, @@ -1157,7 +1120,6 @@ def fake_draw(ax, indices, rows, cols, **kwargs): assert seen["color_indices"].shape == (4, 4) - @requires_viz def test_draw_hierarchical_grid_uses_supplied_color_indices(): layout = Layout(((2, 2), (2, 2)), ((1, 4), (2, 8))) @@ -1224,7 +1186,7 @@ def _label_bboxes(ax): def _has_bbox_overlap(boxes): """Return True if any pair of bounding boxes overlaps.""" for i, (_, bbox_i) in enumerate(boxes): - for _, bbox_j in boxes[i + 1:]: + for _, bbox_j in boxes[i + 1 :]: if Bbox.overlaps(bbox_i, bbox_j): return True return False @@ -1258,11 +1220,16 @@ def _cell_patch_bboxes(ax): renderer = fig.canvas.get_renderer() boxes = {} for patch in ax.patches: - if not all(hasattr(patch, attr) for attr in ("get_x", "get_y", "get_width", "get_height")): + if not all( + hasattr(patch, attr) + for attr in ("get_x", "get_y", "get_width", "get_height") + ): continue if patch.get_width() != 1.0 or patch.get_height() != 1.0: continue - boxes[(int(round(patch.get_y())), int(round(patch.get_x())))] = patch.get_window_extent(renderer=renderer) + boxes[(int(round(patch.get_y())), int(round(patch.get_x())))] = ( + patch.get_window_extent(renderer=renderer) + ) return boxes @@ -1282,7 +1249,6 @@ def _cell_text_bboxes(ax, rows: int, cols: int): return boxes - @requires_viz def test_draw_hierarchical_grid_draws_outer_perimeter_for_coarse_tiles(): layout = Layout(((2, 2), (2, 2)), ((1, 4), (2, 8))) @@ -1298,15 +1264,15 @@ def test_draw_hierarchical_grid_draws_outer_perimeter_for_coarse_tiles(): plt.close(fig) - @requires_viz def test_draw_hierarchical_grid_cecka_hier_col_margin_labels_do_not_overlap(): layout = Layout((4, (4, 2)), (4, (1, 16))) fig, ax = plt.subplots(figsize=(8 * 0.8 + 1, 4 * 0.8 + 1)) try: - _call_draw_hierarchical_grid(ax, layout, flatten_hierarchical=False, - label_hierarchy_levels=True) + _call_draw_hierarchical_grid( + ax, layout, flatten_hierarchical=False, label_hierarchy_levels=True + ) row_boxes, col_boxes = _label_bboxes(ax) assert row_boxes assert col_boxes @@ -1316,15 +1282,15 @@ def test_draw_hierarchical_grid_cecka_hier_col_margin_labels_do_not_overlap(): plt.close(fig) - @requires_viz def test_draw_hierarchical_grid_offset_values_clear_offset_equals_label(): layout = Layout((4, (4, 2)), (4, (1, 16))) fig, ax = plt.subplots(figsize=(8 * 0.8 + 1, 4 * 0.8 + 1)) try: - _call_draw_hierarchical_grid(ax, layout, flatten_hierarchical=False, - label_hierarchy_levels=True) + _call_draw_hierarchical_grid( + ax, layout, flatten_hierarchical=False, label_hierarchy_levels=True + ) pairs = _offset_label_value_bboxes(ax) assert pairs min_gap = min(value_bbox.x0 - label_bbox.x1 for label_bbox, value_bbox in pairs) @@ -1337,7 +1303,6 @@ def test_draw_hierarchical_grid_offset_values_clear_offset_equals_label(): plt.close(fig) - @requires_viz def test_draw_layout_small_nested_hierarchy_keeps_text_inside_cells(monkeypatch): layout = Layout(((2, 2), (2, 2)), ((1, 4), (2, 8))) @@ -1372,15 +1337,15 @@ def fake_save(fig, filename, dpi=150): plt.close(fig) - @requires_viz def test_draw_hierarchical_grid_leaves_corner_gap_between_axis_label_bands(): layout = Layout(((3, 2), ((2, 3), 2)), ((4, 1), ((2, 15), 100))) fig, ax = plt.subplots(figsize=(12 * 0.8 + 1, 6 * 0.8 + 1)) try: - _call_draw_hierarchical_grid(ax, layout, flatten_hierarchical=False, - label_hierarchy_levels=True) + _call_draw_hierarchical_grid( + ax, layout, flatten_hierarchical=False, label_hierarchy_levels=True + ) row_boxes, col_boxes = _label_bboxes(ax) assert row_boxes assert col_boxes @@ -1398,11 +1363,9 @@ def test_draw_hierarchical_grid_leaves_corner_gap_between_axis_label_bands(): plt.close(fig) - @requires_viz def test_draw_hierarchical_grid_draws_outer_perimeter_for_multiple_levels(): - layout = Layout(((3, 2, 2, 2), (4, 2, 2, 2)), - ((1, 3, 6, 12), (24, 96, 192, 384))) + layout = Layout(((3, 2, 2, 2), (4, 2, 2, 2)), ((1, 3, 6, 12), (24, 96, 192, 384))) fig, ax = plt.subplots() try: @@ -1425,7 +1388,6 @@ def test_draw_hierarchical_grid_draws_outer_perimeter_for_multiple_levels(): plt.close(fig) - @requires_viz def test_draw_hierarchical_grid_closes_boxes_for_column_only_hierarchy(): layout = Layout((4, (4, 2)), (4, (1, 16))) @@ -1441,7 +1403,6 @@ def test_draw_hierarchical_grid_closes_boxes_for_column_only_hierarchy(): plt.close(fig) - @requires_viz def test_draw_hierarchical_grid_closes_boxes_for_coarse_column_only_level(): layout = Layout(((3, 2), ((2, 3), 2)), ((4, 1), ((2, 15), 100))) @@ -1457,7 +1418,6 @@ def test_draw_hierarchical_grid_closes_boxes_for_coarse_column_only_level(): plt.close(fig) - @requires_viz def test_draw_hierarchical_grid_draws_coarser_lines_above_finer_lines(): layout = Layout( @@ -1479,7 +1439,6 @@ def test_draw_hierarchical_grid_draws_coarser_lines_above_finer_lines(): plt.close(fig) - @requires_viz def test_draw_slice_hierarchical_keeps_flat_grid_and_highlights_on_top(monkeypatch): layout = Layout(((2, 2), (2, 2)), ((1, 4), (2, 8))) @@ -1509,7 +1468,6 @@ def fake_save(fig, filename, dpi=150): assert max(seen["base_zorders"]) < min(seen["highlight_zorders"]) - @requires_viz def test_slice_highlight_mask_tracks_logical_cells_not_offsets(): layout = Layout((2, 2), (0, 1)) @@ -1520,7 +1478,6 @@ def test_slice_highlight_mask_tracks_logical_cells_not_offsets(): ] - @requires_viz def test_slice_highlight_mask_1d_tuple_spec(): """1D layout with tuple slice_spec should highlight the correct elements.""" @@ -1530,7 +1487,6 @@ def test_slice_highlight_mask_1d_tuple_spec(): assert mask.tolist() == [[False, False, True, True, True, False, False, False]] - @requires_viz def test_slice_highlight_mask_1d_tuple_spec_rank1(): """Rank-1 layout with tuple slice_spec should highlight the correct elements.""" @@ -1540,7 +1496,6 @@ def test_slice_highlight_mask_1d_tuple_spec_rank1(): assert mask.tolist() == [[False, False, True, True, True, False, False, False]] - @requires_viz def test_slice_highlight_mask_1d_tuple_int_spec(): """1D layout with tuple (int,) slice_spec highlights a single element.""" @@ -1550,7 +1505,6 @@ def test_slice_highlight_mask_1d_tuple_int_spec(): assert mask.tolist() == [[False, False, False, True, False, False, False, False]] - @requires_viz def test_slice_highlight_mask_1d_tuple_none_spec(): """1D layout with tuple (None,) selects all elements.""" @@ -1559,7 +1513,6 @@ def test_slice_highlight_mask_1d_tuple_none_spec(): assert mask.tolist() == [[True, True, True, True]] - @requires_viz def test_slice_highlight_mask_1d_wrong_tuple_length_raises(): """1D layout with 2-element tuple slice_spec raises ValueError.""" @@ -1568,7 +1521,6 @@ def test_slice_highlight_mask_1d_wrong_tuple_length_raises(): _get_slice_highlight_mask_2d(layout, (1, 2)) - @requires_viz def test_compute_tv_mapping_uses_first_wins_for_duplicate_cells(): layout = Layout((2, 2), (0, 0)) @@ -1576,7 +1528,6 @@ def test_compute_tv_mapping_uses_first_wins_for_duplicate_cells(): assert tv_map == {(0, 0): (0, 0, 0)} - @requires_viz def test_compute_tv_mapping_rebases_negative_offsets(): """TV mapping should place shifted dense images into the rebased footprint.""" @@ -1594,7 +1545,6 @@ def test_compute_tv_mapping_rebases_negative_offsets(): } - @requires_viz def test_compute_tv_mapping_raises_for_out_of_bounds_grid(): layout = Layout((2, 2), (1, 2)) @@ -1602,7 +1552,6 @@ def test_compute_tv_mapping_raises_for_out_of_bounds_grid(): _compute_tv_mapping(layout, grid_rows=1, grid_cols=1) - @requires_viz def test_draw_swizzle_delegates_to_shared_builder(monkeypatch): layout = Layout((8, 64), (64, 1)) @@ -1630,7 +1579,6 @@ def fake_save(passed_fig, filename, dpi=150): plt.close(fig) - @requires_viz def test_draw_swizzle_saves_figure_from_shared_builder(monkeypatch): layout = Layout((8, 64), (64, 1)) @@ -1657,7 +1605,6 @@ def fake_save(passed_fig, filename, dpi=150): # ── draw_combined_mma_grid / draw_combined_mma_grid ────────────────────── - @requires_viz def test_draw_combined_mma_grid_smoke(): atom = SM80_16x8x16_F16F16F16F16_TN @@ -1673,9 +1620,9 @@ def test_draw_combined_mma_grid_smoke(): M, N, K = M_a * 2, N_a * 2, K_a with tempfile.NamedTemporaryFile(suffix=".png") as f: - draw_combined_mma_grid(a_grid, b_display, c_grid, M, N, K, - filename=f.name, title="test") - + draw_combined_mma_grid( + a_grid, b_display, c_grid, M, N, K, filename=f.name, title="test" + ) @requires_viz @@ -1691,8 +1638,7 @@ def test_draw_combined_mma_grid_returns_figure(): M_a, N_a, K_a = atom.shape_mnk M, N, K = M_a * 2, N_a * 2, K_a - fig = _build_combined_grid_figure(a_grid, b_display, c_grid, M, N, K, - title="test") + fig = _build_combined_grid_figure(a_grid, b_display, c_grid, M, N, K, title="test") try: assert isinstance(fig, matplotlib.figure.Figure) finally: