Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 136 additions & 52 deletions docs/generate_figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,17 @@

import matplotlib.patches as patches
import matplotlib.pyplot as plt

from tensor_layouts import Layout, Swizzle, compose
from tensor_layouts import compose, Layout, Swizzle
from tensor_layouts.atoms_nv import SM80_16x8x16_F16F16F16F16_TN
from tensor_layouts.layout_utils import tile_mma_grid
from tensor_layouts.viz import (
draw_composite,
draw_layout,
draw_mma_layout,
draw_slice,
draw_swizzle,
draw_tiled_grid,
draw_tv_layout,
draw_mma_layout,
)

IMAGES = Path(__file__).resolve().parent / "images"
Expand All @@ -65,8 +64,10 @@ def _generate_intile_oftile(path: Path) -> None:
M, K = 4, 8
tm, tk = 2, 4
tile_colors = {
(0, 0): "#DBEAFE", (0, 1): "#FEE2E2",
(1, 0): "#D1FAE5", (1, 1): "#EDE9FE",
(0, 0): "#DBEAFE",
(0, 1): "#FEE2E2",
(1, 0): "#D1FAE5",
(1, 1): "#EDE9FE",
}

fig, axes = plt.subplots(1, 3, figsize=(18, 5.2))
Expand All @@ -79,40 +80,71 @@ def _draw_grid(ax, cell_text_fn, title, subtitle):
ok = c // tk
y = M - 1 - r
rect = patches.Rectangle(
(c, y), 1, 1,
(c, y),
1,
1,
facecolor=tile_colors[(om, ok)],
edgecolor="#D1D5DB", linewidth=0.5,
edgecolor="#D1D5DB",
linewidth=0.5,
)
ax.add_patch(rect)
ax.text(
c + 0.5, y + 0.5, cell_text_fn(r, c),
ha="center", va="center", fontsize=9,
color="#374151", family="monospace",
c + 0.5,
y + 0.5,
cell_text_fn(r, c),
ha="center",
va="center",
fontsize=9,
color="#374151",
family="monospace",
)
# thick tile borders
for i in range(0, M + 1, tm):
ax.plot([0, K], [i, i], color="#1F2937", lw=2.5,
solid_capstyle="butt")
ax.plot([0, K], [i, i], color="#1F2937", lw=2.5, solid_capstyle="butt")
for j in range(0, K + 1, tk):
ax.plot([j, j], [0, M], color="#1F2937", lw=2.5,
solid_capstyle="butt")
ax.plot([j, j], [0, M], color="#1F2937", lw=2.5, solid_capstyle="butt")
# oftile margin labels
for om in range(M // tm):
y_c = M - om * tm - tm / 2
ax.text(-0.3, y_c, f"oftile\u2080={om}", ha="right", va="center",
fontsize=8, fontweight="bold", color="#7C3AED",
family="monospace")
ax.text(
-0.3,
y_c,
f"oftile\u2080={om}",
ha="right",
va="center",
fontsize=8,
fontweight="bold",
color="#7C3AED",
family="monospace",
)
for ok in range(K // tk):
x_c = ok * tk + tk / 2
ax.text(x_c, M + 0.15, f"oftile\u2081={ok}", ha="center",
va="bottom", fontsize=8, fontweight="bold", color="#7C3AED",
family="monospace")
ax.text(
x_c,
M + 0.15,
f"oftile\u2081={ok}",
ha="center",
va="bottom",
fontsize=8,
fontweight="bold",
color="#7C3AED",
family="monospace",
)
ax.set_xlim(-2.5, K + 0.5)
ax.set_ylim(-1.8, M + 0.8)
ax.axis("off")
ax.set_title(title, fontsize=10.5, fontweight="bold", pad=10)
ax.text(K / 2, -0.3, subtitle, ha="center", va="top", fontsize=8,
color="#6B7280", family="monospace", linespacing=1.6)
ax.text(
K / 2,
-0.3,
subtitle,
ha="center",
va="top",
fontsize=8,
color="#6B7280",
family="monospace",
linespacing=1.6,
)

# ── Panel 1: linear index ────────────────────────────────────
_draw_grid(
Expand All @@ -129,8 +161,7 @@ def _draw_grid(ax, cell_text_fn, title, subtitle):
axes[1],
lambda r, c: f"{r},{c}",
"2D index (row, col)",
"intile = (row % 2, col % 4)\n"
"oftile = (row // 2, col // 4)",
"intile = (row % 2, col % 4)\noftile = (row // 2, col // 4)",
)

# ── Panel 3: layout algebra ──────────────────────────────────
Expand All @@ -142,9 +173,17 @@ def _draw_grid(ax, cell_text_fn, title, subtitle):
"mode 0: (intile\u2080, oftile\u2080)\n"
"mode 1: (intile\u2081, oftile\u2081)",
)
axes[2].text(K / 2, -1.45, "cells show (intile\u2080, intile\u2081)",
ha="center", va="top", fontsize=9, fontweight="bold",
color="#2563EB", family="monospace")
axes[2].text(
K / 2,
-1.45,
"cells show (intile\u2080, intile\u2081)",
ha="center",
va="top",
fontsize=9,
fontweight="bold",
color="#2563EB",
family="monospace",
)

plt.tight_layout()
fig.savefig(path, dpi=150, bbox_inches="tight")
Expand Down Expand Up @@ -188,7 +227,7 @@ def _generate_im2col(
window_colors_rgb.append((r, g, b))

def rgb_to_hex(rgb):
return f"#{int(rgb[0]*255):02X}{int(rgb[1]*255):02X}{int(rgb[2]*255):02X}"
return f"#{int(rgb[0] * 255):02X}{int(rgb[1] * 255):02X}{int(rgb[2] * 255):02X}"

window_colors = [rgb_to_hex(c) for c in window_colors_rgb]

Expand All @@ -205,7 +244,9 @@ def blend(win_indices):

# -- figure --
fig, axes = plt.subplots(
1, 2, figsize=(4 + n_taps * 1.2, max(H, n_windows) * 0.8 + 1.5),
1,
2,
figsize=(4 + n_taps * 1.2, max(H, n_windows) * 0.8 + 1.5),
gridspec_kw={"width_ratios": [W, n_taps * 1.1]},
)

Expand All @@ -217,19 +258,33 @@ def blend(win_indices):
y = H - 1 - row
color = blend(cell_windows[idx])
rect = patches.Rectangle(
(col, y), 1, 1,
facecolor=color, edgecolor="#D1D5DB", linewidth=0.5,
(col, y),
1,
1,
facecolor=color,
edgecolor="#D1D5DB",
linewidth=0.5,
)
ax.add_patch(rect)
ax.text(
col + 0.5, y + 0.5, labels[idx],
ha="center", va="center", fontsize=13,
color="#374151", family="monospace", fontweight="bold",
col + 0.5,
y + 0.5,
labels[idx],
ha="center",
va="center",
fontsize=13,
color="#374151",
family="monospace",
fontweight="bold",
)
# thick outer border
rect = patches.Rectangle(
(0, 0), W, H,
facecolor="none", edgecolor="#1F2937", linewidth=2.5,
(0, 0),
W,
H,
facecolor="none",
edgecolor="#1F2937",
linewidth=2.5,
)
ax.add_patch(rect)
ax.set_xlim(-0.5, W + 0.5)
Expand All @@ -245,27 +300,45 @@ def blend(win_indices):
for col_idx in range(n_taps):
cell_label = labels[im2col_rows[row_idx][col_idx]]
rect = patches.Rectangle(
(col_idx, y), 1, 1,
(col_idx, y),
1,
1,
facecolor=window_colors[row_idx],
edgecolor="#D1D5DB", linewidth=0.5,
edgecolor="#D1D5DB",
linewidth=0.5,
)
ax.add_patch(rect)
ax.text(
col_idx + 0.5, y + 0.5, cell_label,
ha="center", va="center", fontsize=12,
color="#374151", family="monospace", fontweight="bold",
col_idx + 0.5,
y + 0.5,
cell_label,
ha="center",
va="center",
fontsize=12,
color="#374151",
family="monospace",
fontweight="bold",
)
# row label: window position (p, q)
p, q = divmod(row_idx, Q)
ax.text(
-0.2, y + 0.5, f"({p},{q})",
ha="right", va="center", fontsize=8,
color="#6B7280", family="monospace",
-0.2,
y + 0.5,
f"({p},{q})",
ha="right",
va="center",
fontsize=8,
color="#6B7280",
family="monospace",
)
# thick outer border
rect = patches.Rectangle(
(0, 0), n_taps, n_windows,
facecolor="none", edgecolor="#1F2937", linewidth=2.5,
(0, 0),
n_taps,
n_windows,
facecolor="none",
edgecolor="#1F2937",
linewidth=2.5,
)
ax.add_patch(rect)
ax.set_xlim(-1.5, n_taps + 0.5)
Expand All @@ -274,21 +347,32 @@ def blend(win_indices):
ax.axis("off")
ax.set_title(
f"im2col output ({n_windows}\u00d7{n_taps})",
fontsize=11, fontweight="bold", pad=10,
fontsize=11,
fontweight="bold",
pad=10,
)

# ── Arrow connecting the panels ─────────────────────────────────
arrow = patches.FancyArrowPatch(
(0.44, 0.5), (0.52, 0.5),
(0.44, 0.5),
(0.52, 0.5),
transform=fig.transFigure,
arrowstyle="->,head_width=6,head_length=5",
color="#374151", linewidth=2,
color="#374151",
linewidth=2,
)
fig.patches.append(arrow)
fig.text(
0.48, 0.54, f"im2col({R}\u00d7{S})",
ha="center", va="bottom", fontsize=11, fontweight="bold",
color="#374151", family="monospace", transform=fig.transFigure,
0.48,
0.54,
f"im2col({R}\u00d7{S})",
ha="center",
va="bottom",
fontsize=11,
fontweight="bold",
color="#374151",
family="monospace",
transform=fig.transFigure,
)

plt.tight_layout()
Expand Down
15 changes: 12 additions & 3 deletions examples/composed.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def example_fast_path() -> None:


def example_exact_fallback() -> LayoutExpr:
_banner("2. Exact fallback: compositions that do not collapse return ComposedLayout")
_banner(
"2. Exact fallback: compositions that do not collapse return ComposedLayout"
)

base = Layout((4, 4), (4, 1))
swizzled = compose(Swizzle(2, 0, 2), base)
Expand Down Expand Up @@ -142,7 +144,12 @@ def maybe_draw(exact: LayoutExpr, outdir: Path | None) -> None:
from tensor_layouts.viz import draw_layout, draw_slice

outdir.mkdir(parents=True, exist_ok=True)
draw_layout(exact, outdir / "composed_exact.png", title="Exact composed layout", colorize=True)
draw_layout(
exact,
outdir / "composed_exact.png",
title="Exact composed layout",
colorize=True,
)
draw_slice(
exact,
(None, 1),
Expand All @@ -155,7 +162,9 @@ def maybe_draw(exact: LayoutExpr, outdir: Path | None) -> None:

def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--draw", type=Path, default=None, help="Optional directory for PNG output")
parser.add_argument(
"--draw", type=Path, default=None, help="Optional directory for PNG output"
)
args = parser.parse_args()

example_fast_path()
Expand Down
Loading
Loading