From c40398c4996c1d5f8a1e5a1b12db2978d973ca8d Mon Sep 17 00:00:00 2001 From: Teddy Do Date: Thu, 14 May 2026 16:22:40 -0700 Subject: [PATCH 1/2] [JAX] Size autotuned Triton grids per config (#2975) * [JAX] Size autotuned Triton grids per config (3x perm-kernel speedup) The autotuned path in triton_call_lowering compiled all BLOCK_SIZE configs but dispatched every one with the same fixed grid sized for the smallest BLOCK_SIZE, so larger configs over-launched by the BLOCK_SIZE ratio. Make grid accept a callable(meta)->tuple evaluated per config, matching the jax-triton API. Update _permute_kernel, _unpermute_kernel, and _sort_chunks_by_map_kernel lowerings. Measured 22.6ms -> 7.4ms (3.06x) on GB200 for sort_chunks at 524k tokens, hidden=4096, fp32. * [JAX] Triton wrapper defaults match jax-triton (3.25ms speedup) num_warps default 32->4 and num_stages 1->3 in triton_call_lowering match Triton's own triton.Config defaults. Non-autotuned kernels (e.g. _make_chunk_sort_map_kernel) were running with 1024 threads/block, an 8x kernel slowdown. Also: tuple/callable grid assertion + comment trims. Signed-off-by: tdophung --- .../jax/triton_extensions/permutation.py | 25 +++-- .../jax/triton_extensions/utils.py | 93 +++++++++++++------ 2 files changed, 84 insertions(+), 34 deletions(-) diff --git a/transformer_engine/jax/triton_extensions/permutation.py b/transformer_engine/jax/triton_extensions/permutation.py index 98c54e52bb..22f983f078 100644 --- a/transformer_engine/jax/triton_extensions/permutation.py +++ b/transformer_engine/jax/triton_extensions/permutation.py @@ -589,10 +589,13 @@ def lowering( probs_stride_token = 0 probs_stride_expert = 0 - # Grid function equivalent: (num_tokens, cdiv(hidden_size, BLOCK_SIZE)) - # Use minimum BLOCK_SIZE from autotune configs to ensure grid covers all elements + # We use BLOCK_SIZE in the grid calculation to ensure the grid is the + # proper size. If the grid size is an overestimate it can significantly + # hurt performance. + def grid(meta): + return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"])) + block_size = _get_min_block_size(_permute_kernel) - grid = (num_tokens, triton.cdiv(hidden_size, block_size)) # Use input_output_aliases to alias pre-zeroed buffers to outputs. # This ensures padding positions contain zeros since the kernel only writes valid positions. @@ -997,9 +1000,13 @@ def lowering( unpermuted_probs_stride_token = num_experts unpermuted_probs_stride_expert = 1 - # Grid - use minimum BLOCK_SIZE from autotune configs + # We use BLOCK_SIZE in the grid calculation to ensure the grid is the + # proper size. If the grid size is an overestimate it can significantly + # hurt performance. + def grid(meta): + return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"])) + block_size = _get_min_block_size(_unpermute_kernel) - grid = (num_tokens, triton.cdiv(hidden_size, block_size)) return triton_call_lowering( ctx, @@ -1720,9 +1727,13 @@ def lowering( probs_stride_token = 1 permuted_probs_stride_token = 1 - # Grid - use minimum BLOCK_SIZE from autotune configs + # We use BLOCK_SIZE in the grid calculation to ensure the grid is the + # proper size. If the grid size is an overestimate it can significantly + # hurt performance. + def grid(meta): + return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"])) + block_size = _get_min_block_size(_sort_chunks_by_map_kernel) - grid = (num_tokens, triton.cdiv(hidden_size, block_size)) # Declare input_output_aliases so XLA knows output slot 0 is claimed by # input 3 (output_buf). This prevents XLA from implicitly aliasing any diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 2a86321c34..332bc6ddb7 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -390,7 +390,16 @@ def triton_call_lowering( ctx: MLIR lowering context kernel_fn: Triton kernel function *array_args: Input arrays (from ctx) - grid: Grid dimensions (int or tuple) + grid: Grid dimensions. May be either: + - an int or tuple (fixed grid for every config), or + - a callable ``meta -> int|tuple`` (evaluated per autotune config). + + Use the callable form for autotuned kernels whose grid depends on + ``BLOCK_SIZE`` (or any other autotuned constexpr); otherwise the + launch grid will not match the autotuner-selected config and the + kernel will either over-launch (waste) or under-cover. ``meta`` is + the merged dict ``{**constexprs, **config.kwargs}`` for the chosen + config — the same convention as jax-triton's ``triton_call``. input_output_aliases: Mapping of input to output aliases constexprs: Compile-time constants for the kernel. This includes both tl.constexpr arguments AND scalar runtime arguments (like @@ -404,13 +413,12 @@ def triton_call_lowering( def lowering(ctx, x, *, block_size): from ..triton_extensions import triton_call_lowering n = ctx.avals_in[0].size + + def grid(meta): + return (triton.cdiv(n, meta["BLOCK_SIZE"]),) + return triton_call_lowering( - ctx, my_kernel, x, - grid=(triton.cdiv(n, block_size),), - constexprs={ - "n_elements": n, # scalar arg (not tl.constexpr in kernel) - "BLOCK_SIZE": block_size, # tl.constexpr arg - }, + ctx, my_kernel, x, grid=grid, constexprs={"n_elements": n}, ) """ # Get compute capability using gpu_triton @@ -431,22 +439,39 @@ def lowering(ctx, x, *, block_size): tensor_arg_names = [n for n in arg_names if n not in constexpr_names] signature = {n: get_triton_dtype(a) for n, a in zip(tensor_arg_names, all_avals)} - # Normalize grid to 3D - if isinstance(grid, int): - grid_tuple = (grid, 1, 1) - elif len(grid) == 1: - grid_tuple = (grid[0], 1, 1) - elif len(grid) == 2: - grid_tuple = (grid[0], grid[1], 1) - else: - grid_tuple = grid[:3] + assert callable(grid) or isinstance(grid, tuple), ( + "Argument 'grid' must be a tuple or a callable but received: " + f"type={type(grid)}, value={grid}" + ) - # Default values for the kernel + # Normalize grid to 3D. When `grid` is a callable, defer evaluation until + # we know the per-config meta (so each autotune config gets its own grid, + # matching jax-triton's behavior). + def _normalize_grid(grid_tuple): + if isinstance(grid_tuple, int): + return (grid_tuple, 1, 1) + if len(grid_tuple) == 1: + return (grid_tuple[0], 1, 1) + if len(grid_tuple) == 2: + return (grid_tuple[0], grid_tuple[1], 1) + return tuple(grid_tuple[:3]) + + grid_callable = grid if callable(grid) else None + if grid_callable is None: + grid_tuple = _normalize_grid(grid) + else: + grid_tuple = None # evaluated per-config below + + # Default kernel launch parameters. These apply to non-autotuned kernels + # and as a fallback when an autotuned config doesn't specify them. Values + # match Triton's own `triton.Config` defaults (num_warps=4, num_stages=3, + # num_ctas=1) and jax-triton's `get_or_create_triton_kernel`. Using a + # larger default (e.g. num_warps=32) over-provisions threads per block, + # which slashes SM occupancy on non-autotuned kernels — measured as an 8× + # slowdown on `_make_chunk_sort_map_kernel` vs jax-triton. actual_kernel_fn = kernel_fn - num_warps = 32 - num_stages = ( - 1 # TODO(Phuong): consider if it is beneficial to expose num_warps, num_stages, num_ctas - ) + num_warps = 4 + num_stages = 3 num_ctas = 1 kernel_constexprs = constexprs if constexprs is not None else {} @@ -510,11 +535,18 @@ def lowering(ctx, x, *, block_size): for _ in list(ctx.avals_in) + list(ctx.avals_out): config_params.append(gpu_triton.create_array_parameter(0, 16)) + # Per-config grid: evaluate `grid(meta)` if grid is a callable so + # the launch shape matches this config's BLOCK_SIZE (etc.). + if grid_callable is not None: + config_grid = _normalize_grid(grid_callable(config_constexprs)) + else: + config_grid = grid_tuple + config_call = gpu_triton.TritonKernelCall( config_kernel, - grid_tuple[0], - grid_tuple[1], - grid_tuple[2], + config_grid[0], + config_grid[1], + config_grid[2], config_params, ) @@ -571,11 +603,18 @@ def lowering(ctx, x, *, block_size): for _ in list(ctx.avals_in) + list(ctx.avals_out): kernel_params.append(gpu_triton.create_array_parameter(0, 16)) + # Non-autotuned dispatch: evaluate `grid(meta)` once with the merged + # constexprs (which already reflect the single config we'll launch). + if grid_callable is not None: + single_grid = _normalize_grid(grid_callable(kernel_constexprs)) + else: + single_grid = grid_tuple + kernel_call = gpu_triton.TritonKernelCall( kernel, - grid_tuple[0], - grid_tuple[1], - grid_tuple[2], + single_grid[0], + single_grid[1], + single_grid[2], kernel_params, ) From eca05d3b554e36d99106c368f45df2cc350ddebc Mon Sep 17 00:00:00 2001 From: Arpit Jain <3242828+arpitjain099@users.noreply.github.com> Date: Fri, 15 May 2026 08:53:54 +0900 Subject: [PATCH 2/2] ci: declare contents:read on Lint workflow (#2989) The Lint workflow runs cpplint and pylint against the checked-out tree. No cache, no GitHub API write. `permissions: contents: read` captures that and matches the per-job permissions blocks already used in deploy_nightly_docs.yml (pages:write + id-token:write) and upload-ci-logs.yml (statuses:write). build.yml is left out because it pulls mozilla-actions/sccache-action (which writes to the Actions cache) and easimon/maximize-build-space. A drive-by permissions block there would need actions:write for the sccache save path, which deserves a separate look. Signed-off-by: Arpit Jain --- .github/workflows/lint.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 1d2fb272f8..016d2079d2 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -11,6 +11,8 @@ concurrency: # Group by workflow name + PR number (for PRs) or ref (for branch/tag pushes) group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: true +permissions: + contents: read jobs: pytorch_cpplint: name: 'PyTorch C++'