diff --git a/benchmark/examples/benchmark_moe.py b/benchmark/examples/benchmark_moe.py index c3818fa2..6ffff367 100644 --- a/benchmark/examples/benchmark_moe.py +++ b/benchmark/examples/benchmark_moe.py @@ -144,6 +144,41 @@ def parse_args(): default=None, help="Override GEMM_SMS for WG-specialized variant (default: auto)", ) + + # WG kernel tuning sweep arguments. + parser.add_argument( + "--gemm_sms_values", + type=int, + nargs="+", + default=None, + metavar="N", + help=( + "List of GEMM_SMS values to sweep when --tune is active " + "(e.g. --gemm_sms_values 64 128 192). " + "Ignored unless --tune is set." + ), + ) + parser.add_argument( + "--block_m_values", + type=int, + nargs="+", + default=None, + metavar="N", + help=( + "List of BLOCK_M tile sizes to sweep when --tune is active " + "(e.g. --block_m_values 64 128 256). " + "Ignored unless --tune is set." + ), + ) + parser.add_argument( + "--tune", + action="store_true", + help=( + "Sweep all (gemm_sms, block_m) combinations for the WG fusion mode " + "across each bpe and report the best configuration per bpe. " + "Requires --fusion_mode wg_fused_grouped_matmul_convert_ep_to_dp." + ), + ) return parser.parse_args() @@ -170,6 +205,9 @@ def _run_dist_once( shmem, fusion_config, gemm_sms=None, + block_m=None, + block_n=None, + block_k=None, ): return mixture_of_expt_epsharded( x_dp_local, @@ -181,7 +219,88 @@ def _run_dist_once( shmem, fusion_config=fusion_config, gemm_sms=gemm_sms, + block_m=block_m, + block_n=block_n, + block_k=block_k, + ) + + +def _bench_dist(run_fn, shmem, heap_snapshot, n_warmup, n_repeat): + """Benchmark a single distributed run function and return mean latency in ms.""" + reset_heap = _make_heap_resetter(shmem.heap.allocator, heap_snapshot) + saved_refresh = shmem.heap.refresh_peer_access + shmem.heap.refresh_peer_access = lambda: None + ms = iris.do_bench( + run_fn, + barrier_fn=shmem.barrier, + preamble_fn=reset_heap, + n_warmup=n_warmup, + n_repeat=n_repeat, + return_mode="mean", ) + shmem.heap.refresh_peer_access = saved_refresh + reset_heap() + return float(ms) + + +def _tune_wg_configs( + x_dp_local, + l_dp_local, + w_ep_local, + b_ep_local, + expt_assignment, + n_expts_act, + shmem, + fusion_config, + heap_snapshot, + gemm_sms_values, + block_m_values, + rank, + n_warmup=5, + n_repeat=20, +): + """Sweep (gemm_sms, block_m) combinations and return the best config. + + Returns: + best_gemm_sms (int), best_block_m (int), tune_configs (list[dict]) + """ + tune_configs = [] + best_ms = float("inf") + best_gemm_sms = gemm_sms_values[0] + best_block_m = block_m_values[0] + + for gs in gemm_sms_values: + for bm in block_m_values: + run_fn = functools.partial( + _run_dist_once, + x_dp_local, + l_dp_local, + w_ep_local, + b_ep_local, + expt_assignment, + n_expts_act, + shmem, + fusion_config, + gs, + bm, + ) + try: + ms = _bench_dist(run_fn, shmem, heap_snapshot, n_warmup, n_repeat) + except Exception as e: + if rank == 0: + print(f" [tune] gemm_sms={gs} block_m={bm} FAILED: {e}") + ms = float("inf") + + if rank == 0: + print(f" [tune] gemm_sms={gs:4d} block_m={bm:4d} ms={ms:.3f}") + tune_configs.append({"gemm_sms": gs, "block_m": bm, "ms": ms}) + + if ms < best_ms: + best_ms = ms + best_gemm_sms = gs + best_block_m = bm + + return best_gemm_sms, best_block_m, tune_configs def _worker(rank: int, world_size: int, init_url: str, args): @@ -205,6 +324,9 @@ def _worker(rank: int, world_size: int, init_url: str, args): if args.n_expts_tot % ws != 0: raise ValueError(f"n_expts_tot ({args.n_expts_tot}) must be divisible by world_size ({ws})") + if getattr(args, "tune", False) and args.fusion_mode != "wg_fused_grouped_matmul_convert_ep_to_dp": + raise ValueError("--tune requires --fusion_mode wg_fused_grouped_matmul_convert_ep_to_dp") + if args.batch_per_expt: sweep = args.batch_per_expt else: @@ -213,6 +335,27 @@ def _worker(rank: int, world_size: int, init_url: str, args): if rank == 0: os.makedirs(args.output_dir, exist_ok=True) + # Derive default sweep values for tune mode. + if getattr(args, "tune", False): + cu_count = torch.cuda.get_device_properties(device).multi_processor_count + num_sms = int(cu_count) + if getattr(args, "gemm_sms_values", None): + gemm_sms_sweep = args.gemm_sms_values + else: + # Default: quarter, half, and three-quarter of available SMs, + # each clamped to [1, num_sms - 1]. + gemm_sms_sweep = sorted( + { + max(1, min(num_sms // 4, num_sms - 1)), + max(1, min(num_sms // 2, num_sms - 1)), + max(1, min(3 * num_sms // 4, num_sms - 1)), + } + ) + block_m_sweep = getattr(args, "block_m_values", None) or [32, 64, 128, 256] + else: + gemm_sms_sweep = None + block_m_sweep = None + results: list[dict] = [] sweep_heap_base = shmem.heap.allocator.heap_offset @@ -247,6 +390,41 @@ def _worker(rank: int, world_size: int, init_url: str, args): w_ep_local = w_global[expt_assignment.expt_boolmask[rank]].contiguous() b_ep_local = b_global[expt_assignment.expt_boolmask[rank]].contiguous() + # --- Tune: sweep (gemm_sms, block_m) and pick the best config ------- + tune_result = {} + active_gemm_sms = args.gemm_sms + active_block_m = None + if getattr(args, "tune", False): + heap_snapshot = shmem.heap.allocator.heap_offset + if rank == 0: + print(f"[tune bpe={bpe}] sweeping gemm_sms={gemm_sms_sweep} block_m={block_m_sweep}") + best_gs, best_bm, tune_configs = _tune_wg_configs( + x_dp_local, + l_dp_local, + w_ep_local, + b_ep_local, + expt_assignment, + args.n_expts_act, + shmem, + fusion_config, + heap_snapshot, + gemm_sms_sweep, + block_m_sweep, + rank, + n_warmup=5, + n_repeat=20, + ) + active_gemm_sms = best_gs + active_block_m = best_bm + tune_result = { + "tune_best_gemm_sms": best_gs, + "tune_best_block_m": best_bm, + "tune_configs": tune_configs, + } + if rank == 0: + print(f"[tune bpe={bpe}] best: gemm_sms={best_gs} block_m={best_bm}") + shmem.heap.allocator.heap_offset = sweep_heap_base + run_dist = functools.partial( _run_dist_once, x_dp_local, @@ -257,7 +435,8 @@ def _worker(rank: int, world_size: int, init_url: str, args): args.n_expts_act, shmem, fusion_config, - args.gemm_sms, + active_gemm_sms, + active_block_m, ) if args.validate or args.compare_single_gpu: @@ -284,7 +463,8 @@ def _worker(rank: int, world_size: int, init_url: str, args): shmem, fusion_config=fusion_config, timing_dict=td, - gemm_sms=args.gemm_sms, + gemm_sms=active_gemm_sms, + block_m=active_block_m, ) if rank == 0: for j in range(1, len(td)): @@ -310,6 +490,7 @@ def _worker(rank: int, world_size: int, init_url: str, args): "dtype": args.datatype, "fusion_mode": fusion_config.mode_name(), } + result.update(tune_result) if args.validate: diff = (y_ref.float() - y_tri.float()).abs() @@ -364,6 +545,12 @@ def run_ref(): else "" ) + (f" max_diff={result.get('validate_max_diff', 0.0):.4f}" if args.validate else "") + + ( + f" best_config=(gemm_sms={tune_result['tune_best_gemm_sms']}" + f" block_m={tune_result['tune_best_block_m']})" + if tune_result + else "" + ) ) results.append(result) @@ -393,8 +580,8 @@ def run_ref(): def main(): args = parse_args() - if not args.benchmark and not args.validate and not args.compare_single_gpu: - print("No mode selected. Use at least one of: --benchmark, --validate, --compare_single_gpu") + if not args.benchmark and not args.validate and not args.compare_single_gpu and not getattr(args, "tune", False): + print("No mode selected. Use at least one of: --benchmark, --validate, --compare_single_gpu, --tune") sys.exit(1) init_url = f"tcp://127.0.0.1:{args.init_port}" diff --git a/examples/31_expert_sharded_moe/fused_exp_matmul_ep_to_dp_wg.py b/examples/31_expert_sharded_moe/fused_exp_matmul_ep_to_dp_wg.py index 4d5e1502..0283e4f8 100644 --- a/examples/31_expert_sharded_moe/fused_exp_matmul_ep_to_dp_wg.py +++ b/examples/31_expert_sharded_moe/fused_exp_matmul_ep_to_dp_wg.py @@ -19,8 +19,6 @@ Lock granularity: one lock per (expert, N-tile, M-tile) triple. """ -import math - import torch import triton import triton.language as tl @@ -189,6 +187,33 @@ def _wg_fused_exp_matmul_ep_to_dp_kernel( iris.store(dst_ptrs_2d, out, SRC_RANK, r, heap_bases, mask=store_mask, hint=(1, 16)) +def _heuristic_wg_config(num_sms: int, avg_bpe: int) -> tuple[int, int]: + """Select (gemm_sms, block_m) heuristically based on avg tokens-per-expert. + + Heuristic derived from a tune sweep on MI300X (304 CUs, 8 ranks) across + bpe ∈ {64, 128, 256, 512, 1024} with block_m ∈ {32, 64, 128, 256} and + gemm_sms ∈ {¼, ½, ¾} × num_sms: + + bpe ≤ 64 → gemm_sms = num_sms // 2, block_m = 128 + bpe ≤ 128 → gemm_sms = 3 * num_sms // 4, block_m = 128 + bpe > 128 → gemm_sms = 3 * num_sms // 4, block_m = 256 + + Args: + num_sms: Total CU count on the device. + avg_bpe: Average number of tokens routed per local expert + (n_slots_per_rank // n_local_experts). + + Returns: + (gemm_sms, block_m) tuple. + """ + if avg_bpe <= 64: + return max(1, num_sms // 2), 128 + elif avg_bpe <= 128: + return max(1, 3 * num_sms // 4), 128 + else: + return max(1, 3 * num_sms // 4), 256 + + def wg_fused_exp_matmul_ep_to_dp( x_ep_local: torch.Tensor, w_ep_local: torch.Tensor, @@ -200,6 +225,9 @@ def wg_fused_exp_matmul_ep_to_dp( shmem, ragged_metadata: RaggedTensorMetadata | None = None, gemm_sms: int | None = None, + block_m: int | None = None, + block_n: int | None = None, + block_k: int | None = None, ) -> torch.Tensor: """WG-specialized fused expert matmul + EP->DP scatter. @@ -216,7 +244,14 @@ def wg_fused_exp_matmul_ep_to_dp( combine_indx: (n_total_slots,) col_sorted_indx. shmem: iris.Iris instance. ragged_metadata: local-expert-view ragged metadata. - gemm_sms: Number of CUs for GEMM path. Default: 2^floor(log2(cu_count)). + gemm_sms: Number of CUs for GEMM path. + Default: auto-selected by _heuristic_wg_config based on avg bpe. + block_m: GEMM tile size along the M (token) dimension. + Default: auto-selected by _heuristic_wg_config based on avg bpe. + block_n: GEMM tile size along the N (output) dimension. + Default: min(triton.next_power_of_2(N), 128). + block_k: GEMM tile size along the K (reduction) dimension. + Default: min(triton.next_power_of_2(K), 64). Returns: (n_slots_per_rank, d_model) DP-local combined output. @@ -228,9 +263,21 @@ def wg_fused_exp_matmul_ep_to_dp( K = d_model N = d_model - BLOCK_M = 128 - BLOCK_N = min(triton.next_power_of_2(N), 128) - BLOCK_K = min(triton.next_power_of_2(K), 64) + device = x_ep_local.device + num_sms = torch.cuda.get_device_properties(device).multi_processor_count + + # Derive heuristic defaults for gemm_sms / block_m when not specified. + if gemm_sms is None or block_m is None: + avg_bpe = n_slots_per_rank // max(n_local_experts, 1) + h_gemm_sms, h_block_m = _heuristic_wg_config(num_sms, avg_bpe) + if gemm_sms is None: + gemm_sms = h_gemm_sms + if block_m is None: + block_m = h_block_m + + BLOCK_M = block_m + BLOCK_N = block_n if block_n is not None else min(triton.next_power_of_2(N), 128) + BLOCK_K = block_k if block_k is not None else min(triton.next_power_of_2(K), 64) max_slice_size = int(ragged_metadata.slice_sizes.max().item()) max_m_tiles = triton.cdiv(max_slice_size, BLOCK_M) @@ -242,12 +289,6 @@ def wg_fused_exp_matmul_ep_to_dp( shmem.barrier() return dst_local - device = x_ep_local.device - cu_count = torch.cuda.get_device_properties(device).multi_processor_count - num_sms = cu_count - if gemm_sms is None: - gemm_sms = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 - y_buf = torch.zeros((n_total_slots, N), dtype=x_ep_local.dtype, device=device) dst_local = shmem.zeros((n_slots_per_rank, d_model), dtype=x_ep_local.dtype) n_locks = n_n_tiles * n_local_experts * max_m_tiles diff --git a/examples/31_expert_sharded_moe/moe.py b/examples/31_expert_sharded_moe/moe.py index 580fb4bd..270f267a 100644 --- a/examples/31_expert_sharded_moe/moe.py +++ b/examples/31_expert_sharded_moe/moe.py @@ -210,6 +210,9 @@ def mixture_of_expt_epsharded( fusion_config: MoeFusionConfig | None = None, timing_dict: dict | None = None, gemm_sms: int | None = None, + block_m: int | None = None, + block_n: int | None = None, + block_k: int | None = None, ): """Expert-parallel MoE forward using iris symmetric heap. @@ -221,6 +224,12 @@ def mixture_of_expt_epsharded( expt_assignment: ExptAssignment mapping experts to ranks. n_expts_act: k (experts per token). shmem: iris.Iris instance. + fusion_config: MoeFusionConfig controlling which stages are fused. + timing_dict: optional list to collect (label, cuda_event) pairs. + gemm_sms: GEMM CU count for WG-specialized kernel (default: auto). + block_m: M-dim tile size for WG kernel (default: 128). + block_n: N-dim tile size for WG kernel (default: auto from d_model). + block_k: K-dim tile size for WG kernel (default: auto from d_model). Returns: (n_tokens_local, d_model) output for this rank's tokens. @@ -344,6 +353,9 @@ def _tick(label): shmem, ragged_metadata=y_ep_local_metadata, gemm_sms=gemm_sms, + block_m=block_m, + block_n=block_n, + block_k=block_k, ) _tick("wg_fused_matmul_scatter") else: